From e1eda90e29497aecea3ea0bb7a09e62536233498 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 12 Feb 2017 21:46:15 -0600 Subject: [PATCH 0001/1158] Add ChunkReader --- chunkreader.go | 106 ++++++++++++++++++++++++++++++++++++ chunkreader_test.go | 128 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 234 insertions(+) create mode 100644 chunkreader.go create mode 100644 chunkreader_test.go diff --git a/chunkreader.go b/chunkreader.go new file mode 100644 index 00000000..f9d6555c --- /dev/null +++ b/chunkreader.go @@ -0,0 +1,106 @@ +package chunkreader + +import ( + "io" +) + +type ChunkReader struct { + r io.Reader + + buf []byte + rp, wp int // buf read position and write position + taken bool + + options Options +} + +type Options struct { + MinBufLen int // Minimum buffer length + BlockLen int // Increments to expand buffer (e.g. a 8000 byte request with a BlockLen of 1024 would yield a buffer len of 8192) +} + +func NewChunkReader(r io.Reader) *ChunkReader { + cr, err := NewChunkReaderEx(r, Options{}) + if err != nil { + panic("default options can't be bad") + } + + return cr +} + +func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) { + if options.MinBufLen == 0 { + options.MinBufLen = 4096 + } + if options.BlockLen == 0 { + options.BlockLen = 512 + } + + return &ChunkReader{ + r: r, + buf: make([]byte, options.MinBufLen), + options: options, + }, nil +} + +// Next returns buf filled with the next n bytes. buf is only valid until the +// next call to Next. If an error occurs, buf will be nil. +func (r *ChunkReader) Next(n int) (buf []byte, err error) { + // n bytes already in buf + if (r.wp - r.rp) >= n { + buf = r.buf[r.rp : r.rp+n] + r.rp += n + return buf, err + } + + // available space in buf is less than n + if len(r.buf) < n { + r.copyBufContents(r.newBuf(n)) + r.taken = false + } + + // buf is large enough, but need to shift filled area to start to make enough contiguous space + minReadCount := n - (r.wp - r.rp) + if (len(r.buf) - r.wp) < minReadCount { + newBuf := r.buf + if r.taken { + newBuf = r.newBuf(n) + r.taken = false + } + r.copyBufContents(newBuf) + } + + if err := r.appendAtLeast(minReadCount); err != nil { + return nil, err + } + + buf = r.buf[r.rp : r.rp+n] + r.rp += n + return buf, nil +} + +// KeepLast prevents the last data retrieved by Next from being reused by the +// ChunkReader. +func (r *ChunkReader) KeepLast() { + r.taken = true +} + +func (r *ChunkReader) appendAtLeast(fillLen int) error { + n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen) + r.wp += n + return err +} + +func (r *ChunkReader) newBuf(min int) []byte { + size := ((min / r.options.BlockLen) + 1) * r.options.BlockLen + if size < r.options.MinBufLen { + size = r.options.MinBufLen + } + return make([]byte, size) +} + +func (r *ChunkReader) copyBufContents(dest []byte) { + r.wp = copy(dest, r.buf[r.rp:r.wp]) + r.rp = 0 + r.buf = dest +} diff --git a/chunkreader_test.go b/chunkreader_test.go new file mode 100644 index 00000000..9c19ff4a --- /dev/null +++ b/chunkreader_test.go @@ -0,0 +1,128 @@ +package chunkreader + +import ( + "bytes" + "testing" +) + +func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { + server := &bytes.Buffer{} + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 2}) + if err != nil { + t.Fatal(err) + } + + src := []byte{1, 2, 3, 4} + server.Write(src) + + n1, err := r.Next(2) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n1, src[0:2]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1) + } + + n2, err := r.Next(2) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n2, src[2:4]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2) + } + + if bytes.Compare(r.buf, src) != 0 { + t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf) + } + if r.rp != 4 { + t.Fatalf("Expected r.rp to be %v, but it was %v", 4, r.rp) + } + if r.wp != 4 { + t.Fatalf("Expected r.wp to be %v, but it was %v", 4, r.wp) + } +} + +func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { + server := &bytes.Buffer{} + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 2}) + if err != nil { + t.Fatal(err) + } + + src := []byte{1, 2, 3, 4, 5, 6, 7, 8} + server.Write(src) + + n1, err := r.Next(5) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n1, src[0:5]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1) + } + if len(r.buf) != 6 { + t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 6, len(r.buf)) + } +} + +func TestChunkReaderNextReusesBuf(t *testing.T) { + server := &bytes.Buffer{} + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 1}) + if err != nil { + t.Fatal(err) + } + + src := []byte{1, 2, 3, 4, 5, 6, 7, 8} + server.Write(src) + + n1, err := r.Next(4) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n1, src[0:4]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1) + } + + n2, err := r.Next(4) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n2, src[4:8]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2) + } + + if bytes.Compare(n1, src[4:8]) != 0 { + t.Fatalf("Expected Next to have reused buf, %v found instead of %v", src[4:8], n1) + } +} + +func TestChunkReaderKeepLastPreventsBufReuse(t *testing.T) { + server := &bytes.Buffer{} + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 1}) + if err != nil { + t.Fatal(err) + } + + src := []byte{1, 2, 3, 4, 5, 6, 7, 8} + server.Write(src) + + n1, err := r.Next(4) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n1, src[0:4]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1) + } + r.KeepLast() + + n2, err := r.Next(4) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n2, src[4:8]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2) + } + + if bytes.Compare(n1, src[0:4]) != 0 { + t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1) + } +} From ac2414449c1bc2ddf447d256da3c9fa0ad4f5c36 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 15:33:34 -0600 Subject: [PATCH 0002/1158] Initial proof-of-concept for pgtype Squashed commit of the following: commit c19454582b335ce5bdda6320f7e4e8c76cfeaf44 Author: Jack Christensen Date: Fri Mar 3 15:24:47 2017 -0600 Add AssignTo to pgtype.Timestamptz Also handle infinity for pgtype.Date commit 7329933610b38f4bc15731b1f7c55c520b49e300 Author: Jack Christensen Date: Fri Mar 3 15:12:18 2017 -0600 Implement AssignTo for most pgtypes commit cc3d1e4af896d34ec98c3bf2e982d0367451f21c Author: Jack Christensen Date: Thu Mar 2 21:19:07 2017 -0600 Use pgtype.Int2Array in pgx commit 36da5cc2178d1a31a56dc6e6f128843bd80dea0b Author: Jack Christensen Date: Tue Feb 28 21:45:33 2017 -0600 Add text array transcoding commit 1b0f18d99f38b69f8c2db26388815e67b2b03d59 Author: Jack Christensen Date: Mon Feb 27 19:28:55 2017 -0600 Add ParseUntypedTextArray commit 0f50ce3e833fc38495d333228daf04f5142be676 Author: Jack Christensen Date: Mon Feb 27 18:54:20 2017 -0600 wip commit d934f273627d79997035c282416db922f2fbe87a Author: Jack Christensen Date: Sun Feb 26 17:14:32 2017 -0600 WIP - beginning text format array parsing commit 7276ad33ce7fa9c250745a3ed909998f3dae4a32 Author: Jack Christensen Date: Sat Feb 25 22:50:11 2017 -0600 Beginning binary arrays commit 917faa5a3175d376222423c10aca297a20f96448 Author: Jack Christensen Date: Sat Feb 25 19:36:35 2017 -0600 Fix incomplete tests commit de8c140cfb98b7b047d53c5718ccbf12eaf813a1 Author: Jack Christensen Date: Sat Feb 25 19:32:22 2017 -0600 Add timestamptz null and infinity commit 7d9f954de4e071a1eccac762248079b90dbeb53f Author: Jack Christensen Date: Sat Feb 25 18:19:38 2017 -0600 Add infinity to pgtype.Date commit 7bf783ae20ba05571c2fb9f661183233c95eab41 Author: Jack Christensen Date: Sat Feb 25 17:19:55 2017 -0600 Add Status to pgtype.Date commit 984500455c9b9a4b6221758540d248e6410d93a4 Author: Jack Christensen Date: Sat Feb 25 16:54:01 2017 -0600 Add status to Int4 and Int8 commit 6fe76fcfc2de31552790db3b093480a9d5b2a742 Author: Jack Christensen Date: Sat Feb 25 16:40:27 2017 -0600 Extract testSuccessfulTranscode commit 001647c1da03f796014cf21f41c9a7fd2cfadfde Author: Jack Christensen Date: Sat Feb 25 16:15:51 2017 -0600 Add Status to pgtype.Int2 commit 720451f06d13d9c9fa2a0482e010f24bf4627c2a Author: Jack Christensen Date: Sat Feb 25 15:56:44 2017 -0600 Add status to pgtype.Bool commit 325f700b6edff215a692b10bc5b94cdfe1100769 Author: Jack Christensen Date: Fri Feb 24 17:28:15 2017 -0600 Add date to conversion system commit 4a9343e45d3897f59eab98a0009d2ddbe07e02d7 Author: Jack Christensen Date: Fri Feb 24 16:28:35 2017 -0600 Add bool to oid based encoding commit d984fcafab1476cf84852485b6711f4b2069eb6d Author: Jack Christensen Date: Fri Feb 24 16:15:38 2017 -0600 Add pgtype interfaces commit 0f93bfc2de4023b069b966c0988bf7f0469d1809 Author: Jack Christensen Date: Fri Feb 24 14:48:34 2017 -0600 Begin introduction of Convert commit e5707023cac7c07342b8c910e480d09a1caaf6ee Author: Jack Christensen Date: Fri Feb 24 14:10:56 2017 -0600 Move bool to pgtype commit bb764d2129efe7fb21e841dbb35e6d0dc7586d37 Author: Jack Christensen Date: Fri Feb 24 13:45:05 2017 -0600 Add Int2 test commit 08c49437f455a32f7c3f0a524cd21a895d440301 Author: Jack Christensen Date: Fri Feb 24 13:44:09 2017 -0600 Add Int4 test commit 16722952222fd15c53c8fa84974645504a6d0dc0 Author: Jack Christensen Date: Fri Feb 24 08:56:59 2017 -0600 Add int8 tests commit 83a5447cd2c46b58d0880023cc4e9af0c84988a2 Author: Jack Christensen Date: Wed Feb 22 18:08:05 2017 -0600 wip commit 0ca0ee72068a72b016729b01fccef22474595285 Author: Jack Christensen Date: Mon Feb 20 18:56:52 2017 -0600 wip commit d2c2baf4ea2cd0793d68c7094c425217df952bec Author: Jack Christensen Date: Mon Feb 20 18:46:10 2017 -0600 wip commit f78371da0098356527b193fd496a338da5fe414b Author: Jack Christensen Date: Mon Feb 20 17:43:39 2017 -0600 wip commit 3366699bea62ec0110db05f3cb2998d58ac9ce5d Author: Jack Christensen Date: Mon Feb 20 14:07:47 2017 -0600 wip commit 66b79e940870fd0133ebb10ac1547e1d4d7d0b51 Author: Jack Christensen Date: Mon Feb 20 13:35:37 2017 -0600 Extract pgio commit 8b07d97d1305ed98fd76db6e306a289c0af92d56 Author: Jack Christensen Date: Mon Feb 20 13:20:00 2017 -0600 wip commit 62f1adb3427f4317b708da075dce50c4d4daff7b Author: Jack Christensen Date: Mon Feb 20 12:08:46 2017 -0600 wip commit a712d2546933a5a8433c65eef0ff2ee135077c87 Author: Jack Christensen Date: Mon Feb 20 09:30:52 2017 -0600 wip commit 4faf97cc588126dda160fc360680719572a23105 Author: Jack Christensen Date: Fri Feb 17 22:20:18 2017 -0600 wip --- array.go | 375 ++++++++++++++++++++++++++++++++++++++++++++ array_test.go | 98 ++++++++++++ bool.go | 166 ++++++++++++++++++++ bool_test.go | 43 +++++ convert.go | 239 ++++++++++++++++++++++++++++ date.go | 191 ++++++++++++++++++++++ date_test.go | 51 ++++++ extra-interface.txt | 3 + int2.go | 167 ++++++++++++++++++++ int2_test.go | 55 +++++++ int2array.go | 308 ++++++++++++++++++++++++++++++++++++ int2array_test.go | 87 ++++++++++ int4.go | 158 +++++++++++++++++++ int4_test.go | 55 +++++++ int8.go | 149 ++++++++++++++++++ int8_test.go | 55 +++++++ pgtype.go | 102 ++++++++++++ pgtype_test.go | 108 +++++++++++++ text_element.go | 112 +++++++++++++ timestamptz.go | 203 ++++++++++++++++++++++++ timestamptz_test.go | 60 +++++++ 21 files changed, 2785 insertions(+) create mode 100644 array.go create mode 100644 array_test.go create mode 100644 bool.go create mode 100644 bool_test.go create mode 100644 convert.go create mode 100644 date.go create mode 100644 date_test.go create mode 100644 extra-interface.txt create mode 100644 int2.go create mode 100644 int2_test.go create mode 100644 int2array.go create mode 100644 int2array_test.go create mode 100644 int4.go create mode 100644 int4_test.go create mode 100644 int8.go create mode 100644 int8_test.go create mode 100644 pgtype.go create mode 100644 pgtype_test.go create mode 100644 text_element.go create mode 100644 timestamptz.go create mode 100644 timestamptz_test.go diff --git a/array.go b/array.go new file mode 100644 index 00000000..75d2e440 --- /dev/null +++ b/array.go @@ -0,0 +1,375 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + "strconv" + "unicode" + + "github.com/jackc/pgx/pgio" +) + +// Information on the internals of PostgreSQL arrays can be found in +// src/include/utils/array.h and src/backend/utils/adt/arrayfuncs.c. Of +// particular interest is the array_send function. + +type ArrayHeader struct { + ContainsNull bool + ElementOID int32 + Dimensions []ArrayDimension +} + +type ArrayDimension struct { + Length int32 + LowerBound int32 +} + +func (ah *ArrayHeader) DecodeBinary(r io.Reader) error { + numDims, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if numDims > 0 { + ah.Dimensions = make([]ArrayDimension, numDims) + } + + containsNull, err := pgio.ReadInt32(r) + if err != nil { + return err + } + ah.ContainsNull = containsNull == 1 + + ah.ElementOID, err = pgio.ReadInt32(r) + if err != nil { + return err + } + + for i := range ah.Dimensions { + ah.Dimensions[i].Length, err = pgio.ReadInt32(r) + if err != nil { + return err + } + + ah.Dimensions[i].LowerBound, err = pgio.ReadInt32(r) + if err != nil { + return err + } + } + + return nil +} + +func (ah *ArrayHeader) EncodeBinary(w io.Writer) error { + _, err := pgio.WriteInt32(w, int32(len(ah.Dimensions))) + if err != nil { + return err + } + + var containsNull int32 + if ah.ContainsNull { + containsNull = 1 + } + _, err = pgio.WriteInt32(w, containsNull) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, ah.ElementOID) + if err != nil { + return err + } + + for i := range ah.Dimensions { + _, err = pgio.WriteInt32(w, ah.Dimensions[i].Length) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, ah.Dimensions[i].LowerBound) + if err != nil { + return err + } + } + + return nil +} + +type UntypedTextArray struct { + Elements []string + Dimensions []ArrayDimension +} + +func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { + uta := &UntypedTextArray{} + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + var explicitDimensions []ArrayDimension + + // Array has explicit dimensions + if r == '[' { + buf.UnreadRune() + + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r == '=' { + break + } else if r != '[' { + return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r) + } + + lower, err := arrayParseInteger(buf) + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != ':' { + return nil, fmt.Errorf("invalid array, expected ':' got %v", r) + } + + upper, err := arrayParseInteger(buf) + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != ']' { + return nil, fmt.Errorf("invalid array, expected ']' got %v", r) + } + + explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1}) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + } + + if r != '{' { + return nil, fmt.Errorf("invalid array, expected '{': %v", err) + } + + implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}} + + // Consume all initial opening brackets. This provides number of dimensions. + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r == '{' { + implicitDimensions[len(implicitDimensions)-1].Length = 1 + implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1}) + } else { + buf.UnreadRune() + break + } + } + currentDim := len(implicitDimensions) - 1 + counterDim := currentDim + + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + switch r { + case '{': + if currentDim == counterDim { + implicitDimensions[currentDim].Length++ + } + currentDim++ + case ',': + case '}': + currentDim-- + if currentDim < counterDim { + counterDim = currentDim + } + default: + buf.UnreadRune() + value, err := arrayParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid array value: %v", err) + } + if currentDim == counterDim { + implicitDimensions[currentDim].Length++ + } + uta.Elements = append(uta.Elements, value) + } + + if currentDim < 0 { + break + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + if len(uta.Elements) == 0 { + uta.Dimensions = nil + } else if len(explicitDimensions) > 0 { + uta.Dimensions = explicitDimensions + } else { + uta.Dimensions = implicitDimensions + } + + return uta, nil +} + +func skipWhitespace(buf *bytes.Buffer) { + var r rune + var err error + for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() { + } + + if err != io.EOF { + buf.UnreadRune() + } +} + +func arrayParseValue(buf *bytes.Buffer) (string, error) { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + if r == '"' { + return arrayParseQuotedValue(buf) + } + buf.UnreadRune() + + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case ',', '}': + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} + +func arrayParseQuotedValue(buf *bytes.Buffer) (string, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case '"': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + buf.UnreadRune() + return s.String(), nil + } + s.WriteRune(r) + } +} + +func arrayParseInteger(buf *bytes.Buffer) (int32, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return 0, err + } + + if '0' <= r && r <= '9' { + s.WriteRune(r) + } else { + buf.UnreadRune() + n, err := strconv.ParseInt(s.String(), 10, 32) + if err != nil { + return 0, err + } + return int32(n), nil + } + } +} + +func EncodeTextArrayDimensions(w io.Writer, dimensions []ArrayDimension) error { + var customDimensions bool + for _, dim := range dimensions { + if dim.LowerBound != 1 { + customDimensions = true + } + } + + if !customDimensions { + return nil + } + + for _, dim := range dimensions { + err := pgio.WriteByte(w, '[') + if err != nil { + return err + } + + _, err = io.WriteString(w, strconv.FormatInt(int64(dim.LowerBound), 10)) + if err != nil { + return err + } + + err = pgio.WriteByte(w, ':') + if err != nil { + return err + } + + _, err = io.WriteString(w, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)) + if err != nil { + return err + } + + err = pgio.WriteByte(w, ']') + if err != nil { + return err + } + } + + return pgio.WriteByte(w, '=') +} diff --git a/array_test.go b/array_test.go new file mode 100644 index 00000000..5e5f00e7 --- /dev/null +++ b/array_test.go @@ -0,0 +1,98 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestParseUntypedTextArray(t *testing.T) { + tests := []struct { + source string + result pgtype.UntypedTextArray + }{ + { + source: "{}", + result: pgtype.UntypedTextArray{ + Elements: nil, + Dimensions: nil, + }, + }, + { + source: "{1}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{a,b}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b"}, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + }, + }, + { + source: `{"NULL"}`, + result: pgtype.UntypedTextArray{ + Elements: []string{"NULL"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: `{"He said, \"Hello.\""}`, + result: pgtype.UntypedTextArray{ + Elements: []string{`He said, "Hello."`}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{{a,b},{c,d},{e,f}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d", "e", "f"}, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + }, + }, + { + source: "{{{a,b},{c,d},{e,f}},{{a,b},{c,d},{e,f}}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d", "e", "f", "a", "b", "c", "d", "e", "f"}, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 1}, + {Length: 3, LowerBound: 1}, + {Length: 2, LowerBound: 1}, + }, + }, + }, + { + source: "[4:4]={1}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 4}}, + }, + }, + { + source: "[4:5][2:3]={{a,b},{c,d}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d"}, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + }, + }, + } + + for i, tt := range tests { + r, err := pgtype.ParseUntypedTextArray(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !reflect.DeepEqual(*r, tt.result) { + t.Errorf("%d: expected %+v to be parsed to %+v, but it was %+v", i, tt.source, tt.result, *r) + } + } +} diff --git a/bool.go b/bool.go new file mode 100644 index 00000000..81c72472 --- /dev/null +++ b/bool.go @@ -0,0 +1,166 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Bool struct { + Bool bool + Status Status +} + +func (b *Bool) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Bool: + *b = value + case bool: + *b = Bool{Bool: value, Status: Present} + case string: + bb, err := strconv.ParseBool(value) + if err != nil { + return err + } + *b = Bool{Bool: bb, Status: Present} + default: + if originalSrc, ok := underlyingBoolType(src); ok { + return b.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Bool", value) + } + + return nil +} + +func (b *Bool) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *bool: + if b.Status != Present { + return fmt.Errorf("cannot assign %v to %T", b, dst) + } + *v = b.Bool + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if b.Status == Null { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return b.AssignTo(el.Interface()) + case reflect.Bool: + if b.Status != Present { + return fmt.Errorf("cannot assign %v to %T", b, dst) + } + el.SetBool(b.Bool) + return nil + } + } + return fmt.Errorf("cannot put decode %v into %T", b, dst) + } + + return nil +} + +func (b *Bool) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *b = Bool{Status: Null} + return nil + } + + if size != 1 { + return fmt.Errorf("invalid length for bool: %v", size) + } + + byt, err := pgio.ReadByte(r) + if err != nil { + return err + } + + *b = Bool{Bool: byt == 't', Status: Present} + return nil +} + +func (b *Bool) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *b = Bool{Status: Null} + return nil + } + + if size != 1 { + return fmt.Errorf("invalid length for bool: %v", size) + } + + byt, err := pgio.ReadByte(r) + if err != nil { + return err + } + + *b = Bool{Bool: byt == 1, Status: Present} + return nil +} + +func (b Bool) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, b.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 1) + if err != nil { + return nil + } + + var buf []byte + if b.Bool { + buf = []byte{'t'} + } else { + buf = []byte{'f'} + } + + _, err = w.Write(buf) + return err +} + +func (b Bool) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, b.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 1) + if err != nil { + return nil + } + + var buf []byte + if b.Bool { + buf = []byte{1} + } else { + buf = []byte{0} + } + + _, err = w.Write(buf) + return err +} diff --git a/bool_test.go b/bool_test.go new file mode 100644 index 00000000..53df1747 --- /dev/null +++ b/bool_test.go @@ -0,0 +1,43 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestBoolTranscode(t *testing.T) { + testSuccessfulTranscode(t, "bool", []interface{}{ + pgtype.Bool{Bool: false, Status: pgtype.Present}, + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Bool: false, Status: pgtype.Null}, + }) +} + +func TestBoolConvertFrom(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result pgtype.Bool + }{ + {source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: false, result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "f", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Bool + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/convert.go b/convert.go new file mode 100644 index 00000000..3f3d9e5f --- /dev/null +++ b/convert.go @@ -0,0 +1,239 @@ +package pgtype + +import ( + "fmt" + "math" + "reflect" + "time" +) + +const maxUint = ^uint(0) +const maxInt = int(maxUint >> 1) +const minInt = -maxInt - 1 + +// underlyingIntType gets the underlying type that can be converted to Int2, Int4, or Int8 +func underlyingIntType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Int: + convVal := int(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int8: + convVal := int8(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int16: + convVal := int16(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int32: + convVal := int32(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int64: + convVal := int64(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint: + convVal := uint(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint8: + convVal := uint8(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint16: + convVal := uint16(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint32: + convVal := uint32(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint64: + convVal := uint64(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingBoolType gets the underlying type that can be converted to Bool +func underlyingBoolType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Bool: + convVal := refVal.Bool() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingTimeType gets the underlying type that can be converted to time.Time +func underlyingTimeType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return time.Time{}, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + timeType := reflect.TypeOf(time.Time{}) + if refVal.Type().ConvertibleTo(timeType) { + return refVal.Convert(timeType).Interface(), true + } + + return time.Time{}, false +} + +// underlyingSliceType gets the underlying slice type +func underlyingSliceType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + baseSliceType := reflect.SliceOf(refVal.Type().Elem()) + if refVal.Type().ConvertibleTo(baseSliceType) { + convVal := refVal.Convert(baseSliceType) + return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() + } + } + + return nil, false +} + +func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { + if srcStatus == Present { + switch v := dst.(type) { + case *int: + if srcVal < int64(minInt) { + return fmt.Errorf("%d is less than minimum value for int", srcVal) + } else if srcVal > int64(maxInt) { + return fmt.Errorf("%d is greater than maximum value for int", srcVal) + } + *v = int(srcVal) + case *int8: + if srcVal < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", srcVal) + } else if srcVal > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", srcVal) + } + *v = int8(srcVal) + case *int16: + if srcVal < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", srcVal) + } else if srcVal > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", srcVal) + } + *v = int16(srcVal) + case *int32: + if srcVal < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", srcVal) + } else if srcVal > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", srcVal) + } + *v = int32(srcVal) + case *int64: + if srcVal < math.MinInt64 { + return fmt.Errorf("%d is less than minimum value for int64", srcVal) + } else if srcVal > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for int64", srcVal) + } + *v = int64(srcVal) + case *uint: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint", srcVal) + } else if uint64(srcVal) > uint64(maxUint) { + return fmt.Errorf("%d is greater than maximum value for uint", srcVal) + } + *v = uint(srcVal) + case *uint8: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint8", srcVal) + } else if srcVal > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", srcVal) + } + *v = uint8(srcVal) + case *uint16: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", srcVal) + } + *v = uint16(srcVal) + case *uint32: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", srcVal) + } + *v = uint32(srcVal) + case *uint64: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint64", srcVal) + } + *v = uint64(srcVal) + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return int64AssignTo(srcVal, srcStatus, el.Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if el.OverflowInt(int64(srcVal)) { + return fmt.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetInt(int64(srcVal)) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for %T", srcVal, dst) + } + if el.OverflowUint(uint64(srcVal)) { + return fmt.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetUint(uint64(srcVal)) + return nil + } + } + return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + } + return nil + } + + // if dst is a pointer to pointer and srcStatus is not Present, nil it out + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + if el.Kind() == reflect.Ptr { + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) +} diff --git a/date.go b/date.go new file mode 100644 index 00000000..f3e3e4c6 --- /dev/null +++ b/date.go @@ -0,0 +1,191 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + "time" + + "github.com/jackc/pgx/pgio" +) + +type Date struct { + Time time.Time + Status Status + InfinityModifier +} + +const ( + negativeInfinityDayOffset = -2147483648 + infinityDayOffset = 2147483647 +) + +func (d *Date) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Date: + *d = value + case time.Time: + *d = Date{Time: value, Status: Present} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return d.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Date", value) + } + + return nil +} + +func (d *Date) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *time.Time: + if d.Status != Present || d.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", d, dst) + } + *v = d.Time + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if d.Status == Null { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return d.AssignTo(el.Interface()) + } + } + return fmt.Errorf("cannot decode %v into %T", d, dst) + } + + return nil +} + +func (d *Date) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *d = Date{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + sbuf := string(buf) + switch sbuf { + case "infinity": + *d = Date{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *d = Date{Status: Present, InfinityModifier: -Infinity} + default: + t, err := time.ParseInLocation("2006-01-02", sbuf, time.UTC) + if err != nil { + return err + } + + *d = Date{Time: t, Status: Present} + } + + return nil +} + +func (d *Date) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *d = Date{Status: Null} + return nil + } + + if size != 4 { + return fmt.Errorf("invalid length for date: %v", size) + } + + dayOffset, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + switch dayOffset { + case infinityDayOffset: + *d = Date{Status: Present, InfinityModifier: Infinity} + case negativeInfinityDayOffset: + *d = Date{Status: Present, InfinityModifier: -Infinity} + default: + t := time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.UTC) + *d = Date{Time: t, Status: Present} + } + + return nil +} + +func (d Date) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, d.Status); done { + return err + } + + var s string + + switch d.InfinityModifier { + case None: + s = d.Time.Format("2006-01-02") + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + + _, err = w.Write([]byte(s)) + return err +} + +func (d Date) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, d.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + var daysSinceDateEpoch int32 + switch d.InfinityModifier { + case None: + tUnix := time.Date(d.Time.Year(), d.Time.Month(), d.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() + dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() + + secSinceDateEpoch := tUnix - dateEpoch + daysSinceDateEpoch = int32(secSinceDateEpoch / 86400) + case Infinity: + daysSinceDateEpoch = infinityDayOffset + case NegativeInfinity: + daysSinceDateEpoch = negativeInfinityDayOffset + } + + _, err = pgio.WriteInt32(w, daysSinceDateEpoch) + return err +} diff --git a/date_test.go b/date_test.go new file mode 100644 index 00000000..c3e971d0 --- /dev/null +++ b/date_test.go @@ -0,0 +1,51 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestDateTranscode(t *testing.T) { + testSuccessfulTranscode(t, "date", []interface{}{ + pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Status: pgtype.Null}, + pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }) +} + +func TestDateConvertFrom(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Date + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var d pgtype.Date + err := d.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} diff --git a/extra-interface.txt b/extra-interface.txt new file mode 100644 index 00000000..16453823 --- /dev/null +++ b/extra-interface.txt @@ -0,0 +1,3 @@ +Can pass function to get inet data and function to get oid/name mapping as optional interface with io.Reader or io.Writer + +Could be useful for arrays of types without defined OIDs like hstore. diff --git a/int2.go b/int2.go new file mode 100644 index 00000000..2da8a96d --- /dev/null +++ b/int2.go @@ -0,0 +1,167 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Int2 struct { + Int int16 + Status Status +} + +func (i *Int2) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int2: + *i = value + case int8: + *i = Int2{Int: int16(value), Status: Present} + case uint8: + *i = Int2{Int: int16(value), Status: Present} + case int16: + *i = Int2{Int: int16(value), Status: Present} + case uint16: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case int32: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case uint32: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case int64: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case uint64: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case int: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case uint: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 16) + if err != nil { + return err + } + *i = Int2{Int: int16(num), Status: Present} + default: + if originalSrc, ok := underlyingIntType(src); ok { + return i.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int2", value) + } + + return nil +} + +func (i *Int2) AssignTo(dst interface{}) error { + return int64AssignTo(int64(i.Int), i.Status, dst) +} + +func (i *Int2) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int2{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseInt(string(buf), 10, 16) + if err != nil { + return err + } + + *i = Int2{Int: int16(n), Status: Present} + return nil +} + +func (i *Int2) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int2{Status: Null} + return nil + } + + if size != 2 { + return fmt.Errorf("invalid length for int2: %v", size) + } + + n, err := pgio.ReadInt16(r) + if err != nil { + return err + } + + *i = Int2{Int: int16(n), Status: Present} + return nil +} + +func (i Int2) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + s := strconv.FormatInt(int64(i.Int), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (i Int2) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = pgio.WriteInt16(w, i.Int) + return err +} diff --git a/int2_test.go b/int2_test.go new file mode 100644 index 00000000..a8493a16 --- /dev/null +++ b/int2_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "math" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt2Transcode(t *testing.T) { + testSuccessfulTranscode(t, "int2", []interface{}{ + pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}, + pgtype.Int2{Int: -1, Status: pgtype.Present}, + pgtype.Int2{Int: 0, Status: pgtype.Present}, + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Int: math.MaxInt16, Status: pgtype.Present}, + pgtype.Int2{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt2ConvertFrom(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result pgtype.Int2 + }{ + {source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int2 + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/int2array.go b/int2array.go new file mode 100644 index 00000000..86375516 --- /dev/null +++ b/int2array.go @@ -0,0 +1,308 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Int2Array struct { + Elements []Int2 + Dimensions []ArrayDimension + Status Status +} + +func (a *Int2Array) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int2Array: + *a = value + case []int16: + if value == nil { + *a = Int2Array{Status: Null} + } else if len(value) == 0 { + *a = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *a = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint16: + if value == nil { + *a = Int2Array{Status: Null} + } else if len(value) == 0 { + *a = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *a = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return a.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int2", value) + } + + return nil +} + +func (a *Int2Array) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *[]int16: + if a.Status == Present { + *v = make([]int16, len(a.Elements)) + for i := range a.Elements { + if err := a.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + case *[]uint16: + if a.Status == Present { + *v = make([]uint16, len(a.Elements)) + for i := range a.Elements { + if err := a.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + default: + return fmt.Errorf("cannot put decode %v into %T", a, dst) + } + + return nil +} + +func (a *Int2Array) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *a = Int2Array{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Int2 + + if len(uta.Elements) > 0 { + elements = make([]Int2, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int2 + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *a = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (a *Int2Array) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *a = Int2Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *a = Int2Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int2, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *a = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (a *Int2Array) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, a.Status); done { + return err + } + + if len(a.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, a.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(a.Dimensions)) + dimElemCounts[len(a.Dimensions)-1] = int(a.Dimensions[len(a.Dimensions)-1].Length) + for i := len(a.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(a.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range a.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (a *Int2Array) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, a.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range a.Elements { + err := a.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if a.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = Int2OID + arrayHeader.Dimensions = a.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/int2array_test.go b/int2array_test.go new file mode 100644 index 00000000..5ea81990 --- /dev/null +++ b/int2array_test.go @@ -0,0 +1,87 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt2ArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "int2[]", []interface{}{ + &pgtype.Int2Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int2Array{Status: pgtype.Null}, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Int: 2, Status: pgtype.Present}, + pgtype.Int2{Int: 3, Status: pgtype.Present}, + pgtype.Int2{Int: 4, Status: pgtype.Present}, + pgtype.Int2{Status: pgtype.Null}, + pgtype.Int2{Int: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Int: 2, Status: pgtype.Present}, + pgtype.Int2{Int: 3, Status: pgtype.Present}, + pgtype.Int2{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +// func TestInt2ConvertFrom(t *testing.T) { +// type _int8 int8 + +// successfulTests := []struct { +// source interface{} +// result pgtype.Int2 +// }{ +// {source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, +// {source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, +// {source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, +// {source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, +// {source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// } + +// for i, tt := range successfulTests { +// var r pgtype.Int2 +// err := r.ConvertFrom(tt.source) +// if err != nil { +// t.Errorf("%d: %v", i, err) +// } + +// if r != tt.result { +// t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) +// } +// } +// } diff --git a/int4.go b/int4.go new file mode 100644 index 00000000..84c45522 --- /dev/null +++ b/int4.go @@ -0,0 +1,158 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Int4 struct { + Int int32 + Status Status +} + +func (i *Int4) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int4: + *i = value + case int8: + *i = Int4{Int: int32(value), Status: Present} + case uint8: + *i = Int4{Int: int32(value), Status: Present} + case int16: + *i = Int4{Int: int32(value), Status: Present} + case uint16: + *i = Int4{Int: int32(value), Status: Present} + case int32: + *i = Int4{Int: int32(value), Status: Present} + case uint32: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case int64: + if value < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case uint64: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case int: + if value < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case uint: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 32) + if err != nil { + return err + } + *i = Int4{Int: int32(num), Status: Present} + default: + if originalSrc, ok := underlyingIntType(src); ok { + return i.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int8", value) + } + + return nil +} + +func (i *Int4) AssignTo(dst interface{}) error { + return int64AssignTo(int64(i.Int), i.Status, dst) +} + +func (i *Int4) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int4{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseInt(string(buf), 10, 32) + if err != nil { + return err + } + + *i = Int4{Int: int32(n), Status: Present} + return nil +} + +func (i *Int4) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int4{Status: Null} + return nil + } + + if size != 4 { + return fmt.Errorf("invalid length for int4: %v", size) + } + + n, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + *i = Int4{Int: n, Status: Present} + return nil +} + +func (i Int4) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + s := strconv.FormatInt(int64(i.Int), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (i Int4) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, i.Int) + return err +} diff --git a/int4_test.go b/int4_test.go new file mode 100644 index 00000000..04411849 --- /dev/null +++ b/int4_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "math" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt4Transcode(t *testing.T) { + testSuccessfulTranscode(t, "int4", []interface{}{ + pgtype.Int4{Int: math.MinInt32, Status: pgtype.Present}, + pgtype.Int4{Int: -1, Status: pgtype.Present}, + pgtype.Int4{Int: 0, Status: pgtype.Present}, + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Int: math.MaxInt32, Status: pgtype.Present}, + pgtype.Int4{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt4ConvertFrom(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result pgtype.Int4 + }{ + {source: int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int4 + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/int8.go b/int8.go new file mode 100644 index 00000000..c0e14e44 --- /dev/null +++ b/int8.go @@ -0,0 +1,149 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Int8 struct { + Int int64 + Status Status +} + +func (i *Int8) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int8: + *i = value + case int8: + *i = Int8{Int: int64(value), Status: Present} + case uint8: + *i = Int8{Int: int64(value), Status: Present} + case int16: + *i = Int8{Int: int64(value), Status: Present} + case uint16: + *i = Int8{Int: int64(value), Status: Present} + case int32: + *i = Int8{Int: int64(value), Status: Present} + case uint32: + *i = Int8{Int: int64(value), Status: Present} + case int64: + *i = Int8{Int: int64(value), Status: Present} + case uint64: + if value > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *i = Int8{Int: int64(value), Status: Present} + case int: + if int64(value) < math.MinInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + if int64(value) > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *i = Int8{Int: int64(value), Status: Present} + case uint: + if uint64(value) > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *i = Int8{Int: int64(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + *i = Int8{Int: num, Status: Present} + default: + if originalSrc, ok := underlyingIntType(src); ok { + return i.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int8", value) + } + + return nil +} + +func (i *Int8) AssignTo(dst interface{}) error { + return int64AssignTo(int64(i.Int), i.Status, dst) +} + +func (i *Int8) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int8{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseInt(string(buf), 10, 64) + if err != nil { + return err + } + + *i = Int8{Int: n, Status: Present} + return nil +} + +func (i *Int8) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int8{Status: Null} + return nil + } + + if size != 8 { + return fmt.Errorf("invalid length for int8: %v", size) + } + + n, err := pgio.ReadInt64(r) + if err != nil { + return err + } + + *i = Int8{Int: n, Status: Present} + return nil +} + +func (i Int8) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + s := strconv.FormatInt(i.Int, 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (i Int8) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 8) + if err != nil { + return err + } + + _, err = pgio.WriteInt64(w, i.Int) + return err +} diff --git a/int8_test.go b/int8_test.go new file mode 100644 index 00000000..ba246224 --- /dev/null +++ b/int8_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "math" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt8Transcode(t *testing.T) { + testSuccessfulTranscode(t, "int8", []interface{}{ + pgtype.Int8{Int: math.MinInt64, Status: pgtype.Present}, + pgtype.Int8{Int: -1, Status: pgtype.Present}, + pgtype.Int8{Int: 0, Status: pgtype.Present}, + pgtype.Int8{Int: 1, Status: pgtype.Present}, + pgtype.Int8{Int: math.MaxInt64, Status: pgtype.Present}, + pgtype.Int8{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt8ConvertFrom(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result pgtype.Int8 + }{ + {source: int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int8 + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype.go b/pgtype.go new file mode 100644 index 00000000..f9833363 --- /dev/null +++ b/pgtype.go @@ -0,0 +1,102 @@ +package pgtype + +import ( + "errors" + "io" + + "github.com/jackc/pgx/pgio" +) + +// PostgreSQL oids for common types +const ( + BoolOID = 16 + ByteaOID = 17 + CharOID = 18 + NameOID = 19 + Int8OID = 20 + Int2OID = 21 + Int4OID = 23 + TextOID = 25 + OIDOID = 26 + TidOID = 27 + XidOID = 28 + CidOID = 29 + JSONOID = 114 + CidrOID = 650 + CidrArrayOID = 651 + Float4OID = 700 + Float8OID = 701 + UnknownOID = 705 + InetOID = 869 + BoolArrayOID = 1000 + Int2ArrayOID = 1005 + Int4ArrayOID = 1007 + TextArrayOID = 1009 + ByteaArrayOID = 1001 + VarcharArrayOID = 1015 + Int8ArrayOID = 1016 + Float4ArrayOID = 1021 + Float8ArrayOID = 1022 + AclItemOID = 1033 + AclItemArrayOID = 1034 + InetArrayOID = 1041 + VarcharOID = 1043 + DateOID = 1082 + TimestampOID = 1114 + TimestampArrayOID = 1115 + TimestampTzOID = 1184 + TimestampTzArrayOID = 1185 + RecordOID = 2249 + UUIDOID = 2950 + JSONBOID = 3802 +) + +type Status byte + +const ( + Undefined Status = iota + Null + Present +) + +type InfinityModifier int8 + +const ( + Infinity InfinityModifier = 1 + None InfinityModifier = 0 + NegativeInfinity InfinityModifier = -Infinity +) + +type Value interface { + ConvertFrom(src interface{}) error + AssignTo(dst interface{}) error +} + +type BinaryDecoder interface { + DecodeBinary(r io.Reader) error +} + +type TextDecoder interface { + DecodeText(r io.Reader) error +} + +type BinaryEncoder interface { + EncodeBinary(w io.Writer) error +} + +type TextEncoder interface { + EncodeText(w io.Writer) error +} + +var errUndefined = errors.New("cannot encode status undefined") + +func encodeNotPresent(w io.Writer, status Status) (done bool, err error) { + switch status { + case Undefined: + return true, errUndefined + case Null: + _, err = pgio.WriteInt32(w, -1) + return true, err + } + return false, nil +} diff --git a/pgtype_test.go b/pgtype_test.go new file mode 100644 index 00000000..a1a575f7 --- /dev/null +++ b/pgtype_test.go @@ -0,0 +1,108 @@ +package pgtype_test + +import ( + "fmt" + "io" + "os" + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" +) + +func mustConnectPgx(t testing.TB) *pgx.Conn { + config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) + if err != nil { + t.Fatal(err) + } + + conn, err := pgx.Connect(config) + if err != nil { + t.Fatal(err) + } + + return conn +} + +func mustClose(t testing.TB, conn interface { + Close() error +}) { + err := conn.Close() + if err != nil { + t.Fatal(err) + } +} + +type forceTextEncoder struct { + e pgtype.TextEncoder +} + +func (f forceTextEncoder) EncodeText(w io.Writer) error { + return f.e.EncodeText(w) +} + +type forceBinaryEncoder struct { + e pgtype.BinaryEncoder +} + +func (f forceBinaryEncoder) EncodeBinary(w io.Writer) error { + return f.e.EncodeBinary(w) +} + +func forceEncoder(e interface{}, formatCode int16) interface{} { + switch formatCode { + case pgx.TextFormatCode: + return forceTextEncoder{e: e.(pgtype.TextEncoder)} + case pgx.BinaryFormatCode: + return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} + default: + panic("bad encoder") + } +} + +func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { + testSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, fc := range formats { + ps.FieldDescriptions[0].FormatCode = fc.formatCode + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := conn.QueryRow("test", forceEncoder(v, fc.formatCode)).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + } + } + } +} diff --git a/text_element.go b/text_element.go new file mode 100644 index 00000000..1a585d08 --- /dev/null +++ b/text_element.go @@ -0,0 +1,112 @@ +package pgtype + +import ( + "bytes" + "errors" + "io" + + "github.com/jackc/pgx/pgio" +) + +// TextElementWriter is a wrapper that makes TextEncoders composable into other +// TextEncoders. TextEncoder first writes the length of the subsequent value. +// This is not necessary when the value is part of another value such as an +// array. TextElementWriter requires one int32 to be written first which it +// ignores. No other integer writes are valid. +type TextElementWriter struct { + w io.Writer + lengthHeaderIgnored bool +} + +func NewTextElementWriter(w io.Writer) *TextElementWriter { + return &TextElementWriter{w: w} +} + +func (w *TextElementWriter) WriteUint16(n uint16) (int, error) { + return 0, errors.New("WriteUint16 should never be called on TextElementWriter") +} + +func (w *TextElementWriter) WriteUint32(n uint32) (int, error) { + if !w.lengthHeaderIgnored { + w.lengthHeaderIgnored = true + + if int32(n) == -1 { + return io.WriteString(w.w, "NULL") + } + + return 4, nil + } + + return 0, errors.New("WriteUint32 should only be called once on TextElementWriter") +} + +func (w *TextElementWriter) WriteUint64(n uint64) (int, error) { + if w.lengthHeaderIgnored { + return pgio.WriteUint64(w.w, n) + } + + return 0, errors.New("WriteUint64 should never be called on TextElementWriter") +} + +func (w *TextElementWriter) Write(buf []byte) (int, error) { + if w.lengthHeaderIgnored { + return w.w.Write(buf) + } + + return 0, errors.New("int32 must be written first") +} + +func (w *TextElementWriter) Reset() { + w.lengthHeaderIgnored = false +} + +// TextElementReader is a wrapper that makes TextDecoders composable into other +// TextDecoders. TextEncoders first read the length of the subsequent value. +// This length value is not present when the value is part of another value such +// as an array. TextElementReader provides a substitute length value from the +// length of the string. No other integer reads are valid. Each time DecodeText +// is called with a TextElementReader as the source the TextElementReader must +// first have Reset called with the new element string data. +type TextElementReader struct { + buf *bytes.Buffer + lengthHeaderIgnored bool +} + +func NewTextElementReader(r io.Reader) *TextElementReader { + return &TextElementReader{buf: &bytes.Buffer{}} +} + +func (r *TextElementReader) ReadUint16() (uint16, error) { + return 0, errors.New("ReadUint16 should never be called on TextElementReader") +} + +func (r *TextElementReader) ReadUint32() (uint32, error) { + if !r.lengthHeaderIgnored { + r.lengthHeaderIgnored = true + if r.buf.String() == "NULL" { + n32 := int32(-1) + return uint32(n32), nil + } + return uint32(r.buf.Len()), nil + } + + return 0, errors.New("ReadUint32 should only be called once on TextElementReader") +} + +func (r *TextElementReader) WriteUint64(n uint64) (int, error) { + return 0, errors.New("ReadUint64 should never be called on TextElementReader") +} + +func (r *TextElementReader) Read(buf []byte) (int, error) { + if r.lengthHeaderIgnored { + return r.buf.Read(buf) + } + + return 0, errors.New("int32 must be read first") +} + +func (r *TextElementReader) Reset(s string) { + r.lengthHeaderIgnored = false + r.buf.Reset() + r.buf.WriteString(s) +} diff --git a/timestamptz.go b/timestamptz.go new file mode 100644 index 00000000..cc33b296 --- /dev/null +++ b/timestamptz.go @@ -0,0 +1,203 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + "time" + + "github.com/jackc/pgx/pgio" +) + +const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" +const pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" +const pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" +const microsecFromUnixEpochToY2K = 946684800 * 1000000 + +const ( + negativeInfinityMicrosecondOffset = -9223372036854775808 + infinityMicrosecondOffset = 9223372036854775807 +) + +type Timestamptz struct { + Time time.Time + Status Status + InfinityModifier +} + +func (t *Timestamptz) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Timestamptz: + *t = value + case time.Time: + *t = Timestamptz{Time: value, Status: Present} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return t.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Timestamptz", value) + } + + return nil +} + +func (t *Timestamptz) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *time.Time: + if t.Status != Present || t.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", t, dst) + } + *v = t.Time + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if t.Status == Null { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return t.AssignTo(el.Interface()) + } + } + return fmt.Errorf("cannot assign %v into %T", t, dst) + } + + return nil +} + +func (t *Timestamptz) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *t = Timestamptz{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + sbuf := string(buf) + switch sbuf { + case "infinity": + *t = Timestamptz{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *t = Timestamptz{Status: Present, InfinityModifier: -Infinity} + default: + var format string + if sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+' { + format = pgTimestamptzSecondFormat + } else if sbuf[len(sbuf)-6] == '-' || sbuf[len(sbuf)-6] == '+' { + format = pgTimestamptzMinuteFormat + } else { + format = pgTimestamptzHourFormat + } + + tim, err := time.Parse(format, sbuf) + if err != nil { + return err + } + + *t = Timestamptz{Time: tim, Status: Present} + } + + return nil +} + +func (t *Timestamptz) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *t = Timestamptz{Status: Null} + return nil + } + + if size != 8 { + return fmt.Errorf("invalid length for timestamptz: %v", size) + } + + microsecSinceY2K, err := pgio.ReadInt64(r) + if err != nil { + return err + } + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + *t = Timestamptz{Status: Present, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + *t = Timestamptz{Status: Present, InfinityModifier: -Infinity} + default: + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + tim := time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) + *t = Timestamptz{Time: tim, Status: Present} + } + + return nil +} + +func (t Timestamptz) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, t.Status); done { + return err + } + + var s string + + switch t.InfinityModifier { + case None: + s = t.Time.UTC().Format(pgTimestamptzSecondFormat) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + + _, err = w.Write([]byte(s)) + return err +} + +func (t Timestamptz) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, t.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 8) + if err != nil { + return err + } + + var microsecSinceY2K int64 + switch t.InfinityModifier { + case None: + microsecSinceUnixEpoch := t.Time.Unix()*1000000 + int64(t.Time.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + _, err = pgio.WriteInt64(w, microsecSinceY2K) + return err +} diff --git a/timestamptz_test.go b/timestamptz_test.go new file mode 100644 index 00000000..795195f8 --- /dev/null +++ b/timestamptz_test.go @@ -0,0 +1,60 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestTimestamptzTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ + pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Status: pgtype.Null}, + pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Timestamptz) + bt := b.(pgtype.Timestamptz) + + return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier + }) +} + +func TestTimestamptzConvertFrom(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Timestamptz + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Timestamptz + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} From 579b6cd612600fb7a72d93cbe917dee794f3a2c7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 15:33:34 -0600 Subject: [PATCH 0003/1158] Initial proof-of-concept for pgtype Squashed commit of the following: commit c19454582b335ce5bdda6320f7e4e8c76cfeaf44 Author: Jack Christensen Date: Fri Mar 3 15:24:47 2017 -0600 Add AssignTo to pgtype.Timestamptz Also handle infinity for pgtype.Date commit 7329933610b38f4bc15731b1f7c55c520b49e300 Author: Jack Christensen Date: Fri Mar 3 15:12:18 2017 -0600 Implement AssignTo for most pgtypes commit cc3d1e4af896d34ec98c3bf2e982d0367451f21c Author: Jack Christensen Date: Thu Mar 2 21:19:07 2017 -0600 Use pgtype.Int2Array in pgx commit 36da5cc2178d1a31a56dc6e6f128843bd80dea0b Author: Jack Christensen Date: Tue Feb 28 21:45:33 2017 -0600 Add text array transcoding commit 1b0f18d99f38b69f8c2db26388815e67b2b03d59 Author: Jack Christensen Date: Mon Feb 27 19:28:55 2017 -0600 Add ParseUntypedTextArray commit 0f50ce3e833fc38495d333228daf04f5142be676 Author: Jack Christensen Date: Mon Feb 27 18:54:20 2017 -0600 wip commit d934f273627d79997035c282416db922f2fbe87a Author: Jack Christensen Date: Sun Feb 26 17:14:32 2017 -0600 WIP - beginning text format array parsing commit 7276ad33ce7fa9c250745a3ed909998f3dae4a32 Author: Jack Christensen Date: Sat Feb 25 22:50:11 2017 -0600 Beginning binary arrays commit 917faa5a3175d376222423c10aca297a20f96448 Author: Jack Christensen Date: Sat Feb 25 19:36:35 2017 -0600 Fix incomplete tests commit de8c140cfb98b7b047d53c5718ccbf12eaf813a1 Author: Jack Christensen Date: Sat Feb 25 19:32:22 2017 -0600 Add timestamptz null and infinity commit 7d9f954de4e071a1eccac762248079b90dbeb53f Author: Jack Christensen Date: Sat Feb 25 18:19:38 2017 -0600 Add infinity to pgtype.Date commit 7bf783ae20ba05571c2fb9f661183233c95eab41 Author: Jack Christensen Date: Sat Feb 25 17:19:55 2017 -0600 Add Status to pgtype.Date commit 984500455c9b9a4b6221758540d248e6410d93a4 Author: Jack Christensen Date: Sat Feb 25 16:54:01 2017 -0600 Add status to Int4 and Int8 commit 6fe76fcfc2de31552790db3b093480a9d5b2a742 Author: Jack Christensen Date: Sat Feb 25 16:40:27 2017 -0600 Extract testSuccessfulTranscode commit 001647c1da03f796014cf21f41c9a7fd2cfadfde Author: Jack Christensen Date: Sat Feb 25 16:15:51 2017 -0600 Add Status to pgtype.Int2 commit 720451f06d13d9c9fa2a0482e010f24bf4627c2a Author: Jack Christensen Date: Sat Feb 25 15:56:44 2017 -0600 Add status to pgtype.Bool commit 325f700b6edff215a692b10bc5b94cdfe1100769 Author: Jack Christensen Date: Fri Feb 24 17:28:15 2017 -0600 Add date to conversion system commit 4a9343e45d3897f59eab98a0009d2ddbe07e02d7 Author: Jack Christensen Date: Fri Feb 24 16:28:35 2017 -0600 Add bool to oid based encoding commit d984fcafab1476cf84852485b6711f4b2069eb6d Author: Jack Christensen Date: Fri Feb 24 16:15:38 2017 -0600 Add pgtype interfaces commit 0f93bfc2de4023b069b966c0988bf7f0469d1809 Author: Jack Christensen Date: Fri Feb 24 14:48:34 2017 -0600 Begin introduction of Convert commit e5707023cac7c07342b8c910e480d09a1caaf6ee Author: Jack Christensen Date: Fri Feb 24 14:10:56 2017 -0600 Move bool to pgtype commit bb764d2129efe7fb21e841dbb35e6d0dc7586d37 Author: Jack Christensen Date: Fri Feb 24 13:45:05 2017 -0600 Add Int2 test commit 08c49437f455a32f7c3f0a524cd21a895d440301 Author: Jack Christensen Date: Fri Feb 24 13:44:09 2017 -0600 Add Int4 test commit 16722952222fd15c53c8fa84974645504a6d0dc0 Author: Jack Christensen Date: Fri Feb 24 08:56:59 2017 -0600 Add int8 tests commit 83a5447cd2c46b58d0880023cc4e9af0c84988a2 Author: Jack Christensen Date: Wed Feb 22 18:08:05 2017 -0600 wip commit 0ca0ee72068a72b016729b01fccef22474595285 Author: Jack Christensen Date: Mon Feb 20 18:56:52 2017 -0600 wip commit d2c2baf4ea2cd0793d68c7094c425217df952bec Author: Jack Christensen Date: Mon Feb 20 18:46:10 2017 -0600 wip commit f78371da0098356527b193fd496a338da5fe414b Author: Jack Christensen Date: Mon Feb 20 17:43:39 2017 -0600 wip commit 3366699bea62ec0110db05f3cb2998d58ac9ce5d Author: Jack Christensen Date: Mon Feb 20 14:07:47 2017 -0600 wip commit 66b79e940870fd0133ebb10ac1547e1d4d7d0b51 Author: Jack Christensen Date: Mon Feb 20 13:35:37 2017 -0600 Extract pgio commit 8b07d97d1305ed98fd76db6e306a289c0af92d56 Author: Jack Christensen Date: Mon Feb 20 13:20:00 2017 -0600 wip commit 62f1adb3427f4317b708da075dce50c4d4daff7b Author: Jack Christensen Date: Mon Feb 20 12:08:46 2017 -0600 wip commit a712d2546933a5a8433c65eef0ff2ee135077c87 Author: Jack Christensen Date: Mon Feb 20 09:30:52 2017 -0600 wip commit 4faf97cc588126dda160fc360680719572a23105 Author: Jack Christensen Date: Fri Feb 17 22:20:18 2017 -0600 wip --- doc.go | 8 +++++ read.go | 104 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ write.go | 97 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 209 insertions(+) create mode 100644 doc.go create mode 100644 read.go create mode 100644 write.go diff --git a/doc.go b/doc.go new file mode 100644 index 00000000..36233a47 --- /dev/null +++ b/doc.go @@ -0,0 +1,8 @@ +// Package pgio a extremely low-level IO toolkit for the PostgreSQL wire protocol. +/* +pgio provides functions for reading and writing integers from io.Reader and +io.Writer while doing byte order conversion. It publishes interfaces which +readers and writers may implement to decode and encode messages with the minimum +of memory allocations. +*/ +package pgio diff --git a/read.go b/read.go new file mode 100644 index 00000000..7c39162c --- /dev/null +++ b/read.go @@ -0,0 +1,104 @@ +package pgio + +import ( + "encoding/binary" + "io" +) + +type Uint16Reader interface { + ReadUint16() (n uint16, err error) +} + +type Uint32Reader interface { + ReadUint32() (n uint32, err error) +} + +type Uint64Reader interface { + ReadUint64() (n uint64, err error) +} + +// ReadByte reads a byte from r. +func ReadByte(r io.Reader) (byte, error) { + if r, ok := r.(io.ByteReader); ok { + return r.ReadByte() + } + + buf := make([]byte, 1) + _, err := r.Read(buf) + return buf[0], err +} + +// ReadUint16 reads an uint16 from r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint16 +// method. +func ReadUint16(r io.Reader) (uint16, error) { + if r, ok := r.(Uint16Reader); ok { + return r.ReadUint16() + } + + buf := make([]byte, 2) + _, err := io.ReadFull(r, buf) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint16(buf), nil +} + +// ReadInt16 reads an int16 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint16 +// method. +func ReadInt16(r io.Reader) (int16, error) { + n, err := ReadUint16(r) + return int16(n), err +} + +// ReadUint32 reads an uint32 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint32 +// method. +func ReadUint32(r io.Reader) (uint32, error) { + if r, ok := r.(Uint32Reader); ok { + return r.ReadUint32() + } + + buf := make([]byte, 4) + _, err := io.ReadFull(r, buf) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint32(buf), nil +} + +// ReadInt32 reads an int32 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint32 +// method. +func ReadInt32(r io.Reader) (int32, error) { + n, err := ReadUint32(r) + return int32(n), err +} + +// ReadUint64 reads an uint64 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint64 +// method. +func ReadUint64(r io.Reader) (uint64, error) { + if r, ok := r.(Uint64Reader); ok { + return r.ReadUint64() + } + + buf := make([]byte, 8) + _, err := io.ReadFull(r, buf) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint64(buf), nil +} + +// ReadInt64 reads an int64 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint64 +// method. +func ReadInt64(r io.Reader) (int64, error) { + n, err := ReadUint64(r) + return int64(n), err +} diff --git a/write.go b/write.go new file mode 100644 index 00000000..823fbd00 --- /dev/null +++ b/write.go @@ -0,0 +1,97 @@ +package pgio + +import ( + "encoding/binary" + "io" +) + +type Uint16Writer interface { + WriteUint16(uint16) (n int, err error) +} + +type Uint32Writer interface { + WriteUint32(uint32) (n int, err error) +} + +type Uint64Writer interface { + WriteUint64(uint64) (n int, err error) +} + +// WriteByte writes b to w. +func WriteByte(w io.Writer, b byte) error { + if w, ok := w.(io.ByteWriter); ok { + return w.WriteByte(b) + } + _, err := w.Write([]byte{b}) + return err +} + +// WriteUint16 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint16 +// method. +func WriteUint16(w io.Writer, n uint16) (int, error) { + if w, ok := w.(Uint16Writer); ok { + return w.WriteUint16(n) + } + b := make([]byte, 2) + binary.BigEndian.PutUint16(b, n) + return w.Write(b) +} + +// WriteInt16 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint16 +// method. +func WriteInt16(w io.Writer, n int16) (int, error) { + return WriteUint16(w, uint16(n)) +} + +// WriteUint32 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint32 +// method. +func WriteUint32(w io.Writer, n uint32) (int, error) { + if w, ok := w.(Uint32Writer); ok { + return w.WriteUint32(n) + } + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, n) + return w.Write(b) +} + +// WriteInt32 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint32 +// method. +func WriteInt32(w io.Writer, n int32) (int, error) { + return WriteUint32(w, uint32(n)) +} + +// WriteUint64 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint64 +// method. +func WriteUint64(w io.Writer, n uint64) (int, error) { + if w, ok := w.(Uint64Writer); ok { + return w.WriteUint64(n) + } + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, n) + return w.Write(b) +} + +// WriteInt64 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint64 +// method. +func WriteInt64(w io.Writer, n int64) (int, error) { + return WriteUint64(w, uint64(n)) +} + +// WriteCString writes s to w followed by a null byte. +func WriteCString(w io.Writer, s string) (int, error) { + n, err := io.WriteString(w, s) + if err != nil { + return n, err + } + err = WriteByte(w, 0) + if err != nil { + return n, err + } + return n + 1, nil +} From a1e4efe14e77a72a8924a4cc95b9cb6ae6109cc4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 17:15:05 -0600 Subject: [PATCH 0004/1158] Add more tests for pgtype.Bool --- bool.go | 5 +---- bool_test.go | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/bool.go b/bool.go index 81c72472..14bc2d6e 100644 --- a/bool.go +++ b/bool.go @@ -50,10 +50,7 @@ func (b *Bool) AssignTo(dst interface{}) error { // if dst is a pointer to pointer, strip the pointer and try again case reflect.Ptr: if b.Status == Null { - if !el.IsNil() { - // if the destination pointer is not nil, nil it out - el.Set(reflect.Zero(el.Type())) - } + el.Set(reflect.Zero(el.Type())) return nil } if el.IsNil() { diff --git a/bool_test.go b/bool_test.go index 53df1747..74140b5e 100644 --- a/bool_test.go +++ b/bool_test.go @@ -1,11 +1,14 @@ package pgtype_test import ( + "reflect" "testing" "github.com/jackc/pgx/pgtype" ) +type _bool bool + func TestBoolTranscode(t *testing.T) { testSuccessfulTranscode(t, "bool", []interface{}{ pgtype.Bool{Bool: false, Status: pgtype.Present}, @@ -15,18 +18,19 @@ func TestBoolTranscode(t *testing.T) { } func TestBoolConvertFrom(t *testing.T) { - type _int8 int8 - successfulTests := []struct { source interface{} result pgtype.Bool }{ + {source: pgtype.Bool{Bool: false, Status: pgtype.Null}, result: pgtype.Bool{Bool: false, Status: pgtype.Null}}, {source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, {source: false, result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, {source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, {source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, {source: "f", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: _bool(true), result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: _bool(false), result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, } for i, tt := range successfulTests { @@ -41,3 +45,54 @@ func TestBoolConvertFrom(t *testing.T) { } } } + +func TestBoolAssignTo(t *testing.T) { + var b bool + var _b _bool + var pb *bool + var _pb *_bool + + simpleTests := []struct { + src pgtype.Bool + dst interface{} + expected interface{} + }{ + {src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &b, expected: false}, + {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &b, expected: true}, + {src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &_b, expected: _bool(false)}, + {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_b, expected: _bool(true)}, + {src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &pb, expected: ((*bool)(nil))}, + {src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &_pb, expected: ((*_bool)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Bool + dst interface{} + expected interface{} + }{ + {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &pb, expected: true}, + {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_pb, expected: _bool(true)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} From 890708967c12d684cef07a0060e6a0c04df44e9f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 17:35:02 -0600 Subject: [PATCH 0005/1158] Standardize receiver variable name for pgtype Conversion functions now use standardized src and dst depending on their role. --- array.go | 42 +++++++++++++------------- bool.go | 54 ++++++++++++++++----------------- date.go | 58 +++++++++++++++++------------------ int2.go | 56 +++++++++++++++++----------------- int2array.go | 82 +++++++++++++++++++++++++------------------------- int4.go | 56 +++++++++++++++++----------------- int8.go | 56 +++++++++++++++++----------------- timestamptz.go | 58 +++++++++++++++++------------------ 8 files changed, 231 insertions(+), 231 deletions(-) diff --git a/array.go b/array.go index 75d2e440..76492c61 100644 --- a/array.go +++ b/array.go @@ -25,34 +25,34 @@ type ArrayDimension struct { LowerBound int32 } -func (ah *ArrayHeader) DecodeBinary(r io.Reader) error { +func (dst *ArrayHeader) DecodeBinary(r io.Reader) error { numDims, err := pgio.ReadInt32(r) if err != nil { return err } if numDims > 0 { - ah.Dimensions = make([]ArrayDimension, numDims) + dst.Dimensions = make([]ArrayDimension, numDims) } containsNull, err := pgio.ReadInt32(r) if err != nil { return err } - ah.ContainsNull = containsNull == 1 + dst.ContainsNull = containsNull == 1 - ah.ElementOID, err = pgio.ReadInt32(r) + dst.ElementOID, err = pgio.ReadInt32(r) if err != nil { return err } - for i := range ah.Dimensions { - ah.Dimensions[i].Length, err = pgio.ReadInt32(r) + for i := range dst.Dimensions { + dst.Dimensions[i].Length, err = pgio.ReadInt32(r) if err != nil { return err } - ah.Dimensions[i].LowerBound, err = pgio.ReadInt32(r) + dst.Dimensions[i].LowerBound, err = pgio.ReadInt32(r) if err != nil { return err } @@ -61,14 +61,14 @@ func (ah *ArrayHeader) DecodeBinary(r io.Reader) error { return nil } -func (ah *ArrayHeader) EncodeBinary(w io.Writer) error { - _, err := pgio.WriteInt32(w, int32(len(ah.Dimensions))) +func (src *ArrayHeader) EncodeBinary(w io.Writer) error { + _, err := pgio.WriteInt32(w, int32(len(src.Dimensions))) if err != nil { return err } var containsNull int32 - if ah.ContainsNull { + if src.ContainsNull { containsNull = 1 } _, err = pgio.WriteInt32(w, containsNull) @@ -76,18 +76,18 @@ func (ah *ArrayHeader) EncodeBinary(w io.Writer) error { return err } - _, err = pgio.WriteInt32(w, ah.ElementOID) + _, err = pgio.WriteInt32(w, src.ElementOID) if err != nil { return err } - for i := range ah.Dimensions { - _, err = pgio.WriteInt32(w, ah.Dimensions[i].Length) + for i := range src.Dimensions { + _, err = pgio.WriteInt32(w, src.Dimensions[i].Length) if err != nil { return err } - _, err = pgio.WriteInt32(w, ah.Dimensions[i].LowerBound) + _, err = pgio.WriteInt32(w, src.Dimensions[i].LowerBound) if err != nil { return err } @@ -102,7 +102,7 @@ type UntypedTextArray struct { } func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { - uta := &UntypedTextArray{} + dst := &UntypedTextArray{} buf := bytes.NewBufferString(src) @@ -219,7 +219,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { if currentDim == counterDim { implicitDimensions[currentDim].Length++ } - uta.Elements = append(uta.Elements, value) + dst.Elements = append(dst.Elements, value) } if currentDim < 0 { @@ -233,15 +233,15 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) } - if len(uta.Elements) == 0 { - uta.Dimensions = nil + if len(dst.Elements) == 0 { + dst.Dimensions = nil } else if len(explicitDimensions) > 0 { - uta.Dimensions = explicitDimensions + dst.Dimensions = explicitDimensions } else { - uta.Dimensions = implicitDimensions + dst.Dimensions = implicitDimensions } - return uta, nil + return dst, nil } func skipWhitespace(buf *bytes.Buffer) { diff --git a/bool.go b/bool.go index 14bc2d6e..2889b787 100644 --- a/bool.go +++ b/bool.go @@ -14,21 +14,21 @@ type Bool struct { Status Status } -func (b *Bool) ConvertFrom(src interface{}) error { +func (dst *Bool) ConvertFrom(src interface{}) error { switch value := src.(type) { case Bool: - *b = value + *dst = value case bool: - *b = Bool{Bool: value, Status: Present} + *dst = Bool{Bool: value, Status: Present} case string: bb, err := strconv.ParseBool(value) if err != nil { return err } - *b = Bool{Bool: bb, Status: Present} + *dst = Bool{Bool: bb, Status: Present} default: if originalSrc, ok := underlyingBoolType(src); ok { - return b.ConvertFrom(originalSrc) + return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Bool", value) } @@ -36,20 +36,20 @@ func (b *Bool) ConvertFrom(src interface{}) error { return nil } -func (b *Bool) AssignTo(dst interface{}) error { +func (src *Bool) AssignTo(dst interface{}) error { switch v := dst.(type) { case *bool: - if b.Status != Present { - return fmt.Errorf("cannot assign %v to %T", b, dst) + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) } - *v = b.Bool + *v = src.Bool default: if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { el := v.Elem() switch el.Kind() { // if dst is a pointer to pointer, strip the pointer and try again case reflect.Ptr: - if b.Status == Null { + if src.Status == Null { el.Set(reflect.Zero(el.Type())) return nil } @@ -57,29 +57,29 @@ func (b *Bool) AssignTo(dst interface{}) error { // allocate destination el.Set(reflect.New(el.Type().Elem())) } - return b.AssignTo(el.Interface()) + return src.AssignTo(el.Interface()) case reflect.Bool: - if b.Status != Present { - return fmt.Errorf("cannot assign %v to %T", b, dst) + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) } - el.SetBool(b.Bool) + el.SetBool(src.Bool) return nil } } - return fmt.Errorf("cannot put decode %v into %T", b, dst) + return fmt.Errorf("cannot put decode %v into %T", src, dst) } return nil } -func (b *Bool) DecodeText(r io.Reader) error { +func (dst *Bool) DecodeText(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *b = Bool{Status: Null} + *dst = Bool{Status: Null} return nil } @@ -92,18 +92,18 @@ func (b *Bool) DecodeText(r io.Reader) error { return err } - *b = Bool{Bool: byt == 't', Status: Present} + *dst = Bool{Bool: byt == 't', Status: Present} return nil } -func (b *Bool) DecodeBinary(r io.Reader) error { +func (dst *Bool) DecodeBinary(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *b = Bool{Status: Null} + *dst = Bool{Status: Null} return nil } @@ -116,12 +116,12 @@ func (b *Bool) DecodeBinary(r io.Reader) error { return err } - *b = Bool{Bool: byt == 1, Status: Present} + *dst = Bool{Bool: byt == 1, Status: Present} return nil } -func (b Bool) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, b.Status); done { +func (src Bool) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -131,7 +131,7 @@ func (b Bool) EncodeText(w io.Writer) error { } var buf []byte - if b.Bool { + if src.Bool { buf = []byte{'t'} } else { buf = []byte{'f'} @@ -141,8 +141,8 @@ func (b Bool) EncodeText(w io.Writer) error { return err } -func (b Bool) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, b.Status); done { +func (src Bool) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -152,7 +152,7 @@ func (b Bool) EncodeBinary(w io.Writer) error { } var buf []byte - if b.Bool { + if src.Bool { buf = []byte{1} } else { buf = []byte{0} diff --git a/date.go b/date.go index f3e3e4c6..6cd8e499 100644 --- a/date.go +++ b/date.go @@ -20,15 +20,15 @@ const ( infinityDayOffset = 2147483647 ) -func (d *Date) ConvertFrom(src interface{}) error { +func (dst *Date) ConvertFrom(src interface{}) error { switch value := src.(type) { case Date: - *d = value + *dst = value case time.Time: - *d = Date{Time: value, Status: Present} + *dst = Date{Time: value, Status: Present} default: if originalSrc, ok := underlyingTimeType(src); ok { - return d.ConvertFrom(originalSrc) + return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Date", value) } @@ -36,20 +36,20 @@ func (d *Date) ConvertFrom(src interface{}) error { return nil } -func (d *Date) AssignTo(dst interface{}) error { +func (src *Date) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: - if d.Status != Present || d.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", d, dst) + if src.Status != Present || src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) } - *v = d.Time + *v = src.Time default: if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { el := v.Elem() switch el.Kind() { // if dst is a pointer to pointer, strip the pointer and try again case reflect.Ptr: - if d.Status == Null { + if src.Status == Null { if !el.IsNil() { // if the destination pointer is not nil, nil it out el.Set(reflect.Zero(el.Type())) @@ -60,23 +60,23 @@ func (d *Date) AssignTo(dst interface{}) error { // allocate destination el.Set(reflect.New(el.Type().Elem())) } - return d.AssignTo(el.Interface()) + return src.AssignTo(el.Interface()) } } - return fmt.Errorf("cannot decode %v into %T", d, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil } -func (d *Date) DecodeText(r io.Reader) error { +func (dst *Date) DecodeText(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *d = Date{Status: Null} + *dst = Date{Status: Null} return nil } @@ -89,29 +89,29 @@ func (d *Date) DecodeText(r io.Reader) error { sbuf := string(buf) switch sbuf { case "infinity": - *d = Date{Status: Present, InfinityModifier: Infinity} + *dst = Date{Status: Present, InfinityModifier: Infinity} case "-infinity": - *d = Date{Status: Present, InfinityModifier: -Infinity} + *dst = Date{Status: Present, InfinityModifier: -Infinity} default: t, err := time.ParseInLocation("2006-01-02", sbuf, time.UTC) if err != nil { return err } - *d = Date{Time: t, Status: Present} + *dst = Date{Time: t, Status: Present} } return nil } -func (d *Date) DecodeBinary(r io.Reader) error { +func (dst *Date) DecodeBinary(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *d = Date{Status: Null} + *dst = Date{Status: Null} return nil } @@ -126,27 +126,27 @@ func (d *Date) DecodeBinary(r io.Reader) error { switch dayOffset { case infinityDayOffset: - *d = Date{Status: Present, InfinityModifier: Infinity} + *dst = Date{Status: Present, InfinityModifier: Infinity} case negativeInfinityDayOffset: - *d = Date{Status: Present, InfinityModifier: -Infinity} + *dst = Date{Status: Present, InfinityModifier: -Infinity} default: t := time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.UTC) - *d = Date{Time: t, Status: Present} + *dst = Date{Time: t, Status: Present} } return nil } -func (d Date) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, d.Status); done { +func (src Date) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } var s string - switch d.InfinityModifier { + switch src.InfinityModifier { case None: - s = d.Time.Format("2006-01-02") + s = src.Time.Format("2006-01-02") case Infinity: s = "infinity" case NegativeInfinity: @@ -162,8 +162,8 @@ func (d Date) EncodeText(w io.Writer) error { return err } -func (d Date) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, d.Status); done { +func (src Date) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -173,9 +173,9 @@ func (d Date) EncodeBinary(w io.Writer) error { } var daysSinceDateEpoch int32 - switch d.InfinityModifier { + switch src.InfinityModifier { case None: - tUnix := time.Date(d.Time.Year(), d.Time.Month(), d.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() + tUnix := time.Date(src.Time.Year(), src.Time.Month(), src.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() secSinceDateEpoch := tUnix - dateEpoch diff --git a/int2.go b/int2.go index 2da8a96d..fb6a8ccc 100644 --- a/int2.go +++ b/int2.go @@ -14,21 +14,21 @@ type Int2 struct { Status Status } -func (i *Int2) ConvertFrom(src interface{}) error { +func (dst *Int2) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int2: - *i = value + *dst = value case int8: - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case uint8: - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case int16: - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case uint16: if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case int32: if value < math.MinInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) @@ -36,12 +36,12 @@ func (i *Int2) ConvertFrom(src interface{}) error { if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case uint32: if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case int64: if value < math.MinInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) @@ -49,12 +49,12 @@ func (i *Int2) ConvertFrom(src interface{}) error { if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case uint64: if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case int: if value < math.MinInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) @@ -62,21 +62,21 @@ func (i *Int2) ConvertFrom(src interface{}) error { if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case uint: if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case string: num, err := strconv.ParseInt(value, 10, 16) if err != nil { return err } - *i = Int2{Int: int16(num), Status: Present} + *dst = Int2{Int: int16(num), Status: Present} default: if originalSrc, ok := underlyingIntType(src); ok { - return i.ConvertFrom(originalSrc) + return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Int2", value) } @@ -84,18 +84,18 @@ func (i *Int2) ConvertFrom(src interface{}) error { return nil } -func (i *Int2) AssignTo(dst interface{}) error { - return int64AssignTo(int64(i.Int), i.Status, dst) +func (src *Int2) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) } -func (i *Int2) DecodeText(r io.Reader) error { +func (dst *Int2) DecodeText(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *i = Int2{Status: Null} + *dst = Int2{Status: Null} return nil } @@ -110,18 +110,18 @@ func (i *Int2) DecodeText(r io.Reader) error { return err } - *i = Int2{Int: int16(n), Status: Present} + *dst = Int2{Int: int16(n), Status: Present} return nil } -func (i *Int2) DecodeBinary(r io.Reader) error { +func (dst *Int2) DecodeBinary(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *i = Int2{Status: Null} + *dst = Int2{Status: Null} return nil } @@ -134,16 +134,16 @@ func (i *Int2) DecodeBinary(r io.Reader) error { return err } - *i = Int2{Int: int16(n), Status: Present} + *dst = Int2{Int: int16(n), Status: Present} return nil } -func (i Int2) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, i.Status); done { +func (src Int2) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } - s := strconv.FormatInt(int64(i.Int), 10) + s := strconv.FormatInt(int64(src.Int), 10) _, err := pgio.WriteInt32(w, int32(len(s))) if err != nil { return nil @@ -152,8 +152,8 @@ func (i Int2) EncodeText(w io.Writer) error { return err } -func (i Int2) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, i.Status); done { +func (src Int2) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -162,6 +162,6 @@ func (i Int2) EncodeBinary(w io.Writer) error { return err } - _, err = pgio.WriteInt16(w, i.Int) + _, err = pgio.WriteInt16(w, src.Int) return err } diff --git a/int2array.go b/int2array.go index 86375516..4ac0c409 100644 --- a/int2array.go +++ b/int2array.go @@ -14,15 +14,15 @@ type Int2Array struct { Status Status } -func (a *Int2Array) ConvertFrom(src interface{}) error { +func (dst *Int2Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int2Array: - *a = value + *dst = value case []int16: if value == nil { - *a = Int2Array{Status: Null} + *dst = Int2Array{Status: Null} } else if len(value) == 0 { - *a = Int2Array{Status: Present} + *dst = Int2Array{Status: Present} } else { elements := make([]Int2, len(value)) for i := range value { @@ -30,7 +30,7 @@ func (a *Int2Array) ConvertFrom(src interface{}) error { return err } } - *a = Int2Array{ + *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, Status: Present, @@ -38,9 +38,9 @@ func (a *Int2Array) ConvertFrom(src interface{}) error { } case []uint16: if value == nil { - *a = Int2Array{Status: Null} + *dst = Int2Array{Status: Null} } else if len(value) == 0 { - *a = Int2Array{Status: Present} + *dst = Int2Array{Status: Present} } else { elements := make([]Int2, len(value)) for i := range value { @@ -48,7 +48,7 @@ func (a *Int2Array) ConvertFrom(src interface{}) error { return err } } - *a = Int2Array{ + *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, Status: Present, @@ -56,7 +56,7 @@ func (a *Int2Array) ConvertFrom(src interface{}) error { } default: if originalSrc, ok := underlyingSliceType(src); ok { - return a.ConvertFrom(originalSrc) + return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Int2", value) } @@ -64,13 +64,13 @@ func (a *Int2Array) ConvertFrom(src interface{}) error { return nil } -func (a *Int2Array) AssignTo(dst interface{}) error { +func (src *Int2Array) AssignTo(dst interface{}) error { switch v := dst.(type) { case *[]int16: - if a.Status == Present { - *v = make([]int16, len(a.Elements)) - for i := range a.Elements { - if err := a.Elements[i].AssignTo(&((*v)[i])); err != nil { + if src.Status == Present { + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } @@ -78,10 +78,10 @@ func (a *Int2Array) AssignTo(dst interface{}) error { *v = nil } case *[]uint16: - if a.Status == Present { - *v = make([]uint16, len(a.Elements)) - for i := range a.Elements { - if err := a.Elements[i].AssignTo(&((*v)[i])); err != nil { + if src.Status == Present { + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } @@ -89,20 +89,20 @@ func (a *Int2Array) AssignTo(dst interface{}) error { *v = nil } default: - return fmt.Errorf("cannot put decode %v into %T", a, dst) + return fmt.Errorf("cannot put decode %v into %T", src, dst) } return nil } -func (a *Int2Array) DecodeText(r io.Reader) error { +func (dst *Int2Array) DecodeText(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *a = Int2Array{Status: Null} + *dst = Int2Array{Status: Null} return nil } @@ -135,19 +135,19 @@ func (a *Int2Array) DecodeText(r io.Reader) error { } } - *a = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} return nil } -func (a *Int2Array) DecodeBinary(r io.Reader) error { +func (dst *Int2Array) DecodeBinary(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *a = Int2Array{Status: Null} + *dst = Int2Array{Status: Null} return nil } @@ -158,7 +158,7 @@ func (a *Int2Array) DecodeBinary(r io.Reader) error { } if len(arrayHeader.Dimensions) == 0 { - *a = Int2Array{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = Int2Array{Dimensions: arrayHeader.Dimensions, Status: Present} return nil } @@ -176,16 +176,16 @@ func (a *Int2Array) DecodeBinary(r io.Reader) error { } } - *a = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} return nil } -func (a *Int2Array) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, a.Status); done { +func (src *Int2Array) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } - if len(a.Dimensions) == 0 { + if len(src.Dimensions) == 0 { _, err := pgio.WriteInt32(w, 2) if err != nil { return err @@ -197,7 +197,7 @@ func (a *Int2Array) EncodeText(w io.Writer) error { buf := &bytes.Buffer{} - err := EncodeTextArrayDimensions(buf, a.Dimensions) + err := EncodeTextArrayDimensions(buf, src.Dimensions) if err != nil { return err } @@ -207,15 +207,15 @@ func (a *Int2Array) EncodeText(w io.Writer) error { // [4]. A multi-dimensional array of lengths [3,5,2] would have a // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' // or '}'. - dimElemCounts := make([]int, len(a.Dimensions)) - dimElemCounts[len(a.Dimensions)-1] = int(a.Dimensions[len(a.Dimensions)-1].Length) - for i := len(a.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(a.Dimensions[i].Length) * dimElemCounts[i+1] + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } textElementWriter := NewTextElementWriter(buf) - for i, elem := range a.Elements { + for i, elem := range src.Elements { if i > 0 { err = pgio.WriteByte(buf, ',') if err != nil { @@ -257,8 +257,8 @@ func (a *Int2Array) EncodeText(w io.Writer) error { return err } -func (a *Int2Array) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, a.Status); done { +func (src *Int2Array) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -268,18 +268,18 @@ func (a *Int2Array) EncodeBinary(w io.Writer) error { // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} - for i := range a.Elements { - err := a.Elements[i].EncodeBinary(elemBuf) + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { return err } - if a.Elements[i].Status == Null { + if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true } } arrayHeader.ElementOID = Int2OID - arrayHeader.Dimensions = a.Dimensions + arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - // or how not pay allocations for the byte order conversions. diff --git a/int4.go b/int4.go index 84c45522..1a4733b0 100644 --- a/int4.go +++ b/int4.go @@ -14,25 +14,25 @@ type Int4 struct { Status Status } -func (i *Int4) ConvertFrom(src interface{}) error { +func (dst *Int4) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int4: - *i = value + *dst = value case int8: - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case uint8: - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case int16: - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case uint16: - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case int32: - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case uint32: if value > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) } - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case int64: if value < math.MinInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) @@ -40,12 +40,12 @@ func (i *Int4) ConvertFrom(src interface{}) error { if value > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) } - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case uint64: if value > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) } - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case int: if value < math.MinInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) @@ -53,21 +53,21 @@ func (i *Int4) ConvertFrom(src interface{}) error { if value > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) } - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case uint: if value > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) } - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case string: num, err := strconv.ParseInt(value, 10, 32) if err != nil { return err } - *i = Int4{Int: int32(num), Status: Present} + *dst = Int4{Int: int32(num), Status: Present} default: if originalSrc, ok := underlyingIntType(src); ok { - return i.ConvertFrom(originalSrc) + return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Int8", value) } @@ -75,18 +75,18 @@ func (i *Int4) ConvertFrom(src interface{}) error { return nil } -func (i *Int4) AssignTo(dst interface{}) error { - return int64AssignTo(int64(i.Int), i.Status, dst) +func (src *Int4) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) } -func (i *Int4) DecodeText(r io.Reader) error { +func (dst *Int4) DecodeText(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *i = Int4{Status: Null} + *dst = Int4{Status: Null} return nil } @@ -101,18 +101,18 @@ func (i *Int4) DecodeText(r io.Reader) error { return err } - *i = Int4{Int: int32(n), Status: Present} + *dst = Int4{Int: int32(n), Status: Present} return nil } -func (i *Int4) DecodeBinary(r io.Reader) error { +func (dst *Int4) DecodeBinary(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *i = Int4{Status: Null} + *dst = Int4{Status: Null} return nil } @@ -125,16 +125,16 @@ func (i *Int4) DecodeBinary(r io.Reader) error { return err } - *i = Int4{Int: n, Status: Present} + *dst = Int4{Int: n, Status: Present} return nil } -func (i Int4) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, i.Status); done { +func (src Int4) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } - s := strconv.FormatInt(int64(i.Int), 10) + s := strconv.FormatInt(int64(src.Int), 10) _, err := pgio.WriteInt32(w, int32(len(s))) if err != nil { return nil @@ -143,8 +143,8 @@ func (i Int4) EncodeText(w io.Writer) error { return err } -func (i Int4) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, i.Status); done { +func (src Int4) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -153,6 +153,6 @@ func (i Int4) EncodeBinary(w io.Writer) error { return err } - _, err = pgio.WriteInt32(w, i.Int) + _, err = pgio.WriteInt32(w, src.Int) return err } diff --git a/int8.go b/int8.go index c0e14e44..7f307f18 100644 --- a/int8.go +++ b/int8.go @@ -14,29 +14,29 @@ type Int8 struct { Status Status } -func (i *Int8) ConvertFrom(src interface{}) error { +func (dst *Int8) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int8: - *i = value + *dst = value case int8: - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case uint8: - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case int16: - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case uint16: - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case int32: - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case uint32: - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case int64: - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case uint64: if value > math.MaxInt64 { return fmt.Errorf("%d is greater than maximum value for Int8", value) } - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case int: if int64(value) < math.MinInt64 { return fmt.Errorf("%d is greater than maximum value for Int8", value) @@ -44,21 +44,21 @@ func (i *Int8) ConvertFrom(src interface{}) error { if int64(value) > math.MaxInt64 { return fmt.Errorf("%d is greater than maximum value for Int8", value) } - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case uint: if uint64(value) > math.MaxInt64 { return fmt.Errorf("%d is greater than maximum value for Int8", value) } - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case string: num, err := strconv.ParseInt(value, 10, 64) if err != nil { return err } - *i = Int8{Int: num, Status: Present} + *dst = Int8{Int: num, Status: Present} default: if originalSrc, ok := underlyingIntType(src); ok { - return i.ConvertFrom(originalSrc) + return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Int8", value) } @@ -66,18 +66,18 @@ func (i *Int8) ConvertFrom(src interface{}) error { return nil } -func (i *Int8) AssignTo(dst interface{}) error { - return int64AssignTo(int64(i.Int), i.Status, dst) +func (src *Int8) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) } -func (i *Int8) DecodeText(r io.Reader) error { +func (dst *Int8) DecodeText(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *i = Int8{Status: Null} + *dst = Int8{Status: Null} return nil } @@ -92,18 +92,18 @@ func (i *Int8) DecodeText(r io.Reader) error { return err } - *i = Int8{Int: n, Status: Present} + *dst = Int8{Int: n, Status: Present} return nil } -func (i *Int8) DecodeBinary(r io.Reader) error { +func (dst *Int8) DecodeBinary(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *i = Int8{Status: Null} + *dst = Int8{Status: Null} return nil } @@ -116,16 +116,16 @@ func (i *Int8) DecodeBinary(r io.Reader) error { return err } - *i = Int8{Int: n, Status: Present} + *dst = Int8{Int: n, Status: Present} return nil } -func (i Int8) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, i.Status); done { +func (src Int8) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } - s := strconv.FormatInt(i.Int, 10) + s := strconv.FormatInt(src.Int, 10) _, err := pgio.WriteInt32(w, int32(len(s))) if err != nil { return nil @@ -134,8 +134,8 @@ func (i Int8) EncodeText(w io.Writer) error { return err } -func (i Int8) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, i.Status); done { +func (src Int8) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -144,6 +144,6 @@ func (i Int8) EncodeBinary(w io.Writer) error { return err } - _, err = pgio.WriteInt64(w, i.Int) + _, err = pgio.WriteInt64(w, src.Int) return err } diff --git a/timestamptz.go b/timestamptz.go index cc33b296..4f08cd2a 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -25,15 +25,15 @@ type Timestamptz struct { InfinityModifier } -func (t *Timestamptz) ConvertFrom(src interface{}) error { +func (dst *Timestamptz) ConvertFrom(src interface{}) error { switch value := src.(type) { case Timestamptz: - *t = value + *dst = value case time.Time: - *t = Timestamptz{Time: value, Status: Present} + *dst = Timestamptz{Time: value, Status: Present} default: if originalSrc, ok := underlyingTimeType(src); ok { - return t.ConvertFrom(originalSrc) + return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Timestamptz", value) } @@ -41,20 +41,20 @@ func (t *Timestamptz) ConvertFrom(src interface{}) error { return nil } -func (t *Timestamptz) AssignTo(dst interface{}) error { +func (src *Timestamptz) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: - if t.Status != Present || t.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", t, dst) + if src.Status != Present || src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) } - *v = t.Time + *v = src.Time default: if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { el := v.Elem() switch el.Kind() { // if dst is a pointer to pointer, strip the pointer and try again case reflect.Ptr: - if t.Status == Null { + if src.Status == Null { if !el.IsNil() { // if the destination pointer is not nil, nil it out el.Set(reflect.Zero(el.Type())) @@ -65,23 +65,23 @@ func (t *Timestamptz) AssignTo(dst interface{}) error { // allocate destination el.Set(reflect.New(el.Type().Elem())) } - return t.AssignTo(el.Interface()) + return src.AssignTo(el.Interface()) } } - return fmt.Errorf("cannot assign %v into %T", t, dst) + return fmt.Errorf("cannot assign %v into %T", src, dst) } return nil } -func (t *Timestamptz) DecodeText(r io.Reader) error { +func (dst *Timestamptz) DecodeText(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *t = Timestamptz{Status: Null} + *dst = Timestamptz{Status: Null} return nil } @@ -94,9 +94,9 @@ func (t *Timestamptz) DecodeText(r io.Reader) error { sbuf := string(buf) switch sbuf { case "infinity": - *t = Timestamptz{Status: Present, InfinityModifier: Infinity} + *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} case "-infinity": - *t = Timestamptz{Status: Present, InfinityModifier: -Infinity} + *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} default: var format string if sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+' { @@ -112,20 +112,20 @@ func (t *Timestamptz) DecodeText(r io.Reader) error { return err } - *t = Timestamptz{Time: tim, Status: Present} + *dst = Timestamptz{Time: tim, Status: Present} } return nil } -func (t *Timestamptz) DecodeBinary(r io.Reader) error { +func (dst *Timestamptz) DecodeBinary(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *t = Timestamptz{Status: Null} + *dst = Timestamptz{Status: Null} return nil } @@ -140,28 +140,28 @@ func (t *Timestamptz) DecodeBinary(r io.Reader) error { switch microsecSinceY2K { case infinityMicrosecondOffset: - *t = Timestamptz{Status: Present, InfinityModifier: Infinity} + *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} case negativeInfinityMicrosecondOffset: - *t = Timestamptz{Status: Present, InfinityModifier: -Infinity} + *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} default: microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K tim := time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) - *t = Timestamptz{Time: tim, Status: Present} + *dst = Timestamptz{Time: tim, Status: Present} } return nil } -func (t Timestamptz) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, t.Status); done { +func (src Timestamptz) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } var s string - switch t.InfinityModifier { + switch src.InfinityModifier { case None: - s = t.Time.UTC().Format(pgTimestamptzSecondFormat) + s = src.Time.UTC().Format(pgTimestamptzSecondFormat) case Infinity: s = "infinity" case NegativeInfinity: @@ -177,8 +177,8 @@ func (t Timestamptz) EncodeText(w io.Writer) error { return err } -func (t Timestamptz) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, t.Status); done { +func (src Timestamptz) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -188,9 +188,9 @@ func (t Timestamptz) EncodeBinary(w io.Writer) error { } var microsecSinceY2K int64 - switch t.InfinityModifier { + switch src.InfinityModifier { case None: - microsecSinceUnixEpoch := t.Time.Unix()*1000000 + int64(t.Time.Nanosecond())/1000 + microsecSinceUnixEpoch := src.Time.Unix()*1000000 + int64(src.Time.Nanosecond())/1000 microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K case Infinity: microsecSinceY2K = infinityMicrosecondOffset From 3d54c9a9588e4961a382109bfdc709fa5b4812ac Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 17:59:26 -0600 Subject: [PATCH 0006/1158] Add test for pgtype.Int2.AssignTo --- bool_test.go | 2 -- int2_test.go | 70 ++++++++++++++++++++++++++++++++++++++++++++++++-- pgtype_test.go | 5 ++++ 3 files changed, 73 insertions(+), 4 deletions(-) diff --git a/bool_test.go b/bool_test.go index 74140b5e..374f07da 100644 --- a/bool_test.go +++ b/bool_test.go @@ -7,8 +7,6 @@ import ( "github.com/jackc/pgx/pgtype" ) -type _bool bool - func TestBoolTranscode(t *testing.T) { testSuccessfulTranscode(t, "bool", []interface{}{ pgtype.Bool{Bool: false, Status: pgtype.Present}, diff --git a/int2_test.go b/int2_test.go index a8493a16..1074c9b5 100644 --- a/int2_test.go +++ b/int2_test.go @@ -2,6 +2,7 @@ package pgtype_test import ( "math" + "reflect" "testing" "github.com/jackc/pgx/pgtype" @@ -19,8 +20,6 @@ func TestInt2Transcode(t *testing.T) { } func TestInt2ConvertFrom(t *testing.T) { - type _int8 int8 - successfulTests := []struct { source interface{} result pgtype.Int2 @@ -53,3 +52,70 @@ func TestInt2ConvertFrom(t *testing.T) { } } } + +func TestInt2AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.Int2 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Int2 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype_test.go b/pgtype_test.go index a1a575f7..32ebebfe 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -11,6 +11,11 @@ import ( "github.com/jackc/pgx/pgtype" ) +// Test for renamed types +type _bool bool +type _int8 int8 +type _int16 int16 + func mustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) if err != nil { From db69aa6f720cdd3f6ff203ec05930e947e1b1cbe Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 18:23:26 -0600 Subject: [PATCH 0007/1158] Add tests to more pgtypes Int4, Int8, Date, Timestamptz --- date.go | 5 +--- date_test.go | 46 +++++++++++++++++++++++++++++ int4_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++-- int8_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++-- timestamptz_test.go | 46 +++++++++++++++++++++++++++++ 5 files changed, 229 insertions(+), 8 deletions(-) diff --git a/date.go b/date.go index 6cd8e499..307f1e59 100644 --- a/date.go +++ b/date.go @@ -50,10 +50,7 @@ func (src *Date) AssignTo(dst interface{}) error { // if dst is a pointer to pointer, strip the pointer and try again case reflect.Ptr: if src.Status == Null { - if !el.IsNil() { - // if the destination pointer is not nil, nil it out - el.Set(reflect.Zero(el.Type())) - } + el.Set(reflect.Zero(el.Type())) return nil } if el.IsNil() { diff --git a/date_test.go b/date_test.go index c3e971d0..65d743e9 100644 --- a/date_test.go +++ b/date_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "reflect" "testing" "time" @@ -28,6 +29,7 @@ func TestDateConvertFrom(t *testing.T) { source interface{} result pgtype.Date }{ + {source: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, @@ -49,3 +51,47 @@ func TestDateConvertFrom(t *testing.T) { } } } + +func TestDateAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Date + dst interface{} + expected interface{} + }{ + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {src: pgtype.Date{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Date + dst interface{} + expected interface{} + }{ + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/int4_test.go b/int4_test.go index 04411849..cd57e2c9 100644 --- a/int4_test.go +++ b/int4_test.go @@ -2,6 +2,7 @@ package pgtype_test import ( "math" + "reflect" "testing" "github.com/jackc/pgx/pgtype" @@ -19,8 +20,6 @@ func TestInt4Transcode(t *testing.T) { } func TestInt4ConvertFrom(t *testing.T) { - type _int8 int8 - successfulTests := []struct { source interface{} result pgtype.Int4 @@ -53,3 +52,70 @@ func TestInt4ConvertFrom(t *testing.T) { } } } + +func TestInt4AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.Int4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Int4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/int8_test.go b/int8_test.go index ba246224..f9d8646f 100644 --- a/int8_test.go +++ b/int8_test.go @@ -2,6 +2,7 @@ package pgtype_test import ( "math" + "reflect" "testing" "github.com/jackc/pgx/pgtype" @@ -19,8 +20,6 @@ func TestInt8Transcode(t *testing.T) { } func TestInt8ConvertFrom(t *testing.T) { - type _int8 int8 - successfulTests := []struct { source interface{} result pgtype.Int8 @@ -53,3 +52,70 @@ func TestInt8ConvertFrom(t *testing.T) { } } } + +func TestInt8AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.Int8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Int8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/timestamptz_test.go b/timestamptz_test.go index 795195f8..adb72620 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "reflect" "testing" "time" @@ -37,6 +38,7 @@ func TestTimestamptzConvertFrom(t *testing.T) { source interface{} result pgtype.Timestamptz }{ + {source: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Status: pgtype.Present}}, @@ -58,3 +60,47 @@ func TestTimestamptzConvertFrom(t *testing.T) { } } } + +func TestTimestamptzAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Timestamptz + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {src: pgtype.Timestamptz{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Timestamptz + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} From c4e08dab42cad1f1fee6d785ca898617e64d8f3b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 18:39:52 -0600 Subject: [PATCH 0008/1158] Add pgtype error cases --- date_test.go | 16 ++++++++++++++++ int2_test.go | 20 ++++++++++++++++++++ int4_test.go | 21 +++++++++++++++++++++ int8_test.go | 22 ++++++++++++++++++++++ timestamptz.go | 5 +---- timestamptz_test.go | 16 ++++++++++++++++ 6 files changed, 96 insertions(+), 4 deletions(-) diff --git a/date_test.go b/date_test.go index 65d743e9..3a473b6a 100644 --- a/date_test.go +++ b/date_test.go @@ -94,4 +94,20 @@ func TestDateAssignTo(t *testing.T) { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } } + + errorTests := []struct { + src pgtype.Date + dst interface{} + }{ + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } } diff --git a/int2_test.go b/int2_test.go index 1074c9b5..8601309d 100644 --- a/int2_test.go +++ b/int2_test.go @@ -118,4 +118,24 @@ func TestInt2AssignTo(t *testing.T) { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } } + + errorTests := []struct { + src pgtype.Int2 + dst interface{} + }{ + {src: pgtype.Int2{Int: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &i16}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } } diff --git a/int4_test.go b/int4_test.go index cd57e2c9..0ac2e5b5 100644 --- a/int4_test.go +++ b/int4_test.go @@ -118,4 +118,25 @@ func TestInt4AssignTo(t *testing.T) { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } } + + errorTests := []struct { + src pgtype.Int4 + dst interface{} + }{ + {src: pgtype.Int4{Int: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Int4{Int: 40000, Status: pgtype.Present}, dst: &i16}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } } diff --git a/int8_test.go b/int8_test.go index f9d8646f..15762a50 100644 --- a/int8_test.go +++ b/int8_test.go @@ -118,4 +118,26 @@ func TestInt8AssignTo(t *testing.T) { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } } + + errorTests := []struct { + src pgtype.Int8 + dst interface{} + }{ + {src: pgtype.Int8{Int: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Int8{Int: 40000, Status: pgtype.Present}, dst: &i16}, + {src: pgtype.Int8{Int: 5000000000, Status: pgtype.Present}, dst: &i32}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &i64}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } } diff --git a/timestamptz.go b/timestamptz.go index 4f08cd2a..721c8084 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -55,10 +55,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { // if dst is a pointer to pointer, strip the pointer and try again case reflect.Ptr: if src.Status == Null { - if !el.IsNil() { - // if the destination pointer is not nil, nil it out - el.Set(reflect.Zero(el.Type())) - } + el.Set(reflect.Zero(el.Type())) return nil } if el.IsNil() { diff --git a/timestamptz_test.go b/timestamptz_test.go index adb72620..8f80ca81 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -103,4 +103,20 @@ func TestTimestamptzAssignTo(t *testing.T) { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } } + + errorTests := []struct { + src pgtype.Timestamptz + dst interface{} + }{ + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } } From a2843aba531dbdd3d53638238f02cf8403851cae Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 19:19:31 -0600 Subject: [PATCH 0009/1158] Add tests for pgtype.Int2Array --- convert.go | 22 +++++++ int2array.go | 3 + int2array_test.go | 153 ++++++++++++++++++++++++++++++++++++---------- pgtype_test.go | 1 + 4 files changed, 147 insertions(+), 32 deletions(-) diff --git a/convert.go b/convert.go index 3f3d9e5f..e35e2310 100644 --- a/convert.go +++ b/convert.go @@ -122,6 +122,28 @@ func underlyingSliceType(val interface{}) (interface{}, bool) { return nil, false } +func underlyingPtrSliceType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + if refVal.Kind() != reflect.Ptr { + return nil, false + } + if refVal.IsNil() { + return nil, false + } + + sliceVal := refVal.Elem().Interface() + baseSliceType := reflect.SliceOf(reflect.TypeOf(sliceVal).Elem()) + ptrBaseSliceType := reflect.PtrTo(baseSliceType) + + if refVal.Type().ConvertibleTo(ptrBaseSliceType) { + convVal := refVal.Convert(ptrBaseSliceType) + return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() + } + + return nil, false +} + func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { if srcStatus == Present { switch v := dst.(type) { diff --git a/int2array.go b/int2array.go index 4ac0c409..e6809c1e 100644 --- a/int2array.go +++ b/int2array.go @@ -89,6 +89,9 @@ func (src *Int2Array) AssignTo(dst interface{}) error { *v = nil } default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } return fmt.Errorf("cannot put decode %v into %T", src, dst) } diff --git a/int2array_test.go b/int2array_test.go index 5ea81990..ced0eab4 100644 --- a/int2array_test.go +++ b/int2array_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "reflect" "testing" "github.com/jackc/pgx/pgtype" @@ -50,38 +51,126 @@ func TestInt2ArrayTranscode(t *testing.T) { }) } -// func TestInt2ConvertFrom(t *testing.T) { -// type _int8 int8 +func TestInt2ArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int2Array + }{ + { + source: []int16{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint16{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]int16)(nil)), + result: pgtype.Int2Array{Status: pgtype.Null}, + }, + } -// successfulTests := []struct { -// source interface{} -// result pgtype.Int2 -// }{ -// {source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, -// {source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, -// {source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, -// {source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, -// {source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// } + for i, tt := range successfulTests { + var r pgtype.Int2Array + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } -// for i, tt := range successfulTests { -// var r pgtype.Int2 -// err := r.ConvertFrom(tt.source) -// if err != nil { -// t.Errorf("%d: %v", i, err) -// } + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} -// if r != tt.result { -// t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) -// } -// } -// } +func TestInt2ArrayAssignTo(t *testing.T) { + var int16Slice []int16 + var uint16Slice []uint16 + var namedInt16Slice _int16Slice + + simpleTests := []struct { + src pgtype.Int2Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int16Slice, + expected: []int16{1}, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint16Slice, + expected: []uint16{1}, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedInt16Slice, + expected: _int16Slice{1}, + }, + { + src: pgtype.Int2Array{Status: pgtype.Null}, + dst: &int16Slice, + expected: (([]int16)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int2Array + dst interface{} + }{ + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int16Slice, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: -1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint16Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype_test.go b/pgtype_test.go index 32ebebfe..a727e2e5 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -15,6 +15,7 @@ import ( type _bool bool type _int8 int8 type _int16 int16 +type _int16Slice []int16 func mustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) From 34c5070371fc7bc2f97014455f87e6f674105403 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 11:48:53 -0600 Subject: [PATCH 0010/1158] Add arrays to all other pgtypes --- boolarray.go | 286 +++++++++++++++++++++++++++++++++++ boolarray_test.go | 152 +++++++++++++++++++ datearray.go | 287 +++++++++++++++++++++++++++++++++++ datearray_test.go | 142 ++++++++++++++++++ int2array.go | 6 + int4array.go | 317 +++++++++++++++++++++++++++++++++++++++ int4array_test.go | 176 ++++++++++++++++++++++ int8array.go | 317 +++++++++++++++++++++++++++++++++++++++ int8array_test.go | 176 ++++++++++++++++++++++ pgtype.go | 5 +- pgtype_test.go | 2 + timestamptzarray.go | 287 +++++++++++++++++++++++++++++++++++ timestamptzarray_test.go | 158 +++++++++++++++++++ typed_array.go.erb | 286 +++++++++++++++++++++++++++++++++++ typed_array_gen.sh | 6 + 15 files changed, 2601 insertions(+), 2 deletions(-) create mode 100644 boolarray.go create mode 100644 boolarray_test.go create mode 100644 datearray.go create mode 100644 datearray_test.go create mode 100644 int4array.go create mode 100644 int4array_test.go create mode 100644 int8array.go create mode 100644 int8array_test.go create mode 100644 timestamptzarray.go create mode 100644 timestamptzarray_test.go create mode 100644 typed_array.go.erb create mode 100644 typed_array_gen.sh diff --git a/boolarray.go b/boolarray.go new file mode 100644 index 00000000..8dd68dc2 --- /dev/null +++ b/boolarray.go @@ -0,0 +1,286 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type BoolArray struct { + Elements []Bool + Dimensions []ArrayDimension + Status Status +} + +func (dst *BoolArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case BoolArray: + *dst = value + + case []bool: + if value == nil { + *dst = BoolArray{Status: Null} + } else if len(value) == 0 { + *dst = BoolArray{Status: Present} + } else { + elements := make([]Bool, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = BoolArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Bool", value) + } + + return nil +} + +func (src *BoolArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]bool: + if src.Status == Present { + *v = make([]bool, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *BoolArray) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = BoolArray{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Bool + + if len(uta.Elements) > 0 { + elements = make([]Bool, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Bool + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = BoolArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *BoolArray) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = BoolArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = BoolArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Bool, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = BoolArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *BoolArray) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *BoolArray) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = BoolOID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/boolarray_test.go b/boolarray_test.go new file mode 100644 index 00000000..c5f15f97 --- /dev/null +++ b/boolarray_test.go @@ -0,0 +1,152 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestBoolArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "bool[]", []interface{}{ + &pgtype.BoolArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.BoolArray{ + Elements: []pgtype.Bool{ + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.BoolArray{Status: pgtype.Null}, + &pgtype.BoolArray{ + Elements: []pgtype.Bool{ + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Bool: false, Status: pgtype.Present}, + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Status: pgtype.Null}, + pgtype.Bool{Bool: false, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.BoolArray{ + Elements: []pgtype.Bool{ + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Bool: false, Status: pgtype.Present}, + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Bool: false, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestBoolArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.BoolArray + }{ + { + source: []bool{true}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]bool)(nil)), + result: pgtype.BoolArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.BoolArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestBoolArrayAssignTo(t *testing.T) { + var boolSlice []bool + type _boolSlice []bool + var namedBoolSlice _boolSlice + + simpleTests := []struct { + src pgtype.BoolArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &boolSlice, + expected: []bool{true}, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedBoolSlice, + expected: _boolSlice{true}, + }, + { + src: pgtype.BoolArray{Status: pgtype.Null}, + dst: &boolSlice, + expected: (([]bool)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.BoolArray + dst interface{} + }{ + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &boolSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/datearray.go b/datearray.go new file mode 100644 index 00000000..877f328e --- /dev/null +++ b/datearray.go @@ -0,0 +1,287 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + "time" + + "github.com/jackc/pgx/pgio" +) + +type DateArray struct { + Elements []Date + Dimensions []ArrayDimension + Status Status +} + +func (dst *DateArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case DateArray: + *dst = value + + case []time.Time: + if value == nil { + *dst = DateArray{Status: Null} + } else if len(value) == 0 { + *dst = DateArray{Status: Present} + } else { + elements := make([]Date, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = DateArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Date", value) + } + + return nil +} + +func (src *DateArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]time.Time: + if src.Status == Present { + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *DateArray) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = DateArray{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Date + + if len(uta.Elements) > 0 { + elements = make([]Date, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Date + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = DateArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *DateArray) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = DateArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = DateArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Date, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = DateArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *DateArray) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *DateArray) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = DateOID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/datearray_test.go b/datearray_test.go new file mode 100644 index 00000000..60f15983 --- /dev/null +++ b/datearray_test.go @@ -0,0 +1,142 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestDateArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "date[]", []interface{}{ + &pgtype.DateArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.DateArray{ + Elements: []pgtype.Date{ + pgtype.Date{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.DateArray{Status: pgtype.Null}, + &pgtype.DateArray{ + Elements: []pgtype.Date{ + pgtype.Date{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Status: pgtype.Null}, + pgtype.Date{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.DateArray{ + Elements: []pgtype.Date{ + pgtype.Date{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestDateArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.DateArray + }{ + { + source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]time.Time)(nil)), + result: pgtype.DateArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.DateArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestDateArrayAssignTo(t *testing.T) { + var timeSlice []time.Time + + simpleTests := []struct { + src pgtype.DateArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + }, + { + src: pgtype.DateArray{Status: pgtype.Null}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.DateArray + dst interface{} + }{ + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/int2array.go b/int2array.go index e6809c1e..4fc6d882 100644 --- a/int2array.go +++ b/int2array.go @@ -18,6 +18,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int2Array: *dst = value + case []int16: if value == nil { *dst = Int2Array{Status: Null} @@ -36,6 +37,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { Status: Present, } } + case []uint16: if value == nil { *dst = Int2Array{Status: Null} @@ -54,6 +56,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { Status: Present, } } + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -66,6 +69,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { func (src *Int2Array) AssignTo(dst interface{}) error { switch v := dst.(type) { + case *[]int16: if src.Status == Present { *v = make([]int16, len(src.Elements)) @@ -77,6 +81,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } else { *v = nil } + case *[]uint16: if src.Status == Present { *v = make([]uint16, len(src.Elements)) @@ -88,6 +93,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } else { *v = nil } + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) diff --git a/int4array.go b/int4array.go new file mode 100644 index 00000000..40e1490d --- /dev/null +++ b/int4array.go @@ -0,0 +1,317 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Int4Array struct { + Elements []Int4 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Int4Array) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int4Array: + *dst = value + + case []int32: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint32: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int4", value) + } + + return nil +} + +func (src *Int4Array) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]int32: + if src.Status == Present { + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + case *[]uint32: + if src.Status == Present { + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Int4Array) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Int4Array{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Int4 + + if len(uta.Elements) > 0 { + elements = make([]Int4, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int4 + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Int4Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Int4Array) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Int4Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Int4Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int4, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = Int4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *Int4Array) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *Int4Array) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = Int4OID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/int4array_test.go b/int4array_test.go new file mode 100644 index 00000000..38ba27cb --- /dev/null +++ b/int4array_test.go @@ -0,0 +1,176 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt4ArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "int4[]", []interface{}{ + &pgtype.Int4Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4Array{Status: pgtype.Null}, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Int: 2, Status: pgtype.Present}, + pgtype.Int4{Int: 3, Status: pgtype.Present}, + pgtype.Int4{Int: 4, Status: pgtype.Present}, + pgtype.Int4{Status: pgtype.Null}, + pgtype.Int4{Int: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Int: 2, Status: pgtype.Present}, + pgtype.Int4{Int: 3, Status: pgtype.Present}, + pgtype.Int4{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInt4ArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int4Array + }{ + { + source: []int32{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint32{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]int32)(nil)), + result: pgtype.Int4Array{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Int4Array + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt4ArrayAssignTo(t *testing.T) { + var int32Slice []int32 + var uint32Slice []uint32 + var namedInt32Slice _int32Slice + + simpleTests := []struct { + src pgtype.Int4Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int32Slice, + expected: []int32{1}, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint32Slice, + expected: []uint32{1}, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedInt32Slice, + expected: _int32Slice{1}, + }, + { + src: pgtype.Int4Array{Status: pgtype.Null}, + dst: &int32Slice, + expected: (([]int32)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int4Array + dst interface{} + }{ + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int32Slice, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: -1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint32Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/int8array.go b/int8array.go new file mode 100644 index 00000000..35ecf946 --- /dev/null +++ b/int8array.go @@ -0,0 +1,317 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Int8Array struct { + Elements []Int8 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Int8Array) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int8Array: + *dst = value + + case []int64: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint64: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int8", value) + } + + return nil +} + +func (src *Int8Array) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]int64: + if src.Status == Present { + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + case *[]uint64: + if src.Status == Present { + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Int8Array) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Int8Array{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Int8 + + if len(uta.Elements) > 0 { + elements = make([]Int8, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int8 + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Int8Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Int8Array) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Int8Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Int8Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int8, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = Int8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *Int8Array) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *Int8Array) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = Int8OID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/int8array_test.go b/int8array_test.go new file mode 100644 index 00000000..137768c6 --- /dev/null +++ b/int8array_test.go @@ -0,0 +1,176 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt8ArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "int8[]", []interface{}{ + &pgtype.Int8Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Int8Array{ + Elements: []pgtype.Int8{ + pgtype.Int8{Int: 1, Status: pgtype.Present}, + pgtype.Int8{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int8Array{Status: pgtype.Null}, + &pgtype.Int8Array{ + Elements: []pgtype.Int8{ + pgtype.Int8{Int: 1, Status: pgtype.Present}, + pgtype.Int8{Int: 2, Status: pgtype.Present}, + pgtype.Int8{Int: 3, Status: pgtype.Present}, + pgtype.Int8{Int: 4, Status: pgtype.Present}, + pgtype.Int8{Status: pgtype.Null}, + pgtype.Int8{Int: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int8Array{ + Elements: []pgtype.Int8{ + pgtype.Int8{Int: 1, Status: pgtype.Present}, + pgtype.Int8{Int: 2, Status: pgtype.Present}, + pgtype.Int8{Int: 3, Status: pgtype.Present}, + pgtype.Int8{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInt8ArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int8Array + }{ + { + source: []int64{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint64{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]int64)(nil)), + result: pgtype.Int8Array{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Int8Array + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt8ArrayAssignTo(t *testing.T) { + var int64Slice []int64 + var uint64Slice []uint64 + var namedInt64Slice _int64Slice + + simpleTests := []struct { + src pgtype.Int8Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int64Slice, + expected: []int64{1}, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint64Slice, + expected: []uint64{1}, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedInt64Slice, + expected: _int64Slice{1}, + }, + { + src: pgtype.Int8Array{Status: pgtype.Null}, + dst: &int64Slice, + expected: (([]int64)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int8Array + dst interface{} + }{ + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int64Slice, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: -1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint64Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype.go b/pgtype.go index f9833363..5722c8ab 100644 --- a/pgtype.go +++ b/pgtype.go @@ -44,8 +44,9 @@ const ( DateOID = 1082 TimestampOID = 1114 TimestampArrayOID = 1115 - TimestampTzOID = 1184 - TimestampTzArrayOID = 1185 + DateArrayOID = 1182 + TimestamptzOID = 1184 + TimestamptzArrayOID = 1185 RecordOID = 2249 UUIDOID = 2950 JSONBOID = 3802 diff --git a/pgtype_test.go b/pgtype_test.go index a727e2e5..97afc249 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -16,6 +16,8 @@ type _bool bool type _int8 int8 type _int16 int16 type _int16Slice []int16 +type _int32Slice []int32 +type _int64Slice []int64 func mustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) diff --git a/timestamptzarray.go b/timestamptzarray.go new file mode 100644 index 00000000..72b28e43 --- /dev/null +++ b/timestamptzarray.go @@ -0,0 +1,287 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + "time" + + "github.com/jackc/pgx/pgio" +) + +type TimestamptzArray struct { + Elements []Timestamptz + Dimensions []ArrayDimension + Status Status +} + +func (dst *TimestamptzArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case TimestamptzArray: + *dst = value + + case []time.Time: + if value == nil { + *dst = TimestamptzArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestamptzArray{Status: Present} + } else { + elements := make([]Timestamptz, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = TimestamptzArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Timestamptz", value) + } + + return nil +} + +func (src *TimestamptzArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]time.Time: + if src.Status == Present { + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *TimestamptzArray) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = TimestamptzArray{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Timestamptz + + if len(uta.Elements) > 0 { + elements = make([]Timestamptz, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Timestamptz + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TimestamptzArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TimestamptzArray) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = TimestamptzArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TimestamptzArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Timestamptz, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = TimestamptzArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *TimestamptzArray) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *TimestamptzArray) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = TimestamptzOID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/timestamptzarray_test.go b/timestamptzarray_test.go new file mode 100644 index 00000000..af2c004b --- /dev/null +++ b/timestamptzarray_test.go @@ -0,0 +1,158 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestTimestamptzArrayTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "timestamptz[]", []interface{}{ + &pgtype.TimestamptzArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + pgtype.Timestamptz{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TimestamptzArray{Status: pgtype.Null}, + &pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + pgtype.Timestamptz{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Status: pgtype.Null}, + pgtype.Timestamptz{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + pgtype.Timestamptz{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }, func(a, b interface{}) bool { + ata := a.(pgtype.TimestamptzArray) + bta := b.(pgtype.TimestamptzArray) + + if len(ata.Elements) != len(bta.Elements) || ata.Status != bta.Status { + return false + } + + for i := range ata.Elements { + ae, be := ata.Elements[i], bta.Elements[i] + if !(ae.Time.Equal(be.Time) && ae.Status == be.Status && ae.InfinityModifier == be.InfinityModifier) { + return false + } + } + + return true + }) +} + +func TestTimestamptzArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.TimestamptzArray + }{ + { + source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]time.Time)(nil)), + result: pgtype.TimestamptzArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.TimestamptzArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestamptzArrayAssignTo(t *testing.T) { + var timeSlice []time.Time + + simpleTests := []struct { + src pgtype.TimestamptzArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + }, + { + src: pgtype.TimestamptzArray{Status: pgtype.Null}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.TimestamptzArray + dst interface{} + }{ + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/typed_array.go.erb b/typed_array.go.erb new file mode 100644 index 00000000..e6e480b0 --- /dev/null +++ b/typed_array.go.erb @@ -0,0 +1,286 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type <%= pgtype_array_type %> struct { + Elements []<%= pgtype_element_type %> + Dimensions []ArrayDimension + Status Status +} + +func (dst *<%= pgtype_array_type %>) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case <%= pgtype_array_type %>: + *dst = value + <% go_array_types.split(",").each do |t| %> + case <%= t %>: + if value == nil { + *dst = <%= pgtype_array_type %>{Status: Null} + } else if len(value) == 0 { + *dst = <%= pgtype_array_type %>{Status: Present} + } else { + elements := make([]<%= pgtype_element_type %>, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = <%= pgtype_array_type %>{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + <% end %> + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to <%= pgtype_element_type %>", value) + } + + return nil +} + +func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { + switch v := dst.(type) { + <% go_array_types.split(",").each do |t| %> + case *<%= t %>: + if src.Status == Present { + *v = make(<%= t %>, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + <% end %> + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *<%= pgtype_array_type %>) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = <%= pgtype_array_type %>{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []<%= pgtype_element_type %> + + if len(uta.Elements) > 0 { + elements = make([]<%= pgtype_element_type %>, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem <%= pgtype_element_type %> + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *<%= pgtype_array_type %>) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = <%= pgtype_array_type %>{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = <%= pgtype_array_type %>{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]<%= pgtype_element_type %>, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = <%= element_oid %> + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/typed_array_gen.sh b/typed_array_gen.sh new file mode 100644 index 00000000..9fec58e8 --- /dev/null +++ b/typed_array_gen.sh @@ -0,0 +1,6 @@ +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2OID typed_array.go.erb > int2array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4OID typed_array.go.erb > int4array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8OID typed_array.go.erb > int2array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOID typed_array.go.erb > boolarray.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID typed_array.go.erb > datearray.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID typed_array.go.erb > timestamptzarray.go From 39b60605ae3eff87adfb684e176eeec64a0ea610 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 12:36:24 -0600 Subject: [PATCH 0011/1158] Add timestamp to pgtype --- timestamp.go | 204 +++++++++++++++++++++++++++++ timestamp_test.go | 123 ++++++++++++++++++ timestamparray.go | 287 +++++++++++++++++++++++++++++++++++++++++ timestamparray_test.go | 158 +++++++++++++++++++++++ typed_array_gen.sh | 1 + 5 files changed, 773 insertions(+) create mode 100644 timestamp.go create mode 100644 timestamp_test.go create mode 100644 timestamparray.go create mode 100644 timestamparray_test.go diff --git a/timestamp.go b/timestamp.go new file mode 100644 index 00000000..c6933988 --- /dev/null +++ b/timestamp.go @@ -0,0 +1,204 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + "time" + + "github.com/jackc/pgx/pgio" +) + +const pgTimestampFormat = "2006-01-02 15:04:05.999999999" + +// Timestamp represents the PostgreSQL timestamp type. The PostgreSQL +// timestamp does not have a time zone. This presents a problem when +// translating to and from time.Time which requires a time zone. It is highly +// recommended to use timestamptz whenever possible. Timestamp methods either +// convert to UTC or return an error on non-UTC times. +type Timestamp struct { + Time time.Time // Time must always be in UTC. + Status Status + InfinityModifier +} + +// ConvertFrom converts src into a Timestamp and stores in dst. If src is a +// time.Time in a non-UTC time zone, the time zone is discarded. +func (dst *Timestamp) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Timestamp: + *dst = value + case time.Time: + *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Timestamp", value) + } + + return nil +} + +func (src *Timestamp) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *time.Time: + if src.Status != Present || src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if src.Status == Null { + el.Set(reflect.Zero(el.Type())) + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return src.AssignTo(el.Interface()) + } + } + return fmt.Errorf("cannot assign %v into %T", src, dst) + } + + return nil +} + +// DecodeText decodes from src into dst. The decoded time is considered to +// be in UTC. +func (dst *Timestamp) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Timestamp{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + sbuf := string(buf) + switch sbuf { + case "infinity": + *dst = Timestamp{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *dst = Timestamp{Status: Present, InfinityModifier: -Infinity} + default: + tim, err := time.Parse(pgTimestampFormat, sbuf) + if err != nil { + return err + } + + *dst = Timestamp{Time: tim, Status: Present} + } + + return nil +} + +// DecodeBinary decodes from src into dst. The decoded time is considered to +// be in UTC. +func (dst *Timestamp) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Timestamp{Status: Null} + return nil + } + + if size != 8 { + return fmt.Errorf("invalid length for timestamp: %v", size) + } + + microsecSinceY2K, err := pgio.ReadInt64(r) + if err != nil { + return err + } + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + *dst = Timestamp{Status: Present, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + *dst = Timestamp{Status: Present, InfinityModifier: -Infinity} + default: + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + tim := time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000).UTC() + *dst = Timestamp{Time: tim, Status: Present} + } + + return nil +} + +// EncodeText writes the text encoding of src into w. If src.Time is not in +// the UTC time zone it returns an error. +func (src Timestamp) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + if src.Time.Location() != time.UTC { + return fmt.Errorf("cannot encode non-UTC time into timestamp") + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.Format(pgTimestampFormat) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + + _, err = w.Write([]byte(s)) + return err +} + +// EncodeBinary writes the binary encoding of src into w. If src.Time is not in +// the UTC time zone it returns an error. +func (src Timestamp) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + if src.Time.Location() != time.UTC { + return fmt.Errorf("cannot encode non-UTC time into timestamp") + } + + _, err := pgio.WriteInt32(w, 8) + if err != nil { + return err + } + + var microsecSinceY2K int64 + switch src.InfinityModifier { + case None: + microsecSinceUnixEpoch := src.Time.Unix()*1000000 + int64(src.Time.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + _, err = pgio.WriteInt64(w, microsecSinceY2K) + return err +} diff --git a/timestamp_test.go b/timestamp_test.go new file mode 100644 index 00000000..6d6e738c --- /dev/null +++ b/timestamp_test.go @@ -0,0 +1,123 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestTimestampTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ + pgtype.Timestamp{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Status: pgtype.Null}, + pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Timestamp) + bt := b.(pgtype.Timestamp) + + return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier + }) +} + +func TestTimestampConvertFrom(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Timestamp + }{ + {source: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, result: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Timestamp + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestampAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Timestamp + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)}, + {src: pgtype.Timestamp{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Timestamp + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Timestamp + dst interface{} + }{ + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/timestamparray.go b/timestamparray.go new file mode 100644 index 00000000..f1b1d003 --- /dev/null +++ b/timestamparray.go @@ -0,0 +1,287 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + "time" + + "github.com/jackc/pgx/pgio" +) + +type TimestampArray struct { + Elements []Timestamp + Dimensions []ArrayDimension + Status Status +} + +func (dst *TimestampArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case TimestampArray: + *dst = value + + case []time.Time: + if value == nil { + *dst = TimestampArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestampArray{Status: Present} + } else { + elements := make([]Timestamp, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = TimestampArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Timestamp", value) + } + + return nil +} + +func (src *TimestampArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]time.Time: + if src.Status == Present { + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *TimestampArray) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = TimestampArray{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Timestamp + + if len(uta.Elements) > 0 { + elements = make([]Timestamp, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Timestamp + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TimestampArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TimestampArray) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = TimestampArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TimestampArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Timestamp, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = TimestampArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *TimestampArray) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *TimestampArray) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = TimestampOID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/timestamparray_test.go b/timestamparray_test.go new file mode 100644 index 00000000..68189cc7 --- /dev/null +++ b/timestamparray_test.go @@ -0,0 +1,158 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestTimestampArrayTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "timestamp[]", []interface{}{ + &pgtype.TimestampArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + pgtype.Timestamp{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TimestampArray{Status: pgtype.Null}, + &pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + pgtype.Timestamp{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Status: pgtype.Null}, + pgtype.Timestamp{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + pgtype.Timestamp{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }, func(a, b interface{}) bool { + ata := a.(pgtype.TimestampArray) + bta := b.(pgtype.TimestampArray) + + if len(ata.Elements) != len(bta.Elements) || ata.Status != bta.Status { + return false + } + + for i := range ata.Elements { + ae, be := ata.Elements[i], bta.Elements[i] + if !(ae.Time.Equal(be.Time) && ae.Status == be.Status && ae.InfinityModifier == be.InfinityModifier) { + return false + } + } + + return true + }) +} + +func TestTimestampArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.TimestampArray + }{ + { + source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + result: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]time.Time)(nil)), + result: pgtype.TimestampArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.TimestampArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestampArrayAssignTo(t *testing.T) { + var timeSlice []time.Time + + simpleTests := []struct { + src pgtype.TimestampArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + }, + { + src: pgtype.TimestampArray{Status: pgtype.Null}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.TimestampArray + dst interface{} + }{ + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 9fec58e8..9f4e1ce0 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -4,3 +4,4 @@ erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64, erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOID typed_array.go.erb > boolarray.go erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID typed_array.go.erb > datearray.go erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID typed_array.go.erb > timestamptzarray.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID typed_array.go.erb > timestamparray.go From 0f115477de91f36be3cab6d0f20bd91ea8bdbcda Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 13:29:04 -0600 Subject: [PATCH 0012/1158] Add float4, float8 and arrays --- convert.go | 52 +++++++- float4.go | 171 ++++++++++++++++++++++++++ float4_test.go | 148 +++++++++++++++++++++++ float4array.go | 286 ++++++++++++++++++++++++++++++++++++++++++++ float4array_test.go | 151 +++++++++++++++++++++++ float8.go | 161 +++++++++++++++++++++++++ float8_test.go | 148 +++++++++++++++++++++++ float8array.go | 286 ++++++++++++++++++++++++++++++++++++++++++++ float8array_test.go | 151 +++++++++++++++++++++++ int2.go | 2 +- int4.go | 2 +- int8.go | 2 +- pgtype_test.go | 2 + typed_array_gen.sh | 2 + 14 files changed, 1559 insertions(+), 5 deletions(-) create mode 100644 float4.go create mode 100644 float4_test.go create mode 100644 float4array.go create mode 100644 float4array_test.go create mode 100644 float8.go create mode 100644 float8_test.go create mode 100644 float8array.go create mode 100644 float8array_test.go diff --git a/convert.go b/convert.go index e35e2310..c4b52322 100644 --- a/convert.go +++ b/convert.go @@ -11,8 +11,8 @@ const maxUint = ^uint(0) const maxInt = int(maxUint >> 1) const minInt = -maxInt - 1 -// underlyingIntType gets the underlying type that can be converted to Int2, Int4, or Int8 -func underlyingIntType(val interface{}) (interface{}, bool) { +// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8 +func underlyingNumberType(val interface{}) (interface{}, bool) { refVal := reflect.ValueOf(val) switch refVal.Kind() { @@ -52,6 +52,12 @@ func underlyingIntType(val interface{}) (interface{}, bool) { case reflect.Uint64: convVal := uint64(refVal.Uint()) return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Float32: + convVal := float32(refVal.Float()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Float64: + convVal := refVal.Float() + return convVal, reflect.TypeOf(convVal) != refVal.Type() case reflect.String: convVal := refVal.String() return convVal, reflect.TypeOf(convVal) != refVal.Type() @@ -259,3 +265,45 @@ func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) } + +func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { + if srcStatus == Present { + switch v := dst.(type) { + case *float32: + *v = float32(srcVal) + case *float64: + *v = srcVal + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return float64AssignTo(srcVal, srcStatus, el.Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + i64 := int64(srcVal) + if float64(i64) == srcVal { + return int64AssignTo(i64, srcStatus, dst) + } + } + } + return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + } + return nil + } + + // if dst is a pointer to pointer and srcStatus is not Present, nil it out + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + if el.Kind() == reflect.Ptr { + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) +} diff --git a/float4.go b/float4.go new file mode 100644 index 00000000..a1e5aa18 --- /dev/null +++ b/float4.go @@ -0,0 +1,171 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Float4 struct { + Float float32 + Status Status +} + +func (dst *Float4) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Float4: + *dst = value + case float32: + *dst = Float4{Float: value, Status: Present} + case float64: + *dst = Float4{Float: float32(value), Status: Present} + case int8: + *dst = Float4{Float: float32(value), Status: Present} + case uint8: + *dst = Float4{Float: float32(value), Status: Present} + case int16: + *dst = Float4{Float: float32(value), Status: Present} + case uint16: + *dst = Float4{Float: float32(value), Status: Present} + case int32: + f32 := float32(value) + if int32(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case uint32: + f32 := float32(value) + if uint32(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case int64: + f32 := float32(value) + if int64(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case uint64: + f32 := float32(value) + if uint64(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case int: + f32 := float32(value) + if int(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case uint: + f32 := float32(value) + if uint(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case string: + num, err := strconv.ParseFloat(value, 32) + if err != nil { + return err + } + *dst = Float4{Float: float32(num), Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Float8", value) + } + + return nil +} + +func (src *Float4) AssignTo(dst interface{}) error { + return float64AssignTo(float64(src.Float), src.Status, dst) +} + +func (dst *Float4) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Float4{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseFloat(string(buf), 32) + if err != nil { + return err + } + + *dst = Float4{Float: float32(n), Status: Present} + return nil +} + +func (dst *Float4) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Float4{Status: Null} + return nil + } + + if size != 4 { + return fmt.Errorf("invalid length for float4: %v", size) + } + + n, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + *dst = Float4{Float: math.Float32frombits(uint32(n)), Status: Present} + return nil +} + +func (src Float4) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + s := strconv.FormatFloat(float64(src.Float), 'f', -1, 32) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (src Float4) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(math.Float32bits(src.Float))) + return err +} diff --git a/float4_test.go b/float4_test.go new file mode 100644 index 00000000..62420b8d --- /dev/null +++ b/float4_test.go @@ -0,0 +1,148 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestFloat4Transcode(t *testing.T) { + testSuccessfulTranscode(t, "float4", []interface{}{ + pgtype.Float4{Float: -1, Status: pgtype.Present}, + pgtype.Float4{Float: 0, Status: pgtype.Present}, + pgtype.Float4{Float: 0.00001, Status: pgtype.Present}, + pgtype.Float4{Float: 1, Status: pgtype.Present}, + pgtype.Float4{Float: 9999.99, Status: pgtype.Present}, + pgtype.Float4{Float: 0, Status: pgtype.Null}, + }) +} + +func TestFloat4ConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float4 + }{ + {source: float32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: float64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Float4 + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat4AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src pgtype.Float4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Float4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float4 + dst interface{} + }{ + {src: pgtype.Float4{Float: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Float4{Float: 40000, Status: pgtype.Present}, dst: &i16}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/float4array.go b/float4array.go new file mode 100644 index 00000000..c06490cf --- /dev/null +++ b/float4array.go @@ -0,0 +1,286 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Float4Array struct { + Elements []Float4 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Float4Array) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Float4Array: + *dst = value + + case []float32: + if value == nil { + *dst = Float4Array{Status: Null} + } else if len(value) == 0 { + *dst = Float4Array{Status: Present} + } else { + elements := make([]Float4, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = Float4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Float4", value) + } + + return nil +} + +func (src *Float4Array) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]float32: + if src.Status == Present { + *v = make([]float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Float4Array) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Float4Array{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Float4 + + if len(uta.Elements) > 0 { + elements = make([]Float4, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Float4 + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Float4Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Float4Array) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Float4Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Float4Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Float4, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = Float4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *Float4Array) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *Float4Array) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = Float4OID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/float4array_test.go b/float4array_test.go new file mode 100644 index 00000000..b22f4fbc --- /dev/null +++ b/float4array_test.go @@ -0,0 +1,151 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestFloat4ArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "float4[]", []interface{}{ + &pgtype.Float4Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Float4Array{ + Elements: []pgtype.Float4{ + pgtype.Float4{Float: 1, Status: pgtype.Present}, + pgtype.Float4{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Float4Array{Status: pgtype.Null}, + &pgtype.Float4Array{ + Elements: []pgtype.Float4{ + pgtype.Float4{Float: 1, Status: pgtype.Present}, + pgtype.Float4{Float: 2, Status: pgtype.Present}, + pgtype.Float4{Float: 3, Status: pgtype.Present}, + pgtype.Float4{Float: 4, Status: pgtype.Present}, + pgtype.Float4{Status: pgtype.Null}, + pgtype.Float4{Float: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Float4Array{ + Elements: []pgtype.Float4{ + pgtype.Float4{Float: 1, Status: pgtype.Present}, + pgtype.Float4{Float: 2, Status: pgtype.Present}, + pgtype.Float4{Float: 3, Status: pgtype.Present}, + pgtype.Float4{Float: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestFloat4ArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float4Array + }{ + { + source: []float32{1}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]float32)(nil)), + result: pgtype.Float4Array{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Float4Array + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat4ArrayAssignTo(t *testing.T) { + var float32Slice []float32 + var namedFloat32Slice _float32Slice + + simpleTests := []struct { + src pgtype.Float4Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1.23, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + expected: []float32{1.23}, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1.23, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedFloat32Slice, + expected: _float32Slice{1.23}, + }, + { + src: pgtype.Float4Array{Status: pgtype.Null}, + dst: &float32Slice, + expected: (([]float32)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float4Array + dst interface{} + }{ + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/float8.go b/float8.go new file mode 100644 index 00000000..c1347cb2 --- /dev/null +++ b/float8.go @@ -0,0 +1,161 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Float8 struct { + Float float64 + Status Status +} + +func (dst *Float8) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Float8: + *dst = value + case float32: + *dst = Float8{Float: float64(value), Status: Present} + case float64: + *dst = Float8{Float: value, Status: Present} + case int8: + *dst = Float8{Float: float64(value), Status: Present} + case uint8: + *dst = Float8{Float: float64(value), Status: Present} + case int16: + *dst = Float8{Float: float64(value), Status: Present} + case uint16: + *dst = Float8{Float: float64(value), Status: Present} + case int32: + *dst = Float8{Float: float64(value), Status: Present} + case uint32: + *dst = Float8{Float: float64(value), Status: Present} + case int64: + f64 := float64(value) + if int64(f64) == value { + *dst = Float8{Float: f64, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float64", value) + } + case uint64: + f64 := float64(value) + if uint64(f64) == value { + *dst = Float8{Float: f64, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float64", value) + } + case int: + f64 := float64(value) + if int(f64) == value { + *dst = Float8{Float: f64, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float64", value) + } + case uint: + f64 := float64(value) + if uint(f64) == value { + *dst = Float8{Float: f64, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float64", value) + } + case string: + num, err := strconv.ParseFloat(value, 64) + if err != nil { + return err + } + *dst = Float8{Float: float64(num), Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Float8", value) + } + + return nil +} + +func (src *Float8) AssignTo(dst interface{}) error { + return float64AssignTo(src.Float, src.Status, dst) +} + +func (dst *Float8) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Float8{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseFloat(string(buf), 64) + if err != nil { + return err + } + + *dst = Float8{Float: n, Status: Present} + return nil +} + +func (dst *Float8) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Float8{Status: Null} + return nil + } + + if size != 8 { + return fmt.Errorf("invalid length for float4: %v", size) + } + + n, err := pgio.ReadInt64(r) + if err != nil { + return err + } + + *dst = Float8{Float: math.Float64frombits(uint64(n)), Status: Present} + return nil +} + +func (src Float8) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + s := strconv.FormatFloat(float64(src.Float), 'f', -1, 64) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (src Float8) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 8) + if err != nil { + return err + } + + _, err = pgio.WriteInt64(w, int64(math.Float64bits(src.Float))) + return err +} diff --git a/float8_test.go b/float8_test.go new file mode 100644 index 00000000..748ffd25 --- /dev/null +++ b/float8_test.go @@ -0,0 +1,148 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestFloat8Transcode(t *testing.T) { + testSuccessfulTranscode(t, "float8", []interface{}{ + pgtype.Float8{Float: -1, Status: pgtype.Present}, + pgtype.Float8{Float: 0, Status: pgtype.Present}, + pgtype.Float8{Float: 0.00001, Status: pgtype.Present}, + pgtype.Float8{Float: 1, Status: pgtype.Present}, + pgtype.Float8{Float: 9999.99, Status: pgtype.Present}, + pgtype.Float8{Float: 0, Status: pgtype.Null}, + }) +} + +func TestFloat8ConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float8 + }{ + {source: float32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: float64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Float8 + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat8AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src pgtype.Float8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Float8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float8 + dst interface{} + }{ + {src: pgtype.Float8{Float: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Float8{Float: 40000, Status: pgtype.Present}, dst: &i16}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/float8array.go b/float8array.go new file mode 100644 index 00000000..776fc1e6 --- /dev/null +++ b/float8array.go @@ -0,0 +1,286 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Float8Array struct { + Elements []Float8 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Float8Array) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Float8Array: + *dst = value + + case []float64: + if value == nil { + *dst = Float8Array{Status: Null} + } else if len(value) == 0 { + *dst = Float8Array{Status: Present} + } else { + elements := make([]Float8, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = Float8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Float8", value) + } + + return nil +} + +func (src *Float8Array) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]float64: + if src.Status == Present { + *v = make([]float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Float8Array) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Float8Array{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Float8 + + if len(uta.Elements) > 0 { + elements = make([]Float8, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Float8 + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Float8Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Float8Array) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Float8Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Float8Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Float8, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = Float8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *Float8Array) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *Float8Array) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = Float8OID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/float8array_test.go b/float8array_test.go new file mode 100644 index 00000000..d4402281 --- /dev/null +++ b/float8array_test.go @@ -0,0 +1,151 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestFloat8ArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "float8[]", []interface{}{ + &pgtype.Float8Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Float8Array{ + Elements: []pgtype.Float8{ + pgtype.Float8{Float: 1, Status: pgtype.Present}, + pgtype.Float8{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Float8Array{Status: pgtype.Null}, + &pgtype.Float8Array{ + Elements: []pgtype.Float8{ + pgtype.Float8{Float: 1, Status: pgtype.Present}, + pgtype.Float8{Float: 2, Status: pgtype.Present}, + pgtype.Float8{Float: 3, Status: pgtype.Present}, + pgtype.Float8{Float: 4, Status: pgtype.Present}, + pgtype.Float8{Status: pgtype.Null}, + pgtype.Float8{Float: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Float8Array{ + Elements: []pgtype.Float8{ + pgtype.Float8{Float: 1, Status: pgtype.Present}, + pgtype.Float8{Float: 2, Status: pgtype.Present}, + pgtype.Float8{Float: 3, Status: pgtype.Present}, + pgtype.Float8{Float: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestFloat8ArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float8Array + }{ + { + source: []float64{1}, + result: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]float64)(nil)), + result: pgtype.Float8Array{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Float8Array + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat8ArrayAssignTo(t *testing.T) { + var float64Slice []float64 + var namedFloat64Slice _float64Slice + + simpleTests := []struct { + src pgtype.Float8Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1.23, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float64Slice, + expected: []float64{1.23}, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1.23, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedFloat64Slice, + expected: _float64Slice{1.23}, + }, + { + src: pgtype.Float8Array{Status: pgtype.Null}, + dst: &float64Slice, + expected: (([]float64)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float8Array + dst interface{} + }{ + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float64Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/int2.go b/int2.go index fb6a8ccc..8057550b 100644 --- a/int2.go +++ b/int2.go @@ -75,7 +75,7 @@ func (dst *Int2) ConvertFrom(src interface{}) error { } *dst = Int2{Int: int16(num), Status: Present} default: - if originalSrc, ok := underlyingIntType(src); ok { + if originalSrc, ok := underlyingNumberType(src); ok { return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Int2", value) diff --git a/int4.go b/int4.go index 1a4733b0..43691bb6 100644 --- a/int4.go +++ b/int4.go @@ -66,7 +66,7 @@ func (dst *Int4) ConvertFrom(src interface{}) error { } *dst = Int4{Int: int32(num), Status: Present} default: - if originalSrc, ok := underlyingIntType(src); ok { + if originalSrc, ok := underlyingNumberType(src); ok { return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Int8", value) diff --git a/int8.go b/int8.go index 7f307f18..b87bb85a 100644 --- a/int8.go +++ b/int8.go @@ -57,7 +57,7 @@ func (dst *Int8) ConvertFrom(src interface{}) error { } *dst = Int8{Int: num, Status: Present} default: - if originalSrc, ok := underlyingIntType(src); ok { + if originalSrc, ok := underlyingNumberType(src); ok { return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Int8", value) diff --git a/pgtype_test.go b/pgtype_test.go index 97afc249..a1dcd11b 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -18,6 +18,8 @@ type _int16 int16 type _int16Slice []int16 type _int32Slice []int32 type _int64Slice []int64 +type _float32Slice []float32 +type _float64Slice []float64 func mustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 9f4e1ce0..4ce6c3b5 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -5,3 +5,5 @@ erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool e erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID typed_array.go.erb > datearray.go erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID typed_array.go.erb > timestamptzarray.go erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID typed_array.go.erb > timestamparray.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID typed_array.go.erb > float4array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID typed_array.go.erb > float8array.go From 93e1715082540b6b67424a850c13ba1a75e12cce Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 17:33:41 -0600 Subject: [PATCH 0013/1158] Add inet and cidr to pgtype --- cidrarray.go | 31 +++++ convert.go | 16 +++ inet.go | 240 ++++++++++++++++++++++++++++++++++ inet_test.go | 115 ++++++++++++++++ inetarray.go | 320 +++++++++++++++++++++++++++++++++++++++++++++ inetarray_test.go | 164 +++++++++++++++++++++++ pgtype_test.go | 10 ++ typed_array_gen.sh | 1 + 8 files changed, 897 insertions(+) create mode 100644 cidrarray.go create mode 100644 inet.go create mode 100644 inet_test.go create mode 100644 inetarray.go create mode 100644 inetarray_test.go diff --git a/cidrarray.go b/cidrarray.go new file mode 100644 index 00000000..66dd20d0 --- /dev/null +++ b/cidrarray.go @@ -0,0 +1,31 @@ +package pgtype + +import ( + "io" +) + +type CidrArray InetArray + +func (dst *CidrArray) ConvertFrom(src interface{}) error { + return (*InetArray)(dst).ConvertFrom(src) +} + +func (src *CidrArray) AssignTo(dst interface{}) error { + return (*InetArray)(src).AssignTo(dst) +} + +func (dst *CidrArray) DecodeText(r io.Reader) error { + return (*InetArray)(dst).DecodeText(r) +} + +func (dst *CidrArray) DecodeBinary(r io.Reader) error { + return (*InetArray)(dst).DecodeBinary(r) +} + +func (src *CidrArray) EncodeText(w io.Writer) error { + return (*InetArray)(src).EncodeText(w) +} + +func (src *CidrArray) EncodeBinary(w io.Writer) error { + return (*InetArray)(src).encodeBinary(w, CidrOID) +} diff --git a/convert.go b/convert.go index c4b52322..7111f8bc 100644 --- a/convert.go +++ b/convert.go @@ -85,6 +85,22 @@ func underlyingBoolType(val interface{}) (interface{}, bool) { return nil, false } +// underlyingPtrType dereferences a pointer +func underlyingPtrType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + return nil, false +} + // underlyingTimeType gets the underlying type that can be converted to time.Time func underlyingTimeType(val interface{}) (interface{}, bool) { refVal := reflect.ValueOf(val) diff --git a/inet.go b/inet.go new file mode 100644 index 00000000..e47c64b0 --- /dev/null +++ b/inet.go @@ -0,0 +1,240 @@ +package pgtype + +import ( + "fmt" + "io" + "net" + "reflect" + + "github.com/jackc/pgx/pgio" +) + +// Network address family is dependent on server socket.h value for AF_INET. +// In practice, all platforms appear to have the same value. See +// src/include/utils/inet.h for more information. +const ( + defaultAFInet = 2 + defaultAFInet6 = 3 +) + +// Inet represents both inet and cidr PostgreSQL types. +type Inet struct { + IPNet *net.IPNet + Status Status +} + +func (dst *Inet) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Inet: + *dst = value + case net.IPNet: + *dst = Inet{IPNet: &value, Status: Present} + case *net.IPNet: + *dst = Inet{IPNet: value, Status: Present} + case net.IP: + bitCount := len(value) * 8 + mask := net.CIDRMask(bitCount, bitCount) + *dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Status: Present} + case string: + _, ipnet, err := net.ParseCIDR(value) + if err != nil { + return err + } + *dst = Inet{IPNet: ipnet, Status: Present} + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Inet", value) + } + + return nil +} + +func (src *Inet) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *net.IPNet: + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = *src.IPNet + case *net.IP: + if src.Status == Present { + + if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.IPNet.IP + } else { + *v = nil + } + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if src.Status == Null { + el.Set(reflect.Zero(el.Type())) + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return src.AssignTo(el.Interface()) + } + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Inet) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Inet{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + var ipnet *net.IPNet + + if ip := net.ParseIP(string(buf)); ip != nil { + ipv4 := ip.To4() + if ipv4 != nil { + ip = ipv4 + } + bitCount := len(ip) * 8 + mask := net.CIDRMask(bitCount, bitCount) + ipnet = &net.IPNet{Mask: mask, IP: ip} + } else { + _, ipnet, err = net.ParseCIDR(string(buf)) + if err != nil { + return err + } + } + + *dst = Inet{IPNet: ipnet, Status: Present} + return nil +} + +func (dst *Inet) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Inet{Status: Null} + return nil + } + + if size != 8 && size != 20 { + return fmt.Errorf("Received an invalid size for a inet: %d", size) + } + + // ignore family + _, err = pgio.ReadByte(r) + if err != nil { + return err + } + + bits, err := pgio.ReadByte(r) + if err != nil { + return err + } + + // ignore is_cidr + _, err = pgio.ReadByte(r) + if err != nil { + return err + } + + addressLength, err := pgio.ReadByte(r) + if err != nil { + return err + } + + var ipnet net.IPNet + ipnet.IP = make(net.IP, int(addressLength)) + _, err = r.Read(ipnet.IP) + if err != nil { + return err + } + + ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) + + *dst = Inet{IPNet: &ipnet, Status: Present} + + return nil +} + +func (src Inet) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + s := src.IPNet.String() + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +// EncodeBinary encodes src into w. +func (src Inet) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var size int32 + var family byte + switch len(src.IPNet.IP) { + case net.IPv4len: + size = 8 + family = defaultAFInet + case net.IPv6len: + size = 20 + family = defaultAFInet6 + default: + return fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) + } + + if _, err := pgio.WriteInt32(w, size); err != nil { + return err + } + + if err := pgio.WriteByte(w, family); err != nil { + return err + } + + ones, _ := src.IPNet.Mask.Size() + if err := pgio.WriteByte(w, byte(ones)); err != nil { + return err + } + + // is_cidr is ignored on server + if err := pgio.WriteByte(w, 0); err != nil { + return err + } + + if err := pgio.WriteByte(w, byte(len(src.IPNet.IP))); err != nil { + return err + } + + _, err := w.Write(src.IPNet.IP) + return err +} diff --git a/inet_test.go b/inet_test.go new file mode 100644 index 00000000..5e86376b --- /dev/null +++ b/inet_test.go @@ -0,0 +1,115 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInetTranscode(t *testing.T) { + for _, pgTypeName := range []string{"inet", "cidr"} { + testSuccessfulTranscode(t, pgTypeName, []interface{}{ + pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{Status: pgtype.Null}, + }) + } +} + +func TestInetConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Inet + }{ + {source: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Null}, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Null}}, + {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Inet + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInetAssignTo(t *testing.T) { + var ipnet net.IPNet + var pipnet *net.IPNet + var ip net.IP + var pip *net.IP + + simpleTests := []struct { + src pgtype.Inet + dst interface{} + expected interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + {src: pgtype.Inet{Status: pgtype.Null}, dst: &pipnet, expected: ((*net.IPNet)(nil))}, + {src: pgtype.Inet{Status: pgtype.Null}, dst: &pip, expected: ((*net.IP)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %#v, but result was %#v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Inet + dst interface{} + expected interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Inet + dst interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip}, + {src: pgtype.Inet{Status: pgtype.Null}, dst: &ipnet}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/inetarray.go b/inetarray.go new file mode 100644 index 00000000..eb5a4c88 --- /dev/null +++ b/inetarray.go @@ -0,0 +1,320 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + "net" + + "github.com/jackc/pgx/pgio" +) + +type InetArray struct { + Elements []Inet + Dimensions []ArrayDimension + Status Status +} + +func (dst *InetArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case InetArray: + *dst = value + case CidrArray: + *dst = InetArray(value) + case []*net.IPNet: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []net.IP: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Inet", value) + } + + return nil +} + +func (src *InetArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]*net.IPNet: + if src.Status == Present { + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + case *[]net.IP: + if src.Status == Present { + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *InetArray) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = InetArray{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Inet + + if len(uta.Elements) > 0 { + elements = make([]Inet, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Inet + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = InetArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *InetArray) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = InetArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = InetArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Inet, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = InetArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *InetArray) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *InetArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, InetOID) +} + +func (src *InetArray) encodeBinary(w io.Writer, elementOID int32) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = elementOID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/inetarray_test.go b/inetarray_test.go new file mode 100644 index 00000000..8cab5355 --- /dev/null +++ b/inetarray_test.go @@ -0,0 +1,164 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInetArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "inet[]", []interface{}{ + &pgtype.InetArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.InetArray{Status: pgtype.Null}, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{Status: pgtype.Null}, + pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInetArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.InetArray + }{ + { + source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]*net.IPNet)(nil)), + result: pgtype.InetArray{Status: pgtype.Null}, + }, + { + source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]net.IP)(nil)), + result: pgtype.InetArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.InetArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInetArrayAssignTo(t *testing.T) { + var ipnetSlice []*net.IPNet + var ipSlice []net.IP + + simpleTests := []struct { + src pgtype.InetArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{nil}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{nil}, + }, + { + src: pgtype.InetArray{Status: pgtype.Null}, + dst: &ipnetSlice, + expected: (([]*net.IPNet)(nil)), + }, + { + src: pgtype.InetArray{Status: pgtype.Null}, + dst: &ipSlice, + expected: (([]net.IP)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype_test.go b/pgtype_test.go index a1dcd11b..7d34ae34 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -3,6 +3,7 @@ package pgtype_test import ( "fmt" "io" + "net" "os" "reflect" "testing" @@ -44,6 +45,15 @@ func mustClose(t testing.TB, conn interface { } } +func mustParseCIDR(t testing.TB, s string) *net.IPNet { + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + t.Fatal(err) + } + + return ipnet +} + type forceTextEncoder struct { e pgtype.TextEncoder } diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 4ce6c3b5..47afdf1d 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -7,3 +7,4 @@ erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_ erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID typed_array.go.erb > timestamparray.go erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID typed_array.go.erb > float4array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID typed_array.go.erb > float8array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID typed_array.go.erb > inetarray.go From 4254e5f2d274b98809ef2321b7210b2c380962ac Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 21:20:56 -0600 Subject: [PATCH 0014/1158] Add text to pgtype --- array_test.go | 7 ++ bool.go | 2 +- boolarray.go | 10 +- convert.go | 19 +++ datearray.go | 2 +- float4array.go | 10 +- float8array.go | 10 +- inetarray.go | 2 +- int2array.go | 14 +-- int4array.go | 14 +-- int8array.go | 14 +-- pgtype_test.go | 1 + text.go | 115 +++++++++++++++++ text_test.go | 100 +++++++++++++++ textarray.go | 297 ++++++++++++++++++++++++++++++++++++++++++++ textarray_test.go | 151 ++++++++++++++++++++++ timestamparray.go | 2 +- timestamptzarray.go | 2 +- typed_array.go.erb | 2 +- typed_array_gen.sh | 1 + varchararray.go | 31 +++++ 21 files changed, 764 insertions(+), 42 deletions(-) create mode 100644 text.go create mode 100644 text_test.go create mode 100644 textarray.go create mode 100644 textarray_test.go create mode 100644 varchararray.go diff --git a/array_test.go b/array_test.go index 5e5f00e7..d1cdb4c5 100644 --- a/array_test.go +++ b/array_test.go @@ -40,6 +40,13 @@ func TestParseUntypedTextArray(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, }, }, + { + source: `{""}`, + result: pgtype.UntypedTextArray{ + Elements: []string{""}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, { source: `{"He said, \"Hello.\""}`, result: pgtype.UntypedTextArray{ diff --git a/bool.go b/bool.go index 2889b787..076403f9 100644 --- a/bool.go +++ b/bool.go @@ -66,7 +66,7 @@ func (src *Bool) AssignTo(dst interface{}) error { return nil } } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/boolarray.go b/boolarray.go index 8dd68dc2..b6b5db02 100644 --- a/boolarray.go +++ b/boolarray.go @@ -18,7 +18,7 @@ func (dst *BoolArray) ConvertFrom(src interface{}) error { switch value := src.(type) { case BoolArray: *dst = value - + case []bool: if value == nil { *dst = BoolArray{Status: Null} @@ -37,7 +37,7 @@ func (dst *BoolArray) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -50,7 +50,7 @@ func (dst *BoolArray) ConvertFrom(src interface{}) error { func (src *BoolArray) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]bool: if src.Status == Present { *v = make([]bool, len(src.Elements)) @@ -62,12 +62,12 @@ func (src *BoolArray) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/convert.go b/convert.go index 7111f8bc..31bbf060 100644 --- a/convert.go +++ b/convert.go @@ -85,6 +85,25 @@ func underlyingBoolType(val interface{}) (interface{}, bool) { return nil, false } +// underlyingStringType gets the underlying type that can be converted to String +func underlyingStringType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + // underlyingPtrType dereferences a pointer func underlyingPtrType(val interface{}) (interface{}, bool) { refVal := reflect.ValueOf(val) diff --git a/datearray.go b/datearray.go index 877f328e..5e93501e 100644 --- a/datearray.go +++ b/datearray.go @@ -68,7 +68,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/float4array.go b/float4array.go index c06490cf..8834d213 100644 --- a/float4array.go +++ b/float4array.go @@ -18,7 +18,7 @@ func (dst *Float4Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Float4Array: *dst = value - + case []float32: if value == nil { *dst = Float4Array{Status: Null} @@ -37,7 +37,7 @@ func (dst *Float4Array) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -50,7 +50,7 @@ func (dst *Float4Array) ConvertFrom(src interface{}) error { func (src *Float4Array) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]float32: if src.Status == Present { *v = make([]float32, len(src.Elements)) @@ -62,12 +62,12 @@ func (src *Float4Array) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/float8array.go b/float8array.go index 776fc1e6..bad9ed9f 100644 --- a/float8array.go +++ b/float8array.go @@ -18,7 +18,7 @@ func (dst *Float8Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Float8Array: *dst = value - + case []float64: if value == nil { *dst = Float8Array{Status: Null} @@ -37,7 +37,7 @@ func (dst *Float8Array) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -50,7 +50,7 @@ func (dst *Float8Array) ConvertFrom(src interface{}) error { func (src *Float8Array) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]float64: if src.Status == Present { *v = make([]float64, len(src.Elements)) @@ -62,12 +62,12 @@ func (src *Float8Array) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/inetarray.go b/inetarray.go index eb5a4c88..cd12e917 100644 --- a/inetarray.go +++ b/inetarray.go @@ -97,7 +97,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/int2array.go b/int2array.go index 4fc6d882..a989347d 100644 --- a/int2array.go +++ b/int2array.go @@ -18,7 +18,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int2Array: *dst = value - + case []int16: if value == nil { *dst = Int2Array{Status: Null} @@ -37,7 +37,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { Status: Present, } } - + case []uint16: if value == nil { *dst = Int2Array{Status: Null} @@ -56,7 +56,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -69,7 +69,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { func (src *Int2Array) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]int16: if src.Status == Present { *v = make([]int16, len(src.Elements)) @@ -81,7 +81,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } else { *v = nil } - + case *[]uint16: if src.Status == Present { *v = make([]uint16, len(src.Elements)) @@ -93,12 +93,12 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/int4array.go b/int4array.go index 40e1490d..89caf263 100644 --- a/int4array.go +++ b/int4array.go @@ -18,7 +18,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int4Array: *dst = value - + case []int32: if value == nil { *dst = Int4Array{Status: Null} @@ -37,7 +37,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { Status: Present, } } - + case []uint32: if value == nil { *dst = Int4Array{Status: Null} @@ -56,7 +56,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -69,7 +69,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { func (src *Int4Array) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]int32: if src.Status == Present { *v = make([]int32, len(src.Elements)) @@ -81,7 +81,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } else { *v = nil } - + case *[]uint32: if src.Status == Present { *v = make([]uint32, len(src.Elements)) @@ -93,12 +93,12 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/int8array.go b/int8array.go index 35ecf946..003ed055 100644 --- a/int8array.go +++ b/int8array.go @@ -18,7 +18,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int8Array: *dst = value - + case []int64: if value == nil { *dst = Int8Array{Status: Null} @@ -37,7 +37,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { Status: Present, } } - + case []uint64: if value == nil { *dst = Int8Array{Status: Null} @@ -56,7 +56,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -69,7 +69,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { func (src *Int8Array) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]int64: if src.Status == Present { *v = make([]int64, len(src.Elements)) @@ -81,7 +81,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } else { *v = nil } - + case *[]uint64: if src.Status == Present { *v = make([]uint64, len(src.Elements)) @@ -93,12 +93,12 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype_test.go b/pgtype_test.go index 7d34ae34..304fd0ea 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -13,6 +13,7 @@ import ( ) // Test for renamed types +type _string string type _bool bool type _int8 int8 type _int16 int16 diff --git a/text.go b/text.go new file mode 100644 index 00000000..c9054468 --- /dev/null +++ b/text.go @@ -0,0 +1,115 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + + "github.com/jackc/pgx/pgio" +) + +type Text struct { + String string + Status Status +} + +func (dst *Text) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Text: + *dst = value + case string: + *dst = Text{String: value, Status: Present} + case *string: + if value == nil { + *dst = Text{Status: Null} + } else { + *dst = Text{String: *value, Status: Present} + } + default: + if originalSrc, ok := underlyingStringType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Text", value) + } + + return nil +} + +func (src *Text) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *string: + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.String + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if src.Status == Null { + el.Set(reflect.Zero(el.Type())) + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return src.AssignTo(el.Interface()) + case reflect.String: + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + el.SetString(src.String) + return nil + } + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Text) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Text{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + *dst = Text{String: string(buf), Status: Present} + return nil +} + +func (dst *Text) DecodeBinary(r io.Reader) error { + return dst.DecodeText(r) +} + +func (src Text) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, int32(len(src.String))) + if err != nil { + return nil + } + + _, err = io.WriteString(w, src.String) + return err +} + +func (src Text) EncodeBinary(w io.Writer) error { + return src.EncodeText(w) +} diff --git a/text_test.go b/text_test.go new file mode 100644 index 00000000..6e944857 --- /dev/null +++ b/text_test.go @@ -0,0 +1,100 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestTextTranscode(t *testing.T) { + for _, pgTypeName := range []string{"text", "varchar"} { + testSuccessfulTranscode(t, pgTypeName, []interface{}{ + pgtype.Text{String: "", Status: pgtype.Present}, + pgtype.Text{String: "foo", Status: pgtype.Present}, + pgtype.Text{Status: pgtype.Null}, + }) + } +} + +func TestTextConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Text + }{ + {source: pgtype.Text{String: "foo", Status: pgtype.Present}, result: pgtype.Text{String: "foo", Status: pgtype.Present}}, + {source: "foo", result: pgtype.Text{String: "foo", Status: pgtype.Present}}, + {source: _string("bar"), result: pgtype.Text{String: "bar", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.Text{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var d pgtype.Text + err := d.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestTextAssignTo(t *testing.T) { + var s string + var ps *string + + simpleTests := []struct { + src pgtype.Text + dst interface{} + expected interface{} + }{ + {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &s, expected: "foo"}, + {src: pgtype.Text{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Text + dst interface{} + expected interface{} + }{ + {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &ps, expected: "foo"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Text + dst interface{} + }{ + {src: pgtype.Text{Status: pgtype.Null}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/textarray.go b/textarray.go new file mode 100644 index 00000000..c420e5c9 --- /dev/null +++ b/textarray.go @@ -0,0 +1,297 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type TextArray struct { + Elements []Text + Dimensions []ArrayDimension + Status Status +} + +func (dst *TextArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case TextArray: + *dst = value + + case []string: + if value == nil { + *dst = TextArray{Status: Null} + } else if len(value) == 0 { + *dst = TextArray{Status: Present} + } else { + elements := make([]Text, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = TextArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Text", value) + } + + return nil +} + +func (src *TextArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]string: + if src.Status == Present { + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *TextArray) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = TextArray{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Text + + if len(uta.Elements) > 0 { + elements = make([]Text, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Text + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TextArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TextArray) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = TextArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TextArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Text, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = TextArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *TextArray) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + if elem.String == "" && elem.Status == Present { + _, err := io.WriteString(buf, `""`) + if err != nil { + return err + } + } else { + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *TextArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, TextOID) +} + +func (src *TextArray) encodeBinary(w io.Writer, elementOID int32) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = elementOID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/textarray_test.go b/textarray_test.go new file mode 100644 index 00000000..29e3a6c7 --- /dev/null +++ b/textarray_test.go @@ -0,0 +1,151 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestTextArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "text[]", []interface{}{ + &pgtype.TextArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + pgtype.Text{String: "foo", Status: pgtype.Present}, + pgtype.Text{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TextArray{Status: pgtype.Null}, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + pgtype.Text{String: "bar", Status: pgtype.Present}, + pgtype.Text{String: "baz", Status: pgtype.Present}, + pgtype.Text{String: "quz", Status: pgtype.Present}, + pgtype.Text{String: "", Status: pgtype.Present}, + pgtype.Text{Status: pgtype.Null}, + pgtype.Text{String: "foo", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + pgtype.Text{String: "bar", Status: pgtype.Present}, + pgtype.Text{String: "baz", Status: pgtype.Present}, + pgtype.Text{String: "quz", Status: pgtype.Present}, + pgtype.Text{String: "foo", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestTextArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.TextArray + }{ + { + source: []string{"foo"}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.TextArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.TextArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTextArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + + simpleTests := []struct { + src pgtype.TextArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.TextArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.TextArray + dst interface{} + }{ + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/timestamparray.go b/timestamparray.go index f1b1d003..3acbb35f 100644 --- a/timestamparray.go +++ b/timestamparray.go @@ -68,7 +68,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/timestamptzarray.go b/timestamptzarray.go index 72b28e43..9df746e6 100644 --- a/timestamptzarray.go +++ b/timestamptzarray.go @@ -68,7 +68,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/typed_array.go.erb b/typed_array.go.erb index e6e480b0..647ed7c0 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -67,7 +67,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 47afdf1d..f984e12e 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -8,3 +8,4 @@ erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_type erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID typed_array.go.erb > float4array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID typed_array.go.erb > float8array.go erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID typed_array.go.erb > inetarray.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID typed_array.go.erb > textarray.go diff --git a/varchararray.go b/varchararray.go new file mode 100644 index 00000000..13d94bc0 --- /dev/null +++ b/varchararray.go @@ -0,0 +1,31 @@ +package pgtype + +import ( + "io" +) + +type VarcharArray TextArray + +func (dst *VarcharArray) ConvertFrom(src interface{}) error { + return (*TextArray)(dst).ConvertFrom(src) +} + +func (src *VarcharArray) AssignTo(dst interface{}) error { + return (*TextArray)(src).AssignTo(dst) +} + +func (dst *VarcharArray) DecodeText(r io.Reader) error { + return (*TextArray)(dst).DecodeText(r) +} + +func (dst *VarcharArray) DecodeBinary(r io.Reader) error { + return (*TextArray)(dst).DecodeBinary(r) +} + +func (src *VarcharArray) EncodeText(w io.Writer) error { + return (*TextArray)(src).EncodeText(w) +} + +func (src *VarcharArray) EncodeBinary(w io.Writer) error { + return (*TextArray)(src).encodeBinary(w, VarcharOID) +} From 0437c9f5d63d9f0569311f257485156c834317c1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 22:12:03 -0600 Subject: [PATCH 0015/1158] Move cid to pgtype --- cid.go | 141 ++++++++++++++++++++++++++++++++++++++++++++++++++++ cid_test.go | 94 +++++++++++++++++++++++++++++++++++ pgtype.go | 2 +- 3 files changed, 236 insertions(+), 1 deletion(-) create mode 100644 cid.go create mode 100644 cid_test.go diff --git a/cid.go b/cid.go new file mode 100644 index 00000000..9f8c87d8 --- /dev/null +++ b/cid.go @@ -0,0 +1,141 @@ +package pgtype + +import ( + "fmt" + "io" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +// CID is PostgreSQL's Command Identifier type. +// +// When one does +// +// select cmin, cmax, * from some_table; +// +// it is the data type of the cmin and cmax hidden system columns. +// +// It is currently implemented as an unsigned four byte integer. +// Its definition can be found in src/include/c.h as CommandId +// in the PostgreSQL sources. +type CID struct { + Uint uint32 + Status Status +} + +// ConvertFrom converts from src to dst. Note that as CID is not a general +// number type ConvertFrom does not do automatic type conversion as other number +// types do. +func (dst *CID) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case CID: + *dst = value + case uint32: + *dst = CID{Uint: value, Status: Present} + default: + return fmt.Errorf("cannot convert %v to CID", value) + } + + return nil +} + +// AssignTo assigns from src to dst. Note that as CID is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *CID) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *uint32: + if src.Status == Present { + *v = src.Uint + } else { + return fmt.Errorf("cannot assign %v into %T", src, dst) + } + case **uint32: + if src.Status == Present { + n := src.Uint + *v = &n + } else { + *v = nil + } + } + + return nil +} + +func (dst *CID) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = CID{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseUint(string(buf), 10, 32) + if err != nil { + return err + } + + *dst = CID{Uint: uint32(n), Status: Present} + return nil +} + +func (dst *CID) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = CID{Status: Null} + return nil + } + + if size != 4 { + return fmt.Errorf("invalid length for cid: %v", size) + } + + n, err := pgio.ReadUint32(r) + if err != nil { + return err + } + + *dst = CID{Uint: n, Status: Present} + return nil +} + +func (src CID) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + s := strconv.FormatUint(uint64(src.Uint), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (src CID) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + _, err = pgio.WriteUint32(w, src.Uint) + return err +} diff --git a/cid_test.go b/cid_test.go new file mode 100644 index 00000000..72f5dfea --- /dev/null +++ b/cid_test.go @@ -0,0 +1,94 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestCIDTranscode(t *testing.T) { + testSuccessfulTranscode(t, "cid", []interface{}{ + pgtype.CID{Uint: 42, Status: pgtype.Present}, + pgtype.CID{Status: pgtype.Null}, + }) +} + +func TestCIDConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.CID + }{ + {source: uint32(1), result: pgtype.CID{Uint: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.CID + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestCIDAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.CID + dst interface{} + expected interface{} + }{ + {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.CID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.CID + dst interface{} + expected interface{} + }{ + {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.CID + dst interface{} + }{ + {src: pgtype.CID{Status: pgtype.Null}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype.go b/pgtype.go index 5722c8ab..1200bf12 100644 --- a/pgtype.go +++ b/pgtype.go @@ -20,7 +20,7 @@ const ( OIDOID = 26 TidOID = 27 XidOID = 28 - CidOID = 29 + CIDOID = 29 JSONOID = 114 CidrOID = 650 CidrArrayOID = 651 From 3aad9c08d58e31093c573e748bbc9b346bc94f77 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 5 Mar 2017 08:59:26 -0600 Subject: [PATCH 0016/1158] Generalize array template --- typed_array.go.erb | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/typed_array.go.erb b/typed_array.go.erb index 647ed7c0..8c18073b 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -211,9 +211,16 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) error { } textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) - if err != nil { - return err + if elem.String == "" && elem.Status == Present { + _, err := io.WriteString(buf, `""`) + if err != nil { + return err + } + } else { + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } } for _, dec := range dimElemCounts { @@ -236,6 +243,10 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) error { } func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, <%= element_oid %>) +} + +func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -256,7 +267,7 @@ func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = <%= element_oid %> + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - From 8922421ad60854c5387cd7a863b1a1d0b21b0e14 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 5 Mar 2017 09:07:07 -0600 Subject: [PATCH 0017/1158] Move XID to pgypte --- pgtype.go | 2 +- xid.go | 45 +++++++++++++++++++++++++ xid_test.go | 94 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 xid.go create mode 100644 xid_test.go diff --git a/pgtype.go b/pgtype.go index 1200bf12..15c0cc76 100644 --- a/pgtype.go +++ b/pgtype.go @@ -19,7 +19,7 @@ const ( TextOID = 25 OIDOID = 26 TidOID = 27 - XidOID = 28 + XIDOID = 28 CIDOID = 29 JSONOID = 114 CidrOID = 650 diff --git a/xid.go b/xid.go new file mode 100644 index 00000000..f4d087a5 --- /dev/null +++ b/xid.go @@ -0,0 +1,45 @@ +package pgtype + +import ( + "io" +) + +// Xid is PostgreSQL's Transaction ID type. +// +// In later versions of PostgreSQL, it is the type used for the backend_xid +// and backend_xmin columns of the pg_stat_activity system view. +// +// Also, when one does +// +// select xmin, xmax, * from some_table; +// +// it is the data type of the xmin and xmax hidden system columns. +// +// It is currently implemented as an unsigned four byte integer. +// Its definition can be found in src/include/postgres_ext.h as TransactionId +// in the PostgreSQL sources. +type XID CID + +func (dst *XID) ConvertFrom(src interface{}) error { + return (*CID)(dst).ConvertFrom(src) +} + +func (src *XID) AssignTo(dst interface{}) error { + return (*CID)(src).AssignTo(dst) +} + +func (dst *XID) DecodeText(r io.Reader) error { + return (*CID)(dst).DecodeText(r) +} + +func (dst *XID) DecodeBinary(r io.Reader) error { + return (*CID)(dst).DecodeBinary(r) +} + +func (src XID) EncodeText(w io.Writer) error { + return (CID)(src).EncodeText(w) +} + +func (src XID) EncodeBinary(w io.Writer) error { + return (CID)(src).EncodeBinary(w) +} diff --git a/xid_test.go b/xid_test.go new file mode 100644 index 00000000..664920bc --- /dev/null +++ b/xid_test.go @@ -0,0 +1,94 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestXIDTranscode(t *testing.T) { + testSuccessfulTranscode(t, "xid", []interface{}{ + pgtype.XID{Uint: 42, Status: pgtype.Present}, + pgtype.XID{Status: pgtype.Null}, + }) +} + +func TestXIDConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.XID + }{ + {source: uint32(1), result: pgtype.XID{Uint: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.XID + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestXIDAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.XID + dst interface{} + expected interface{} + }{ + {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.XID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.XID + dst interface{} + expected interface{} + }{ + {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.XID + dst interface{} + }{ + {src: pgtype.XID{Status: pgtype.Null}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} From 603d829611fa3bfe43c35efe6f35bd8e75f5d666 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 5 Mar 2017 09:13:25 -0600 Subject: [PATCH 0018/1158] Extract pguint32 --- cid.go | 108 +++---------------------------------------- pguint32.go | 130 ++++++++++++++++++++++++++++++++++++++++++++++++++++ xid.go | 19 +++++--- 3 files changed, 149 insertions(+), 108 deletions(-) create mode 100644 pguint32.go diff --git a/cid.go b/cid.go index 9f8c87d8..21d6fb80 100644 --- a/cid.go +++ b/cid.go @@ -1,11 +1,7 @@ package pgtype import ( - "fmt" "io" - "strconv" - - "github.com/jackc/pgx/pgio" ) // CID is PostgreSQL's Command Identifier type. @@ -19,123 +15,33 @@ import ( // It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/c.h as CommandId // in the PostgreSQL sources. -type CID struct { - Uint uint32 - Status Status -} +type CID pguint32 // ConvertFrom converts from src to dst. Note that as CID is not a general // number type ConvertFrom does not do automatic type conversion as other number // types do. func (dst *CID) ConvertFrom(src interface{}) error { - switch value := src.(type) { - case CID: - *dst = value - case uint32: - *dst = CID{Uint: value, Status: Present} - default: - return fmt.Errorf("cannot convert %v to CID", value) - } - - return nil + return (*pguint32)(dst).ConvertFrom(src) } // AssignTo assigns from src to dst. Note that as CID is not a general number // type AssignTo does not do automatic type conversion as other number types do. func (src *CID) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *uint32: - if src.Status == Present { - *v = src.Uint - } else { - return fmt.Errorf("cannot assign %v into %T", src, dst) - } - case **uint32: - if src.Status == Present { - n := src.Uint - *v = &n - } else { - *v = nil - } - } - - return nil + return (*pguint32)(src).AssignTo(dst) } func (dst *CID) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { - *dst = CID{Status: Null} - return nil - } - - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseUint(string(buf), 10, 32) - if err != nil { - return err - } - - *dst = CID{Uint: uint32(n), Status: Present} - return nil + return (*pguint32)(dst).DecodeText(r) } func (dst *CID) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { - *dst = CID{Status: Null} - return nil - } - - if size != 4 { - return fmt.Errorf("invalid length for cid: %v", size) - } - - n, err := pgio.ReadUint32(r) - if err != nil { - return err - } - - *dst = CID{Uint: n, Status: Present} - return nil + return (*pguint32)(dst).DecodeBinary(r) } func (src CID) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err - } - - s := strconv.FormatUint(uint64(src.Uint), 10) - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err + return (pguint32)(src).EncodeText(w) } func (src CID) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err - } - - _, err := pgio.WriteInt32(w, 4) - if err != nil { - return err - } - - _, err = pgio.WriteUint32(w, src.Uint) - return err + return (pguint32)(src).EncodeBinary(w) } diff --git a/pguint32.go b/pguint32.go new file mode 100644 index 00000000..66b385fb --- /dev/null +++ b/pguint32.go @@ -0,0 +1,130 @@ +package pgtype + +import ( + "fmt" + "io" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +// pguint32 is the core type that is used to implement PostgreSQL types such as +// CID and XID. +type pguint32 struct { + Uint uint32 + Status Status +} + +// ConvertFrom converts from src to dst. Note that as pguint32 is not a general +// number type ConvertFrom does not do automatic type conversion as other number +// types do. +func (dst *pguint32) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case uint32: + *dst = pguint32{Uint: value, Status: Present} + default: + return fmt.Errorf("cannot convert %v to pguint32", value) + } + + return nil +} + +// AssignTo assigns from src to dst. Note that as pguint32 is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *pguint32) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *uint32: + if src.Status == Present { + *v = src.Uint + } else { + return fmt.Errorf("cannot assign %v into %T", src, dst) + } + case **uint32: + if src.Status == Present { + n := src.Uint + *v = &n + } else { + *v = nil + } + } + + return nil +} + +func (dst *pguint32) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = pguint32{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseUint(string(buf), 10, 32) + if err != nil { + return err + } + + *dst = pguint32{Uint: uint32(n), Status: Present} + return nil +} + +func (dst *pguint32) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = pguint32{Status: Null} + return nil + } + + if size != 4 { + return fmt.Errorf("invalid length for cid: %v", size) + } + + n, err := pgio.ReadUint32(r) + if err != nil { + return err + } + + *dst = pguint32{Uint: n, Status: Present} + return nil +} + +func (src pguint32) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + s := strconv.FormatUint(uint64(src.Uint), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (src pguint32) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + _, err = pgio.WriteUint32(w, src.Uint) + return err +} diff --git a/xid.go b/xid.go index f4d087a5..b311cbfb 100644 --- a/xid.go +++ b/xid.go @@ -18,28 +18,33 @@ import ( // It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/postgres_ext.h as TransactionId // in the PostgreSQL sources. -type XID CID +type XID pguint32 +// ConvertFrom converts from src to dst. Note that as XID is not a general +// number type ConvertFrom does not do automatic type conversion as other number +// types do. func (dst *XID) ConvertFrom(src interface{}) error { - return (*CID)(dst).ConvertFrom(src) + return (*pguint32)(dst).ConvertFrom(src) } +// AssignTo assigns from src to dst. Note that as XID is not a general number +// type AssignTo does not do automatic type conversion as other number types do. func (src *XID) AssignTo(dst interface{}) error { - return (*CID)(src).AssignTo(dst) + return (*pguint32)(src).AssignTo(dst) } func (dst *XID) DecodeText(r io.Reader) error { - return (*CID)(dst).DecodeText(r) + return (*pguint32)(dst).DecodeText(r) } func (dst *XID) DecodeBinary(r io.Reader) error { - return (*CID)(dst).DecodeBinary(r) + return (*pguint32)(dst).DecodeBinary(r) } func (src XID) EncodeText(w io.Writer) error { - return (CID)(src).EncodeText(w) + return (pguint32)(src).EncodeText(w) } func (src XID) EncodeBinary(w io.Writer) error { - return (CID)(src).EncodeBinary(w) + return (pguint32)(src).EncodeBinary(w) } From 6f9aef67c7c16c6f9fb5ffdd42bf6dc823c9b74a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 5 Mar 2017 09:18:50 -0600 Subject: [PATCH 0019/1158] Fix comment on XID --- xid.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xid.go b/xid.go index b311cbfb..d4003b5d 100644 --- a/xid.go +++ b/xid.go @@ -4,7 +4,7 @@ import ( "io" ) -// Xid is PostgreSQL's Transaction ID type. +// XID is PostgreSQL's Transaction ID type. // // In later versions of PostgreSQL, it is the type used for the backend_xid // and backend_xmin columns of the pg_stat_activity system view. From b139307f5bdfab58712841628e701959749523ff Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 5 Mar 2017 13:05:49 -0600 Subject: [PATCH 0020/1158] Move OID to pgtype --- oid.go | 41 +++++++++++++++++++++++ oid_test.go | 94 +++++++++++++++++++++++++++++++++++++++++++++++++++++ pguint32.go | 2 +- 3 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 oid.go create mode 100644 oid_test.go diff --git a/oid.go b/oid.go new file mode 100644 index 00000000..d137f352 --- /dev/null +++ b/oid.go @@ -0,0 +1,41 @@ +package pgtype + +import ( + "io" +) + +// OID (Object Identifier Type) is, according to +// https://www.postgresql.org/docs/current/static/datatype-oid.html, used +// internally by PostgreSQL as a primary key for various system tables. It is +// currently implemented as an unsigned four-byte integer. Its definition can be +// found in src/include/postgres_ext.h in the PostgreSQL sources. +type OID pguint32 + +// ConvertFrom converts from src to dst. Note that as OID is not a general +// number type ConvertFrom does not do automatic type conversion as other number +// types do. +func (dst *OID) ConvertFrom(src interface{}) error { + return (*pguint32)(dst).ConvertFrom(src) +} + +// AssignTo assigns from src to dst. Note that as OID is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *OID) AssignTo(dst interface{}) error { + return (*pguint32)(src).AssignTo(dst) +} + +func (dst *OID) DecodeText(r io.Reader) error { + return (*pguint32)(dst).DecodeText(r) +} + +func (dst *OID) DecodeBinary(r io.Reader) error { + return (*pguint32)(dst).DecodeBinary(r) +} + +func (src OID) EncodeText(w io.Writer) error { + return (pguint32)(src).EncodeText(w) +} + +func (src OID) EncodeBinary(w io.Writer) error { + return (pguint32)(src).EncodeBinary(w) +} diff --git a/oid_test.go b/oid_test.go new file mode 100644 index 00000000..c8e0b2d6 --- /dev/null +++ b/oid_test.go @@ -0,0 +1,94 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestOIDTranscode(t *testing.T) { + testSuccessfulTranscode(t, "oid", []interface{}{ + pgtype.OID{Uint: 42, Status: pgtype.Present}, + pgtype.OID{Status: pgtype.Null}, + }) +} + +func TestOIDConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.OID + }{ + {source: uint32(1), result: pgtype.OID{Uint: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.OID + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestOIDAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.OID + dst interface{} + expected interface{} + }{ + {src: pgtype.OID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.OID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.OID + dst interface{} + expected interface{} + }{ + {src: pgtype.OID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.OID + dst interface{} + }{ + {src: pgtype.OID{Status: pgtype.Null}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pguint32.go b/pguint32.go index 66b385fb..9c1ccd6c 100644 --- a/pguint32.go +++ b/pguint32.go @@ -89,7 +89,7 @@ func (dst *pguint32) DecodeBinary(r io.Reader) error { } if size != 4 { - return fmt.Errorf("invalid length for cid: %v", size) + return fmt.Errorf("invalid length: %v", size) } n, err := pgio.ReadUint32(r) From 94612427ed655a8216a8259318133864d88d5d8d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 6 Mar 2017 17:55:20 -0600 Subject: [PATCH 0021/1158] Move Name to pgtype --- name.go | 44 ++++++++++++++++++++++++ name_test.go | 97 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+) create mode 100644 name.go create mode 100644 name_test.go diff --git a/name.go b/name.go new file mode 100644 index 00000000..3ff81f12 --- /dev/null +++ b/name.go @@ -0,0 +1,44 @@ +package pgtype + +import ( + "io" +) + +// Name is a type used for PostgreSQL's special 63-byte +// name data type, used for identifiers like table names. +// The pg_class.relname column is a good example of where the +// name data type is used. +// +// Note that the underlying Go data type of pgx.Name is string, +// so there is no way to enforce the 63-byte length. Inputting +// a longer name into PostgreSQL will result in silent truncation +// to 63 bytes. +// +// Also, if you have custom-compiled PostgreSQL and set +// NAMEDATALEN to a different value, obviously that number of +// bytes applies, rather than the default 63. +type Name Text + +func (dst *Name) ConvertFrom(src interface{}) error { + return (*Text)(dst).ConvertFrom(src) +} + +func (src *Name) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Name) DecodeText(r io.Reader) error { + return (*Text)(dst).DecodeText(r) +} + +func (dst *Name) DecodeBinary(r io.Reader) error { + return (*Text)(dst).DecodeBinary(r) +} + +func (src Name) EncodeText(w io.Writer) error { + return (Text)(src).EncodeText(w) +} + +func (src Name) EncodeBinary(w io.Writer) error { + return (Text)(src).EncodeBinary(w) +} diff --git a/name_test.go b/name_test.go new file mode 100644 index 00000000..c5f7de17 --- /dev/null +++ b/name_test.go @@ -0,0 +1,97 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestNameTranscode(t *testing.T) { + testSuccessfulTranscode(t, "name", []interface{}{ + pgtype.Name{String: "", Status: pgtype.Present}, + pgtype.Name{String: "foo", Status: pgtype.Present}, + pgtype.Name{Status: pgtype.Null}, + }) +} + +func TestNameConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Name + }{ + {source: "foo", result: pgtype.Name{String: "foo", Status: pgtype.Present}}, + {source: _string("bar"), result: pgtype.Name{String: "bar", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.Name{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var d pgtype.Name + err := d.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestNameAssignTo(t *testing.T) { + var s string + var ps *string + + simpleTests := []struct { + src pgtype.Name + dst interface{} + expected interface{} + }{ + {src: pgtype.Name{String: "foo", Status: pgtype.Present}, dst: &s, expected: "foo"}, + {src: pgtype.Name{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Name + dst interface{} + expected interface{} + }{ + {src: pgtype.Name{String: "foo", Status: pgtype.Present}, dst: &ps, expected: "foo"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Name + dst interface{} + }{ + {src: pgtype.Name{Status: pgtype.Null}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} From eea6e5a64c6474f5d7d45e6c1e633209453c3d1e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 7 Mar 2017 19:39:57 -0600 Subject: [PATCH 0022/1158] Move "char" to pgtype --- pgtype_test.go | 22 +++++--- qchar.go | 144 +++++++++++++++++++++++++++++++++++++++++++++++++ qchar_test.go | 140 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 299 insertions(+), 7 deletions(-) create mode 100644 qchar.go create mode 100644 qchar_test.go diff --git a/pgtype_test.go b/pgtype_test.go index 304fd0ea..c1dba383 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -74,12 +74,15 @@ func (f forceBinaryEncoder) EncodeBinary(w io.Writer) error { func forceEncoder(e interface{}, formatCode int16) interface{} { switch formatCode { case pgx.TextFormatCode: - return forceTextEncoder{e: e.(pgtype.TextEncoder)} + if e, ok := e.(pgtype.TextEncoder); ok { + return forceTextEncoder{e: e} + } case pgx.BinaryFormatCode: - return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} - default: - panic("bad encoder") + if e, ok := e.(pgtype.BinaryEncoder); ok { + return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} + } } + return nil } func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { @@ -105,9 +108,14 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, } - for _, fc := range formats { - ps.FieldDescriptions[0].FormatCode = fc.formatCode - for i, v := range values { + for i, v := range values { + for _, fc := range formats { + ps.FieldDescriptions[0].FormatCode = fc.formatCode + vEncoder := forceEncoder(v, fc.formatCode) + if vEncoder == nil { + t.Logf("%v does not implement %v", fc.name) + continue + } // Derefence value if it is a pointer derefV := v refVal := reflect.ValueOf(v) diff --git a/qchar.go b/qchar.go new file mode 100644 index 00000000..6dd14625 --- /dev/null +++ b/qchar.go @@ -0,0 +1,144 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +// QChar is for PostgreSQL's special 8-bit-only "char" type more akin to the C +// language's char type, or Go's byte type. (Note that the name in PostgreSQL +// itself is "char", in double-quotes, and not char.) It gets used a lot in +// PostgreSQL's system tables to hold a single ASCII character value (eg +// pg_class.relkind). It is named Qchar for quoted char to disambiguate from SQL +// standard type char. +// +// Not all possible values of QChar are representable in the text format. +// Therefore, QChar does not implement TextEncoder and TextDecoder. +type QChar struct { + Int int8 + Status Status +} + +func (dst *QChar) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case QChar: + *dst = value + case int8: + *dst = QChar{Int: value, Status: Present} + case uint8: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int16: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint16: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int32: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint32: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int64: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint64: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 8) + if err != nil { + return err + } + *dst = QChar{Int: int8(num), Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to QChar", value) + } + + return nil +} + +func (src *QChar) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) +} + +func (dst *QChar) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = QChar{Status: Null} + return nil + } + + if size != 1 { + return fmt.Errorf(`invalid length for "char": %v`, size) + } + + byt, err := pgio.ReadByte(r) + if err != nil { + return err + } + + *dst = QChar{Int: int8(byt), Status: Present} + return nil +} + +func (src QChar) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 1) + if err != nil { + return nil + } + + return pgio.WriteByte(w, byte(src.Int)) +} diff --git a/qchar_test.go b/qchar_test.go new file mode 100644 index 00000000..ea7b56a8 --- /dev/null +++ b/qchar_test.go @@ -0,0 +1,140 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestQCharTranscode(t *testing.T) { + testSuccessfulTranscode(t, `"char"`, []interface{}{ + pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, + pgtype.QChar{Int: -1, Status: pgtype.Present}, + pgtype.QChar{Int: 0, Status: pgtype.Present}, + pgtype.QChar{Int: 1, Status: pgtype.Present}, + pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present}, + pgtype.QChar{Int: 0, Status: pgtype.Null}, + }) +} + +func TestQCharConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.QChar + }{ + {source: int8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.QChar + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestQCharAssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.QChar + dst interface{} + expected interface{} + }{ + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.QChar + dst interface{} + expected interface{} + }{ + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.QChar + dst interface{} + }{ + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &i16}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} From 8fa9afbb365bdcd85b2dd07ec32c1da4dc5d4a1f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 9 Mar 2017 21:07:40 -0600 Subject: [PATCH 0023/1158] Add bytea --- bytea.go | 160 +++++++++++++++++++++++++++++++++++++++++++++++++ bytea_test.go | 73 ++++++++++++++++++++++ convert.go | 21 +++++++ pgtype_test.go | 1 + 4 files changed, 255 insertions(+) create mode 100644 bytea.go create mode 100644 bytea_test.go diff --git a/bytea.go b/bytea.go new file mode 100644 index 00000000..2532182f --- /dev/null +++ b/bytea.go @@ -0,0 +1,160 @@ +package pgtype + +import ( + "encoding/hex" + "fmt" + "io" + "reflect" + + "github.com/jackc/pgx/pgio" +) + +type Bytea struct { + Bytes []byte + Status Status +} + +func (dst *Bytea) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Bytea: + *dst = value + case []byte: + if value != nil { + *dst = Bytea{Bytes: value, Status: Present} + } else { + *dst = Bytea{Status: Null} + } + default: + if originalSrc, ok := underlyingBytesType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Bytea", value) + } + + return nil +} + +func (src *Bytea) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *[]byte: + if src.Status == Present { + *v = src.Bytes + } else { + *v = nil + } + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if src.Status == Null { + el.Set(reflect.Zero(el.Type())) + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return src.AssignTo(el.Interface()) + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + } + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +// DecodeText only supports the hex format. This has been the default since +// PostgreSQL 9.0. +func (dst *Bytea) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Bytea{Status: Null} + return nil + } + + sbuf := make([]byte, int(size)) + _, err = io.ReadFull(r, sbuf) + if err != nil { + return err + } + + if len(sbuf) < 2 || sbuf[0] != '\\' || sbuf[1] != 'x' { + return fmt.Errorf("invalid hex format") + } + + buf := make([]byte, (len(sbuf)-2)/2) + _, err = hex.Decode(buf, sbuf[2:]) + if err != nil { + return err + } + + *dst = Bytea{Bytes: buf, Status: Present} + return nil +} + +func (dst *Bytea) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Bytea{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + *dst = Bytea{Bytes: buf, Status: Present} + return nil +} + +func (src Bytea) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + str := hex.EncodeToString(src.Bytes) + + _, err := pgio.WriteInt32(w, int32(len(str)+2)) + if err != nil { + return nil + } + + _, err = io.WriteString(w, `\x`) + if err != nil { + return nil + } + + _, err = io.WriteString(w, str) + return err +} + +func (src Bytea) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, int32(len(src.Bytes))) + if err != nil { + return nil + } + + _, err = w.Write(src.Bytes) + return err +} diff --git a/bytea_test.go b/bytea_test.go new file mode 100644 index 00000000..51941387 --- /dev/null +++ b/bytea_test.go @@ -0,0 +1,73 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestByteaTranscode(t *testing.T) { + testSuccessfulTranscode(t, "bytea", []interface{}{ + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: nil, Status: pgtype.Null}, + }) +} + +func TestByteaConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Bytea + }{ + {source: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Null}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Null}}, + {source: []byte{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + {source: []byte{}, result: pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}}, + {source: []byte(nil), result: pgtype.Bytea{Status: pgtype.Null}}, + {source: _byteSlice{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + {source: _byteSlice(nil), result: pgtype.Bytea{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var r pgtype.Bytea + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestByteaAssignTo(t *testing.T) { + var buf []byte + var _buf _byteSlice + var pbuf *[]byte + var _pbuf *_byteSlice + + simpleTests := []struct { + src pgtype.Bytea + dst interface{} + expected interface{} + }{ + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &buf, expected: []byte{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_buf, expected: _byteSlice{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &pbuf, expected: &[]byte{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_pbuf, expected: &_byteSlice{1, 2, 3}}, + {src: pgtype.Bytea{Status: pgtype.Null}, dst: &pbuf, expected: ((*[]byte)(nil))}, + {src: pgtype.Bytea{Status: pgtype.Null}, dst: &_pbuf, expected: ((*_byteSlice)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/convert.go b/convert.go index 31bbf060..648209f5 100644 --- a/convert.go +++ b/convert.go @@ -85,6 +85,27 @@ func underlyingBoolType(val interface{}) (interface{}, bool) { return nil, false } +// underlyingBytesType gets the underlying type that can be converted to []byte +func underlyingBytesType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + if refVal.Type().Elem().Kind() == reflect.Uint8 { + convVal := refVal.Bytes() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + } + + return nil, false +} + // underlyingStringType gets the underlying type that can be converted to String func underlyingStringType(val interface{}) (interface{}, bool) { refVal := reflect.ValueOf(val) diff --git a/pgtype_test.go b/pgtype_test.go index c1dba383..6e173cbe 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -22,6 +22,7 @@ type _int32Slice []int32 type _int64Slice []int64 type _float32Slice []float32 type _float64Slice []float64 +type _byteSlice []byte func mustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) From bb7122d4a8a9e2da89fe5edabfcfab0a03fb853e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 9 Mar 2017 21:09:36 -0600 Subject: [PATCH 0024/1158] Fix typed_array_gen.sh typo --- typed_array_gen.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/typed_array_gen.sh b/typed_array_gen.sh index f984e12e..1e2dce64 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -1,6 +1,6 @@ erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2OID typed_array.go.erb > int2array.go erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4OID typed_array.go.erb > int4array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8OID typed_array.go.erb > int2array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8OID typed_array.go.erb > int8array.go erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOID typed_array.go.erb > boolarray.go erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID typed_array.go.erb > datearray.go erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID typed_array.go.erb > timestamptzarray.go From 361a54abb7410fdd25c9744e7aee55cb8714dc49 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 10 Mar 2017 16:08:47 -0600 Subject: [PATCH 0025/1158] Decode(Text|Binary) now accepts []byte instead of io.Reader --- array.go | 46 +++++++++++++++++----------------- bool.go | 40 ++++++++---------------------- boolarray.go | 52 +++++++++++++++++++-------------------- bytea.go | 39 ++++++----------------------- cid.go | 8 +++--- cidrarray.go | 8 +++--- date.go | 36 +++++++-------------------- datearray.go | 52 +++++++++++++++++++-------------------- float4.go | 36 +++++++-------------------- float4array.go | 52 +++++++++++++++++++-------------------- float8.go | 36 +++++++-------------------- float8array.go | 52 +++++++++++++++++++-------------------- inet.go | 60 +++++++++------------------------------------ inetarray.go | 52 ++++++++++++++++++--------------------- int2.go | 39 ++++++++--------------------- int2array.go | 52 +++++++++++++++++++-------------------- int4.go | 37 +++++++--------------------- int4array.go | 52 +++++++++++++++++++-------------------- int8.go | 36 +++++++-------------------- int8array.go | 52 +++++++++++++++++++-------------------- name.go | 8 +++--- oid.go | 8 +++--- pgtype.go | 4 +-- pguint32.go | 37 +++++++--------------------- qchar.go | 20 ++++----------- text.go | 21 ++++------------ textarray.go | 53 ++++++++++++++++++++------------------- timestamp.go | 36 +++++++-------------------- timestamparray.go | 52 +++++++++++++++++++-------------------- timestamptz.go | 36 +++++++-------------------- timestamptzarray.go | 52 +++++++++++++++++++-------------------- to-consider.txt | 9 +++++++ typed_array.go.erb | 58 +++++++++++++++++-------------------------- varchararray.go | 8 +++--- xid.go | 8 +++--- 35 files changed, 476 insertions(+), 771 deletions(-) create mode 100644 to-consider.txt diff --git a/array.go b/array.go index 76492c61..6b705103 100644 --- a/array.go +++ b/array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" "strconv" @@ -25,40 +26,37 @@ type ArrayDimension struct { LowerBound int32 } -func (dst *ArrayHeader) DecodeBinary(r io.Reader) error { - numDims, err := pgio.ReadInt32(r) - if err != nil { - return err +func (dst *ArrayHeader) DecodeBinary(src []byte) (int, error) { + if len(src) < 12 { + return 0, fmt.Errorf("array header too short: %d", len(src)) } + rp := 0 + + numDims := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1 + rp += 4 + + dst.ElementOID = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + if numDims > 0 { dst.Dimensions = make([]ArrayDimension, numDims) } - - containsNull, err := pgio.ReadInt32(r) - if err != nil { - return err + if len(src) < 12+numDims*8 { + return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) } - dst.ContainsNull = containsNull == 1 - - dst.ElementOID, err = pgio.ReadInt32(r) - if err != nil { - return err - } - for i := range dst.Dimensions { - dst.Dimensions[i].Length, err = pgio.ReadInt32(r) - if err != nil { - return err - } + dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 - dst.Dimensions[i].LowerBound, err = pgio.ReadInt32(r) - if err != nil { - return err - } + dst.Dimensions[i].LowerBound = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 } - return nil + return rp, nil } func (src *ArrayHeader) EncodeBinary(w io.Writer) error { diff --git a/bool.go b/bool.go index 076403f9..b7bc14d0 100644 --- a/bool.go +++ b/bool.go @@ -72,51 +72,31 @@ func (src *Bool) AssignTo(dst interface{}) error { return nil } -func (dst *Bool) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Bool) DecodeText(src []byte) error { + if src == nil { *dst = Bool{Status: Null} return nil } - if size != 1 { - return fmt.Errorf("invalid length for bool: %v", size) + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) } - byt, err := pgio.ReadByte(r) - if err != nil { - return err - } - - *dst = Bool{Bool: byt == 't', Status: Present} + *dst = Bool{Bool: src[0] == 't', Status: Present} return nil } -func (dst *Bool) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Bool) DecodeBinary(src []byte) error { + if src == nil { *dst = Bool{Status: Null} return nil } - if size != 1 { - return fmt.Errorf("invalid length for bool: %v", size) + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) } - byt, err := pgio.ReadByte(r) - if err != nil { - return err - } - - *dst = Bool{Bool: byt == 1, Status: Present} + *dst = Bool{Bool: src[0] == 1, Status: Present} return nil } diff --git a/boolarray.go b/boolarray.go index b6b5db02..a9b8bf50 100644 --- a/boolarray.go +++ b/boolarray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -73,29 +74,17 @@ func (src *BoolArray) AssignTo(dst interface{}) error { return nil } -func (dst *BoolArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *BoolArray) DecodeText(src []byte) error { + if src == nil { *dst = BoolArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Bool if len(uta.Elements) > 0 { @@ -103,8 +92,11 @@ func (dst *BoolArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Bool - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -118,19 +110,14 @@ func (dst *BoolArray) DecodeText(r io.Reader) error { return nil } -func (dst *BoolArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *BoolArray) DecodeBinary(src []byte) error { + if src == nil { *dst = BoolArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -148,7 +135,14 @@ func (dst *BoolArray) DecodeBinary(r io.Reader) error { elements := make([]Bool, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -236,6 +230,10 @@ func (src *BoolArray) EncodeText(w io.Writer) error { } func (src *BoolArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, BoolOID) +} + +func (src *BoolArray) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -256,7 +254,7 @@ func (src *BoolArray) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = BoolOID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/bytea.go b/bytea.go index 2532182f..db20482f 100644 --- a/bytea.go +++ b/bytea.go @@ -71,29 +71,18 @@ func (src *Bytea) AssignTo(dst interface{}) error { // DecodeText only supports the hex format. This has been the default since // PostgreSQL 9.0. -func (dst *Bytea) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Bytea) DecodeText(src []byte) error { + if src == nil { *dst = Bytea{Status: Null} return nil } - sbuf := make([]byte, int(size)) - _, err = io.ReadFull(r, sbuf) - if err != nil { - return err - } - - if len(sbuf) < 2 || sbuf[0] != '\\' || sbuf[1] != 'x' { + if len(src) < 2 || src[0] != '\\' || src[1] != 'x' { return fmt.Errorf("invalid hex format") } - buf := make([]byte, (len(sbuf)-2)/2) - _, err = hex.Decode(buf, sbuf[2:]) + buf := make([]byte, (len(src)-2)/2) + _, err := hex.Decode(buf, src[2:]) if err != nil { return err } @@ -102,25 +91,13 @@ func (dst *Bytea) DecodeText(r io.Reader) error { return nil } -func (dst *Bytea) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Bytea) DecodeBinary(src []byte) error { + if src == nil { *dst = Bytea{Status: Null} return nil } - buf := make([]byte, int(size)) - - _, err = io.ReadFull(r, buf) - if err != nil { - return err - } - - *dst = Bytea{Bytes: buf, Status: Present} + *dst = Bytea{Bytes: src, Status: Present} return nil } diff --git a/cid.go b/cid.go index 21d6fb80..f8d706d0 100644 --- a/cid.go +++ b/cid.go @@ -30,12 +30,12 @@ func (src *CID) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *CID) DecodeText(r io.Reader) error { - return (*pguint32)(dst).DecodeText(r) +func (dst *CID) DecodeText(src []byte) error { + return (*pguint32)(dst).DecodeText(src) } -func (dst *CID) DecodeBinary(r io.Reader) error { - return (*pguint32)(dst).DecodeBinary(r) +func (dst *CID) DecodeBinary(src []byte) error { + return (*pguint32)(dst).DecodeBinary(src) } func (src CID) EncodeText(w io.Writer) error { diff --git a/cidrarray.go b/cidrarray.go index 66dd20d0..d95eef4a 100644 --- a/cidrarray.go +++ b/cidrarray.go @@ -14,12 +14,12 @@ func (src *CidrArray) AssignTo(dst interface{}) error { return (*InetArray)(src).AssignTo(dst) } -func (dst *CidrArray) DecodeText(r io.Reader) error { - return (*InetArray)(dst).DecodeText(r) +func (dst *CidrArray) DecodeText(src []byte) error { + return (*InetArray)(dst).DecodeText(src) } -func (dst *CidrArray) DecodeBinary(r io.Reader) error { - return (*InetArray)(dst).DecodeBinary(r) +func (dst *CidrArray) DecodeBinary(src []byte) error { + return (*InetArray)(dst).DecodeBinary(src) } func (src *CidrArray) EncodeText(w io.Writer) error { diff --git a/date.go b/date.go index 307f1e59..1bb81d35 100644 --- a/date.go +++ b/date.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "reflect" @@ -66,24 +67,13 @@ func (src *Date) AssignTo(dst interface{}) error { return nil } -func (dst *Date) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Date) DecodeText(src []byte) error { + if src == nil { *dst = Date{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - sbuf := string(buf) + sbuf := string(src) switch sbuf { case "infinity": *dst = Date{Status: Present, InfinityModifier: Infinity} @@ -101,25 +91,17 @@ func (dst *Date) DecodeText(r io.Reader) error { return nil } -func (dst *Date) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Date) DecodeBinary(src []byte) error { + if src == nil { *dst = Date{Status: Null} return nil } - if size != 4 { - return fmt.Errorf("invalid length for date: %v", size) + if len(src) != 4 { + return fmt.Errorf("invalid length for date: %v", len(src)) } - dayOffset, err := pgio.ReadInt32(r) - if err != nil { - return err - } + dayOffset := int32(binary.BigEndian.Uint32(src)) switch dayOffset { case infinityDayOffset: diff --git a/datearray.go b/datearray.go index 5e93501e..e9ad1f62 100644 --- a/datearray.go +++ b/datearray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" "time" @@ -74,29 +75,17 @@ func (src *DateArray) AssignTo(dst interface{}) error { return nil } -func (dst *DateArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *DateArray) DecodeText(src []byte) error { + if src == nil { *dst = DateArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Date if len(uta.Elements) > 0 { @@ -104,8 +93,11 @@ func (dst *DateArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Date - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -119,19 +111,14 @@ func (dst *DateArray) DecodeText(r io.Reader) error { return nil } -func (dst *DateArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *DateArray) DecodeBinary(src []byte) error { + if src == nil { *dst = DateArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -149,7 +136,14 @@ func (dst *DateArray) DecodeBinary(r io.Reader) error { elements := make([]Date, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -237,6 +231,10 @@ func (src *DateArray) EncodeText(w io.Writer) error { } func (src *DateArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, DateOID) +} + +func (src *DateArray) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -257,7 +255,7 @@ func (src *DateArray) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = DateOID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/float4.go b/float4.go index a1e5aa18..fb0415e5 100644 --- a/float4.go +++ b/float4.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "math" @@ -92,24 +93,13 @@ func (src *Float4) AssignTo(dst interface{}) error { return float64AssignTo(float64(src.Float), src.Status, dst) } -func (dst *Float4) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float4) DecodeText(src []byte) error { + if src == nil { *dst = Float4{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseFloat(string(buf), 32) + n, err := strconv.ParseFloat(string(src), 32) if err != nil { return err } @@ -118,25 +108,17 @@ func (dst *Float4) DecodeText(r io.Reader) error { return nil } -func (dst *Float4) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float4) DecodeBinary(src []byte) error { + if src == nil { *dst = Float4{Status: Null} return nil } - if size != 4 { - return fmt.Errorf("invalid length for float4: %v", size) + if len(src) != 4 { + return fmt.Errorf("invalid length for float4: %v", len(src)) } - n, err := pgio.ReadInt32(r) - if err != nil { - return err - } + n := int32(binary.BigEndian.Uint32(src)) *dst = Float4{Float: math.Float32frombits(uint32(n)), Status: Present} return nil diff --git a/float4array.go b/float4array.go index 8834d213..a4a72146 100644 --- a/float4array.go +++ b/float4array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -73,29 +74,17 @@ func (src *Float4Array) AssignTo(dst interface{}) error { return nil } -func (dst *Float4Array) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float4Array) DecodeText(src []byte) error { + if src == nil { *dst = Float4Array{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Float4 if len(uta.Elements) > 0 { @@ -103,8 +92,11 @@ func (dst *Float4Array) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Float4 - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -118,19 +110,14 @@ func (dst *Float4Array) DecodeText(r io.Reader) error { return nil } -func (dst *Float4Array) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float4Array) DecodeBinary(src []byte) error { + if src == nil { *dst = Float4Array{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -148,7 +135,14 @@ func (dst *Float4Array) DecodeBinary(r io.Reader) error { elements := make([]Float4, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -236,6 +230,10 @@ func (src *Float4Array) EncodeText(w io.Writer) error { } func (src *Float4Array) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, Float4OID) +} + +func (src *Float4Array) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -256,7 +254,7 @@ func (src *Float4Array) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = Float4OID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/float8.go b/float8.go index c1347cb2..a53de5e3 100644 --- a/float8.go +++ b/float8.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "math" @@ -82,24 +83,13 @@ func (src *Float8) AssignTo(dst interface{}) error { return float64AssignTo(src.Float, src.Status, dst) } -func (dst *Float8) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float8) DecodeText(src []byte) error { + if src == nil { *dst = Float8{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseFloat(string(buf), 64) + n, err := strconv.ParseFloat(string(src), 64) if err != nil { return err } @@ -108,25 +98,17 @@ func (dst *Float8) DecodeText(r io.Reader) error { return nil } -func (dst *Float8) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float8) DecodeBinary(src []byte) error { + if src == nil { *dst = Float8{Status: Null} return nil } - if size != 8 { - return fmt.Errorf("invalid length for float4: %v", size) + if len(src) != 8 { + return fmt.Errorf("invalid length for float4: %v", len(src)) } - n, err := pgio.ReadInt64(r) - if err != nil { - return err - } + n := int64(binary.BigEndian.Uint64(src)) *dst = Float8{Float: math.Float64frombits(uint64(n)), Status: Present} return nil diff --git a/float8array.go b/float8array.go index bad9ed9f..082e817d 100644 --- a/float8array.go +++ b/float8array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -73,29 +74,17 @@ func (src *Float8Array) AssignTo(dst interface{}) error { return nil } -func (dst *Float8Array) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float8Array) DecodeText(src []byte) error { + if src == nil { *dst = Float8Array{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Float8 if len(uta.Elements) > 0 { @@ -103,8 +92,11 @@ func (dst *Float8Array) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Float8 - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -118,19 +110,14 @@ func (dst *Float8Array) DecodeText(r io.Reader) error { return nil } -func (dst *Float8Array) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float8Array) DecodeBinary(src []byte) error { + if src == nil { *dst = Float8Array{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -148,7 +135,14 @@ func (dst *Float8Array) DecodeBinary(r io.Reader) error { elements := make([]Float8, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -236,6 +230,10 @@ func (src *Float8Array) EncodeText(w io.Writer) error { } func (src *Float8Array) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, Float8OID) +} + +func (src *Float8Array) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -256,7 +254,7 @@ func (src *Float8Array) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = Float8OID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/inet.go b/inet.go index e47c64b0..132a876a 100644 --- a/inet.go +++ b/inet.go @@ -91,26 +91,16 @@ func (src *Inet) AssignTo(dst interface{}) error { return nil } -func (dst *Inet) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Inet) DecodeText(src []byte) error { + if src == nil { *dst = Inet{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) - if err != nil { - return err - } - var ipnet *net.IPNet + var err error - if ip := net.ParseIP(string(buf)); ip != nil { + if ip := net.ParseIP(string(src)); ip != nil { ipv4 := ip.To4() if ipv4 != nil { ip = ipv4 @@ -119,7 +109,7 @@ func (dst *Inet) DecodeText(r io.Reader) error { mask := net.CIDRMask(bitCount, bitCount) ipnet = &net.IPNet{Mask: mask, IP: ip} } else { - _, ipnet, err = net.ParseCIDR(string(buf)) + _, ipnet, err = net.ParseCIDR(string(src)) if err != nil { return err } @@ -129,50 +119,24 @@ func (dst *Inet) DecodeText(r io.Reader) error { return nil } -func (dst *Inet) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Inet) DecodeBinary(src []byte) error { + if src == nil { *dst = Inet{Status: Null} return nil } - if size != 8 && size != 20 { - return fmt.Errorf("Received an invalid size for a inet: %d", size) + if len(src) != 8 && len(src) != 20 { + return fmt.Errorf("Received an invalid size for a inet: %d", len(src)) } // ignore family - _, err = pgio.ReadByte(r) - if err != nil { - return err - } - - bits, err := pgio.ReadByte(r) - if err != nil { - return err - } - + bits := src[1] // ignore is_cidr - _, err = pgio.ReadByte(r) - if err != nil { - return err - } - - addressLength, err := pgio.ReadByte(r) - if err != nil { - return err - } + addressLength := src[3] var ipnet net.IPNet ipnet.IP = make(net.IP, int(addressLength)) - _, err = r.Read(ipnet.IP) - if err != nil { - return err - } - + copy(ipnet.IP, src[4:]) ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) *dst = Inet{IPNet: &ipnet, Status: Present} diff --git a/inetarray.go b/inetarray.go index cd12e917..28de736f 100644 --- a/inetarray.go +++ b/inetarray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" "net" @@ -19,8 +20,7 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { switch value := src.(type) { case InetArray: *dst = value - case CidrArray: - *dst = InetArray(value) + case []*net.IPNet: if value == nil { *dst = InetArray{Status: Null} @@ -39,6 +39,7 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { Status: Present, } } + case []net.IP: if value == nil { *dst = InetArray{Status: Null} @@ -57,6 +58,7 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { Status: Present, } } + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -81,6 +83,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { } else { *v = nil } + case *[]net.IP: if src.Status == Present { *v = make([]net.IP, len(src.Elements)) @@ -103,29 +106,17 @@ func (src *InetArray) AssignTo(dst interface{}) error { return nil } -func (dst *InetArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *InetArray) DecodeText(src []byte) error { + if src == nil { *dst = InetArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Inet if len(uta.Elements) > 0 { @@ -133,8 +124,11 @@ func (dst *InetArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Inet - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -148,19 +142,14 @@ func (dst *InetArray) DecodeText(r io.Reader) error { return nil } -func (dst *InetArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *InetArray) DecodeBinary(src []byte) error { + if src == nil { *dst = InetArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -178,7 +167,14 @@ func (dst *InetArray) DecodeBinary(r io.Reader) error { elements := make([]Inet, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } diff --git a/int2.go b/int2.go index 8057550b..51346a43 100644 --- a/int2.go +++ b/int2.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "math" @@ -88,24 +89,13 @@ func (src *Int2) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int2) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int2) DecodeText(src []byte) error { + if src == nil { *dst = Int2{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseInt(string(buf), 10, 16) + n, err := strconv.ParseInt(string(src), 10, 16) if err != nil { return err } @@ -114,27 +104,18 @@ func (dst *Int2) DecodeText(r io.Reader) error { return nil } -func (dst *Int2) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int2) DecodeBinary(src []byte) error { + if src == nil { *dst = Int2{Status: Null} return nil } - if size != 2 { - return fmt.Errorf("invalid length for int2: %v", size) + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) } - n, err := pgio.ReadInt16(r) - if err != nil { - return err - } - - *dst = Int2{Int: int16(n), Status: Present} + n := int16(binary.BigEndian.Uint16(src)) + *dst = Int2{Int: n, Status: Present} return nil } diff --git a/int2array.go b/int2array.go index a989347d..71760e1e 100644 --- a/int2array.go +++ b/int2array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -104,29 +105,17 @@ func (src *Int2Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int2Array) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int2Array) DecodeText(src []byte) error { + if src == nil { *dst = Int2Array{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Int2 if len(uta.Elements) > 0 { @@ -134,8 +123,11 @@ func (dst *Int2Array) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Int2 - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -149,19 +141,14 @@ func (dst *Int2Array) DecodeText(r io.Reader) error { return nil } -func (dst *Int2Array) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int2Array) DecodeBinary(src []byte) error { + if src == nil { *dst = Int2Array{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -179,7 +166,14 @@ func (dst *Int2Array) DecodeBinary(r io.Reader) error { elements := make([]Int2, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -267,6 +261,10 @@ func (src *Int2Array) EncodeText(w io.Writer) error { } func (src *Int2Array) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, Int2OID) +} + +func (src *Int2Array) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -287,7 +285,7 @@ func (src *Int2Array) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = Int2OID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/int4.go b/int4.go index 43691bb6..8a53d454 100644 --- a/int4.go +++ b/int4.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "math" @@ -79,24 +80,13 @@ func (src *Int4) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int4) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int4) DecodeText(src []byte) error { + if src == nil { *dst = Int4{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseInt(string(buf), 10, 32) + n, err := strconv.ParseInt(string(src), 10, 32) if err != nil { return err } @@ -105,26 +95,17 @@ func (dst *Int4) DecodeText(r io.Reader) error { return nil } -func (dst *Int4) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int4) DecodeBinary(src []byte) error { + if src == nil { *dst = Int4{Status: Null} return nil } - if size != 4 { - return fmt.Errorf("invalid length for int4: %v", size) - } - - n, err := pgio.ReadInt32(r) - if err != nil { - return err + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) } + n := int32(binary.BigEndian.Uint32(src)) *dst = Int4{Int: n, Status: Present} return nil } diff --git a/int4array.go b/int4array.go index 89caf263..6a202b08 100644 --- a/int4array.go +++ b/int4array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -104,29 +105,17 @@ func (src *Int4Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int4Array) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int4Array) DecodeText(src []byte) error { + if src == nil { *dst = Int4Array{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Int4 if len(uta.Elements) > 0 { @@ -134,8 +123,11 @@ func (dst *Int4Array) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Int4 - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -149,19 +141,14 @@ func (dst *Int4Array) DecodeText(r io.Reader) error { return nil } -func (dst *Int4Array) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int4Array) DecodeBinary(src []byte) error { + if src == nil { *dst = Int4Array{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -179,7 +166,14 @@ func (dst *Int4Array) DecodeBinary(r io.Reader) error { elements := make([]Int4, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -267,6 +261,10 @@ func (src *Int4Array) EncodeText(w io.Writer) error { } func (src *Int4Array) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, Int4OID) +} + +func (src *Int4Array) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -287,7 +285,7 @@ func (src *Int4Array) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = Int4OID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/int8.go b/int8.go index b87bb85a..c6bedaa6 100644 --- a/int8.go +++ b/int8.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "math" @@ -70,24 +71,13 @@ func (src *Int8) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int8) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int8) DecodeText(src []byte) error { + if src == nil { *dst = Int8{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseInt(string(buf), 10, 64) + n, err := strconv.ParseInt(string(src), 10, 64) if err != nil { return err } @@ -96,25 +86,17 @@ func (dst *Int8) DecodeText(r io.Reader) error { return nil } -func (dst *Int8) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int8) DecodeBinary(src []byte) error { + if src == nil { *dst = Int8{Status: Null} return nil } - if size != 8 { - return fmt.Errorf("invalid length for int8: %v", size) + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) } - n, err := pgio.ReadInt64(r) - if err != nil { - return err - } + n := int64(binary.BigEndian.Uint64(src)) *dst = Int8{Int: n, Status: Present} return nil diff --git a/int8array.go b/int8array.go index 003ed055..f621618e 100644 --- a/int8array.go +++ b/int8array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -104,29 +105,17 @@ func (src *Int8Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int8Array) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int8Array) DecodeText(src []byte) error { + if src == nil { *dst = Int8Array{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Int8 if len(uta.Elements) > 0 { @@ -134,8 +123,11 @@ func (dst *Int8Array) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Int8 - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -149,19 +141,14 @@ func (dst *Int8Array) DecodeText(r io.Reader) error { return nil } -func (dst *Int8Array) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int8Array) DecodeBinary(src []byte) error { + if src == nil { *dst = Int8Array{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -179,7 +166,14 @@ func (dst *Int8Array) DecodeBinary(r io.Reader) error { elements := make([]Int8, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -267,6 +261,10 @@ func (src *Int8Array) EncodeText(w io.Writer) error { } func (src *Int8Array) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, Int8OID) +} + +func (src *Int8Array) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -287,7 +285,7 @@ func (src *Int8Array) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = Int8OID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/name.go b/name.go index 3ff81f12..4bbc43c1 100644 --- a/name.go +++ b/name.go @@ -27,12 +27,12 @@ func (src *Name) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } -func (dst *Name) DecodeText(r io.Reader) error { - return (*Text)(dst).DecodeText(r) +func (dst *Name) DecodeText(src []byte) error { + return (*Text)(dst).DecodeText(src) } -func (dst *Name) DecodeBinary(r io.Reader) error { - return (*Text)(dst).DecodeBinary(r) +func (dst *Name) DecodeBinary(src []byte) error { + return (*Text)(dst).DecodeBinary(src) } func (src Name) EncodeText(w io.Writer) error { diff --git a/oid.go b/oid.go index d137f352..2ea9c2d1 100644 --- a/oid.go +++ b/oid.go @@ -24,12 +24,12 @@ func (src *OID) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *OID) DecodeText(r io.Reader) error { - return (*pguint32)(dst).DecodeText(r) +func (dst *OID) DecodeText(src []byte) error { + return (*pguint32)(dst).DecodeText(src) } -func (dst *OID) DecodeBinary(r io.Reader) error { - return (*pguint32)(dst).DecodeBinary(r) +func (dst *OID) DecodeBinary(src []byte) error { + return (*pguint32)(dst).DecodeBinary(src) } func (src OID) EncodeText(w io.Writer) error { diff --git a/pgtype.go b/pgtype.go index 15c0cc76..7928e1cc 100644 --- a/pgtype.go +++ b/pgtype.go @@ -74,11 +74,11 @@ type Value interface { } type BinaryDecoder interface { - DecodeBinary(r io.Reader) error + DecodeBinary(src []byte) error } type TextDecoder interface { - DecodeText(r io.Reader) error + DecodeText(src []byte) error } type BinaryEncoder interface { diff --git a/pguint32.go b/pguint32.go index 9c1ccd6c..9bf1eef6 100644 --- a/pguint32.go +++ b/pguint32.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "strconv" @@ -51,24 +52,13 @@ func (src *pguint32) AssignTo(dst interface{}) error { return nil } -func (dst *pguint32) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *pguint32) DecodeText(src []byte) error { + if src == nil { *dst = pguint32{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseUint(string(buf), 10, 32) + n, err := strconv.ParseUint(string(src), 10, 32) if err != nil { return err } @@ -77,26 +67,17 @@ func (dst *pguint32) DecodeText(r io.Reader) error { return nil } -func (dst *pguint32) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *pguint32) DecodeBinary(src []byte) error { + if src == nil { *dst = pguint32{Status: Null} return nil } - if size != 4 { - return fmt.Errorf("invalid length: %v", size) - } - - n, err := pgio.ReadUint32(r) - if err != nil { - return err + if len(src) != 4 { + return fmt.Errorf("invalid length: %v", len(src)) } + n := binary.BigEndian.Uint32(src) *dst = pguint32{Uint: n, Status: Present} return nil } diff --git a/qchar.go b/qchar.go index 6dd14625..8abec935 100644 --- a/qchar.go +++ b/qchar.go @@ -106,27 +106,17 @@ func (src *QChar) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *QChar) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *QChar) DecodeBinary(src []byte) error { + if src == nil { *dst = QChar{Status: Null} return nil } - if size != 1 { - return fmt.Errorf(`invalid length for "char": %v`, size) + if len(src) != 1 { + return fmt.Errorf(`invalid length for "char": %v`, len(src)) } - byt, err := pgio.ReadByte(r) - if err != nil { - return err - } - - *dst = QChar{Int: int8(byt), Status: Present} + *dst = QChar{Int: int8(src[0]), Status: Present} return nil } diff --git a/text.go b/text.go index c9054468..2951b5ad 100644 --- a/text.go +++ b/text.go @@ -71,29 +71,18 @@ func (src *Text) AssignTo(dst interface{}) error { return nil } -func (dst *Text) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Text) DecodeText(src []byte) error { + if src == nil { *dst = Text{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - *dst = Text{String: string(buf), Status: Present} + *dst = Text{String: string(src), Status: Present} return nil } -func (dst *Text) DecodeBinary(r io.Reader) error { - return dst.DecodeText(r) +func (dst *Text) DecodeBinary(src []byte) error { + return dst.DecodeText(src) } func (src Text) EncodeText(w io.Writer) error { diff --git a/textarray.go b/textarray.go index c420e5c9..e7ca3578 100644 --- a/textarray.go +++ b/textarray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -73,29 +74,17 @@ func (src *TextArray) AssignTo(dst interface{}) error { return nil } -func (dst *TextArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TextArray) DecodeText(src []byte) error { + if src == nil { *dst = TextArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Text if len(uta.Elements) > 0 { @@ -103,8 +92,11 @@ func (dst *TextArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Text - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -118,19 +110,14 @@ func (dst *TextArray) DecodeText(r io.Reader) error { return nil } -func (dst *TextArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TextArray) DecodeBinary(src []byte) error { + if src == nil { *dst = TextArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -148,7 +135,14 @@ func (dst *TextArray) DecodeBinary(r io.Reader) error { elements := make([]Text, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -211,7 +205,12 @@ func (src *TextArray) EncodeText(w io.Writer) error { } textElementWriter.Reset() - if elem.String == "" && elem.Status == Present { + if elem.Status == Null { + _, err := io.WriteString(buf, `"NULL"`) + if err != nil { + return err + } + } else if elem.String == "" { _, err := io.WriteString(buf, `""`) if err != nil { return err diff --git a/timestamp.go b/timestamp.go index c6933988..ca5eb738 100644 --- a/timestamp.go +++ b/timestamp.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "reflect" @@ -72,24 +73,13 @@ func (src *Timestamp) AssignTo(dst interface{}) error { // DecodeText decodes from src into dst. The decoded time is considered to // be in UTC. -func (dst *Timestamp) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Timestamp) DecodeText(src []byte) error { + if src == nil { *dst = Timestamp{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - sbuf := string(buf) + sbuf := string(src) switch sbuf { case "infinity": *dst = Timestamp{Status: Present, InfinityModifier: Infinity} @@ -109,25 +99,17 @@ func (dst *Timestamp) DecodeText(r io.Reader) error { // DecodeBinary decodes from src into dst. The decoded time is considered to // be in UTC. -func (dst *Timestamp) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Timestamp) DecodeBinary(src []byte) error { + if src == nil { *dst = Timestamp{Status: Null} return nil } - if size != 8 { - return fmt.Errorf("invalid length for timestamp: %v", size) + if len(src) != 8 { + return fmt.Errorf("invalid length for timestamp: %v", len(src)) } - microsecSinceY2K, err := pgio.ReadInt64(r) - if err != nil { - return err - } + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) switch microsecSinceY2K { case infinityMicrosecondOffset: diff --git a/timestamparray.go b/timestamparray.go index 3acbb35f..695559ac 100644 --- a/timestamparray.go +++ b/timestamparray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" "time" @@ -74,29 +75,17 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { return nil } -func (dst *TimestampArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TimestampArray) DecodeText(src []byte) error { + if src == nil { *dst = TimestampArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Timestamp if len(uta.Elements) > 0 { @@ -104,8 +93,11 @@ func (dst *TimestampArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Timestamp - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -119,19 +111,14 @@ func (dst *TimestampArray) DecodeText(r io.Reader) error { return nil } -func (dst *TimestampArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TimestampArray) DecodeBinary(src []byte) error { + if src == nil { *dst = TimestampArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -149,7 +136,14 @@ func (dst *TimestampArray) DecodeBinary(r io.Reader) error { elements := make([]Timestamp, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -237,6 +231,10 @@ func (src *TimestampArray) EncodeText(w io.Writer) error { } func (src *TimestampArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, TimestampOID) +} + +func (src *TimestampArray) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -257,7 +255,7 @@ func (src *TimestampArray) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = TimestampOID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/timestamptz.go b/timestamptz.go index 721c8084..7255bb06 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "reflect" @@ -71,24 +72,13 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { return nil } -func (dst *Timestamptz) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Timestamptz) DecodeText(src []byte) error { + if src == nil { *dst = Timestamptz{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - sbuf := string(buf) + sbuf := string(src) switch sbuf { case "infinity": *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} @@ -115,25 +105,17 @@ func (dst *Timestamptz) DecodeText(r io.Reader) error { return nil } -func (dst *Timestamptz) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Timestamptz) DecodeBinary(src []byte) error { + if src == nil { *dst = Timestamptz{Status: Null} return nil } - if size != 8 { - return fmt.Errorf("invalid length for timestamptz: %v", size) + if len(src) != 8 { + return fmt.Errorf("invalid length for timestamptz: %v", len(src)) } - microsecSinceY2K, err := pgio.ReadInt64(r) - if err != nil { - return err - } + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) switch microsecSinceY2K { case infinityMicrosecondOffset: diff --git a/timestamptzarray.go b/timestamptzarray.go index 9df746e6..ca416c97 100644 --- a/timestamptzarray.go +++ b/timestamptzarray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" "time" @@ -74,29 +75,17 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { return nil } -func (dst *TimestamptzArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TimestamptzArray) DecodeText(src []byte) error { + if src == nil { *dst = TimestamptzArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Timestamptz if len(uta.Elements) > 0 { @@ -104,8 +93,11 @@ func (dst *TimestamptzArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Timestamptz - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -119,19 +111,14 @@ func (dst *TimestamptzArray) DecodeText(r io.Reader) error { return nil } -func (dst *TimestamptzArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TimestamptzArray) DecodeBinary(src []byte) error { + if src == nil { *dst = TimestamptzArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -149,7 +136,14 @@ func (dst *TimestamptzArray) DecodeBinary(r io.Reader) error { elements := make([]Timestamptz, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -237,6 +231,10 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) error { } func (src *TimestamptzArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, TimestamptzOID) +} + +func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -257,7 +255,7 @@ func (src *TimestamptzArray) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = TimestamptzOID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/to-consider.txt b/to-consider.txt new file mode 100644 index 00000000..ba4f3511 --- /dev/null +++ b/to-consider.txt @@ -0,0 +1,9 @@ +DecodeText and DecodeBinary take []byte instead of io.Reader +EncodeText and EncodeBinary do not write size +Add Nullable interface with IsNull() and SetNull() + +The above would keep types from needing to worry about writing their own size. Could make EncodeText and DecodeText easier to use with sql.Scanner and driver.Valuer. SetNull() could be removed as DecodeText and DecodeBinary could interpret a nil slice as null. + +EncodeText and EncodeBinary could return (null bool, err error). That would finish removing Nullable interface. + +Also, consider whether arrays and ranges could be represented as generic data types or more common code could be extracted instead of using code generation. diff --git a/typed_array.go.erb b/typed_array.go.erb index 8c18073b..316439ef 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -73,29 +73,17 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { return nil } -func (dst *<%= pgtype_array_type %>) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { + if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []<%= pgtype_element_type %> if len(uta.Elements) > 0 { @@ -103,8 +91,11 @@ func (dst *<%= pgtype_array_type %>) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem <%= pgtype_element_type %> - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -118,19 +109,14 @@ func (dst *<%= pgtype_array_type %>) DecodeText(r io.Reader) error { return nil } -func (dst *<%= pgtype_array_type %>) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { + if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -148,7 +134,14 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(r io.Reader) error { elements := make([]<%= pgtype_element_type %>, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp:rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -211,16 +204,9 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) error { } textElementWriter.Reset() - if elem.String == "" && elem.Status == Present { - _, err := io.WriteString(buf, `""`) - if err != nil { - return err - } - } else { - err = elem.EncodeText(textElementWriter) - if err != nil { - return err - } + err = elem.EncodeText(textElementWriter) + if err != nil { + return err } for _, dec := range dimElemCounts { diff --git a/varchararray.go b/varchararray.go index 13d94bc0..3a5d8536 100644 --- a/varchararray.go +++ b/varchararray.go @@ -14,12 +14,12 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { return (*TextArray)(src).AssignTo(dst) } -func (dst *VarcharArray) DecodeText(r io.Reader) error { - return (*TextArray)(dst).DecodeText(r) +func (dst *VarcharArray) DecodeText(src []byte) error { + return (*TextArray)(dst).DecodeText(src) } -func (dst *VarcharArray) DecodeBinary(r io.Reader) error { - return (*TextArray)(dst).DecodeBinary(r) +func (dst *VarcharArray) DecodeBinary(src []byte) error { + return (*TextArray)(dst).DecodeBinary(src) } func (src *VarcharArray) EncodeText(w io.Writer) error { diff --git a/xid.go b/xid.go index d4003b5d..389f93bc 100644 --- a/xid.go +++ b/xid.go @@ -33,12 +33,12 @@ func (src *XID) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *XID) DecodeText(r io.Reader) error { - return (*pguint32)(dst).DecodeText(r) +func (dst *XID) DecodeText(src []byte) error { + return (*pguint32)(dst).DecodeText(src) } -func (dst *XID) DecodeBinary(r io.Reader) error { - return (*pguint32)(dst).DecodeBinary(r) +func (dst *XID) DecodeBinary(src []byte) error { + return (*pguint32)(dst).DecodeBinary(src) } func (src XID) EncodeText(w io.Writer) error { From e654d1f0fc4ad76d2c2d59e1cbf7cab7cbec4d67 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 12:32:33 -0600 Subject: [PATCH 0026/1158] pgtype.Encode(Binary|Text) do not write length To aid in composability, these methods no longer write their own length. This is especially useful for text formatted arrays and may be useful for future database/sql compatibility. It also makes the code a little simpler as the types no longer have to compute their own size. Along with this, these methods cannot encode NULL. They now return a boolean if they are NULL. This also benefits text array encoding as numeric arrays require NULL to be exactly `NULL` while string arrays require NULL to be `"NULL"`. --- bool.go | 38 +++++------- boolarray.go | 148 +++++++++++++++++++++++--------------------- bytea.go | 44 ++++++------- cid.go | 4 +- cidrarray.go | 4 +- date.go | 36 +++++------ datearray.go | 148 +++++++++++++++++++++++--------------------- float4.go | 36 +++++------ float4array.go | 148 +++++++++++++++++++++++--------------------- float8.go | 36 +++++------ float8array.go | 148 +++++++++++++++++++++++--------------------- inet.go | 46 ++++++-------- inetarray.go | 148 +++++++++++++++++++++++--------------------- int2.go | 36 +++++------ int2array.go | 148 +++++++++++++++++++++++--------------------- int4.go | 36 +++++------ int4array.go | 148 +++++++++++++++++++++++--------------------- int8.go | 36 +++++------ int8array.go | 148 +++++++++++++++++++++++--------------------- name.go | 4 +- oid.go | 4 +- pgtype.go | 29 +++++---- pgtype_test.go | 4 +- pguint32.go | 36 +++++------ qchar.go | 16 +++-- text.go | 22 +++---- text_element.go | 112 --------------------------------- textarray.go | 148 +++++++++++++++++++++----------------------- timestamp.go | 40 ++++++------ timestamparray.go | 148 +++++++++++++++++++++++--------------------- timestamptz.go | 36 +++++------ timestamptzarray.go | 148 +++++++++++++++++++++++--------------------- to-consider.txt | 9 --- typed_array.go.erb | 148 +++++++++++++++++++++++--------------------- typed_array_gen.sh | 22 +++---- varchararray.go | 4 +- xid.go | 4 +- 37 files changed, 1185 insertions(+), 1285 deletions(-) delete mode 100644 text_element.go delete mode 100644 to-consider.txt diff --git a/bool.go b/bool.go index b7bc14d0..9764fafe 100644 --- a/bool.go +++ b/bool.go @@ -5,8 +5,6 @@ import ( "io" "reflect" "strconv" - - "github.com/jackc/pgx/pgio" ) type Bool struct { @@ -100,14 +98,12 @@ func (dst *Bool) DecodeBinary(src []byte) error { return nil } -func (src Bool) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err - } - - _, err := pgio.WriteInt32(w, 1) - if err != nil { - return nil +func (src Bool) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } var buf []byte @@ -117,18 +113,16 @@ func (src Bool) EncodeText(w io.Writer) error { buf = []byte{'f'} } - _, err = w.Write(buf) - return err + _, err := w.Write(buf) + return false, err } -func (src Bool) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err - } - - _, err := pgio.WriteInt32(w, 1) - if err != nil { - return nil +func (src Bool) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } var buf []byte @@ -138,6 +132,6 @@ func (src Bool) EncodeBinary(w io.Writer) error { buf = []byte{0} } - _, err = w.Write(buf) - return err + _, err := w.Write(buf) + return false, err } diff --git a/boolarray.go b/boolarray.go index a9b8bf50..f7323281 100644 --- a/boolarray.go +++ b/boolarray.go @@ -152,26 +152,22 @@ func (dst *BoolArray) DecodeBinary(src []byte) error { return nil } -func (src *BoolArray) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -185,100 +181,112 @@ func (src *BoolArray) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *BoolArray) EncodeBinary(w io.Writer) error { +func (src *BoolArray) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, BoolOID) } -func (src *BoolArray) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *BoolArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/bytea.go b/bytea.go index db20482f..709499d2 100644 --- a/bytea.go +++ b/bytea.go @@ -5,8 +5,6 @@ import ( "fmt" "io" "reflect" - - "github.com/jackc/pgx/pgio" ) type Bytea struct { @@ -101,37 +99,31 @@ func (dst *Bytea) DecodeBinary(src []byte) error { return nil } -func (src Bytea) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Bytea) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - str := hex.EncodeToString(src.Bytes) - - _, err := pgio.WriteInt32(w, int32(len(str)+2)) + _, err := io.WriteString(w, `\x`) if err != nil { - return nil + return false, err } - _, err = io.WriteString(w, `\x`) - if err != nil { - return nil - } - - _, err = io.WriteString(w, str) - return err + _, err = io.WriteString(w, hex.EncodeToString(src.Bytes)) + return false, err } -func (src Bytea) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Bytea) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, int32(len(src.Bytes))) - if err != nil { - return nil - } - - _, err = w.Write(src.Bytes) - return err + _, err := w.Write(src.Bytes) + return false, err } diff --git a/cid.go b/cid.go index f8d706d0..41b817bb 100644 --- a/cid.go +++ b/cid.go @@ -38,10 +38,10 @@ func (dst *CID) DecodeBinary(src []byte) error { return (*pguint32)(dst).DecodeBinary(src) } -func (src CID) EncodeText(w io.Writer) error { +func (src CID) EncodeText(w io.Writer) (bool, error) { return (pguint32)(src).EncodeText(w) } -func (src CID) EncodeBinary(w io.Writer) error { +func (src CID) EncodeBinary(w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(w) } diff --git a/cidrarray.go b/cidrarray.go index d95eef4a..cb81d2b9 100644 --- a/cidrarray.go +++ b/cidrarray.go @@ -22,10 +22,10 @@ func (dst *CidrArray) DecodeBinary(src []byte) error { return (*InetArray)(dst).DecodeBinary(src) } -func (src *CidrArray) EncodeText(w io.Writer) error { +func (src *CidrArray) EncodeText(w io.Writer) (bool, error) { return (*InetArray)(src).EncodeText(w) } -func (src *CidrArray) EncodeBinary(w io.Writer) error { +func (src *CidrArray) EncodeBinary(w io.Writer) (bool, error) { return (*InetArray)(src).encodeBinary(w, CidrOID) } diff --git a/date.go b/date.go index 1bb81d35..b0d16e64 100644 --- a/date.go +++ b/date.go @@ -116,9 +116,12 @@ func (dst *Date) DecodeBinary(src []byte) error { return nil } -func (src Date) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Date) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } var s string @@ -132,23 +135,16 @@ func (src Date) EncodeText(w io.Writer) error { s = "-infinity" } - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, s) + return false, err } -func (src Date) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err - } - - _, err := pgio.WriteInt32(w, 4) - if err != nil { - return err +func (src Date) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } var daysSinceDateEpoch int32 @@ -165,6 +161,6 @@ func (src Date) EncodeBinary(w io.Writer) error { daysSinceDateEpoch = negativeInfinityDayOffset } - _, err = pgio.WriteInt32(w, daysSinceDateEpoch) - return err + _, err := pgio.WriteInt32(w, daysSinceDateEpoch) + return false, err } diff --git a/datearray.go b/datearray.go index e9ad1f62..9552739b 100644 --- a/datearray.go +++ b/datearray.go @@ -153,26 +153,22 @@ func (dst *DateArray) DecodeBinary(src []byte) error { return nil } -func (src *DateArray) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *DateArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -186,100 +182,112 @@ func (src *DateArray) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *DateArray) EncodeBinary(w io.Writer) error { +func (src *DateArray) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, DateOID) } -func (src *DateArray) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *DateArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/float4.go b/float4.go index fb0415e5..26609ab2 100644 --- a/float4.go +++ b/float4.go @@ -124,30 +124,26 @@ func (dst *Float4) DecodeBinary(src []byte) error { return nil } -func (src Float4) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Float4) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - s := strconv.FormatFloat(float64(src.Float), 'f', -1, 32) - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, strconv.FormatFloat(float64(src.Float), 'f', -1, 32)) + return false, err } -func (src Float4) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Float4) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, 4) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(math.Float32bits(src.Float))) - return err + _, err := pgio.WriteInt32(w, int32(math.Float32bits(src.Float))) + return false, err } diff --git a/float4array.go b/float4array.go index a4a72146..9ab08dcc 100644 --- a/float4array.go +++ b/float4array.go @@ -152,26 +152,22 @@ func (dst *Float4Array) DecodeBinary(src []byte) error { return nil } -func (src *Float4Array) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -185,100 +181,112 @@ func (src *Float4Array) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *Float4Array) EncodeBinary(w io.Writer) error { +func (src *Float4Array) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, Float4OID) } -func (src *Float4Array) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Float4Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/float8.go b/float8.go index a53de5e3..9ec9a665 100644 --- a/float8.go +++ b/float8.go @@ -114,30 +114,26 @@ func (dst *Float8) DecodeBinary(src []byte) error { return nil } -func (src Float8) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Float8) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - s := strconv.FormatFloat(float64(src.Float), 'f', -1, 64) - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, strconv.FormatFloat(float64(src.Float), 'f', -1, 64)) + return false, err } -func (src Float8) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Float8) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, 8) - if err != nil { - return err - } - - _, err = pgio.WriteInt64(w, int64(math.Float64bits(src.Float))) - return err + _, err := pgio.WriteInt64(w, int64(math.Float64bits(src.Float))) + return false, err } diff --git a/float8array.go b/float8array.go index 082e817d..ce7e3b90 100644 --- a/float8array.go +++ b/float8array.go @@ -152,26 +152,22 @@ func (dst *Float8Array) DecodeBinary(src []byte) error { return nil } -func (src *Float8Array) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -185,100 +181,112 @@ func (src *Float8Array) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *Float8Array) EncodeBinary(w io.Writer) error { +func (src *Float8Array) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, Float8OID) } -func (src *Float8Array) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Float8Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/inet.go b/inet.go index 132a876a..f94622f4 100644 --- a/inet.go +++ b/inet.go @@ -144,61 +144,55 @@ func (dst *Inet) DecodeBinary(src []byte) error { return nil } -func (src Inet) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Inet) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - s := src.IPNet.String() - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, src.IPNet.String()) + return false, err } // EncodeBinary encodes src into w. -func (src Inet) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Inet) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var size int32 var family byte switch len(src.IPNet.IP) { case net.IPv4len: - size = 8 family = defaultAFInet case net.IPv6len: - size = 20 family = defaultAFInet6 default: - return fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) - } - - if _, err := pgio.WriteInt32(w, size); err != nil { - return err + return false, fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) } if err := pgio.WriteByte(w, family); err != nil { - return err + return false, err } ones, _ := src.IPNet.Mask.Size() if err := pgio.WriteByte(w, byte(ones)); err != nil { - return err + return false, err } // is_cidr is ignored on server if err := pgio.WriteByte(w, 0); err != nil { - return err + return false, err } if err := pgio.WriteByte(w, byte(len(src.IPNet.IP))); err != nil { - return err + return false, err } _, err := w.Write(src.IPNet.IP) - return err + return false, err } diff --git a/inetarray.go b/inetarray.go index 28de736f..32cde554 100644 --- a/inetarray.go +++ b/inetarray.go @@ -184,26 +184,22 @@ func (dst *InetArray) DecodeBinary(src []byte) error { return nil } -func (src *InetArray) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *InetArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -217,100 +213,112 @@ func (src *InetArray) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *InetArray) EncodeBinary(w io.Writer) error { +func (src *InetArray) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, InetOID) } -func (src *InetArray) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *InetArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/int2.go b/int2.go index 51346a43..7bdbacfe 100644 --- a/int2.go +++ b/int2.go @@ -119,30 +119,26 @@ func (dst *Int2) DecodeBinary(src []byte) error { return nil } -func (src Int2) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Int2) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - s := strconv.FormatInt(int64(src.Int), 10) - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, strconv.FormatInt(int64(src.Int), 10)) + return false, err } -func (src Int2) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Int2) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = pgio.WriteInt16(w, src.Int) - return err + _, err := pgio.WriteInt16(w, src.Int) + return false, err } diff --git a/int2array.go b/int2array.go index 71760e1e..f7cc2492 100644 --- a/int2array.go +++ b/int2array.go @@ -183,26 +183,22 @@ func (dst *Int2Array) DecodeBinary(src []byte) error { return nil } -func (src *Int2Array) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -216,100 +212,112 @@ func (src *Int2Array) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *Int2Array) EncodeBinary(w io.Writer) error { +func (src *Int2Array) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, Int2OID) } -func (src *Int2Array) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Int2Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/int4.go b/int4.go index 8a53d454..2d96ea48 100644 --- a/int4.go +++ b/int4.go @@ -110,30 +110,26 @@ func (dst *Int4) DecodeBinary(src []byte) error { return nil } -func (src Int4) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Int4) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - s := strconv.FormatInt(int64(src.Int), 10) - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, strconv.FormatInt(int64(src.Int), 10)) + return false, err } -func (src Int4) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Int4) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, 4) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, src.Int) - return err + _, err := pgio.WriteInt32(w, src.Int) + return false, err } diff --git a/int4array.go b/int4array.go index 6a202b08..fa710af7 100644 --- a/int4array.go +++ b/int4array.go @@ -183,26 +183,22 @@ func (dst *Int4Array) DecodeBinary(src []byte) error { return nil } -func (src *Int4Array) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -216,100 +212,112 @@ func (src *Int4Array) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *Int4Array) EncodeBinary(w io.Writer) error { +func (src *Int4Array) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, Int4OID) } -func (src *Int4Array) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Int4Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/int8.go b/int8.go index c6bedaa6..91f5b877 100644 --- a/int8.go +++ b/int8.go @@ -102,30 +102,26 @@ func (dst *Int8) DecodeBinary(src []byte) error { return nil } -func (src Int8) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Int8) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - s := strconv.FormatInt(src.Int, 10) - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, strconv.FormatInt(src.Int, 10)) + return false, err } -func (src Int8) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Int8) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, 8) - if err != nil { - return err - } - - _, err = pgio.WriteInt64(w, src.Int) - return err + _, err := pgio.WriteInt64(w, src.Int) + return false, err } diff --git a/int8array.go b/int8array.go index f621618e..65f42477 100644 --- a/int8array.go +++ b/int8array.go @@ -183,26 +183,22 @@ func (dst *Int8Array) DecodeBinary(src []byte) error { return nil } -func (src *Int8Array) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -216,100 +212,112 @@ func (src *Int8Array) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *Int8Array) EncodeBinary(w io.Writer) error { +func (src *Int8Array) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, Int8OID) } -func (src *Int8Array) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Int8Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/name.go b/name.go index 4bbc43c1..513abfc7 100644 --- a/name.go +++ b/name.go @@ -35,10 +35,10 @@ func (dst *Name) DecodeBinary(src []byte) error { return (*Text)(dst).DecodeBinary(src) } -func (src Name) EncodeText(w io.Writer) error { +func (src Name) EncodeText(w io.Writer) (bool, error) { return (Text)(src).EncodeText(w) } -func (src Name) EncodeBinary(w io.Writer) error { +func (src Name) EncodeBinary(w io.Writer) (bool, error) { return (Text)(src).EncodeBinary(w) } diff --git a/oid.go b/oid.go index 2ea9c2d1..e1bee4cf 100644 --- a/oid.go +++ b/oid.go @@ -32,10 +32,10 @@ func (dst *OID) DecodeBinary(src []byte) error { return (*pguint32)(dst).DecodeBinary(src) } -func (src OID) EncodeText(w io.Writer) error { +func (src OID) EncodeText(w io.Writer) (bool, error) { return (pguint32)(src).EncodeText(w) } -func (src OID) EncodeBinary(w io.Writer) error { +func (src OID) EncodeBinary(w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(w) } diff --git a/pgtype.go b/pgtype.go index 7928e1cc..d6cd53c1 100644 --- a/pgtype.go +++ b/pgtype.go @@ -3,8 +3,6 @@ package pgtype import ( "errors" "io" - - "github.com/jackc/pgx/pgio" ) // PostgreSQL oids for common types @@ -81,23 +79,24 @@ type TextDecoder interface { DecodeText(src []byte) error } +// BinaryEncoder is implemented by types that can encode themselves into the +// PostgreSQL binary wire format. type BinaryEncoder interface { - EncodeBinary(w io.Writer) error + // EncodeBinary should encode the binary format of self to w. If self is the + // SQL value NULL then write nothing and return (true, nil). The caller of + // EncodeBinary is responsible for writing the correct NULL value or the + // length of the data written. + EncodeBinary(w io.Writer) (null bool, err error) } +// TextEncoder is implemented by types that can encode themselves into the +// PostgreSQL text wire format. type TextEncoder interface { - EncodeText(w io.Writer) error + // EncodeText should encode the text format of self to w. If self is the SQL + // value NULL then write nothing and return (true, nil). The caller of + // EncodeText is responsible for writing the correct NULL value or the length + // of the data written. + EncodeText(w io.Writer) (null bool, err error) } var errUndefined = errors.New("cannot encode status undefined") - -func encodeNotPresent(w io.Writer, status Status) (done bool, err error) { - switch status { - case Undefined: - return true, errUndefined - case Null: - _, err = pgio.WriteInt32(w, -1) - return true, err - } - return false, nil -} diff --git a/pgtype_test.go b/pgtype_test.go index 6e173cbe..07a40160 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -60,7 +60,7 @@ type forceTextEncoder struct { e pgtype.TextEncoder } -func (f forceTextEncoder) EncodeText(w io.Writer) error { +func (f forceTextEncoder) EncodeText(w io.Writer) (bool, error) { return f.e.EncodeText(w) } @@ -68,7 +68,7 @@ type forceBinaryEncoder struct { e pgtype.BinaryEncoder } -func (f forceBinaryEncoder) EncodeBinary(w io.Writer) error { +func (f forceBinaryEncoder) EncodeBinary(w io.Writer) (bool, error) { return f.e.EncodeBinary(w) } diff --git a/pguint32.go b/pguint32.go index 9bf1eef6..df9e0d36 100644 --- a/pguint32.go +++ b/pguint32.go @@ -82,30 +82,26 @@ func (dst *pguint32) DecodeBinary(src []byte) error { return nil } -func (src pguint32) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src pguint32) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - s := strconv.FormatUint(uint64(src.Uint), 10) - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, strconv.FormatUint(uint64(src.Uint), 10)) + return false, err } -func (src pguint32) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src pguint32) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, 4) - if err != nil { - return err - } - - _, err = pgio.WriteUint32(w, src.Uint) - return err + _, err := pgio.WriteUint32(w, src.Uint) + return false, err } diff --git a/qchar.go b/qchar.go index 8abec935..0da1e88b 100644 --- a/qchar.go +++ b/qchar.go @@ -120,15 +120,13 @@ func (dst *QChar) DecodeBinary(src []byte) error { return nil } -func (src QChar) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src QChar) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, 1) - if err != nil { - return nil - } - - return pgio.WriteByte(w, byte(src.Int)) + return false, pgio.WriteByte(w, byte(src.Int)) } diff --git a/text.go b/text.go index 2951b5ad..baf62d1e 100644 --- a/text.go +++ b/text.go @@ -4,8 +4,6 @@ import ( "fmt" "io" "reflect" - - "github.com/jackc/pgx/pgio" ) type Text struct { @@ -85,20 +83,18 @@ func (dst *Text) DecodeBinary(src []byte) error { return dst.DecodeText(src) } -func (src Text) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Text) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, int32(len(src.String))) - if err != nil { - return nil - } - - _, err = io.WriteString(w, src.String) - return err + _, err := io.WriteString(w, src.String) + return false, err } -func (src Text) EncodeBinary(w io.Writer) error { +func (src Text) EncodeBinary(w io.Writer) (bool, error) { return src.EncodeText(w) } diff --git a/text_element.go b/text_element.go deleted file mode 100644 index 1a585d08..00000000 --- a/text_element.go +++ /dev/null @@ -1,112 +0,0 @@ -package pgtype - -import ( - "bytes" - "errors" - "io" - - "github.com/jackc/pgx/pgio" -) - -// TextElementWriter is a wrapper that makes TextEncoders composable into other -// TextEncoders. TextEncoder first writes the length of the subsequent value. -// This is not necessary when the value is part of another value such as an -// array. TextElementWriter requires one int32 to be written first which it -// ignores. No other integer writes are valid. -type TextElementWriter struct { - w io.Writer - lengthHeaderIgnored bool -} - -func NewTextElementWriter(w io.Writer) *TextElementWriter { - return &TextElementWriter{w: w} -} - -func (w *TextElementWriter) WriteUint16(n uint16) (int, error) { - return 0, errors.New("WriteUint16 should never be called on TextElementWriter") -} - -func (w *TextElementWriter) WriteUint32(n uint32) (int, error) { - if !w.lengthHeaderIgnored { - w.lengthHeaderIgnored = true - - if int32(n) == -1 { - return io.WriteString(w.w, "NULL") - } - - return 4, nil - } - - return 0, errors.New("WriteUint32 should only be called once on TextElementWriter") -} - -func (w *TextElementWriter) WriteUint64(n uint64) (int, error) { - if w.lengthHeaderIgnored { - return pgio.WriteUint64(w.w, n) - } - - return 0, errors.New("WriteUint64 should never be called on TextElementWriter") -} - -func (w *TextElementWriter) Write(buf []byte) (int, error) { - if w.lengthHeaderIgnored { - return w.w.Write(buf) - } - - return 0, errors.New("int32 must be written first") -} - -func (w *TextElementWriter) Reset() { - w.lengthHeaderIgnored = false -} - -// TextElementReader is a wrapper that makes TextDecoders composable into other -// TextDecoders. TextEncoders first read the length of the subsequent value. -// This length value is not present when the value is part of another value such -// as an array. TextElementReader provides a substitute length value from the -// length of the string. No other integer reads are valid. Each time DecodeText -// is called with a TextElementReader as the source the TextElementReader must -// first have Reset called with the new element string data. -type TextElementReader struct { - buf *bytes.Buffer - lengthHeaderIgnored bool -} - -func NewTextElementReader(r io.Reader) *TextElementReader { - return &TextElementReader{buf: &bytes.Buffer{}} -} - -func (r *TextElementReader) ReadUint16() (uint16, error) { - return 0, errors.New("ReadUint16 should never be called on TextElementReader") -} - -func (r *TextElementReader) ReadUint32() (uint32, error) { - if !r.lengthHeaderIgnored { - r.lengthHeaderIgnored = true - if r.buf.String() == "NULL" { - n32 := int32(-1) - return uint32(n32), nil - } - return uint32(r.buf.Len()), nil - } - - return 0, errors.New("ReadUint32 should only be called once on TextElementReader") -} - -func (r *TextElementReader) WriteUint64(n uint64) (int, error) { - return 0, errors.New("ReadUint64 should never be called on TextElementReader") -} - -func (r *TextElementReader) Read(buf []byte) (int, error) { - if r.lengthHeaderIgnored { - return r.buf.Read(buf) - } - - return 0, errors.New("int32 must be read first") -} - -func (r *TextElementReader) Reset(s string) { - r.lengthHeaderIgnored = false - r.buf.Reset() - r.buf.WriteString(s) -} diff --git a/textarray.go b/textarray.go index e7ca3578..c3e595e0 100644 --- a/textarray.go +++ b/textarray.go @@ -152,26 +152,22 @@ func (dst *TextArray) DecodeBinary(src []byte) error { return nil } -func (src *TextArray) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *TextArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -185,112 +181,112 @@ func (src *TextArray) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - if elem.Status == Null { - _, err := io.WriteString(buf, `"NULL"`) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `"NULL"`) if err != nil { - return err + return false, err } - } else if elem.String == "" { - _, err := io.WriteString(buf, `""`) + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) if err != nil { - return err + return false, err } } else { - err = elem.EncodeText(textElementWriter) + _, err = elemBuf.WriteTo(w) if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *TextArray) EncodeBinary(w io.Writer) error { +func (src *TextArray) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, TextOID) } -func (src *TextArray) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *TextArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/timestamp.go b/timestamp.go index ca5eb738..a8b628e9 100644 --- a/timestamp.go +++ b/timestamp.go @@ -127,12 +127,15 @@ func (dst *Timestamp) DecodeBinary(src []byte) error { // EncodeText writes the text encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Timestamp) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if src.Time.Location() != time.UTC { - return fmt.Errorf("cannot encode non-UTC time into timestamp") + return false, fmt.Errorf("cannot encode non-UTC time into timestamp") } var s string @@ -146,28 +149,21 @@ func (src Timestamp) EncodeText(w io.Writer) error { s = "-infinity" } - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, s) + return false, err } // EncodeBinary writes the binary encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Timestamp) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if src.Time.Location() != time.UTC { - return fmt.Errorf("cannot encode non-UTC time into timestamp") - } - - _, err := pgio.WriteInt32(w, 8) - if err != nil { - return err + return false, fmt.Errorf("cannot encode non-UTC time into timestamp") } var microsecSinceY2K int64 @@ -181,6 +177,6 @@ func (src Timestamp) EncodeBinary(w io.Writer) error { microsecSinceY2K = negativeInfinityMicrosecondOffset } - _, err = pgio.WriteInt64(w, microsecSinceY2K) - return err + _, err := pgio.WriteInt64(w, microsecSinceY2K) + return false, err } diff --git a/timestamparray.go b/timestamparray.go index 695559ac..21e4de98 100644 --- a/timestamparray.go +++ b/timestamparray.go @@ -153,26 +153,22 @@ func (dst *TimestampArray) DecodeBinary(src []byte) error { return nil } -func (src *TimestampArray) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -186,100 +182,112 @@ func (src *TimestampArray) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *TimestampArray) EncodeBinary(w io.Writer) error { +func (src *TimestampArray) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, TimestampOID) } -func (src *TimestampArray) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *TimestampArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/timestamptz.go b/timestamptz.go index 7255bb06..f4c67b0b 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -131,9 +131,12 @@ func (dst *Timestamptz) DecodeBinary(src []byte) error { return nil } -func (src Timestamptz) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Timestamptz) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } var s string @@ -147,23 +150,16 @@ func (src Timestamptz) EncodeText(w io.Writer) error { s = "-infinity" } - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, s) + return false, err } -func (src Timestamptz) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err - } - - _, err := pgio.WriteInt32(w, 8) - if err != nil { - return err +func (src Timestamptz) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } var microsecSinceY2K int64 @@ -177,6 +173,6 @@ func (src Timestamptz) EncodeBinary(w io.Writer) error { microsecSinceY2K = negativeInfinityMicrosecondOffset } - _, err = pgio.WriteInt64(w, microsecSinceY2K) - return err + _, err := pgio.WriteInt64(w, microsecSinceY2K) + return false, err } diff --git a/timestamptzarray.go b/timestamptzarray.go index ca416c97..597b1842 100644 --- a/timestamptzarray.go +++ b/timestamptzarray.go @@ -153,26 +153,22 @@ func (dst *TimestamptzArray) DecodeBinary(src []byte) error { return nil } -func (src *TimestamptzArray) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -186,100 +182,112 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *TimestamptzArray) EncodeBinary(w io.Writer) error { +func (src *TimestamptzArray) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, TimestamptzOID) } -func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/to-consider.txt b/to-consider.txt deleted file mode 100644 index ba4f3511..00000000 --- a/to-consider.txt +++ /dev/null @@ -1,9 +0,0 @@ -DecodeText and DecodeBinary take []byte instead of io.Reader -EncodeText and EncodeBinary do not write size -Add Nullable interface with IsNull() and SetNull() - -The above would keep types from needing to worry about writing their own size. Could make EncodeText and DecodeText easier to use with sql.Scanner and driver.Valuer. SetNull() could be removed as DecodeText and DecodeBinary could interpret a nil slice as null. - -EncodeText and EncodeBinary could return (null bool, err error). That would finish removing Nullable interface. - -Also, consider whether arrays and ranges could be represented as generic data types or more common code could be extracted instead of using code generation. diff --git a/typed_array.go.erb b/typed_array.go.erb index 316439ef..2e9b77ea 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -151,26 +151,22 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { return nil } -func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -184,100 +180,112 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `<%= text_null %>`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) error { +func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, <%= element_oid %>) } -func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 1e2dce64..43109700 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -1,11 +1,11 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2OID typed_array.go.erb > int2array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4OID typed_array.go.erb > int4array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8OID typed_array.go.erb > int8array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOID typed_array.go.erb > boolarray.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID typed_array.go.erb > datearray.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID typed_array.go.erb > timestamptzarray.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID typed_array.go.erb > timestamparray.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID typed_array.go.erb > float4array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID typed_array.go.erb > float8array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID typed_array.go.erb > inetarray.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID typed_array.go.erb > textarray.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2OID text_null=NULL typed_array.go.erb > int2array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4OID text_null=NULL typed_array.go.erb > int4array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8OID text_null=NULL typed_array.go.erb > int8array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOID text_null=NULL typed_array.go.erb > boolarray.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID text_null=NULL typed_array.go.erb > datearray.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID text_null=NULL typed_array.go.erb > timestamptzarray.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID text_null=NULL typed_array.go.erb > timestamparray.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID text_null=NULL typed_array.go.erb > float4array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID text_null=NULL typed_array.go.erb > float8array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID text_null=NULL typed_array.go.erb > inetarray.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID text_null='"NULL"' typed_array.go.erb > textarray.go diff --git a/varchararray.go b/varchararray.go index 3a5d8536..9c8829d0 100644 --- a/varchararray.go +++ b/varchararray.go @@ -22,10 +22,10 @@ func (dst *VarcharArray) DecodeBinary(src []byte) error { return (*TextArray)(dst).DecodeBinary(src) } -func (src *VarcharArray) EncodeText(w io.Writer) error { +func (src *VarcharArray) EncodeText(w io.Writer) (bool, error) { return (*TextArray)(src).EncodeText(w) } -func (src *VarcharArray) EncodeBinary(w io.Writer) error { +func (src *VarcharArray) EncodeBinary(w io.Writer) (bool, error) { return (*TextArray)(src).encodeBinary(w, VarcharOID) } diff --git a/xid.go b/xid.go index 389f93bc..6635b21e 100644 --- a/xid.go +++ b/xid.go @@ -41,10 +41,10 @@ func (dst *XID) DecodeBinary(src []byte) error { return (*pguint32)(dst).DecodeBinary(src) } -func (src XID) EncodeText(w io.Writer) error { +func (src XID) EncodeText(w io.Writer) (bool, error) { return (pguint32)(src).EncodeText(w) } -func (src XID) EncodeBinary(w io.Writer) error { +func (src XID) EncodeBinary(w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(w) } From 86620c5e91fc6cc990e72214c91fa54cf88014ec Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 13:32:32 -0600 Subject: [PATCH 0027/1158] Add pgtype.ByteaArray Also fix up quoting array elements for text arrays. --- array.go | 14 +++ boolarray.go | 7 +- byteaarray.go | 287 ++++++++++++++++++++++++++++++++++++++++++++ byteaarray_test.go | 119 ++++++++++++++++++ datearray.go | 7 +- float4array.go | 7 +- float8array.go | 7 +- inetarray.go | 7 +- int2array.go | 7 +- int4array.go | 7 +- int8array.go | 7 +- textarray.go | 7 +- textarray_test.go | 8 +- timestamparray.go | 7 +- timestamptzarray.go | 7 +- typed_array.go.erb | 7 +- typed_array_gen.sh | 1 + 17 files changed, 437 insertions(+), 76 deletions(-) create mode 100644 byteaarray.go create mode 100644 byteaarray_test.go diff --git a/array.go b/array.go index 6b705103..90092c8d 100644 --- a/array.go +++ b/array.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "strconv" + "strings" "unicode" "github.com/jackc/pgx/pgio" @@ -371,3 +372,16 @@ func EncodeTextArrayDimensions(w io.Writer, dimensions []ArrayDimension) error { return pgio.WriteByte(w, '=') } + +var quoteArrayReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteArrayElement(src string) string { + return `"` + quoteArrayReplacer.Replace(src) + `"` +} + +func QuoteArrayElementIfNeeded(src string) string { + if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `{},"\`) { + return quoteArrayElement(src) + } + return src +} diff --git a/boolarray.go b/boolarray.go index f7323281..65a6bc9c 100644 --- a/boolarray.go +++ b/boolarray.go @@ -208,13 +208,8 @@ func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/byteaarray.go b/byteaarray.go new file mode 100644 index 00000000..7a4f1601 --- /dev/null +++ b/byteaarray.go @@ -0,0 +1,287 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type ByteaArray struct { + Elements []Bytea + Dimensions []ArrayDimension + Status Status +} + +func (dst *ByteaArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case ByteaArray: + *dst = value + + case [][]byte: + if value == nil { + *dst = ByteaArray{Status: Null} + } else if len(value) == 0 { + *dst = ByteaArray{Status: Present} + } else { + elements := make([]Bytea, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = ByteaArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Bytea", value) + } + + return nil +} + +func (src *ByteaArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[][]byte: + if src.Status == Present { + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *ByteaArray) DecodeText(src []byte) error { + if src == nil { + *dst = ByteaArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Bytea + + if len(uta.Elements) > 0 { + elements = make([]Bytea, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Bytea + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = ByteaArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *ByteaArray) DecodeBinary(src []byte) error { + if src == nil { + *dst = ByteaArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = ByteaArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Bytea, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) + if err != nil { + return err + } + } + + *dst = ByteaArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil +} + +func (src *ByteaArray) EncodeBinary(w io.Writer) (bool, error) { + return src.encodeBinary(w, ByteaOID) +} + +func (src *ByteaArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err +} diff --git a/byteaarray_test.go b/byteaarray_test.go new file mode 100644 index 00000000..b39776d9 --- /dev/null +++ b/byteaarray_test.go @@ -0,0 +1,119 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestByteaArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "bytea[]", []interface{}{ + &pgtype.ByteaArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ByteaArray{Status: pgtype.Null}, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Status: pgtype.Null}, + pgtype.Bytea{Bytes: []byte{1}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{1}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestByteaArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ByteaArray + }{ + { + source: [][]byte{{1, 2, 3}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([][]byte)(nil)), + result: pgtype.ByteaArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.ByteaArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestByteaArrayAssignTo(t *testing.T) { + var byteByteSlice [][]byte + + simpleTests := []struct { + src pgtype.ByteaArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &byteByteSlice, + expected: [][]byte{{1, 2, 3}}, + }, + { + src: pgtype.ByteaArray{Status: pgtype.Null}, + dst: &byteByteSlice, + expected: (([][]byte)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/datearray.go b/datearray.go index 9552739b..623ff9b3 100644 --- a/datearray.go +++ b/datearray.go @@ -209,13 +209,8 @@ func (src *DateArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/float4array.go b/float4array.go index 9ab08dcc..c55f76d0 100644 --- a/float4array.go +++ b/float4array.go @@ -208,13 +208,8 @@ func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/float8array.go b/float8array.go index ce7e3b90..d08a5351 100644 --- a/float8array.go +++ b/float8array.go @@ -208,13 +208,8 @@ func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/inetarray.go b/inetarray.go index 32cde554..12d9493b 100644 --- a/inetarray.go +++ b/inetarray.go @@ -240,13 +240,8 @@ func (src *InetArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/int2array.go b/int2array.go index f7cc2492..37ee9926 100644 --- a/int2array.go +++ b/int2array.go @@ -239,13 +239,8 @@ func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/int4array.go b/int4array.go index fa710af7..f6f62e4b 100644 --- a/int4array.go +++ b/int4array.go @@ -239,13 +239,8 @@ func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/int8array.go b/int8array.go index 65f42477..92d8ec46 100644 --- a/int8array.go +++ b/int8array.go @@ -239,13 +239,8 @@ func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/textarray.go b/textarray.go index c3e595e0..182e76f5 100644 --- a/textarray.go +++ b/textarray.go @@ -208,13 +208,8 @@ func (src *TextArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/textarray_test.go b/textarray_test.go index 29e3a6c7..a22e003d 100644 --- a/textarray_test.go +++ b/textarray_test.go @@ -25,12 +25,12 @@ func TestTextArrayTranscode(t *testing.T) { &pgtype.TextArray{Status: pgtype.Null}, &pgtype.TextArray{ Elements: []pgtype.Text{ - pgtype.Text{String: "bar", Status: pgtype.Present}, - pgtype.Text{String: "baz", Status: pgtype.Present}, - pgtype.Text{String: "quz", Status: pgtype.Present}, + pgtype.Text{String: "bar ", Status: pgtype.Present}, + pgtype.Text{String: "NuLL", Status: pgtype.Present}, + pgtype.Text{String: `wow"quz\`, Status: pgtype.Present}, pgtype.Text{String: "", Status: pgtype.Present}, pgtype.Text{Status: pgtype.Null}, - pgtype.Text{String: "foo", Status: pgtype.Present}, + pgtype.Text{String: "null", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, diff --git a/timestamparray.go b/timestamparray.go index 21e4de98..b0fb25fa 100644 --- a/timestamparray.go +++ b/timestamparray.go @@ -209,13 +209,8 @@ func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/timestamptzarray.go b/timestamptzarray.go index 597b1842..25374717 100644 --- a/timestamptzarray.go +++ b/timestamptzarray.go @@ -209,13 +209,8 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/typed_array.go.erb b/typed_array.go.erb index 2e9b77ea..f9dba308 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -207,13 +207,8 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 43109700..c63414c8 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -9,3 +9,4 @@ erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]fl erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID text_null=NULL typed_array.go.erb > float8array.go erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID text_null=NULL typed_array.go.erb > inetarray.go erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID text_null='"NULL"' typed_array.go.erb > textarray.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOID text_null=NULL typed_array.go.erb > byteaarray.go From 2f63514c47c9afd18866b8b3a30c18550a0cba69 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 16:13:05 -0600 Subject: [PATCH 0028/1158] Move ACLItem to pgtype --- aclitem.go | 104 ++++++++++++++++++++++++ aclitem_test.go | 97 ++++++++++++++++++++++ aclitemarray.go | 186 +++++++++++++++++++++++++++++++++++++++++++ aclitemarray_test.go | 151 +++++++++++++++++++++++++++++++++++ pgtype.go | 4 +- typed_array_gen.sh | 1 + 6 files changed, 541 insertions(+), 2 deletions(-) create mode 100644 aclitem.go create mode 100644 aclitem_test.go create mode 100644 aclitemarray.go create mode 100644 aclitemarray_test.go diff --git a/aclitem.go b/aclitem.go new file mode 100644 index 00000000..bd7b7d45 --- /dev/null +++ b/aclitem.go @@ -0,0 +1,104 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" +) + +// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem +// might look like this: +// +// postgres=arwdDxt/postgres +// +// Note, however, that because the user/role name part of an aclitem is +// an identifier, it follows all the usual formatting rules for SQL +// identifiers: if it contains spaces and other special characters, +// it should appear in double-quotes: +// +// postgres=arwdDxt/"role with spaces" +// +type ACLItem struct { + String string + Status Status +} + +func (dst *ACLItem) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case ACLItem: + *dst = value + case string: + *dst = ACLItem{String: value, Status: Present} + case *string: + if value == nil { + *dst = ACLItem{Status: Null} + } else { + *dst = ACLItem{String: *value, Status: Present} + } + default: + if originalSrc, ok := underlyingStringType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to ACLItem", value) + } + + return nil +} + +func (src *ACLItem) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *string: + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.String + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if src.Status == Null { + el.Set(reflect.Zero(el.Type())) + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return src.AssignTo(el.Interface()) + case reflect.String: + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + el.SetString(src.String) + return nil + } + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *ACLItem) DecodeText(src []byte) error { + if src == nil { + *dst = ACLItem{Status: Null} + return nil + } + + *dst = ACLItem{String: string(src), Status: Present} + return nil +} + +func (src ACLItem) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, src.String) + return false, err +} diff --git a/aclitem_test.go b/aclitem_test.go new file mode 100644 index 00000000..0b2b6cfa --- /dev/null +++ b/aclitem_test.go @@ -0,0 +1,97 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestACLItemTranscode(t *testing.T) { + testSuccessfulTranscode(t, "aclitem", []interface{}{ + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + pgtype.ACLItem{Status: pgtype.Null}, + }) +} + +func TestACLItemConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ACLItem + }{ + {source: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.ACLItem{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var d pgtype.ACLItem + err := d.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestACLItemAssignTo(t *testing.T) { + var s string + var ps *string + + simpleTests := []struct { + src pgtype.ACLItem + dst interface{} + expected interface{} + }{ + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"}, + {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.ACLItem + dst interface{} + expected interface{} + }{ + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.ACLItem + dst interface{} + }{ + {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/aclitemarray.go b/aclitemarray.go new file mode 100644 index 00000000..d69cd83c --- /dev/null +++ b/aclitemarray.go @@ -0,0 +1,186 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type ACLItemArray struct { + Elements []ACLItem + Dimensions []ArrayDimension + Status Status +} + +func (dst *ACLItemArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case ACLItemArray: + *dst = value + + case []string: + if value == nil { + *dst = ACLItemArray{Status: Null} + } else if len(value) == 0 { + *dst = ACLItemArray{Status: Present} + } else { + elements := make([]ACLItem, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = ACLItemArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to ACLItem", value) + } + + return nil +} + +func (src *ACLItemArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]string: + if src.Status == Present { + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *ACLItemArray) DecodeText(src []byte) error { + if src == nil { + *dst = ACLItemArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []ACLItem + + if len(uta.Elements) > 0 { + elements = make([]ACLItem, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem ACLItem + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (src *ACLItemArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil +} diff --git a/aclitemarray_test.go b/aclitemarray_test.go new file mode 100644 index 00000000..8c01ac66 --- /dev/null +++ b/aclitemarray_test.go @@ -0,0 +1,151 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestACLItemArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "aclitem[]", []interface{}{ + &pgtype.ACLItemArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ACLItemArray{Status: pgtype.Null}, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{Status: pgtype.Null}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestACLItemArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ACLItemArray + }{ + { + source: []string{"=r/postgres"}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.ACLItemArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.ACLItemArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestACLItemArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + + simpleTests := []struct { + src pgtype.ACLItemArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"=r/postgres"}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"=r/postgres"}, + }, + { + src: pgtype.ACLItemArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.ACLItemArray + dst interface{} + }{ + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype.go b/pgtype.go index d6cd53c1..d72217ac 100644 --- a/pgtype.go +++ b/pgtype.go @@ -35,8 +35,8 @@ const ( Int8ArrayOID = 1016 Float4ArrayOID = 1021 Float8ArrayOID = 1022 - AclItemOID = 1033 - AclItemArrayOID = 1034 + ACLItemOID = 1033 + ACLItemArrayOID = 1034 InetArrayOID = 1041 VarcharOID = 1043 DateOID = 1082 diff --git a/typed_array_gen.sh b/typed_array_gen.sh index c63414c8..876f8a3c 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -10,3 +10,4 @@ erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]fl erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID text_null=NULL typed_array.go.erb > inetarray.go erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID text_null='"NULL"' typed_array.go.erb > textarray.go erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOID text_null=NULL typed_array.go.erb > byteaarray.go +erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_oid=ACLItemOID text_null=NULL typed_array.go.erb > aclitemarray.go From a231c5461f67c40ca68cdb6e53663cd20a0d2374 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 16:48:37 -0600 Subject: [PATCH 0029/1158] Move Tid to pgtype --- pgtype.go | 9 ++++- tid.go | 104 ++++++++++++++++++++++++++++++++++++++++++++++++++++ tid_test.go | 15 ++++++++ 3 files changed, 126 insertions(+), 2 deletions(-) create mode 100644 tid.go create mode 100644 tid_test.go diff --git a/pgtype.go b/pgtype.go index d72217ac..8c67c630 100644 --- a/pgtype.go +++ b/pgtype.go @@ -16,7 +16,7 @@ const ( Int4OID = 23 TextOID = 25 OIDOID = 26 - TidOID = 27 + TIDOID = 27 XIDOID = 28 CIDOID = 29 JSONOID = 114 @@ -66,8 +66,13 @@ const ( NegativeInfinity InfinityModifier = -Infinity ) -type Value interface { +type Value interface{} + +type ConverterFrom interface { ConvertFrom(src interface{}) error +} + +type AssignerTo interface { AssignTo(dst interface{}) error } diff --git a/tid.go b/tid.go new file mode 100644 index 00000000..804cced2 --- /dev/null +++ b/tid.go @@ -0,0 +1,104 @@ +package pgtype + +import ( + "encoding/binary" + "fmt" + "io" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +// TID is PostgreSQL's Tuple Identifier type. +// +// When one does +// +// select ctid, * from some_table; +// +// it is the data type of the ctid hidden system column. +// +// It is currently implemented as a pair unsigned two byte integers. +// Its conversion functions can be found in src/backend/utils/adt/tid.c +// in the PostgreSQL sources. +type TID struct { + BlockNumber uint32 + OffsetNumber uint16 + Status Status +} + +func (dst *TID) DecodeText(src []byte) error { + if src == nil { + *dst = TID{Status: Null} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for tid: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return fmt.Errorf("invalid format for tid") + } + + blockNumber, err := strconv.ParseUint(parts[0], 10, 32) + if err != nil { + return err + } + + offsetNumber, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return err + } + + *dst = TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Status: Present} + return nil +} + +func (dst *TID) DecodeBinary(src []byte) error { + if src == nil { + *dst = TID{Status: Null} + return nil + } + + if len(src) != 6 { + return fmt.Errorf("invalid length for tid: %v", len(src)) + } + + *dst = TID{ + BlockNumber: binary.BigEndian.Uint32(src), + OffsetNumber: binary.BigEndian.Uint16(src[4:]), + Status: Present, + } + return nil +} + +func (src TID) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber)) + return false, err +} + +func (src TID) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := pgio.WriteUint32(w, src.BlockNumber) + if err != nil { + return false, err + } + + _, err = pgio.WriteUint16(w, src.OffsetNumber) + return false, err +} diff --git a/tid_test.go b/tid_test.go new file mode 100644 index 00000000..a5aab8a3 --- /dev/null +++ b/tid_test.go @@ -0,0 +1,15 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestTIDTranscode(t *testing.T) { + testSuccessfulTranscode(t, "tid", []interface{}{ + pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, + pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, + pgtype.TID{Status: pgtype.Null}, + }) +} From 44e206ab5b6cdffe6cf638b5b90b5b780377a901 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 16:53:07 -0600 Subject: [PATCH 0030/1158] Rename array files --- aclitemarray.go => aclitem_array.go | 0 aclitemarray_test.go => aclitem_array_test.go | 0 boolarray.go => bool_array.go | 0 boolarray_test.go => bool_array_test.go | 0 byteaarray.go => bytea_array.go | 0 byteaarray_test.go => bytea_array_test.go | 0 cidrarray.go => cidr_array.go | 0 datearray.go => date_array.go | 0 datearray_test.go => date_array_test.go | 0 float4array.go => float4_array.go | 0 float4array_test.go => float4_array_test.go | 0 float8array.go => float8_array.go | 0 float8array_test.go => float8_array_test.go | 0 inetarray.go => inet_array.go | 0 inetarray_test.go => inet_array_test.go | 0 int2array.go => int2_array.go | 0 int2array_test.go => int2_array_test.go | 0 int4array.go => int4_array.go | 0 int4array_test.go => int4_array_test.go | 0 int8array.go => int8_array.go | 0 int8array_test.go => int8_array_test.go | 0 textarray.go => text_array.go | 0 textarray_test.go => text_array_test.go | 0 timestamparray.go => timestamp_array.go | 0 ...mparray_test.go => timestamp_array_test.go | 0 timestamptzarray.go => timestamptz_array.go | 0 ...array_test.go => timestamptz_array_test.go | 0 typed_array_gen.sh | 26 +++++++++---------- varchararray.go => varchar_array.go | 0 29 files changed, 13 insertions(+), 13 deletions(-) rename aclitemarray.go => aclitem_array.go (100%) rename aclitemarray_test.go => aclitem_array_test.go (100%) rename boolarray.go => bool_array.go (100%) rename boolarray_test.go => bool_array_test.go (100%) rename byteaarray.go => bytea_array.go (100%) rename byteaarray_test.go => bytea_array_test.go (100%) rename cidrarray.go => cidr_array.go (100%) rename datearray.go => date_array.go (100%) rename datearray_test.go => date_array_test.go (100%) rename float4array.go => float4_array.go (100%) rename float4array_test.go => float4_array_test.go (100%) rename float8array.go => float8_array.go (100%) rename float8array_test.go => float8_array_test.go (100%) rename inetarray.go => inet_array.go (100%) rename inetarray_test.go => inet_array_test.go (100%) rename int2array.go => int2_array.go (100%) rename int2array_test.go => int2_array_test.go (100%) rename int4array.go => int4_array.go (100%) rename int4array_test.go => int4_array_test.go (100%) rename int8array.go => int8_array.go (100%) rename int8array_test.go => int8_array_test.go (100%) rename textarray.go => text_array.go (100%) rename textarray_test.go => text_array_test.go (100%) rename timestamparray.go => timestamp_array.go (100%) rename timestamparray_test.go => timestamp_array_test.go (100%) rename timestamptzarray.go => timestamptz_array.go (100%) rename timestamptzarray_test.go => timestamptz_array_test.go (100%) rename varchararray.go => varchar_array.go (100%) diff --git a/aclitemarray.go b/aclitem_array.go similarity index 100% rename from aclitemarray.go rename to aclitem_array.go diff --git a/aclitemarray_test.go b/aclitem_array_test.go similarity index 100% rename from aclitemarray_test.go rename to aclitem_array_test.go diff --git a/boolarray.go b/bool_array.go similarity index 100% rename from boolarray.go rename to bool_array.go diff --git a/boolarray_test.go b/bool_array_test.go similarity index 100% rename from boolarray_test.go rename to bool_array_test.go diff --git a/byteaarray.go b/bytea_array.go similarity index 100% rename from byteaarray.go rename to bytea_array.go diff --git a/byteaarray_test.go b/bytea_array_test.go similarity index 100% rename from byteaarray_test.go rename to bytea_array_test.go diff --git a/cidrarray.go b/cidr_array.go similarity index 100% rename from cidrarray.go rename to cidr_array.go diff --git a/datearray.go b/date_array.go similarity index 100% rename from datearray.go rename to date_array.go diff --git a/datearray_test.go b/date_array_test.go similarity index 100% rename from datearray_test.go rename to date_array_test.go diff --git a/float4array.go b/float4_array.go similarity index 100% rename from float4array.go rename to float4_array.go diff --git a/float4array_test.go b/float4_array_test.go similarity index 100% rename from float4array_test.go rename to float4_array_test.go diff --git a/float8array.go b/float8_array.go similarity index 100% rename from float8array.go rename to float8_array.go diff --git a/float8array_test.go b/float8_array_test.go similarity index 100% rename from float8array_test.go rename to float8_array_test.go diff --git a/inetarray.go b/inet_array.go similarity index 100% rename from inetarray.go rename to inet_array.go diff --git a/inetarray_test.go b/inet_array_test.go similarity index 100% rename from inetarray_test.go rename to inet_array_test.go diff --git a/int2array.go b/int2_array.go similarity index 100% rename from int2array.go rename to int2_array.go diff --git a/int2array_test.go b/int2_array_test.go similarity index 100% rename from int2array_test.go rename to int2_array_test.go diff --git a/int4array.go b/int4_array.go similarity index 100% rename from int4array.go rename to int4_array.go diff --git a/int4array_test.go b/int4_array_test.go similarity index 100% rename from int4array_test.go rename to int4_array_test.go diff --git a/int8array.go b/int8_array.go similarity index 100% rename from int8array.go rename to int8_array.go diff --git a/int8array_test.go b/int8_array_test.go similarity index 100% rename from int8array_test.go rename to int8_array_test.go diff --git a/textarray.go b/text_array.go similarity index 100% rename from textarray.go rename to text_array.go diff --git a/textarray_test.go b/text_array_test.go similarity index 100% rename from textarray_test.go rename to text_array_test.go diff --git a/timestamparray.go b/timestamp_array.go similarity index 100% rename from timestamparray.go rename to timestamp_array.go diff --git a/timestamparray_test.go b/timestamp_array_test.go similarity index 100% rename from timestamparray_test.go rename to timestamp_array_test.go diff --git a/timestamptzarray.go b/timestamptz_array.go similarity index 100% rename from timestamptzarray.go rename to timestamptz_array.go diff --git a/timestamptzarray_test.go b/timestamptz_array_test.go similarity index 100% rename from timestamptzarray_test.go rename to timestamptz_array_test.go diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 876f8a3c..32c298cc 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -1,13 +1,13 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2OID text_null=NULL typed_array.go.erb > int2array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4OID text_null=NULL typed_array.go.erb > int4array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8OID text_null=NULL typed_array.go.erb > int8array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOID text_null=NULL typed_array.go.erb > boolarray.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID text_null=NULL typed_array.go.erb > datearray.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID text_null=NULL typed_array.go.erb > timestamptzarray.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID text_null=NULL typed_array.go.erb > timestamparray.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID text_null=NULL typed_array.go.erb > float4array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID text_null=NULL typed_array.go.erb > float8array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID text_null=NULL typed_array.go.erb > inetarray.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID text_null='"NULL"' typed_array.go.erb > textarray.go -erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOID text_null=NULL typed_array.go.erb > byteaarray.go -erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_oid=ACLItemOID text_null=NULL typed_array.go.erb > aclitemarray.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2OID text_null=NULL typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4OID text_null=NULL typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8OID text_null=NULL typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOID text_null=NULL typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID text_null=NULL typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID text_null=NULL typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID text_null=NULL typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID text_null=NULL typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID text_null=NULL typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID text_null=NULL typed_array.go.erb > inet_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID text_null='"NULL"' typed_array.go.erb > text_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOID text_null=NULL typed_array.go.erb > bytea_array.go +erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_oid=ACLItemOID text_null=NULL typed_array.go.erb > aclitem_array.go diff --git a/varchararray.go b/varchar_array.go similarity index 100% rename from varchararray.go rename to varchar_array.go From 666af9ead53cd436652e705ee7bb4cbfd4259d40 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 17:03:23 -0600 Subject: [PATCH 0031/1158] Name PG types as words Though this doesn't follow Go naming conventions exactly it makes names more consistent with PostgreSQL and it is easier to read. For example, TIDOID becomes TidOid. In addition this is one less breaking change in the move to V3. --- aclitem.go | 26 +++++++------- aclitem_array.go | 34 +++++++++--------- aclitem_array_test.go | 74 +++++++++++++++++++------------------- aclitem_test.go | 36 +++++++++---------- array.go | 6 ++-- bool_array.go | 6 ++-- bytea_array.go | 6 ++-- cid.go | 20 +++++------ cid_test.go | 30 ++++++++-------- cidr_array.go | 2 +- date_array.go | 6 ++-- extra-interface.txt | 2 +- float4_array.go | 6 ++-- float8_array.go | 6 ++-- inet_array.go | 6 ++-- inet_array_test.go | 36 +++++++++---------- inet_test.go | 38 ++++++++++---------- int2_array.go | 6 ++-- int4_array.go | 6 ++-- int8_array.go | 6 ++-- oid.go | 20 +++++------ oid_test.go | 30 ++++++++-------- pgtype.go | 82 +++++++++++++++++++++---------------------- pgtype_test.go | 2 +- pguint32.go | 2 +- text_array.go | 6 ++-- tid.go | 20 +++++------ tid_test.go | 8 ++--- timestamp_array.go | 6 ++-- timestamptz_array.go | 6 ++-- typed_array.go.erb | 4 +-- typed_array_gen.sh | 26 +++++++------- varchar_array.go | 2 +- xid.go | 20 +++++------ xid_test.go | 30 ++++++++-------- 35 files changed, 311 insertions(+), 311 deletions(-) diff --git a/aclitem.go b/aclitem.go index bd7b7d45..821c5001 100644 --- a/aclitem.go +++ b/aclitem.go @@ -6,7 +6,7 @@ import ( "reflect" ) -// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem +// Aclitem is used for PostgreSQL's aclitem data type. A sample aclitem // might look like this: // // postgres=arwdDxt/postgres @@ -18,34 +18,34 @@ import ( // // postgres=arwdDxt/"role with spaces" // -type ACLItem struct { +type Aclitem struct { String string Status Status } -func (dst *ACLItem) ConvertFrom(src interface{}) error { +func (dst *Aclitem) ConvertFrom(src interface{}) error { switch value := src.(type) { - case ACLItem: + case Aclitem: *dst = value case string: - *dst = ACLItem{String: value, Status: Present} + *dst = Aclitem{String: value, Status: Present} case *string: if value == nil { - *dst = ACLItem{Status: Null} + *dst = Aclitem{Status: Null} } else { - *dst = ACLItem{String: *value, Status: Present} + *dst = Aclitem{String: *value, Status: Present} } default: if originalSrc, ok := underlyingStringType(src); ok { return dst.ConvertFrom(originalSrc) } - return fmt.Errorf("cannot convert %v to ACLItem", value) + return fmt.Errorf("cannot convert %v to Aclitem", value) } return nil } -func (src *ACLItem) AssignTo(dst interface{}) error { +func (src *Aclitem) AssignTo(dst interface{}) error { switch v := dst.(type) { case *string: if src.Status != Present { @@ -81,17 +81,17 @@ func (src *ACLItem) AssignTo(dst interface{}) error { return nil } -func (dst *ACLItem) DecodeText(src []byte) error { +func (dst *Aclitem) DecodeText(src []byte) error { if src == nil { - *dst = ACLItem{Status: Null} + *dst = Aclitem{Status: Null} return nil } - *dst = ACLItem{String: string(src), Status: Present} + *dst = Aclitem{String: string(src), Status: Present} return nil } -func (src ACLItem) EncodeText(w io.Writer) (bool, error) { +func (src Aclitem) EncodeText(w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/aclitem_array.go b/aclitem_array.go index d69cd83c..48f5cd38 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -8,30 +8,30 @@ import ( "github.com/jackc/pgx/pgio" ) -type ACLItemArray struct { - Elements []ACLItem +type AclitemArray struct { + Elements []Aclitem Dimensions []ArrayDimension Status Status } -func (dst *ACLItemArray) ConvertFrom(src interface{}) error { +func (dst *AclitemArray) ConvertFrom(src interface{}) error { switch value := src.(type) { - case ACLItemArray: + case AclitemArray: *dst = value case []string: if value == nil { - *dst = ACLItemArray{Status: Null} + *dst = AclitemArray{Status: Null} } else if len(value) == 0 { - *dst = ACLItemArray{Status: Present} + *dst = AclitemArray{Status: Present} } else { - elements := make([]ACLItem, len(value)) + elements := make([]Aclitem, len(value)) for i := range value { if err := elements[i].ConvertFrom(value[i]); err != nil { return err } } - *dst = ACLItemArray{ + *dst = AclitemArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, Status: Present, @@ -42,13 +42,13 @@ func (dst *ACLItemArray) ConvertFrom(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) } - return fmt.Errorf("cannot convert %v to ACLItem", value) + return fmt.Errorf("cannot convert %v to Aclitem", value) } return nil } -func (src *ACLItemArray) AssignTo(dst interface{}) error { +func (src *AclitemArray) AssignTo(dst interface{}) error { switch v := dst.(type) { case *[]string: @@ -73,9 +73,9 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { return nil } -func (dst *ACLItemArray) DecodeText(src []byte) error { +func (dst *AclitemArray) DecodeText(src []byte) error { if src == nil { - *dst = ACLItemArray{Status: Null} + *dst = AclitemArray{Status: Null} return nil } @@ -84,13 +84,13 @@ func (dst *ACLItemArray) DecodeText(src []byte) error { return err } - var elements []ACLItem + var elements []Aclitem if len(uta.Elements) > 0 { - elements = make([]ACLItem, len(uta.Elements)) + elements = make([]Aclitem, len(uta.Elements)) for i, s := range uta.Elements { - var elem ACLItem + var elem Aclitem var elemSrc []byte if s != "NULL" { elemSrc = []byte(s) @@ -104,12 +104,12 @@ func (dst *ACLItemArray) DecodeText(src []byte) error { } } - *dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = AclitemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} return nil } -func (src *ACLItemArray) EncodeText(w io.Writer) (bool, error) { +func (src *AclitemArray) EncodeText(w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/aclitem_array_test.go b/aclitem_array_test.go index 8c01ac66..e78f14c6 100644 --- a/aclitem_array_test.go +++ b/aclitem_array_test.go @@ -7,40 +7,40 @@ import ( "github.com/jackc/pgx/pgtype" ) -func TestACLItemArrayTranscode(t *testing.T) { +func TestAclitemArrayTranscode(t *testing.T) { testSuccessfulTranscode(t, "aclitem[]", []interface{}{ - &pgtype.ACLItemArray{ + &pgtype.AclitemArray{ Elements: nil, Dimensions: nil, Status: pgtype.Present, }, - &pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.ACLItem{Status: pgtype.Null}, + &pgtype.AclitemArray{ + Elements: []pgtype.Aclitem{ + pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.Aclitem{Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, - &pgtype.ACLItemArray{Status: pgtype.Null}, - &pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.ACLItem{Status: pgtype.Null}, - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + &pgtype.AclitemArray{Status: pgtype.Null}, + &pgtype.AclitemArray{ + Elements: []pgtype.Aclitem{ + pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.Aclitem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.Aclitem{Status: pgtype.Null}, + pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, - &pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + &pgtype.AclitemArray{ + Elements: []pgtype.Aclitem{ + pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, @@ -51,26 +51,26 @@ func TestACLItemArrayTranscode(t *testing.T) { }) } -func TestACLItemArrayConvertFrom(t *testing.T) { +func TestAclitemArrayConvertFrom(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.ACLItemArray + result pgtype.AclitemArray }{ { source: []string{"=r/postgres"}, - result: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + result: pgtype.AclitemArray{ + Elements: []pgtype.Aclitem{{String: "=r/postgres", Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, { source: (([]string)(nil)), - result: pgtype.ACLItemArray{Status: pgtype.Null}, + result: pgtype.AclitemArray{Status: pgtype.Null}, }, } for i, tt := range successfulTests { - var r pgtype.ACLItemArray + var r pgtype.AclitemArray err := r.ConvertFrom(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -82,19 +82,19 @@ func TestACLItemArrayConvertFrom(t *testing.T) { } } -func TestACLItemArrayAssignTo(t *testing.T) { +func TestAclitemArrayAssignTo(t *testing.T) { var stringSlice []string type _stringSlice []string var namedStringSlice _stringSlice simpleTests := []struct { - src pgtype.ACLItemArray + src pgtype.AclitemArray dst interface{} expected interface{} }{ { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + src: pgtype.AclitemArray{ + Elements: []pgtype.Aclitem{{String: "=r/postgres", Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, @@ -102,8 +102,8 @@ func TestACLItemArrayAssignTo(t *testing.T) { expected: []string{"=r/postgres"}, }, { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + src: pgtype.AclitemArray{ + Elements: []pgtype.Aclitem{{String: "=r/postgres", Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, @@ -111,7 +111,7 @@ func TestACLItemArrayAssignTo(t *testing.T) { expected: _stringSlice{"=r/postgres"}, }, { - src: pgtype.ACLItemArray{Status: pgtype.Null}, + src: pgtype.AclitemArray{Status: pgtype.Null}, dst: &stringSlice, expected: (([]string)(nil)), }, @@ -129,12 +129,12 @@ func TestACLItemArrayAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.ACLItemArray + src pgtype.AclitemArray dst interface{} }{ { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{Status: pgtype.Null}}, + src: pgtype.AclitemArray{ + Elements: []pgtype.Aclitem{{Status: pgtype.Null}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, diff --git a/aclitem_test.go b/aclitem_test.go index 0b2b6cfa..fc429acc 100644 --- a/aclitem_test.go +++ b/aclitem_test.go @@ -7,26 +7,26 @@ import ( "github.com/jackc/pgx/pgtype" ) -func TestACLItemTranscode(t *testing.T) { +func TestAclitemTranscode(t *testing.T) { testSuccessfulTranscode(t, "aclitem", []interface{}{ - pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - pgtype.ACLItem{Status: pgtype.Null}, + pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.Aclitem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + pgtype.Aclitem{Status: pgtype.Null}, }) } -func TestACLItemConvertFrom(t *testing.T) { +func TestAclitemConvertFrom(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.ACLItem + result pgtype.Aclitem }{ - {source: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - {source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - {source: (*string)(nil), result: pgtype.ACLItem{Status: pgtype.Null}}, + {source: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, result: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {source: "postgres=arwdDxt/postgres", result: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.Aclitem{Status: pgtype.Null}}, } for i, tt := range successfulTests { - var d pgtype.ACLItem + var d pgtype.Aclitem err := d.ConvertFrom(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -38,17 +38,17 @@ func TestACLItemConvertFrom(t *testing.T) { } } -func TestACLItemAssignTo(t *testing.T) { +func TestAclitemAssignTo(t *testing.T) { var s string var ps *string simpleTests := []struct { - src pgtype.ACLItem + src pgtype.Aclitem dst interface{} expected interface{} }{ - {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"}, - {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + {src: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"}, + {src: pgtype.Aclitem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, } for i, tt := range simpleTests { @@ -63,11 +63,11 @@ func TestACLItemAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.ACLItem + src pgtype.Aclitem dst interface{} expected interface{} }{ - {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, + {src: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, } for i, tt := range pointerAllocTests { @@ -82,10 +82,10 @@ func TestACLItemAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.ACLItem + src pgtype.Aclitem dst interface{} }{ - {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &s}, + {src: pgtype.Aclitem{Status: pgtype.Null}, dst: &s}, } for i, tt := range errorTests { diff --git a/array.go b/array.go index 90092c8d..dff0fe81 100644 --- a/array.go +++ b/array.go @@ -18,7 +18,7 @@ import ( type ArrayHeader struct { ContainsNull bool - ElementOID int32 + ElementOid int32 Dimensions []ArrayDimension } @@ -40,7 +40,7 @@ func (dst *ArrayHeader) DecodeBinary(src []byte) (int, error) { dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1 rp += 4 - dst.ElementOID = int32(binary.BigEndian.Uint32(src[rp:])) + dst.ElementOid = int32(binary.BigEndian.Uint32(src[rp:])) rp += 4 if numDims > 0 { @@ -75,7 +75,7 @@ func (src *ArrayHeader) EncodeBinary(w io.Writer) error { return err } - _, err = pgio.WriteInt32(w, src.ElementOID) + _, err = pgio.WriteInt32(w, src.ElementOid) if err != nil { return err } diff --git a/bool_array.go b/bool_array.go index 65a6bc9c..a74e9f90 100644 --- a/bool_array.go +++ b/bool_array.go @@ -229,10 +229,10 @@ func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { } func (src *BoolArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, BoolOID) + return src.encodeBinary(w, BoolOid) } -func (src *BoolArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -241,7 +241,7 @@ func (src *BoolArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/bytea_array.go b/bytea_array.go index 7a4f1601..9003eafd 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -229,10 +229,10 @@ func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { } func (src *ByteaArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, ByteaOID) + return src.encodeBinary(w, ByteaOid) } -func (src *ByteaArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -241,7 +241,7 @@ func (src *ByteaArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/cid.go b/cid.go index 41b817bb..be93a03e 100644 --- a/cid.go +++ b/cid.go @@ -4,7 +4,7 @@ import ( "io" ) -// CID is PostgreSQL's Command Identifier type. +// Cid is PostgreSQL's Command Identifier type. // // When one does // @@ -15,33 +15,33 @@ import ( // It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/c.h as CommandId // in the PostgreSQL sources. -type CID pguint32 +type Cid pguint32 -// ConvertFrom converts from src to dst. Note that as CID is not a general +// ConvertFrom converts from src to dst. Note that as Cid is not a general // number type ConvertFrom does not do automatic type conversion as other number // types do. -func (dst *CID) ConvertFrom(src interface{}) error { +func (dst *Cid) ConvertFrom(src interface{}) error { return (*pguint32)(dst).ConvertFrom(src) } -// AssignTo assigns from src to dst. Note that as CID is not a general number +// AssignTo assigns from src to dst. Note that as Cid is not a general number // type AssignTo does not do automatic type conversion as other number types do. -func (src *CID) AssignTo(dst interface{}) error { +func (src *Cid) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *CID) DecodeText(src []byte) error { +func (dst *Cid) DecodeText(src []byte) error { return (*pguint32)(dst).DecodeText(src) } -func (dst *CID) DecodeBinary(src []byte) error { +func (dst *Cid) DecodeBinary(src []byte) error { return (*pguint32)(dst).DecodeBinary(src) } -func (src CID) EncodeText(w io.Writer) (bool, error) { +func (src Cid) EncodeText(w io.Writer) (bool, error) { return (pguint32)(src).EncodeText(w) } -func (src CID) EncodeBinary(w io.Writer) (bool, error) { +func (src Cid) EncodeBinary(w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(w) } diff --git a/cid_test.go b/cid_test.go index 72f5dfea..7d9fde34 100644 --- a/cid_test.go +++ b/cid_test.go @@ -7,23 +7,23 @@ import ( "github.com/jackc/pgx/pgtype" ) -func TestCIDTranscode(t *testing.T) { +func TestCidTranscode(t *testing.T) { testSuccessfulTranscode(t, "cid", []interface{}{ - pgtype.CID{Uint: 42, Status: pgtype.Present}, - pgtype.CID{Status: pgtype.Null}, + pgtype.Cid{Uint: 42, Status: pgtype.Present}, + pgtype.Cid{Status: pgtype.Null}, }) } -func TestCIDConvertFrom(t *testing.T) { +func TestCidConvertFrom(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.CID + result pgtype.Cid }{ - {source: uint32(1), result: pgtype.CID{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Cid{Uint: 1, Status: pgtype.Present}}, } for i, tt := range successfulTests { - var r pgtype.CID + var r pgtype.Cid err := r.ConvertFrom(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -35,17 +35,17 @@ func TestCIDConvertFrom(t *testing.T) { } } -func TestCIDAssignTo(t *testing.T) { +func TestCidAssignTo(t *testing.T) { var ui32 uint32 var pui32 *uint32 simpleTests := []struct { - src pgtype.CID + src pgtype.Cid dst interface{} expected interface{} }{ - {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.CID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.Cid{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Cid{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -60,11 +60,11 @@ func TestCIDAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.CID + src pgtype.Cid dst interface{} expected interface{} }{ - {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.Cid{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -79,10 +79,10 @@ func TestCIDAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.CID + src pgtype.Cid dst interface{} }{ - {src: pgtype.CID{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.Cid{Status: pgtype.Null}, dst: &ui32}, } for i, tt := range errorTests { diff --git a/cidr_array.go b/cidr_array.go index cb81d2b9..e0219ee5 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -27,5 +27,5 @@ func (src *CidrArray) EncodeText(w io.Writer) (bool, error) { } func (src *CidrArray) EncodeBinary(w io.Writer) (bool, error) { - return (*InetArray)(src).encodeBinary(w, CidrOID) + return (*InetArray)(src).encodeBinary(w, CidrOid) } diff --git a/date_array.go b/date_array.go index 623ff9b3..8f7cba18 100644 --- a/date_array.go +++ b/date_array.go @@ -230,10 +230,10 @@ func (src *DateArray) EncodeText(w io.Writer) (bool, error) { } func (src *DateArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, DateOID) + return src.encodeBinary(w, DateOid) } -func (src *DateArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -242,7 +242,7 @@ func (src *DateArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/extra-interface.txt b/extra-interface.txt index 16453823..f07818bc 100644 --- a/extra-interface.txt +++ b/extra-interface.txt @@ -1,3 +1,3 @@ Can pass function to get inet data and function to get oid/name mapping as optional interface with io.Reader or io.Writer -Could be useful for arrays of types without defined OIDs like hstore. +Could be useful for arrays of types without defined Oids like hstore. diff --git a/float4_array.go b/float4_array.go index c55f76d0..632e7e4b 100644 --- a/float4_array.go +++ b/float4_array.go @@ -229,10 +229,10 @@ func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { } func (src *Float4Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Float4OID) + return src.encodeBinary(w, Float4Oid) } -func (src *Float4Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -241,7 +241,7 @@ func (src *Float4Array) encodeBinary(w io.Writer, elementOID int32) (bool, error } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/float8_array.go b/float8_array.go index d08a5351..68cf30f2 100644 --- a/float8_array.go +++ b/float8_array.go @@ -229,10 +229,10 @@ func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { } func (src *Float8Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Float8OID) + return src.encodeBinary(w, Float8Oid) } -func (src *Float8Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -241,7 +241,7 @@ func (src *Float8Array) encodeBinary(w io.Writer, elementOID int32) (bool, error } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/inet_array.go b/inet_array.go index 12d9493b..629cd51f 100644 --- a/inet_array.go +++ b/inet_array.go @@ -261,10 +261,10 @@ func (src *InetArray) EncodeText(w io.Writer) (bool, error) { } func (src *InetArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, InetOID) + return src.encodeBinary(w, InetOid) } -func (src *InetArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -273,7 +273,7 @@ func (src *InetArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/inet_array_test.go b/inet_array_test.go index 8cab5355..523a9f8d 100644 --- a/inet_array_test.go +++ b/inet_array_test.go @@ -17,7 +17,7 @@ func TestInetArrayTranscode(t *testing.T) { }, &pgtype.InetArray{ Elements: []pgtype.Inet{ - pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, pgtype.Inet{Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, @@ -26,22 +26,22 @@ func TestInetArrayTranscode(t *testing.T) { &pgtype.InetArray{Status: pgtype.Null}, &pgtype.InetArray{ Elements: []pgtype.Inet{ - pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, pgtype.Inet{Status: pgtype.Null}, - pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.InetArray{ Elements: []pgtype.Inet{ - pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, @@ -58,9 +58,9 @@ func TestInetArrayConvertFrom(t *testing.T) { result pgtype.InetArray }{ { - source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + source: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, @@ -69,9 +69,9 @@ func TestInetArrayConvertFrom(t *testing.T) { result: pgtype.InetArray{Status: pgtype.Null}, }, { - source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + source: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, @@ -105,12 +105,12 @@ func TestInetArrayAssignTo(t *testing.T) { }{ { src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, dst: &ipnetSlice, - expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + expected: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, }, { src: pgtype.InetArray{ @@ -123,12 +123,12 @@ func TestInetArrayAssignTo(t *testing.T) { }, { src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, dst: &ipSlice, - expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + expected: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, }, { src: pgtype.InetArray{ diff --git a/inet_test.go b/inet_test.go index 5e86376b..5a326810 100644 --- a/inet_test.go +++ b/inet_test.go @@ -11,16 +11,16 @@ import ( func TestInetTranscode(t *testing.T) { for _, pgTypeName := range []string{"inet", "cidr"} { testSuccessfulTranscode(t, pgTypeName, []interface{}{ - pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "0.0.0.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "192.168.1.0/24"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "255.255.255.255/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "::/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "::/0"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "::1/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, pgtype.Inet{Status: pgtype.Null}, }) } @@ -31,10 +31,10 @@ func TestInetConvertFrom(t *testing.T) { source interface{} result pgtype.Inet }{ - {source: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Null}, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Null}}, - {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Null}, result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Null}}, + {source: mustParseCidr(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: mustParseCidr(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, } for i, tt := range successfulTests { @@ -61,8 +61,8 @@ func TestInetAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + {src: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCidr(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCidr(t, "127.0.0.1/32").IP}, {src: pgtype.Inet{Status: pgtype.Null}, dst: &pipnet, expected: ((*net.IPNet)(nil))}, {src: pgtype.Inet{Status: pgtype.Null}, dst: &pip, expected: ((*net.IP)(nil))}, } @@ -83,8 +83,8 @@ func TestInetAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + {src: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCidr(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCidr(t, "127.0.0.1/32").IP}, } for i, tt := range pointerAllocTests { @@ -102,7 +102,7 @@ func TestInetAssignTo(t *testing.T) { src pgtype.Inet dst interface{} }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip}, + {src: pgtype.Inet{IPNet: mustParseCidr(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip}, {src: pgtype.Inet{Status: pgtype.Null}, dst: &ipnet}, } diff --git a/int2_array.go b/int2_array.go index 37ee9926..d8268c0a 100644 --- a/int2_array.go +++ b/int2_array.go @@ -260,10 +260,10 @@ func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { } func (src *Int2Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int2OID) + return src.encodeBinary(w, Int2Oid) } -func (src *Int2Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -272,7 +272,7 @@ func (src *Int2Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/int4_array.go b/int4_array.go index f6f62e4b..dcdb50c1 100644 --- a/int4_array.go +++ b/int4_array.go @@ -260,10 +260,10 @@ func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { } func (src *Int4Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int4OID) + return src.encodeBinary(w, Int4Oid) } -func (src *Int4Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -272,7 +272,7 @@ func (src *Int4Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/int8_array.go b/int8_array.go index 92d8ec46..ed82f079 100644 --- a/int8_array.go +++ b/int8_array.go @@ -260,10 +260,10 @@ func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { } func (src *Int8Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int8OID) + return src.encodeBinary(w, Int8Oid) } -func (src *Int8Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -272,7 +272,7 @@ func (src *Int8Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/oid.go b/oid.go index e1bee4cf..c77f3f10 100644 --- a/oid.go +++ b/oid.go @@ -4,38 +4,38 @@ import ( "io" ) -// OID (Object Identifier Type) is, according to +// Oid (Object Identifier Type) is, according to // https://www.postgresql.org/docs/current/static/datatype-oid.html, used // internally by PostgreSQL as a primary key for various system tables. It is // currently implemented as an unsigned four-byte integer. Its definition can be // found in src/include/postgres_ext.h in the PostgreSQL sources. -type OID pguint32 +type Oid pguint32 -// ConvertFrom converts from src to dst. Note that as OID is not a general +// ConvertFrom converts from src to dst. Note that as Oid is not a general // number type ConvertFrom does not do automatic type conversion as other number // types do. -func (dst *OID) ConvertFrom(src interface{}) error { +func (dst *Oid) ConvertFrom(src interface{}) error { return (*pguint32)(dst).ConvertFrom(src) } -// AssignTo assigns from src to dst. Note that as OID is not a general number +// AssignTo assigns from src to dst. Note that as Oid is not a general number // type AssignTo does not do automatic type conversion as other number types do. -func (src *OID) AssignTo(dst interface{}) error { +func (src *Oid) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *OID) DecodeText(src []byte) error { +func (dst *Oid) DecodeText(src []byte) error { return (*pguint32)(dst).DecodeText(src) } -func (dst *OID) DecodeBinary(src []byte) error { +func (dst *Oid) DecodeBinary(src []byte) error { return (*pguint32)(dst).DecodeBinary(src) } -func (src OID) EncodeText(w io.Writer) (bool, error) { +func (src Oid) EncodeText(w io.Writer) (bool, error) { return (pguint32)(src).EncodeText(w) } -func (src OID) EncodeBinary(w io.Writer) (bool, error) { +func (src Oid) EncodeBinary(w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(w) } diff --git a/oid_test.go b/oid_test.go index c8e0b2d6..bbab6699 100644 --- a/oid_test.go +++ b/oid_test.go @@ -7,23 +7,23 @@ import ( "github.com/jackc/pgx/pgtype" ) -func TestOIDTranscode(t *testing.T) { +func TestOidTranscode(t *testing.T) { testSuccessfulTranscode(t, "oid", []interface{}{ - pgtype.OID{Uint: 42, Status: pgtype.Present}, - pgtype.OID{Status: pgtype.Null}, + pgtype.Oid{Uint: 42, Status: pgtype.Present}, + pgtype.Oid{Status: pgtype.Null}, }) } -func TestOIDConvertFrom(t *testing.T) { +func TestOidConvertFrom(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.OID + result pgtype.Oid }{ - {source: uint32(1), result: pgtype.OID{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Oid{Uint: 1, Status: pgtype.Present}}, } for i, tt := range successfulTests { - var r pgtype.OID + var r pgtype.Oid err := r.ConvertFrom(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -35,17 +35,17 @@ func TestOIDConvertFrom(t *testing.T) { } } -func TestOIDAssignTo(t *testing.T) { +func TestOidAssignTo(t *testing.T) { var ui32 uint32 var pui32 *uint32 simpleTests := []struct { - src pgtype.OID + src pgtype.Oid dst interface{} expected interface{} }{ - {src: pgtype.OID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.OID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.Oid{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Oid{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -60,11 +60,11 @@ func TestOIDAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.OID + src pgtype.Oid dst interface{} expected interface{} }{ - {src: pgtype.OID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.Oid{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -79,10 +79,10 @@ func TestOIDAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.OID + src pgtype.Oid dst interface{} }{ - {src: pgtype.OID{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.Oid{Status: pgtype.Null}, dst: &ui32}, } for i, tt := range errorTests { diff --git a/pgtype.go b/pgtype.go index 8c67c630..cbcd6bd5 100644 --- a/pgtype.go +++ b/pgtype.go @@ -7,47 +7,47 @@ import ( // PostgreSQL oids for common types const ( - BoolOID = 16 - ByteaOID = 17 - CharOID = 18 - NameOID = 19 - Int8OID = 20 - Int2OID = 21 - Int4OID = 23 - TextOID = 25 - OIDOID = 26 - TIDOID = 27 - XIDOID = 28 - CIDOID = 29 - JSONOID = 114 - CidrOID = 650 - CidrArrayOID = 651 - Float4OID = 700 - Float8OID = 701 - UnknownOID = 705 - InetOID = 869 - BoolArrayOID = 1000 - Int2ArrayOID = 1005 - Int4ArrayOID = 1007 - TextArrayOID = 1009 - ByteaArrayOID = 1001 - VarcharArrayOID = 1015 - Int8ArrayOID = 1016 - Float4ArrayOID = 1021 - Float8ArrayOID = 1022 - ACLItemOID = 1033 - ACLItemArrayOID = 1034 - InetArrayOID = 1041 - VarcharOID = 1043 - DateOID = 1082 - TimestampOID = 1114 - TimestampArrayOID = 1115 - DateArrayOID = 1182 - TimestamptzOID = 1184 - TimestamptzArrayOID = 1185 - RecordOID = 2249 - UUIDOID = 2950 - JSONBOID = 3802 + BoolOid = 16 + ByteaOid = 17 + CharOid = 18 + NameOid = 19 + Int8Oid = 20 + Int2Oid = 21 + Int4Oid = 23 + TextOid = 25 + OidOid = 26 + TidOid = 27 + XidOid = 28 + CidOid = 29 + JsonOid = 114 + CidrOid = 650 + CidrArrayOid = 651 + Float4Oid = 700 + Float8Oid = 701 + UnknownOid = 705 + InetOid = 869 + BoolArrayOid = 1000 + Int2ArrayOid = 1005 + Int4ArrayOid = 1007 + TextArrayOid = 1009 + ByteaArrayOid = 1001 + VarcharArrayOid = 1015 + Int8ArrayOid = 1016 + Float4ArrayOid = 1021 + Float8ArrayOid = 1022 + AclitemOid = 1033 + AclitemArrayOid = 1034 + InetArrayOid = 1041 + VarcharOid = 1043 + DateOid = 1082 + TimestampOid = 1114 + TimestampArrayOid = 1115 + DateArrayOid = 1182 + TimestamptzOid = 1184 + TimestamptzArrayOid = 1185 + RecordOid = 2249 + UuidOid = 2950 + JsonbOid = 3802 ) type Status byte diff --git a/pgtype_test.go b/pgtype_test.go index 07a40160..f9b6f56d 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -47,7 +47,7 @@ func mustClose(t testing.TB, conn interface { } } -func mustParseCIDR(t testing.TB, s string) *net.IPNet { +func mustParseCidr(t testing.TB, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { t.Fatal(err) diff --git a/pguint32.go b/pguint32.go index df9e0d36..c636e1c4 100644 --- a/pguint32.go +++ b/pguint32.go @@ -10,7 +10,7 @@ import ( ) // pguint32 is the core type that is used to implement PostgreSQL types such as -// CID and XID. +// Cid and Xid. type pguint32 struct { Uint uint32 Status Status diff --git a/text_array.go b/text_array.go index 182e76f5..06e3c0df 100644 --- a/text_array.go +++ b/text_array.go @@ -229,10 +229,10 @@ func (src *TextArray) EncodeText(w io.Writer) (bool, error) { } func (src *TextArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TextOID) + return src.encodeBinary(w, TextOid) } -func (src *TextArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -241,7 +241,7 @@ func (src *TextArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/tid.go b/tid.go index 804cced2..b67892ff 100644 --- a/tid.go +++ b/tid.go @@ -10,7 +10,7 @@ import ( "github.com/jackc/pgx/pgio" ) -// TID is PostgreSQL's Tuple Identifier type. +// Tid is PostgreSQL's Tuple Identifier type. // // When one does // @@ -21,15 +21,15 @@ import ( // It is currently implemented as a pair unsigned two byte integers. // Its conversion functions can be found in src/backend/utils/adt/tid.c // in the PostgreSQL sources. -type TID struct { +type Tid struct { BlockNumber uint32 OffsetNumber uint16 Status Status } -func (dst *TID) DecodeText(src []byte) error { +func (dst *Tid) DecodeText(src []byte) error { if src == nil { - *dst = TID{Status: Null} + *dst = Tid{Status: Null} return nil } @@ -52,13 +52,13 @@ func (dst *TID) DecodeText(src []byte) error { return err } - *dst = TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Status: Present} + *dst = Tid{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Status: Present} return nil } -func (dst *TID) DecodeBinary(src []byte) error { +func (dst *Tid) DecodeBinary(src []byte) error { if src == nil { - *dst = TID{Status: Null} + *dst = Tid{Status: Null} return nil } @@ -66,7 +66,7 @@ func (dst *TID) DecodeBinary(src []byte) error { return fmt.Errorf("invalid length for tid: %v", len(src)) } - *dst = TID{ + *dst = Tid{ BlockNumber: binary.BigEndian.Uint32(src), OffsetNumber: binary.BigEndian.Uint16(src[4:]), Status: Present, @@ -74,7 +74,7 @@ func (dst *TID) DecodeBinary(src []byte) error { return nil } -func (src TID) EncodeText(w io.Writer) (bool, error) { +func (src Tid) EncodeText(w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -86,7 +86,7 @@ func (src TID) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src TID) EncodeBinary(w io.Writer) (bool, error) { +func (src Tid) EncodeBinary(w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/tid_test.go b/tid_test.go index a5aab8a3..56595ef4 100644 --- a/tid_test.go +++ b/tid_test.go @@ -6,10 +6,10 @@ import ( "github.com/jackc/pgx/pgtype" ) -func TestTIDTranscode(t *testing.T) { +func TestTidTranscode(t *testing.T) { testSuccessfulTranscode(t, "tid", []interface{}{ - pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, - pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, - pgtype.TID{Status: pgtype.Null}, + pgtype.Tid{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, + pgtype.Tid{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, + pgtype.Tid{Status: pgtype.Null}, }) } diff --git a/timestamp_array.go b/timestamp_array.go index b0fb25fa..1ea30ba4 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -230,10 +230,10 @@ func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { } func (src *TimestampArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TimestampOID) + return src.encodeBinary(w, TimestampOid) } -func (src *TimestampArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -242,7 +242,7 @@ func (src *TimestampArray) encodeBinary(w io.Writer, elementOID int32) (bool, er } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/timestamptz_array.go b/timestamptz_array.go index 25374717..fc3ce08c 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -230,10 +230,10 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { } func (src *TimestamptzArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TimestamptzOID) + return src.encodeBinary(w, TimestamptzOid) } -func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -242,7 +242,7 @@ func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOID int32) (bool, } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/typed_array.go.erb b/typed_array.go.erb index f9dba308..98c8d845 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -231,7 +231,7 @@ func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, <%= element_oid %>) } -func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -240,7 +240,7 @@ func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOID int32) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 32c298cc..41c1313f 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -1,13 +1,13 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2OID text_null=NULL typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4OID text_null=NULL typed_array.go.erb > int4_array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8OID text_null=NULL typed_array.go.erb > int8_array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOID text_null=NULL typed_array.go.erb > bool_array.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID text_null=NULL typed_array.go.erb > date_array.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID text_null=NULL typed_array.go.erb > timestamptz_array.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID text_null=NULL typed_array.go.erb > timestamp_array.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID text_null=NULL typed_array.go.erb > float4_array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID text_null=NULL typed_array.go.erb > float8_array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID text_null=NULL typed_array.go.erb > inet_array.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID text_null='"NULL"' typed_array.go.erb > text_array.go -erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOID text_null=NULL typed_array.go.erb > bytea_array.go -erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_oid=ACLItemOID text_null=NULL typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2Oid text_null=NULL typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4Oid text_null=NULL typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8Oid text_null=NULL typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOid text_null=NULL typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOid text_null=NULL typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOid text_null=NULL typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOid text_null=NULL typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4Oid text_null=NULL typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8Oid text_null=NULL typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOid text_null=NULL typed_array.go.erb > inet_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOid text_null='"NULL"' typed_array.go.erb > text_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOid text_null=NULL typed_array.go.erb > bytea_array.go +erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_oid=AclitemOid text_null=NULL typed_array.go.erb > aclitem_array.go diff --git a/varchar_array.go b/varchar_array.go index 9c8829d0..b9d87b7f 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -27,5 +27,5 @@ func (src *VarcharArray) EncodeText(w io.Writer) (bool, error) { } func (src *VarcharArray) EncodeBinary(w io.Writer) (bool, error) { - return (*TextArray)(src).encodeBinary(w, VarcharOID) + return (*TextArray)(src).encodeBinary(w, VarcharOid) } diff --git a/xid.go b/xid.go index 6635b21e..7deaa4f0 100644 --- a/xid.go +++ b/xid.go @@ -4,7 +4,7 @@ import ( "io" ) -// XID is PostgreSQL's Transaction ID type. +// Xid is PostgreSQL's Transaction ID type. // // In later versions of PostgreSQL, it is the type used for the backend_xid // and backend_xmin columns of the pg_stat_activity system view. @@ -18,33 +18,33 @@ import ( // It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/postgres_ext.h as TransactionId // in the PostgreSQL sources. -type XID pguint32 +type Xid pguint32 -// ConvertFrom converts from src to dst. Note that as XID is not a general +// ConvertFrom converts from src to dst. Note that as Xid is not a general // number type ConvertFrom does not do automatic type conversion as other number // types do. -func (dst *XID) ConvertFrom(src interface{}) error { +func (dst *Xid) ConvertFrom(src interface{}) error { return (*pguint32)(dst).ConvertFrom(src) } -// AssignTo assigns from src to dst. Note that as XID is not a general number +// AssignTo assigns from src to dst. Note that as Xid is not a general number // type AssignTo does not do automatic type conversion as other number types do. -func (src *XID) AssignTo(dst interface{}) error { +func (src *Xid) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *XID) DecodeText(src []byte) error { +func (dst *Xid) DecodeText(src []byte) error { return (*pguint32)(dst).DecodeText(src) } -func (dst *XID) DecodeBinary(src []byte) error { +func (dst *Xid) DecodeBinary(src []byte) error { return (*pguint32)(dst).DecodeBinary(src) } -func (src XID) EncodeText(w io.Writer) (bool, error) { +func (src Xid) EncodeText(w io.Writer) (bool, error) { return (pguint32)(src).EncodeText(w) } -func (src XID) EncodeBinary(w io.Writer) (bool, error) { +func (src Xid) EncodeBinary(w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(w) } diff --git a/xid_test.go b/xid_test.go index 664920bc..a5c5df51 100644 --- a/xid_test.go +++ b/xid_test.go @@ -7,23 +7,23 @@ import ( "github.com/jackc/pgx/pgtype" ) -func TestXIDTranscode(t *testing.T) { +func TestXidTranscode(t *testing.T) { testSuccessfulTranscode(t, "xid", []interface{}{ - pgtype.XID{Uint: 42, Status: pgtype.Present}, - pgtype.XID{Status: pgtype.Null}, + pgtype.Xid{Uint: 42, Status: pgtype.Present}, + pgtype.Xid{Status: pgtype.Null}, }) } -func TestXIDConvertFrom(t *testing.T) { +func TestXidConvertFrom(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.XID + result pgtype.Xid }{ - {source: uint32(1), result: pgtype.XID{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Xid{Uint: 1, Status: pgtype.Present}}, } for i, tt := range successfulTests { - var r pgtype.XID + var r pgtype.Xid err := r.ConvertFrom(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -35,17 +35,17 @@ func TestXIDConvertFrom(t *testing.T) { } } -func TestXIDAssignTo(t *testing.T) { +func TestXidAssignTo(t *testing.T) { var ui32 uint32 var pui32 *uint32 simpleTests := []struct { - src pgtype.XID + src pgtype.Xid dst interface{} expected interface{} }{ - {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.XID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.Xid{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Xid{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -60,11 +60,11 @@ func TestXIDAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.XID + src pgtype.Xid dst interface{} expected interface{} }{ - {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.Xid{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -79,10 +79,10 @@ func TestXIDAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.XID + src pgtype.Xid dst interface{} }{ - {src: pgtype.XID{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.Xid{Status: pgtype.Null}, dst: &ui32}, } for i, tt := range errorTests { From 7985ca5f8703514c601902763467c1917651b5fc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 18:46:51 -0600 Subject: [PATCH 0032/1158] Add json/jsonb to pgtype --- json.go | 102 ++++++++++++++++++++++++++++++++++++++ json_test.go | 135 ++++++++++++++++++++++++++++++++++++++++++++++++++ jsonb.go | 64 ++++++++++++++++++++++++ jsonb_test.go | 135 ++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 436 insertions(+) create mode 100644 json.go create mode 100644 json_test.go create mode 100644 jsonb.go create mode 100644 jsonb_test.go diff --git a/json.go b/json.go new file mode 100644 index 00000000..8a258ea4 --- /dev/null +++ b/json.go @@ -0,0 +1,102 @@ +package pgtype + +import ( + "encoding/json" + "io" +) + +type Json struct { + Bytes []byte + Status Status +} + +func (dst *Json) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case string: + *dst = Json{Bytes: []byte(value), Status: Present} + case *string: + if value == nil { + *dst = Json{Status: Null} + } else { + *dst = Json{Bytes: []byte(*value), Status: Present} + } + case []byte: + if value == nil { + *dst = Json{Status: Null} + } else { + *dst = Json{Bytes: value, Status: Present} + } + default: + buf, err := json.Marshal(value) + if err != nil { + return err + } + *dst = Json{Bytes: buf, Status: Present} + } + + return nil +} + +func (src *Json) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *string: + if src.Status != Present { + v = nil + } else { + *v = string(src.Bytes) + } + case **string: + *v = new(string) + return src.AssignTo(*v) + case *[]byte: + if src.Status != Present { + *v = nil + } else { + buf := make([]byte, len(src.Bytes)) + copy(buf, src.Bytes) + *v = buf + } + default: + data := src.Bytes + if data == nil || src.Status != Present { + data = []byte("null") + } + + return json.Unmarshal(data, dst) + } + + return nil +} + +func (dst *Json) DecodeText(src []byte) error { + if src == nil { + *dst = Json{Status: Null} + return nil + } + + buf := make([]byte, len(src)) + copy(buf, src) + + *dst = Json{Bytes: buf, Status: Present} + return nil +} + +func (dst *Json) DecodeBinary(src []byte) error { + return dst.DecodeText(src) +} + +func (src Json) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := w.Write(src.Bytes) + return false, err +} + +func (src Json) EncodeBinary(w io.Writer) (bool, error) { + return src.EncodeText(w) +} diff --git a/json_test.go b/json_test.go new file mode 100644 index 00000000..87770f31 --- /dev/null +++ b/json_test.go @@ -0,0 +1,135 @@ +package pgtype_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestJsonTranscode(t *testing.T) { + testSuccessfulTranscode(t, "json", []interface{}{ + pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, + pgtype.Json{Bytes: []byte("null"), Status: pgtype.Present}, + pgtype.Json{Bytes: []byte("42"), Status: pgtype.Present}, + pgtype.Json{Bytes: []byte(`"hello"`), Status: pgtype.Present}, + pgtype.Json{Status: pgtype.Null}, + }) +} + +func TestJsonConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Json + }{ + {source: "{}", result: pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: []byte("{}"), result: pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: ([]byte)(nil), result: pgtype.Json{Status: pgtype.Null}}, + {source: (*string)(nil), result: pgtype.Json{Status: pgtype.Null}}, + {source: []int{1, 2, 3}, result: pgtype.Json{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, + {source: map[string]interface{}{"foo": "bar"}, result: pgtype.Json{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var d pgtype.Json + err := d.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(d, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestJsonAssignTo(t *testing.T) { + var s string + var ps *string + var b []byte + + rawStringTests := []struct { + src pgtype.Json + dst *string + expected string + }{ + {src: pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, + } + + for i, tt := range rawStringTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + rawBytesTests := []struct { + src pgtype.Json + dst *[]byte + expected []byte + }{ + {src: pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, + {src: pgtype.Json{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, + } + + for i, tt := range rawBytesTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if bytes.Compare(tt.expected, *tt.dst) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + var mapDst map[string]interface{} + type structDst struct { + Name string `json:"name"` + Age int `json:"age"` + } + var strDst structDst + + unmarshalTests := []struct { + src pgtype.Json + dst interface{} + expected interface{} + }{ + {src: pgtype.Json{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, + {src: pgtype.Json{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + } + for i, tt := range unmarshalTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Json + dst **string + expected *string + }{ + {src: pgtype.Json{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst == tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} diff --git a/jsonb.go b/jsonb.go new file mode 100644 index 00000000..0739a468 --- /dev/null +++ b/jsonb.go @@ -0,0 +1,64 @@ +package pgtype + +import ( + "fmt" + "io" +) + +type Jsonb Json + +func (dst *Jsonb) ConvertFrom(src interface{}) error { + return (*Json)(dst).ConvertFrom(src) +} + +func (src *Jsonb) AssignTo(dst interface{}) error { + return (*Json)(src).AssignTo(dst) +} + +func (dst *Jsonb) DecodeText(src []byte) error { + return (*Json)(dst).DecodeText(src) +} + +func (dst *Jsonb) DecodeBinary(src []byte) error { + if src == nil { + *dst = Jsonb{Status: Null} + return nil + } + + if len(src) == 0 { + return fmt.Errorf("jsonb too short") + } + + if src[0] != 1 { + return fmt.Errorf("unknown jsonb version number %d", src[0]) + } + src = src[1:] + + buf := make([]byte, len(src)) + copy(buf, src) + + *dst = Jsonb{Bytes: buf, Status: Present} + return nil + +} + +func (src Jsonb) EncodeText(w io.Writer) (bool, error) { + return (Json)(src).EncodeText(w) +} + +func (src Jsonb) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := w.Write([]byte{1}) + if err != nil { + return false, err + } + + _, err = w.Write(src.Bytes) + return false, err +} diff --git a/jsonb_test.go b/jsonb_test.go new file mode 100644 index 00000000..e42931d5 --- /dev/null +++ b/jsonb_test.go @@ -0,0 +1,135 @@ +package pgtype_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestJsonbTranscode(t *testing.T) { + testSuccessfulTranscode(t, "jsonb", []interface{}{ + pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, + pgtype.Jsonb{Bytes: []byte("null"), Status: pgtype.Present}, + pgtype.Jsonb{Bytes: []byte("42"), Status: pgtype.Present}, + pgtype.Jsonb{Bytes: []byte(`"hello"`), Status: pgtype.Present}, + pgtype.Jsonb{Status: pgtype.Null}, + }) +} + +func TestJsonbConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Jsonb + }{ + {source: "{}", result: pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: []byte("{}"), result: pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: ([]byte)(nil), result: pgtype.Jsonb{Status: pgtype.Null}}, + {source: (*string)(nil), result: pgtype.Jsonb{Status: pgtype.Null}}, + {source: []int{1, 2, 3}, result: pgtype.Jsonb{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, + {source: map[string]interface{}{"foo": "bar"}, result: pgtype.Jsonb{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var d pgtype.Jsonb + err := d.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(d, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestJsonbAssignTo(t *testing.T) { + var s string + var ps *string + var b []byte + + rawStringTests := []struct { + src pgtype.Jsonb + dst *string + expected string + }{ + {src: pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, + } + + for i, tt := range rawStringTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + rawBytesTests := []struct { + src pgtype.Jsonb + dst *[]byte + expected []byte + }{ + {src: pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, + {src: pgtype.Jsonb{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, + } + + for i, tt := range rawBytesTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if bytes.Compare(tt.expected, *tt.dst) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + var mapDst map[string]interface{} + type structDst struct { + Name string `json:"name"` + Age int `json:"age"` + } + var strDst structDst + + unmarshalTests := []struct { + src pgtype.Jsonb + dst interface{} + expected interface{} + }{ + {src: pgtype.Jsonb{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, + {src: pgtype.Jsonb{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + } + for i, tt := range unmarshalTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Jsonb + dst **string + expected *string + }{ + {src: pgtype.Jsonb{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst == tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} From 9b9361848dfe4f4b7bc17e4b1e2a403d6ab388e1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 19:53:02 -0600 Subject: [PATCH 0033/1158] Expand pgtype.Value interface - Include and rename ConvertFrom to Set - Add Get - Include AssignTo --- aclitem.go | 15 +++++++++++++-- aclitem_array.go | 17 ++++++++++++++--- aclitem_array_test.go | 4 ++-- aclitem_test.go | 4 ++-- bool.go | 15 +++++++++++++-- bool_array.go | 17 ++++++++++++++--- bool_array_test.go | 4 ++-- bool_test.go | 4 ++-- bytea.go | 15 +++++++++++++-- bytea_array.go | 17 ++++++++++++++--- bytea_array_test.go | 4 ++-- bytea_test.go | 4 ++-- cid.go | 12 ++++++++---- cid_test.go | 4 ++-- cidr_array.go | 8 ++++++-- date.go | 15 +++++++++++++-- date_array.go | 17 ++++++++++++++--- date_array_test.go | 4 ++-- date_test.go | 4 ++-- float4.go | 15 +++++++++++++-- float4_array.go | 17 ++++++++++++++--- float4_array_test.go | 4 ++-- float4_test.go | 4 ++-- float8.go | 15 +++++++++++++-- float8_array.go | 17 ++++++++++++++--- float8_array_test.go | 4 ++-- float8_test.go | 4 ++-- inet.go | 15 +++++++++++++-- inet_array.go | 19 +++++++++++++++---- inet_array_test.go | 4 ++-- inet_test.go | 4 ++-- int2.go | 15 +++++++++++++-- int2_array.go | 19 +++++++++++++++---- int2_array_test.go | 4 ++-- int2_test.go | 4 ++-- int4.go | 15 +++++++++++++-- int4_array.go | 19 +++++++++++++++---- int4_array_test.go | 4 ++-- int4_test.go | 4 ++-- int8.go | 15 +++++++++++++-- int8_array.go | 19 +++++++++++++++---- int8_array_test.go | 4 ++-- int8_test.go | 4 ++-- json.go | 18 +++++++++++++++++- json_test.go | 4 ++-- jsonb.go | 8 ++++++-- jsonb_test.go | 4 ++-- name.go | 8 ++++++-- name_test.go | 4 ++-- oid.go | 12 ++++++++---- oid_test.go | 4 ++-- pgtype.go | 13 ++++++++----- pguint32.go | 17 ++++++++++++++--- qchar.go | 15 +++++++++++++-- qchar_test.go | 4 ++-- text.go | 15 +++++++++++++-- text_array.go | 17 ++++++++++++++--- text_array_test.go | 4 ++-- text_test.go | 4 ++-- tid.go | 19 +++++++++++++++++++ timestamp.go | 20 +++++++++++++++++--- timestamp_array.go | 17 ++++++++++++++--- timestamp_array_test.go | 4 ++-- timestamp_test.go | 4 ++-- timestamptz.go | 18 ++++++++++++++++-- timestamptz_array.go | 17 ++++++++++++++--- timestamptz_array_test.go | 4 ++-- timestamptz_test.go | 4 ++-- typed_array.go.erb | 17 ++++++++++++++--- varchar_array.go | 8 ++++++-- xid.go | 12 ++++++++---- xid_test.go | 4 ++-- 72 files changed, 561 insertions(+), 170 deletions(-) diff --git a/aclitem.go b/aclitem.go index 821c5001..36cf3bbf 100644 --- a/aclitem.go +++ b/aclitem.go @@ -23,7 +23,7 @@ type Aclitem struct { Status Status } -func (dst *Aclitem) ConvertFrom(src interface{}) error { +func (dst *Aclitem) Set(src interface{}) error { switch value := src.(type) { case Aclitem: *dst = value @@ -37,7 +37,7 @@ func (dst *Aclitem) ConvertFrom(src interface{}) error { } default: if originalSrc, ok := underlyingStringType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Aclitem", value) } @@ -45,6 +45,17 @@ func (dst *Aclitem) ConvertFrom(src interface{}) error { return nil } +func (dst *Aclitem) Get() interface{} { + switch dst.Status { + case Present: + return dst.String + case Null: + return nil + default: + return dst.Status + } +} + func (src *Aclitem) AssignTo(dst interface{}) error { switch v := dst.(type) { case *string: diff --git a/aclitem_array.go b/aclitem_array.go index 48f5cd38..13952e5c 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -14,7 +14,7 @@ type AclitemArray struct { Status Status } -func (dst *AclitemArray) ConvertFrom(src interface{}) error { +func (dst *AclitemArray) Set(src interface{}) error { switch value := src.(type) { case AclitemArray: *dst = value @@ -27,7 +27,7 @@ func (dst *AclitemArray) ConvertFrom(src interface{}) error { } else { elements := make([]Aclitem, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -40,7 +40,7 @@ func (dst *AclitemArray) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Aclitem", value) } @@ -48,6 +48,17 @@ func (dst *AclitemArray) ConvertFrom(src interface{}) error { return nil } +func (dst *AclitemArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *AclitemArray) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/aclitem_array_test.go b/aclitem_array_test.go index e78f14c6..75c672bd 100644 --- a/aclitem_array_test.go +++ b/aclitem_array_test.go @@ -51,7 +51,7 @@ func TestAclitemArrayTranscode(t *testing.T) { }) } -func TestAclitemArrayConvertFrom(t *testing.T) { +func TestAclitemArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.AclitemArray @@ -71,7 +71,7 @@ func TestAclitemArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.AclitemArray - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/aclitem_test.go b/aclitem_test.go index fc429acc..47e6fa84 100644 --- a/aclitem_test.go +++ b/aclitem_test.go @@ -15,7 +15,7 @@ func TestAclitemTranscode(t *testing.T) { }) } -func TestAclitemConvertFrom(t *testing.T) { +func TestAclitemSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Aclitem @@ -27,7 +27,7 @@ func TestAclitemConvertFrom(t *testing.T) { for i, tt := range successfulTests { var d pgtype.Aclitem - err := d.ConvertFrom(tt.source) + err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/bool.go b/bool.go index 9764fafe..04a261c2 100644 --- a/bool.go +++ b/bool.go @@ -12,7 +12,7 @@ type Bool struct { Status Status } -func (dst *Bool) ConvertFrom(src interface{}) error { +func (dst *Bool) Set(src interface{}) error { switch value := src.(type) { case Bool: *dst = value @@ -26,7 +26,7 @@ func (dst *Bool) ConvertFrom(src interface{}) error { *dst = Bool{Bool: bb, Status: Present} default: if originalSrc, ok := underlyingBoolType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Bool", value) } @@ -34,6 +34,17 @@ func (dst *Bool) ConvertFrom(src interface{}) error { return nil } +func (dst *Bool) Get() interface{} { + switch dst.Status { + case Present: + return dst.Bool + case Null: + return nil + default: + return dst.Status + } +} + func (src *Bool) AssignTo(dst interface{}) error { switch v := dst.(type) { case *bool: diff --git a/bool_array.go b/bool_array.go index a74e9f90..fdcbf7a0 100644 --- a/bool_array.go +++ b/bool_array.go @@ -15,7 +15,7 @@ type BoolArray struct { Status Status } -func (dst *BoolArray) ConvertFrom(src interface{}) error { +func (dst *BoolArray) Set(src interface{}) error { switch value := src.(type) { case BoolArray: *dst = value @@ -28,7 +28,7 @@ func (dst *BoolArray) ConvertFrom(src interface{}) error { } else { elements := make([]Bool, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -41,7 +41,7 @@ func (dst *BoolArray) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Bool", value) } @@ -49,6 +49,17 @@ func (dst *BoolArray) ConvertFrom(src interface{}) error { return nil } +func (dst *BoolArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *BoolArray) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/bool_array_test.go b/bool_array_test.go index c5f15f97..a526d892 100644 --- a/bool_array_test.go +++ b/bool_array_test.go @@ -51,7 +51,7 @@ func TestBoolArrayTranscode(t *testing.T) { }) } -func TestBoolArrayConvertFrom(t *testing.T) { +func TestBoolArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.BoolArray @@ -71,7 +71,7 @@ func TestBoolArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.BoolArray - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/bool_test.go b/bool_test.go index 374f07da..773bd99b 100644 --- a/bool_test.go +++ b/bool_test.go @@ -15,7 +15,7 @@ func TestBoolTranscode(t *testing.T) { }) } -func TestBoolConvertFrom(t *testing.T) { +func TestBoolSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Bool @@ -33,7 +33,7 @@ func TestBoolConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Bool - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/bytea.go b/bytea.go index 709499d2..9d2e20f3 100644 --- a/bytea.go +++ b/bytea.go @@ -12,7 +12,7 @@ type Bytea struct { Status Status } -func (dst *Bytea) ConvertFrom(src interface{}) error { +func (dst *Bytea) Set(src interface{}) error { switch value := src.(type) { case Bytea: *dst = value @@ -24,7 +24,7 @@ func (dst *Bytea) ConvertFrom(src interface{}) error { } default: if originalSrc, ok := underlyingBytesType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Bytea", value) } @@ -32,6 +32,17 @@ func (dst *Bytea) ConvertFrom(src interface{}) error { return nil } +func (dst *Bytea) Get() interface{} { + switch dst.Status { + case Present: + return dst.Bytes + case Null: + return nil + default: + return dst.Status + } +} + func (src *Bytea) AssignTo(dst interface{}) error { switch v := dst.(type) { case *[]byte: diff --git a/bytea_array.go b/bytea_array.go index 9003eafd..5362944a 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -15,7 +15,7 @@ type ByteaArray struct { Status Status } -func (dst *ByteaArray) ConvertFrom(src interface{}) error { +func (dst *ByteaArray) Set(src interface{}) error { switch value := src.(type) { case ByteaArray: *dst = value @@ -28,7 +28,7 @@ func (dst *ByteaArray) ConvertFrom(src interface{}) error { } else { elements := make([]Bytea, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -41,7 +41,7 @@ func (dst *ByteaArray) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Bytea", value) } @@ -49,6 +49,17 @@ func (dst *ByteaArray) ConvertFrom(src interface{}) error { return nil } +func (dst *ByteaArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *ByteaArray) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/bytea_array_test.go b/bytea_array_test.go index b39776d9..22c6478b 100644 --- a/bytea_array_test.go +++ b/bytea_array_test.go @@ -51,7 +51,7 @@ func TestByteaArrayTranscode(t *testing.T) { }) } -func TestByteaArrayConvertFrom(t *testing.T) { +func TestByteaArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.ByteaArray @@ -71,7 +71,7 @@ func TestByteaArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.ByteaArray - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/bytea_test.go b/bytea_test.go index 51941387..4655a1c1 100644 --- a/bytea_test.go +++ b/bytea_test.go @@ -15,7 +15,7 @@ func TestByteaTranscode(t *testing.T) { }) } -func TestByteaConvertFrom(t *testing.T) { +func TestByteaSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Bytea @@ -30,7 +30,7 @@ func TestByteaConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Bytea - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/cid.go b/cid.go index be93a03e..20957f36 100644 --- a/cid.go +++ b/cid.go @@ -17,11 +17,15 @@ import ( // in the PostgreSQL sources. type Cid pguint32 -// ConvertFrom converts from src to dst. Note that as Cid is not a general -// number type ConvertFrom does not do automatic type conversion as other number +// Set converts from src to dst. Note that as Cid is not a general +// number type Set does not do automatic type conversion as other number // types do. -func (dst *Cid) ConvertFrom(src interface{}) error { - return (*pguint32)(dst).ConvertFrom(src) +func (dst *Cid) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst *Cid) Get() interface{} { + return (*pguint32)(dst).Get() } // AssignTo assigns from src to dst. Note that as Cid is not a general number diff --git a/cid_test.go b/cid_test.go index 7d9fde34..0d114cda 100644 --- a/cid_test.go +++ b/cid_test.go @@ -14,7 +14,7 @@ func TestCidTranscode(t *testing.T) { }) } -func TestCidConvertFrom(t *testing.T) { +func TestCidSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Cid @@ -24,7 +24,7 @@ func TestCidConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Cid - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/cidr_array.go b/cidr_array.go index e0219ee5..c30c53d3 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -6,8 +6,12 @@ import ( type CidrArray InetArray -func (dst *CidrArray) ConvertFrom(src interface{}) error { - return (*InetArray)(dst).ConvertFrom(src) +func (dst *CidrArray) Set(src interface{}) error { + return (*InetArray)(dst).Set(src) +} + +func (dst *CidrArray) Get() interface{} { + return (*InetArray)(dst).Get() } func (src *CidrArray) AssignTo(dst interface{}) error { diff --git a/date.go b/date.go index b0d16e64..a3b8d99f 100644 --- a/date.go +++ b/date.go @@ -21,7 +21,7 @@ const ( infinityDayOffset = 2147483647 ) -func (dst *Date) ConvertFrom(src interface{}) error { +func (dst *Date) Set(src interface{}) error { switch value := src.(type) { case Date: *dst = value @@ -29,7 +29,7 @@ func (dst *Date) ConvertFrom(src interface{}) error { *dst = Date{Time: value, Status: Present} default: if originalSrc, ok := underlyingTimeType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Date", value) } @@ -37,6 +37,17 @@ func (dst *Date) ConvertFrom(src interface{}) error { return nil } +func (dst *Date) Get() interface{} { + switch dst.Status { + case Present: + return dst.Time + case Null: + return nil + default: + return dst.Status + } +} + func (src *Date) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: diff --git a/date_array.go b/date_array.go index 8f7cba18..ce28e236 100644 --- a/date_array.go +++ b/date_array.go @@ -16,7 +16,7 @@ type DateArray struct { Status Status } -func (dst *DateArray) ConvertFrom(src interface{}) error { +func (dst *DateArray) Set(src interface{}) error { switch value := src.(type) { case DateArray: *dst = value @@ -29,7 +29,7 @@ func (dst *DateArray) ConvertFrom(src interface{}) error { } else { elements := make([]Date, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -42,7 +42,7 @@ func (dst *DateArray) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Date", value) } @@ -50,6 +50,17 @@ func (dst *DateArray) ConvertFrom(src interface{}) error { return nil } +func (dst *DateArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *DateArray) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/date_array_test.go b/date_array_test.go index 60f15983..a05f4254 100644 --- a/date_array_test.go +++ b/date_array_test.go @@ -52,7 +52,7 @@ func TestDateArrayTranscode(t *testing.T) { }) } -func TestDateArrayConvertFrom(t *testing.T) { +func TestDateArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.DateArray @@ -72,7 +72,7 @@ func TestDateArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.DateArray - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/date_test.go b/date_test.go index 3a473b6a..eff3a521 100644 --- a/date_test.go +++ b/date_test.go @@ -22,7 +22,7 @@ func TestDateTranscode(t *testing.T) { }) } -func TestDateConvertFrom(t *testing.T) { +func TestDateSet(t *testing.T) { type _time time.Time successfulTests := []struct { @@ -41,7 +41,7 @@ func TestDateConvertFrom(t *testing.T) { for i, tt := range successfulTests { var d pgtype.Date - err := d.ConvertFrom(tt.source) + err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/float4.go b/float4.go index 26609ab2..a38d24db 100644 --- a/float4.go +++ b/float4.go @@ -15,7 +15,7 @@ type Float4 struct { Status Status } -func (dst *Float4) ConvertFrom(src interface{}) error { +func (dst *Float4) Set(src interface{}) error { switch value := src.(type) { case Float4: *dst = value @@ -81,7 +81,7 @@ func (dst *Float4) ConvertFrom(src interface{}) error { *dst = Float4{Float: float32(num), Status: Present} default: if originalSrc, ok := underlyingNumberType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Float8", value) } @@ -89,6 +89,17 @@ func (dst *Float4) ConvertFrom(src interface{}) error { return nil } +func (dst *Float4) Get() interface{} { + switch dst.Status { + case Present: + return dst.Float + case Null: + return nil + default: + return dst.Status + } +} + func (src *Float4) AssignTo(dst interface{}) error { return float64AssignTo(float64(src.Float), src.Status, dst) } diff --git a/float4_array.go b/float4_array.go index 632e7e4b..410a8b37 100644 --- a/float4_array.go +++ b/float4_array.go @@ -15,7 +15,7 @@ type Float4Array struct { Status Status } -func (dst *Float4Array) ConvertFrom(src interface{}) error { +func (dst *Float4Array) Set(src interface{}) error { switch value := src.(type) { case Float4Array: *dst = value @@ -28,7 +28,7 @@ func (dst *Float4Array) ConvertFrom(src interface{}) error { } else { elements := make([]Float4, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -41,7 +41,7 @@ func (dst *Float4Array) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Float4", value) } @@ -49,6 +49,17 @@ func (dst *Float4Array) ConvertFrom(src interface{}) error { return nil } +func (dst *Float4Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *Float4Array) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/float4_array_test.go b/float4_array_test.go index b22f4fbc..06a1d2e0 100644 --- a/float4_array_test.go +++ b/float4_array_test.go @@ -51,7 +51,7 @@ func TestFloat4ArrayTranscode(t *testing.T) { }) } -func TestFloat4ArrayConvertFrom(t *testing.T) { +func TestFloat4ArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Float4Array @@ -71,7 +71,7 @@ func TestFloat4ArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Float4Array - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/float4_test.go b/float4_test.go index 62420b8d..ea60cd3a 100644 --- a/float4_test.go +++ b/float4_test.go @@ -18,7 +18,7 @@ func TestFloat4Transcode(t *testing.T) { }) } -func TestFloat4ConvertFrom(t *testing.T) { +func TestFloat4Set(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Float4 @@ -43,7 +43,7 @@ func TestFloat4ConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Float4 - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/float8.go b/float8.go index 9ec9a665..9129e8ba 100644 --- a/float8.go +++ b/float8.go @@ -15,7 +15,7 @@ type Float8 struct { Status Status } -func (dst *Float8) ConvertFrom(src interface{}) error { +func (dst *Float8) Set(src interface{}) error { switch value := src.(type) { case Float8: *dst = value @@ -71,7 +71,7 @@ func (dst *Float8) ConvertFrom(src interface{}) error { *dst = Float8{Float: float64(num), Status: Present} default: if originalSrc, ok := underlyingNumberType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Float8", value) } @@ -79,6 +79,17 @@ func (dst *Float8) ConvertFrom(src interface{}) error { return nil } +func (dst *Float8) Get() interface{} { + switch dst.Status { + case Present: + return dst.Float + case Null: + return nil + default: + return dst.Status + } +} + func (src *Float8) AssignTo(dst interface{}) error { return float64AssignTo(src.Float, src.Status, dst) } diff --git a/float8_array.go b/float8_array.go index 68cf30f2..b2f70f51 100644 --- a/float8_array.go +++ b/float8_array.go @@ -15,7 +15,7 @@ type Float8Array struct { Status Status } -func (dst *Float8Array) ConvertFrom(src interface{}) error { +func (dst *Float8Array) Set(src interface{}) error { switch value := src.(type) { case Float8Array: *dst = value @@ -28,7 +28,7 @@ func (dst *Float8Array) ConvertFrom(src interface{}) error { } else { elements := make([]Float8, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -41,7 +41,7 @@ func (dst *Float8Array) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Float8", value) } @@ -49,6 +49,17 @@ func (dst *Float8Array) ConvertFrom(src interface{}) error { return nil } +func (dst *Float8Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *Float8Array) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/float8_array_test.go b/float8_array_test.go index d4402281..635e249a 100644 --- a/float8_array_test.go +++ b/float8_array_test.go @@ -51,7 +51,7 @@ func TestFloat8ArrayTranscode(t *testing.T) { }) } -func TestFloat8ArrayConvertFrom(t *testing.T) { +func TestFloat8ArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Float8Array @@ -71,7 +71,7 @@ func TestFloat8ArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Float8Array - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/float8_test.go b/float8_test.go index 748ffd25..724e9350 100644 --- a/float8_test.go +++ b/float8_test.go @@ -18,7 +18,7 @@ func TestFloat8Transcode(t *testing.T) { }) } -func TestFloat8ConvertFrom(t *testing.T) { +func TestFloat8Set(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Float8 @@ -43,7 +43,7 @@ func TestFloat8ConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Float8 - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/inet.go b/inet.go index f94622f4..00bfb30c 100644 --- a/inet.go +++ b/inet.go @@ -23,7 +23,7 @@ type Inet struct { Status Status } -func (dst *Inet) ConvertFrom(src interface{}) error { +func (dst *Inet) Set(src interface{}) error { switch value := src.(type) { case Inet: *dst = value @@ -43,7 +43,7 @@ func (dst *Inet) ConvertFrom(src interface{}) error { *dst = Inet{IPNet: ipnet, Status: Present} default: if originalSrc, ok := underlyingPtrType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Inet", value) } @@ -51,6 +51,17 @@ func (dst *Inet) ConvertFrom(src interface{}) error { return nil } +func (dst *Inet) Get() interface{} { + switch dst.Status { + case Present: + return dst.IPNet + case Null: + return nil + default: + return dst.Status + } +} + func (src *Inet) AssignTo(dst interface{}) error { switch v := dst.(type) { case *net.IPNet: diff --git a/inet_array.go b/inet_array.go index 629cd51f..4d865b4f 100644 --- a/inet_array.go +++ b/inet_array.go @@ -16,7 +16,7 @@ type InetArray struct { Status Status } -func (dst *InetArray) ConvertFrom(src interface{}) error { +func (dst *InetArray) Set(src interface{}) error { switch value := src.(type) { case InetArray: *dst = value @@ -29,7 +29,7 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { } else { elements := make([]Inet, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -48,7 +48,7 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { } else { elements := make([]Inet, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -61,7 +61,7 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Inet", value) } @@ -69,6 +69,17 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { return nil } +func (dst *InetArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *InetArray) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/inet_array_test.go b/inet_array_test.go index 523a9f8d..fe22285d 100644 --- a/inet_array_test.go +++ b/inet_array_test.go @@ -52,7 +52,7 @@ func TestInetArrayTranscode(t *testing.T) { }) } -func TestInetArrayConvertFrom(t *testing.T) { +func TestInetArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.InetArray @@ -83,7 +83,7 @@ func TestInetArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.InetArray - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/inet_test.go b/inet_test.go index 5a326810..90b0723f 100644 --- a/inet_test.go +++ b/inet_test.go @@ -26,7 +26,7 @@ func TestInetTranscode(t *testing.T) { } } -func TestInetConvertFrom(t *testing.T) { +func TestInetSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Inet @@ -39,7 +39,7 @@ func TestInetConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Inet - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/int2.go b/int2.go index 7bdbacfe..525427c5 100644 --- a/int2.go +++ b/int2.go @@ -15,7 +15,7 @@ type Int2 struct { Status Status } -func (dst *Int2) ConvertFrom(src interface{}) error { +func (dst *Int2) Set(src interface{}) error { switch value := src.(type) { case Int2: *dst = value @@ -77,7 +77,7 @@ func (dst *Int2) ConvertFrom(src interface{}) error { *dst = Int2{Int: int16(num), Status: Present} default: if originalSrc, ok := underlyingNumberType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Int2", value) } @@ -85,6 +85,17 @@ func (dst *Int2) ConvertFrom(src interface{}) error { return nil } +func (dst *Int2) Get() interface{} { + switch dst.Status { + case Present: + return dst.Int + case Null: + return nil + default: + return dst.Status + } +} + func (src *Int2) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } diff --git a/int2_array.go b/int2_array.go index d8268c0a..28792fa5 100644 --- a/int2_array.go +++ b/int2_array.go @@ -15,7 +15,7 @@ type Int2Array struct { Status Status } -func (dst *Int2Array) ConvertFrom(src interface{}) error { +func (dst *Int2Array) Set(src interface{}) error { switch value := src.(type) { case Int2Array: *dst = value @@ -28,7 +28,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { } else { elements := make([]Int2, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -47,7 +47,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { } else { elements := make([]Int2, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -60,7 +60,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Int2", value) } @@ -68,6 +68,17 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { return nil } +func (dst *Int2Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *Int2Array) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/int2_array_test.go b/int2_array_test.go index ced0eab4..8af4523d 100644 --- a/int2_array_test.go +++ b/int2_array_test.go @@ -51,7 +51,7 @@ func TestInt2ArrayTranscode(t *testing.T) { }) } -func TestInt2ArrayConvertFrom(t *testing.T) { +func TestInt2ArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Int2Array @@ -78,7 +78,7 @@ func TestInt2ArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Int2Array - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/int2_test.go b/int2_test.go index 8601309d..2bd8e016 100644 --- a/int2_test.go +++ b/int2_test.go @@ -19,7 +19,7 @@ func TestInt2Transcode(t *testing.T) { }) } -func TestInt2ConvertFrom(t *testing.T) { +func TestInt2Set(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Int2 @@ -42,7 +42,7 @@ func TestInt2ConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Int2 - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/int4.go b/int4.go index 2d96ea48..b3203a28 100644 --- a/int4.go +++ b/int4.go @@ -15,7 +15,7 @@ type Int4 struct { Status Status } -func (dst *Int4) ConvertFrom(src interface{}) error { +func (dst *Int4) Set(src interface{}) error { switch value := src.(type) { case Int4: *dst = value @@ -68,7 +68,7 @@ func (dst *Int4) ConvertFrom(src interface{}) error { *dst = Int4{Int: int32(num), Status: Present} default: if originalSrc, ok := underlyingNumberType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Int8", value) } @@ -76,6 +76,17 @@ func (dst *Int4) ConvertFrom(src interface{}) error { return nil } +func (dst *Int4) Get() interface{} { + switch dst.Status { + case Present: + return dst.Int + case Null: + return nil + default: + return dst.Status + } +} + func (src *Int4) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } diff --git a/int4_array.go b/int4_array.go index dcdb50c1..61cedb2e 100644 --- a/int4_array.go +++ b/int4_array.go @@ -15,7 +15,7 @@ type Int4Array struct { Status Status } -func (dst *Int4Array) ConvertFrom(src interface{}) error { +func (dst *Int4Array) Set(src interface{}) error { switch value := src.(type) { case Int4Array: *dst = value @@ -28,7 +28,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { } else { elements := make([]Int4, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -47,7 +47,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { } else { elements := make([]Int4, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -60,7 +60,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Int4", value) } @@ -68,6 +68,17 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { return nil } +func (dst *Int4Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *Int4Array) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/int4_array_test.go b/int4_array_test.go index 38ba27cb..111cb56b 100644 --- a/int4_array_test.go +++ b/int4_array_test.go @@ -51,7 +51,7 @@ func TestInt4ArrayTranscode(t *testing.T) { }) } -func TestInt4ArrayConvertFrom(t *testing.T) { +func TestInt4ArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Int4Array @@ -78,7 +78,7 @@ func TestInt4ArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Int4Array - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/int4_test.go b/int4_test.go index 0ac2e5b5..3e000182 100644 --- a/int4_test.go +++ b/int4_test.go @@ -19,7 +19,7 @@ func TestInt4Transcode(t *testing.T) { }) } -func TestInt4ConvertFrom(t *testing.T) { +func TestInt4Set(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Int4 @@ -42,7 +42,7 @@ func TestInt4ConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Int4 - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/int8.go b/int8.go index 91f5b877..15ad6715 100644 --- a/int8.go +++ b/int8.go @@ -15,7 +15,7 @@ type Int8 struct { Status Status } -func (dst *Int8) ConvertFrom(src interface{}) error { +func (dst *Int8) Set(src interface{}) error { switch value := src.(type) { case Int8: *dst = value @@ -59,7 +59,7 @@ func (dst *Int8) ConvertFrom(src interface{}) error { *dst = Int8{Int: num, Status: Present} default: if originalSrc, ok := underlyingNumberType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Int8", value) } @@ -67,6 +67,17 @@ func (dst *Int8) ConvertFrom(src interface{}) error { return nil } +func (dst *Int8) Get() interface{} { + switch dst.Status { + case Present: + return dst.Int + case Null: + return nil + default: + return dst.Status + } +} + func (src *Int8) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } diff --git a/int8_array.go b/int8_array.go index ed82f079..9f4373e8 100644 --- a/int8_array.go +++ b/int8_array.go @@ -15,7 +15,7 @@ type Int8Array struct { Status Status } -func (dst *Int8Array) ConvertFrom(src interface{}) error { +func (dst *Int8Array) Set(src interface{}) error { switch value := src.(type) { case Int8Array: *dst = value @@ -28,7 +28,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { } else { elements := make([]Int8, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -47,7 +47,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { } else { elements := make([]Int8, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -60,7 +60,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Int8", value) } @@ -68,6 +68,17 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { return nil } +func (dst *Int8Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *Int8Array) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/int8_array_test.go b/int8_array_test.go index 137768c6..349a1f7e 100644 --- a/int8_array_test.go +++ b/int8_array_test.go @@ -51,7 +51,7 @@ func TestInt8ArrayTranscode(t *testing.T) { }) } -func TestInt8ArrayConvertFrom(t *testing.T) { +func TestInt8ArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Int8Array @@ -78,7 +78,7 @@ func TestInt8ArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Int8Array - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/int8_test.go b/int8_test.go index 15762a50..e1fe69fb 100644 --- a/int8_test.go +++ b/int8_test.go @@ -19,7 +19,7 @@ func TestInt8Transcode(t *testing.T) { }) } -func TestInt8ConvertFrom(t *testing.T) { +func TestInt8Set(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Int8 @@ -42,7 +42,7 @@ func TestInt8ConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Int8 - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/json.go b/json.go index 8a258ea4..ecdb3dab 100644 --- a/json.go +++ b/json.go @@ -10,7 +10,7 @@ type Json struct { Status Status } -func (dst *Json) ConvertFrom(src interface{}) error { +func (dst *Json) Set(src interface{}) error { switch value := src.(type) { case string: *dst = Json{Bytes: []byte(value), Status: Present} @@ -37,6 +37,22 @@ func (dst *Json) ConvertFrom(src interface{}) error { return nil } +func (dst *Json) Get() interface{} { + switch dst.Status { + case Present: + var i interface{} + err := json.Unmarshal(dst.Bytes, &i) + if err != nil { + return dst + } + return i + case Null: + return nil + default: + return dst.Status + } +} + func (src *Json) AssignTo(dst interface{}) error { switch v := dst.(type) { case *string: diff --git a/json_test.go b/json_test.go index 87770f31..b0aa8c9b 100644 --- a/json_test.go +++ b/json_test.go @@ -18,7 +18,7 @@ func TestJsonTranscode(t *testing.T) { }) } -func TestJsonConvertFrom(t *testing.T) { +func TestJsonSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Json @@ -33,7 +33,7 @@ func TestJsonConvertFrom(t *testing.T) { for i, tt := range successfulTests { var d pgtype.Json - err := d.ConvertFrom(tt.source) + err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/jsonb.go b/jsonb.go index 0739a468..13062e8e 100644 --- a/jsonb.go +++ b/jsonb.go @@ -7,8 +7,12 @@ import ( type Jsonb Json -func (dst *Jsonb) ConvertFrom(src interface{}) error { - return (*Json)(dst).ConvertFrom(src) +func (dst *Jsonb) Set(src interface{}) error { + return (*Json)(dst).Set(src) +} + +func (dst *Jsonb) Get() interface{} { + return (*Json)(dst).Get() } func (src *Jsonb) AssignTo(dst interface{}) error { diff --git a/jsonb_test.go b/jsonb_test.go index e42931d5..3978b0d4 100644 --- a/jsonb_test.go +++ b/jsonb_test.go @@ -18,7 +18,7 @@ func TestJsonbTranscode(t *testing.T) { }) } -func TestJsonbConvertFrom(t *testing.T) { +func TestJsonbSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Jsonb @@ -33,7 +33,7 @@ func TestJsonbConvertFrom(t *testing.T) { for i, tt := range successfulTests { var d pgtype.Jsonb - err := d.ConvertFrom(tt.source) + err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/name.go b/name.go index 513abfc7..9eb12ece 100644 --- a/name.go +++ b/name.go @@ -19,8 +19,12 @@ import ( // bytes applies, rather than the default 63. type Name Text -func (dst *Name) ConvertFrom(src interface{}) error { - return (*Text)(dst).ConvertFrom(src) +func (dst *Name) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *Name) Get() interface{} { + return (*Text)(dst).Get() } func (src *Name) AssignTo(dst interface{}) error { diff --git a/name_test.go b/name_test.go index c5f7de17..81a766b8 100644 --- a/name_test.go +++ b/name_test.go @@ -15,7 +15,7 @@ func TestNameTranscode(t *testing.T) { }) } -func TestNameConvertFrom(t *testing.T) { +func TestNameSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Name @@ -27,7 +27,7 @@ func TestNameConvertFrom(t *testing.T) { for i, tt := range successfulTests { var d pgtype.Name - err := d.ConvertFrom(tt.source) + err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/oid.go b/oid.go index c77f3f10..e57bb2e6 100644 --- a/oid.go +++ b/oid.go @@ -11,11 +11,15 @@ import ( // found in src/include/postgres_ext.h in the PostgreSQL sources. type Oid pguint32 -// ConvertFrom converts from src to dst. Note that as Oid is not a general -// number type ConvertFrom does not do automatic type conversion as other number +// Set converts from src to dst. Note that as Oid is not a general +// number type Set does not do automatic type conversion as other number // types do. -func (dst *Oid) ConvertFrom(src interface{}) error { - return (*pguint32)(dst).ConvertFrom(src) +func (dst *Oid) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst *Oid) Get() interface{} { + return (*pguint32)(dst).Get() } // AssignTo assigns from src to dst. Note that as Oid is not a general number diff --git a/oid_test.go b/oid_test.go index bbab6699..b3b96959 100644 --- a/oid_test.go +++ b/oid_test.go @@ -14,7 +14,7 @@ func TestOidTranscode(t *testing.T) { }) } -func TestOidConvertFrom(t *testing.T) { +func TestOidSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Oid @@ -24,7 +24,7 @@ func TestOidConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Oid - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype.go b/pgtype.go index cbcd6bd5..5a51172e 100644 --- a/pgtype.go +++ b/pgtype.go @@ -66,13 +66,16 @@ const ( NegativeInfinity InfinityModifier = -Infinity ) -type Value interface{} +type Value interface { + // Set converts and assigns src to itself. + Set(src interface{}) error -type ConverterFrom interface { - ConvertFrom(src interface{}) error -} + // Get returns the simplest representation of Value. If the Value is Null or + // Undefined that is the return value. If no simpler representation is + // possible, then Get() returns Value. + Get() interface{} -type AssignerTo interface { + // AssignTo converts and assigns the Value to dst. AssignTo(dst interface{}) error } diff --git a/pguint32.go b/pguint32.go index c636e1c4..05c79c0e 100644 --- a/pguint32.go +++ b/pguint32.go @@ -16,10 +16,10 @@ type pguint32 struct { Status Status } -// ConvertFrom converts from src to dst. Note that as pguint32 is not a general -// number type ConvertFrom does not do automatic type conversion as other number +// Set converts from src to dst. Note that as pguint32 is not a general +// number type Set does not do automatic type conversion as other number // types do. -func (dst *pguint32) ConvertFrom(src interface{}) error { +func (dst *pguint32) Set(src interface{}) error { switch value := src.(type) { case uint32: *dst = pguint32{Uint: value, Status: Present} @@ -30,6 +30,17 @@ func (dst *pguint32) ConvertFrom(src interface{}) error { return nil } +func (dst *pguint32) Get() interface{} { + switch dst.Status { + case Present: + return dst.Uint + case Null: + return nil + default: + return dst.Status + } +} + // AssignTo assigns from src to dst. Note that as pguint32 is not a general number // type AssignTo does not do automatic type conversion as other number types do. func (src *pguint32) AssignTo(dst interface{}) error { diff --git a/qchar.go b/qchar.go index 0da1e88b..b6392cf9 100644 --- a/qchar.go +++ b/qchar.go @@ -23,7 +23,7 @@ type QChar struct { Status Status } -func (dst *QChar) ConvertFrom(src interface{}) error { +func (dst *QChar) Set(src interface{}) error { switch value := src.(type) { case QChar: *dst = value @@ -94,7 +94,7 @@ func (dst *QChar) ConvertFrom(src interface{}) error { *dst = QChar{Int: int8(num), Status: Present} default: if originalSrc, ok := underlyingNumberType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to QChar", value) } @@ -102,6 +102,17 @@ func (dst *QChar) ConvertFrom(src interface{}) error { return nil } +func (dst *QChar) Get() interface{} { + switch dst.Status { + case Present: + return dst.Int + case Null: + return nil + default: + return dst.Status + } +} + func (src *QChar) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } diff --git a/qchar_test.go b/qchar_test.go index ea7b56a8..a1b6d22e 100644 --- a/qchar_test.go +++ b/qchar_test.go @@ -19,7 +19,7 @@ func TestQCharTranscode(t *testing.T) { }) } -func TestQCharConvertFrom(t *testing.T) { +func TestQCharSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.QChar @@ -42,7 +42,7 @@ func TestQCharConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.QChar - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/text.go b/text.go index baf62d1e..50db2349 100644 --- a/text.go +++ b/text.go @@ -11,7 +11,7 @@ type Text struct { Status Status } -func (dst *Text) ConvertFrom(src interface{}) error { +func (dst *Text) Set(src interface{}) error { switch value := src.(type) { case Text: *dst = value @@ -25,7 +25,7 @@ func (dst *Text) ConvertFrom(src interface{}) error { } default: if originalSrc, ok := underlyingStringType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Text", value) } @@ -33,6 +33,17 @@ func (dst *Text) ConvertFrom(src interface{}) error { return nil } +func (dst *Text) Get() interface{} { + switch dst.Status { + case Present: + return dst.String + case Null: + return nil + default: + return dst.Status + } +} + func (src *Text) AssignTo(dst interface{}) error { switch v := dst.(type) { case *string: diff --git a/text_array.go b/text_array.go index 06e3c0df..3a5a64ce 100644 --- a/text_array.go +++ b/text_array.go @@ -15,7 +15,7 @@ type TextArray struct { Status Status } -func (dst *TextArray) ConvertFrom(src interface{}) error { +func (dst *TextArray) Set(src interface{}) error { switch value := src.(type) { case TextArray: *dst = value @@ -28,7 +28,7 @@ func (dst *TextArray) ConvertFrom(src interface{}) error { } else { elements := make([]Text, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -41,7 +41,7 @@ func (dst *TextArray) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Text", value) } @@ -49,6 +49,17 @@ func (dst *TextArray) ConvertFrom(src interface{}) error { return nil } +func (dst *TextArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *TextArray) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/text_array_test.go b/text_array_test.go index a22e003d..5a78d7bc 100644 --- a/text_array_test.go +++ b/text_array_test.go @@ -51,7 +51,7 @@ func TestTextArrayTranscode(t *testing.T) { }) } -func TestTextArrayConvertFrom(t *testing.T) { +func TestTextArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.TextArray @@ -71,7 +71,7 @@ func TestTextArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.TextArray - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/text_test.go b/text_test.go index 6e944857..f5e20055 100644 --- a/text_test.go +++ b/text_test.go @@ -17,7 +17,7 @@ func TestTextTranscode(t *testing.T) { } } -func TestTextConvertFrom(t *testing.T) { +func TestTextSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Text @@ -30,7 +30,7 @@ func TestTextConvertFrom(t *testing.T) { for i, tt := range successfulTests { var d pgtype.Text - err := d.ConvertFrom(tt.source) + err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/tid.go b/tid.go index b67892ff..20d962df 100644 --- a/tid.go +++ b/tid.go @@ -27,6 +27,25 @@ type Tid struct { Status Status } +func (dst *Tid) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Tid", src) +} + +func (dst *Tid) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Tid) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + func (dst *Tid) DecodeText(src []byte) error { if src == nil { *dst = Tid{Status: Null} diff --git a/timestamp.go b/timestamp.go index a8b628e9..a84f3881 100644 --- a/timestamp.go +++ b/timestamp.go @@ -23,9 +23,9 @@ type Timestamp struct { InfinityModifier } -// ConvertFrom converts src into a Timestamp and stores in dst. If src is a +// Set converts src into a Timestamp and stores in dst. If src is a // time.Time in a non-UTC time zone, the time zone is discarded. -func (dst *Timestamp) ConvertFrom(src interface{}) error { +func (dst *Timestamp) Set(src interface{}) error { switch value := src.(type) { case Timestamp: *dst = value @@ -33,7 +33,7 @@ func (dst *Timestamp) ConvertFrom(src interface{}) error { *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present} default: if originalSrc, ok := underlyingTimeType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Timestamp", value) } @@ -41,6 +41,20 @@ func (dst *Timestamp) ConvertFrom(src interface{}) error { return nil } +func (dst *Timestamp) Get() interface{} { + switch dst.Status { + case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst.Time + case Null: + return nil + default: + return dst.Status + } +} + func (src *Timestamp) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: diff --git a/timestamp_array.go b/timestamp_array.go index 1ea30ba4..ec0facb2 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -16,7 +16,7 @@ type TimestampArray struct { Status Status } -func (dst *TimestampArray) ConvertFrom(src interface{}) error { +func (dst *TimestampArray) Set(src interface{}) error { switch value := src.(type) { case TimestampArray: *dst = value @@ -29,7 +29,7 @@ func (dst *TimestampArray) ConvertFrom(src interface{}) error { } else { elements := make([]Timestamp, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -42,7 +42,7 @@ func (dst *TimestampArray) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Timestamp", value) } @@ -50,6 +50,17 @@ func (dst *TimestampArray) ConvertFrom(src interface{}) error { return nil } +func (dst *TimestampArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *TimestampArray) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/timestamp_array_test.go b/timestamp_array_test.go index 68189cc7..a15d3696 100644 --- a/timestamp_array_test.go +++ b/timestamp_array_test.go @@ -68,7 +68,7 @@ func TestTimestampArrayTranscode(t *testing.T) { }) } -func TestTimestampArrayConvertFrom(t *testing.T) { +func TestTimestampArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.TimestampArray @@ -88,7 +88,7 @@ func TestTimestampArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.TimestampArray - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/timestamp_test.go b/timestamp_test.go index 6d6e738c..7297ed1f 100644 --- a/timestamp_test.go +++ b/timestamp_test.go @@ -31,7 +31,7 @@ func TestTimestampTranscode(t *testing.T) { }) } -func TestTimestampConvertFrom(t *testing.T) { +func TestTimestampSet(t *testing.T) { type _time time.Time successfulTests := []struct { @@ -51,7 +51,7 @@ func TestTimestampConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Timestamp - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/timestamptz.go b/timestamptz.go index f4c67b0b..a6922d5b 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -26,7 +26,7 @@ type Timestamptz struct { InfinityModifier } -func (dst *Timestamptz) ConvertFrom(src interface{}) error { +func (dst *Timestamptz) Set(src interface{}) error { switch value := src.(type) { case Timestamptz: *dst = value @@ -34,7 +34,7 @@ func (dst *Timestamptz) ConvertFrom(src interface{}) error { *dst = Timestamptz{Time: value, Status: Present} default: if originalSrc, ok := underlyingTimeType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Timestamptz", value) } @@ -42,6 +42,20 @@ func (dst *Timestamptz) ConvertFrom(src interface{}) error { return nil } +func (dst *Timestamptz) Get() interface{} { + switch dst.Status { + case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst.Time + case Null: + return nil + default: + return dst.Status + } +} + func (src *Timestamptz) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: diff --git a/timestamptz_array.go b/timestamptz_array.go index fc3ce08c..775ec970 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -16,7 +16,7 @@ type TimestamptzArray struct { Status Status } -func (dst *TimestamptzArray) ConvertFrom(src interface{}) error { +func (dst *TimestamptzArray) Set(src interface{}) error { switch value := src.(type) { case TimestamptzArray: *dst = value @@ -29,7 +29,7 @@ func (dst *TimestamptzArray) ConvertFrom(src interface{}) error { } else { elements := make([]Timestamptz, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -42,7 +42,7 @@ func (dst *TimestamptzArray) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Timestamptz", value) } @@ -50,6 +50,17 @@ func (dst *TimestamptzArray) ConvertFrom(src interface{}) error { return nil } +func (dst *TimestamptzArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *TimestamptzArray) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/timestamptz_array_test.go b/timestamptz_array_test.go index af2c004b..e0017828 100644 --- a/timestamptz_array_test.go +++ b/timestamptz_array_test.go @@ -68,7 +68,7 @@ func TestTimestamptzArrayTranscode(t *testing.T) { }) } -func TestTimestamptzArrayConvertFrom(t *testing.T) { +func TestTimestamptzArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.TimestamptzArray @@ -88,7 +88,7 @@ func TestTimestamptzArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.TimestamptzArray - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/timestamptz_test.go b/timestamptz_test.go index 8f80ca81..242cd05f 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -31,7 +31,7 @@ func TestTimestamptzTranscode(t *testing.T) { }) } -func TestTimestamptzConvertFrom(t *testing.T) { +func TestTimestamptzSet(t *testing.T) { type _time time.Time successfulTests := []struct { @@ -50,7 +50,7 @@ func TestTimestamptzConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Timestamptz - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/typed_array.go.erb b/typed_array.go.erb index 98c8d845..c62e2896 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -14,7 +14,7 @@ type <%= pgtype_array_type %> struct { Status Status } -func (dst *<%= pgtype_array_type %>) ConvertFrom(src interface{}) error { +func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { switch value := src.(type) { case <%= pgtype_array_type %>: *dst = value @@ -27,7 +27,7 @@ func (dst *<%= pgtype_array_type %>) ConvertFrom(src interface{}) error { } else { elements := make([]<%= pgtype_element_type %>, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -40,7 +40,7 @@ func (dst *<%= pgtype_array_type %>) ConvertFrom(src interface{}) error { <% end %> default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to <%= pgtype_element_type %>", value) } @@ -48,6 +48,17 @@ func (dst *<%= pgtype_array_type %>) ConvertFrom(src interface{}) error { return nil } +func (dst *<%= pgtype_array_type %>) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { switch v := dst.(type) { <% go_array_types.split(",").each do |t| %> diff --git a/varchar_array.go b/varchar_array.go index b9d87b7f..693b9a61 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -6,8 +6,12 @@ import ( type VarcharArray TextArray -func (dst *VarcharArray) ConvertFrom(src interface{}) error { - return (*TextArray)(dst).ConvertFrom(src) +func (dst *VarcharArray) Set(src interface{}) error { + return (*TextArray)(dst).Set(src) +} + +func (dst *VarcharArray) Get() interface{} { + return (*TextArray)(dst).Get() } func (src *VarcharArray) AssignTo(dst interface{}) error { diff --git a/xid.go b/xid.go index 7deaa4f0..a53120de 100644 --- a/xid.go +++ b/xid.go @@ -20,11 +20,15 @@ import ( // in the PostgreSQL sources. type Xid pguint32 -// ConvertFrom converts from src to dst. Note that as Xid is not a general -// number type ConvertFrom does not do automatic type conversion as other number +// Set converts from src to dst. Note that as Xid is not a general +// number type Set does not do automatic type conversion as other number // types do. -func (dst *Xid) ConvertFrom(src interface{}) error { - return (*pguint32)(dst).ConvertFrom(src) +func (dst *Xid) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst *Xid) Get() interface{} { + return (*pguint32)(dst).Get() } // AssignTo assigns from src to dst. Note that as Xid is not a general number diff --git a/xid_test.go b/xid_test.go index a5c5df51..fecfb64b 100644 --- a/xid_test.go +++ b/xid_test.go @@ -14,7 +14,7 @@ func TestXidTranscode(t *testing.T) { }) } -func TestXidConvertFrom(t *testing.T) { +func TestXidSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Xid @@ -24,7 +24,7 @@ func TestXidConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Xid - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } From b94ccae4c9a6b3086c46d15d6253cb083c26fc3d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 20:12:47 -0600 Subject: [PATCH 0034/1158] Document that Decode* must not keep src - Also fix Bytea.DecodeBinary to not keep src. --- bytea.go | 5 ++++- pgtype.go | 6 ++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/bytea.go b/bytea.go index 9d2e20f3..a8ee55ae 100644 --- a/bytea.go +++ b/bytea.go @@ -106,7 +106,10 @@ func (dst *Bytea) DecodeBinary(src []byte) error { return nil } - *dst = Bytea{Bytes: src, Status: Present} + buf := make([]byte, len(src)) + copy(buf, src) + + *dst = Bytea{Bytes: buf, Status: Present} return nil } diff --git a/pgtype.go b/pgtype.go index 5a51172e..7b1470b7 100644 --- a/pgtype.go +++ b/pgtype.go @@ -80,10 +80,16 @@ type Value interface { } type BinaryDecoder interface { + // DecodeBinary decodes src into BinaryDecoder. If src is nil then the + // original SQL value is NULL. BinaryDecoder MUST not retain a reference to + // src. It MUST make a copy if it needs to retain the raw bytes. DecodeBinary(src []byte) error } type TextDecoder interface { + // DecodeText decodes src into TextDecoder. If src is nil then the original + // SQL value is NULL. TextDecoder MUST not retain a reference to src. It MUST + // make a copy if it needs to retain the raw bytes. DecodeText(src []byte) error } From a79b498533c17c032bc10da1dc631ae78be30c20 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 20:18:56 -0600 Subject: [PATCH 0035/1158] Remove Set self support from pgtype Set having the capability to assign an object of the same type was inconsistently implemented. Some places it was not implemented at all, some places it was a shallow copy, some places a deep copy. Given that it doesn't seem likely to ever be used, and if it is needed it is easy enough to do outside of the library this code has been removed. --- aclitem.go | 2 -- aclitem_array.go | 2 -- aclitem_test.go | 1 - bool.go | 2 -- bool_array.go | 2 -- bool_test.go | 1 - bytea.go | 2 -- bytea_array.go | 2 -- bytea_test.go | 1 - date.go | 2 -- date_array.go | 2 -- date_test.go | 1 - float4.go | 2 -- float4_array.go | 2 -- float8.go | 2 -- float8_array.go | 2 -- inet.go | 2 -- inet_array.go | 2 -- inet_test.go | 1 - int2.go | 2 -- int2_array.go | 2 -- int4.go | 2 -- int4_array.go | 2 -- int8.go | 2 -- int8_array.go | 2 -- qchar.go | 2 -- text.go | 2 -- text_array.go | 2 -- text_test.go | 1 - timestamp.go | 2 -- timestamp_array.go | 2 -- timestamp_test.go | 1 - timestamptz.go | 2 -- timestamptz_array.go | 2 -- timestamptz_test.go | 1 - typed_array.go.erb | 2 -- 36 files changed, 64 deletions(-) diff --git a/aclitem.go b/aclitem.go index 36cf3bbf..b8a1549e 100644 --- a/aclitem.go +++ b/aclitem.go @@ -25,8 +25,6 @@ type Aclitem struct { func (dst *Aclitem) Set(src interface{}) error { switch value := src.(type) { - case Aclitem: - *dst = value case string: *dst = Aclitem{String: value, Status: Present} case *string: diff --git a/aclitem_array.go b/aclitem_array.go index 13952e5c..5e3647b7 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -16,8 +16,6 @@ type AclitemArray struct { func (dst *AclitemArray) Set(src interface{}) error { switch value := src.(type) { - case AclitemArray: - *dst = value case []string: if value == nil { diff --git a/aclitem_test.go b/aclitem_test.go index 47e6fa84..1738025a 100644 --- a/aclitem_test.go +++ b/aclitem_test.go @@ -20,7 +20,6 @@ func TestAclitemSet(t *testing.T) { source interface{} result pgtype.Aclitem }{ - {source: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, result: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, {source: "postgres=arwdDxt/postgres", result: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, {source: (*string)(nil), result: pgtype.Aclitem{Status: pgtype.Null}}, } diff --git a/bool.go b/bool.go index 04a261c2..a8e9b8e1 100644 --- a/bool.go +++ b/bool.go @@ -14,8 +14,6 @@ type Bool struct { func (dst *Bool) Set(src interface{}) error { switch value := src.(type) { - case Bool: - *dst = value case bool: *dst = Bool{Bool: value, Status: Present} case string: diff --git a/bool_array.go b/bool_array.go index fdcbf7a0..4c5fc563 100644 --- a/bool_array.go +++ b/bool_array.go @@ -17,8 +17,6 @@ type BoolArray struct { func (dst *BoolArray) Set(src interface{}) error { switch value := src.(type) { - case BoolArray: - *dst = value case []bool: if value == nil { diff --git a/bool_test.go b/bool_test.go index 773bd99b..412e2fd0 100644 --- a/bool_test.go +++ b/bool_test.go @@ -20,7 +20,6 @@ func TestBoolSet(t *testing.T) { source interface{} result pgtype.Bool }{ - {source: pgtype.Bool{Bool: false, Status: pgtype.Null}, result: pgtype.Bool{Bool: false, Status: pgtype.Null}}, {source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, {source: false, result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, diff --git a/bytea.go b/bytea.go index a8ee55ae..5df05360 100644 --- a/bytea.go +++ b/bytea.go @@ -14,8 +14,6 @@ type Bytea struct { func (dst *Bytea) Set(src interface{}) error { switch value := src.(type) { - case Bytea: - *dst = value case []byte: if value != nil { *dst = Bytea{Bytes: value, Status: Present} diff --git a/bytea_array.go b/bytea_array.go index 5362944a..c6f676a4 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -17,8 +17,6 @@ type ByteaArray struct { func (dst *ByteaArray) Set(src interface{}) error { switch value := src.(type) { - case ByteaArray: - *dst = value case [][]byte: if value == nil { diff --git a/bytea_test.go b/bytea_test.go index 4655a1c1..e21296c6 100644 --- a/bytea_test.go +++ b/bytea_test.go @@ -20,7 +20,6 @@ func TestByteaSet(t *testing.T) { source interface{} result pgtype.Bytea }{ - {source: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Null}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Null}}, {source: []byte{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, {source: []byte{}, result: pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}}, {source: []byte(nil), result: pgtype.Bytea{Status: pgtype.Null}}, diff --git a/date.go b/date.go index a3b8d99f..d0481637 100644 --- a/date.go +++ b/date.go @@ -23,8 +23,6 @@ const ( func (dst *Date) Set(src interface{}) error { switch value := src.(type) { - case Date: - *dst = value case time.Time: *dst = Date{Time: value, Status: Present} default: diff --git a/date_array.go b/date_array.go index ce28e236..7f602d83 100644 --- a/date_array.go +++ b/date_array.go @@ -18,8 +18,6 @@ type DateArray struct { func (dst *DateArray) Set(src interface{}) error { switch value := src.(type) { - case DateArray: - *dst = value case []time.Time: if value == nil { diff --git a/date_test.go b/date_test.go index eff3a521..cfc3dd70 100644 --- a/date_test.go +++ b/date_test.go @@ -29,7 +29,6 @@ func TestDateSet(t *testing.T) { source interface{} result pgtype.Date }{ - {source: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, diff --git a/float4.go b/float4.go index a38d24db..053af44b 100644 --- a/float4.go +++ b/float4.go @@ -17,8 +17,6 @@ type Float4 struct { func (dst *Float4) Set(src interface{}) error { switch value := src.(type) { - case Float4: - *dst = value case float32: *dst = Float4{Float: value, Status: Present} case float64: diff --git a/float4_array.go b/float4_array.go index 410a8b37..0e815e0b 100644 --- a/float4_array.go +++ b/float4_array.go @@ -17,8 +17,6 @@ type Float4Array struct { func (dst *Float4Array) Set(src interface{}) error { switch value := src.(type) { - case Float4Array: - *dst = value case []float32: if value == nil { diff --git a/float8.go b/float8.go index 9129e8ba..635b7a09 100644 --- a/float8.go +++ b/float8.go @@ -17,8 +17,6 @@ type Float8 struct { func (dst *Float8) Set(src interface{}) error { switch value := src.(type) { - case Float8: - *dst = value case float32: *dst = Float8{Float: float64(value), Status: Present} case float64: diff --git a/float8_array.go b/float8_array.go index b2f70f51..811c5a1f 100644 --- a/float8_array.go +++ b/float8_array.go @@ -17,8 +17,6 @@ type Float8Array struct { func (dst *Float8Array) Set(src interface{}) error { switch value := src.(type) { - case Float8Array: - *dst = value case []float64: if value == nil { diff --git a/inet.go b/inet.go index 00bfb30c..87d675f9 100644 --- a/inet.go +++ b/inet.go @@ -25,8 +25,6 @@ type Inet struct { func (dst *Inet) Set(src interface{}) error { switch value := src.(type) { - case Inet: - *dst = value case net.IPNet: *dst = Inet{IPNet: &value, Status: Present} case *net.IPNet: diff --git a/inet_array.go b/inet_array.go index 4d865b4f..1d1cf3fd 100644 --- a/inet_array.go +++ b/inet_array.go @@ -18,8 +18,6 @@ type InetArray struct { func (dst *InetArray) Set(src interface{}) error { switch value := src.(type) { - case InetArray: - *dst = value case []*net.IPNet: if value == nil { diff --git a/inet_test.go b/inet_test.go index 90b0723f..16035fca 100644 --- a/inet_test.go +++ b/inet_test.go @@ -31,7 +31,6 @@ func TestInetSet(t *testing.T) { source interface{} result pgtype.Inet }{ - {source: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Null}, result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Null}}, {source: mustParseCidr(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: mustParseCidr(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, diff --git a/int2.go b/int2.go index 525427c5..62e1bc69 100644 --- a/int2.go +++ b/int2.go @@ -17,8 +17,6 @@ type Int2 struct { func (dst *Int2) Set(src interface{}) error { switch value := src.(type) { - case Int2: - *dst = value case int8: *dst = Int2{Int: int16(value), Status: Present} case uint8: diff --git a/int2_array.go b/int2_array.go index 28792fa5..3d06c018 100644 --- a/int2_array.go +++ b/int2_array.go @@ -17,8 +17,6 @@ type Int2Array struct { func (dst *Int2Array) Set(src interface{}) error { switch value := src.(type) { - case Int2Array: - *dst = value case []int16: if value == nil { diff --git a/int4.go b/int4.go index b3203a28..8eaf5094 100644 --- a/int4.go +++ b/int4.go @@ -17,8 +17,6 @@ type Int4 struct { func (dst *Int4) Set(src interface{}) error { switch value := src.(type) { - case Int4: - *dst = value case int8: *dst = Int4{Int: int32(value), Status: Present} case uint8: diff --git a/int4_array.go b/int4_array.go index 61cedb2e..5cd91c04 100644 --- a/int4_array.go +++ b/int4_array.go @@ -17,8 +17,6 @@ type Int4Array struct { func (dst *Int4Array) Set(src interface{}) error { switch value := src.(type) { - case Int4Array: - *dst = value case []int32: if value == nil { diff --git a/int8.go b/int8.go index 15ad6715..2416500d 100644 --- a/int8.go +++ b/int8.go @@ -17,8 +17,6 @@ type Int8 struct { func (dst *Int8) Set(src interface{}) error { switch value := src.(type) { - case Int8: - *dst = value case int8: *dst = Int8{Int: int64(value), Status: Present} case uint8: diff --git a/int8_array.go b/int8_array.go index 9f4373e8..5efc0f45 100644 --- a/int8_array.go +++ b/int8_array.go @@ -17,8 +17,6 @@ type Int8Array struct { func (dst *Int8Array) Set(src interface{}) error { switch value := src.(type) { - case Int8Array: - *dst = value case []int64: if value == nil { diff --git a/qchar.go b/qchar.go index b6392cf9..d46e716d 100644 --- a/qchar.go +++ b/qchar.go @@ -25,8 +25,6 @@ type QChar struct { func (dst *QChar) Set(src interface{}) error { switch value := src.(type) { - case QChar: - *dst = value case int8: *dst = QChar{Int: value, Status: Present} case uint8: diff --git a/text.go b/text.go index 50db2349..3dd082c9 100644 --- a/text.go +++ b/text.go @@ -13,8 +13,6 @@ type Text struct { func (dst *Text) Set(src interface{}) error { switch value := src.(type) { - case Text: - *dst = value case string: *dst = Text{String: value, Status: Present} case *string: diff --git a/text_array.go b/text_array.go index 3a5a64ce..1e6677a9 100644 --- a/text_array.go +++ b/text_array.go @@ -17,8 +17,6 @@ type TextArray struct { func (dst *TextArray) Set(src interface{}) error { switch value := src.(type) { - case TextArray: - *dst = value case []string: if value == nil { diff --git a/text_test.go b/text_test.go index f5e20055..39348bcc 100644 --- a/text_test.go +++ b/text_test.go @@ -22,7 +22,6 @@ func TestTextSet(t *testing.T) { source interface{} result pgtype.Text }{ - {source: pgtype.Text{String: "foo", Status: pgtype.Present}, result: pgtype.Text{String: "foo", Status: pgtype.Present}}, {source: "foo", result: pgtype.Text{String: "foo", Status: pgtype.Present}}, {source: _string("bar"), result: pgtype.Text{String: "bar", Status: pgtype.Present}}, {source: (*string)(nil), result: pgtype.Text{Status: pgtype.Null}}, diff --git a/timestamp.go b/timestamp.go index a84f3881..3bb8f080 100644 --- a/timestamp.go +++ b/timestamp.go @@ -27,8 +27,6 @@ type Timestamp struct { // time.Time in a non-UTC time zone, the time zone is discarded. func (dst *Timestamp) Set(src interface{}) error { switch value := src.(type) { - case Timestamp: - *dst = value case time.Time: *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present} default: diff --git a/timestamp_array.go b/timestamp_array.go index ec0facb2..c955dc42 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -18,8 +18,6 @@ type TimestampArray struct { func (dst *TimestampArray) Set(src interface{}) error { switch value := src.(type) { - case TimestampArray: - *dst = value case []time.Time: if value == nil { diff --git a/timestamp_test.go b/timestamp_test.go index 7297ed1f..58828806 100644 --- a/timestamp_test.go +++ b/timestamp_test.go @@ -38,7 +38,6 @@ func TestTimestampSet(t *testing.T) { source interface{} result pgtype.Timestamp }{ - {source: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, result: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), Status: pgtype.Present}}, diff --git a/timestamptz.go b/timestamptz.go index a6922d5b..5b9f5038 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -28,8 +28,6 @@ type Timestamptz struct { func (dst *Timestamptz) Set(src interface{}) error { switch value := src.(type) { - case Timestamptz: - *dst = value case time.Time: *dst = Timestamptz{Time: value, Status: Present} default: diff --git a/timestamptz_array.go b/timestamptz_array.go index 775ec970..cd63e02e 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -18,8 +18,6 @@ type TimestamptzArray struct { func (dst *TimestamptzArray) Set(src interface{}) error { switch value := src.(type) { - case TimestamptzArray: - *dst = value case []time.Time: if value == nil { diff --git a/timestamptz_test.go b/timestamptz_test.go index 242cd05f..6ddfc1bc 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -38,7 +38,6 @@ func TestTimestamptzSet(t *testing.T) { source interface{} result pgtype.Timestamptz }{ - {source: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Status: pgtype.Present}}, diff --git a/typed_array.go.erb b/typed_array.go.erb index c62e2896..a56097c0 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -16,8 +16,6 @@ type <%= pgtype_array_type %> struct { func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { switch value := src.(type) { - case <%= pgtype_array_type %>: - *dst = value <% go_array_types.split(",").each do |t| %> case <%= t %>: if value == nil { From 45b33519d78213b1bc5d41c0c209e78addd682cc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 20:28:14 -0600 Subject: [PATCH 0036/1158] Add pgtype GenericText and GenericBinary Rows.Values uses this for unknown types. --- generic_binary.go | 29 +++++++++++++++++++++++++++++ generic_text.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 generic_binary.go create mode 100644 generic_text.go diff --git a/generic_binary.go b/generic_binary.go new file mode 100644 index 00000000..ac35ea60 --- /dev/null +++ b/generic_binary.go @@ -0,0 +1,29 @@ +package pgtype + +import ( + "io" +) + +// GenericBinary is a placeholder for binary format values that no other type exists +// to handle. +type GenericBinary Bytea + +func (dst *GenericBinary) Set(src interface{}) error { + return (*Bytea)(dst).Set(src) +} + +func (dst *GenericBinary) Get() interface{} { + return (*Bytea)(dst).Get() +} + +func (src *GenericBinary) AssignTo(dst interface{}) error { + return (*Bytea)(src).AssignTo(dst) +} + +func (dst *GenericBinary) DecodeBinary(src []byte) error { + return (*Bytea)(dst).DecodeBinary(src) +} + +func (src GenericBinary) EncodeBinary(w io.Writer) (bool, error) { + return (Bytea)(src).EncodeBinary(w) +} diff --git a/generic_text.go b/generic_text.go new file mode 100644 index 00000000..19f41059 --- /dev/null +++ b/generic_text.go @@ -0,0 +1,29 @@ +package pgtype + +import ( + "io" +) + +// GenericText is a placeholder for text format values that no other type exists +// to handle. +type GenericText Text + +func (dst *GenericText) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *GenericText) Get() interface{} { + return (*Text)(dst).Get() +} + +func (src *GenericText) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *GenericText) DecodeText(src []byte) error { + return (*Text)(dst).DecodeText(src) +} + +func (src GenericText) EncodeText(w io.Writer) (bool, error) { + return (Text)(src).EncodeText(w) +} From f9e58790729ad761d100ef18a614d0c9768dfbff Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 12 Mar 2017 17:06:06 -0500 Subject: [PATCH 0037/1158] Move hstore to pgtype Also implement binary format --- hstore.go | 438 +++++++++++++++++++++++++++++++++++++++++++++++++ hstore_test.go | 108 ++++++++++++ 2 files changed, 546 insertions(+) create mode 100644 hstore.go create mode 100644 hstore_test.go diff --git a/hstore.go b/hstore.go new file mode 100644 index 00000000..11bfb9a7 --- /dev/null +++ b/hstore.go @@ -0,0 +1,438 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "strings" + "unicode" + "unicode/utf8" + + "github.com/jackc/pgx/pgio" +) + +// Hstore represents an hstore column that can be null or have null values +// associated with its keys. +type Hstore struct { + Map map[string]Text + Status Status +} + +func (dst *Hstore) Set(src interface{}) error { + switch value := src.(type) { + case map[string]string: + m := make(map[string]Text, len(value)) + for k, v := range value { + m[k] = Text{String: v, Status: Present} + } + *dst = Hstore{Map: m, Status: Present} + default: + return fmt.Errorf("cannot convert %v to Tid", src) + } + + return nil +} + +func (dst *Hstore) Get() interface{} { + switch dst.Status { + case Present: + return dst.Map + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Hstore) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *map[string]string: + switch src.Status { + case Present: + *v = make(map[string]string, len(src.Map)) + for k, val := range src.Map { + if val.Status != Present { + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + (*v)[k] = val.String + } + case Null: + *v = nil + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Hstore) DecodeText(src []byte) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + keys, values, err := parseHstore(string(src)) + if err != nil { + return err + } + + m := make(map[string]Text, len(keys)) + for i := range keys { + m[keys[i]] = values[i] + } + + *dst = Hstore{Map: m, Status: Present} + return nil +} + +func (dst *Hstore) DecodeBinary(src []byte) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + rp := 0 + + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + m := make(map[string]Text, pairCount) + + for i := 0; i < pairCount; i++ { + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + if len(src[rp:]) < keyLen { + return fmt.Errorf("hstore incomplete %v", src) + } + key := string(src[rp : rp+keyLen]) + rp += keyLen + + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + var valueBuf []byte + if valueLen >= 0 { + valueBuf = src[rp : rp+valueLen] + } + rp += valueLen + + var value Text + err := value.DecodeBinary(valueBuf) + if err != nil { + return err + } + m[key] = value + } + + *dst = Hstore{Map: m, Status: Present} + + return nil +} + +func (src Hstore) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + firstPair := true + + for k, v := range src.Map { + if firstPair { + firstPair = false + } else { + err := pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + _, err := io.WriteString(w, quoteHstoreElementIfNeeded(k)) + if err != nil { + return false, err + } + + _, err = io.WriteString(w, "=>") + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + null, err := v.EncodeText(elemBuf) + if err != nil { + return false, err + } + + if null { + _, err = io.WriteString(w, "NULL") + if err != nil { + return false, err + } + } else { + _, err := io.WriteString(w, quoteHstoreElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + } + + return false, nil +} + +func (src Hstore) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := pgio.WriteInt32(w, int32(len(src.Map))) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + for k, v := range src.Map { + _, err := pgio.WriteInt32(w, int32(len(k))) + if err != nil { + return false, err + } + _, err = io.WriteString(w, k) + if err != nil { + return false, err + } + + null, err := v.EncodeText(elemBuf) + if err != nil { + return false, err + } + if null { + _, err := pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err := pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err +} + +var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteHstoreElement(src string) string { + return `"` + quoteArrayReplacer.Replace(src) + `"` +} + +func quoteHstoreElementIfNeeded(src string) string { + if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, ` {},"\=>`) { + return quoteArrayElement(src) + } + return src +} + +const ( + hsPre = iota + hsKey + hsSep + hsVal + hsNul + hsNext +) + +type hstoreParser struct { + str string + pos int +} + +func newHSP(in string) *hstoreParser { + return &hstoreParser{ + pos: 0, + str: in, + } +} + +func (p *hstoreParser) Consume() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, w := utf8.DecodeRuneInString(p.str[p.pos:]) + p.pos += w + return +} + +func (p *hstoreParser) Peek() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) + return +} + +// parseHstore parses the string representation of an hstore column (the same +// you would get from an ordinary SELECT) into two slices of keys and values. it +// is used internally in the default parsing of hstores. +func parseHstore(s string) (k []string, v []Text, err error) { + if s == "" { + return + } + + buf := bytes.Buffer{} + keys := []string{} + values := []Text{} + p := newHSP(s) + + r, end := p.Consume() + state := hsPre + + for !end { + switch state { + case hsPre: + if r == '"' { + state = hsKey + } else { + err = errors.New("String does not begin with \"") + } + case hsKey: + switch r { + case '"': //End of the key + if buf.Len() == 0 { + err = errors.New("Empty Key is invalid") + } else { + keys = append(keys, buf.String()) + buf = bytes.Buffer{} + state = hsSep + } + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsSep: + if r == '=' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=', expecting '>'") + case r == '>': + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") + case r == '"': + state = hsVal + case r == 'N': + state = hsNul + default: + err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) + } + default: + err = fmt.Errorf("Invalid character after '=', expecting '>'") + } + } else { + err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) + } + case hsVal: + switch r { + case '"': //End of the value + values = append(values, Text{String: buf.String(), Status: Present}) + buf = bytes.Buffer{} + state = hsNext + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsNul: + nulBuf := make([]rune, 3) + nulBuf[0] = r + for i := 1; i < 3; i++ { + r, end = p.Consume() + if end { + err = errors.New("Found EOS in NULL value") + return + } + nulBuf[i] = r + } + if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { + values = append(values, Text{Status: Null}) + state = hsNext + } else { + err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) + } + case hsNext: + if r == ',' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after ',', expcting space") + case (unicode.IsSpace(r)): + r, end = p.Consume() + state = hsKey + default: + err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) + } + } else { + err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) + } + } + + if err != nil { + return + } + r, end = p.Consume() + } + if state != hsNext { + err = errors.New("Improperly formatted hstore") + return + } + k = keys + v = values + return +} diff --git a/hstore_test.go b/hstore_test.go new file mode 100644 index 00000000..fbe8dee5 --- /dev/null +++ b/hstore_test.go @@ -0,0 +1,108 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestHstoreTranscode(t *testing.T) { + text := func(s string) pgtype.Text { + return pgtype.Text{String: s, Status: pgtype.Present} + } + + values := []interface{}{ + pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + pgtype.Hstore{Status: pgtype.Null}, + } + + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + + // Special value values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + } + + testSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { + a := ai.(pgtype.Hstore) + b := bi.(pgtype.Hstore) + + if len(a.Map) != len(b.Map) || a.Status != b.Status { + return false + } + + for k := range a.Map { + if a.Map[k] != b.Map[k] { + return false + } + } + + return true + }) +} + +func TestHstoreSet(t *testing.T) { + successfulTests := []struct { + src map[string]string + result pgtype.Hstore + }{ + {src: map[string]string{"foo": "bar"}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var dst pgtype.Hstore + err := dst.Set(tt.src) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(dst, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) + } + } +} + +func TestHstoreAssignTo(t *testing.T) { + var m map[string]string + + simpleTests := []struct { + src pgtype.Hstore + dst *map[string]string + expected map[string]string + }{ + {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}, dst: &m, expected: map[string]string{"foo": "bar"}}, + {src: pgtype.Hstore{Status: pgtype.Null}, dst: &m, expected: ((map[string]string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(*tt.dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} From 937368fd5fdc6a5fd17f53f90fdbec3d2c669163 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 13 Mar 2017 20:23:17 -0500 Subject: [PATCH 0038/1158] Fix error message for hstore --- hstore.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hstore.go b/hstore.go index 11bfb9a7..c48ae6da 100644 --- a/hstore.go +++ b/hstore.go @@ -29,7 +29,7 @@ func (dst *Hstore) Set(src interface{}) error { } *dst = Hstore{Map: m, Status: Present} default: - return fmt.Errorf("cannot convert %v to Tid", src) + return fmt.Errorf("cannot convert %v to Hstore", src) } return nil From b31d409dc25775e19b142672d90d32e3ea3654a9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 13 Mar 2017 21:34:38 -0500 Subject: [PATCH 0039/1158] Move not null Oid to pgtype In preparation to ConnInfo implementation. --- oid.go | 58 +++++++++++++++++++------------- oid_value.go | 45 +++++++++++++++++++++++++ oid_test.go => oid_value_test.go | 30 ++++++++--------- 3 files changed, 95 insertions(+), 38 deletions(-) create mode 100644 oid_value.go rename oid_test.go => oid_value_test.go (66%) diff --git a/oid.go b/oid.go index e57bb2e6..eab1fbcb 100644 --- a/oid.go +++ b/oid.go @@ -1,45 +1,57 @@ package pgtype import ( + "encoding/binary" + "fmt" "io" + "strconv" + + "github.com/jackc/pgx/pgio" ) // Oid (Object Identifier Type) is, according to // https://www.postgresql.org/docs/current/static/datatype-oid.html, used // internally by PostgreSQL as a primary key for various system tables. It is // currently implemented as an unsigned four-byte integer. Its definition can be -// found in src/include/postgres_ext.h in the PostgreSQL sources. -type Oid pguint32 - -// Set converts from src to dst. Note that as Oid is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *Oid) Set(src interface{}) error { - return (*pguint32)(dst).Set(src) -} - -func (dst *Oid) Get() interface{} { - return (*pguint32)(dst).Get() -} - -// AssignTo assigns from src to dst. Note that as Oid is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *Oid) AssignTo(dst interface{}) error { - return (*pguint32)(src).AssignTo(dst) -} +// found in src/include/postgres_ext.h in the PostgreSQL sources. Because it is +// so frequently required to be in a NOT NULL condition Oid cannot be NULL. To +// allow for NULL Oids use OidValue. +type Oid uint32 func (dst *Oid) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) + if src == nil { + return fmt.Errorf("cannot decode nil into Oid") + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + *dst = Oid(n) + return nil } func (dst *Oid) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) + if src == nil { + return fmt.Errorf("cannot decode nil into Oid") + } + + if len(src) != 4 { + return fmt.Errorf("invalid length: %v", len(src)) + } + + n := binary.BigEndian.Uint32(src) + *dst = Oid(n) + return nil } func (src Oid) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) + _, err := io.WriteString(w, strconv.FormatUint(uint64(src), 10)) + return false, err } func (src Oid) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) + _, err := pgio.WriteUint32(w, uint32(src)) + return false, err } diff --git a/oid_value.go b/oid_value.go new file mode 100644 index 00000000..a2b2dcbe --- /dev/null +++ b/oid_value.go @@ -0,0 +1,45 @@ +package pgtype + +import ( + "io" +) + +// OidValue (Object Identifier Type) is, according to +// https://www.postgresql.org/docs/current/static/datatype-OidValue.html, used +// internally by PostgreSQL as a primary key for various system tables. It is +// currently implemented as an unsigned four-byte integer. Its definition can be +// found in src/include/postgres_ext.h in the PostgreSQL sources. +type OidValue pguint32 + +// Set converts from src to dst. Note that as OidValue is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *OidValue) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst *OidValue) Get() interface{} { + return (*pguint32)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as OidValue is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *OidValue) AssignTo(dst interface{}) error { + return (*pguint32)(src).AssignTo(dst) +} + +func (dst *OidValue) DecodeText(src []byte) error { + return (*pguint32)(dst).DecodeText(src) +} + +func (dst *OidValue) DecodeBinary(src []byte) error { + return (*pguint32)(dst).DecodeBinary(src) +} + +func (src OidValue) EncodeText(w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(w) +} + +func (src OidValue) EncodeBinary(w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(w) +} diff --git a/oid_test.go b/oid_value_test.go similarity index 66% rename from oid_test.go rename to oid_value_test.go index b3b96959..21dd6f9d 100644 --- a/oid_test.go +++ b/oid_value_test.go @@ -7,23 +7,23 @@ import ( "github.com/jackc/pgx/pgtype" ) -func TestOidTranscode(t *testing.T) { +func TestOidValueTranscode(t *testing.T) { testSuccessfulTranscode(t, "oid", []interface{}{ - pgtype.Oid{Uint: 42, Status: pgtype.Present}, - pgtype.Oid{Status: pgtype.Null}, + pgtype.OidValue{Uint: 42, Status: pgtype.Present}, + pgtype.OidValue{Status: pgtype.Null}, }) } -func TestOidSet(t *testing.T) { +func TestOidValueSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.Oid + result pgtype.OidValue }{ - {source: uint32(1), result: pgtype.Oid{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.OidValue{Uint: 1, Status: pgtype.Present}}, } for i, tt := range successfulTests { - var r pgtype.Oid + var r pgtype.OidValue err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -35,17 +35,17 @@ func TestOidSet(t *testing.T) { } } -func TestOidAssignTo(t *testing.T) { +func TestOidValueAssignTo(t *testing.T) { var ui32 uint32 var pui32 *uint32 simpleTests := []struct { - src pgtype.Oid + src pgtype.OidValue dst interface{} expected interface{} }{ - {src: pgtype.Oid{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Oid{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.OidValue{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.OidValue{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -60,11 +60,11 @@ func TestOidAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.Oid + src pgtype.OidValue dst interface{} expected interface{} }{ - {src: pgtype.Oid{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.OidValue{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -79,10 +79,10 @@ func TestOidAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.Oid + src pgtype.OidValue dst interface{} }{ - {src: pgtype.Oid{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.OidValue{Status: pgtype.Null}, dst: &ui32}, } for i, tt := range errorTests { From 6e21cb00fe39e5a8be930d6b8afeb2c961b990e1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 12:01:16 -0500 Subject: [PATCH 0040/1158] Add pgtype.Record and prerequisite restructuring Because reading a record type requires the decoder to be able to look up oid to type mapping and types such as hstore have types that are not fixed between different PostgreSQL servers it was necessary to restructure the pgtype system so all encoders and decodes take a *ConnInfo that includes oid/name/type information. --- aclitem.go | 4 +- aclitem_array.go | 8 +- array.go | 4 +- bool.go | 8 +- bool_array.go | 24 ++-- bytea.go | 8 +- bytea_array.go | 24 ++-- cid.go | 16 +-- cidr.go | 35 +++++ cidr_array.go | 317 ++++++++++++++++++++++++++++++++++++++++-- cidr_array_test.go | 164 ++++++++++++++++++++++ database_sql.go | 66 +++++++++ date.go | 11 +- date_array.go | 24 ++-- float4.go | 8 +- float4_array.go | 24 ++-- float8.go | 8 +- float8_array.go | 24 ++-- generic_binary.go | 8 +- generic_text.go | 8 +- hstore.go | 14 +- inet.go | 8 +- inet_array.go | 24 ++-- int2.go | 8 +- int2_array.go | 24 ++-- int4.go | 8 +- int4_array.go | 24 ++-- int8.go | 8 +- int8_array.go | 24 ++-- json.go | 12 +- jsonb.go | 12 +- name.go | 16 +-- oid.go | 8 +- oid_value.go | 16 +-- pgtype.go | 129 ++++++++++++++++- pgtype_test.go | 10 +- pguint32.go | 8 +- qchar.go | 4 +- record.go | 123 ++++++++++++++++ record_test.go | 150 ++++++++++++++++++++ text.go | 12 +- text_array.go | 24 ++-- tid.go | 8 +- timestamp.go | 8 +- timestamp_array.go | 24 ++-- timestamptz.go | 8 +- timestamptz_array.go | 24 ++-- typed_array.go.erb | 24 ++-- typed_array_gen.sh | 2 + unknown.go | 32 +++++ varchar.go | 40 ++++++ varchar_array.go | 285 +++++++++++++++++++++++++++++++++++-- varchar_array_test.go | 151 ++++++++++++++++++++ xid.go | 16 +-- 54 files changed, 1761 insertions(+), 320 deletions(-) create mode 100644 cidr.go create mode 100644 cidr_array_test.go create mode 100644 database_sql.go create mode 100644 record.go create mode 100644 record_test.go create mode 100644 unknown.go create mode 100644 varchar.go create mode 100644 varchar_array_test.go diff --git a/aclitem.go b/aclitem.go index b8a1549e..f9faab20 100644 --- a/aclitem.go +++ b/aclitem.go @@ -90,7 +90,7 @@ func (src *Aclitem) AssignTo(dst interface{}) error { return nil } -func (dst *Aclitem) DecodeText(src []byte) error { +func (dst *Aclitem) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Aclitem{Status: Null} return nil @@ -100,7 +100,7 @@ func (dst *Aclitem) DecodeText(src []byte) error { return nil } -func (src Aclitem) EncodeText(w io.Writer) (bool, error) { +func (src Aclitem) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/aclitem_array.go b/aclitem_array.go index 5e3647b7..f02d339e 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -82,7 +82,7 @@ func (src *AclitemArray) AssignTo(dst interface{}) error { return nil } -func (dst *AclitemArray) DecodeText(src []byte) error { +func (dst *AclitemArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = AclitemArray{Status: Null} return nil @@ -104,7 +104,7 @@ func (dst *AclitemArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -118,7 +118,7 @@ func (dst *AclitemArray) DecodeText(src []byte) error { return nil } -func (src *AclitemArray) EncodeText(w io.Writer) (bool, error) { +func (src *AclitemArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -165,7 +165,7 @@ func (src *AclitemArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } diff --git a/array.go b/array.go index dff0fe81..9561afe5 100644 --- a/array.go +++ b/array.go @@ -27,7 +27,7 @@ type ArrayDimension struct { LowerBound int32 } -func (dst *ArrayHeader) DecodeBinary(src []byte) (int, error) { +func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { if len(src) < 12 { return 0, fmt.Errorf("array header too short: %d", len(src)) } @@ -60,7 +60,7 @@ func (dst *ArrayHeader) DecodeBinary(src []byte) (int, error) { return rp, nil } -func (src *ArrayHeader) EncodeBinary(w io.Writer) error { +func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, w io.Writer) error { _, err := pgio.WriteInt32(w, int32(len(src.Dimensions))) if err != nil { return err diff --git a/bool.go b/bool.go index a8e9b8e1..87316381 100644 --- a/bool.go +++ b/bool.go @@ -79,7 +79,7 @@ func (src *Bool) AssignTo(dst interface{}) error { return nil } -func (dst *Bool) DecodeText(src []byte) error { +func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bool{Status: Null} return nil @@ -93,7 +93,7 @@ func (dst *Bool) DecodeText(src []byte) error { return nil } -func (dst *Bool) DecodeBinary(src []byte) error { +func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bool{Status: Null} return nil @@ -107,7 +107,7 @@ func (dst *Bool) DecodeBinary(src []byte) error { return nil } -func (src Bool) EncodeText(w io.Writer) (bool, error) { +func (src Bool) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -126,7 +126,7 @@ func (src Bool) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Bool) EncodeBinary(w io.Writer) (bool, error) { +func (src Bool) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/bool_array.go b/bool_array.go index 4c5fc563..1cb46cf6 100644 --- a/bool_array.go +++ b/bool_array.go @@ -83,7 +83,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error { return nil } -func (dst *BoolArray) DecodeText(src []byte) error { +func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = BoolArray{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *BoolArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *BoolArray) DecodeText(src []byte) error { return nil } -func (dst *BoolArray) DecodeBinary(src []byte) error { +func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = BoolArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *BoolArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *BoolArray) DecodeBinary(src []byte) error { return nil } -func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { +func (src *BoolArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *BoolArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, BoolOid) +func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, BoolOid) } -func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *BoolArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/bytea.go b/bytea.go index 5df05360..dc1e9c07 100644 --- a/bytea.go +++ b/bytea.go @@ -78,7 +78,7 @@ func (src *Bytea) AssignTo(dst interface{}) error { // DecodeText only supports the hex format. This has been the default since // PostgreSQL 9.0. -func (dst *Bytea) DecodeText(src []byte) error { +func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bytea{Status: Null} return nil @@ -98,7 +98,7 @@ func (dst *Bytea) DecodeText(src []byte) error { return nil } -func (dst *Bytea) DecodeBinary(src []byte) error { +func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bytea{Status: Null} return nil @@ -111,7 +111,7 @@ func (dst *Bytea) DecodeBinary(src []byte) error { return nil } -func (src Bytea) EncodeText(w io.Writer) (bool, error) { +func (src Bytea) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -128,7 +128,7 @@ func (src Bytea) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Bytea) EncodeBinary(w io.Writer) (bool, error) { +func (src Bytea) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/bytea_array.go b/bytea_array.go index c6f676a4..30405509 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -83,7 +83,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { return nil } -func (dst *ByteaArray) DecodeText(src []byte) error { +func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = ByteaArray{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *ByteaArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *ByteaArray) DecodeText(src []byte) error { return nil } -func (dst *ByteaArray) DecodeBinary(src []byte) error { +func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = ByteaArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *ByteaArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *ByteaArray) DecodeBinary(src []byte) error { return nil } -func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { +func (src *ByteaArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *ByteaArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, ByteaOid) +func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, ByteaOid) } -func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *ByteaArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/cid.go b/cid.go index 20957f36..d86e8063 100644 --- a/cid.go +++ b/cid.go @@ -34,18 +34,18 @@ func (src *Cid) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *Cid) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) +func (dst *Cid) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *Cid) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) +func (dst *Cid) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src Cid) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) +func (src Cid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(ci, w) } -func (src Cid) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) +func (src Cid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(ci, w) } diff --git a/cidr.go b/cidr.go new file mode 100644 index 00000000..463b279d --- /dev/null +++ b/cidr.go @@ -0,0 +1,35 @@ +package pgtype + +import ( + "io" +) + +type Cidr Inet + +func (dst *Cidr) Set(src interface{}) error { + return (*Inet)(dst).Set(src) +} + +func (dst *Cidr) Get() interface{} { + return (*Inet)(dst).Get() +} + +func (src *Cidr) AssignTo(dst interface{}) error { + return (*Inet)(src).AssignTo(dst) +} + +func (dst *Cidr) DecodeText(ci *ConnInfo, src []byte) error { + return (*Inet)(dst).DecodeText(ci, src) +} + +func (dst *Cidr) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Inet)(dst).DecodeBinary(ci, src) +} + +func (src Cidr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Inet)(src).EncodeText(ci, w) +} + +func (src Cidr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Inet)(src).EncodeBinary(ci, w) +} diff --git a/cidr_array.go b/cidr_array.go index c30c53d3..32d2e7bf 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -1,35 +1,328 @@ package pgtype import ( + "bytes" + "encoding/binary" + "fmt" "io" + "net" + + "github.com/jackc/pgx/pgio" ) -type CidrArray InetArray +type CidrArray struct { + Elements []Cidr + Dimensions []ArrayDimension + Status Status +} func (dst *CidrArray) Set(src interface{}) error { - return (*InetArray)(dst).Set(src) + switch value := src.(type) { + + case []*net.IPNet: + if value == nil { + *dst = CidrArray{Status: Null} + } else if len(value) == 0 { + *dst = CidrArray{Status: Present} + } else { + elements := make([]Cidr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CidrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []net.IP: + if value == nil { + *dst = CidrArray{Status: Null} + } else if len(value) == 0 { + *dst = CidrArray{Status: Present} + } else { + elements := make([]Cidr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CidrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Cidr", value) + } + + return nil } func (dst *CidrArray) Get() interface{} { - return (*InetArray)(dst).Get() + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } } func (src *CidrArray) AssignTo(dst interface{}) error { - return (*InetArray)(src).AssignTo(dst) + switch v := dst.(type) { + + case *[]*net.IPNet: + if src.Status == Present { + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + case *[]net.IP: + if src.Status == Present { + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil } -func (dst *CidrArray) DecodeText(src []byte) error { - return (*InetArray)(dst).DecodeText(src) +func (dst *CidrArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = CidrArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Cidr + + if len(uta.Elements) > 0 { + elements = make([]Cidr, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Cidr + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = CidrArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil } -func (dst *CidrArray) DecodeBinary(src []byte) error { - return (*InetArray)(dst).DecodeBinary(src) +func (dst *CidrArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = CidrArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = CidrArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Cidr, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = CidrArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil } -func (src *CidrArray) EncodeText(w io.Writer) (bool, error) { - return (*InetArray)(src).EncodeText(w) +func (src *CidrArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil } -func (src *CidrArray) EncodeBinary(w io.Writer) (bool, error) { - return (*InetArray)(src).encodeBinary(w, CidrOid) +func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, CidrOid) +} + +func (src *CidrArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + ElementOid: elementOid, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(ci, w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err } diff --git a/cidr_array_test.go b/cidr_array_test.go new file mode 100644 index 00000000..ec105914 --- /dev/null +++ b/cidr_array_test.go @@ -0,0 +1,164 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestCidrArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "cidr[]", []interface{}{ + &pgtype.CidrArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.CidrArray{ + Elements: []pgtype.Cidr{ + pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Cidr{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.CidrArray{Status: pgtype.Null}, + &pgtype.CidrArray{ + Elements: []pgtype.Cidr{ + pgtype.Cidr{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Cidr{Status: pgtype.Null}, + pgtype.Cidr{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.CidrArray{ + Elements: []pgtype.Cidr{ + pgtype.Cidr{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestCidrArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.CidrArray + }{ + { + source: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, + result: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]*net.IPNet)(nil)), + result: pgtype.CidrArray{Status: pgtype.Null}, + }, + { + source: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, + result: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]net.IP)(nil)), + result: pgtype.CidrArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.CidrArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestCidrArrayAssignTo(t *testing.T) { + var ipnetSlice []*net.IPNet + var ipSlice []net.IP + + simpleTests := []struct { + src pgtype.CidrArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, + }, + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{nil}, + }, + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, + }, + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{nil}, + }, + { + src: pgtype.CidrArray{Status: pgtype.Null}, + dst: &ipnetSlice, + expected: (([]*net.IPNet)(nil)), + }, + { + src: pgtype.CidrArray{Status: pgtype.Null}, + dst: &ipSlice, + expected: (([]net.IP)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/database_sql.go b/database_sql.go new file mode 100644 index 00000000..969d6542 --- /dev/null +++ b/database_sql.go @@ -0,0 +1,66 @@ +package pgtype + +import ( + "bytes" + "errors" +) + +func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { + switch src := src.(type) { + case *Bool: + return src.Bool, nil + case *Bytea: + return src.Bytes, nil + case *Date: + if src.InfinityModifier == None { + return src.Time, nil + } + case *Float4: + return float64(src.Float), nil + case *Float8: + return src.Float, nil + case *GenericBinary: + return src.Bytes, nil + case *GenericText: + return src.String, nil + case *Int2: + return int64(src.Int), nil + case *Int4: + return int64(src.Int), nil + case *Int8: + return int64(src.Int), nil + case *Text: + return src.String, nil + case *Timestamp: + if src.InfinityModifier == None { + return src.Time, nil + } + case *Timestamptz: + if src.InfinityModifier == None { + return src.Time, nil + } + case *Unknown: + return src.String, nil + case *Varchar: + return src.String, nil + } + + buf := &bytes.Buffer{} + if textEncoder, ok := src.(TextEncoder); ok { + _, err := textEncoder.EncodeText(ci, buf) + if err != nil { + return nil, err + } + return buf.String(), nil + } + + if binaryEncoder, ok := src.(BinaryEncoder); ok { + _, err := binaryEncoder.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + return buf.Bytes(), nil + } + + return nil, errors.New("cannot convert to database/sql compatible value") +} diff --git a/date.go b/date.go index d0481637..b6cc8329 100644 --- a/date.go +++ b/date.go @@ -38,6 +38,9 @@ func (dst *Date) Set(src interface{}) error { func (dst *Date) Get() interface{} { switch dst.Status { case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } return dst.Time case Null: return nil @@ -76,7 +79,7 @@ func (src *Date) AssignTo(dst interface{}) error { return nil } -func (dst *Date) DecodeText(src []byte) error { +func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Date{Status: Null} return nil @@ -100,7 +103,7 @@ func (dst *Date) DecodeText(src []byte) error { return nil } -func (dst *Date) DecodeBinary(src []byte) error { +func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Date{Status: Null} return nil @@ -125,7 +128,7 @@ func (dst *Date) DecodeBinary(src []byte) error { return nil } -func (src Date) EncodeText(w io.Writer) (bool, error) { +func (src Date) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -148,7 +151,7 @@ func (src Date) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Date) EncodeBinary(w io.Writer) (bool, error) { +func (src Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/date_array.go b/date_array.go index 7f602d83..ba68d561 100644 --- a/date_array.go +++ b/date_array.go @@ -84,7 +84,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { return nil } -func (dst *DateArray) DecodeText(src []byte) error { +func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = DateArray{Status: Null} return nil @@ -106,7 +106,7 @@ func (dst *DateArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -120,14 +120,14 @@ func (dst *DateArray) DecodeText(src []byte) error { return nil } -func (dst *DateArray) DecodeBinary(src []byte) error { +func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = DateArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -152,7 +152,7 @@ func (dst *DateArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -162,7 +162,7 @@ func (dst *DateArray) DecodeBinary(src []byte) error { return nil } -func (src *DateArray) EncodeText(w io.Writer) (bool, error) { +func (src *DateArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -209,7 +209,7 @@ func (src *DateArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -238,11 +238,11 @@ func (src *DateArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *DateArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, DateOid) +func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, DateOid) } -func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *DateArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -262,7 +262,7 @@ func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -272,7 +272,7 @@ func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/float4.go b/float4.go index 053af44b..94b7b7a1 100644 --- a/float4.go +++ b/float4.go @@ -102,7 +102,7 @@ func (src *Float4) AssignTo(dst interface{}) error { return float64AssignTo(float64(src.Float), src.Status, dst) } -func (dst *Float4) DecodeText(src []byte) error { +func (dst *Float4) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4{Status: Null} return nil @@ -117,7 +117,7 @@ func (dst *Float4) DecodeText(src []byte) error { return nil } -func (dst *Float4) DecodeBinary(src []byte) error { +func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4{Status: Null} return nil @@ -133,7 +133,7 @@ func (dst *Float4) DecodeBinary(src []byte) error { return nil } -func (src Float4) EncodeText(w io.Writer) (bool, error) { +func (src Float4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -145,7 +145,7 @@ func (src Float4) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Float4) EncodeBinary(w io.Writer) (bool, error) { +func (src Float4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/float4_array.go b/float4_array.go index 0e815e0b..40152bcf 100644 --- a/float4_array.go +++ b/float4_array.go @@ -83,7 +83,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error { return nil } -func (dst *Float4Array) DecodeText(src []byte) error { +func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4Array{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *Float4Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *Float4Array) DecodeText(src []byte) error { return nil } -func (dst *Float4Array) DecodeBinary(src []byte) error { +func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *Float4Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *Float4Array) DecodeBinary(src []byte) error { return nil } -func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { +func (src *Float4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Float4Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Float4Oid) +func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Float4Oid) } -func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Float4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/float8.go b/float8.go index 635b7a09..dd2d592d 100644 --- a/float8.go +++ b/float8.go @@ -92,7 +92,7 @@ func (src *Float8) AssignTo(dst interface{}) error { return float64AssignTo(src.Float, src.Status, dst) } -func (dst *Float8) DecodeText(src []byte) error { +func (dst *Float8) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8{Status: Null} return nil @@ -107,7 +107,7 @@ func (dst *Float8) DecodeText(src []byte) error { return nil } -func (dst *Float8) DecodeBinary(src []byte) error { +func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8{Status: Null} return nil @@ -123,7 +123,7 @@ func (dst *Float8) DecodeBinary(src []byte) error { return nil } -func (src Float8) EncodeText(w io.Writer) (bool, error) { +func (src Float8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -135,7 +135,7 @@ func (src Float8) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Float8) EncodeBinary(w io.Writer) (bool, error) { +func (src Float8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/float8_array.go b/float8_array.go index 811c5a1f..d0ee0d70 100644 --- a/float8_array.go +++ b/float8_array.go @@ -83,7 +83,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error { return nil } -func (dst *Float8Array) DecodeText(src []byte) error { +func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8Array{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *Float8Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *Float8Array) DecodeText(src []byte) error { return nil } -func (dst *Float8Array) DecodeBinary(src []byte) error { +func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *Float8Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *Float8Array) DecodeBinary(src []byte) error { return nil } -func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { +func (src *Float8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Float8Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Float8Oid) +func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Float8Oid) } -func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Float8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/generic_binary.go b/generic_binary.go index ac35ea60..aa28bb62 100644 --- a/generic_binary.go +++ b/generic_binary.go @@ -20,10 +20,10 @@ func (src *GenericBinary) AssignTo(dst interface{}) error { return (*Bytea)(src).AssignTo(dst) } -func (dst *GenericBinary) DecodeBinary(src []byte) error { - return (*Bytea)(dst).DecodeBinary(src) +func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Bytea)(dst).DecodeBinary(ci, src) } -func (src GenericBinary) EncodeBinary(w io.Writer) (bool, error) { - return (Bytea)(src).EncodeBinary(w) +func (src GenericBinary) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Bytea)(src).EncodeBinary(ci, w) } diff --git a/generic_text.go b/generic_text.go index 19f41059..bd75e0d0 100644 --- a/generic_text.go +++ b/generic_text.go @@ -20,10 +20,10 @@ func (src *GenericText) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } -func (dst *GenericText) DecodeText(src []byte) error { - return (*Text)(dst).DecodeText(src) +func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) } -func (src GenericText) EncodeText(w io.Writer) (bool, error) { - return (Text)(src).EncodeText(w) +func (src GenericText) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeText(ci, w) } diff --git a/hstore.go b/hstore.go index c48ae6da..d771d6e6 100644 --- a/hstore.go +++ b/hstore.go @@ -70,7 +70,7 @@ func (src *Hstore) AssignTo(dst interface{}) error { return nil } -func (dst *Hstore) DecodeText(src []byte) error { +func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Hstore{Status: Null} return nil @@ -90,7 +90,7 @@ func (dst *Hstore) DecodeText(src []byte) error { return nil } -func (dst *Hstore) DecodeBinary(src []byte) error { +func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Hstore{Status: Null} return nil @@ -132,7 +132,7 @@ func (dst *Hstore) DecodeBinary(src []byte) error { rp += valueLen var value Text - err := value.DecodeBinary(valueBuf) + err := value.DecodeBinary(ci, valueBuf) if err != nil { return err } @@ -144,7 +144,7 @@ func (dst *Hstore) DecodeBinary(src []byte) error { return nil } -func (src Hstore) EncodeText(w io.Writer) (bool, error) { +func (src Hstore) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -175,7 +175,7 @@ func (src Hstore) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := v.EncodeText(elemBuf) + null, err := v.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -196,7 +196,7 @@ func (src Hstore) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src Hstore) EncodeBinary(w io.Writer) (bool, error) { +func (src Hstore) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -220,7 +220,7 @@ func (src Hstore) EncodeBinary(w io.Writer) (bool, error) { return false, err } - null, err := v.EncodeText(elemBuf) + null, err := v.EncodeText(ci, elemBuf) if err != nil { return false, err } diff --git a/inet.go b/inet.go index 87d675f9..b83bd1c9 100644 --- a/inet.go +++ b/inet.go @@ -100,7 +100,7 @@ func (src *Inet) AssignTo(dst interface{}) error { return nil } -func (dst *Inet) DecodeText(src []byte) error { +func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Inet{Status: Null} return nil @@ -128,7 +128,7 @@ func (dst *Inet) DecodeText(src []byte) error { return nil } -func (dst *Inet) DecodeBinary(src []byte) error { +func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Inet{Status: Null} return nil @@ -153,7 +153,7 @@ func (dst *Inet) DecodeBinary(src []byte) error { return nil } -func (src Inet) EncodeText(w io.Writer) (bool, error) { +func (src Inet) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Inet) EncodeText(w io.Writer) (bool, error) { } // EncodeBinary encodes src into w. -func (src Inet) EncodeBinary(w io.Writer) (bool, error) { +func (src Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/inet_array.go b/inet_array.go index 1d1cf3fd..6cad82e7 100644 --- a/inet_array.go +++ b/inet_array.go @@ -115,7 +115,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { return nil } -func (dst *InetArray) DecodeText(src []byte) error { +func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = InetArray{Status: Null} return nil @@ -137,7 +137,7 @@ func (dst *InetArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -151,14 +151,14 @@ func (dst *InetArray) DecodeText(src []byte) error { return nil } -func (dst *InetArray) DecodeBinary(src []byte) error { +func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = InetArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -183,7 +183,7 @@ func (dst *InetArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -193,7 +193,7 @@ func (dst *InetArray) DecodeBinary(src []byte) error { return nil } -func (src *InetArray) EncodeText(w io.Writer) (bool, error) { +func (src *InetArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -240,7 +240,7 @@ func (src *InetArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -269,11 +269,11 @@ func (src *InetArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *InetArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, InetOid) +func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, InetOid) } -func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *InetArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -293,7 +293,7 @@ func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -303,7 +303,7 @@ func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/int2.go b/int2.go index 62e1bc69..6996cd4f 100644 --- a/int2.go +++ b/int2.go @@ -98,7 +98,7 @@ func (src *Int2) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int2) DecodeText(src []byte) error { +func (dst *Int2) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2{Status: Null} return nil @@ -113,7 +113,7 @@ func (dst *Int2) DecodeText(src []byte) error { return nil } -func (dst *Int2) DecodeBinary(src []byte) error { +func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2{Status: Null} return nil @@ -128,7 +128,7 @@ func (dst *Int2) DecodeBinary(src []byte) error { return nil } -func (src Int2) EncodeText(w io.Writer) (bool, error) { +func (src Int2) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -140,7 +140,7 @@ func (src Int2) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Int2) EncodeBinary(w io.Writer) (bool, error) { +func (src Int2) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/int2_array.go b/int2_array.go index 3d06c018..2bf1c237 100644 --- a/int2_array.go +++ b/int2_array.go @@ -114,7 +114,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int2Array) DecodeText(src []byte) error { +func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2Array{Status: Null} return nil @@ -136,7 +136,7 @@ func (dst *Int2Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -150,14 +150,14 @@ func (dst *Int2Array) DecodeText(src []byte) error { return nil } -func (dst *Int2Array) DecodeBinary(src []byte) error { +func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -182,7 +182,7 @@ func (dst *Int2Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -192,7 +192,7 @@ func (dst *Int2Array) DecodeBinary(src []byte) error { return nil } -func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { +func (src *Int2Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -239,7 +239,7 @@ func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -268,11 +268,11 @@ func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Int2Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int2Oid) +func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Int2Oid) } -func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Int2Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -292,7 +292,7 @@ func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -302,7 +302,7 @@ func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/int4.go b/int4.go index 8eaf5094..62ee366f 100644 --- a/int4.go +++ b/int4.go @@ -89,7 +89,7 @@ func (src *Int4) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int4) DecodeText(src []byte) error { +func (dst *Int4) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4{Status: Null} return nil @@ -104,7 +104,7 @@ func (dst *Int4) DecodeText(src []byte) error { return nil } -func (dst *Int4) DecodeBinary(src []byte) error { +func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4{Status: Null} return nil @@ -119,7 +119,7 @@ func (dst *Int4) DecodeBinary(src []byte) error { return nil } -func (src Int4) EncodeText(w io.Writer) (bool, error) { +func (src Int4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -131,7 +131,7 @@ func (src Int4) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Int4) EncodeBinary(w io.Writer) (bool, error) { +func (src Int4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/int4_array.go b/int4_array.go index 5cd91c04..dda88eaf 100644 --- a/int4_array.go +++ b/int4_array.go @@ -114,7 +114,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int4Array) DecodeText(src []byte) error { +func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4Array{Status: Null} return nil @@ -136,7 +136,7 @@ func (dst *Int4Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -150,14 +150,14 @@ func (dst *Int4Array) DecodeText(src []byte) error { return nil } -func (dst *Int4Array) DecodeBinary(src []byte) error { +func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -182,7 +182,7 @@ func (dst *Int4Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -192,7 +192,7 @@ func (dst *Int4Array) DecodeBinary(src []byte) error { return nil } -func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { +func (src *Int4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -239,7 +239,7 @@ func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -268,11 +268,11 @@ func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Int4Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int4Oid) +func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Int4Oid) } -func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Int4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -292,7 +292,7 @@ func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -302,7 +302,7 @@ func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/int8.go b/int8.go index 2416500d..7ed54f8e 100644 --- a/int8.go +++ b/int8.go @@ -80,7 +80,7 @@ func (src *Int8) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int8) DecodeText(src []byte) error { +func (dst *Int8) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8{Status: Null} return nil @@ -95,7 +95,7 @@ func (dst *Int8) DecodeText(src []byte) error { return nil } -func (dst *Int8) DecodeBinary(src []byte) error { +func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8{Status: Null} return nil @@ -111,7 +111,7 @@ func (dst *Int8) DecodeBinary(src []byte) error { return nil } -func (src Int8) EncodeText(w io.Writer) (bool, error) { +func (src Int8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -123,7 +123,7 @@ func (src Int8) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Int8) EncodeBinary(w io.Writer) (bool, error) { +func (src Int8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/int8_array.go b/int8_array.go index 5efc0f45..468c126b 100644 --- a/int8_array.go +++ b/int8_array.go @@ -114,7 +114,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int8Array) DecodeText(src []byte) error { +func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8Array{Status: Null} return nil @@ -136,7 +136,7 @@ func (dst *Int8Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -150,14 +150,14 @@ func (dst *Int8Array) DecodeText(src []byte) error { return nil } -func (dst *Int8Array) DecodeBinary(src []byte) error { +func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -182,7 +182,7 @@ func (dst *Int8Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -192,7 +192,7 @@ func (dst *Int8Array) DecodeBinary(src []byte) error { return nil } -func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { +func (src *Int8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -239,7 +239,7 @@ func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -268,11 +268,11 @@ func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Int8Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int8Oid) +func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Int8Oid) } -func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Int8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -292,7 +292,7 @@ func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -302,7 +302,7 @@ func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/json.go b/json.go index ecdb3dab..bfffae14 100644 --- a/json.go +++ b/json.go @@ -84,7 +84,7 @@ func (src *Json) AssignTo(dst interface{}) error { return nil } -func (dst *Json) DecodeText(src []byte) error { +func (dst *Json) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Json{Status: Null} return nil @@ -97,11 +97,11 @@ func (dst *Json) DecodeText(src []byte) error { return nil } -func (dst *Json) DecodeBinary(src []byte) error { - return dst.DecodeText(src) +func (dst *Json) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) } -func (src Json) EncodeText(w io.Writer) (bool, error) { +func (src Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -113,6 +113,6 @@ func (src Json) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Json) EncodeBinary(w io.Writer) (bool, error) { - return src.EncodeText(w) +func (src Json) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.EncodeText(ci, w) } diff --git a/jsonb.go b/jsonb.go index 13062e8e..e44f3c41 100644 --- a/jsonb.go +++ b/jsonb.go @@ -19,11 +19,11 @@ func (src *Jsonb) AssignTo(dst interface{}) error { return (*Json)(src).AssignTo(dst) } -func (dst *Jsonb) DecodeText(src []byte) error { - return (*Json)(dst).DecodeText(src) +func (dst *Jsonb) DecodeText(ci *ConnInfo, src []byte) error { + return (*Json)(dst).DecodeText(ci, src) } -func (dst *Jsonb) DecodeBinary(src []byte) error { +func (dst *Jsonb) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Jsonb{Status: Null} return nil @@ -46,11 +46,11 @@ func (dst *Jsonb) DecodeBinary(src []byte) error { } -func (src Jsonb) EncodeText(w io.Writer) (bool, error) { - return (Json)(src).EncodeText(w) +func (src Jsonb) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Json)(src).EncodeText(ci, w) } -func (src Jsonb) EncodeBinary(w io.Writer) (bool, error) { +func (src Jsonb) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/name.go b/name.go index 9eb12ece..9ebf63d3 100644 --- a/name.go +++ b/name.go @@ -31,18 +31,18 @@ func (src *Name) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } -func (dst *Name) DecodeText(src []byte) error { - return (*Text)(dst).DecodeText(src) +func (dst *Name) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) } -func (dst *Name) DecodeBinary(src []byte) error { - return (*Text)(dst).DecodeBinary(src) +func (dst *Name) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) } -func (src Name) EncodeText(w io.Writer) (bool, error) { - return (Text)(src).EncodeText(w) +func (src Name) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeText(ci, w) } -func (src Name) EncodeBinary(w io.Writer) (bool, error) { - return (Text)(src).EncodeBinary(w) +func (src Name) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeBinary(ci, w) } diff --git a/oid.go b/oid.go index eab1fbcb..3edd7f3c 100644 --- a/oid.go +++ b/oid.go @@ -18,7 +18,7 @@ import ( // allow for NULL Oids use OidValue. type Oid uint32 -func (dst *Oid) DecodeText(src []byte) error { +func (dst *Oid) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { return fmt.Errorf("cannot decode nil into Oid") } @@ -32,7 +32,7 @@ func (dst *Oid) DecodeText(src []byte) error { return nil } -func (dst *Oid) DecodeBinary(src []byte) error { +func (dst *Oid) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { return fmt.Errorf("cannot decode nil into Oid") } @@ -46,12 +46,12 @@ func (dst *Oid) DecodeBinary(src []byte) error { return nil } -func (src Oid) EncodeText(w io.Writer) (bool, error) { +func (src Oid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { _, err := io.WriteString(w, strconv.FormatUint(uint64(src), 10)) return false, err } -func (src Oid) EncodeBinary(w io.Writer) (bool, error) { +func (src Oid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteUint32(w, uint32(src)) return false, err } diff --git a/oid_value.go b/oid_value.go index a2b2dcbe..1bce6e11 100644 --- a/oid_value.go +++ b/oid_value.go @@ -28,18 +28,18 @@ func (src *OidValue) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *OidValue) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) +func (dst *OidValue) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *OidValue) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) +func (dst *OidValue) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src OidValue) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) +func (src OidValue) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(ci, w) } -func (src OidValue) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) +func (src OidValue) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(ci, w) } diff --git a/pgtype.go b/pgtype.go index 7b1470b7..674c0db7 100644 --- a/pgtype.go +++ b/pgtype.go @@ -3,6 +3,7 @@ package pgtype import ( "errors" "io" + "reflect" ) // PostgreSQL oids for common types @@ -83,14 +84,14 @@ type BinaryDecoder interface { // DecodeBinary decodes src into BinaryDecoder. If src is nil then the // original SQL value is NULL. BinaryDecoder MUST not retain a reference to // src. It MUST make a copy if it needs to retain the raw bytes. - DecodeBinary(src []byte) error + DecodeBinary(ci *ConnInfo, src []byte) error } type TextDecoder interface { // DecodeText decodes src into TextDecoder. If src is nil then the original // SQL value is NULL. TextDecoder MUST not retain a reference to src. It MUST // make a copy if it needs to retain the raw bytes. - DecodeText(src []byte) error + DecodeText(ci *ConnInfo, src []byte) error } // BinaryEncoder is implemented by types that can encode themselves into the @@ -100,7 +101,7 @@ type BinaryEncoder interface { // SQL value NULL then write nothing and return (true, nil). The caller of // EncodeBinary is responsible for writing the correct NULL value or the // length of the data written. - EncodeBinary(w io.Writer) (null bool, err error) + EncodeBinary(ci *ConnInfo, w io.Writer) (null bool, err error) } // TextEncoder is implemented by types that can encode themselves into the @@ -110,7 +111,127 @@ type TextEncoder interface { // value NULL then write nothing and return (true, nil). The caller of // EncodeText is responsible for writing the correct NULL value or the length // of the data written. - EncodeText(w io.Writer) (null bool, err error) + EncodeText(ci *ConnInfo, w io.Writer) (null bool, err error) } var errUndefined = errors.New("cannot encode status undefined") + +type DataType struct { + Value Value + Name string + Oid Oid +} + +type ConnInfo struct { + oidToDataType map[Oid]*DataType + nameToDataType map[string]*DataType + reflectTypeToDataType map[reflect.Type]*DataType +} + +func NewConnInfo() *ConnInfo { + return &ConnInfo{ + oidToDataType: make(map[Oid]*DataType, 256), + nameToDataType: make(map[string]*DataType, 256), + reflectTypeToDataType: make(map[reflect.Type]*DataType, 256), + } +} + +func (ci *ConnInfo) InitializeDataTypes(nameOids map[string]Oid) { + for name, oid := range nameOids { + var value Value + if t, ok := nameValues[name]; ok { + value = reflect.New(reflect.ValueOf(t).Elem().Type()).Interface().(Value) + } else { + value = &GenericText{} + } + ci.RegisterDataType(DataType{Value: value, Name: name, Oid: oid}) + } +} + +func (ci *ConnInfo) RegisterDataType(t DataType) { + ci.oidToDataType[t.Oid] = &t + ci.nameToDataType[t.Name] = &t + ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t +} + +func (ci *ConnInfo) DataTypeForOid(oid Oid) (*DataType, bool) { + dt, ok := ci.oidToDataType[oid] + return dt, ok +} + +func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) { + dt, ok := ci.nameToDataType[name] + return dt, ok +} + +func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) { + dt, ok := ci.reflectTypeToDataType[reflect.ValueOf(v).Type()] + return dt, ok +} + +// DeepCopy makes a deep copy of the ConnInfo. +func (ci *ConnInfo) DeepCopy() *ConnInfo { + ci2 := &ConnInfo{ + oidToDataType: make(map[Oid]*DataType, len(ci.oidToDataType)), + nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)), + reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)), + } + + for _, dt := range ci.oidToDataType { + ci2.RegisterDataType(DataType{ + Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value), + Name: dt.Name, + Oid: dt.Oid, + }) + } + + return ci2 +} + +var nameValues map[string]Value + +func init() { + nameValues = map[string]Value{ + "_aclitem": &AclitemArray{}, + "_bool": &BoolArray{}, + "_bytea": &ByteaArray{}, + "_cidr": &CidrArray{}, + "_date": &DateArray{}, + "_float4": &Float4Array{}, + "_float8": &Float8Array{}, + "_inet": &InetArray{}, + "_int2": &Int2Array{}, + "_int4": &Int4Array{}, + "_int8": &Int8Array{}, + "_text": &TextArray{}, + "_timestamp": &TimestampArray{}, + "_timestamptz": &TimestamptzArray{}, + "_varchar": &VarcharArray{}, + "aclitem": &Aclitem{}, + "bool": &Bool{}, + "bytea": &Bytea{}, + "char": &QChar{}, + "cid": &Cid{}, + "cidr": &Cidr{}, + "date": &Date{}, + "float4": &Float4{}, + "float8": &Float8{}, + "hstore": &Hstore{}, + "inet": &Inet{}, + "int2": &Int2{}, + "int4": &Int4{}, + "int8": &Int8{}, + "json": &Json{}, + "jsonb": &Jsonb{}, + "name": &Name{}, + "oid": &OidValue{}, + "record": &Record{}, + "text": &Text{}, + "tid": &Tid{}, + "timestamp": &Timestamp{}, + "timestamptz": &Timestamptz{}, + "unknown": &Unknown{}, + "varchar": &Varchar{}, + "xid": &Xid{}, + } +} diff --git a/pgtype_test.go b/pgtype_test.go index f9b6f56d..391fed57 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -60,16 +60,16 @@ type forceTextEncoder struct { e pgtype.TextEncoder } -func (f forceTextEncoder) EncodeText(w io.Writer) (bool, error) { - return f.e.EncodeText(w) +func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + return f.e.EncodeText(ci, w) } type forceBinaryEncoder struct { e pgtype.BinaryEncoder } -func (f forceBinaryEncoder) EncodeBinary(w io.Writer) (bool, error) { - return f.e.EncodeBinary(w) +func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + return f.e.EncodeBinary(ci, w) } func forceEncoder(e interface{}, formatCode int16) interface{} { @@ -114,7 +114,7 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int ps.FieldDescriptions[0].FormatCode = fc.formatCode vEncoder := forceEncoder(v, fc.formatCode) if vEncoder == nil { - t.Logf("%v does not implement %v", fc.name) + t.Logf("%#v does not implement %v", v, fc.name) continue } // Derefence value if it is a pointer diff --git a/pguint32.go b/pguint32.go index 05c79c0e..3f9e7bf7 100644 --- a/pguint32.go +++ b/pguint32.go @@ -63,7 +63,7 @@ func (src *pguint32) AssignTo(dst interface{}) error { return nil } -func (dst *pguint32) DecodeText(src []byte) error { +func (dst *pguint32) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = pguint32{Status: Null} return nil @@ -78,7 +78,7 @@ func (dst *pguint32) DecodeText(src []byte) error { return nil } -func (dst *pguint32) DecodeBinary(src []byte) error { +func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = pguint32{Status: Null} return nil @@ -93,7 +93,7 @@ func (dst *pguint32) DecodeBinary(src []byte) error { return nil } -func (src pguint32) EncodeText(w io.Writer) (bool, error) { +func (src pguint32) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -105,7 +105,7 @@ func (src pguint32) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src pguint32) EncodeBinary(w io.Writer) (bool, error) { +func (src pguint32) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/qchar.go b/qchar.go index d46e716d..4b32ee4a 100644 --- a/qchar.go +++ b/qchar.go @@ -115,7 +115,7 @@ func (src *QChar) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *QChar) DecodeBinary(src []byte) error { +func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = QChar{Status: Null} return nil @@ -129,7 +129,7 @@ func (dst *QChar) DecodeBinary(src []byte) error { return nil } -func (src QChar) EncodeBinary(w io.Writer) (bool, error) { +func (src QChar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/record.go b/record.go new file mode 100644 index 00000000..1bfd05b9 --- /dev/null +++ b/record.go @@ -0,0 +1,123 @@ +package pgtype + +import ( + "encoding/binary" + "fmt" +) + +// Record is the generic PostgreSQL record type such as is created with the +// "row" function. Record only implements BinaryEncoder and Value. The text +// format output format from PostgreSQL does not include type information and is +// therefore impossible to decode. No encoders are implemented because +// PostgreSQL does not support input of generic records. +type Record struct { + Fields []Value + Status Status +} + +func (dst *Record) Set(src interface{}) error { + switch value := src.(type) { + case []Value: + *dst = Record{Fields: value, Status: Present} + default: + return fmt.Errorf("cannot convert %v to Record", src) + } + + return nil +} + +func (dst *Record) Get() interface{} { + switch dst.Status { + case Present: + return dst.Fields + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Record) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *[]Value: + switch src.Status { + case Present: + *v = make([]Value, len(src.Fields)) + copy(*v, src.Fields) + case Null: + *v = nil + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + case *[]interface{}: + switch src.Status { + case Present: + *v = make([]interface{}, len(src.Fields)) + for i := range *v { + (*v)[i] = src.Fields[i].Get() + } + case Null: + *v = nil + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Record{Status: Null} + return nil + } + + rp := 0 + + if len(src[rp:]) < 4 { + return fmt.Errorf("Record incomplete %v", src) + } + fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + fields := make([]Value, fieldCount) + + for i := 0; i < fieldCount; i++ { + if len(src[rp:]) < 8 { + return fmt.Errorf("Record incomplete %v", src) + } + fieldOid := Oid(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + var binaryDecoder BinaryDecoder + if dt, ok := ci.DataTypeForOid(fieldOid); ok { + if binaryDecoder, ok = dt.Value.(BinaryDecoder); !ok { + return fmt.Errorf("unknown oid while decoding record: %v", fieldOid) + } + } + + var fieldBytes []byte + if fieldLen >= 0 { + if len(src[rp:]) < fieldLen { + return fmt.Errorf("Record incomplete %v", src) + } + fieldBytes = src[rp : rp+fieldLen] + rp += fieldLen + } + + if err := binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { + return err + } + + fields[i] = binaryDecoder.(Value) + } + + *dst = Record{Fields: fields, Status: Present} + + return nil +} diff --git a/record_test.go b/record_test.go new file mode 100644 index 00000000..bc6e5893 --- /dev/null +++ b/record_test.go @@ -0,0 +1,150 @@ +package pgtype_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" +) + +func TestRecordTranscode(t *testing.T) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + tests := []struct { + sql string + expected pgtype.Record + }{ + { + sql: `select row()`, + expected: pgtype.Record{ + Fields: []pgtype.Value{}, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Int: 2, Status: pgtype.Present}, + pgtype.Int4{Status: pgtype.Null}, + pgtype.Int4{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row(null)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Unknown{Status: pgtype.Null}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select null::record`, + expected: pgtype.Record{ + Status: pgtype.Null, + }, + }, + } + + for i, tt := range tests { + psName := fmt.Sprintf("test%d", i) + ps, err := conn.Prepare(psName, tt.sql) + if err != nil { + t.Fatal(err) + } + ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode + + var result pgtype.Record + if err := conn.QueryRow(psName).Scan(&result); err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !reflect.DeepEqual(tt.expected, result) { + t.Errorf("%d: expected %v, got %v", i, tt.expected, result) + } + } +} + +func TestRecordAssignTo(t *testing.T) { + var valueSlice []pgtype.Value + var interfaceSlice []interface{} + + simpleTests := []struct { + src pgtype.Record + dst interface{} + expected interface{} + }{ + { + src: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + dst: &valueSlice, + expected: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + }, + { + src: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + dst: &interfaceSlice, + expected: []interface{}{"foo", int32(42)}, + }, + { + src: pgtype.Record{Status: pgtype.Null}, + dst: &valueSlice, + expected: (([]pgtype.Value)(nil)), + }, + { + src: pgtype.Record{Status: pgtype.Null}, + dst: &interfaceSlice, + expected: (([]interface{})(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/text.go b/text.go index 3dd082c9..f1a76b6e 100644 --- a/text.go +++ b/text.go @@ -78,7 +78,7 @@ func (src *Text) AssignTo(dst interface{}) error { return nil } -func (dst *Text) DecodeText(src []byte) error { +func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Text{Status: Null} return nil @@ -88,11 +88,11 @@ func (dst *Text) DecodeText(src []byte) error { return nil } -func (dst *Text) DecodeBinary(src []byte) error { - return dst.DecodeText(src) +func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) } -func (src Text) EncodeText(w io.Writer) (bool, error) { +func (src Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -104,6 +104,6 @@ func (src Text) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Text) EncodeBinary(w io.Writer) (bool, error) { - return src.EncodeText(w) +func (src Text) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.EncodeText(ci, w) } diff --git a/text_array.go b/text_array.go index 1e6677a9..6e89708f 100644 --- a/text_array.go +++ b/text_array.go @@ -83,7 +83,7 @@ func (src *TextArray) AssignTo(dst interface{}) error { return nil } -func (dst *TextArray) DecodeText(src []byte) error { +func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TextArray{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *TextArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *TextArray) DecodeText(src []byte) error { return nil } -func (dst *TextArray) DecodeBinary(src []byte) error { +func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = TextArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *TextArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *TextArray) DecodeBinary(src []byte) error { return nil } -func (src *TextArray) EncodeText(w io.Writer) (bool, error) { +func (src *TextArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *TextArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *TextArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *TextArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TextOid) +func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, TextOid) } -func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *TextArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/tid.go b/tid.go index 20d962df..b91711d3 100644 --- a/tid.go +++ b/tid.go @@ -46,7 +46,7 @@ func (src *Tid) AssignTo(dst interface{}) error { return fmt.Errorf("cannot assign %v to %T", src, dst) } -func (dst *Tid) DecodeText(src []byte) error { +func (dst *Tid) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Tid{Status: Null} return nil @@ -75,7 +75,7 @@ func (dst *Tid) DecodeText(src []byte) error { return nil } -func (dst *Tid) DecodeBinary(src []byte) error { +func (dst *Tid) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Tid{Status: Null} return nil @@ -93,7 +93,7 @@ func (dst *Tid) DecodeBinary(src []byte) error { return nil } -func (src Tid) EncodeText(w io.Writer) (bool, error) { +func (src Tid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -105,7 +105,7 @@ func (src Tid) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Tid) EncodeBinary(w io.Writer) (bool, error) { +func (src Tid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/timestamp.go b/timestamp.go index 3bb8f080..9a9e74ea 100644 --- a/timestamp.go +++ b/timestamp.go @@ -85,7 +85,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error { // DecodeText decodes from src into dst. The decoded time is considered to // be in UTC. -func (dst *Timestamp) DecodeText(src []byte) error { +func (dst *Timestamp) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamp{Status: Null} return nil @@ -111,7 +111,7 @@ func (dst *Timestamp) DecodeText(src []byte) error { // DecodeBinary decodes from src into dst. The decoded time is considered to // be in UTC. -func (dst *Timestamp) DecodeBinary(src []byte) error { +func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamp{Status: Null} return nil @@ -139,7 +139,7 @@ func (dst *Timestamp) DecodeBinary(src []byte) error { // EncodeText writes the text encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeText(w io.Writer) (bool, error) { +func (src Timestamp) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -167,7 +167,7 @@ func (src Timestamp) EncodeText(w io.Writer) (bool, error) { // EncodeBinary writes the binary encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeBinary(w io.Writer) (bool, error) { +func (src Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/timestamp_array.go b/timestamp_array.go index c955dc42..064ad483 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -84,7 +84,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { return nil } -func (dst *TimestampArray) DecodeText(src []byte) error { +func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestampArray{Status: Null} return nil @@ -106,7 +106,7 @@ func (dst *TimestampArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -120,14 +120,14 @@ func (dst *TimestampArray) DecodeText(src []byte) error { return nil } -func (dst *TimestampArray) DecodeBinary(src []byte) error { +func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestampArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -152,7 +152,7 @@ func (dst *TimestampArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -162,7 +162,7 @@ func (dst *TimestampArray) DecodeBinary(src []byte) error { return nil } -func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { +func (src *TimestampArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -209,7 +209,7 @@ func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -238,11 +238,11 @@ func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *TimestampArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TimestampOid) +func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, TimestampOid) } -func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *TimestampArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -262,7 +262,7 @@ func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, er } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -272,7 +272,7 @@ func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, er for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/timestamptz.go b/timestamptz.go index 5b9f5038..7f57f4b7 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -84,7 +84,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { return nil } -func (dst *Timestamptz) DecodeText(src []byte) error { +func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamptz{Status: Null} return nil @@ -117,7 +117,7 @@ func (dst *Timestamptz) DecodeText(src []byte) error { return nil } -func (dst *Timestamptz) DecodeBinary(src []byte) error { +func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamptz{Status: Null} return nil @@ -143,7 +143,7 @@ func (dst *Timestamptz) DecodeBinary(src []byte) error { return nil } -func (src Timestamptz) EncodeText(w io.Writer) (bool, error) { +func (src Timestamptz) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Timestamptz) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Timestamptz) EncodeBinary(w io.Writer) (bool, error) { +func (src Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/timestamptz_array.go b/timestamptz_array.go index cd63e02e..4af1460b 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -84,7 +84,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { return nil } -func (dst *TimestamptzArray) DecodeText(src []byte) error { +func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestamptzArray{Status: Null} return nil @@ -106,7 +106,7 @@ func (dst *TimestamptzArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -120,14 +120,14 @@ func (dst *TimestamptzArray) DecodeText(src []byte) error { return nil } -func (dst *TimestamptzArray) DecodeBinary(src []byte) error { +func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestamptzArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -152,7 +152,7 @@ func (dst *TimestamptzArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -162,7 +162,7 @@ func (dst *TimestamptzArray) DecodeBinary(src []byte) error { return nil } -func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { +func (src *TimestamptzArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -209,7 +209,7 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -238,11 +238,11 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *TimestamptzArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TimestamptzOid) +func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, TimestamptzOid) } -func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *TimestamptzArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -262,7 +262,7 @@ func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -272,7 +272,7 @@ func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/typed_array.go.erb b/typed_array.go.erb index a56097c0..2a46a658 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -82,7 +82,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { return nil } -func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { +func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} return nil @@ -104,7 +104,7 @@ func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -118,14 +118,14 @@ func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { return nil } -func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { +func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -150,7 +150,7 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { elemSrc = src[rp:rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -160,7 +160,7 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { return nil } -func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { +func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -207,7 +207,7 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -236,11 +236,11 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, <%= element_oid %>) +func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, <%= element_oid %>) } -func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *<%= pgtype_array_type %>) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -260,7 +260,7 @@ func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -270,7 +270,7 @@ func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 41c1313f..5fde32aa 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -8,6 +8,8 @@ erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_type erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4Oid text_null=NULL typed_array.go.erb > float4_array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8Oid text_null=NULL typed_array.go.erb > float8_array.go erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOid text_null=NULL typed_array.go.erb > inet_array.go +erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_oid=CidrOid text_null=NULL typed_array.go.erb > cidr_array.go erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOid text_null='"NULL"' typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_oid=VarcharOid text_null='"NULL"' typed_array.go.erb > varchar_array.go erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOid text_null=NULL typed_array.go.erb > bytea_array.go erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_oid=AclitemOid text_null=NULL typed_array.go.erb > aclitem_array.go diff --git a/unknown.go b/unknown.go new file mode 100644 index 00000000..b951ad99 --- /dev/null +++ b/unknown.go @@ -0,0 +1,32 @@ +package pgtype + +// Unknown represents the PostgreSQL unknown type. It is either a string literal +// or NULL. It is used when PostgreSQL does not know the type of a value. In +// general, this will only be used in pgx when selecting a null value without +// type information. e.g. SELECT NULL; +type Unknown struct { + String string + Status Status +} + +func (dst *Unknown) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *Unknown) Get() interface{} { + return (*Text)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as Unknown is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *Unknown) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Unknown) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Unknown) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} diff --git a/varchar.go b/varchar.go new file mode 100644 index 00000000..adda6c49 --- /dev/null +++ b/varchar.go @@ -0,0 +1,40 @@ +package pgtype + +import ( + "io" +) + +type Varchar Text + +// Set converts from src to dst. Note that as Varchar is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *Varchar) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *Varchar) Get() interface{} { + return (*Text)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as Varchar is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *Varchar) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Varchar) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +func (src Varchar) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeText(ci, w) +} + +func (src Varchar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeBinary(ci, w) +} diff --git a/varchar_array.go b/varchar_array.go index 693b9a61..21e9ccff 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -1,35 +1,296 @@ package pgtype import ( + "bytes" + "encoding/binary" + "fmt" "io" + + "github.com/jackc/pgx/pgio" ) -type VarcharArray TextArray +type VarcharArray struct { + Elements []Varchar + Dimensions []ArrayDimension + Status Status +} func (dst *VarcharArray) Set(src interface{}) error { - return (*TextArray)(dst).Set(src) + switch value := src.(type) { + + case []string: + if value == nil { + *dst = VarcharArray{Status: Null} + } else if len(value) == 0 { + *dst = VarcharArray{Status: Present} + } else { + elements := make([]Varchar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = VarcharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Varchar", value) + } + + return nil } func (dst *VarcharArray) Get() interface{} { - return (*TextArray)(dst).Get() + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } } func (src *VarcharArray) AssignTo(dst interface{}) error { - return (*TextArray)(src).AssignTo(dst) + switch v := dst.(type) { + + case *[]string: + if src.Status == Present { + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil } -func (dst *VarcharArray) DecodeText(src []byte) error { - return (*TextArray)(dst).DecodeText(src) +func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = VarcharArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Varchar + + if len(uta.Elements) > 0 { + elements = make([]Varchar, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Varchar + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = VarcharArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil } -func (dst *VarcharArray) DecodeBinary(src []byte) error { - return (*TextArray)(dst).DecodeBinary(src) +func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = VarcharArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = VarcharArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Varchar, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = VarcharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil } -func (src *VarcharArray) EncodeText(w io.Writer) (bool, error) { - return (*TextArray)(src).EncodeText(w) +func (src *VarcharArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `"NULL"`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil } -func (src *VarcharArray) EncodeBinary(w io.Writer) (bool, error) { - return (*TextArray)(src).encodeBinary(w, VarcharOid) +func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, VarcharOid) +} + +func (src *VarcharArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + ElementOid: elementOid, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(ci, w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err } diff --git a/varchar_array_test.go b/varchar_array_test.go new file mode 100644 index 00000000..4a8b09b8 --- /dev/null +++ b/varchar_array_test.go @@ -0,0 +1,151 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestVarcharArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "varchar[]", []interface{}{ + &pgtype.VarcharArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "foo", Status: pgtype.Present}, + pgtype.Varchar{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{Status: pgtype.Null}, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "bar ", Status: pgtype.Present}, + pgtype.Varchar{String: "NuLL", Status: pgtype.Present}, + pgtype.Varchar{String: `wow"quz\`, Status: pgtype.Present}, + pgtype.Varchar{String: "", Status: pgtype.Present}, + pgtype.Varchar{Status: pgtype.Null}, + pgtype.Varchar{String: "null", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "bar", Status: pgtype.Present}, + pgtype.Varchar{String: "baz", Status: pgtype.Present}, + pgtype.Varchar{String: "quz", Status: pgtype.Present}, + pgtype.Varchar{String: "foo", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestVarcharArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.VarcharArray + }{ + { + source: []string{"foo"}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.VarcharArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.VarcharArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestVarcharArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + + simpleTests := []struct { + src pgtype.VarcharArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.VarcharArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.VarcharArray + dst interface{} + }{ + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/xid.go b/xid.go index a53120de..c76548a4 100644 --- a/xid.go +++ b/xid.go @@ -37,18 +37,18 @@ func (src *Xid) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *Xid) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) +func (dst *Xid) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *Xid) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) +func (dst *Xid) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src Xid) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) +func (src Xid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(ci, w) } -func (src Xid) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) +func (src Xid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(ci, w) } From df8f8e17cfa493ed7c3b21199b7d685e32506661 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 12:40:54 -0500 Subject: [PATCH 0041/1158] Add pgtype.HstoreArray This required restructuring array types to lookup oid of element instead of hard-coding it due to hstore having a variable oid. --- bool_array.go | 11 +- bytea_array.go | 11 +- cidr_array.go | 11 +- date_array.go | 11 +- float4_array.go | 11 +- float8_array.go | 11 +- hstore_array.go | 297 +++++++++++++++++++++++++++++++++++++++++++ hstore_array_test.go | 183 ++++++++++++++++++++++++++ inet_array.go | 11 +- int2_array.go | 11 +- int4_array.go | 11 +- int8_array.go | 11 +- text_array.go | 11 +- timestamp_array.go | 11 +- timestamptz_array.go | 11 +- typed_array.go.erb | 11 +- typed_array_gen.sh | 31 ++--- varchar_array.go | 11 +- 18 files changed, 586 insertions(+), 90 deletions(-) create mode 100644 hstore_array.go create mode 100644 hstore_array_test.go diff --git a/bool_array.go b/bool_array.go index 1cb46cf6..6adfbb00 100644 --- a/bool_array.go +++ b/bool_array.go @@ -238,10 +238,6 @@ func (src *BoolArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, BoolOid) -} - -func (src *BoolArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *BoolArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("bool"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "bool") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/bytea_array.go b/bytea_array.go index 30405509..d318fa3b 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -238,10 +238,6 @@ func (src *ByteaArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, ByteaOid) -} - -func (src *ByteaArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *ByteaArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("bytea"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "bytea") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/cidr_array.go b/cidr_array.go index 32d2e7bf..3ab83ecd 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -270,10 +270,6 @@ func (src *CidrArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, CidrOid) -} - -func (src *CidrArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -282,10 +278,15 @@ func (src *CidrArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("cidr"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "cidr") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/date_array.go b/date_array.go index ba68d561..8bc8ff72 100644 --- a/date_array.go +++ b/date_array.go @@ -239,10 +239,6 @@ func (src *DateArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, DateOid) -} - -func (src *DateArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -251,10 +247,15 @@ func (src *DateArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("date"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "date") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/float4_array.go b/float4_array.go index 40152bcf..6abc1a31 100644 --- a/float4_array.go +++ b/float4_array.go @@ -238,10 +238,6 @@ func (src *Float4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, Float4Oid) -} - -func (src *Float4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *Float4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32 } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("float4"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "float4") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/float8_array.go b/float8_array.go index d0ee0d70..050efa3f 100644 --- a/float8_array.go +++ b/float8_array.go @@ -238,10 +238,6 @@ func (src *Float8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, Float8Oid) -} - -func (src *Float8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *Float8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32 } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("float8"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "float8") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/hstore_array.go b/hstore_array.go new file mode 100644 index 00000000..ba192462 --- /dev/null +++ b/hstore_array.go @@ -0,0 +1,297 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type HstoreArray struct { + Elements []Hstore + Dimensions []ArrayDimension + Status Status +} + +func (dst *HstoreArray) Set(src interface{}) error { + switch value := src.(type) { + + case []map[string]string: + if value == nil { + *dst = HstoreArray{Status: Null} + } else if len(value) == 0 { + *dst = HstoreArray{Status: Present} + } else { + elements := make([]Hstore, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = HstoreArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Hstore", value) + } + + return nil +} + +func (dst *HstoreArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *HstoreArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]map[string]string: + if src.Status == Present { + *v = make([]map[string]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = HstoreArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Hstore + + if len(uta.Elements) > 0 { + elements = make([]Hstore, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Hstore + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = HstoreArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = HstoreArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = HstoreArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Hstore, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = HstoreArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *HstoreArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil +} + +func (src *HstoreArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("hstore"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "hstore") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(ci, w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err +} diff --git a/hstore_array_test.go b/hstore_array_test.go new file mode 100644 index 00000000..e23c7b3b --- /dev/null +++ b/hstore_array_test.go @@ -0,0 +1,183 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" +) + +func TestHstoreArrayTranscode(t *testing.T) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + text := func(s string) pgtype.Text { + return pgtype.Text{String: s, Status: pgtype.Present} + } + + values := []pgtype.Hstore{ + pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + pgtype.Hstore{Status: pgtype.Null}, + } + + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + + // Special value values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + } + + src := pgtype.HstoreArray{ + Elements: values, + Dimensions: []pgtype.ArrayDimension{{Length: int32(len(values)), LowerBound: 1}}, + Status: pgtype.Present, + } + + ps, err := conn.Prepare("test", "select $1::hstore[]") + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, fc := range formats { + ps.FieldDescriptions[0].FormatCode = fc.formatCode + vEncoder := forceEncoder(src, fc.formatCode) + if vEncoder == nil { + t.Logf("%#v does not implement %v", src, fc.name) + continue + } + + var result pgtype.HstoreArray + err := conn.QueryRow("test", vEncoder).Scan(&result) + if err != nil { + t.Errorf("%v: %v", fc.name, err) + continue + } + + if result.Status != src.Status { + t.Errorf("%v: expected Status %v, got %v", fc.formatCode, src.Status, result.Status) + continue + } + + if len(result.Elements) != len(src.Elements) { + t.Errorf("%v: expected %v elements, got %v", fc.formatCode, len(src.Elements), len(result.Elements)) + continue + } + + for i := range result.Elements { + a := src.Elements[i] + b := result.Elements[i] + + if a.Status != b.Status { + t.Errorf("%v element idx %d: expected status %v, got %v", fc.formatCode, i, a.Status, b.Status) + } + + if len(a.Map) != len(b.Map) { + t.Errorf("%v element idx %d: expected %v pairs, got %v", fc.formatCode, i, len(a.Map), len(b.Map)) + } + + for k := range a.Map { + if a.Map[k] != b.Map[k] { + t.Errorf("%v element idx %d: expected key %v to be %v, got %v", fc.formatCode, i, k, a.Map[k], b.Map[k]) + } + } + } + } +} + +func TestHstoreArraySet(t *testing.T) { + successfulTests := []struct { + src []map[string]string + result pgtype.HstoreArray + }{ + { + src: []map[string]string{map[string]string{"foo": "bar"}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + }, + } + + for i, tt := range successfulTests { + var dst pgtype.HstoreArray + err := dst.Set(tt.src) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(dst, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) + } + } +} + +func TestHstoreArrayAssignTo(t *testing.T) { + var m []map[string]string + + simpleTests := []struct { + src pgtype.HstoreArray + dst *[]map[string]string + expected []map[string]string + }{ + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &m, + expected: []map[string]string{{"foo": "bar"}}}, + {src: pgtype.HstoreArray{Status: pgtype.Null}, dst: &m, expected: (([]map[string]string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(*tt.dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} diff --git a/inet_array.go b/inet_array.go index 6cad82e7..d893a724 100644 --- a/inet_array.go +++ b/inet_array.go @@ -270,10 +270,6 @@ func (src *InetArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, InetOid) -} - -func (src *InetArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -282,10 +278,15 @@ func (src *InetArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("inet"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "inet") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/int2_array.go b/int2_array.go index 2bf1c237..b93a4fa3 100644 --- a/int2_array.go +++ b/int2_array.go @@ -269,10 +269,6 @@ func (src *Int2Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, Int2Oid) -} - -func (src *Int2Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -281,10 +277,15 @@ func (src *Int2Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("int2"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "int2") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/int4_array.go b/int4_array.go index dda88eaf..0b96b7a4 100644 --- a/int4_array.go +++ b/int4_array.go @@ -269,10 +269,6 @@ func (src *Int4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, Int4Oid) -} - -func (src *Int4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -281,10 +277,15 @@ func (src *Int4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("int4"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "int4") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/int8_array.go b/int8_array.go index 468c126b..02a240f4 100644 --- a/int8_array.go +++ b/int8_array.go @@ -269,10 +269,6 @@ func (src *Int8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, Int8Oid) -} - -func (src *Int8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -281,10 +277,15 @@ func (src *Int8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("int8"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "int8") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/text_array.go b/text_array.go index 6e89708f..9f25727e 100644 --- a/text_array.go +++ b/text_array.go @@ -238,10 +238,6 @@ func (src *TextArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, TextOid) -} - -func (src *TextArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *TextArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("text"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "text") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/timestamp_array.go b/timestamp_array.go index 064ad483..bb19e502 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -239,10 +239,6 @@ func (src *TimestampArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, TimestampOid) -} - -func (src *TimestampArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -251,10 +247,15 @@ func (src *TimestampArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid in } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("timestamp"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "timestamp") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/timestamptz_array.go b/timestamptz_array.go index 4af1460b..6a85cefa 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -239,10 +239,6 @@ func (src *TimestamptzArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) } func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, TimestamptzOid) -} - -func (src *TimestamptzArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -251,10 +247,15 @@ func (src *TimestamptzArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("timestamptz"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "timestamptz") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/typed_array.go.erb b/typed_array.go.erb index 2a46a658..2b81666e 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -237,10 +237,6 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool } func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, <%= element_oid %>) -} - -func (src *<%= pgtype_array_type %>) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -249,10 +245,15 @@ func (src *<%= pgtype_array_type %>) encodeBinary(ci *ConnInfo, w io.Writer, ele } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 5fde32aa..166f8802 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -1,15 +1,16 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2Oid text_null=NULL typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4Oid text_null=NULL typed_array.go.erb > int4_array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8Oid text_null=NULL typed_array.go.erb > int8_array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOid text_null=NULL typed_array.go.erb > bool_array.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOid text_null=NULL typed_array.go.erb > date_array.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOid text_null=NULL typed_array.go.erb > timestamptz_array.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOid text_null=NULL typed_array.go.erb > timestamp_array.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4Oid text_null=NULL typed_array.go.erb > float4_array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8Oid text_null=NULL typed_array.go.erb > float8_array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOid text_null=NULL typed_array.go.erb > inet_array.go -erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_oid=CidrOid text_null=NULL typed_array.go.erb > cidr_array.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOid text_null='"NULL"' typed_array.go.erb > text_array.go -erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_oid=VarcharOid text_null='"NULL"' typed_array.go.erb > varchar_array.go -erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOid text_null=NULL typed_array.go.erb > bytea_array.go -erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_oid=AclitemOid text_null=NULL typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_type_name=int2 text_null=NULL typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_type_name=int4 text_null=NULL typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_type_name=int8 text_null=NULL typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_type_name=bool text_null=NULL typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_type_name=date text_null=NULL typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_type_name=timestamptz text_null=NULL typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_type_name=timestamp text_null=NULL typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_type_name=float4 text_null=NULL typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_type_name=float8 text_null=NULL typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_type_name=inet text_null=NULL typed_array.go.erb > inet_array.go +erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL typed_array.go.erb > cidr_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' typed_array.go.erb > varchar_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL typed_array.go.erb > bytea_array.go +erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_type_name=aclitem text_null=NULL typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL typed_array.go.erb > hstore_array.go diff --git a/varchar_array.go b/varchar_array.go index 21e9ccff..158ece94 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -238,10 +238,6 @@ func (src *VarcharArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, VarcharOid) -} - -func (src *VarcharArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *VarcharArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int3 } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("varchar"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "varchar") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true From d516894475a21d614652a4d25af635c32cd01654 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 14:42:36 -0500 Subject: [PATCH 0042/1158] Simplify []byte scanning --- text.go | 10 ++++++++++ text_test.go | 27 +++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/text.go b/text.go index f1a76b6e..af7f16fc 100644 --- a/text.go +++ b/text.go @@ -49,6 +49,16 @@ func (src *Text) AssignTo(dst interface{}) error { return fmt.Errorf("cannot assign %v to %T", src, dst) } *v = src.String + case *[]byte: + switch src.Status { + case Present: + *v = make([]byte, len(src.String)) + copy(*v, src.String) + case Null: + *v = nil + default: + return fmt.Errorf("unknown status") + } default: if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { el := v.Elem() diff --git a/text_test.go b/text_test.go index 39348bcc..34b6a784 100644 --- a/text_test.go +++ b/text_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "bytes" "reflect" "testing" @@ -44,7 +45,7 @@ func TestTextAssignTo(t *testing.T) { var s string var ps *string - simpleTests := []struct { + stringTests := []struct { src pgtype.Text dst interface{} expected interface{} @@ -53,7 +54,7 @@ func TestTextAssignTo(t *testing.T) { {src: pgtype.Text{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, } - for i, tt := range simpleTests { + for i, tt := range stringTests { err := tt.src.AssignTo(tt.dst) if err != nil { t.Errorf("%d: %v", i, err) @@ -64,6 +65,28 @@ func TestTextAssignTo(t *testing.T) { } } + var buf []byte + + bytesTests := []struct { + src pgtype.Text + dst *[]byte + expected []byte + }{ + {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &buf, expected: []byte("foo")}, + {src: pgtype.Text{Status: pgtype.Null}, dst: &buf, expected: nil}, + } + + for i, tt := range bytesTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if bytes.Compare(*tt.dst, tt.expected) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst) + } + } + pointerAllocTests := []struct { src pgtype.Text dst interface{} From 0f92da1f24d28f4c90a708967ccf59c60def0847 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 15:51:16 -0500 Subject: [PATCH 0043/1158] Remove unneeded idea file --- extra-interface.txt | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 extra-interface.txt diff --git a/extra-interface.txt b/extra-interface.txt deleted file mode 100644 index f07818bc..00000000 --- a/extra-interface.txt +++ /dev/null @@ -1,3 +0,0 @@ -Can pass function to get inet data and function to get oid/name mapping as optional interface with io.Reader or io.Writer - -Could be useful for arrays of types without defined Oids like hstore. From 85f7df1e81e85f1355469da32ed0cc5d4632e7d6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 16:54:08 -0500 Subject: [PATCH 0044/1158] Factor out duplication in AssignTo --- aclitem.go | 42 +++++------------- aclitem_array.go | 23 +++++----- bool.go | 42 +++++------------- bool_array.go | 23 +++++----- bytea.go | 43 ++++++------------ bytea_array.go | 23 +++++----- cidr_array.go | 30 ++++++------- convert.go | 102 +++++++++++++++++++++++++++++++++---------- date.go | 39 +++++++---------- date_array.go | 23 +++++----- float4_array.go | 23 +++++----- float8_array.go | 23 +++++----- hstore.go | 21 ++++----- hstore_array.go | 23 +++++----- inet.go | 44 ++++++------------- inet_array.go | 30 ++++++------- int2_array.go | 30 ++++++------- int4_array.go | 30 ++++++------- int8_array.go | 30 ++++++------- record.go | 31 ++++++------- text.go | 50 ++++++--------------- text_array.go | 23 +++++----- timestamp.go | 39 +++++++---------- timestamp_array.go | 23 +++++----- timestamptz.go | 39 +++++++---------- timestamptz_array.go | 23 +++++----- typed_array.go.erb | 27 ++++++------ varchar_array.go | 23 +++++----- 28 files changed, 430 insertions(+), 492 deletions(-) diff --git a/aclitem.go b/aclitem.go index f9faab20..e8386ae7 100644 --- a/aclitem.go +++ b/aclitem.go @@ -3,7 +3,6 @@ package pgtype import ( "fmt" "io" - "reflect" ) // Aclitem is used for PostgreSQL's aclitem data type. A sample aclitem @@ -55,39 +54,22 @@ func (dst *Aclitem) Get() interface{} { } func (src *Aclitem) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *string: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.String - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) - case reflect.String: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - el.SetString(src.String) - return nil + switch src.Status { + case Present: + switch v := dst.(type) { + case *string: + *v = src.String + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Aclitem) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/aclitem_array.go b/aclitem_array.go index f02d339e..1c97e74f 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -58,28 +58,29 @@ func (dst *AclitemArray) Get() interface{} { } func (src *AclitemArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]string: - if src.Status == Present { + case *[]string: *v = make([]string, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *AclitemArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/bool.go b/bool.go index 87316381..608a6f95 100644 --- a/bool.go +++ b/bool.go @@ -3,7 +3,6 @@ package pgtype import ( "fmt" "io" - "reflect" "strconv" ) @@ -44,39 +43,22 @@ func (dst *Bool) Get() interface{} { } func (src *Bool) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *bool: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Bool - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) - case reflect.Bool: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - el.SetBool(src.Bool) - return nil + switch src.Status { + case Present: + switch v := dst.(type) { + case *bool: + *v = src.Bool + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/bool_array.go b/bool_array.go index 6adfbb00..cdfe9685 100644 --- a/bool_array.go +++ b/bool_array.go @@ -59,28 +59,29 @@ func (dst *BoolArray) Get() interface{} { } func (src *BoolArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]bool: - if src.Status == Present { + case *[]bool: *v = make([]bool, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/bytea.go b/bytea.go index dc1e9c07..00bed8e8 100644 --- a/bytea.go +++ b/bytea.go @@ -4,7 +4,6 @@ import ( "encoding/hex" "fmt" "io" - "reflect" ) type Bytea struct { @@ -42,38 +41,24 @@ func (dst *Bytea) Get() interface{} { } func (src *Bytea) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *[]byte: - if src.Status == Present { - *v = src.Bytes - } else { - *v = nil - } - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) - } + switch src.Status { + case Present: + switch v := dst.(type) { + case *[]byte: + buf := make([]byte, len(src.Bytes)) + copy(buf, src.Bytes) + *v = buf + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } // DecodeText only supports the hex format. This has been the default since diff --git a/bytea_array.go b/bytea_array.go index d318fa3b..175ca2f6 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -59,28 +59,29 @@ func (dst *ByteaArray) Get() interface{} { } func (src *ByteaArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[][]byte: - if src.Status == Present { + case *[][]byte: *v = make([][]byte, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/cidr_array.go b/cidr_array.go index 3ab83ecd..49a2728b 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -79,40 +79,38 @@ func (dst *CidrArray) Get() interface{} { } func (src *CidrArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]*net.IPNet: - if src.Status == Present { + case *[]*net.IPNet: *v = make([]*net.IPNet, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - case *[]net.IP: - if src.Status == Present { + case *[]net.IP: *v = make([]net.IP, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *CidrArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/convert.go b/convert.go index 648209f5..4fba8430 100644 --- a/convert.go +++ b/convert.go @@ -184,28 +184,6 @@ func underlyingSliceType(val interface{}) (interface{}, bool) { return nil, false } -func underlyingPtrSliceType(val interface{}) (interface{}, bool) { - refVal := reflect.ValueOf(val) - - if refVal.Kind() != reflect.Ptr { - return nil, false - } - if refVal.IsNil() { - return nil, false - } - - sliceVal := refVal.Elem().Interface() - baseSliceType := reflect.SliceOf(reflect.TypeOf(sliceVal).Elem()) - ptrBaseSliceType := reflect.PtrTo(baseSliceType) - - if refVal.Type().ConvertibleTo(ptrBaseSliceType) { - convVal := refVal.Convert(ptrBaseSliceType) - return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() - } - - return nil, false -} - func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { if srcStatus == Present { switch v := dst.(type) { @@ -363,3 +341,83 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) } + +func nullAssignTo(dst interface{}) error { + dstPtr := reflect.ValueOf(dst) + + // AssignTo dst must always be a pointer + if dstPtr.Kind() != reflect.Ptr { + return fmt.Errorf("cannot assign NULL to %T", dst) + } + + dstVal := dstPtr.Elem() + + switch dstVal.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map: + dstVal.Set(reflect.Zero(dstVal.Type())) + return nil + } + + return fmt.Errorf("cannot assign NULL to %T", dst) +} + +var kindTypes map[reflect.Kind]reflect.Type + +// GetAssignToDstType attempts to convert dst to something AssignTo can assign +// to. If dst is a pointer to pointer it allocates a value and returns the +// dereferences pointer. If dst is a named type such as *Foo where Foo is type +// Foo int16, it converts dst to *int16. +// +// GetAssignToDstType returns the converted dst and a bool representing if any +// change was made. +func GetAssignToDstType(dst interface{}) (interface{}, bool) { + dstPtr := reflect.ValueOf(dst) + + // AssignTo dst must always be a pointer + if dstPtr.Kind() != reflect.Ptr { + return nil, false + } + + dstVal := dstPtr.Elem() + + // if dst is a pointer to pointer, allocate space try again with the dereferenced pointer + if dstVal.Kind() == reflect.Ptr { + dstVal.Set(reflect.New(dstVal.Type().Elem())) + return dstVal.Interface(), true + } + + // if dst is pointer to a base type that has been renamed + if baseValType, ok := kindTypes[dstVal.Kind()]; ok { + nextDst := dstPtr.Convert(reflect.PtrTo(baseValType)) + return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + } + + if dstVal.Kind() == reflect.Slice { + if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { + baseSliceType := reflect.PtrTo(reflect.SliceOf(baseElemType)) + nextDst := dstPtr.Convert(baseSliceType) + return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + } + } + + return nil, false +} + +func init() { + kindTypes = map[reflect.Kind]reflect.Type{ + reflect.Bool: reflect.TypeOf(false), + reflect.Float32: reflect.TypeOf(float32(0)), + reflect.Float64: reflect.TypeOf(float64(0)), + reflect.Int: reflect.TypeOf(int(0)), + reflect.Int8: reflect.TypeOf(int8(0)), + reflect.Int16: reflect.TypeOf(int16(0)), + reflect.Int32: reflect.TypeOf(int32(0)), + reflect.Int64: reflect.TypeOf(int64(0)), + reflect.Uint: reflect.TypeOf(uint(0)), + reflect.Uint8: reflect.TypeOf(uint8(0)), + reflect.Uint16: reflect.TypeOf(uint16(0)), + reflect.Uint32: reflect.TypeOf(uint32(0)), + reflect.Uint64: reflect.TypeOf(uint64(0)), + reflect.String: reflect.TypeOf(""), + } +} diff --git a/date.go b/date.go index b6cc8329..ab854eb2 100644 --- a/date.go +++ b/date.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "fmt" "io" - "reflect" "time" "github.com/jackc/pgx/pgio" @@ -50,33 +49,25 @@ func (dst *Date) Get() interface{} { } func (src *Date) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *time.Time: - if src.Status != Present || src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/date_array.go b/date_array.go index 8bc8ff72..bf791677 100644 --- a/date_array.go +++ b/date_array.go @@ -60,28 +60,29 @@ func (dst *DateArray) Get() interface{} { } func (src *DateArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]time.Time: - if src.Status == Present { + case *[]time.Time: *v = make([]time.Time, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/float4_array.go b/float4_array.go index 6abc1a31..b4d05c55 100644 --- a/float4_array.go +++ b/float4_array.go @@ -59,28 +59,29 @@ func (dst *Float4Array) Get() interface{} { } func (src *Float4Array) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]float32: - if src.Status == Present { + case *[]float32: *v = make([]float32, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/float8_array.go b/float8_array.go index 050efa3f..e000807e 100644 --- a/float8_array.go +++ b/float8_array.go @@ -59,28 +59,29 @@ func (dst *Float8Array) Get() interface{} { } func (src *Float8Array) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]float64: - if src.Status == Present { + case *[]float64: *v = make([]float64, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/hstore.go b/hstore.go index d771d6e6..8dc5b4d8 100644 --- a/hstore.go +++ b/hstore.go @@ -47,10 +47,10 @@ func (dst *Hstore) Get() interface{} { } func (src *Hstore) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *map[string]string: - switch src.Status { - case Present: + switch src.Status { + case Present: + switch v := dst.(type) { + case *map[string]string: *v = make(map[string]string, len(src.Map)) for k, val := range src.Map { if val.Status != Present { @@ -58,16 +58,17 @@ func (src *Hstore) AssignTo(dst interface{}) error { } (*v)[k] = val.String } - case Null: - *v = nil + return nil default: - return fmt.Errorf("cannot decode %v into %T", src, dst) + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - default: - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/hstore_array.go b/hstore_array.go index ba192462..9bd0ed3b 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -59,28 +59,29 @@ func (dst *HstoreArray) Get() interface{} { } func (src *HstoreArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]map[string]string: - if src.Status == Present { + case *[]map[string]string: *v = make([]map[string]string, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/inet.go b/inet.go index b83bd1c9..13764814 100644 --- a/inet.go +++ b/inet.go @@ -4,7 +4,6 @@ import ( "fmt" "io" "net" - "reflect" "github.com/jackc/pgx/pgio" ) @@ -61,43 +60,28 @@ func (dst *Inet) Get() interface{} { } func (src *Inet) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *net.IPNet: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = *src.IPNet - case *net.IP: - if src.Status == Present { - + switch src.Status { + case Present: + switch v := dst.(type) { + case *net.IPNet: + *v = *src.IPNet + return nil + case *net.IP: if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { return fmt.Errorf("cannot assign %v to %T", src, dst) } *v = src.IPNet.IP - } else { - *v = nil - } - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/inet_array.go b/inet_array.go index d893a724..1988a145 100644 --- a/inet_array.go +++ b/inet_array.go @@ -79,40 +79,38 @@ func (dst *InetArray) Get() interface{} { } func (src *InetArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]*net.IPNet: - if src.Status == Present { + case *[]*net.IPNet: *v = make([]*net.IPNet, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - case *[]net.IP: - if src.Status == Present { + case *[]net.IP: *v = make([]net.IP, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/int2_array.go b/int2_array.go index b93a4fa3..531e7dd6 100644 --- a/int2_array.go +++ b/int2_array.go @@ -78,40 +78,38 @@ func (dst *Int2Array) Get() interface{} { } func (src *Int2Array) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]int16: - if src.Status == Present { + case *[]int16: *v = make([]int16, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - case *[]uint16: - if src.Status == Present { + case *[]uint16: *v = make([]uint16, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/int4_array.go b/int4_array.go index 0b96b7a4..3617050f 100644 --- a/int4_array.go +++ b/int4_array.go @@ -78,40 +78,38 @@ func (dst *Int4Array) Get() interface{} { } func (src *Int4Array) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]int32: - if src.Status == Present { + case *[]int32: *v = make([]int32, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - case *[]uint32: - if src.Status == Present { + case *[]uint32: *v = make([]uint32, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/int8_array.go b/int8_array.go index 02a240f4..4f04b660 100644 --- a/int8_array.go +++ b/int8_array.go @@ -78,40 +78,38 @@ func (dst *Int8Array) Get() interface{} { } func (src *Int8Array) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]int64: - if src.Status == Present { + case *[]int64: *v = make([]int64, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - case *[]uint64: - if src.Status == Present { + case *[]uint64: *v = make([]uint64, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/record.go b/record.go index 1bfd05b9..89e081ca 100644 --- a/record.go +++ b/record.go @@ -38,34 +38,29 @@ func (dst *Record) Get() interface{} { } func (src *Record) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *[]Value: - switch src.Status { - case Present: + switch src.Status { + case Present: + switch v := dst.(type) { + case *[]Value: *v = make([]Value, len(src.Fields)) copy(*v, src.Fields) - case Null: - *v = nil - default: - return fmt.Errorf("cannot decode %v into %T", src, dst) - } - case *[]interface{}: - switch src.Status { - case Present: + return nil + case *[]interface{}: *v = make([]interface{}, len(src.Fields)) for i := range *v { (*v)[i] = src.Fields[i].Get() } - case Null: - *v = nil + return nil default: - return fmt.Errorf("cannot decode %v into %T", src, dst) + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - default: - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { diff --git a/text.go b/text.go index af7f16fc..dbc9362b 100644 --- a/text.go +++ b/text.go @@ -3,7 +3,6 @@ package pgtype import ( "fmt" "io" - "reflect" ) type Text struct { @@ -43,49 +42,26 @@ func (dst *Text) Get() interface{} { } func (src *Text) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *string: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.String - case *[]byte: - switch src.Status { - case Present: + switch src.Status { + case Present: + switch v := dst.(type) { + case *string: + *v = src.String + return nil + case *[]byte: *v = make([]byte, len(src.String)) copy(*v, src.String) - case Null: - *v = nil + return nil default: - return fmt.Errorf("unknown status") - } - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) - case reflect.String: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - el.SetString(src.String) - return nil + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/text_array.go b/text_array.go index 9f25727e..6e8ead26 100644 --- a/text_array.go +++ b/text_array.go @@ -59,28 +59,29 @@ func (dst *TextArray) Get() interface{} { } func (src *TextArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]string: - if src.Status == Present { + case *[]string: *v = make([]string, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/timestamp.go b/timestamp.go index 9a9e74ea..4b42f3cf 100644 --- a/timestamp.go +++ b/timestamp.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "fmt" "io" - "reflect" "time" "github.com/jackc/pgx/pgio" @@ -54,33 +53,25 @@ func (dst *Timestamp) Get() interface{} { } func (src *Timestamp) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *time.Time: - if src.Status != Present || src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot assign %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } // DecodeText decodes from src into dst. The decoded time is considered to diff --git a/timestamp_array.go b/timestamp_array.go index bb19e502..6a6950c7 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -60,28 +60,29 @@ func (dst *TimestampArray) Get() interface{} { } func (src *TimestampArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]time.Time: - if src.Status == Present { + case *[]time.Time: *v = make([]time.Time, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/timestamptz.go b/timestamptz.go index 7f57f4b7..ba849ac8 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "fmt" "io" - "reflect" "time" "github.com/jackc/pgx/pgio" @@ -55,33 +54,25 @@ func (dst *Timestamptz) Get() interface{} { } func (src *Timestamptz) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *time.Time: - if src.Status != Present || src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot assign %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/timestamptz_array.go b/timestamptz_array.go index 6a85cefa..347d0b8b 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -60,28 +60,29 @@ func (dst *TimestamptzArray) Get() interface{} { } func (src *TimestamptzArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]time.Time: - if src.Status == Present { + case *[]time.Time: *v = make([]time.Time, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/typed_array.go.erb b/typed_array.go.erb index 2b81666e..26c4671c 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -58,28 +58,29 @@ func (dst *<%= pgtype_array_type %>) Get() interface{} { } func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { - switch v := dst.(type) { - <% go_array_types.split(",").each do |t| %> - case *<%= t %>: - if src.Status == Present { + switch src.Status { + case Present: + switch v := dst.(type) { + <% go_array_types.split(",").each do |t| %> + case *<%= t %>: *v = make(<%= t %>, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil + return nil + <% end %> + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - <% end %> - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) - } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/varchar_array.go b/varchar_array.go index 158ece94..e1dd3910 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -59,28 +59,29 @@ func (dst *VarcharArray) Get() interface{} { } func (src *VarcharArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]string: - if src.Status == Present { + case *[]string: *v = make([]string, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { From 3acd3d8546eda21913a199dd0908bb96c2b81789 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 17:38:52 -0500 Subject: [PATCH 0045/1158] Optionally generate binary array format --- typed_array.go.erb | 92 ++++++++++++++++++++++++---------------------- typed_array_gen.sh | 32 ++++++++-------- 2 files changed, 64 insertions(+), 60 deletions(-) diff --git a/typed_array.go.erb b/typed_array.go.erb index 26c4671c..0e5725ce 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -119,6 +119,7 @@ func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error return nil } +<% if binary_format == "true" %> func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} @@ -160,6 +161,7 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) erro *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} return nil } +<% end %> func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { @@ -237,61 +239,63 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool return false, nil } -func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - switch src.Status { - case Null: - return true, nil - case Undefined: - return false, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { - arrayHeader.ElementOid = int32(dt.Oid) - } else { - return false, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break +<% if binary_format == "true" %> + func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } - elemBuf := &bytes.Buffer{} + if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") + } - for i := range src.Elements { - elemBuf.Reset() + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } - if null { - _, err = pgio.WriteInt32(w, -1) + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - } - return false, err -} + return false, err + } +<% end %> diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 166f8802..d77c8ca3 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -1,16 +1,16 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_type_name=int2 text_null=NULL typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_type_name=int4 text_null=NULL typed_array.go.erb > int4_array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_type_name=int8 text_null=NULL typed_array.go.erb > int8_array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_type_name=bool text_null=NULL typed_array.go.erb > bool_array.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_type_name=date text_null=NULL typed_array.go.erb > date_array.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_type_name=timestamptz text_null=NULL typed_array.go.erb > timestamptz_array.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_type_name=timestamp text_null=NULL typed_array.go.erb > timestamp_array.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_type_name=float4 text_null=NULL typed_array.go.erb > float4_array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_type_name=float8 text_null=NULL typed_array.go.erb > float8_array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_type_name=inet text_null=NULL typed_array.go.erb > inet_array.go -erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL typed_array.go.erb > cidr_array.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' typed_array.go.erb > text_array.go -erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' typed_array.go.erb > varchar_array.go -erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL typed_array.go.erb > bytea_array.go -erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_type_name=aclitem text_null=NULL typed_array.go.erb > aclitem_array.go -erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL typed_array.go.erb > hstore_array.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go +erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' binary_format=true typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' binary_format=true typed_array.go.erb > varchar_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go +erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go From 6f9ef694d0322f02a47d378cd830e67821e59a2f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 21:11:43 -0500 Subject: [PATCH 0046/1158] Add database/sql support to pgtype --- aclitem.go | 30 ++++++++++++++++++++++ aclitem_array.go | 31 ++++++++++++++++++++++ bool.go | 33 ++++++++++++++++++++++++ bool_array.go | 31 ++++++++++++++++++++++ bytea.go | 38 +++++++++++++++++++++++++++ bytea_array.go | 31 ++++++++++++++++++++++ cid.go | 11 ++++++++ cidr_array.go | 31 ++++++++++++++++++++++ database_sql.go | 52 +++++++++++-------------------------- date.go | 47 +++++++++++++++++++++++++++++++--- date_array.go | 31 ++++++++++++++++++++++ date_test.go | 7 ++++- float4.go | 38 +++++++++++++++++++++++++++ float4_array.go | 31 ++++++++++++++++++++++ float8.go | 38 +++++++++++++++++++++++++++ float8_array.go | 31 ++++++++++++++++++++++ generic_binary.go | 11 ++++++++ generic_text.go | 11 ++++++++ hstore.go | 28 ++++++++++++++++++++ hstore_array.go | 31 ++++++++++++++++++++++ inet.go | 28 ++++++++++++++++++++ inet_array.go | 31 ++++++++++++++++++++++ int2.go | 44 ++++++++++++++++++++++++++++++++ int2_array.go | 31 ++++++++++++++++++++++ int4.go | 46 ++++++++++++++++++++++++++++++++- int4_array.go | 31 ++++++++++++++++++++++ int8.go | 38 +++++++++++++++++++++++++++ int8_array.go | 31 ++++++++++++++++++++++ json.go | 36 ++++++++++++++++++++++++++ jsonb.go | 11 ++++++++ name.go | 11 ++++++++ oid.go | 25 ++++++++++++++++++ oid_value.go | 11 ++++++++ pgtype.go | 13 ++++++++++ pgtype_test.go | 61 +++++++++++++++++++++++++++++++++++++++++++- pguint32.go | 45 ++++++++++++++++++++++++++++++++ qchar.go | 9 ++++++- qchar_test.go | 4 ++- record.go | 5 ++++ text.go | 41 +++++++++++++++++++++++++++++ text_array.go | 31 ++++++++++++++++++++++ tid.go | 23 +++++++++++++++++ timestamp.go | 47 +++++++++++++++++++++++++++++++--- timestamp_array.go | 31 ++++++++++++++++++++++ timestamptz.go | 47 +++++++++++++++++++++++++++++++--- timestamptz_array.go | 31 ++++++++++++++++++++++ typed_array.go.erb | 30 ++++++++++++++++++++++ unknown.go | 12 +++++++++ varchar.go | 11 ++++++++ varchar_array.go | 31 ++++++++++++++++++++++ xid.go | 11 ++++++++ 51 files changed, 1398 insertions(+), 51 deletions(-) diff --git a/aclitem.go b/aclitem.go index e8386ae7..77e385e6 100644 --- a/aclitem.go +++ b/aclitem.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" ) @@ -93,3 +94,32 @@ func (src Aclitem) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { _, err := io.WriteString(w, src.String) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Aclitem) Scan(src interface{}) error { + if src == nil { + *dst = Aclitem{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Aclitem) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.String, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/aclitem_array.go b/aclitem_array.go index 1c97e74f..20a7636a 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "fmt" "io" @@ -194,3 +195,33 @@ func (src *AclitemArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } + +// Scan implements the database/sql Scanner interface. +func (dst *AclitemArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *AclitemArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/bool.go b/bool.go index 608a6f95..736d19cf 100644 --- a/bool.go +++ b/bool.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" "strconv" @@ -126,3 +127,35 @@ func (src Bool) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := w.Write(buf) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Bool) Scan(src interface{}) error { + if src == nil { + *dst = Bool{Status: Null} + return nil + } + + switch src := src.(type) { + case bool: + *dst = Bool{Bool: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bool) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bool, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/bool_array.go b/bool_array.go index cdfe9685..4705d734 100644 --- a/bool_array.go +++ b/bool_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *BoolArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *BoolArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/bytea.go b/bytea.go index 00bed8e8..9f0266e7 100644 --- a/bytea.go +++ b/bytea.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/hex" "fmt" "io" @@ -12,6 +13,11 @@ type Bytea struct { } func (dst *Bytea) Set(src interface{}) error { + if src == nil { + *dst = Bytea{Status: Null} + return nil + } + switch value := src.(type) { case []byte: if value != nil { @@ -124,3 +130,35 @@ func (src Bytea) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := w.Write(src.Bytes) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Bytea) Scan(src interface{}) error { + if src == nil { + *dst = Bytea{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + buf := make([]byte, len(src)) + copy(buf, src) + *dst = Bytea{Bytes: buf, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bytea) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bytes, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/bytea_array.go b/bytea_array.go index 175ca2f6..268364c1 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *ByteaArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *ByteaArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/cid.go b/cid.go index d86e8063..63ba6a2f 100644 --- a/cid.go +++ b/cid.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -49,3 +50,13 @@ func (src Cid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Cid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Cid) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Cid) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} diff --git a/cidr_array.go b/cidr_array.go index 49a2728b..6643bb47 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -325,3 +326,33 @@ func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *CidrArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *CidrArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/database_sql.go b/database_sql.go index 969d6542..2ddd842d 100644 --- a/database_sql.go +++ b/database_sql.go @@ -2,47 +2,13 @@ package pgtype import ( "bytes" + "database/sql/driver" "errors" ) func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { - switch src := src.(type) { - case *Bool: - return src.Bool, nil - case *Bytea: - return src.Bytes, nil - case *Date: - if src.InfinityModifier == None { - return src.Time, nil - } - case *Float4: - return float64(src.Float), nil - case *Float8: - return src.Float, nil - case *GenericBinary: - return src.Bytes, nil - case *GenericText: - return src.String, nil - case *Int2: - return int64(src.Int), nil - case *Int4: - return int64(src.Int), nil - case *Int8: - return int64(src.Int), nil - case *Text: - return src.String, nil - case *Timestamp: - if src.InfinityModifier == None { - return src.Time, nil - } - case *Timestamptz: - if src.InfinityModifier == None { - return src.Time, nil - } - case *Unknown: - return src.String, nil - case *Varchar: - return src.String, nil + if valuer, ok := src.(driver.Valuer); ok { + return valuer.Value() } buf := &bytes.Buffer{} @@ -64,3 +30,15 @@ func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { return nil, errors.New("cannot convert to database/sql compatible value") } + +func encodeValueText(src TextEncoder) (interface{}, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + return buf.String(), err +} diff --git a/date.go b/date.go index ab854eb2..7dd2c4f0 100644 --- a/date.go +++ b/date.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -10,9 +11,9 @@ import ( ) type Date struct { - Time time.Time - Status Status - InfinityModifier + Time time.Time + Status Status + InfinityModifier InfinityModifier } const ( @@ -21,6 +22,11 @@ const ( ) func (dst *Date) Set(src interface{}) error { + if src == nil { + *dst = Date{Status: Null} + return nil + } + switch value := src.(type) { case time.Time: *dst = Date{Time: value, Status: Present} @@ -167,3 +173,38 @@ func (src Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt32(w, daysSinceDateEpoch) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Date) Scan(src interface{}) error { + if src == nil { + *dst = Date{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + case time.Time: + *dst = Date{Time: src, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Date) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/date_array.go b/date_array.go index bf791677..f58de011 100644 --- a/date_array.go +++ b/date_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -297,3 +298,33 @@ func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *DateArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *DateArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/date_test.go b/date_test.go index cfc3dd70..1832b5b4 100644 --- a/date_test.go +++ b/date_test.go @@ -9,7 +9,7 @@ import ( ) func TestDateTranscode(t *testing.T) { - testSuccessfulTranscode(t, "date", []interface{}{ + testSuccessfulTranscodeEqFunc(t, "date", []interface{}{ pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, @@ -19,6 +19,11 @@ func TestDateTranscode(t *testing.T) { pgtype.Date{Status: pgtype.Null}, pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Date) + bt := b.(pgtype.Date) + + return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier }) } diff --git a/float4.go b/float4.go index 94b7b7a1..e92149a6 100644 --- a/float4.go +++ b/float4.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Float4 struct { } func (dst *Float4) Set(src interface{}) error { + if src == nil { + *dst = Float4{Status: Null} + return nil + } + switch value := src.(type) { case float32: *dst = Float4{Float: value, Status: Present} @@ -156,3 +162,35 @@ func (src Float4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt32(w, int32(math.Float32bits(src.Float))) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Float4) Scan(src interface{}) error { + if src == nil { + *dst = Float4{Status: Null} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Float4{Float: float32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float4) Value() (driver.Value, error) { + switch src.Status { + case Present: + return float64(src.Float), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/float4_array.go b/float4_array.go index b4d05c55..b9ee4b9e 100644 --- a/float4_array.go +++ b/float4_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Float4Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Float4Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/float8.go b/float8.go index dd2d592d..4d094757 100644 --- a/float8.go +++ b/float8.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Float8 struct { } func (dst *Float8) Set(src interface{}) error { + if src == nil { + *dst = Float8{Status: Null} + return nil + } + switch value := src.(type) { case float32: *dst = Float8{Float: float64(value), Status: Present} @@ -146,3 +152,35 @@ func (src Float8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt64(w, int64(math.Float64bits(src.Float))) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Float8) Scan(src interface{}) error { + if src == nil { + *dst = Float8{Status: Null} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Float8{Float: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float8) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Float, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/float8_array.go b/float8_array.go index e000807e..d49f18a7 100644 --- a/float8_array.go +++ b/float8_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Float8Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Float8Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/generic_binary.go b/generic_binary.go index aa28bb62..f834bfb2 100644 --- a/generic_binary.go +++ b/generic_binary.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -27,3 +28,13 @@ func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { func (src GenericBinary) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (Bytea)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *GenericBinary) Scan(src interface{}) error { + return (*Bytea)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src GenericBinary) Value() (driver.Value, error) { + return (Bytea)(src).Value() +} diff --git a/generic_text.go b/generic_text.go index bd75e0d0..053ec504 100644 --- a/generic_text.go +++ b/generic_text.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -27,3 +28,13 @@ func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { func (src GenericText) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return (Text)(src).EncodeText(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *GenericText) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src GenericText) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/hstore.go b/hstore.go index 8dc5b4d8..b8b0c6f3 100644 --- a/hstore.go +++ b/hstore.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "errors" "fmt" @@ -21,6 +22,11 @@ type Hstore struct { } func (dst *Hstore) Set(src interface{}) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + switch value := src.(type) { case map[string]string: m := make(map[string]Text, len(value)) @@ -437,3 +443,25 @@ func parseHstore(s string) (k []string, v []Text, err error) { v = values return } + +// Scan implements the database/sql Scanner interface. +func (dst *Hstore) Scan(src interface{}) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Hstore) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/hstore_array.go b/hstore_array.go index 9bd0ed3b..097fec7b 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *HstoreArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *HstoreArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *HstoreArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/inet.go b/inet.go index 13764814..0ca3ee7a 100644 --- a/inet.go +++ b/inet.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" "net" @@ -23,6 +24,11 @@ type Inet struct { } func (dst *Inet) Set(src interface{}) error { + if src == nil { + *dst = Inet{Status: Null} + return nil + } + switch value := src.(type) { case net.IPNet: *dst = Inet{IPNet: &value, Status: Present} @@ -189,3 +195,25 @@ func (src Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := w.Write(src.IPNet.IP) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Inet) Scan(src interface{}) error { + if src == nil { + *dst = Inet{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Inet) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/inet_array.go b/inet_array.go index 1988a145..a108d75b 100644 --- a/inet_array.go +++ b/inet_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -325,3 +326,33 @@ func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *InetArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *InetArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/int2.go b/int2.go index 6996cd4f..3bcac63c 100644 --- a/int2.go +++ b/int2.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Int2 struct { } func (dst *Int2) Set(src interface{}) error { + if src == nil { + *dst = Int2{Status: Null} + return nil + } + switch value := src.(type) { case int8: *dst = Int2{Int: int16(value), Status: Present} @@ -151,3 +157,41 @@ func (src Int2) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt16(w, src.Int) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int2) Scan(src interface{}) error { + if src == nil { + *dst = Int2{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + if src < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", src) + } + if src > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", src) + } + *dst = Int2{Int: int16(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int2) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/int2_array.go b/int2_array.go index 531e7dd6..bddb5ac2 100644 --- a/int2_array.go +++ b/int2_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -324,3 +325,33 @@ func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int2Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int2Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/int4.go b/int4.go index 62ee366f..5069dab4 100644 --- a/int4.go +++ b/int4.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Int4 struct { } func (dst *Int4) Set(src interface{}) error { + if src == nil { + *dst = Int4{Status: Null} + return nil + } + switch value := src.(type) { case int8: *dst = Int4{Int: int32(value), Status: Present} @@ -68,7 +74,7 @@ func (dst *Int4) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Int8", value) + return fmt.Errorf("cannot convert %v to Int4", value) } return nil @@ -142,3 +148,41 @@ func (src Int4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt32(w, src.Int) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int4) Scan(src interface{}) error { + if src == nil { + *dst = Int4{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + if src < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", src) + } + if src > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", src) + } + *dst = Int4{Int: int32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/int4_array.go b/int4_array.go index 3617050f..d5c8f911 100644 --- a/int4_array.go +++ b/int4_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -324,3 +325,33 @@ func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int4Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int4Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/int8.go b/int8.go index 7ed54f8e..cf701dc6 100644 --- a/int8.go +++ b/int8.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Int8 struct { } func (dst *Int8) Set(src interface{}) error { + if src == nil { + *dst = Int8{Status: Null} + return nil + } + switch value := src.(type) { case int8: *dst = Int8{Int: int64(value), Status: Present} @@ -134,3 +140,35 @@ func (src Int8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt64(w, src.Int) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int8) Scan(src interface{}) error { + if src == nil { + *dst = Int8{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + *dst = Int8{Int: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/int8_array.go b/int8_array.go index 4f04b660..ae2521fa 100644 --- a/int8_array.go +++ b/int8_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -324,3 +325,33 @@ func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int8Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int8Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/json.go b/json.go index bfffae14..05d965ca 100644 --- a/json.go +++ b/json.go @@ -1,7 +1,9 @@ package pgtype import ( + "database/sql/driver" "encoding/json" + "fmt" "io" ) @@ -11,6 +13,11 @@ type Json struct { } func (dst *Json) Set(src interface{}) error { + if src == nil { + *dst = Json{Status: Null} + return nil + } + switch value := src.(type) { case string: *dst = Json{Bytes: []byte(value), Status: Present} @@ -116,3 +123,32 @@ func (src Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Json) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return src.EncodeText(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Json) Scan(src interface{}) error { + if src == nil { + *dst = Json{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Json) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bytes, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/jsonb.go b/jsonb.go index e44f3c41..f47476d6 100644 --- a/jsonb.go +++ b/jsonb.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" ) @@ -66,3 +67,13 @@ func (src Jsonb) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err = w.Write(src.Bytes) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Jsonb) Scan(src interface{}) error { + return (*Json)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Jsonb) Value() (driver.Value, error) { + return (Json)(src).Value() +} diff --git a/name.go b/name.go index 9ebf63d3..cc4ae23b 100644 --- a/name.go +++ b/name.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -46,3 +47,13 @@ func (src Name) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Name) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (Text)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Name) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Name) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/oid.go b/oid.go index 3edd7f3c..339dee0f 100644 --- a/oid.go +++ b/oid.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -55,3 +56,27 @@ func (src Oid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteUint32(w, uint32(src)) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Oid) Scan(src interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", src) + } + + switch src := src.(type) { + case int64: + *dst = Oid(src) + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Oid) Value() (driver.Value, error) { + return int64(src), nil +} diff --git a/oid_value.go b/oid_value.go index 1bce6e11..cb03802e 100644 --- a/oid_value.go +++ b/oid_value.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -43,3 +44,13 @@ func (src OidValue) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src OidValue) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *OidValue) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src OidValue) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} diff --git a/pgtype.go b/pgtype.go index 674c0db7..7e6633d9 100644 --- a/pgtype.go +++ b/pgtype.go @@ -67,6 +67,19 @@ const ( NegativeInfinity InfinityModifier = -Infinity ) +func (im InfinityModifier) String() string { + switch im { + case None: + return "none" + case Infinity: + return "infinity" + case NegativeInfinity: + return "-infinity" + default: + return "invalid" + } +} + type Value interface { // Set converts and assigns src to itself. Set(src interface{}) error diff --git a/pgtype_test.go b/pgtype_test.go index 391fed57..16cabfd1 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "database/sql" "fmt" "io" "net" @@ -10,6 +11,8 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" + _ "github.com/jackc/pgx/stdlib" + _ "github.com/lib/pq" ) // Test for renamed types @@ -24,6 +27,25 @@ type _float32Slice []float32 type _float64Slice []float64 type _byteSlice []byte +func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { + var sqlDriverName string + switch driverName { + case "github.com/lib/pq": + sqlDriverName = "postgres" + case "github.com/jackc/pgx/stdlib": + sqlDriverName = "pgx" + default: + t.Fatalf("Unknown driver %v", driverName) + } + + db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL")) + if err != nil { + t.Fatal(err) + } + + return db +} + func mustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) if err != nil { @@ -93,6 +115,13 @@ func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface } func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } +} + +func testPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { conn := mustConnectPgx(t) defer mustClose(t, conn) @@ -114,7 +143,7 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int ps.FieldDescriptions[0].FormatCode = fc.formatCode vEncoder := forceEncoder(v, fc.formatCode) if vEncoder == nil { - t.Logf("%#v does not implement %v", v, fc.name) + t.Logf("Skipping: %#v does not implement %v", v, fc.name) continue } // Derefence value if it is a pointer @@ -136,3 +165,33 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int } } } + +func testDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectDatabaseSQL(t, driverName) + defer mustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := ps.QueryRow(v).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } + +} diff --git a/pguint32.go b/pguint32.go index 3f9e7bf7..7138a409 100644 --- a/pguint32.go +++ b/pguint32.go @@ -1,9 +1,11 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" + "math" "strconv" "github.com/jackc/pgx/pgio" @@ -21,6 +23,14 @@ type pguint32 struct { // types do. func (dst *pguint32) Set(src interface{}) error { switch value := src.(type) { + case int64: + if value < 0 { + return fmt.Errorf("%d is less than minimum value for pguint32", value) + } + if value > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for pguint32", value) + } + *dst = pguint32{Uint: uint32(value), Status: Present} case uint32: *dst = pguint32{Uint: value, Status: Present} default: @@ -116,3 +126,38 @@ func (src pguint32) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteUint32(w, src.Uint) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *pguint32) Scan(src interface{}) error { + if src == nil { + *dst = pguint32{Status: Null} + return nil + } + + switch src := src.(type) { + case uint32: + *dst = pguint32{Uint: src, Status: Present} + return nil + case int64: + *dst = pguint32{Uint: uint32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src pguint32) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Uint), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/qchar.go b/qchar.go index 4b32ee4a..49475bd3 100644 --- a/qchar.go +++ b/qchar.go @@ -17,13 +17,20 @@ import ( // standard type char. // // Not all possible values of QChar are representable in the text format. -// Therefore, QChar does not implement TextEncoder and TextDecoder. +// Therefore, QChar does not implement TextEncoder and TextDecoder. In +// addition, database/sql Scanner and database/sql/driver Value are not +// implemented. type QChar struct { Int int8 Status Status } func (dst *QChar) Set(src interface{}) error { + if src == nil { + *dst = QChar{Status: Null} + return nil + } + switch value := src.(type) { case int8: *dst = QChar{Int: value, Status: Present} diff --git a/qchar_test.go b/qchar_test.go index a1b6d22e..afac5016 100644 --- a/qchar_test.go +++ b/qchar_test.go @@ -9,13 +9,15 @@ import ( ) func TestQCharTranscode(t *testing.T) { - testSuccessfulTranscode(t, `"char"`, []interface{}{ + testPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, pgtype.QChar{Int: -1, Status: pgtype.Present}, pgtype.QChar{Int: 0, Status: pgtype.Present}, pgtype.QChar{Int: 1, Status: pgtype.Present}, pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present}, pgtype.QChar{Int: 0, Status: pgtype.Null}, + }, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) }) } diff --git a/record.go b/record.go index 89e081ca..9c42c907 100644 --- a/record.go +++ b/record.go @@ -16,6 +16,11 @@ type Record struct { } func (dst *Record) Set(src interface{}) error { + if src == nil { + *dst = Record{Status: Null} + return nil + } + switch value := src.(type) { case []Value: *dst = Record{Fields: value, Status: Present} diff --git a/text.go b/text.go index dbc9362b..482c9023 100644 --- a/text.go +++ b/text.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" ) @@ -11,6 +12,11 @@ type Text struct { } func (dst *Text) Set(src interface{}) error { + if src == nil { + *dst = Text{Status: Null} + return nil + } + switch value := src.(type) { case string: *dst = Text{String: value, Status: Present} @@ -20,6 +26,12 @@ func (dst *Text) Set(src interface{}) error { } else { *dst = Text{String: *value, Status: Present} } + case []byte: + if value == nil { + *dst = Text{Status: Null} + } else { + *dst = Text{String: string(value), Status: Present} + } default: if originalSrc, ok := underlyingStringType(src); ok { return dst.Set(originalSrc) @@ -93,3 +105,32 @@ func (src Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Text) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return src.EncodeText(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Text) Scan(src interface{}) error { + if src == nil { + *dst = Text{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Text) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.String, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/text_array.go b/text_array.go index 6e8ead26..64728048 100644 --- a/text_array.go +++ b/text_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *TextArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TextArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/tid.go b/tid.go index b91711d3..b363c1f9 100644 --- a/tid.go +++ b/tid.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -121,3 +122,25 @@ func (src Tid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err = pgio.WriteUint16(w, src.OffsetNumber) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Tid) Scan(src interface{}) error { + if src == nil { + *dst = Tid{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Tid) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/timestamp.go b/timestamp.go index 4b42f3cf..78c6355e 100644 --- a/timestamp.go +++ b/timestamp.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -17,14 +18,19 @@ const pgTimestampFormat = "2006-01-02 15:04:05.999999999" // recommended to use timestamptz whenever possible. Timestamp methods either // convert to UTC or return an error on non-UTC times. type Timestamp struct { - Time time.Time // Time must always be in UTC. - Status Status - InfinityModifier + Time time.Time // Time must always be in UTC. + Status Status + InfinityModifier InfinityModifier } // Set converts src into a Timestamp and stores in dst. If src is a // time.Time in a non-UTC time zone, the time zone is discarded. func (dst *Timestamp) Set(src interface{}) error { + if src == nil { + *dst = Timestamp{Status: Null} + return nil + } + switch value := src.(type) { case time.Time: *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present} @@ -183,3 +189,38 @@ func (src Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt64(w, microsecSinceY2K) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamp) Scan(src interface{}) error { + if src == nil { + *dst = Timestamp{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + case time.Time: + *dst = Timestamp{Time: src, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamp) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/timestamp_array.go b/timestamp_array.go index 6a6950c7..5d08f9cc 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -297,3 +298,33 @@ func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *TimestampArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TimestampArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/timestamptz.go b/timestamptz.go index ba849ac8..50370335 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -20,12 +21,17 @@ const ( ) type Timestamptz struct { - Time time.Time - Status Status - InfinityModifier + Time time.Time + Status Status + InfinityModifier InfinityModifier } func (dst *Timestamptz) Set(src interface{}) error { + if src == nil { + *dst = Timestamptz{Status: Null} + return nil + } + switch value := src.(type) { case time.Time: *dst = Timestamptz{Time: value, Status: Present} @@ -179,3 +185,38 @@ func (src Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt64(w, microsecSinceY2K) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamptz) Scan(src interface{}) error { + if src == nil { + *dst = Timestamptz{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + case time.Time: + *dst = Timestamptz{Time: src, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamptz) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/timestamptz_array.go b/timestamptz_array.go index 347d0b8b..107be06a 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -297,3 +298,33 @@ func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, erro return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *TimestamptzArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TimestamptzArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/typed_array.go.erb b/typed_array.go.erb index 0e5725ce..4b8f1a28 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -299,3 +299,33 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool return false, err } <% end %> + +// Scan implements the database/sql Scanner interface. +func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *<%= pgtype_array_type %>) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/unknown.go b/unknown.go index b951ad99..2dca0f87 100644 --- a/unknown.go +++ b/unknown.go @@ -1,5 +1,7 @@ package pgtype +import "database/sql/driver" + // Unknown represents the PostgreSQL unknown type. It is either a string literal // or NULL. It is used when PostgreSQL does not know the type of a value. In // general, this will only be used in pgx when selecting a null value without @@ -30,3 +32,13 @@ func (dst *Unknown) DecodeText(ci *ConnInfo, src []byte) error { func (dst *Unknown) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } + +// Scan implements the database/sql Scanner interface. +func (dst *Unknown) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Unknown) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/varchar.go b/varchar.go index adda6c49..f25ada5d 100644 --- a/varchar.go +++ b/varchar.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -38,3 +39,13 @@ func (src Varchar) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Varchar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (Text)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Varchar) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Varchar) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/varchar_array.go b/varchar_array.go index e1dd3910..2712b4d2 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *VarcharArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *VarcharArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/xid.go b/xid.go index c76548a4..0a7fc7d9 100644 --- a/xid.go +++ b/xid.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -52,3 +53,13 @@ func (src Xid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Xid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Xid) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Xid) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} From 46454758004cd091b96d86e82acb16a6b57d121e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 20 Mar 2017 08:00:43 -0500 Subject: [PATCH 0047/1158] Run goimports as part of array gen script --- typed_array_gen.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/typed_array_gen.sh b/typed_array_gen.sh index d77c8ca3..52612466 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -14,3 +14,4 @@ erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[] erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go +goimports -w *_array.go From 0e51991aaae85893affa3457e9b63756138c318e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 20 Mar 2017 08:58:28 -0500 Subject: [PATCH 0048/1158] Skip jsonb test if no jsonb type --- jsonb_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/jsonb_test.go b/jsonb_test.go index 3978b0d4..91637eb8 100644 --- a/jsonb_test.go +++ b/jsonb_test.go @@ -9,6 +9,12 @@ import ( ) func TestJsonbTranscode(t *testing.T) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + if _, ok := conn.ConnInfo.DataTypeForName("jsonb"); !ok { + t.Skip("Skipping due to no jsonb type") + } + testSuccessfulTranscode(t, "jsonb", []interface{}{ pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, pgtype.Jsonb{Bytes: []byte("null"), Status: pgtype.Present}, From be04ad7b21659fd99e3014dbe500fd0be2a4b775 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 23 Mar 2017 18:41:52 -0500 Subject: [PATCH 0049/1158] Add int4range --- int4range.go | 268 +++++++++++++++++++++++++++++++++++++++++++++ int4range_test.go | 25 +++++ pgtype.go | 1 + pgtype_test.go | 93 ++++++++++++++++ range.go | 273 ++++++++++++++++++++++++++++++++++++++++++++++ range_test.go | 177 ++++++++++++++++++++++++++++++ 6 files changed, 837 insertions(+) create mode 100644 int4range.go create mode 100644 int4range_test.go create mode 100644 range.go create mode 100644 range_test.go diff --git a/int4range.go b/int4range.go new file mode 100644 index 00000000..cac4484c --- /dev/null +++ b/int4range.go @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Int4range struct { + Lower Int4 + Upper Int4 + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Int4range) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Int4range", src) +} + +func (dst *Int4range) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int4range) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Int4range) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4range{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Int4range{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Int4range) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4range{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Int4range{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Int4range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src Int4range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4range) Scan(src interface{}) error { + if src == nil { + *dst = Int4range{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4range) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/int4range_test.go b/int4range_test.go new file mode 100644 index 00000000..c96fe9cd --- /dev/null +++ b/int4range_test.go @@ -0,0 +1,25 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt4rangeTranscode(t *testing.T) { + testSuccessfulTranscode(t, "int4range", []interface{}{ + pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int4{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + pgtype.Int4range{Status: pgtype.Null}, + }) +} + +func TestInt4rangeNormalize(t *testing.T) { + testSuccessfulNormalize(t, []normalizeTest{ + { + sql: "select int4range(1, 10, '(]')", + value: pgtype.Int4range{Lower: pgtype.Int4{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + }, + }) +} diff --git a/pgtype.go b/pgtype.go index 7e6633d9..7a95994c 100644 --- a/pgtype.go +++ b/pgtype.go @@ -233,6 +233,7 @@ func init() { "inet": &Inet{}, "int2": &Int2{}, "int4": &Int4{}, + "int4range": &Int4range{}, "int8": &Int8{}, "json": &Json{}, "jsonb": &Jsonb{}, diff --git a/pgtype_test.go b/pgtype_test.go index 16cabfd1..298cff64 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -189,6 +189,99 @@ func testDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeNa t.Errorf("%v %d: %v", driverName, i, err) } + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } +} + +type normalizeTest struct { + sql string + value interface{} +} + +func testSuccessfulNormalize(t testing.TB, tests []normalizeTest) { + testSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func testSuccessfulNormalizeEqFunc(t testing.TB, tests []normalizeTest, eqFunc func(a, b interface{}) bool) { + testPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + testDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc) + } +} + +func testPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []normalizeTest, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for i, tt := range tests { + for _, fc := range formats { + psName := fmt.Sprintf("test%d", i) + ps, err := conn.Prepare(psName, tt.sql) + if err != nil { + t.Fatal(err) + } + + ps.FieldDescriptions[0].FormatCode = fc.formatCode + if forceEncoder(tt.value, fc.formatCode) == nil { + t.Logf("Skipping: %#v does not implement %v", tt.value, fc.name) + continue + } + // Derefence value if it is a pointer + derefV := tt.value + refVal := reflect.ValueOf(tt.value) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err = conn.QueryRow(psName).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + } + } + } +} + +func testDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []normalizeTest, eqFunc func(a, b interface{}) bool) { + conn := mustConnectDatabaseSQL(t, driverName) + defer mustClose(t, conn) + + for i, tt := range tests { + ps, err := conn.Prepare(tt.sql) + if err != nil { + t.Errorf("%d. %v", i, err) + continue + } + + // Derefence value if it is a pointer + derefV := tt.value + refVal := reflect.ValueOf(tt.value) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err = ps.QueryRow().Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + if !eqFunc(result.Elem().Interface(), derefV) { t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) } diff --git a/range.go b/range.go new file mode 100644 index 00000000..76daf8cc --- /dev/null +++ b/range.go @@ -0,0 +1,273 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +type BoundType byte + +const ( + Inclusive = BoundType('i') + Exclusive = BoundType('e') + Unbounded = BoundType('U') + Empty = BoundType('E') +) + +type UntypedTextRange struct { + Lower string + Upper string + LowerType BoundType + UpperType BoundType +} + +func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { + utr := &UntypedTextRange{} + if src == "empty" { + utr.LowerType = 'E' + utr.UpperType = 'E' + return utr, nil + } + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid lower bound: %v", err) + } + switch r { + case '(': + utr.LowerType = Exclusive + case '[': + utr.LowerType = Inclusive + default: + return nil, fmt.Errorf("missing lower bound, instead got: %v", string(r)) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid lower value: %v", err) + } + buf.UnreadRune() + + if r == ',' { + utr.LowerType = Unbounded + } else { + utr.Lower, err = rangeParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid lower value: %v", err) + } + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("missing range separator: %v", err) + } + if r != ',' { + return nil, fmt.Errorf("missing range separator: %v", r) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid upper value: %v", err) + } + buf.UnreadRune() + + if r == ')' || r == ']' { + utr.UpperType = Unbounded + } else { + utr.Upper, err = rangeParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid upper value: %v", err) + } + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("missing upper bound: %v", err) + } + switch r { + case ')': + utr.UpperType = Exclusive + case ']': + utr.UpperType = Inclusive + default: + return nil, fmt.Errorf("missing upper bound, instead got: %v", string(r)) + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + return utr, nil +} + +func rangeParseValue(buf *bytes.Buffer) (string, error) { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + if r == '"' { + return rangeParseQuotedValue(buf) + } + buf.UnreadRune() + + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case ',', '[', ']', '(', ')': + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} + +func rangeParseQuotedValue(buf *bytes.Buffer) (string, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case '"': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + if r != '"' { + buf.UnreadRune() + return s.String(), nil + } + } + s.WriteRune(r) + } +} + +type UntypedBinaryRange struct { + Lower []byte + Upper []byte + LowerType BoundType + UpperType BoundType +} + +// 0 = () = 00000 +// 1 = empty = 00001 +// 2 = [) = 00010 +// 4 = (] = 00100 +// 6 = [] = 00110 +// 8 = ) = 01000 +// 12 = ] = 01100 +// 16 = ( = 10000 +// 18 = [ = 10010 +// 24 = = 11000 + +const emptyMask = 1 +const lowerInclusiveMask = 2 +const upperInclusiveMask = 4 +const lowerUnboundedMask = 8 +const upperUnboundedMask = 16 + +func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { + ubr := &UntypedBinaryRange{} + + if len(src) == 0 { + return nil, fmt.Errorf("range too short: %v", len(src)) + } + + rangeType := src[0] + rp := 1 + + if rangeType&emptyMask > 0 { + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) + } + ubr.LowerType = Empty + ubr.UpperType = Empty + return ubr, nil + } + + if rangeType&lowerInclusiveMask > 0 { + ubr.LowerType = Inclusive + } else if rangeType&lowerUnboundedMask > 0 { + ubr.LowerType = Unbounded + } else { + ubr.LowerType = Exclusive + } + + if rangeType&upperInclusiveMask > 0 { + ubr.UpperType = Inclusive + } else if rangeType&upperUnboundedMask > 0 { + ubr.UpperType = Unbounded + } else { + ubr.UpperType = Exclusive + } + + if ubr.LowerType == Unbounded && ubr.UpperType == Unbounded { + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) + } + return ubr, nil + } + + if len(src[rp:]) < 4 { + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + } + valueLen := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + val := src[rp : rp+valueLen] + rp += valueLen + + if ubr.LowerType != Unbounded { + ubr.Lower = val + } else { + ubr.Upper = val + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + } + return ubr, nil + } + + if ubr.UpperType != Unbounded { + if len(src[rp:]) < 4 { + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + } + valueLen := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + ubr.Upper = src[rp : rp+valueLen] + rp += valueLen + } + + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + } + + return ubr, nil + +} diff --git a/range_test.go b/range_test.go new file mode 100644 index 00000000..9e16df59 --- /dev/null +++ b/range_test.go @@ -0,0 +1,177 @@ +package pgtype + +import ( + "bytes" + "testing" +) + +func TestParseUntypedTextRange(t *testing.T) { + tests := []struct { + src string + result UntypedTextRange + err error + }{ + { + src: `[1,2)`, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[1,2]`, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: `(1,3)`, + result: UntypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: ` [1,2) `, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[ foo , bar )`, + result: UntypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["foo","bar")`, + result: UntypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["f""oo","b""ar")`, + result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["f""oo","b""ar")`, + result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["","bar")`, + result: UntypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[f\"oo\,,b\\ar\))`, + result: UntypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `empty`, + result: UntypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty}, + err: nil, + }, + } + + for i, tt := range tests { + r, err := ParseUntypedTextRange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if r.LowerType != tt.result.LowerType { + t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) + } + + if r.UpperType != tt.result.UpperType { + t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) + } + + if r.Lower != tt.result.Lower { + t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) + } + + if r.Upper != tt.result.Upper { + t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) + } + } +} + +func TestParseUntypedBinaryRange(t *testing.T) { + tests := []struct { + src []byte + result UntypedBinaryRange + err error + }{ + { + src: []byte{0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{1}, + result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty}, + err: nil, + }, + { + src: []byte{2, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{4, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{6, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{8, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{12, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{16, 0, 0, 0, 2, 0, 4}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded}, + err: nil, + }, + { + src: []byte{18, 0, 0, 0, 2, 0, 4}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded}, + err: nil, + }, + { + src: []byte{24}, + result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded}, + err: nil, + }, + } + + for i, tt := range tests { + r, err := ParseUntypedBinaryRange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if r.LowerType != tt.result.LowerType { + t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) + } + + if r.UpperType != tt.result.UpperType { + t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) + } + + if bytes.Compare(r.Lower, tt.result.Lower) != 0 { + t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) + } + + if bytes.Compare(r.Upper, tt.result.Upper) != 0 { + t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) + } + } +} From a021a7717a43c140d0ed2e82e5232c66d9918dff Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 24 Mar 2017 13:36:10 -0500 Subject: [PATCH 0050/1158] Add Int8range Add code generation for ranges --- int8range.go | 268 +++++++++++++++++++++++++++++++++++++++++++++ int8range_test.go | 25 +++++ typed_range.go.erb | 268 +++++++++++++++++++++++++++++++++++++++++++++ typed_range_gen.sh | 3 + 4 files changed, 564 insertions(+) create mode 100644 int8range.go create mode 100644 int8range_test.go create mode 100644 typed_range.go.erb create mode 100644 typed_range_gen.sh diff --git a/int8range.go b/int8range.go new file mode 100644 index 00000000..44946be9 --- /dev/null +++ b/int8range.go @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Int8range struct { + Lower Int8 + Upper Int8 + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Int8range) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Int8range", src) +} + +func (dst *Int8range) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int8range) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Int8range) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Int8range{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Int8range{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8range) Scan(src interface{}) error { + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8range) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/int8range_test.go b/int8range_test.go new file mode 100644 index 00000000..1b3e594c --- /dev/null +++ b/int8range_test.go @@ -0,0 +1,25 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt8rangeTranscode(t *testing.T) { + testSuccessfulTranscode(t, "Int8range", []interface{}{ + pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int8{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + pgtype.Int8range{Status: pgtype.Null}, + }) +} + +func TestInt8rangeNormalize(t *testing.T) { + testSuccessfulNormalize(t, []normalizeTest{ + { + sql: "select Int8range(1, 10, '(]')", + value: pgtype.Int8range{Lower: pgtype.Int8{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + }, + }) +} diff --git a/typed_range.go.erb b/typed_range.go.erb new file mode 100644 index 00000000..922b98b4 --- /dev/null +++ b/typed_range.go.erb @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type <%= range_type %> struct { + Lower <%= element_type %> + Upper <%= element_type %> + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *<%= range_type %>) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to <%= range_type %>", src) +} + +func (dst *<%= range_type %>) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *<%= range_type %>) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *<%= range_type %>) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = <%= range_type %>{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *<%= range_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = <%= range_type %>{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src <%= range_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *<%= range_type %>) Scan(src interface{}) error { + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src <%= range_type %>) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/typed_range_gen.sh b/typed_range_gen.sh new file mode 100644 index 00000000..af3e2cd1 --- /dev/null +++ b/typed_range_gen.sh @@ -0,0 +1,3 @@ +erb range_type=Int4range element_type=Int4 typed_range.go.erb > int4range.go +erb range_type=Int8range element_type=Int8 typed_range.go.erb > int8range.go +goimports -w *range.go From 94971db9e2d7001186522c2bab8e1fdf19637b82 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 24 Mar 2017 14:17:49 -0500 Subject: [PATCH 0051/1158] Add daterange, tsrange, and tstzrange --- daterange.go | 268 +++++++++++++++++++++++++++++++++++++++++++++ daterange_test.go | 66 +++++++++++ pgtype.go | 4 + tsrange.go | 268 +++++++++++++++++++++++++++++++++++++++++++++ tsrange_test.go | 40 +++++++ tstzrange.go | 268 +++++++++++++++++++++++++++++++++++++++++++++ tstzrange_test.go | 40 +++++++ typed_range_gen.sh | 3 + 8 files changed, 957 insertions(+) create mode 100644 daterange.go create mode 100644 daterange_test.go create mode 100644 tsrange.go create mode 100644 tsrange_test.go create mode 100644 tstzrange.go create mode 100644 tstzrange_test.go diff --git a/daterange.go b/daterange.go new file mode 100644 index 00000000..fbf51980 --- /dev/null +++ b/daterange.go @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Daterange struct { + Lower Date + Upper Date + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Daterange) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Daterange", src) +} + +func (dst *Daterange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Daterange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Daterange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Daterange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Daterange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Daterange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Daterange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Daterange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Daterange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src Daterange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Daterange) Scan(src interface{}) error { + if src == nil { + *dst = Daterange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Daterange) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/daterange_test.go b/daterange_test.go new file mode 100644 index 00000000..8501cc7e --- /dev/null +++ b/daterange_test.go @@ -0,0 +1,66 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestDaterangeTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "daterange", []interface{}{ + pgtype.Daterange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Daterange{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Daterange) + b := bb.(pgtype.Daterange) + + return a.Status == b.Status && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Status == b.Lower.Status && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Status == b.Upper.Status && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} + +func TestDaterangeNormalize(t *testing.T) { + testSuccessfulNormalizeEqFunc(t, []normalizeTest{ + { + sql: "select daterange('2010-01-01', '2010-01-11', '(]')", + value: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(2010, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2010, 1, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + }, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Daterange) + b := bb.(pgtype.Daterange) + + return a.Status == b.Status && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Status == b.Lower.Status && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Status == b.Upper.Status && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} diff --git a/pgtype.go b/pgtype.go index 7a95994c..3d691044 100644 --- a/pgtype.go +++ b/pgtype.go @@ -227,6 +227,7 @@ func init() { "cid": &Cid{}, "cidr": &Cidr{}, "date": &Date{}, + "daterange": &Daterange{}, "float4": &Float4{}, "float8": &Float8{}, "hstore": &Hstore{}, @@ -235,6 +236,7 @@ func init() { "int4": &Int4{}, "int4range": &Int4range{}, "int8": &Int8{}, + "int8range": &Int8range{}, "json": &Json{}, "jsonb": &Jsonb{}, "name": &Name{}, @@ -244,6 +246,8 @@ func init() { "tid": &Tid{}, "timestamp": &Timestamp{}, "timestamptz": &Timestamptz{}, + "tsrange": &Tsrange{}, + "tstzrange": &Tstzrange{}, "unknown": &Unknown{}, "varchar": &Varchar{}, "xid": &Xid{}, diff --git a/tsrange.go b/tsrange.go new file mode 100644 index 00000000..48992829 --- /dev/null +++ b/tsrange.go @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Tsrange struct { + Lower Timestamp + Upper Timestamp + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Tsrange) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Tsrange", src) +} + +func (dst *Tsrange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Tsrange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Tsrange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tsrange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Tsrange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Tsrange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tsrange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Tsrange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Tsrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src Tsrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Tsrange) Scan(src interface{}) error { + if src == nil { + *dst = Tsrange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Tsrange) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/tsrange_test.go b/tsrange_test.go new file mode 100644 index 00000000..448cb92f --- /dev/null +++ b/tsrange_test.go @@ -0,0 +1,40 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestTsrangeTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "tsrange", []interface{}{ + pgtype.Tsrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + pgtype.Tsrange{ + Lower: pgtype.Timestamp{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Timestamp{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Tsrange{ + Lower: pgtype.Timestamp{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Tsrange{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Tsrange) + b := bb.(pgtype.Tsrange) + + return a.Status == b.Status && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Status == b.Lower.Status && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Status == b.Upper.Status && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} diff --git a/tstzrange.go b/tstzrange.go new file mode 100644 index 00000000..61e94ab4 --- /dev/null +++ b/tstzrange.go @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Tstzrange struct { + Lower Timestamptz + Upper Timestamptz + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Tstzrange) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Tstzrange", src) +} + +func (dst *Tstzrange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Tstzrange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Tstzrange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tstzrange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Tstzrange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Tstzrange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tstzrange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Tstzrange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Tstzrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src Tstzrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Tstzrange) Scan(src interface{}) error { + if src == nil { + *dst = Tstzrange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Tstzrange) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/tstzrange_test.go b/tstzrange_test.go new file mode 100644 index 00000000..197aabbc --- /dev/null +++ b/tstzrange_test.go @@ -0,0 +1,40 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestTstzrangeTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "tstzrange", []interface{}{ + pgtype.Tstzrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + pgtype.Tstzrange{ + Lower: pgtype.Timestamptz{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Timestamptz{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Tstzrange{ + Lower: pgtype.Timestamptz{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Tstzrange{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Tstzrange) + b := bb.(pgtype.Tstzrange) + + return a.Status == b.Status && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Status == b.Lower.Status && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Status == b.Upper.Status && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} diff --git a/typed_range_gen.sh b/typed_range_gen.sh index af3e2cd1..b4220f09 100644 --- a/typed_range_gen.sh +++ b/typed_range_gen.sh @@ -1,3 +1,6 @@ erb range_type=Int4range element_type=Int4 typed_range.go.erb > int4range.go erb range_type=Int8range element_type=Int8 typed_range.go.erb > int8range.go +erb range_type=Tsrange element_type=Timestamp typed_range.go.erb > tsrange.go +erb range_type=Tstzrange element_type=Timestamptz typed_range.go.erb > tstzrange.go +erb range_type=Daterange element_type=Date typed_range.go.erb > daterange.go goimports -w *range.go From d25c346d6d67423c749a90ffd4a0338a16d235e9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 31 Mar 2017 20:11:18 -0500 Subject: [PATCH 0052/1158] Add interval type --- interval.go | 271 +++++++++++++++++++++++++++++++++++++++++++++++ interval_test.go | 62 +++++++++++ 2 files changed, 333 insertions(+) create mode 100644 interval.go create mode 100644 interval_test.go diff --git a/interval.go b/interval.go new file mode 100644 index 00000000..7eddb10f --- /dev/null +++ b/interval.go @@ -0,0 +1,271 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "strconv" + "strings" + "time" + + "github.com/jackc/pgx/pgio" +) + +const ( + microsecondsPerSecond = 1000000 + microsecondsPerMinute = 60 * microsecondsPerSecond + microsecondsPerHour = 60 * microsecondsPerMinute +) + +type Interval struct { + Microseconds int64 + Days int32 + Months int32 + Status Status +} + +func (dst *Interval) Set(src interface{}) error { + if src == nil { + *dst = Interval{Status: Null} + return nil + } + + switch value := src.(type) { + case time.Duration: + *dst = Interval{Microseconds: int64(value) / 1000, Status: Present} + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Interval", value) + } + + return nil +} + +func (dst *Interval) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Interval) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Duration: + if src.Days > 0 || src.Months > 0 { + return fmt.Errorf("interval with months or days cannot be decoded into %T", dst) + } + *v = time.Duration(src.Microseconds) * time.Microsecond + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return nullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Interval{Status: Null} + return nil + } + + var microseconds int64 + var days int32 + var months int32 + + parts := strings.Split(string(src), " ") + + for i := 0; i < len(parts)-1; i += 2 { + scalar, err := strconv.ParseInt(parts[i], 10, 64) + if err != nil { + return fmt.Errorf("bad interval format") + } + + switch parts[i+1] { + case "year", "years": + months += int32(scalar * 12) + case "mon", "mons": + months += int32(scalar) + case "day", "days": + days = int32(scalar) + } + } + + if len(parts)%2 == 1 { + timeParts := strings.SplitN(parts[len(parts)-1], ":", 3) + if len(timeParts) != 3 { + return fmt.Errorf("bad interval format") + } + + var negative bool + if timeParts[0][0] == '-' { + negative = true + timeParts[0] = timeParts[0][1:] + } + + hours, err := strconv.ParseInt(timeParts[0], 10, 64) + if err != nil { + return fmt.Errorf("bad interval hour format: %s", hours) + } + + minutes, err := strconv.ParseInt(timeParts[1], 10, 64) + if err != nil { + return fmt.Errorf("bad interval minute format: %s", minutes) + } + + secondParts := strings.SplitN(timeParts[2], ".", 2) + + seconds, err := strconv.ParseInt(secondParts[0], 10, 64) + if err != nil { + return fmt.Errorf("bad interval second format: %s", seconds) + } + + var uSeconds int64 + if len(secondParts) == 2 { + uSeconds, err = strconv.ParseInt(secondParts[1], 10, 64) + if err != nil { + return fmt.Errorf("bad interval decimal format: %s", seconds) + } + + for i := 0; i < 6-len(secondParts[1]); i++ { + uSeconds *= 10 + } + } + + microseconds = hours * microsecondsPerHour + microseconds += minutes * microsecondsPerMinute + microseconds += seconds * microsecondsPerSecond + microseconds += uSeconds + + if negative { + microseconds = -microseconds + } + } + + *dst = Interval{Months: months, Days: days, Microseconds: microseconds, Status: Present} + return nil +} + +func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Interval{Status: Null} + return nil + } + + if len(src) != 16 { + return fmt.Errorf("Received an invalid size for a interval: %d", len(src)) + } + + microseconds := int64(binary.BigEndian.Uint64(src)) + days := int32(binary.BigEndian.Uint32(src[8:])) + months := int32(binary.BigEndian.Uint32(src[12:])) + + *dst = Interval{Microseconds: microseconds, Days: days, Months: months, Status: Present} + return nil +} + +func (src Interval) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if src.Months != 0 { + if _, err := io.WriteString(w, strconv.FormatInt(int64(src.Months), 10)); err != nil { + return false, err + } + + if _, err := io.WriteString(w, " mon "); err != nil { + return false, err + } + } + + if src.Days != 0 { + if _, err := io.WriteString(w, strconv.FormatInt(int64(src.Days), 10)); err != nil { + return false, err + } + + if _, err := io.WriteString(w, " day "); err != nil { + return false, err + } + } + + absMicroseconds := src.Microseconds + if absMicroseconds < 0 { + absMicroseconds = -absMicroseconds + + if err := pgio.WriteByte(w, '-'); err != nil { + return false, err + } + } + + hours := absMicroseconds / microsecondsPerHour + minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute + seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond + microseconds := absMicroseconds % microsecondsPerSecond + + timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds) + + _, err := io.WriteString(w, timeStr) + return false, err +} + +// EncodeBinary encodes src into w. +func (src Interval) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := pgio.WriteInt64(w, src.Microseconds); err != nil { + return false, err + } + if _, err := pgio.WriteInt32(w, src.Days); err != nil { + return false, err + } + if _, err := pgio.WriteInt32(w, src.Months); err != nil { + return false, err + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Interval) Scan(src interface{}) error { + if src == nil { + *dst = Interval{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Interval) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/interval_test.go b/interval_test.go new file mode 100644 index 00000000..db9614ef --- /dev/null +++ b/interval_test.go @@ -0,0 +1,62 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestIntervalTranscode(t *testing.T) { + testSuccessfulTranscode(t, "interval", []interface{}{ + pgtype.Interval{Microseconds: 1, Status: pgtype.Present}, + pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, + pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, + pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, + pgtype.Interval{Days: 1, Status: pgtype.Present}, + pgtype.Interval{Months: 1, Status: pgtype.Present}, + pgtype.Interval{Months: 12, Status: pgtype.Present}, + pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Status: pgtype.Present}, + pgtype.Interval{Microseconds: -1, Status: pgtype.Present}, + pgtype.Interval{Microseconds: -1000000, Status: pgtype.Present}, + pgtype.Interval{Microseconds: -1000001, Status: pgtype.Present}, + pgtype.Interval{Microseconds: -123202800000000, Status: pgtype.Present}, + pgtype.Interval{Days: -1, Status: pgtype.Present}, + pgtype.Interval{Months: -1, Status: pgtype.Present}, + pgtype.Interval{Months: -12, Status: pgtype.Present}, + pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Status: pgtype.Present}, + pgtype.Interval{Status: pgtype.Null}, + }) +} + +func TestIntervalNormalize(t *testing.T) { + testSuccessfulNormalize(t, []normalizeTest{ + { + sql: "select '1 second'::interval", + value: pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, + }, + { + sql: "select '1.000001 second'::interval", + value: pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, + }, + { + sql: "select '34223 hours'::interval", + value: pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, + }, + { + sql: "select '1 day'::interval", + value: pgtype.Interval{Days: 1, Status: pgtype.Present}, + }, + { + sql: "select '1 month'::interval", + value: pgtype.Interval{Months: 1, Status: pgtype.Present}, + }, + { + sql: "select '1 year'::interval", + value: pgtype.Interval{Months: 12, Status: pgtype.Present}, + }, + { + sql: "select '-13 mon'::interval", + value: pgtype.Interval{Months: -13, Status: pgtype.Present}, + }, + }) +} From f7191d3a5605bd3b128844799e5c8f7119fc7a1f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Apr 2017 23:33:04 -0500 Subject: [PATCH 0053/1158] Add pgtype.Numeric --- decimal.go | 35 +++ numeric.go | 602 ++++++++++++++++++++++++++++++++++++++++++++++++ numeric_test.go | 315 +++++++++++++++++++++++++ pgtype.go | 2 + 4 files changed, 954 insertions(+) create mode 100644 decimal.go create mode 100644 numeric.go create mode 100644 numeric_test.go diff --git a/decimal.go b/decimal.go new file mode 100644 index 00000000..728c748e --- /dev/null +++ b/decimal.go @@ -0,0 +1,35 @@ +package pgtype + +import ( + "io" +) + +type Decimal Numeric + +func (dst *Decimal) Set(src interface{}) error { + return (*Numeric)(dst).Set(src) +} + +func (dst *Decimal) Get() interface{} { + return (*Numeric)(dst).Get() +} + +func (src *Decimal) AssignTo(dst interface{}) error { + return (*Numeric)(src).AssignTo(dst) +} + +func (dst *Decimal) DecodeText(ci *ConnInfo, src []byte) error { + return (*Numeric)(dst).DecodeText(ci, src) +} + +func (dst *Decimal) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Numeric)(dst).DecodeBinary(ci, src) +} + +func (src *Decimal) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Numeric)(src).EncodeText(ci, w) +} + +func (src *Decimal) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Numeric)(src).EncodeBinary(ci, w) +} diff --git a/numeric.go b/numeric.go new file mode 100644 index 00000000..0f3f6529 --- /dev/null +++ b/numeric.go @@ -0,0 +1,602 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "math/big" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 +const nbase = 10000 + +var big0 *big.Int = big.NewInt(0) +var big10 *big.Int = big.NewInt(10) +var big100 *big.Int = big.NewInt(100) +var big1000 *big.Int = big.NewInt(1000) + +var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8) +var bigMinInt8 *big.Int = big.NewInt(math.MinInt8) +var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16) +var bigMinInt16 *big.Int = big.NewInt(math.MinInt16) +var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32) +var bigMinInt32 *big.Int = big.NewInt(math.MinInt32) +var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64) +var bigMinInt64 *big.Int = big.NewInt(math.MinInt64) +var bigMaxInt *big.Int = big.NewInt(int64(maxInt)) +var bigMinInt *big.Int = big.NewInt(int64(minInt)) + +var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8) +var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16) +var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32) +var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64)) +var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint)) + +var bigNBase *big.Int = big.NewInt(nbase) +var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) +var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) +var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) + +type Numeric struct { + Int *big.Int + Exp int32 + Status Status +} + +func (dst *Numeric) Set(src interface{}) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + switch value := src.(type) { + case float32: + num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) + if err != nil { + return err + } + *dst = Numeric{Int: num, Exp: exp, Status: Present} + case float64: + num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) + if err != nil { + return err + } + *dst = Numeric{Int: num, Exp: exp, Status: Present} + case int8: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint8: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case int16: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint16: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case int32: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint32: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case int64: + *dst = Numeric{Int: big.NewInt(value), Status: Present} + case uint64: + *dst = Numeric{Int: (&big.Int{}).SetUint64(value), Status: Present} + case int: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint: + *dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Status: Present} + case string: + num, exp, err := parseNumericString(value) + if err != nil { + return err + } + *dst = Numeric{Int: num, Exp: exp, Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Numeric", value) + } + + return nil +} + +func (dst *Numeric) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Numeric) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *float32: + f, err := strconv.ParseFloat(src.Int.String(), 64) + if err != nil { + return err + } + return float64AssignTo(f, src.Status, dst) + case *float64: + f, err := strconv.ParseFloat(src.Int.String(), 64) + if err != nil { + return err + } + return float64AssignTo(f, src.Status, dst) + case *int: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int(normalizedInt.Int64()) + case *int8: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt8) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt8) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int8(normalizedInt.Int64()) + case *int16: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt16) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt16) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int16(normalizedInt.Int64()) + case *int32: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt32) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt32) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int32(normalizedInt.Int64()) + case *int64: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt64) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt64) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = normalizedInt.Int64() + case *uint: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint(normalizedInt.Uint64()) + case *uint8: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint8) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint8(normalizedInt.Uint64()) + case *uint16: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint16) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint16(normalizedInt.Uint64()) + case *uint32: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint32) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint32(normalizedInt.Uint64()) + case *uint64: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint64) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = normalizedInt.Uint64() + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return nullAssignTo(dst) + } + + return nil +} + +func (dst *Numeric) toBigInt() (*big.Int, error) { + if dst.Exp == 0 { + return dst.Int, nil + } + + num := &big.Int{} + num.Set(dst.Int) + if dst.Exp > 0 { + mul := &big.Int{} + mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil) + num.Mul(num, mul) + return num, nil + } + + div := &big.Int{} + div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil) + remainder := &big.Int{} + num.DivMod(num, div, remainder) + if remainder.Cmp(big0) != 0 { + return nil, fmt.Errorf("cannot convert %v to integer", dst) + } + return num, nil +} + +func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + num, exp, err := parseNumericString(string(src)) + if err != nil { + return err + } + + *dst = Numeric{Int: num, Exp: exp, Status: Present} + return nil +} + +func parseNumericString(str string) (n *big.Int, exp int32, err error) { + parts := strings.SplitN(str, ".", 2) + digits := strings.Join(parts, "") + + if len(parts) > 1 { + exp = int32(-len(parts[1])) + } else { + for len(digits) > 1 && digits[len(digits)-1] == '0' { + digits = digits[:len(digits)-1] + exp++ + } + } + + accum := &big.Int{} + if _, ok := accum.SetString(digits, 10); !ok { + return nil, 0, fmt.Errorf("%s is not a number", str) + } + + return accum, exp, nil +} + +func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + if len(src) < 8 { + return fmt.Errorf("numeric incomplete %v", src) + } + + rp := 0 + ndigits := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if ndigits == 0 { + *dst = Numeric{Int: big.NewInt(0), Status: Present} + return nil + } + + weight := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + sign := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + dscale := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if len(src[rp:]) < int(ndigits)*2 { + return fmt.Errorf("numeric incomplete %v", src) + } + + accum := &big.Int{} + + for i := 0; i < int(ndigits+3)/4; i++ { + int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:]) + rp += bytesRead + + if i > 0 { + var mul *big.Int + switch digitsRead { + case 1: + mul = bigNBase + case 2: + mul = bigNBaseX2 + case 3: + mul = bigNBaseX3 + case 4: + mul = bigNBaseX4 + default: + return fmt.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) + } + accum.Mul(accum, mul) + } + + accum.Add(accum, big.NewInt(int64accum)) + } + + exp := (int32(weight) - int32(ndigits) + 1) * 4 + + if dscale > 0 { + fracNBaseDigits := ndigits - weight - 1 + fracDecimalDigits := fracNBaseDigits * 4 + + if dscale > fracDecimalDigits { + multCount := int(dscale - fracDecimalDigits) + for i := 0; i < multCount; i++ { + accum.Mul(accum, big10) + exp-- + } + } else if dscale < fracDecimalDigits { + divCount := int(fracDecimalDigits - dscale) + for i := 0; i < divCount; i++ { + accum.Div(accum, big10) + exp++ + } + } + } + + reduced := &big.Int{} + remainder := &big.Int{} + if exp >= 0 { + for { + reduced.DivMod(accum, big10, remainder) + if remainder.Cmp(big0) != 0 { + break + } + accum.Set(reduced) + exp++ + } + } + + if sign != 0 { + accum.Neg(accum) + } + + *dst = Numeric{Int: accum, Exp: exp, Status: Present} + + return nil + +} + +func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { + digits := len(src) / 2 + if digits > 4 { + digits = 4 + } + + rp := 0 + + for i := 0; i < digits; i++ { + if i > 0 { + accum *= nbase + } + accum += int64(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + return accum, rp, digits +} + +func (src *Numeric) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := io.WriteString(w, src.Int.String()); err != nil { + return false, err + } + + if err := pgio.WriteByte(w, 'e'); err != nil { + return false, err + } + + if _, err := io.WriteString(w, strconv.FormatInt(int64(src.Exp), 10)); err != nil { + return false, err + } + + return false, nil + +} + +func (src *Numeric) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var sign int16 + if src.Int.Cmp(big0) < 0 { + sign = 16384 + } + + absInt := &big.Int{} + wholePart := &big.Int{} + fracPart := &big.Int{} + remainder := &big.Int{} + absInt.Abs(src.Int) + + // Normalize absInt and exp to where exp is always a multiple of 4. This makes + // converting to 16-bit base 10,000 digits easier. + var exp int32 + switch src.Exp % 4 { + case 1, -3: + exp = src.Exp - 1 + absInt.Mul(absInt, big10) + case 2, -2: + exp = src.Exp - 2 + absInt.Mul(absInt, big100) + case 3, -1: + exp = src.Exp - 3 + absInt.Mul(absInt, big1000) + default: + exp = src.Exp + } + + if exp < 0 { + divisor := &big.Int{} + divisor.Exp(big10, big.NewInt(int64(-exp)), nil) + wholePart.DivMod(absInt, divisor, fracPart) + } else { + wholePart = absInt + } + + var wholeDigits, fracDigits []int16 + + for wholePart.Cmp(big0) != 0 { + wholePart.DivMod(wholePart, bigNBase, remainder) + wholeDigits = append(wholeDigits, int16(remainder.Int64())) + } + + for fracPart.Cmp(big0) != 0 { + fracPart.DivMod(fracPart, bigNBase, remainder) + fracDigits = append(fracDigits, int16(remainder.Int64())) + } + + if _, err := pgio.WriteInt16(w, int16(len(wholeDigits)+len(fracDigits))); err != nil { + return false, err + } + + var weight int16 + if len(wholeDigits) > 0 { + weight = int16(len(wholeDigits) - 1) + if exp > 0 { + weight += int16(exp / 4) + } + } else { + weight = int16(exp/4) - 1 + int16(len(fracDigits)) + } + if _, err := pgio.WriteInt16(w, weight); err != nil { + return false, err + } + + if _, err := pgio.WriteInt16(w, sign); err != nil { + return false, err + } + + var dscale int16 + if src.Exp < 0 { + dscale = int16(-src.Exp) + } + if _, err := pgio.WriteInt16(w, dscale); err != nil { + return false, err + } + + for i := len(wholeDigits) - 1; i >= 0; i-- { + if _, err := pgio.WriteInt16(w, wholeDigits[i]); err != nil { + return false, err + } + } + + for i := len(fracDigits) - 1; i >= 0; i-- { + if _, err := pgio.WriteInt16(w, fracDigits[i]); err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Numeric) Scan(src interface{}) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + switch src := src.(type) { + case float64: + // TODO + // *dst = Numeric{Float: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Numeric) Value() (driver.Value, error) { + switch src.Status { + case Present: + buf := &bytes.Buffer{} + _, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + + return buf.String(), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/numeric_test.go b/numeric_test.go new file mode 100644 index 00000000..64dea847 --- /dev/null +++ b/numeric_test.go @@ -0,0 +1,315 @@ +package pgtype_test + +import ( + "math/big" + "math/rand" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +// For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0) +func numericEqual(left, right *pgtype.Numeric) bool { + return left.Status == right.Status && + left.Exp == right.Exp && + ((left.Int == nil && right.Int == nil) || (left.Int != nil && right.Int != nil && left.Int.Cmp(right.Int) == 0)) +} + +// For test purposes only. +func numericNormalizedEqual(left, right *pgtype.Numeric) bool { + if left.Status != right.Status { + return false + } + + normLeft := &pgtype.Numeric{Int: (&big.Int{}).Set(left.Int), Status: left.Status} + normRight := &pgtype.Numeric{Int: (&big.Int{}).Set(right.Int), Status: right.Status} + + if left.Exp < right.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(right.Exp-left.Exp)), nil) + normRight.Int.Mul(normRight.Int, mul) + } else if left.Exp > right.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(left.Exp-right.Exp)), nil) + normLeft.Int.Mul(normLeft.Int, mul) + } + + return normLeft.Int.Cmp(normRight.Int) == 0 +} + +func mustParseBigInt(t *testing.T, src string) *big.Int { + i := &big.Int{} + if _, ok := i.SetString(src, 10); !ok { + t.Fatalf("could not parse big.Int: %s", src) + } + return i +} + +func TestNumericNormalize(t *testing.T) { + testSuccessfulNormalize(t, []normalizeTest{ + { + sql: "select '0'::numeric", + value: pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, + }, + { + sql: "select '1'::numeric", + value: pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, + }, + { + sql: "select '10.00'::numeric", + value: pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present}, + }, + { + sql: "select '1e-3'::numeric", + value: pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present}, + }, + { + sql: "select '-1'::numeric", + value: pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, + }, + { + sql: "select '10000'::numeric", + value: pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present}, + }, + { + sql: "select '3.14'::numeric", + value: pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, + }, + { + sql: "select '1.1'::numeric", + value: pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present}, + }, + { + sql: "select '100010001'::numeric", + value: pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present}, + }, + { + sql: "select '100010001.0001'::numeric", + value: pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present}, + }, + { + sql: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", + value: pgtype.Numeric{ + Int: mustParseBigInt(t, "423723478923478928934789237432487213832189417894318904389012483210893443219085471578891547854892438945012347981"), + Exp: -41, + Status: pgtype.Present, + }, + }, + { + sql: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", + value: pgtype.Numeric{ + Int: mustParseBigInt(t, "8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), + Exp: -196, + Status: pgtype.Present, + }, + }, + { + sql: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", + value: pgtype.Numeric{ + Int: mustParseBigInt(t, "123"), + Exp: -186, + Status: pgtype.Present, + }, + }, + }) +} + +func TestNumericTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ + &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(1), Exp: 6, Status: pgtype.Present}, + + // preserves significant zeroes + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -1, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -2, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -3, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -4, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -5, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -6, Status: pgtype.Present}, + + &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -7, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -8, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -9, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -1500, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3723409723490243842378942378901237502734019231380123"), Exp: 81, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "723409723490243842378942378901237502734019231380123"), Exp: 82, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "23409723490243842378942378901237502734019231380123"), Exp: 83, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3409723490243842378942378901237502734019231380123"), Exp: 84, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3423409823409243892349028349023482934092340892390101"), Exp: -91, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "423409823409243892349028349023482934092340892390101"), Exp: -92, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "23409823409243892349028349023482934092340892390101"), Exp: -93, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3409823409243892349028349023482934092340892390101"), Exp: -94, Status: pgtype.Present}, + &pgtype.Numeric{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Numeric) + b := bb.(pgtype.Numeric) + + return numericEqual(&a, &b) + }) + +} + +func TestNumericTranscodeFuzz(t *testing.T) { + r := rand.New(rand.NewSource(0)) + max := &big.Int{} + max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) + + values := make([]interface{}, 0, 2000) + for i := 0; i < 10; i++ { + for j := -50; j < 50; j++ { + num := (&big.Int{}).Rand(r, max) + negNum := &big.Int{} + negNum.Neg(num) + values = append(values, &pgtype.Numeric{Int: num, Exp: int32(j), Status: pgtype.Present}) + values = append(values, &pgtype.Numeric{Int: negNum, Exp: int32(j), Status: pgtype.Present}) + } + } + + testSuccessfulTranscodeEqFunc(t, "numeric", values, + func(aa, bb interface{}) bool { + a := aa.(pgtype.Numeric) + b := bb.(pgtype.Numeric) + + return numericNormalizedEqual(&a, &b) + }) +} + +func TestNumericSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result *pgtype.Numeric + }{ + {source: float32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: float64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int8(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: int16(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: int32(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: int64(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: uint8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: uint16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: uint32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: uint64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: "1", result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: _int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: float64(1000), result: &pgtype.Numeric{Int: big.NewInt(1), Exp: 3, Status: pgtype.Present}}, + {source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Status: pgtype.Present}}, + {source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Status: pgtype.Present}}, + {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + r := &pgtype.Numeric{} + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !numericEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericAssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src *pgtype.Numeric + dst interface{} + expected interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src *pgtype.Numeric + dst interface{} + expected interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src *pgtype.Numeric + dst interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(150), Status: pgtype.Present}, dst: &i8}, + {src: &pgtype.Numeric{Int: big.NewInt(40000), Status: pgtype.Present}, dst: &i16}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui8}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui16}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui32}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui64}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype.go b/pgtype.go index 3d691044..84939b58 100644 --- a/pgtype.go +++ b/pgtype.go @@ -228,6 +228,7 @@ func init() { "cidr": &Cidr{}, "date": &Date{}, "daterange": &Daterange{}, + "decimal": &Decimal{}, "float4": &Float4{}, "float8": &Float8{}, "hstore": &Hstore{}, @@ -240,6 +241,7 @@ func init() { "json": &Json{}, "jsonb": &Jsonb{}, "name": &Name{}, + "numeric": &Numeric{}, "oid": &OidValue{}, "record": &Record{}, "text": &Text{}, From 066562fc89899eac7e67acc5194f285826e9a734 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 3 Apr 2017 07:35:19 -0500 Subject: [PATCH 0054/1158] Add pgtype.Numrange --- numrange.go | 268 +++++++++++++++++++++++++++++++++++++++++++++ numrange_test.go | 33 ++++++ pgtype.go | 1 + typed_range_gen.sh | 1 + 4 files changed, 303 insertions(+) create mode 100644 numrange.go create mode 100644 numrange_test.go diff --git a/numrange.go b/numrange.go new file mode 100644 index 00000000..cf42dcbd --- /dev/null +++ b/numrange.go @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Numrange struct { + Lower Numeric + Upper Numeric + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Numrange) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Numrange", src) +} + +func (dst *Numrange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Numrange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Numrange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Numrange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Numrange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Numrange) Scan(src interface{}) error { + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Numrange) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/numrange_test.go b/numrange_test.go new file mode 100644 index 00000000..81202362 --- /dev/null +++ b/numrange_test.go @@ -0,0 +1,33 @@ +package pgtype_test + +import ( + "math/big" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestNumrangeTranscode(t *testing.T) { + testSuccessfulTranscode(t, "numrange", []interface{}{ + pgtype.Numrange{ + LowerType: pgtype.Empty, + UpperType: pgtype.Empty, + Status: pgtype.Present, + }, + pgtype.Numrange{ + Lower: pgtype.Numeric{Int: big.NewInt(-543), Exp: 3, Status: pgtype.Present}, + Upper: pgtype.Numeric{Int: big.NewInt(342), Exp: 1, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Numrange{ + Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, + Upper: pgtype.Numeric{Int: big.NewInt(-5), Exp: 0, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Numrange{Status: pgtype.Null}, + }) +} diff --git a/pgtype.go b/pgtype.go index 84939b58..d7e28641 100644 --- a/pgtype.go +++ b/pgtype.go @@ -242,6 +242,7 @@ func init() { "jsonb": &Jsonb{}, "name": &Name{}, "numeric": &Numeric{}, + "numrange": &Numrange{}, "oid": &OidValue{}, "record": &Record{}, "text": &Text{}, diff --git a/typed_range_gen.sh b/typed_range_gen.sh index b4220f09..bedda292 100644 --- a/typed_range_gen.sh +++ b/typed_range_gen.sh @@ -3,4 +3,5 @@ erb range_type=Int8range element_type=Int8 typed_range.go.erb > int8range.go erb range_type=Tsrange element_type=Timestamp typed_range.go.erb > tsrange.go erb range_type=Tstzrange element_type=Timestamptz typed_range.go.erb > tstzrange.go erb range_type=Daterange element_type=Date typed_range.go.erb > daterange.go +erb range_type=Numrange element_type=Numeric typed_range.go.erb > numrange.go goimports -w *range.go From cc873a0bcf6e2c6c51cab6597f1d3421fb84a121 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 3 Apr 2017 07:46:45 -0500 Subject: [PATCH 0055/1158] Add pgtype.NumericArray --- numeric_array.go | 357 ++++++++++++++++++++++++++++++++++++++++++ numeric_array_test.go | 159 +++++++++++++++++++ pgtype.go | 1 + typed_array_gen.sh | 1 + 4 files changed, 518 insertions(+) create mode 100644 numeric_array.go create mode 100644 numeric_array_test.go diff --git a/numeric_array.go b/numeric_array.go new file mode 100644 index 00000000..b147e6a2 --- /dev/null +++ b/numeric_array.go @@ -0,0 +1,357 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type NumericArray struct { + Elements []Numeric + Dimensions []ArrayDimension + Status Status +} + +func (dst *NumericArray) Set(src interface{}) error { + switch value := src.(type) { + + case []float32: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []float64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Numeric", value) + } + + return nil +} + +func (dst *NumericArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *NumericArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]float32: + *v = make([]float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]float64: + *v = make([]float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return nullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = NumericArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Numeric + + if len(uta.Elements) > 0 { + elements = make([]Numeric, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Numeric + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = NumericArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *NumericArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = NumericArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = NumericArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Numeric, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = NumericArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *NumericArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil +} + +func (src *NumericArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("numeric"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "numeric") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(ci, w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *NumericArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *NumericArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/numeric_array_test.go b/numeric_array_test.go new file mode 100644 index 00000000..af2e8e51 --- /dev/null +++ b/numeric_array_test.go @@ -0,0 +1,159 @@ +package pgtype_test + +import ( + "math/big" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestNumericArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "numeric[]", []interface{}{ + &pgtype.NumericArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}, + pgtype.Numeric{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.NumericArray{Status: pgtype.Null}, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(2), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(3), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(4), Status: pgtype.Present}, + pgtype.Numeric{Status: pgtype.Null}, + pgtype.Numeric{Int: big.NewInt(6), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(2), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(3), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(4), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestNumericArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.NumericArray + }{ + { + source: []float32{1}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []float64{1}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]float32)(nil)), + result: pgtype.NumericArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.NumericArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericArrayAssignTo(t *testing.T) { + var float32Slice []float32 + var float64Slice []float64 + + simpleTests := []struct { + src pgtype.NumericArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + expected: []float32{1}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float64Slice, + expected: []float64{1}, + }, + { + src: pgtype.NumericArray{Status: pgtype.Null}, + dst: &float32Slice, + expected: (([]float32)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.NumericArray + dst interface{} + }{ + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype.go b/pgtype.go index d7e28641..208b1f00 100644 --- a/pgtype.go +++ b/pgtype.go @@ -216,6 +216,7 @@ func init() { "_int2": &Int2Array{}, "_int4": &Int4Array{}, "_int8": &Int8Array{}, + "_numeric": &NumericArray{}, "_text": &TextArray{}, "_timestamp": &TimestampArray{}, "_timestamptz": &TimestamptzArray{}, diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 52612466..2e36b8b3 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -14,4 +14,5 @@ erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[] erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go +erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]float64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go goimports -w *_array.go From 0079bd5095f0ee58b0e67022b66f9c56b4d56326 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 3 Apr 2017 17:53:32 -0500 Subject: [PATCH 0056/1158] Add pgtype.Point --- pgtype.go | 1 + point.go | 139 ++++++++++++++++++++++++++++++++++++++++++++++++++ point_test.go | 15 ++++++ 3 files changed, 155 insertions(+) create mode 100644 point.go create mode 100644 point_test.go diff --git a/pgtype.go b/pgtype.go index 208b1f00..911ab70e 100644 --- a/pgtype.go +++ b/pgtype.go @@ -245,6 +245,7 @@ func init() { "numeric": &Numeric{}, "numrange": &Numrange{}, "oid": &OidValue{}, + "point": &Point{}, "record": &Record{}, "text": &Text{}, "tid": &Tid{}, diff --git a/point.go b/point.go new file mode 100644 index 00000000..1b40bc44 --- /dev/null +++ b/point.go @@ -0,0 +1,139 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Point struct { + X float64 + Y float64 + Status Status +} + +func (dst *Point) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Point", src) +} + +func (dst *Point) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Point) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return fmt.Errorf("invalid format for point") + } + + x, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return err + } + + y, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return err + } + + *dst = Point{X: x, Y: y, Status: Present} + return nil +} + +func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + + *dst = Point{ + X: math.Float64frombits(x), + Y: math.Float64frombits(y), + Status: Present, + } + return nil +} + +func (src *Point) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, src.X, src.Y)) + return false, err +} + +func (src *Point) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := pgio.WriteUint64(w, math.Float64bits(src.X)) + if err != nil { + return false, err + } + + _, err = pgio.WriteUint64(w, math.Float64bits(src.Y)) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Point) Scan(src interface{}) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Point) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/point_test.go b/point_test.go new file mode 100644 index 00000000..4ddb8009 --- /dev/null +++ b/point_test.go @@ -0,0 +1,15 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestPointTranscode(t *testing.T) { + testSuccessfulTranscode(t, "point", []interface{}{ + &pgtype.Point{X: 1.234, Y: 5.6789, Status: pgtype.Present}, + &pgtype.Point{X: -1.234, Y: -5.6789, Status: pgtype.Present}, + &pgtype.Point{Status: pgtype.Null}, + }) +} From dccbbc6a4043789dadc6cf7ad4753895d9898371 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 3 Apr 2017 19:47:36 -0500 Subject: [PATCH 0057/1158] Add pgtype.Box --- box.go | 168 ++++++++++++++++++++++++++++++++++++++++++++++++++ box_test.go | 33 ++++++++++ pgtype.go | 1 + point.go | 13 ++-- point_test.go | 4 +- 5 files changed, 212 insertions(+), 7 deletions(-) create mode 100644 box.go create mode 100644 box_test.go diff --git a/box.go b/box.go new file mode 100644 index 00000000..eaaddbff --- /dev/null +++ b/box.go @@ -0,0 +1,168 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Box struct { + Corners [2]Vec2 + Status Status +} + +func (dst *Box) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Box", src) +} + +func (dst *Box) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Box) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Box{Status: Null} + return nil + } + + if len(src) < 11 { + return fmt.Errorf("invalid length for Box: %v", len(src)) + } + + str := string(src[1:]) + + var end int + end = strings.IndexByte(str, ',') + + x1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+3:] + end = strings.IndexByte(str, ',') + + x2, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1 : len(str)-1] + + y2, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + *dst = Box{Corners: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present} + return nil +} + +func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Box{Status: Null} + return nil + } + + if len(src) != 32 { + return fmt.Errorf("invalid length for Box: %v", len(src)) + } + + x1 := binary.BigEndian.Uint64(src) + y1 := binary.BigEndian.Uint64(src[8:]) + x2 := binary.BigEndian.Uint64(src[16:]) + y2 := binary.BigEndian.Uint64(src[24:]) + + *dst = Box{ + Corners: [2]Vec2{ + {math.Float64frombits(x1), math.Float64frombits(y1)}, + {math.Float64frombits(x2), math.Float64frombits(y2)}, + }, + Status: Present, + } + return nil +} + +func (src *Box) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f),(%f,%f)`, + src.Corners[0].X, src.Corners[0].Y, src.Corners[1].X, src.Corners[1].Y)) + return false, err +} + +func (src *Box) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.Corners[0].X)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.Corners[0].Y)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.Corners[1].X)); err != nil { + return false, err + } + + _, err := pgio.WriteUint64(w, math.Float64bits(src.Corners[1].Y)) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Box) Scan(src interface{}) error { + if src == nil { + *dst = Box{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Box) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/box_test.go b/box_test.go new file mode 100644 index 00000000..21446dc3 --- /dev/null +++ b/box_test.go @@ -0,0 +1,33 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestBoxTranscode(t *testing.T) { + testSuccessfulTranscode(t, "box", []interface{}{ + &pgtype.Box{ + Corners: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, + Status: pgtype.Present, + }, + &pgtype.Box{ + Corners: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Status: pgtype.Present, + }, + &pgtype.Box{Status: pgtype.Null}, + }) +} + +func TestBoxNormalize(t *testing.T) { + testSuccessfulNormalize(t, []normalizeTest{ + { + sql: "select '3.14, 1.678, 7.1, 5.234'::box", + value: &pgtype.Box{ + Corners: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, + Status: pgtype.Present, + }, + }, + }) +} diff --git a/pgtype.go b/pgtype.go index 911ab70e..b29bc90c 100644 --- a/pgtype.go +++ b/pgtype.go @@ -223,6 +223,7 @@ func init() { "_varchar": &VarcharArray{}, "aclitem": &Aclitem{}, "bool": &Bool{}, + "box": &Box{}, "bytea": &Bytea{}, "char": &QChar{}, "cid": &Cid{}, diff --git a/point.go b/point.go index 1b40bc44..94f753e3 100644 --- a/point.go +++ b/point.go @@ -12,9 +12,13 @@ import ( "github.com/jackc/pgx/pgio" ) +type Vec2 struct { + X float64 + Y float64 +} + type Point struct { - X float64 - Y float64 + Vec2 Status Status } @@ -62,7 +66,7 @@ func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Point{X: x, Y: y, Status: Present} + *dst = Point{Vec2: Vec2{x, y}, Status: Present} return nil } @@ -80,8 +84,7 @@ func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { y := binary.BigEndian.Uint64(src[8:]) *dst = Point{ - X: math.Float64frombits(x), - Y: math.Float64frombits(y), + Vec2: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, Status: Present, } return nil diff --git a/point_test.go b/point_test.go index 4ddb8009..723dfa60 100644 --- a/point_test.go +++ b/point_test.go @@ -8,8 +8,8 @@ import ( func TestPointTranscode(t *testing.T) { testSuccessfulTranscode(t, "point", []interface{}{ - &pgtype.Point{X: 1.234, Y: 5.6789, Status: pgtype.Present}, - &pgtype.Point{X: -1.234, Y: -5.6789, Status: pgtype.Present}, + &pgtype.Point{Vec2: pgtype.Vec2{1.234, 5.6789}, Status: pgtype.Present}, + &pgtype.Point{Vec2: pgtype.Vec2{-1.234, -5.6789}, Status: pgtype.Present}, &pgtype.Point{Status: pgtype.Null}, }) } From 2fc89c69e9d1f205021721d63806edde5988918a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 08:04:40 -0500 Subject: [PATCH 0058/1158] Add pgtype.Line --- line.go | 148 +++++++++++++++++++++++++++++++++++++++++++++++++++ line_test.go | 21 ++++++++ pgtype.go | 1 + 3 files changed, 170 insertions(+) create mode 100644 line.go create mode 100644 line_test.go diff --git a/line.go b/line.go new file mode 100644 index 00000000..08a74e84 --- /dev/null +++ b/line.go @@ -0,0 +1,148 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Line struct { + A, B, C float64 + Status Status +} + +func (dst *Line) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Line", src) +} + +func (dst *Line) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Line) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Line{Status: Null} + return nil + } + + if len(src) < 7 { + return fmt.Errorf("invalid length for Line: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 3) + if len(parts) < 3 { + return fmt.Errorf("invalid format for line") + } + + a, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return err + } + + b, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return err + } + + c, err := strconv.ParseFloat(parts[2], 64) + if err != nil { + return err + } + + *dst = Line{A: a, B: b, C: c, Status: Present} + return nil +} + +func (dst *Line) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Line{Status: Null} + return nil + } + + if len(src) != 24 { + return fmt.Errorf("invalid length for Line: %v", len(src)) + } + + a := binary.BigEndian.Uint64(src) + b := binary.BigEndian.Uint64(src[8:]) + c := binary.BigEndian.Uint64(src[16:]) + + *dst = Line{ + A: math.Float64frombits(a), + B: math.Float64frombits(b), + C: math.Float64frombits(c), + Status: Present, + } + return nil +} + +func (src *Line) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, fmt.Sprintf(`{%f,%f,%f}`, src.A, src.B, src.C)) + return false, err +} + +func (src *Line) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.A)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.B)); err != nil { + return false, err + } + + _, err := pgio.WriteUint64(w, math.Float64bits(src.C)) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Line) Scan(src interface{}) error { + if src == nil { + *dst = Line{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Line) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/line_test.go b/line_test.go new file mode 100644 index 00000000..6d3b02e1 --- /dev/null +++ b/line_test.go @@ -0,0 +1,21 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestLineTranscode(t *testing.T) { + testSuccessfulTranscode(t, "line", []interface{}{ + &pgtype.Line{ + A: 1.23, B: 4.56, C: 7.89, + Status: pgtype.Present, + }, + &pgtype.Line{ + A: -1.23, B: -4.56, C: -7.89, + Status: pgtype.Present, + }, + &pgtype.Line{Status: pgtype.Null}, + }) +} diff --git a/pgtype.go b/pgtype.go index b29bc90c..c92dfccf 100644 --- a/pgtype.go +++ b/pgtype.go @@ -242,6 +242,7 @@ func init() { "int8range": &Int8range{}, "json": &Json{}, "jsonb": &Jsonb{}, + "line": &Line{}, "name": &Name{}, "numeric": &Numeric{}, "numrange": &Numrange{}, From d8a778811eefcabd76416f8d8c28a5d079d18626 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 08:16:02 -0500 Subject: [PATCH 0059/1158] Add pgtype.Lseg --- box.go | 18 +++--- box_test.go | 12 ++-- lseg.go | 168 +++++++++++++++++++++++++++++++++++++++++++++++++++ lseg_test.go | 21 +++++++ pgtype.go | 1 + 5 files changed, 205 insertions(+), 15 deletions(-) create mode 100644 lseg.go create mode 100644 lseg_test.go diff --git a/box.go b/box.go index eaaddbff..138953a5 100644 --- a/box.go +++ b/box.go @@ -13,8 +13,8 @@ import ( ) type Box struct { - Corners [2]Vec2 - Status Status + P [2]Vec2 + Status Status } func (dst *Box) Set(src interface{}) error { @@ -79,7 +79,7 @@ func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Box{Corners: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present} + *dst = Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present} return nil } @@ -99,7 +99,7 @@ func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { y2 := binary.BigEndian.Uint64(src[24:]) *dst = Box{ - Corners: [2]Vec2{ + P: [2]Vec2{ {math.Float64frombits(x1), math.Float64frombits(y1)}, {math.Float64frombits(x2), math.Float64frombits(y2)}, }, @@ -117,7 +117,7 @@ func (src *Box) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f),(%f,%f)`, - src.Corners[0].X, src.Corners[0].Y, src.Corners[1].X, src.Corners[1].Y)) + src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)) return false, err } @@ -129,19 +129,19 @@ func (src *Box) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, errUndefined } - if _, err := pgio.WriteUint64(w, math.Float64bits(src.Corners[0].X)); err != nil { + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[0].X)); err != nil { return false, err } - if _, err := pgio.WriteUint64(w, math.Float64bits(src.Corners[0].Y)); err != nil { + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[0].Y)); err != nil { return false, err } - if _, err := pgio.WriteUint64(w, math.Float64bits(src.Corners[1].X)); err != nil { + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[1].X)); err != nil { return false, err } - _, err := pgio.WriteUint64(w, math.Float64bits(src.Corners[1].Y)) + _, err := pgio.WriteUint64(w, math.Float64bits(src.P[1].Y)) return false, err } diff --git a/box_test.go b/box_test.go index 21446dc3..00732973 100644 --- a/box_test.go +++ b/box_test.go @@ -9,12 +9,12 @@ import ( func TestBoxTranscode(t *testing.T) { testSuccessfulTranscode(t, "box", []interface{}{ &pgtype.Box{ - Corners: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, - Status: pgtype.Present, + P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, + Status: pgtype.Present, }, &pgtype.Box{ - Corners: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, - Status: pgtype.Present, + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Status: pgtype.Present, }, &pgtype.Box{Status: pgtype.Null}, }) @@ -25,8 +25,8 @@ func TestBoxNormalize(t *testing.T) { { sql: "select '3.14, 1.678, 7.1, 5.234'::box", value: &pgtype.Box{ - Corners: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, - Status: pgtype.Present, + P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, + Status: pgtype.Present, }, }, }) diff --git a/lseg.go b/lseg.go new file mode 100644 index 00000000..b86256e0 --- /dev/null +++ b/lseg.go @@ -0,0 +1,168 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Lseg struct { + P [2]Vec2 + Status Status +} + +func (dst *Lseg) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Lseg", src) +} + +func (dst *Lseg) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Lseg) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Lseg{Status: Null} + return nil + } + + if len(src) < 11 { + return fmt.Errorf("invalid length for Lseg: %v", len(src)) + } + + str := string(src[2:]) + + var end int + end = strings.IndexByte(str, ',') + + x1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+3:] + end = strings.IndexByte(str, ',') + + x2, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1 : len(str)-2] + + y2, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + *dst = Lseg{P: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present} + return nil +} + +func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Lseg{Status: Null} + return nil + } + + if len(src) != 32 { + return fmt.Errorf("invalid length for Lseg: %v", len(src)) + } + + x1 := binary.BigEndian.Uint64(src) + y1 := binary.BigEndian.Uint64(src[8:]) + x2 := binary.BigEndian.Uint64(src[16:]) + y2 := binary.BigEndian.Uint64(src[24:]) + + *dst = Lseg{ + P: [2]Vec2{ + {math.Float64frombits(x1), math.Float64frombits(y1)}, + {math.Float64frombits(x2), math.Float64frombits(y2)}, + }, + Status: Present, + } + return nil +} + +func (src *Lseg) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f),(%f,%f)`, + src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)) + return false, err +} + +func (src *Lseg) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[0].X)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[0].Y)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[1].X)); err != nil { + return false, err + } + + _, err := pgio.WriteUint64(w, math.Float64bits(src.P[1].Y)) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Lseg) Scan(src interface{}) error { + if src == nil { + *dst = Lseg{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Lseg) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/lseg_test.go b/lseg_test.go new file mode 100644 index 00000000..5f041263 --- /dev/null +++ b/lseg_test.go @@ -0,0 +1,21 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestLsegTranscode(t *testing.T) { + testSuccessfulTranscode(t, "lseg", []interface{}{ + &pgtype.Lseg{ + P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}}, + Status: pgtype.Present, + }, + &pgtype.Lseg{ + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Status: pgtype.Present, + }, + &pgtype.Lseg{Status: pgtype.Null}, + }) +} diff --git a/pgtype.go b/pgtype.go index c92dfccf..6d1f49af 100644 --- a/pgtype.go +++ b/pgtype.go @@ -243,6 +243,7 @@ func init() { "json": &Json{}, "jsonb": &Jsonb{}, "line": &Line{}, + "lseg": &Lseg{}, "name": &Name{}, "numeric": &Numeric{}, "numrange": &Numrange{}, From f4bdd8300f86289d9de626c992ddde6e049b89a8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 08:40:41 -0500 Subject: [PATCH 0060/1158] Add path --- path.go | 207 +++++++++++++++++++++++++++++++++++++++++++++++++++ path_test.go | 28 +++++++ pgtype.go | 1 + 3 files changed, 236 insertions(+) create mode 100644 path.go create mode 100644 path_test.go diff --git a/path.go b/path.go new file mode 100644 index 00000000..fb4193d9 --- /dev/null +++ b/path.go @@ -0,0 +1,207 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Path struct { + P []Vec2 + Closed bool + Status Status +} + +func (dst *Path) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Path", src) +} + +func (dst *Path) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Path) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Path) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Path{Status: Null} + return nil + } + + if len(src) < 7 { + return fmt.Errorf("invalid length for Path: %v", len(src)) + } + + closed := src[0] == '(' + points := make([]Vec2, 0) + + str := string(src[2:]) + + for { + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + points = append(points, Vec2{x, y}) + + if end+3 < len(str) { + str = str[end+3:] + } else { + break + } + } + + *dst = Path{P: points, Closed: closed, Status: Present} + return nil +} + +func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Path{Status: Null} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for Path: %v", len(src)) + } + + closed := src[0] == 1 + pointCount := int(binary.BigEndian.Uint32(src[1:])) + + rp := 5 + + if 5+pointCount*16 != len(src) { + return fmt.Errorf("invalid length for Path with %d points: %v", pointCount, len(src)) + } + + points := make([]Vec2, pointCount) + for i := 0; i < len(points); i++ { + x := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + y := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} + } + + *dst = Path{ + P: points, + Closed: closed, + Status: Present, + } + return nil +} + +func (src *Path) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var startByte, endByte byte + if src.Closed { + startByte = '(' + endByte = ')' + } else { + startByte = '[' + endByte = ']' + } + if err := pgio.WriteByte(w, startByte); err != nil { + return false, err + } + + for i, p := range src.P { + if i > 0 { + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + } + if _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)); err != nil { + return false, err + } + } + + err := pgio.WriteByte(w, endByte) + return false, err +} + +func (src *Path) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var closeByte byte + if src.Closed { + closeByte = 1 + } + if err := pgio.WriteByte(w, closeByte); err != nil { + return false, err + } + + if _, err := pgio.WriteInt32(w, int32(len(src.P))); err != nil { + return false, err + } + + for _, p := range src.P { + if _, err := pgio.WriteUint64(w, math.Float64bits(p.X)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(p.Y)); err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Path) Scan(src interface{}) error { + if src == nil { + *dst = Path{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Path) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/path_test.go b/path_test.go new file mode 100644 index 00000000..4e5f7f62 --- /dev/null +++ b/path_test.go @@ -0,0 +1,28 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestPathTranscode(t *testing.T) { + testSuccessfulTranscode(t, "path", []interface{}{ + &pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}}, + Closed: false, + Status: pgtype.Present, + }, + &pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, + Closed: true, + Status: pgtype.Present, + }, + &pgtype.Path{ + P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Closed: true, + Status: pgtype.Present, + }, + &pgtype.Path{Status: pgtype.Null}, + }) +} diff --git a/pgtype.go b/pgtype.go index 6d1f49af..18d21e20 100644 --- a/pgtype.go +++ b/pgtype.go @@ -248,6 +248,7 @@ func init() { "numeric": &Numeric{}, "numrange": &Numrange{}, "oid": &OidValue{}, + "path": &Path{}, "point": &Point{}, "record": &Record{}, "text": &Text{}, From 8cbf667b8e59f9e16e25df5f08d634c1978f95a4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 20:24:01 -0500 Subject: [PATCH 0061/1158] Add pgtype.Uuid --- pgtype.go | 1 + uuid.go | 173 +++++++++++++++++++++++++++++++++++++++++++++++++++ uuid_test.go | 95 ++++++++++++++++++++++++++++ 3 files changed, 269 insertions(+) create mode 100644 uuid.go create mode 100644 uuid_test.go diff --git a/pgtype.go b/pgtype.go index 18d21e20..5c8adb6e 100644 --- a/pgtype.go +++ b/pgtype.go @@ -258,6 +258,7 @@ func init() { "tsrange": &Tsrange{}, "tstzrange": &Tstzrange{}, "unknown": &Unknown{}, + "uuid": &Uuid{}, "varchar": &Varchar{}, "xid": &Xid{}, } diff --git a/uuid.go b/uuid.go new file mode 100644 index 00000000..111bed35 --- /dev/null +++ b/uuid.go @@ -0,0 +1,173 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/hex" + "fmt" + "io" +) + +type Uuid struct { + Bytes [16]byte + Status Status +} + +func (dst *Uuid) Set(src interface{}) error { + switch value := src.(type) { + case [16]byte: + *dst = Uuid{Bytes: value, Status: Present} + case []byte: + if len(value) != 16 { + return fmt.Errorf("[]byte must be 16 bytes to convert to Uuid: %d", len(value)) + } + *dst = Uuid{Status: Present} + copy(dst.Bytes[:], value) + case string: + uuid, err := parseUuid(value) + if err != nil { + return err + } + *dst = Uuid{Bytes: uuid, Status: Present} + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Uuid", value) + } + + return nil +} + +func (dst *Uuid) Get() interface{} { + switch dst.Status { + case Present: + return dst.Bytes + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Uuid) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *[16]byte: + *v = src.Bytes + return nil + case *[]byte: + *v = make([]byte, 16) + copy(*v, src.Bytes[:]) + return nil + case *string: + *v = encodeUuid(src.Bytes) + return nil + default: + if nextDst, retry := GetAssignToDstType(v); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return nullAssignTo(dst) + } + + return fmt.Errorf("cannot assign %v into %T", src, dst) +} + +// parseUuid converts a string UUID in standard form to a byte array. +func parseUuid(src string) (dst [16]byte, err error) { + src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:] + buf, err := hex.DecodeString(src) + if err != nil { + return dst, err + } + + copy(dst[:], buf) + return dst, err +} + +// encodeUuid converts a uuid byte array to UUID standard string form. +func encodeUuid(src [16]byte) string { + return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16]) +} + +func (dst *Uuid) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Uuid{Status: Null} + return nil + } + + if len(src) != 36 { + return fmt.Errorf("invalid length for Uuid: %v", len(src)) + } + + buf, err := parseUuid(string(src)) + if err != nil { + return err + } + + *dst = Uuid{Bytes: buf, Status: Present} + return nil +} + +func (dst *Uuid) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Uuid{Status: Null} + return nil + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for Uuid: %v", len(src)) + } + + *dst = Uuid{Status: Present} + copy(dst.Bytes[:], src) + return nil +} + +func (src Uuid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, encodeUuid(src.Bytes)) + return false, err +} + +func (src Uuid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := w.Write(src.Bytes[:]) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Uuid) Scan(src interface{}) error { + if src == nil { + *dst = Uuid{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Uuid) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/uuid_test.go b/uuid_test.go new file mode 100644 index 00000000..1eba7e90 --- /dev/null +++ b/uuid_test.go @@ -0,0 +1,95 @@ +package pgtype_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestUuidTranscode(t *testing.T) { + testSuccessfulTranscode(t, "uuid", []interface{}{ + pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + pgtype.Uuid{Status: pgtype.Null}, + }) +} + +func TestUuidSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Uuid + }{ + { + source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: "00010203-0405-0607-0809-0a0b0c0d0e0f", + result: pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Uuid + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestUuidAssignTo(t *testing.T) { + { + src := pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst [16]byte + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst []byte + expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare(dst, expected) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst string + expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + +} From 6a0b41e50a68c1099326223a0f15e01a7b9fb023 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 20:30:04 -0500 Subject: [PATCH 0062/1158] Add pgtype.Polygon --- pgtype.go | 1 + polygon.go | 186 ++++++++++++++++++++++++++++++++++++++++++++++++ polygon_test.go | 21 ++++++ 3 files changed, 208 insertions(+) create mode 100644 polygon.go create mode 100644 polygon_test.go diff --git a/pgtype.go b/pgtype.go index 5c8adb6e..cb0cec2c 100644 --- a/pgtype.go +++ b/pgtype.go @@ -250,6 +250,7 @@ func init() { "oid": &OidValue{}, "path": &Path{}, "point": &Point{}, + "polygon": &Polygon{}, "record": &Record{}, "text": &Text{}, "tid": &Tid{}, diff --git a/polygon.go b/polygon.go new file mode 100644 index 00000000..1e2df011 --- /dev/null +++ b/polygon.go @@ -0,0 +1,186 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Polygon struct { + P []Vec2 + Status Status +} + +func (dst *Polygon) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Polygon", src) +} + +func (dst *Polygon) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Polygon) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Polygon) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Polygon{Status: Null} + return nil + } + + if len(src) < 7 { + return fmt.Errorf("invalid length for Polygon: %v", len(src)) + } + + points := make([]Vec2, 0) + + str := string(src[2:]) + + for { + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + points = append(points, Vec2{x, y}) + + if end+3 < len(str) { + str = str[end+3:] + } else { + break + } + } + + *dst = Polygon{P: points, Status: Present} + return nil +} + +func (dst *Polygon) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Polygon{Status: Null} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for Polygon: %v", len(src)) + } + + pointCount := int(binary.BigEndian.Uint32(src)) + rp := 4 + + if 4+pointCount*16 != len(src) { + return fmt.Errorf("invalid length for Polygon with %d points: %v", pointCount, len(src)) + } + + points := make([]Vec2, pointCount) + for i := 0; i < len(points); i++ { + x := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + y := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} + } + + *dst = Polygon{ + P: points, + Status: Present, + } + return nil +} + +func (src *Polygon) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + + for i, p := range src.P { + if i > 0 { + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + } + if _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)); err != nil { + return false, err + } + } + + err := pgio.WriteByte(w, ')') + return false, err +} + +func (src *Polygon) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := pgio.WriteInt32(w, int32(len(src.P))); err != nil { + return false, err + } + + for _, p := range src.P { + if _, err := pgio.WriteUint64(w, math.Float64bits(p.X)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(p.Y)); err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Polygon) Scan(src interface{}) error { + if src == nil { + *dst = Polygon{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Polygon) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/polygon_test.go b/polygon_test.go new file mode 100644 index 00000000..3a7e1431 --- /dev/null +++ b/polygon_test.go @@ -0,0 +1,21 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestPolygonTranscode(t *testing.T) { + testSuccessfulTranscode(t, "polygon", []interface{}{ + &pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {5.0, 3.234}}, + Status: pgtype.Present, + }, + &pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, + Status: pgtype.Present, + }, + &pgtype.Polygon{Status: pgtype.Null}, + }) +} From d99d09b0d197629f85b8b15cfa1b6ff3f967de68 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 20:39:48 -0500 Subject: [PATCH 0063/1158] Add pgtype.Circle Also rename Point.Vec2 to Point.P to conform to rest of geometric types. --- circle.go | 150 +++++++++++++++++++++++++++++++++++++++++++++++++ circle_test.go | 15 +++++ pgtype.go | 1 + point.go | 12 ++-- point_test.go | 4 +- 5 files changed, 174 insertions(+), 8 deletions(-) create mode 100644 circle.go create mode 100644 circle_test.go diff --git a/circle.go b/circle.go new file mode 100644 index 00000000..62e2e8b3 --- /dev/null +++ b/circle.go @@ -0,0 +1,150 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Circle struct { + P Vec2 + R float64 + Status Status +} + +func (dst *Circle) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Circle", src) +} + +func (dst *Circle) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Circle) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Circle{Status: Null} + return nil + } + + if len(src) < 9 { + return fmt.Errorf("invalid length for Circle: %v", len(src)) + } + + str := string(src[2:]) + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+2 : len(str)-1] + + r, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + *dst = Circle{P: Vec2{x, y}, R: r, Status: Present} + return nil +} + +func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Circle{Status: Null} + return nil + } + + if len(src) != 24 { + return fmt.Errorf("invalid length for Circle: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + r := binary.BigEndian.Uint64(src[16:]) + + *dst = Circle{ + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + R: math.Float64frombits(r), + Status: Present, + } + return nil +} + +func (src *Circle) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, fmt.Sprintf(`<(%f,%f),%f>`, src.P.X, src.P.Y, src.R)) + return false, err +} + +func (src *Circle) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P.X)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P.Y)); err != nil { + return false, err + } + + _, err := pgio.WriteUint64(w, math.Float64bits(src.R)) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Circle) Scan(src interface{}) error { + if src == nil { + *dst = Circle{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Circle) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/circle_test.go b/circle_test.go new file mode 100644 index 00000000..9746dd74 --- /dev/null +++ b/circle_test.go @@ -0,0 +1,15 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestCircleTranscode(t *testing.T) { + testSuccessfulTranscode(t, "circle", []interface{}{ + &pgtype.Circle{P: pgtype.Vec2{1.234, 5.6789}, R: 3.5, Status: pgtype.Present}, + &pgtype.Circle{P: pgtype.Vec2{-1.234, -5.6789}, R: 12.9, Status: pgtype.Present}, + &pgtype.Circle{Status: pgtype.Null}, + }) +} diff --git a/pgtype.go b/pgtype.go index cb0cec2c..52cad561 100644 --- a/pgtype.go +++ b/pgtype.go @@ -228,6 +228,7 @@ func init() { "char": &QChar{}, "cid": &Cid{}, "cidr": &Cidr{}, + "circle": &Circle{}, "date": &Date{}, "daterange": &Daterange{}, "decimal": &Decimal{}, diff --git a/point.go b/point.go index 94f753e3..788a76c9 100644 --- a/point.go +++ b/point.go @@ -18,7 +18,7 @@ type Vec2 struct { } type Point struct { - Vec2 + P Vec2 Status Status } @@ -66,7 +66,7 @@ func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Point{Vec2: Vec2{x, y}, Status: Present} + *dst = Point{P: Vec2{x, y}, Status: Present} return nil } @@ -84,7 +84,7 @@ func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { y := binary.BigEndian.Uint64(src[8:]) *dst = Point{ - Vec2: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, Status: Present, } return nil @@ -98,7 +98,7 @@ func (src *Point) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, errUndefined } - _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, src.X, src.Y)) + _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, src.P.X, src.P.Y)) return false, err } @@ -110,12 +110,12 @@ func (src *Point) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, errUndefined } - _, err := pgio.WriteUint64(w, math.Float64bits(src.X)) + _, err := pgio.WriteUint64(w, math.Float64bits(src.P.X)) if err != nil { return false, err } - _, err = pgio.WriteUint64(w, math.Float64bits(src.Y)) + _, err = pgio.WriteUint64(w, math.Float64bits(src.P.Y)) return false, err } diff --git a/point_test.go b/point_test.go index 723dfa60..c921f794 100644 --- a/point_test.go +++ b/point_test.go @@ -8,8 +8,8 @@ import ( func TestPointTranscode(t *testing.T) { testSuccessfulTranscode(t, "point", []interface{}{ - &pgtype.Point{Vec2: pgtype.Vec2{1.234, 5.6789}, Status: pgtype.Present}, - &pgtype.Point{Vec2: pgtype.Vec2{-1.234, -5.6789}, Status: pgtype.Present}, + &pgtype.Point{P: pgtype.Vec2{1.234, 5.6789}, Status: pgtype.Present}, + &pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Status: pgtype.Present}, &pgtype.Point{Status: pgtype.Null}, }) } From 3631b076fe3654b6287ea9f8729e4c56932df05e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 21:07:27 -0500 Subject: [PATCH 0064/1158] Add pgtype.Macaddr --- macaddr.go | 154 ++++++++++++++++++++++++++++++++++++++++++++++++ macaddr_test.go | 77 ++++++++++++++++++++++++ pgtype.go | 1 + pgtype_test.go | 9 +++ 4 files changed, 241 insertions(+) create mode 100644 macaddr.go create mode 100644 macaddr_test.go diff --git a/macaddr.go b/macaddr.go new file mode 100644 index 00000000..2d09ff8c --- /dev/null +++ b/macaddr.go @@ -0,0 +1,154 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + "io" + "net" +) + +type Macaddr struct { + Addr net.HardwareAddr + Status Status +} + +func (dst *Macaddr) Set(src interface{}) error { + if src == nil { + *dst = Macaddr{Status: Null} + return nil + } + + switch value := src.(type) { + case net.HardwareAddr: + addr := make(net.HardwareAddr, len(value)) + copy(addr, value) + *dst = Macaddr{Addr: addr, Status: Present} + case string: + addr, err := net.ParseMAC(value) + if err != nil { + return err + } + *dst = Macaddr{Addr: addr, Status: Present} + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Macaddr", value) + } + + return nil +} + +func (dst *Macaddr) Get() interface{} { + switch dst.Status { + case Present: + return dst.Addr + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Macaddr) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *net.HardwareAddr: + *v = make(net.HardwareAddr, len(src.Addr)) + copy(*v, src.Addr) + return nil + case *string: + *v = src.Addr.String() + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return nullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Macaddr) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Macaddr{Status: Null} + return nil + } + + addr, err := net.ParseMAC(string(src)) + if err != nil { + return err + } + + *dst = Macaddr{Addr: addr, Status: Present} + return nil +} + +func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Macaddr{Status: Null} + return nil + } + + if len(src) != 6 { + return fmt.Errorf("Received an invalid size for a macaddr: %d", len(src)) + } + + addr := make(net.HardwareAddr, 6) + copy(addr, src) + + *dst = Macaddr{Addr: addr, Status: Present} + + return nil +} + +func (src Macaddr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, src.Addr.String()) + return false, err +} + +// EncodeBinary encodes src into w. +func (src Macaddr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := w.Write([]byte(src.Addr)) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Macaddr) Scan(src interface{}) error { + if src == nil { + *dst = Macaddr{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Macaddr) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/macaddr_test.go b/macaddr_test.go new file mode 100644 index 00000000..6c7b8b89 --- /dev/null +++ b/macaddr_test.go @@ -0,0 +1,77 @@ +package pgtype_test + +import ( + "bytes" + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestMacaddrTranscode(t *testing.T) { + testSuccessfulTranscode(t, "macaddr", []interface{}{ + pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + pgtype.Macaddr{Status: pgtype.Null}, + }) +} + +func TestMacaddrSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Macaddr + }{ + { + source: mustParseMacaddr(t, "01:23:45:67:89:ab"), + result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + }, + { + source: "01:23:45:67:89:ab", + result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Macaddr + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestMacaddrAssignTo(t *testing.T) { + { + src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present} + var dst net.HardwareAddr + expected := mustParseMacaddr(t, "01:23:45:67:89:ab") + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare([]byte(dst), []byte(expected)) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present} + var dst string + expected := "01:23:45:67:89:ab" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } +} diff --git a/pgtype.go b/pgtype.go index 52cad561..6b06539b 100644 --- a/pgtype.go +++ b/pgtype.go @@ -245,6 +245,7 @@ func init() { "jsonb": &Jsonb{}, "line": &Line{}, "lseg": &Lseg{}, + "macaddr": &Macaddr{}, "name": &Name{}, "numeric": &Numeric{}, "numrange": &Numrange{}, diff --git a/pgtype_test.go b/pgtype_test.go index 298cff64..0b1ffc54 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -78,6 +78,15 @@ func mustParseCidr(t testing.TB, s string) *net.IPNet { return ipnet } +func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { + addr, err := net.ParseMAC(s) + if err != nil { + t.Fatal(err) + } + + return addr +} + type forceTextEncoder struct { e pgtype.TextEncoder } From c31fe24693870e99050826a4f257f6e5790b4d04 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 21:13:00 -0500 Subject: [PATCH 0065/1158] Fix pgtype.Inet.AssignTo assigning reference AssignTo should always assign copy. Added documentation for AssignTo interface. --- inet.go | 10 ++++++++-- pgtype.go | 3 ++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/inet.go b/inet.go index 0ca3ee7a..3e00e2fa 100644 --- a/inet.go +++ b/inet.go @@ -70,13 +70,19 @@ func (src *Inet) AssignTo(dst interface{}) error { case Present: switch v := dst.(type) { case *net.IPNet: - *v = *src.IPNet + *v = net.IPNet{ + IP: make(net.IP, len(src.IPNet.IP)), + Mask: make(net.IPMask, len(src.IPNet.Mask)), + } + copy(v.IP, src.IPNet.IP) + copy(v.Mask, src.IPNet.Mask) return nil case *net.IP: if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { return fmt.Errorf("cannot assign %v to %T", src, dst) } - *v = src.IPNet.IP + *v = make(net.IP, len(src.IPNet.IP)) + copy(*v, src.IPNet.IP) return nil default: if nextDst, retry := GetAssignToDstType(dst); retry { diff --git a/pgtype.go b/pgtype.go index 6b06539b..5de07b7d 100644 --- a/pgtype.go +++ b/pgtype.go @@ -89,7 +89,8 @@ type Value interface { // possible, then Get() returns Value. Get() interface{} - // AssignTo converts and assigns the Value to dst. + // AssignTo converts and assigns the Value to dst. It MUST make a deep copy of + // any reference types. AssignTo(dst interface{}) error } From 68fd815778efea51ea3811b359019495578ca334 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 5 Apr 2017 07:54:41 -0500 Subject: [PATCH 0066/1158] Add pgtype.Varbit --- pgtype.go | 1 + varbit.go | 141 +++++++++++++++++++++++++++++++++++++++++++++++++ varbit_test.go | 25 +++++++++ 3 files changed, 167 insertions(+) create mode 100644 varbit.go create mode 100644 varbit_test.go diff --git a/pgtype.go b/pgtype.go index 5de07b7d..338afc9b 100644 --- a/pgtype.go +++ b/pgtype.go @@ -263,6 +263,7 @@ func init() { "tstzrange": &Tstzrange{}, "unknown": &Unknown{}, "uuid": &Uuid{}, + "varbit": &Varbit{}, "varchar": &Varchar{}, "xid": &Xid{}, } diff --git a/varbit.go b/varbit.go new file mode 100644 index 00000000..d28e95cd --- /dev/null +++ b/varbit.go @@ -0,0 +1,141 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Varbit struct { + Bytes []byte + Len int32 // Number of bits + Status Status +} + +func (dst *Varbit) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Varbit", src) +} + +func (dst *Varbit) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Varbit) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Varbit) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Varbit{Status: Null} + return nil + } + + bitLen := len(src) + byteLen := bitLen / 8 + if bitLen%8 > 0 { + byteLen++ + } + buf := make([]byte, byteLen) + + for i, b := range src { + if b == '1' { + byteIdx := i / 8 + bitIdx := uint(i % 8) + buf[byteIdx] = buf[byteIdx] | (128 >> bitIdx) + } + } + + *dst = Varbit{Bytes: buf, Len: int32(bitLen), Status: Present} + return nil +} + +func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Varbit{Status: Null} + return nil + } + + if len(src) < 4 { + return fmt.Errorf("invalid length for varbit: %v", len(src)) + } + + bitLen := int32(binary.BigEndian.Uint32(src)) + rp := 4 + + buf := make([]byte, len(src[rp:])) + copy(buf, src[rp:]) + + *dst = Varbit{Bytes: buf, Len: bitLen, Status: Present} + return nil +} + +func (src *Varbit) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + buf := make([]byte, int(src.Len)) + for i, _ := range buf { + byteIdx := i / 8 + bitMask := byte(128 >> byte(i%8)) + char := byte('0') + if src.Bytes[byteIdx]&bitMask > 0 { + char = '1' + } + buf[i] = char + } + + _, err := w.Write(buf) + return false, err +} + +func (src *Varbit) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := pgio.WriteInt32(w, src.Len); err != nil { + return false, err + } + + _, err := w.Write(src.Bytes) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Varbit) Scan(src interface{}) error { + if src == nil { + *dst = Varbit{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Varbit) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/varbit_test.go b/varbit_test.go new file mode 100644 index 00000000..cd146d26 --- /dev/null +++ b/varbit_test.go @@ -0,0 +1,25 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestVarbitTranscode(t *testing.T) { + testSuccessfulTranscode(t, "varbit", []interface{}{ + &pgtype.Varbit{Bytes: []byte{}, Len: 0, Status: pgtype.Present}, + &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Status: pgtype.Present}, + &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Status: pgtype.Present}, + &pgtype.Varbit{Status: pgtype.Null}, + }) +} + +func TestVarbitNormalize(t *testing.T) { + testSuccessfulNormalize(t, []normalizeTest{ + { + sql: "select B'111111111'", + value: &pgtype.Varbit{Bytes: []byte{255, 128}, Len: 9, Status: pgtype.Present}, + }, + }) +} From 7ff405ff840a0f1177039f5a2aa384dd3fb3e3c2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 10 Apr 2017 08:58:51 -0500 Subject: [PATCH 0067/1158] Add simple protocol suuport with (Query|Exec)Ex --- cid_test.go | 17 +++++++++++++++-- json.go | 2 +- numeric.go | 21 +++++++++++++++++++-- numeric_test.go | 3 +++ pgtype_test.go | 31 +++++++++++++++++++++++++++++++ xid_test.go | 17 +++++++++++++++-- 6 files changed, 84 insertions(+), 7 deletions(-) diff --git a/cid_test.go b/cid_test.go index 0d114cda..210573f6 100644 --- a/cid_test.go +++ b/cid_test.go @@ -8,10 +8,23 @@ import ( ) func TestCidTranscode(t *testing.T) { - testSuccessfulTranscode(t, "cid", []interface{}{ + pgTypeName := "cid" + values := []interface{}{ pgtype.Cid{Uint: 42, Status: pgtype.Present}, pgtype.Cid{Status: pgtype.Null}, - }) + } + eqFunc := func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + } + + testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + + // No direct conversion from int to cid, convert through text + testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) + + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } } func TestCidSet(t *testing.T) { diff --git a/json.go b/json.go index 05d965ca..b1c061f9 100644 --- a/json.go +++ b/json.go @@ -145,7 +145,7 @@ func (dst *Json) Scan(src interface{}) error { func (src Json) Value() (driver.Value, error) { switch src.Status { case Present: - return src.Bytes, nil + return string(src.Bytes), nil case Null: return nil, nil default: diff --git a/numeric.go b/numeric.go index 0f3f6529..a26e8c89 100644 --- a/numeric.go +++ b/numeric.go @@ -121,13 +121,13 @@ func (src *Numeric) AssignTo(dst interface{}) error { case Present: switch v := dst.(type) { case *float32: - f, err := strconv.ParseFloat(src.Int.String(), 64) + f, err := src.toFloat64() if err != nil { return err } return float64AssignTo(f, src.Status, dst) case *float64: - f, err := strconv.ParseFloat(src.Int.String(), 64) + f, err := src.toFloat64() if err != nil { return err } @@ -283,6 +283,23 @@ func (dst *Numeric) toBigInt() (*big.Int, error) { return num, nil } +func (src *Numeric) toFloat64() (float64, error) { + f, err := strconv.ParseFloat(src.Int.String(), 64) + if err != nil { + return 0, err + } + if src.Exp > 0 { + for i := 0; i < int(src.Exp); i++ { + f *= 10 + } + } else if src.Exp < 0 { + for i := 0; i > int(src.Exp); i-- { + f /= 10 + } + } + return f, nil +} + func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Numeric{Status: Null} diff --git a/numeric_test.go b/numeric_test.go index 64dea847..93aa8866 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -247,9 +247,12 @@ func TestNumericAssignTo(t *testing.T) { }{ {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f32, expected: float32(4.2)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f64, expected: float64(4.2)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: 3, Status: pgtype.Present}, dst: &i64, expected: int64(42000)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i, expected: int(42)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, diff --git a/pgtype_test.go b/pgtype_test.go index 0b1ffc54..f486f077 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "context" "database/sql" "fmt" "io" @@ -125,6 +126,7 @@ func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) } @@ -175,6 +177,35 @@ func testPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] } } +func testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := conn.QueryRowEx( + context.Background(), + fmt.Sprintf("select ($1)::%s", pgTypeName), + &pgx.QueryExOptions{SimpleProtocol: true}, + v, + ).Scan(result.Interface()) + if err != nil { + t.Errorf("Simple protocol %d: %v", i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface()) + } + } +} + func testDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { conn := mustConnectDatabaseSQL(t, driverName) defer mustClose(t, conn) diff --git a/xid_test.go b/xid_test.go index fecfb64b..11dd0615 100644 --- a/xid_test.go +++ b/xid_test.go @@ -8,10 +8,23 @@ import ( ) func TestXidTranscode(t *testing.T) { - testSuccessfulTranscode(t, "xid", []interface{}{ + pgTypeName := "xid" + values := []interface{}{ pgtype.Xid{Uint: 42, Status: pgtype.Present}, pgtype.Xid{Status: pgtype.Null}, - }) + } + eqFunc := func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + } + + testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + + // No direct conversion from int to xid, convert through text + testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) + + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } } func TestXidSet(t *testing.T) { From e76cf5617fa7bbdb4c502ee2e96c08665711f5de Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 11 Apr 2017 20:16:41 -0500 Subject: [PATCH 0068/1158] Skip line tests on when server version < PG 9.4 --- line_test.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/line_test.go b/line_test.go index 6d3b02e1..995eaad5 100644 --- a/line_test.go +++ b/line_test.go @@ -3,10 +3,24 @@ package pgtype_test import ( "testing" + version "github.com/hashicorp/go-version" "github.com/jackc/pgx/pgtype" ) func TestLineTranscode(t *testing.T) { + conn := mustConnectPgx(t) + serverVersion, err := version.NewVersion(conn.RuntimeParams["server_version"]) + if err != nil { + t.Fatalf("cannot get server version: %v", err) + } + mustClose(t, conn) + + minVersion := version.Must(version.NewVersion("9.4")) + + if serverVersion.LessThan(minVersion) { + t.Skipf("Skipping line test for server version %v", serverVersion) + } + testSuccessfulTranscode(t, "line", []interface{}{ &pgtype.Line{ A: 1.23, B: 4.56, C: 7.89, From 92474ef2927b843293441bb55bd6d998a93595d6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 13 Apr 2017 21:54:04 -0500 Subject: [PATCH 0069/1158] Add MarshalJSON to a few types --- int2.go | 13 +++++++++++++ int4.go | 13 +++++++++++++ int8.go | 13 +++++++++++++ pgtype.go | 1 + text.go | 14 ++++++++++++++ varchar.go | 4 ++++ 6 files changed, 58 insertions(+) diff --git a/int2.go b/int2.go index 3bcac63c..0cb6ef82 100644 --- a/int2.go +++ b/int2.go @@ -195,3 +195,16 @@ func (src Int2) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src Int2) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return []byte(strconv.FormatInt(int64(src.Int), 10)), nil + case Null: + return []byte("null"), nil + case Undefined: + return []byte("undefined"), nil + } + + return nil, errBadStatus +} diff --git a/int4.go b/int4.go index 5069dab4..4a5bca51 100644 --- a/int4.go +++ b/int4.go @@ -186,3 +186,16 @@ func (src Int4) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src Int4) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return []byte(strconv.FormatInt(int64(src.Int), 10)), nil + case Null: + return []byte("null"), nil + case Undefined: + return []byte("undefined"), nil + } + + return nil, errBadStatus +} diff --git a/int8.go b/int8.go index cf701dc6..0cc3545d 100644 --- a/int8.go +++ b/int8.go @@ -172,3 +172,16 @@ func (src Int8) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src Int8) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return []byte(strconv.FormatInt(src.Int, 10)), nil + case Null: + return []byte("null"), nil + case Undefined: + return []byte("undefined"), nil + } + + return nil, errBadStatus +} diff --git a/pgtype.go b/pgtype.go index 338afc9b..27a1a091 100644 --- a/pgtype.go +++ b/pgtype.go @@ -129,6 +129,7 @@ type TextEncoder interface { } var errUndefined = errors.New("cannot encode status undefined") +var errBadStatus = errors.New("invalid status") type DataType struct { Value Value diff --git a/text.go b/text.go index 482c9023..62158b09 100644 --- a/text.go +++ b/text.go @@ -2,6 +2,7 @@ package pgtype import ( "database/sql/driver" + "encoding/json" "fmt" "io" ) @@ -134,3 +135,16 @@ func (src Text) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src Text) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return json.Marshal(src.String) + case Null: + return []byte("null"), nil + case Undefined: + return []byte("undefined"), nil + } + + return nil, errBadStatus +} diff --git a/varchar.go b/varchar.go index f25ada5d..6c137b9a 100644 --- a/varchar.go +++ b/varchar.go @@ -49,3 +49,7 @@ func (dst *Varchar) Scan(src interface{}) error { func (src Varchar) Value() (driver.Value, error) { return (Text)(src).Value() } + +func (src Varchar) MarshalJSON() ([]byte, error) { + return (Text)(src).MarshalJSON() +} From b49035fdc15f15e90baf6f5698d6d28bb93b15cc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Apr 2017 12:18:49 -0500 Subject: [PATCH 0070/1158] Add shopspring.Numeric This adds PostgreSQL numeric mapping to and from github.com/shopspring/decimal. Makes pgtype.NullAssignTo public as external types need this functionality. Begin extraction of pgtype testing functionality so it can easily be used by external types. --- aclitem.go | 2 +- aclitem_array.go | 2 +- bool.go | 2 +- bool_array.go | 2 +- bytea.go | 2 +- bytea_array.go | 2 +- cidr_array.go | 2 +- convert.go | 2 +- date.go | 2 +- date_array.go | 2 +- ext/shopspring-numeric/decimal.go | 320 +++++++++++++++++++++++++ ext/shopspring-numeric/decimal_test.go | 281 ++++++++++++++++++++++ float4_array.go | 2 +- float8_array.go | 2 +- hstore.go | 2 +- hstore_array.go | 2 +- inet.go | 2 +- inet_array.go | 2 +- int2_array.go | 2 +- int4_array.go | 2 +- int8_array.go | 2 +- interval.go | 2 +- macaddr.go | 2 +- numeric.go | 2 +- numeric_array.go | 2 +- record.go | 2 +- testutil/testutil.go | 298 +++++++++++++++++++++++ text.go | 2 +- text_array.go | 2 +- timestamp.go | 2 +- timestamp_array.go | 2 +- timestamptz.go | 2 +- timestamptz_array.go | 2 +- typed_array.go.erb | 2 +- uuid.go | 2 +- varchar_array.go | 2 +- 36 files changed, 932 insertions(+), 33 deletions(-) create mode 100644 ext/shopspring-numeric/decimal.go create mode 100644 ext/shopspring-numeric/decimal_test.go create mode 100644 testutil/testutil.go diff --git a/aclitem.go b/aclitem.go index 77e385e6..3ccf8318 100644 --- a/aclitem.go +++ b/aclitem.go @@ -67,7 +67,7 @@ func (src *Aclitem) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/aclitem_array.go b/aclitem_array.go index 20a7636a..7ef76573 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -78,7 +78,7 @@ func (src *AclitemArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/bool.go b/bool.go index 736d19cf..1ebf590b 100644 --- a/bool.go +++ b/bool.go @@ -56,7 +56,7 @@ func (src *Bool) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/bool_array.go b/bool_array.go index 4705d734..468f6816 100644 --- a/bool_array.go +++ b/bool_array.go @@ -79,7 +79,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/bytea.go b/bytea.go index 9f0266e7..8bf5de2b 100644 --- a/bytea.go +++ b/bytea.go @@ -61,7 +61,7 @@ func (src *Bytea) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/bytea_array.go b/bytea_array.go index 268364c1..4aa2b862 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -79,7 +79,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/cidr_array.go b/cidr_array.go index 6643bb47..96d912ae 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -108,7 +108,7 @@ func (src *CidrArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/convert.go b/convert.go index 4fba8430..2b406426 100644 --- a/convert.go +++ b/convert.go @@ -342,7 +342,7 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) } -func nullAssignTo(dst interface{}) error { +func NullAssignTo(dst interface{}) error { dstPtr := reflect.ValueOf(dst) // AssignTo dst must always be a pointer diff --git a/date.go b/date.go index 7dd2c4f0..34753f05 100644 --- a/date.go +++ b/date.go @@ -70,7 +70,7 @@ func (src *Date) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/date_array.go b/date_array.go index f58de011..f24bf6b9 100644 --- a/date_array.go +++ b/date_array.go @@ -80,7 +80,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go new file mode 100644 index 00000000..9c7e316b --- /dev/null +++ b/ext/shopspring-numeric/decimal.go @@ -0,0 +1,320 @@ +package numeric + +import ( + "bytes" + "database/sql/driver" + "errors" + "fmt" + "io" + "strconv" + + "github.com/jackc/pgx/pgtype" + "github.com/shopspring/decimal" +) + +var errUndefined = errors.New("cannot encode status undefined") + +type Numeric struct { + Decimal decimal.Decimal + Status pgtype.Status +} + +func (dst *Numeric) Set(src interface{}) error { + if src == nil { + *dst = Numeric{Status: pgtype.Null} + return nil + } + + switch value := src.(type) { + case decimal.Decimal: + *dst = Numeric{Decimal: value, Status: pgtype.Present} + case float32: + *dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present} + case float64: + *dst = Numeric{Decimal: decimal.NewFromFloat(value), Status: pgtype.Present} + case int8: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint8: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case int16: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint16: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case int32: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint32: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case int64: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint64: + // uint64 could be greater than int64 so convert to string then to decimal + dec, err := decimal.NewFromString(strconv.FormatUint(value, 10)) + if err != nil { + return err + } + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + case int: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint: + // uint could be greater than int64 so convert to string then to decimal + dec, err := decimal.NewFromString(strconv.FormatUint(uint64(value), 10)) + if err != nil { + return err + } + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + case string: + dec, err := decimal.NewFromString(value) + if err != nil { + return err + } + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + default: + // If all else fails see if pgtype.Numeric can handle it. If so, translate through that. + num := &pgtype.Numeric{} + if err := num.Set(value); err != nil { + return fmt.Errorf("cannot convert %v to Numeric", value) + } + + buf := &bytes.Buffer{} + if _, err := num.EncodeText(nil, buf); err != nil { + return fmt.Errorf("cannot convert %v to Numeric", value) + } + + dec, err := decimal.NewFromString(buf.String()) + if err != nil { + return fmt.Errorf("cannot convert %v to Numeric", value) + } + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + } + + return nil +} + +func (dst *Numeric) Get() interface{} { + switch dst.Status { + case pgtype.Present: + return dst.Decimal + case pgtype.Null: + return nil + default: + return dst.Status + } +} + +func (src *Numeric) AssignTo(dst interface{}) error { + switch src.Status { + case pgtype.Present: + switch v := dst.(type) { + case *decimal.Decimal: + *v = src.Decimal + case *float32: + f, _ := src.Decimal.Float64() + *v = float32(f) + case *float64: + f, _ := src.Decimal.Float64() + *v = f + case *int: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, strconv.IntSize) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int(n) + case *int8: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 8) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int8(n) + case *int16: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 16) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int16(n) + case *int32: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 32) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int32(n) + case *int64: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 64) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int64(n) + case *uint: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, strconv.IntSize) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint(n) + case *uint8: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 8) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint8(n) + case *uint16: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 16) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint16(n) + case *uint32: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 32) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint32(n) + case *uint64: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 64) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint64(n) + default: + if nextDst, retry := pgtype.GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case pgtype.Null: + return pgtype.NullAssignTo(dst) + } + + return nil +} + +func (dst *Numeric) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: pgtype.Null} + return nil + } + + dec, err := decimal.NewFromString(string(src)) + if err != nil { + return err + } + + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + return nil +} + +func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: pgtype.Null} + return nil + } + + // For now at least, implement this in terms of pgtype.Numeric + + num := &pgtype.Numeric{} + if err := num.DecodeBinary(ci, src); err != nil { + return err + } + + buf := &bytes.Buffer{} + if _, err := num.EncodeText(ci, buf); err != nil { + return err + } + + dec, err := decimal.NewFromString(buf.String()) + if err != nil { + return err + } + + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + + return nil +} + +func (src *Numeric) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case pgtype.Null: + return true, nil + case pgtype.Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, src.Decimal.String()) + return false, err +} + +func (src *Numeric) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case pgtype.Null: + return true, nil + case pgtype.Undefined: + return false, errUndefined + } + + // For now at least, implement this in terms of pgtype.Numeric + num := &pgtype.Numeric{} + if err := num.DecodeText(ci, []byte(src.Decimal.String())); err != nil { + return false, err + } + + return num.EncodeBinary(ci, w) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Numeric) Scan(src interface{}) error { + if src == nil { + *dst = Numeric{Status: pgtype.Null} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Numeric{Decimal: decimal.NewFromFloat(src), Status: pgtype.Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Numeric) Value() (driver.Value, error) { + switch src.Status { + case pgtype.Present: + return src.Decimal.Value() + case pgtype.Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go new file mode 100644 index 00000000..50c0fb8b --- /dev/null +++ b/ext/shopspring-numeric/decimal_test.go @@ -0,0 +1,281 @@ +package numeric_test + +import ( + "fmt" + "math/big" + "math/rand" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + shopspring "github.com/jackc/pgx/pgtype/ext/shopspring-numeric" + "github.com/jackc/pgx/pgtype/testutil" + "github.com/shopspring/decimal" +) + +func mustParseDecimal(t *testing.T, src string) decimal.Decimal { + dec, err := decimal.NewFromString(src) + if err != nil { + t.Fatal(err) + } + return dec +} + +func TestNumericNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select '0'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, + }, + { + SQL: "select '1'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, + }, + { + SQL: "select '10.00'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Status: pgtype.Present}, + }, + { + SQL: "select '1e-3'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, + }, + { + SQL: "select '-1'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, + }, + { + SQL: "select '10000'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Status: pgtype.Present}, + }, + { + SQL: "select '3.14'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, + }, + { + SQL: "select '1.1'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Status: pgtype.Present}, + }, + { + SQL: "select '100010001'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Status: pgtype.Present}, + }, + { + SQL: "select '100010001.0001'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Status: pgtype.Present}, + }, + { + SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", + Value: shopspring.Numeric{ + Decimal: mustParseDecimal(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), + Status: pgtype.Present, + }, + }, + { + SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", + Value: shopspring.Numeric{ + Decimal: mustParseDecimal(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), + Status: pgtype.Present, + }, + }, + { + SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", + Value: shopspring.Numeric{ + Decimal: mustParseDecimal(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), + Status: pgtype.Present, + }, + }, + }) +} + +func TestNumericTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "100000"), Status: pgtype.Present}, + + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.1"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.01"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0001"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00001"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000001"), Status: pgtype.Present}, + + &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0000000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001234567890123456789"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "4309132809320932980457137401234890237489238912983572189348951289375283573984571892758234678903467889512893489128589347891272139.8489235871258912789347891235879148795891238915678189467128957812395781238579189025891238901583915890128973578957912385798125789012378905238905471598123758923478294374327894237892234"), Status: pgtype.Present}, + &shopspring.Numeric{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(shopspring.Numeric) + b := bb.(shopspring.Numeric) + + return a.Status == b.Status && a.Decimal.Equal(b.Decimal) + }) + +} + +func TestNumericTranscodeFuzz(t *testing.T) { + r := rand.New(rand.NewSource(0)) + max := &big.Int{} + max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) + + values := make([]interface{}, 0, 2000) + for i := 0; i < 500; i++ { + num := fmt.Sprintf("%s.%s", (&big.Int{}).Rand(r, max).String(), (&big.Int{}).Rand(r, max).String()) + negNum := "-" + num + values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, num), Status: pgtype.Present}) + values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, negNum), Status: pgtype.Present}) + } + + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, + func(aa, bb interface{}) bool { + a := aa.(shopspring.Numeric) + b := bb.(shopspring.Numeric) + + return a.Status == b.Status && a.Decimal.Equal(b.Decimal) + }) +} + +func TestNumericSet(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result *shopspring.Numeric + }{ + {source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int8(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: int16(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: int32(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: int64(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: uint8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: uint16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: uint32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: uint64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: "1", result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: _int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: float64(1000), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1000"), Status: pgtype.Present}}, + {source: float64(1234), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1234"), Status: pgtype.Present}}, + {source: float64(12345678900), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345678900"), Status: pgtype.Present}}, + {source: float64(12345.678901), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345.678901"), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + r := &shopspring.Numeric{} + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !(r.Status == tt.result.Status && r.Decimal.Equal(tt.result.Decimal)) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericAssignTo(t *testing.T) { + type _int8 int8 + + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src *shopspring.Numeric + dst interface{} + expected interface{} + }{ + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f32, expected: float32(4.2)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f64, expected: float64(4.2)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &i64, expected: int64(42000)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src *shopspring.Numeric + dst interface{} + expected interface{} + }{ + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src *shopspring.Numeric + dst interface{} + }{ + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "150"), Status: pgtype.Present}, dst: &i8}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "40000"), Status: pgtype.Present}, dst: &i16}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui8}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui16}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui32}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui64}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/float4_array.go b/float4_array.go index b9ee4b9e..db1523f0 100644 --- a/float4_array.go +++ b/float4_array.go @@ -79,7 +79,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/float8_array.go b/float8_array.go index d49f18a7..19878bbb 100644 --- a/float8_array.go +++ b/float8_array.go @@ -79,7 +79,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/hstore.go b/hstore.go index b8b0c6f3..5dc78671 100644 --- a/hstore.go +++ b/hstore.go @@ -71,7 +71,7 @@ func (src *Hstore) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/hstore_array.go b/hstore_array.go index 097fec7b..e4263f20 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -79,7 +79,7 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/inet.go b/inet.go index 3e00e2fa..09fce04d 100644 --- a/inet.go +++ b/inet.go @@ -90,7 +90,7 @@ func (src *Inet) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/inet_array.go b/inet_array.go index a108d75b..4687b145 100644 --- a/inet_array.go +++ b/inet_array.go @@ -108,7 +108,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/int2_array.go b/int2_array.go index bddb5ac2..3506370e 100644 --- a/int2_array.go +++ b/int2_array.go @@ -107,7 +107,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/int4_array.go b/int4_array.go index d5c8f911..e4ec6455 100644 --- a/int4_array.go +++ b/int4_array.go @@ -107,7 +107,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/int8_array.go b/int8_array.go index ae2521fa..6c0dab65 100644 --- a/int8_array.go +++ b/int8_array.go @@ -107,7 +107,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/interval.go b/interval.go index 7eddb10f..20a4a419 100644 --- a/interval.go +++ b/interval.go @@ -71,7 +71,7 @@ func (src *Interval) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/macaddr.go b/macaddr.go index 2d09ff8c..2834d69f 100644 --- a/macaddr.go +++ b/macaddr.go @@ -67,7 +67,7 @@ func (src *Macaddr) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/numeric.go b/numeric.go index a26e8c89..63f99c06 100644 --- a/numeric.go +++ b/numeric.go @@ -253,7 +253,7 @@ func (src *Numeric) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return nil diff --git a/numeric_array.go b/numeric_array.go index b147e6a2..3d59a6b0 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -107,7 +107,7 @@ func (src *NumericArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/record.go b/record.go index 9c42c907..3b315d40 100644 --- a/record.go +++ b/record.go @@ -62,7 +62,7 @@ func (src *Record) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/testutil/testutil.go b/testutil/testutil.go new file mode 100644 index 00000000..610f0710 --- /dev/null +++ b/testutil/testutil.go @@ -0,0 +1,298 @@ +package testutil + +import ( + "context" + "database/sql" + "fmt" + "io" + "os" + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" + _ "github.com/jackc/pgx/stdlib" + _ "github.com/lib/pq" +) + +func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { + var sqlDriverName string + switch driverName { + case "github.com/lib/pq": + sqlDriverName = "postgres" + case "github.com/jackc/pgx/stdlib": + sqlDriverName = "pgx" + default: + t.Fatalf("Unknown driver %v", driverName) + } + + db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL")) + if err != nil { + t.Fatal(err) + } + + return db +} + +func mustConnectPgx(t testing.TB) *pgx.Conn { + config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) + if err != nil { + t.Fatal(err) + } + + conn, err := pgx.Connect(config) + if err != nil { + t.Fatal(err) + } + + return conn +} + +func mustClose(t testing.TB, conn interface { + Close() error +}) { + err := conn.Close() + if err != nil { + t.Fatal(err) + } +} + +type forceTextEncoder struct { + e pgtype.TextEncoder +} + +func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + return f.e.EncodeText(ci, w) +} + +type forceBinaryEncoder struct { + e pgtype.BinaryEncoder +} + +func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + return f.e.EncodeBinary(ci, w) +} + +func forceEncoder(e interface{}, formatCode int16) interface{} { + switch formatCode { + case pgx.TextFormatCode: + if e, ok := e.(pgtype.TextEncoder); ok { + return forceTextEncoder{e: e} + } + case pgx.BinaryFormatCode: + if e, ok := e.(pgtype.BinaryEncoder); ok { + return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} + } + } + return nil +} + +func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { + TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } +} + +func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for i, v := range values { + for _, fc := range formats { + ps.FieldDescriptions[0].FormatCode = fc.formatCode + vEncoder := forceEncoder(v, fc.formatCode) + if vEncoder == nil { + t.Logf("Skipping: %#v does not implement %v", v, fc.name) + continue + } + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := conn.QueryRow("test", forceEncoder(v, fc.formatCode)).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + } + } + } +} + +func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := conn.QueryRowEx( + context.Background(), + fmt.Sprintf("select ($1)::%s", pgTypeName), + &pgx.QueryExOptions{SimpleProtocol: true}, + v, + ).Scan(result.Interface()) + if err != nil { + t.Errorf("Simple protocol %d: %v", i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface()) + } + } +} + +func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectDatabaseSQL(t, driverName) + defer mustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := ps.QueryRow(v).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } +} + +type NormalizeTest struct { + SQL string + Value interface{} +} + +func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) { + TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + TestDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc) + } +} + +func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for i, tt := range tests { + for _, fc := range formats { + psName := fmt.Sprintf("test%d", i) + ps, err := conn.Prepare(psName, tt.SQL) + if err != nil { + t.Fatal(err) + } + + ps.FieldDescriptions[0].FormatCode = fc.formatCode + if forceEncoder(tt.Value, fc.formatCode) == nil { + t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name) + continue + } + // Derefence value if it is a pointer + derefV := tt.Value + refVal := reflect.ValueOf(tt.Value) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err = conn.QueryRow(psName).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + } + } + } +} + +func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + conn := mustConnectDatabaseSQL(t, driverName) + defer mustClose(t, conn) + + for i, tt := range tests { + ps, err := conn.Prepare(tt.SQL) + if err != nil { + t.Errorf("%d. %v", i, err) + continue + } + + // Derefence value if it is a pointer + derefV := tt.Value + refVal := reflect.ValueOf(tt.Value) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err = ps.QueryRow().Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } +} diff --git a/text.go b/text.go index 62158b09..de80dd08 100644 --- a/text.go +++ b/text.go @@ -71,7 +71,7 @@ func (src *Text) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/text_array.go b/text_array.go index 64728048..a6bd4724 100644 --- a/text_array.go +++ b/text_array.go @@ -79,7 +79,7 @@ func (src *TextArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/timestamp.go b/timestamp.go index 78c6355e..e7bc1c7d 100644 --- a/timestamp.go +++ b/timestamp.go @@ -74,7 +74,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/timestamp_array.go b/timestamp_array.go index 5d08f9cc..2046c387 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -80,7 +80,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/timestamptz.go b/timestamptz.go index 50370335..ef2d7498 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -75,7 +75,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/timestamptz_array.go b/timestamptz_array.go index 107be06a..fd58d3be 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -80,7 +80,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/typed_array.go.erb b/typed_array.go.erb index 4b8f1a28..2a38ed82 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -77,7 +77,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/uuid.go b/uuid.go index 111bed35..88d2195b 100644 --- a/uuid.go +++ b/uuid.go @@ -69,7 +69,7 @@ func (src *Uuid) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot assign %v into %T", src, dst) diff --git a/varchar_array.go b/varchar_array.go index 2712b4d2..9ca16d7e 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -79,7 +79,7 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) From e380de7cd1b046970c603b53d8b61f952a336a91 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Apr 2017 12:38:33 -0500 Subject: [PATCH 0071/1158] Finish extraction of pgtype test helpers --- aclitem_array_test.go | 3 +- aclitem_test.go | 3 +- bool_array_test.go | 3 +- bool_test.go | 3 +- box_test.go | 9 +- bytea_array_test.go | 3 +- bytea_test.go | 3 +- cid_test.go | 7 +- cidr_array_test.go | 3 +- circle_test.go | 3 +- date_array_test.go | 3 +- date_test.go | 3 +- daterange_test.go | 9 +- float4_array_test.go | 3 +- float4_test.go | 3 +- float8_array_test.go | 3 +- float8_test.go | 3 +- hstore_array_test.go | 7 +- hstore_test.go | 3 +- inet_array_test.go | 3 +- inet_test.go | 3 +- int2_array_test.go | 3 +- int2_test.go | 3 +- int4_array_test.go | 3 +- int4_test.go | 3 +- int4range_test.go | 9 +- int8_array_test.go | 3 +- int8_test.go | 3 +- int8range_test.go | 9 +- interval_test.go | 33 ++--- json_test.go | 3 +- jsonb_test.go | 7 +- line_test.go | 7 +- lseg_test.go | 3 +- macaddr_test.go | 3 +- name_test.go | 3 +- numeric_array_test.go | 3 +- numeric_test.go | 59 ++++---- numrange_test.go | 3 +- oid_value_test.go | 3 +- path_test.go | 3 +- pgtype_test.go | 291 -------------------------------------- point_test.go | 3 +- polygon_test.go | 3 +- qchar_test.go | 3 +- record_test.go | 5 +- testutil/testutil.go | 34 ++--- text_array_test.go | 3 +- text_test.go | 3 +- tid_test.go | 3 +- timestamp_array_test.go | 3 +- timestamp_test.go | 3 +- timestamptz_array_test.go | 3 +- timestamptz_test.go | 3 +- tsrange_test.go | 3 +- tstzrange_test.go | 3 +- uuid_test.go | 3 +- varbit_test.go | 9 +- varchar_array_test.go | 3 +- xid_test.go | 7 +- 60 files changed, 202 insertions(+), 435 deletions(-) diff --git a/aclitem_array_test.go b/aclitem_array_test.go index 75c672bd..951e7847 100644 --- a/aclitem_array_test.go +++ b/aclitem_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestAclitemArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "aclitem[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "aclitem[]", []interface{}{ &pgtype.AclitemArray{ Elements: nil, Dimensions: nil, diff --git a/aclitem_test.go b/aclitem_test.go index 1738025a..5389eab2 100644 --- a/aclitem_test.go +++ b/aclitem_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestAclitemTranscode(t *testing.T) { - testSuccessfulTranscode(t, "aclitem", []interface{}{ + testutil.TestSuccessfulTranscode(t, "aclitem", []interface{}{ pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, pgtype.Aclitem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, pgtype.Aclitem{Status: pgtype.Null}, diff --git a/bool_array_test.go b/bool_array_test.go index a526d892..87886da6 100644 --- a/bool_array_test.go +++ b/bool_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestBoolArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "bool[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "bool[]", []interface{}{ &pgtype.BoolArray{ Elements: nil, Dimensions: nil, diff --git a/bool_test.go b/bool_test.go index 412e2fd0..31f3d528 100644 --- a/bool_test.go +++ b/bool_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestBoolTranscode(t *testing.T) { - testSuccessfulTranscode(t, "bool", []interface{}{ + testutil.TestSuccessfulTranscode(t, "bool", []interface{}{ pgtype.Bool{Bool: false, Status: pgtype.Present}, pgtype.Bool{Bool: true, Status: pgtype.Present}, pgtype.Bool{Bool: false, Status: pgtype.Null}, diff --git a/box_test.go b/box_test.go index 00732973..f26cda68 100644 --- a/box_test.go +++ b/box_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestBoxTranscode(t *testing.T) { - testSuccessfulTranscode(t, "box", []interface{}{ + testutil.TestSuccessfulTranscode(t, "box", []interface{}{ &pgtype.Box{ P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, Status: pgtype.Present, @@ -21,10 +22,10 @@ func TestBoxTranscode(t *testing.T) { } func TestBoxNormalize(t *testing.T) { - testSuccessfulNormalize(t, []normalizeTest{ + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { - sql: "select '3.14, 1.678, 7.1, 5.234'::box", - value: &pgtype.Box{ + SQL: "select '3.14, 1.678, 7.1, 5.234'::box", + Value: &pgtype.Box{ P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, Status: pgtype.Present, }, diff --git a/bytea_array_test.go b/bytea_array_test.go index 22c6478b..451c2461 100644 --- a/bytea_array_test.go +++ b/bytea_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestByteaArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "bytea[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "bytea[]", []interface{}{ &pgtype.ByteaArray{ Elements: nil, Dimensions: nil, diff --git a/bytea_test.go b/bytea_test.go index e21296c6..7d32e294 100644 --- a/bytea_test.go +++ b/bytea_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestByteaTranscode(t *testing.T) { - testSuccessfulTranscode(t, "bytea", []interface{}{ + testutil.TestSuccessfulTranscode(t, "bytea", []interface{}{ pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, pgtype.Bytea{Bytes: nil, Status: pgtype.Null}, diff --git a/cid_test.go b/cid_test.go index 210573f6..385b8cac 100644 --- a/cid_test.go +++ b/cid_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestCidTranscode(t *testing.T) { @@ -17,13 +18,13 @@ func TestCidTranscode(t *testing.T) { return reflect.DeepEqual(a, b) } - testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) // No direct conversion from int to cid, convert through text - testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) + testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) } } diff --git a/cidr_array_test.go b/cidr_array_test.go index ec105914..1ebe5195 100644 --- a/cidr_array_test.go +++ b/cidr_array_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestCidrArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "cidr[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "cidr[]", []interface{}{ &pgtype.CidrArray{ Elements: nil, Dimensions: nil, diff --git a/circle_test.go b/circle_test.go index 9746dd74..2747d4f5 100644 --- a/circle_test.go +++ b/circle_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestCircleTranscode(t *testing.T) { - testSuccessfulTranscode(t, "circle", []interface{}{ + testutil.TestSuccessfulTranscode(t, "circle", []interface{}{ &pgtype.Circle{P: pgtype.Vec2{1.234, 5.6789}, R: 3.5, Status: pgtype.Present}, &pgtype.Circle{P: pgtype.Vec2{-1.234, -5.6789}, R: 12.9, Status: pgtype.Present}, &pgtype.Circle{Status: pgtype.Null}, diff --git a/date_array_test.go b/date_array_test.go index a05f4254..74ebfbbe 100644 --- a/date_array_test.go +++ b/date_array_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestDateArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "date[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "date[]", []interface{}{ &pgtype.DateArray{ Elements: nil, Dimensions: nil, diff --git a/date_test.go b/date_test.go index 1832b5b4..d1493f5e 100644 --- a/date_test.go +++ b/date_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestDateTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "date", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "date", []interface{}{ pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, diff --git a/daterange_test.go b/daterange_test.go index 8501cc7e..7dfae0f4 100644 --- a/daterange_test.go +++ b/daterange_test.go @@ -5,10 +5,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestDaterangeTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "daterange", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "daterange", []interface{}{ pgtype.Daterange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, pgtype.Daterange{ Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, @@ -40,10 +41,10 @@ func TestDaterangeTranscode(t *testing.T) { } func TestDaterangeNormalize(t *testing.T) { - testSuccessfulNormalizeEqFunc(t, []normalizeTest{ + testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ { - sql: "select daterange('2010-01-01', '2010-01-11', '(]')", - value: pgtype.Daterange{ + SQL: "select daterange('2010-01-01', '2010-01-11', '(]')", + Value: pgtype.Daterange{ Lower: pgtype.Date{Time: time.Date(2010, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, Upper: pgtype.Date{Time: time.Date(2010, 1, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, LowerType: pgtype.Inclusive, diff --git a/float4_array_test.go b/float4_array_test.go index 06a1d2e0..6d6a4f30 100644 --- a/float4_array_test.go +++ b/float4_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestFloat4ArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "float4[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "float4[]", []interface{}{ &pgtype.Float4Array{ Elements: nil, Dimensions: nil, diff --git a/float4_test.go b/float4_test.go index ea60cd3a..57f4bc34 100644 --- a/float4_test.go +++ b/float4_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestFloat4Transcode(t *testing.T) { - testSuccessfulTranscode(t, "float4", []interface{}{ + testutil.TestSuccessfulTranscode(t, "float4", []interface{}{ pgtype.Float4{Float: -1, Status: pgtype.Present}, pgtype.Float4{Float: 0, Status: pgtype.Present}, pgtype.Float4{Float: 0.00001, Status: pgtype.Present}, diff --git a/float8_array_test.go b/float8_array_test.go index 635e249a..56801e80 100644 --- a/float8_array_test.go +++ b/float8_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestFloat8ArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "float8[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "float8[]", []interface{}{ &pgtype.Float8Array{ Elements: nil, Dimensions: nil, diff --git a/float8_test.go b/float8_test.go index 724e9350..b7527b86 100644 --- a/float8_test.go +++ b/float8_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestFloat8Transcode(t *testing.T) { - testSuccessfulTranscode(t, "float8", []interface{}{ + testutil.TestSuccessfulTranscode(t, "float8", []interface{}{ pgtype.Float8{Float: -1, Status: pgtype.Present}, pgtype.Float8{Float: 0, Status: pgtype.Present}, pgtype.Float8{Float: 0.00001, Status: pgtype.Present}, diff --git a/hstore_array_test.go b/hstore_array_test.go index e23c7b3b..d26497b1 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -6,11 +6,12 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestHstoreArrayTranscode(t *testing.T) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) + conn := testutil.MustConnectPgx(t) + defer testutil.MustClose(t, conn) text := func(s string) pgtype.Text { return pgtype.Text{String: s, Status: pgtype.Present} @@ -69,7 +70,7 @@ func TestHstoreArrayTranscode(t *testing.T) { for _, fc := range formats { ps.FieldDescriptions[0].FormatCode = fc.formatCode - vEncoder := forceEncoder(src, fc.formatCode) + vEncoder := testutil.ForceEncoder(src, fc.formatCode) if vEncoder == nil { t.Logf("%#v does not implement %v", src, fc.name) continue diff --git a/hstore_test.go b/hstore_test.go index fbe8dee5..502a8df0 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestHstoreTranscode(t *testing.T) { @@ -44,7 +45,7 @@ func TestHstoreTranscode(t *testing.T) { values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key } - testSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { + testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { a := ai.(pgtype.Hstore) b := bi.(pgtype.Hstore) diff --git a/inet_array_test.go b/inet_array_test.go index fe22285d..c0465922 100644 --- a/inet_array_test.go +++ b/inet_array_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInetArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "inet[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "inet[]", []interface{}{ &pgtype.InetArray{ Elements: nil, Dimensions: nil, diff --git a/inet_test.go b/inet_test.go index 16035fca..532e9abe 100644 --- a/inet_test.go +++ b/inet_test.go @@ -6,11 +6,12 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInetTranscode(t *testing.T) { for _, pgTypeName := range []string{"inet", "cidr"} { - testSuccessfulTranscode(t, pgTypeName, []interface{}{ + testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ pgtype.Inet{IPNet: mustParseCidr(t, "0.0.0.0/32"), Status: pgtype.Present}, pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, diff --git a/int2_array_test.go b/int2_array_test.go index 8af4523d..0adc1aef 100644 --- a/int2_array_test.go +++ b/int2_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInt2ArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "int2[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "int2[]", []interface{}{ &pgtype.Int2Array{ Elements: nil, Dimensions: nil, diff --git a/int2_test.go b/int2_test.go index 2bd8e016..d81405a6 100644 --- a/int2_test.go +++ b/int2_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInt2Transcode(t *testing.T) { - testSuccessfulTranscode(t, "int2", []interface{}{ + testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}, pgtype.Int2{Int: -1, Status: pgtype.Present}, pgtype.Int2{Int: 0, Status: pgtype.Present}, diff --git a/int4_array_test.go b/int4_array_test.go index 111cb56b..6fad18bb 100644 --- a/int4_array_test.go +++ b/int4_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInt4ArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "int4[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "int4[]", []interface{}{ &pgtype.Int4Array{ Elements: nil, Dimensions: nil, diff --git a/int4_test.go b/int4_test.go index 3e000182..1354b47a 100644 --- a/int4_test.go +++ b/int4_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInt4Transcode(t *testing.T) { - testSuccessfulTranscode(t, "int4", []interface{}{ + testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ pgtype.Int4{Int: math.MinInt32, Status: pgtype.Present}, pgtype.Int4{Int: -1, Status: pgtype.Present}, pgtype.Int4{Int: 0, Status: pgtype.Present}, diff --git a/int4range_test.go b/int4range_test.go index c96fe9cd..74a91e59 100644 --- a/int4range_test.go +++ b/int4range_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInt4rangeTranscode(t *testing.T) { - testSuccessfulTranscode(t, "int4range", []interface{}{ + testutil.TestSuccessfulTranscode(t, "int4range", []interface{}{ pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int4{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, @@ -16,10 +17,10 @@ func TestInt4rangeTranscode(t *testing.T) { } func TestInt4rangeNormalize(t *testing.T) { - testSuccessfulNormalize(t, []normalizeTest{ + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { - sql: "select int4range(1, 10, '(]')", - value: pgtype.Int4range{Lower: pgtype.Int4{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + SQL: "select int4range(1, 10, '(]')", + Value: pgtype.Int4range{Lower: pgtype.Int4{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, }, }) } diff --git a/int8_array_test.go b/int8_array_test.go index 349a1f7e..4f5c4f9a 100644 --- a/int8_array_test.go +++ b/int8_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInt8ArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "int8[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "int8[]", []interface{}{ &pgtype.Int8Array{ Elements: nil, Dimensions: nil, diff --git a/int8_test.go b/int8_test.go index e1fe69fb..d6752205 100644 --- a/int8_test.go +++ b/int8_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInt8Transcode(t *testing.T) { - testSuccessfulTranscode(t, "int8", []interface{}{ + testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ pgtype.Int8{Int: math.MinInt64, Status: pgtype.Present}, pgtype.Int8{Int: -1, Status: pgtype.Present}, pgtype.Int8{Int: 0, Status: pgtype.Present}, diff --git a/int8range_test.go b/int8range_test.go index 1b3e594c..703f476e 100644 --- a/int8range_test.go +++ b/int8range_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInt8rangeTranscode(t *testing.T) { - testSuccessfulTranscode(t, "Int8range", []interface{}{ + testutil.TestSuccessfulTranscode(t, "Int8range", []interface{}{ pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int8{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, @@ -16,10 +17,10 @@ func TestInt8rangeTranscode(t *testing.T) { } func TestInt8rangeNormalize(t *testing.T) { - testSuccessfulNormalize(t, []normalizeTest{ + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { - sql: "select Int8range(1, 10, '(]')", - value: pgtype.Int8range{Lower: pgtype.Int8{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + SQL: "select Int8range(1, 10, '(]')", + Value: pgtype.Int8range{Lower: pgtype.Int8{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, }, }) } diff --git a/interval_test.go b/interval_test.go index db9614ef..28e77e0a 100644 --- a/interval_test.go +++ b/interval_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestIntervalTranscode(t *testing.T) { - testSuccessfulTranscode(t, "interval", []interface{}{ + testutil.TestSuccessfulTranscode(t, "interval", []interface{}{ pgtype.Interval{Microseconds: 1, Status: pgtype.Present}, pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, @@ -29,34 +30,34 @@ func TestIntervalTranscode(t *testing.T) { } func TestIntervalNormalize(t *testing.T) { - testSuccessfulNormalize(t, []normalizeTest{ + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { - sql: "select '1 second'::interval", - value: pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, + SQL: "select '1 second'::interval", + Value: pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, }, { - sql: "select '1.000001 second'::interval", - value: pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, + SQL: "select '1.000001 second'::interval", + Value: pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, }, { - sql: "select '34223 hours'::interval", - value: pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, + SQL: "select '34223 hours'::interval", + Value: pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, }, { - sql: "select '1 day'::interval", - value: pgtype.Interval{Days: 1, Status: pgtype.Present}, + SQL: "select '1 day'::interval", + Value: pgtype.Interval{Days: 1, Status: pgtype.Present}, }, { - sql: "select '1 month'::interval", - value: pgtype.Interval{Months: 1, Status: pgtype.Present}, + SQL: "select '1 month'::interval", + Value: pgtype.Interval{Months: 1, Status: pgtype.Present}, }, { - sql: "select '1 year'::interval", - value: pgtype.Interval{Months: 12, Status: pgtype.Present}, + SQL: "select '1 year'::interval", + Value: pgtype.Interval{Months: 12, Status: pgtype.Present}, }, { - sql: "select '-13 mon'::interval", - value: pgtype.Interval{Months: -13, Status: pgtype.Present}, + SQL: "select '-13 mon'::interval", + Value: pgtype.Interval{Months: -13, Status: pgtype.Present}, }, }) } diff --git a/json_test.go b/json_test.go index b0aa8c9b..6d7cccfd 100644 --- a/json_test.go +++ b/json_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestJsonTranscode(t *testing.T) { - testSuccessfulTranscode(t, "json", []interface{}{ + testutil.TestSuccessfulTranscode(t, "json", []interface{}{ pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, pgtype.Json{Bytes: []byte("null"), Status: pgtype.Present}, pgtype.Json{Bytes: []byte("42"), Status: pgtype.Present}, diff --git a/jsonb_test.go b/jsonb_test.go index 91637eb8..37c11858 100644 --- a/jsonb_test.go +++ b/jsonb_test.go @@ -6,16 +6,17 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestJsonbTranscode(t *testing.T) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) + conn := testutil.MustConnectPgx(t) + defer testutil.MustClose(t, conn) if _, ok := conn.ConnInfo.DataTypeForName("jsonb"); !ok { t.Skip("Skipping due to no jsonb type") } - testSuccessfulTranscode(t, "jsonb", []interface{}{ + testutil.TestSuccessfulTranscode(t, "jsonb", []interface{}{ pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, pgtype.Jsonb{Bytes: []byte("null"), Status: pgtype.Present}, pgtype.Jsonb{Bytes: []byte("42"), Status: pgtype.Present}, diff --git a/line_test.go b/line_test.go index 995eaad5..09e48019 100644 --- a/line_test.go +++ b/line_test.go @@ -5,15 +5,16 @@ import ( version "github.com/hashicorp/go-version" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestLineTranscode(t *testing.T) { - conn := mustConnectPgx(t) + conn := testutil.MustConnectPgx(t) serverVersion, err := version.NewVersion(conn.RuntimeParams["server_version"]) if err != nil { t.Fatalf("cannot get server version: %v", err) } - mustClose(t, conn) + testutil.MustClose(t, conn) minVersion := version.Must(version.NewVersion("9.4")) @@ -21,7 +22,7 @@ func TestLineTranscode(t *testing.T) { t.Skipf("Skipping line test for server version %v", serverVersion) } - testSuccessfulTranscode(t, "line", []interface{}{ + testutil.TestSuccessfulTranscode(t, "line", []interface{}{ &pgtype.Line{ A: 1.23, B: 4.56, C: 7.89, Status: pgtype.Present, diff --git a/lseg_test.go b/lseg_test.go index 5f041263..bd394e3c 100644 --- a/lseg_test.go +++ b/lseg_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestLsegTranscode(t *testing.T) { - testSuccessfulTranscode(t, "lseg", []interface{}{ + testutil.TestSuccessfulTranscode(t, "lseg", []interface{}{ &pgtype.Lseg{ P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}}, Status: pgtype.Present, diff --git a/macaddr_test.go b/macaddr_test.go index 6c7b8b89..c2542da3 100644 --- a/macaddr_test.go +++ b/macaddr_test.go @@ -7,10 +7,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestMacaddrTranscode(t *testing.T) { - testSuccessfulTranscode(t, "macaddr", []interface{}{ + testutil.TestSuccessfulTranscode(t, "macaddr", []interface{}{ pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, pgtype.Macaddr{Status: pgtype.Null}, }) diff --git a/name_test.go b/name_test.go index 81a766b8..348f8d39 100644 --- a/name_test.go +++ b/name_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestNameTranscode(t *testing.T) { - testSuccessfulTranscode(t, "name", []interface{}{ + testutil.TestSuccessfulTranscode(t, "name", []interface{}{ pgtype.Name{String: "", Status: pgtype.Present}, pgtype.Name{String: "foo", Status: pgtype.Present}, pgtype.Name{Status: pgtype.Null}, diff --git a/numeric_array_test.go b/numeric_array_test.go index af2e8e51..25531840 100644 --- a/numeric_array_test.go +++ b/numeric_array_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestNumericArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "numeric[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "numeric[]", []interface{}{ &pgtype.NumericArray{ Elements: nil, Dimensions: nil, diff --git a/numeric_test.go b/numeric_test.go index 93aa8866..d68a9347 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) // For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0) @@ -45,66 +46,66 @@ func mustParseBigInt(t *testing.T, src string) *big.Int { } func TestNumericNormalize(t *testing.T) { - testSuccessfulNormalize(t, []normalizeTest{ + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { - sql: "select '0'::numeric", - value: pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, + SQL: "select '0'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, }, { - sql: "select '1'::numeric", - value: pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, + SQL: "select '1'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, }, { - sql: "select '10.00'::numeric", - value: pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present}, + SQL: "select '10.00'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present}, }, { - sql: "select '1e-3'::numeric", - value: pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present}, + SQL: "select '1e-3'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present}, }, { - sql: "select '-1'::numeric", - value: pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, + SQL: "select '-1'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, }, { - sql: "select '10000'::numeric", - value: pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present}, + SQL: "select '10000'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present}, }, { - sql: "select '3.14'::numeric", - value: pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, + SQL: "select '3.14'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, }, { - sql: "select '1.1'::numeric", - value: pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present}, + SQL: "select '1.1'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present}, }, { - sql: "select '100010001'::numeric", - value: pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present}, + SQL: "select '100010001'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present}, }, { - sql: "select '100010001.0001'::numeric", - value: pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present}, + SQL: "select '100010001.0001'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present}, }, { - sql: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", - value: pgtype.Numeric{ + SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", + Value: pgtype.Numeric{ Int: mustParseBigInt(t, "423723478923478928934789237432487213832189417894318904389012483210893443219085471578891547854892438945012347981"), Exp: -41, Status: pgtype.Present, }, }, { - sql: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", - value: pgtype.Numeric{ + SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", + Value: pgtype.Numeric{ Int: mustParseBigInt(t, "8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), Exp: -196, Status: pgtype.Present, }, }, { - sql: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", - value: pgtype.Numeric{ + SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", + Value: pgtype.Numeric{ Int: mustParseBigInt(t, "123"), Exp: -186, Status: pgtype.Present, @@ -114,7 +115,7 @@ func TestNumericNormalize(t *testing.T) { } func TestNumericTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, @@ -172,7 +173,7 @@ func TestNumericTranscodeFuzz(t *testing.T) { } } - testSuccessfulTranscodeEqFunc(t, "numeric", values, + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, func(aa, bb interface{}) bool { a := aa.(pgtype.Numeric) b := bb.(pgtype.Numeric) diff --git a/numrange_test.go b/numrange_test.go index 81202362..81e73c38 100644 --- a/numrange_test.go +++ b/numrange_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestNumrangeTranscode(t *testing.T) { - testSuccessfulTranscode(t, "numrange", []interface{}{ + testutil.TestSuccessfulTranscode(t, "numrange", []interface{}{ pgtype.Numrange{ LowerType: pgtype.Empty, UpperType: pgtype.Empty, diff --git a/oid_value_test.go b/oid_value_test.go index 21dd6f9d..d3412159 100644 --- a/oid_value_test.go +++ b/oid_value_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestOidValueTranscode(t *testing.T) { - testSuccessfulTranscode(t, "oid", []interface{}{ + testutil.TestSuccessfulTranscode(t, "oid", []interface{}{ pgtype.OidValue{Uint: 42, Status: pgtype.Present}, pgtype.OidValue{Status: pgtype.Null}, }) diff --git a/path_test.go b/path_test.go index 4e5f7f62..d213a1b4 100644 --- a/path_test.go +++ b/path_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestPathTranscode(t *testing.T) { - testSuccessfulTranscode(t, "path", []interface{}{ + testutil.TestSuccessfulTranscode(t, "path", []interface{}{ &pgtype.Path{ P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}}, Closed: false, diff --git a/pgtype_test.go b/pgtype_test.go index f486f077..716e063d 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -1,17 +1,9 @@ package pgtype_test import ( - "context" - "database/sql" - "fmt" - "io" "net" - "os" - "reflect" "testing" - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" _ "github.com/jackc/pgx/stdlib" _ "github.com/lib/pq" ) @@ -28,48 +20,6 @@ type _float32Slice []float32 type _float64Slice []float64 type _byteSlice []byte -func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { - var sqlDriverName string - switch driverName { - case "github.com/lib/pq": - sqlDriverName = "postgres" - case "github.com/jackc/pgx/stdlib": - sqlDriverName = "pgx" - default: - t.Fatalf("Unknown driver %v", driverName) - } - - db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL")) - if err != nil { - t.Fatal(err) - } - - return db -} - -func mustConnectPgx(t testing.TB) *pgx.Conn { - config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) - if err != nil { - t.Fatal(err) - } - - conn, err := pgx.Connect(config) - if err != nil { - t.Fatal(err) - } - - return conn -} - -func mustClose(t testing.TB, conn interface { - Close() error -}) { - err := conn.Close() - if err != nil { - t.Fatal(err) - } -} - func mustParseCidr(t testing.TB, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { @@ -87,244 +37,3 @@ func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { return addr } - -type forceTextEncoder struct { - e pgtype.TextEncoder -} - -func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { - return f.e.EncodeText(ci, w) -} - -type forceBinaryEncoder struct { - e pgtype.BinaryEncoder -} - -func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { - return f.e.EncodeBinary(ci, w) -} - -func forceEncoder(e interface{}, formatCode int16) interface{} { - switch formatCode { - case pgx.TextFormatCode: - if e, ok := e.(pgtype.TextEncoder); ok { - return forceTextEncoder{e: e} - } - case pgx.BinaryFormatCode: - if e, ok := e.(pgtype.BinaryEncoder); ok { - return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} - } - } - return nil -} - -func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { - testSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - }) -} - -func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) - } -} - -func testPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) - - ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for i, v := range values { - for _, fc := range formats { - ps.FieldDescriptions[0].FormatCode = fc.formatCode - vEncoder := forceEncoder(v, fc.formatCode) - if vEncoder == nil { - t.Logf("Skipping: %#v does not implement %v", v, fc.name) - continue - } - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err := conn.QueryRow("test", forceEncoder(v, fc.formatCode)).Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", fc.name, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) - } - } - } -} - -func testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) - - for i, v := range values { - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err := conn.QueryRowEx( - context.Background(), - fmt.Sprintf("select ($1)::%s", pgTypeName), - &pgx.QueryExOptions{SimpleProtocol: true}, - v, - ).Scan(result.Interface()) - if err != nil { - t.Errorf("Simple protocol %d: %v", i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface()) - } - } -} - -func testDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := mustConnectDatabaseSQL(t, driverName) - defer mustClose(t, conn) - - ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - for i, v := range values { - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err := ps.QueryRow(v).Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", driverName, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) - } - } -} - -type normalizeTest struct { - sql string - value interface{} -} - -func testSuccessfulNormalize(t testing.TB, tests []normalizeTest) { - testSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - }) -} - -func testSuccessfulNormalizeEqFunc(t testing.TB, tests []normalizeTest, eqFunc func(a, b interface{}) bool) { - testPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - testDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc) - } -} - -func testPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []normalizeTest, eqFunc func(a, b interface{}) bool) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for i, tt := range tests { - for _, fc := range formats { - psName := fmt.Sprintf("test%d", i) - ps, err := conn.Prepare(psName, tt.sql) - if err != nil { - t.Fatal(err) - } - - ps.FieldDescriptions[0].FormatCode = fc.formatCode - if forceEncoder(tt.value, fc.formatCode) == nil { - t.Logf("Skipping: %#v does not implement %v", tt.value, fc.name) - continue - } - // Derefence value if it is a pointer - derefV := tt.value - refVal := reflect.ValueOf(tt.value) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err = conn.QueryRow(psName).Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", fc.name, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) - } - } - } -} - -func testDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []normalizeTest, eqFunc func(a, b interface{}) bool) { - conn := mustConnectDatabaseSQL(t, driverName) - defer mustClose(t, conn) - - for i, tt := range tests { - ps, err := conn.Prepare(tt.sql) - if err != nil { - t.Errorf("%d. %v", i, err) - continue - } - - // Derefence value if it is a pointer - derefV := tt.value - refVal := reflect.ValueOf(tt.value) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err = ps.QueryRow().Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", driverName, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) - } - } - -} diff --git a/point_test.go b/point_test.go index c921f794..f46b342d 100644 --- a/point_test.go +++ b/point_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestPointTranscode(t *testing.T) { - testSuccessfulTranscode(t, "point", []interface{}{ + testutil.TestSuccessfulTranscode(t, "point", []interface{}{ &pgtype.Point{P: pgtype.Vec2{1.234, 5.6789}, Status: pgtype.Present}, &pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Status: pgtype.Present}, &pgtype.Point{Status: pgtype.Null}, diff --git a/polygon_test.go b/polygon_test.go index 3a7e1431..48481dc5 100644 --- a/polygon_test.go +++ b/polygon_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestPolygonTranscode(t *testing.T) { - testSuccessfulTranscode(t, "polygon", []interface{}{ + testutil.TestSuccessfulTranscode(t, "polygon", []interface{}{ &pgtype.Polygon{ P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {5.0, 3.234}}, Status: pgtype.Present, diff --git a/qchar_test.go b/qchar_test.go index afac5016..b810b89c 100644 --- a/qchar_test.go +++ b/qchar_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestQCharTranscode(t *testing.T) { - testPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ + testutil.TestPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, pgtype.QChar{Int: -1, Status: pgtype.Present}, pgtype.QChar{Int: 0, Status: pgtype.Present}, diff --git a/record_test.go b/record_test.go index bc6e5893..df17501f 100644 --- a/record_test.go +++ b/record_test.go @@ -7,11 +7,12 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestRecordTranscode(t *testing.T) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) + conn := testutil.MustConnectPgx(t) + defer testutil.MustClose(t, conn) tests := []struct { sql string diff --git a/testutil/testutil.go b/testutil/testutil.go index 610f0710..d9aaa5c4 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -15,7 +15,7 @@ import ( _ "github.com/lib/pq" ) -func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { +func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { var sqlDriverName string switch driverName { case "github.com/lib/pq": @@ -34,7 +34,7 @@ func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { return db } -func mustConnectPgx(t testing.TB) *pgx.Conn { +func MustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) if err != nil { t.Fatal(err) @@ -48,7 +48,7 @@ func mustConnectPgx(t testing.TB) *pgx.Conn { return conn } -func mustClose(t testing.TB, conn interface { +func MustClose(t testing.TB, conn interface { Close() error }) { err := conn.Close() @@ -73,7 +73,7 @@ func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool return f.e.EncodeBinary(ci, w) } -func forceEncoder(e interface{}, formatCode int16) interface{} { +func ForceEncoder(e interface{}, formatCode int16) interface{} { switch formatCode { case pgx.TextFormatCode: if e, ok := e.(pgtype.TextEncoder); ok { @@ -102,8 +102,8 @@ func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int } func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) + conn := MustConnectPgx(t) + defer MustClose(t, conn) ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) if err != nil { @@ -121,7 +121,7 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] for i, v := range values { for _, fc := range formats { ps.FieldDescriptions[0].FormatCode = fc.formatCode - vEncoder := forceEncoder(v, fc.formatCode) + vEncoder := ForceEncoder(v, fc.formatCode) if vEncoder == nil { t.Logf("Skipping: %#v does not implement %v", v, fc.name) continue @@ -134,7 +134,7 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] } result := reflect.New(reflect.TypeOf(derefV)) - err := conn.QueryRow("test", forceEncoder(v, fc.formatCode)).Scan(result.Interface()) + err := conn.QueryRow("test", ForceEncoder(v, fc.formatCode)).Scan(result.Interface()) if err != nil { t.Errorf("%v %d: %v", fc.name, i, err) } @@ -147,8 +147,8 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] } func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) + conn := MustConnectPgx(t) + defer MustClose(t, conn) for i, v := range values { // Derefence value if it is a pointer @@ -176,8 +176,8 @@ func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName str } func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := mustConnectDatabaseSQL(t, driverName) - defer mustClose(t, conn) + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) if err != nil { @@ -223,8 +223,8 @@ func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc f } func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) + conn := MustConnectPgx(t) + defer MustClose(t, conn) formats := []struct { name string @@ -243,7 +243,7 @@ func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFun } ps.FieldDescriptions[0].FormatCode = fc.formatCode - if forceEncoder(tt.Value, fc.formatCode) == nil { + if ForceEncoder(tt.Value, fc.formatCode) == nil { t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name) continue } @@ -268,8 +268,8 @@ func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFun } func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { - conn := mustConnectDatabaseSQL(t, driverName) - defer mustClose(t, conn) + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) for i, tt := range tests { ps, err := conn.Prepare(tt.SQL) diff --git a/text_array_test.go b/text_array_test.go index 5a78d7bc..35ebef96 100644 --- a/text_array_test.go +++ b/text_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTextArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "text[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "text[]", []interface{}{ &pgtype.TextArray{ Elements: nil, Dimensions: nil, diff --git a/text_test.go b/text_test.go index 34b6a784..e4c1dbd8 100644 --- a/text_test.go +++ b/text_test.go @@ -6,11 +6,12 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTextTranscode(t *testing.T) { for _, pgTypeName := range []string{"text", "varchar"} { - testSuccessfulTranscode(t, pgTypeName, []interface{}{ + testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ pgtype.Text{String: "", Status: pgtype.Present}, pgtype.Text{String: "foo", Status: pgtype.Present}, pgtype.Text{Status: pgtype.Null}, diff --git a/tid_test.go b/tid_test.go index 56595ef4..7eb7773a 100644 --- a/tid_test.go +++ b/tid_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTidTranscode(t *testing.T) { - testSuccessfulTranscode(t, "tid", []interface{}{ + testutil.TestSuccessfulTranscode(t, "tid", []interface{}{ pgtype.Tid{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, pgtype.Tid{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, pgtype.Tid{Status: pgtype.Null}, diff --git a/timestamp_array_test.go b/timestamp_array_test.go index a15d3696..c75d101f 100644 --- a/timestamp_array_test.go +++ b/timestamp_array_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTimestampArrayTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "timestamp[]", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp[]", []interface{}{ &pgtype.TimestampArray{ Elements: nil, Dimensions: nil, diff --git a/timestamp_test.go b/timestamp_test.go index 58828806..c0427a5c 100644 --- a/timestamp_test.go +++ b/timestamp_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTimestampTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ pgtype.Timestamp{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, pgtype.Timestamp{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, diff --git a/timestamptz_array_test.go b/timestamptz_array_test.go index e0017828..50ee65d0 100644 --- a/timestamptz_array_test.go +++ b/timestamptz_array_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTimestamptzArrayTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "timestamptz[]", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz[]", []interface{}{ &pgtype.TimestamptzArray{ Elements: nil, Dimensions: nil, diff --git a/timestamptz_test.go b/timestamptz_test.go index 6ddfc1bc..bbc001e5 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTimestamptzTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, diff --git a/tsrange_test.go b/tsrange_test.go index 448cb92f..865233c2 100644 --- a/tsrange_test.go +++ b/tsrange_test.go @@ -5,10 +5,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTsrangeTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "tsrange", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "tsrange", []interface{}{ pgtype.Tsrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, pgtype.Tsrange{ Lower: pgtype.Timestamp{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, diff --git a/tstzrange_test.go b/tstzrange_test.go index 197aabbc..8eb00ab9 100644 --- a/tstzrange_test.go +++ b/tstzrange_test.go @@ -5,10 +5,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTstzrangeTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "tstzrange", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "tstzrange", []interface{}{ pgtype.Tstzrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, pgtype.Tstzrange{ Lower: pgtype.Timestamptz{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, diff --git a/uuid_test.go b/uuid_test.go index 1eba7e90..b745d542 100644 --- a/uuid_test.go +++ b/uuid_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestUuidTranscode(t *testing.T) { - testSuccessfulTranscode(t, "uuid", []interface{}{ + testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, pgtype.Uuid{Status: pgtype.Null}, }) diff --git a/varbit_test.go b/varbit_test.go index cd146d26..6c813aae 100644 --- a/varbit_test.go +++ b/varbit_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestVarbitTranscode(t *testing.T) { - testSuccessfulTranscode(t, "varbit", []interface{}{ + testutil.TestSuccessfulTranscode(t, "varbit", []interface{}{ &pgtype.Varbit{Bytes: []byte{}, Len: 0, Status: pgtype.Present}, &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Status: pgtype.Present}, &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Status: pgtype.Present}, @@ -16,10 +17,10 @@ func TestVarbitTranscode(t *testing.T) { } func TestVarbitNormalize(t *testing.T) { - testSuccessfulNormalize(t, []normalizeTest{ + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { - sql: "select B'111111111'", - value: &pgtype.Varbit{Bytes: []byte{255, 128}, Len: 9, Status: pgtype.Present}, + SQL: "select B'111111111'", + Value: &pgtype.Varbit{Bytes: []byte{255, 128}, Len: 9, Status: pgtype.Present}, }, }) } diff --git a/varchar_array_test.go b/varchar_array_test.go index 4a8b09b8..7d6fb39b 100644 --- a/varchar_array_test.go +++ b/varchar_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestVarcharArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "varchar[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "varchar[]", []interface{}{ &pgtype.VarcharArray{ Elements: nil, Dimensions: nil, diff --git a/xid_test.go b/xid_test.go index 11dd0615..868c101e 100644 --- a/xid_test.go +++ b/xid_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestXidTranscode(t *testing.T) { @@ -17,13 +18,13 @@ func TestXidTranscode(t *testing.T) { return reflect.DeepEqual(a, b) } - testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) // No direct conversion from int to xid, convert through text - testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) + testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) } } From d94f8daeb1f6fcef85749d3491c7fb5b06d3c3eb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Apr 2017 13:08:05 -0500 Subject: [PATCH 0072/1158] Use pointer methods for all struct pgtypes Now no need to no whether certain interfaces are implemented by struct or pointer to struct. --- aclitem.go | 4 ++-- aclitem_test.go | 6 +++--- bool.go | 6 +++--- bool_test.go | 6 +++--- bytea.go | 6 +++--- bytea_test.go | 6 +++--- cid.go | 12 ++++++------ cid_test.go | 4 ++-- cidr.go | 8 ++++---- date.go | 6 +++--- date_test.go | 18 +++++++++--------- daterange.go | 6 +++--- daterange_test.go | 8 ++++---- float4.go | 6 +++--- float4_test.go | 12 ++++++------ float8.go | 6 +++--- float8_test.go | 12 ++++++------ generic_binary.go | 8 ++++---- generic_text.go | 8 ++++---- hstore.go | 6 +++--- hstore_test.go | 28 ++++++++++++++-------------- inet.go | 6 +++--- inet_test.go | 22 +++++++++++----------- int2.go | 8 ++++---- int2_test.go | 12 ++++++------ int4.go | 8 ++++---- int4_test.go | 12 ++++++------ int4range.go | 6 +++--- int4range_test.go | 8 ++++---- int8.go | 8 ++++---- int8_test.go | 12 ++++++------ int8range.go | 6 +++--- int8range_test.go | 8 ++++---- interval.go | 6 +++--- interval_test.go | 34 +++++++++++++++++----------------- json.go | 6 +++--- json_test.go | 10 +++++----- jsonb.go | 10 +++++----- jsonb_test.go | 10 +++++----- macaddr.go | 6 +++--- macaddr_test.go | 4 ++-- name.go | 12 ++++++------ name_test.go | 6 +++--- numrange.go | 6 +++--- numrange_test.go | 8 ++++---- oid_value.go | 12 ++++++------ oid_value_test.go | 4 ++-- pguint32.go | 6 +++--- qchar.go | 2 +- text.go | 8 ++++---- text_test.go | 6 +++--- tid.go | 6 +++--- tid_test.go | 6 +++--- timestamp.go | 6 +++--- timestamp_test.go | 26 +++++++++++++------------- timestamptz.go | 6 +++--- timestamptz_test.go | 26 +++++++++++++------------- tsrange.go | 6 +++--- tsrange_test.go | 8 ++++---- tstzrange.go | 6 +++--- tstzrange_test.go | 8 ++++---- unknown.go | 4 ++-- uuid.go | 6 +++--- uuid_test.go | 4 ++-- varchar.go | 16 ++++++++-------- xid.go | 12 ++++++------ xid_test.go | 4 ++-- 67 files changed, 302 insertions(+), 302 deletions(-) diff --git a/aclitem.go b/aclitem.go index 3ccf8318..ebfcc3e7 100644 --- a/aclitem.go +++ b/aclitem.go @@ -83,7 +83,7 @@ func (dst *Aclitem) DecodeText(ci *ConnInfo, src []byte) error { return nil } -func (src Aclitem) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Aclitem) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -113,7 +113,7 @@ func (dst *Aclitem) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Aclitem) Value() (driver.Value, error) { +func (src *Aclitem) Value() (driver.Value, error) { switch src.Status { case Present: return src.String, nil diff --git a/aclitem_test.go b/aclitem_test.go index 5389eab2..13c63395 100644 --- a/aclitem_test.go +++ b/aclitem_test.go @@ -10,9 +10,9 @@ import ( func TestAclitemTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "aclitem", []interface{}{ - pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - pgtype.Aclitem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - pgtype.Aclitem{Status: pgtype.Null}, + &pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + &pgtype.Aclitem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + &pgtype.Aclitem{Status: pgtype.Null}, }) } diff --git a/bool.go b/bool.go index 1ebf590b..9d309f0c 100644 --- a/bool.go +++ b/bool.go @@ -90,7 +90,7 @@ func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Bool) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Bool) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -109,7 +109,7 @@ func (src Bool) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Bool) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Bool) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -149,7 +149,7 @@ func (dst *Bool) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Bool) Value() (driver.Value, error) { +func (src *Bool) Value() (driver.Value, error) { switch src.Status { case Present: return src.Bool, nil diff --git a/bool_test.go b/bool_test.go index 31f3d528..2712e3b0 100644 --- a/bool_test.go +++ b/bool_test.go @@ -10,9 +10,9 @@ import ( func TestBoolTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "bool", []interface{}{ - pgtype.Bool{Bool: false, Status: pgtype.Present}, - pgtype.Bool{Bool: true, Status: pgtype.Present}, - pgtype.Bool{Bool: false, Status: pgtype.Null}, + &pgtype.Bool{Bool: false, Status: pgtype.Present}, + &pgtype.Bool{Bool: true, Status: pgtype.Present}, + &pgtype.Bool{Bool: false, Status: pgtype.Null}, }) } diff --git a/bytea.go b/bytea.go index 8bf5de2b..3e2661db 100644 --- a/bytea.go +++ b/bytea.go @@ -102,7 +102,7 @@ func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Bytea) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Bytea) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -119,7 +119,7 @@ func (src Bytea) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Bytea) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Bytea) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -152,7 +152,7 @@ func (dst *Bytea) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Bytea) Value() (driver.Value, error) { +func (src *Bytea) Value() (driver.Value, error) { switch src.Status { case Present: return src.Bytes, nil diff --git a/bytea_test.go b/bytea_test.go index 7d32e294..fd5a0dec 100644 --- a/bytea_test.go +++ b/bytea_test.go @@ -10,9 +10,9 @@ import ( func TestByteaTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "bytea", []interface{}{ - pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, - pgtype.Bytea{Bytes: nil, Status: pgtype.Null}, + &pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + &pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, + &pgtype.Bytea{Bytes: nil, Status: pgtype.Null}, }) } diff --git a/cid.go b/cid.go index 63ba6a2f..c2b3073b 100644 --- a/cid.go +++ b/cid.go @@ -43,12 +43,12 @@ func (dst *Cid) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src Cid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(ci, w) +func (src *Cid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*pguint32)(src).EncodeText(ci, w) } -func (src Cid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(ci, w) +func (src *Cid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*pguint32)(src).EncodeBinary(ci, w) } // Scan implements the database/sql Scanner interface. @@ -57,6 +57,6 @@ func (dst *Cid) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Cid) Value() (driver.Value, error) { - return (pguint32)(src).Value() +func (src *Cid) Value() (driver.Value, error) { + return (*pguint32)(src).Value() } diff --git a/cid_test.go b/cid_test.go index 385b8cac..c3bf3132 100644 --- a/cid_test.go +++ b/cid_test.go @@ -11,8 +11,8 @@ import ( func TestCidTranscode(t *testing.T) { pgTypeName := "cid" values := []interface{}{ - pgtype.Cid{Uint: 42, Status: pgtype.Present}, - pgtype.Cid{Status: pgtype.Null}, + &pgtype.Cid{Uint: 42, Status: pgtype.Present}, + &pgtype.Cid{Status: pgtype.Null}, } eqFunc := func(a, b interface{}) bool { return reflect.DeepEqual(a, b) diff --git a/cidr.go b/cidr.go index 463b279d..39a87a26 100644 --- a/cidr.go +++ b/cidr.go @@ -26,10 +26,10 @@ func (dst *Cidr) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Inet)(dst).DecodeBinary(ci, src) } -func (src Cidr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (Inet)(src).EncodeText(ci, w) +func (src *Cidr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Inet)(src).EncodeText(ci, w) } -func (src Cidr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (Inet)(src).EncodeBinary(ci, w) +func (src *Cidr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Inet)(src).EncodeBinary(ci, w) } diff --git a/date.go b/date.go index 34753f05..993a04c5 100644 --- a/date.go +++ b/date.go @@ -125,7 +125,7 @@ func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Date) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Date) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -148,7 +148,7 @@ func (src Date) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -195,7 +195,7 @@ func (dst *Date) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Date) Value() (driver.Value, error) { +func (src *Date) Value() (driver.Value, error) { switch src.Status { case Present: if src.InfinityModifier != None { diff --git a/date_test.go b/date_test.go index d1493f5e..d98e1652 100644 --- a/date_test.go +++ b/date_test.go @@ -11,15 +11,15 @@ import ( func TestDateTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "date", []interface{}{ - pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Status: pgtype.Null}, - pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, - pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + &pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Status: pgtype.Null}, + &pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + &pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, }, func(a, b interface{}) bool { at := a.(pgtype.Date) bt := b.(pgtype.Date) diff --git a/daterange.go b/daterange.go index fbf51980..d78c4803 100644 --- a/daterange.go +++ b/daterange.go @@ -106,7 +106,7 @@ func (dst *Daterange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Daterange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Daterange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Daterange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src Daterange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Daterange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -263,6 +263,6 @@ func (dst *Daterange) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Daterange) Value() (driver.Value, error) { +func (src *Daterange) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/daterange_test.go b/daterange_test.go index 7dfae0f4..d2af5986 100644 --- a/daterange_test.go +++ b/daterange_test.go @@ -10,22 +10,22 @@ import ( func TestDaterangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "daterange", []interface{}{ - pgtype.Daterange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - pgtype.Daterange{ + &pgtype.Daterange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Daterange{ Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present, }, - pgtype.Daterange{ + &pgtype.Daterange{ Lower: pgtype.Date{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, Upper: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present, }, - pgtype.Daterange{Status: pgtype.Null}, + &pgtype.Daterange{Status: pgtype.Null}, }, func(aa, bb interface{}) bool { a := aa.(pgtype.Daterange) b := bb.(pgtype.Daterange) diff --git a/float4.go b/float4.go index e92149a6..76be4203 100644 --- a/float4.go +++ b/float4.go @@ -139,7 +139,7 @@ func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Float4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -151,7 +151,7 @@ func (src Float4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Float4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -184,7 +184,7 @@ func (dst *Float4) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Float4) Value() (driver.Value, error) { +func (src *Float4) Value() (driver.Value, error) { switch src.Status { case Present: return float64(src.Float), nil diff --git a/float4_test.go b/float4_test.go index 57f4bc34..2ed8d05d 100644 --- a/float4_test.go +++ b/float4_test.go @@ -10,12 +10,12 @@ import ( func TestFloat4Transcode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "float4", []interface{}{ - pgtype.Float4{Float: -1, Status: pgtype.Present}, - pgtype.Float4{Float: 0, Status: pgtype.Present}, - pgtype.Float4{Float: 0.00001, Status: pgtype.Present}, - pgtype.Float4{Float: 1, Status: pgtype.Present}, - pgtype.Float4{Float: 9999.99, Status: pgtype.Present}, - pgtype.Float4{Float: 0, Status: pgtype.Null}, + &pgtype.Float4{Float: -1, Status: pgtype.Present}, + &pgtype.Float4{Float: 0, Status: pgtype.Present}, + &pgtype.Float4{Float: 0.00001, Status: pgtype.Present}, + &pgtype.Float4{Float: 1, Status: pgtype.Present}, + &pgtype.Float4{Float: 9999.99, Status: pgtype.Present}, + &pgtype.Float4{Float: 0, Status: pgtype.Null}, }) } diff --git a/float8.go b/float8.go index 4d094757..8cfc53c5 100644 --- a/float8.go +++ b/float8.go @@ -129,7 +129,7 @@ func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Float8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -141,7 +141,7 @@ func (src Float8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Float8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -174,7 +174,7 @@ func (dst *Float8) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Float8) Value() (driver.Value, error) { +func (src *Float8) Value() (driver.Value, error) { switch src.Status { case Present: return src.Float, nil diff --git a/float8_test.go b/float8_test.go index b7527b86..46fc8d5d 100644 --- a/float8_test.go +++ b/float8_test.go @@ -10,12 +10,12 @@ import ( func TestFloat8Transcode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "float8", []interface{}{ - pgtype.Float8{Float: -1, Status: pgtype.Present}, - pgtype.Float8{Float: 0, Status: pgtype.Present}, - pgtype.Float8{Float: 0.00001, Status: pgtype.Present}, - pgtype.Float8{Float: 1, Status: pgtype.Present}, - pgtype.Float8{Float: 9999.99, Status: pgtype.Present}, - pgtype.Float8{Float: 0, Status: pgtype.Null}, + &pgtype.Float8{Float: -1, Status: pgtype.Present}, + &pgtype.Float8{Float: 0, Status: pgtype.Present}, + &pgtype.Float8{Float: 0.00001, Status: pgtype.Present}, + &pgtype.Float8{Float: 1, Status: pgtype.Present}, + &pgtype.Float8{Float: 9999.99, Status: pgtype.Present}, + &pgtype.Float8{Float: 0, Status: pgtype.Null}, }) } diff --git a/generic_binary.go b/generic_binary.go index f834bfb2..094bd64e 100644 --- a/generic_binary.go +++ b/generic_binary.go @@ -25,8 +25,8 @@ func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Bytea)(dst).DecodeBinary(ci, src) } -func (src GenericBinary) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (Bytea)(src).EncodeBinary(ci, w) +func (src *GenericBinary) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Bytea)(src).EncodeBinary(ci, w) } // Scan implements the database/sql Scanner interface. @@ -35,6 +35,6 @@ func (dst *GenericBinary) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src GenericBinary) Value() (driver.Value, error) { - return (Bytea)(src).Value() +func (src *GenericBinary) Value() (driver.Value, error) { + return (*Bytea)(src).Value() } diff --git a/generic_text.go b/generic_text.go index 053ec504..5d0d83be 100644 --- a/generic_text.go +++ b/generic_text.go @@ -25,8 +25,8 @@ func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeText(ci, src) } -func (src GenericText) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (Text)(src).EncodeText(ci, w) +func (src *GenericText) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Text)(src).EncodeText(ci, w) } // Scan implements the database/sql Scanner interface. @@ -35,6 +35,6 @@ func (dst *GenericText) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src GenericText) Value() (driver.Value, error) { - return (Text)(src).Value() +func (src *GenericText) Value() (driver.Value, error) { + return (*Text)(src).Value() } diff --git a/hstore.go b/hstore.go index 5dc78671..3d55f783 100644 --- a/hstore.go +++ b/hstore.go @@ -151,7 +151,7 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Hstore) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Hstore) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -203,7 +203,7 @@ func (src Hstore) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src Hstore) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Hstore) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -462,6 +462,6 @@ func (dst *Hstore) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Hstore) Value() (driver.Value, error) { +func (src *Hstore) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/hstore_test.go b/hstore_test.go index 502a8df0..dc2439fc 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -14,12 +14,12 @@ func TestHstoreTranscode(t *testing.T) { } values := []interface{}{ - pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, - pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, - pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, - pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, - pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, - pgtype.Hstore{Status: pgtype.Null}, + &pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + &pgtype.Hstore{Status: pgtype.Null}, } specialStrings := []string{ @@ -33,16 +33,16 @@ func TestHstoreTranscode(t *testing.T) { } for _, s := range specialStrings { // Special key values - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key // Special value values - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key } testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { diff --git a/inet.go b/inet.go index 09fce04d..62734088 100644 --- a/inet.go +++ b/inet.go @@ -149,7 +149,7 @@ func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Inet) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Inet) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -162,7 +162,7 @@ func (src Inet) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } // EncodeBinary encodes src into w. -func (src Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -220,6 +220,6 @@ func (dst *Inet) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Inet) Value() (driver.Value, error) { +func (src *Inet) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/inet_test.go b/inet_test.go index 532e9abe..b883df8e 100644 --- a/inet_test.go +++ b/inet_test.go @@ -12,17 +12,17 @@ import ( func TestInetTranscode(t *testing.T) { for _, pgTypeName := range []string{"inet", "cidr"} { testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ - pgtype.Inet{IPNet: mustParseCidr(t, "0.0.0.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "192.168.1.0/24"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "255.255.255.255/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "::/128"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "::/0"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "::1/128"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - pgtype.Inet{Status: pgtype.Null}, + &pgtype.Inet{IPNet: mustParseCidr(t, "0.0.0.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "192.168.1.0/24"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "255.255.255.255/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "::/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "::/0"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "::1/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + &pgtype.Inet{Status: pgtype.Null}, }) } } diff --git a/int2.go b/int2.go index 0cb6ef82..4a3beb22 100644 --- a/int2.go +++ b/int2.go @@ -134,7 +134,7 @@ func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int2) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int2) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -146,7 +146,7 @@ func (src Int2) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Int2) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int2) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -185,7 +185,7 @@ func (dst *Int2) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Int2) Value() (driver.Value, error) { +func (src *Int2) Value() (driver.Value, error) { switch src.Status { case Present: return int64(src.Int), nil @@ -196,7 +196,7 @@ func (src Int2) Value() (driver.Value, error) { } } -func (src Int2) MarshalJSON() ([]byte, error) { +func (src *Int2) MarshalJSON() ([]byte, error) { switch src.Status { case Present: return []byte(strconv.FormatInt(int64(src.Int), 10)), nil diff --git a/int2_test.go b/int2_test.go index d81405a6..d20bf0ed 100644 --- a/int2_test.go +++ b/int2_test.go @@ -11,12 +11,12 @@ import ( func TestInt2Transcode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ - pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}, - pgtype.Int2{Int: -1, Status: pgtype.Present}, - pgtype.Int2{Int: 0, Status: pgtype.Present}, - pgtype.Int2{Int: 1, Status: pgtype.Present}, - pgtype.Int2{Int: math.MaxInt16, Status: pgtype.Present}, - pgtype.Int2{Int: 0, Status: pgtype.Null}, + &pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}, + &pgtype.Int2{Int: -1, Status: pgtype.Present}, + &pgtype.Int2{Int: 0, Status: pgtype.Present}, + &pgtype.Int2{Int: 1, Status: pgtype.Present}, + &pgtype.Int2{Int: math.MaxInt16, Status: pgtype.Present}, + &pgtype.Int2{Int: 0, Status: pgtype.Null}, }) } diff --git a/int4.go b/int4.go index 4a5bca51..f429d887 100644 --- a/int4.go +++ b/int4.go @@ -125,7 +125,7 @@ func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -137,7 +137,7 @@ func (src Int4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Int4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -176,7 +176,7 @@ func (dst *Int4) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Int4) Value() (driver.Value, error) { +func (src *Int4) Value() (driver.Value, error) { switch src.Status { case Present: return int64(src.Int), nil @@ -187,7 +187,7 @@ func (src Int4) Value() (driver.Value, error) { } } -func (src Int4) MarshalJSON() ([]byte, error) { +func (src *Int4) MarshalJSON() ([]byte, error) { switch src.Status { case Present: return []byte(strconv.FormatInt(int64(src.Int), 10)), nil diff --git a/int4_test.go b/int4_test.go index 1354b47a..02f5409f 100644 --- a/int4_test.go +++ b/int4_test.go @@ -11,12 +11,12 @@ import ( func TestInt4Transcode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ - pgtype.Int4{Int: math.MinInt32, Status: pgtype.Present}, - pgtype.Int4{Int: -1, Status: pgtype.Present}, - pgtype.Int4{Int: 0, Status: pgtype.Present}, - pgtype.Int4{Int: 1, Status: pgtype.Present}, - pgtype.Int4{Int: math.MaxInt32, Status: pgtype.Present}, - pgtype.Int4{Int: 0, Status: pgtype.Null}, + &pgtype.Int4{Int: math.MinInt32, Status: pgtype.Present}, + &pgtype.Int4{Int: -1, Status: pgtype.Present}, + &pgtype.Int4{Int: 0, Status: pgtype.Present}, + &pgtype.Int4{Int: 1, Status: pgtype.Present}, + &pgtype.Int4{Int: math.MaxInt32, Status: pgtype.Present}, + &pgtype.Int4{Int: 0, Status: pgtype.Null}, }) } diff --git a/int4range.go b/int4range.go index cac4484c..8b04cf3c 100644 --- a/int4range.go +++ b/int4range.go @@ -106,7 +106,7 @@ func (dst *Int4range) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int4range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int4range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Int4range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src Int4range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int4range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -263,6 +263,6 @@ func (dst *Int4range) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Int4range) Value() (driver.Value, error) { +func (src *Int4range) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/int4range_test.go b/int4range_test.go index 74a91e59..088097d8 100644 --- a/int4range_test.go +++ b/int4range_test.go @@ -9,10 +9,10 @@ import ( func TestInt4rangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "int4range", []interface{}{ - pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int4{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - pgtype.Int4range{Status: pgtype.Null}, + &pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int4{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int4range{Status: pgtype.Null}, }) } diff --git a/int8.go b/int8.go index 0cc3545d..97db8393 100644 --- a/int8.go +++ b/int8.go @@ -117,7 +117,7 @@ func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -129,7 +129,7 @@ func (src Int8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Int8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -162,7 +162,7 @@ func (dst *Int8) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Int8) Value() (driver.Value, error) { +func (src *Int8) Value() (driver.Value, error) { switch src.Status { case Present: return int64(src.Int), nil @@ -173,7 +173,7 @@ func (src Int8) Value() (driver.Value, error) { } } -func (src Int8) MarshalJSON() ([]byte, error) { +func (src *Int8) MarshalJSON() ([]byte, error) { switch src.Status { case Present: return []byte(strconv.FormatInt(src.Int, 10)), nil diff --git a/int8_test.go b/int8_test.go index d6752205..0b3bb3eb 100644 --- a/int8_test.go +++ b/int8_test.go @@ -11,12 +11,12 @@ import ( func TestInt8Transcode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ - pgtype.Int8{Int: math.MinInt64, Status: pgtype.Present}, - pgtype.Int8{Int: -1, Status: pgtype.Present}, - pgtype.Int8{Int: 0, Status: pgtype.Present}, - pgtype.Int8{Int: 1, Status: pgtype.Present}, - pgtype.Int8{Int: math.MaxInt64, Status: pgtype.Present}, - pgtype.Int8{Int: 0, Status: pgtype.Null}, + &pgtype.Int8{Int: math.MinInt64, Status: pgtype.Present}, + &pgtype.Int8{Int: -1, Status: pgtype.Present}, + &pgtype.Int8{Int: 0, Status: pgtype.Present}, + &pgtype.Int8{Int: 1, Status: pgtype.Present}, + &pgtype.Int8{Int: math.MaxInt64, Status: pgtype.Present}, + &pgtype.Int8{Int: 0, Status: pgtype.Null}, }) } diff --git a/int8range.go b/int8range.go index 44946be9..f8e056cb 100644 --- a/int8range.go +++ b/int8range.go @@ -106,7 +106,7 @@ func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -263,6 +263,6 @@ func (dst *Int8range) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Int8range) Value() (driver.Value, error) { +func (src *Int8range) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/int8range_test.go b/int8range_test.go index 703f476e..c039ec65 100644 --- a/int8range_test.go +++ b/int8range_test.go @@ -9,10 +9,10 @@ import ( func TestInt8rangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "Int8range", []interface{}{ - pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int8{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - pgtype.Int8range{Status: pgtype.Null}, + &pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int8{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int8range{Status: pgtype.Null}, }) } diff --git a/interval.go b/interval.go index 20a4a419..1cbdffc3 100644 --- a/interval.go +++ b/interval.go @@ -178,7 +178,7 @@ func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Interval) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Interval) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -227,7 +227,7 @@ func (src Interval) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } // EncodeBinary encodes src into w. -func (src Interval) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Interval) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -266,6 +266,6 @@ func (dst *Interval) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Interval) Value() (driver.Value, error) { +func (src *Interval) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/interval_test.go b/interval_test.go index 28e77e0a..18e21ddd 100644 --- a/interval_test.go +++ b/interval_test.go @@ -9,23 +9,23 @@ import ( func TestIntervalTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "interval", []interface{}{ - pgtype.Interval{Microseconds: 1, Status: pgtype.Present}, - pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, - pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, - pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, - pgtype.Interval{Days: 1, Status: pgtype.Present}, - pgtype.Interval{Months: 1, Status: pgtype.Present}, - pgtype.Interval{Months: 12, Status: pgtype.Present}, - pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Status: pgtype.Present}, - pgtype.Interval{Microseconds: -1, Status: pgtype.Present}, - pgtype.Interval{Microseconds: -1000000, Status: pgtype.Present}, - pgtype.Interval{Microseconds: -1000001, Status: pgtype.Present}, - pgtype.Interval{Microseconds: -123202800000000, Status: pgtype.Present}, - pgtype.Interval{Days: -1, Status: pgtype.Present}, - pgtype.Interval{Months: -1, Status: pgtype.Present}, - pgtype.Interval{Months: -12, Status: pgtype.Present}, - pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Status: pgtype.Present}, - pgtype.Interval{Status: pgtype.Null}, + &pgtype.Interval{Microseconds: 1, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, + &pgtype.Interval{Days: 1, Status: pgtype.Present}, + &pgtype.Interval{Months: 1, Status: pgtype.Present}, + &pgtype.Interval{Months: 12, Status: pgtype.Present}, + &pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: -1, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: -1000000, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: -1000001, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: -123202800000000, Status: pgtype.Present}, + &pgtype.Interval{Days: -1, Status: pgtype.Present}, + &pgtype.Interval{Months: -1, Status: pgtype.Present}, + &pgtype.Interval{Months: -12, Status: pgtype.Present}, + &pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Status: pgtype.Present}, + &pgtype.Interval{Status: pgtype.Null}, }) } diff --git a/json.go b/json.go index b1c061f9..a027a91c 100644 --- a/json.go +++ b/json.go @@ -108,7 +108,7 @@ func (dst *Json) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } -func (src Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -120,7 +120,7 @@ func (src Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Json) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Json) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return src.EncodeText(ci, w) } @@ -142,7 +142,7 @@ func (dst *Json) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Json) Value() (driver.Value, error) { +func (src *Json) Value() (driver.Value, error) { switch src.Status { case Present: return string(src.Bytes), nil diff --git a/json_test.go b/json_test.go index 6d7cccfd..3d8d2a68 100644 --- a/json_test.go +++ b/json_test.go @@ -11,11 +11,11 @@ import ( func TestJsonTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "json", []interface{}{ - pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, - pgtype.Json{Bytes: []byte("null"), Status: pgtype.Present}, - pgtype.Json{Bytes: []byte("42"), Status: pgtype.Present}, - pgtype.Json{Bytes: []byte(`"hello"`), Status: pgtype.Present}, - pgtype.Json{Status: pgtype.Null}, + &pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, + &pgtype.Json{Bytes: []byte("null"), Status: pgtype.Present}, + &pgtype.Json{Bytes: []byte("42"), Status: pgtype.Present}, + &pgtype.Json{Bytes: []byte(`"hello"`), Status: pgtype.Present}, + &pgtype.Json{Status: pgtype.Null}, }) } diff --git a/jsonb.go b/jsonb.go index f47476d6..82cbb21f 100644 --- a/jsonb.go +++ b/jsonb.go @@ -47,11 +47,11 @@ func (dst *Jsonb) DecodeBinary(ci *ConnInfo, src []byte) error { } -func (src Jsonb) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (Json)(src).EncodeText(ci, w) +func (src *Jsonb) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Json)(src).EncodeText(ci, w) } -func (src Jsonb) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Jsonb) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -74,6 +74,6 @@ func (dst *Jsonb) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Jsonb) Value() (driver.Value, error) { - return (Json)(src).Value() +func (src *Jsonb) Value() (driver.Value, error) { + return (*Json)(src).Value() } diff --git a/jsonb_test.go b/jsonb_test.go index 37c11858..86c8a12c 100644 --- a/jsonb_test.go +++ b/jsonb_test.go @@ -17,11 +17,11 @@ func TestJsonbTranscode(t *testing.T) { } testutil.TestSuccessfulTranscode(t, "jsonb", []interface{}{ - pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, - pgtype.Jsonb{Bytes: []byte("null"), Status: pgtype.Present}, - pgtype.Jsonb{Bytes: []byte("42"), Status: pgtype.Present}, - pgtype.Jsonb{Bytes: []byte(`"hello"`), Status: pgtype.Present}, - pgtype.Jsonb{Status: pgtype.Null}, + &pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, + &pgtype.Jsonb{Bytes: []byte("null"), Status: pgtype.Present}, + &pgtype.Jsonb{Bytes: []byte("42"), Status: pgtype.Present}, + &pgtype.Jsonb{Bytes: []byte(`"hello"`), Status: pgtype.Present}, + &pgtype.Jsonb{Status: pgtype.Null}, }) } diff --git a/macaddr.go b/macaddr.go index 2834d69f..cfbb513d 100644 --- a/macaddr.go +++ b/macaddr.go @@ -106,7 +106,7 @@ func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Macaddr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Macaddr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -119,7 +119,7 @@ func (src Macaddr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } // EncodeBinary encodes src into w. -func (src Macaddr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Macaddr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -149,6 +149,6 @@ func (dst *Macaddr) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Macaddr) Value() (driver.Value, error) { +func (src *Macaddr) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/macaddr_test.go b/macaddr_test.go index c2542da3..5d329249 100644 --- a/macaddr_test.go +++ b/macaddr_test.go @@ -12,8 +12,8 @@ import ( func TestMacaddrTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "macaddr", []interface{}{ - pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - pgtype.Macaddr{Status: pgtype.Null}, + &pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + &pgtype.Macaddr{Status: pgtype.Null}, }) } diff --git a/name.go b/name.go index cc4ae23b..05e92563 100644 --- a/name.go +++ b/name.go @@ -40,12 +40,12 @@ func (dst *Name) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } -func (src Name) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (Text)(src).EncodeText(ci, w) +func (src *Name) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Text)(src).EncodeText(ci, w) } -func (src Name) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (Text)(src).EncodeBinary(ci, w) +func (src *Name) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Text)(src).EncodeBinary(ci, w) } // Scan implements the database/sql Scanner interface. @@ -54,6 +54,6 @@ func (dst *Name) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Name) Value() (driver.Value, error) { - return (Text)(src).Value() +func (src *Name) Value() (driver.Value, error) { + return (*Text)(src).Value() } diff --git a/name_test.go b/name_test.go index 348f8d39..ec0820c4 100644 --- a/name_test.go +++ b/name_test.go @@ -10,9 +10,9 @@ import ( func TestNameTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "name", []interface{}{ - pgtype.Name{String: "", Status: pgtype.Present}, - pgtype.Name{String: "foo", Status: pgtype.Present}, - pgtype.Name{Status: pgtype.Null}, + &pgtype.Name{String: "", Status: pgtype.Present}, + &pgtype.Name{String: "foo", Status: pgtype.Present}, + &pgtype.Name{Status: pgtype.Null}, }) } diff --git a/numrange.go b/numrange.go index cf42dcbd..a1b5b184 100644 --- a/numrange.go +++ b/numrange.go @@ -106,7 +106,7 @@ func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -263,6 +263,6 @@ func (dst *Numrange) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Numrange) Value() (driver.Value, error) { +func (src *Numrange) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/numrange_test.go b/numrange_test.go index 81e73c38..32267c86 100644 --- a/numrange_test.go +++ b/numrange_test.go @@ -10,25 +10,25 @@ import ( func TestNumrangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "numrange", []interface{}{ - pgtype.Numrange{ + &pgtype.Numrange{ LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present, }, - pgtype.Numrange{ + &pgtype.Numrange{ Lower: pgtype.Numeric{Int: big.NewInt(-543), Exp: 3, Status: pgtype.Present}, Upper: pgtype.Numeric{Int: big.NewInt(342), Exp: 1, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present, }, - pgtype.Numrange{ + &pgtype.Numrange{ Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, Upper: pgtype.Numeric{Int: big.NewInt(-5), Exp: 0, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present, }, - pgtype.Numrange{Status: pgtype.Null}, + &pgtype.Numrange{Status: pgtype.Null}, }) } diff --git a/oid_value.go b/oid_value.go index cb03802e..4a7de921 100644 --- a/oid_value.go +++ b/oid_value.go @@ -37,12 +37,12 @@ func (dst *OidValue) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src OidValue) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(ci, w) +func (src *OidValue) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*pguint32)(src).EncodeText(ci, w) } -func (src OidValue) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(ci, w) +func (src *OidValue) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*pguint32)(src).EncodeBinary(ci, w) } // Scan implements the database/sql Scanner interface. @@ -51,6 +51,6 @@ func (dst *OidValue) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src OidValue) Value() (driver.Value, error) { - return (pguint32)(src).Value() +func (src *OidValue) Value() (driver.Value, error) { + return (*pguint32)(src).Value() } diff --git a/oid_value_test.go b/oid_value_test.go index d3412159..52ce4064 100644 --- a/oid_value_test.go +++ b/oid_value_test.go @@ -10,8 +10,8 @@ import ( func TestOidValueTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "oid", []interface{}{ - pgtype.OidValue{Uint: 42, Status: pgtype.Present}, - pgtype.OidValue{Status: pgtype.Null}, + &pgtype.OidValue{Uint: 42, Status: pgtype.Present}, + &pgtype.OidValue{Status: pgtype.Null}, }) } diff --git a/pguint32.go b/pguint32.go index 7138a409..0caa0cba 100644 --- a/pguint32.go +++ b/pguint32.go @@ -103,7 +103,7 @@ func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src pguint32) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *pguint32) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -115,7 +115,7 @@ func (src pguint32) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src pguint32) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *pguint32) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -151,7 +151,7 @@ func (dst *pguint32) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src pguint32) Value() (driver.Value, error) { +func (src *pguint32) Value() (driver.Value, error) { switch src.Status { case Present: return int64(src.Uint), nil diff --git a/qchar.go b/qchar.go index 49475bd3..10b56534 100644 --- a/qchar.go +++ b/qchar.go @@ -136,7 +136,7 @@ func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src QChar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *QChar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/text.go b/text.go index de80dd08..8e42a756 100644 --- a/text.go +++ b/text.go @@ -91,7 +91,7 @@ func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } -func (src Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -103,7 +103,7 @@ func (src Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Text) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Text) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return src.EncodeText(ci, w) } @@ -125,7 +125,7 @@ func (dst *Text) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Text) Value() (driver.Value, error) { +func (src *Text) Value() (driver.Value, error) { switch src.Status { case Present: return src.String, nil @@ -136,7 +136,7 @@ func (src Text) Value() (driver.Value, error) { } } -func (src Text) MarshalJSON() ([]byte, error) { +func (src *Text) MarshalJSON() ([]byte, error) { switch src.Status { case Present: return json.Marshal(src.String) diff --git a/text_test.go b/text_test.go index e4c1dbd8..bd971807 100644 --- a/text_test.go +++ b/text_test.go @@ -12,9 +12,9 @@ import ( func TestTextTranscode(t *testing.T) { for _, pgTypeName := range []string{"text", "varchar"} { testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ - pgtype.Text{String: "", Status: pgtype.Present}, - pgtype.Text{String: "foo", Status: pgtype.Present}, - pgtype.Text{Status: pgtype.Null}, + &pgtype.Text{String: "", Status: pgtype.Present}, + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Text{Status: pgtype.Null}, }) } } diff --git a/tid.go b/tid.go index b363c1f9..f24c6244 100644 --- a/tid.go +++ b/tid.go @@ -94,7 +94,7 @@ func (dst *Tid) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Tid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Tid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -106,7 +106,7 @@ func (src Tid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Tid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Tid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -141,6 +141,6 @@ func (dst *Tid) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Tid) Value() (driver.Value, error) { +func (src *Tid) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/tid_test.go b/tid_test.go index 7eb7773a..a5430d11 100644 --- a/tid_test.go +++ b/tid_test.go @@ -9,8 +9,8 @@ import ( func TestTidTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "tid", []interface{}{ - pgtype.Tid{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, - pgtype.Tid{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, - pgtype.Tid{Status: pgtype.Null}, + &pgtype.Tid{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, + &pgtype.Tid{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, + &pgtype.Tid{Status: pgtype.Null}, }) } diff --git a/timestamp.go b/timestamp.go index e7bc1c7d..694b63c0 100644 --- a/timestamp.go +++ b/timestamp.go @@ -136,7 +136,7 @@ func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { // EncodeText writes the text encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Timestamp) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -164,7 +164,7 @@ func (src Timestamp) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { // EncodeBinary writes the binary encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -211,7 +211,7 @@ func (dst *Timestamp) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Timestamp) Value() (driver.Value, error) { +func (src *Timestamp) Value() (driver.Value, error) { switch src.Status { case Present: if src.InfinityModifier != None { diff --git a/timestamp_test.go b/timestamp_test.go index c0427a5c..267f1a7e 100644 --- a/timestamp_test.go +++ b/timestamp_test.go @@ -11,19 +11,19 @@ import ( func TestTimestampTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ - pgtype.Timestamp{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Status: pgtype.Null}, - pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, - pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + &pgtype.Timestamp{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Status: pgtype.Null}, + &pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + &pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, }, func(a, b interface{}) bool { at := a.(pgtype.Timestamp) bt := b.(pgtype.Timestamp) diff --git a/timestamptz.go b/timestamptz.go index ef2d7498..3c76ec03 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -140,7 +140,7 @@ func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Timestamptz) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Timestamptz) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -163,7 +163,7 @@ func (src Timestamptz) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -207,7 +207,7 @@ func (dst *Timestamptz) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Timestamptz) Value() (driver.Value, error) { +func (src *Timestamptz) Value() (driver.Value, error) { switch src.Status { case Present: if src.InfinityModifier != None { diff --git a/timestamptz_test.go b/timestamptz_test.go index bbc001e5..c326802d 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -11,19 +11,19 @@ import ( func TestTimestamptzTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ - pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Status: pgtype.Null}, - pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, - pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + &pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Status: pgtype.Null}, + &pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + &pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, }, func(a, b interface{}) bool { at := a.(pgtype.Timestamptz) bt := b.(pgtype.Timestamptz) diff --git a/tsrange.go b/tsrange.go index 48992829..3bf5f5ca 100644 --- a/tsrange.go +++ b/tsrange.go @@ -106,7 +106,7 @@ func (dst *Tsrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Tsrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Tsrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Tsrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src Tsrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Tsrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -263,6 +263,6 @@ func (dst *Tsrange) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Tsrange) Value() (driver.Value, error) { +func (src *Tsrange) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/tsrange_test.go b/tsrange_test.go index 865233c2..78eb1cd3 100644 --- a/tsrange_test.go +++ b/tsrange_test.go @@ -10,22 +10,22 @@ import ( func TestTsrangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "tsrange", []interface{}{ - pgtype.Tsrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - pgtype.Tsrange{ + &pgtype.Tsrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Tsrange{ Lower: pgtype.Timestamp{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, Upper: pgtype.Timestamp{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present, }, - pgtype.Tsrange{ + &pgtype.Tsrange{ Lower: pgtype.Timestamp{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, Upper: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present, }, - pgtype.Tsrange{Status: pgtype.Null}, + &pgtype.Tsrange{Status: pgtype.Null}, }, func(aa, bb interface{}) bool { a := aa.(pgtype.Tsrange) b := bb.(pgtype.Tsrange) diff --git a/tstzrange.go b/tstzrange.go index 61e94ab4..8e80a8f9 100644 --- a/tstzrange.go +++ b/tstzrange.go @@ -106,7 +106,7 @@ func (dst *Tstzrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Tstzrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Tstzrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Tstzrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src Tstzrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Tstzrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -263,6 +263,6 @@ func (dst *Tstzrange) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Tstzrange) Value() (driver.Value, error) { +func (src *Tstzrange) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/tstzrange_test.go b/tstzrange_test.go index 8eb00ab9..a27ddd3a 100644 --- a/tstzrange_test.go +++ b/tstzrange_test.go @@ -10,22 +10,22 @@ import ( func TestTstzrangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "tstzrange", []interface{}{ - pgtype.Tstzrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - pgtype.Tstzrange{ + &pgtype.Tstzrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Tstzrange{ Lower: pgtype.Timestamptz{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, Upper: pgtype.Timestamptz{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present, }, - pgtype.Tstzrange{ + &pgtype.Tstzrange{ Lower: pgtype.Timestamptz{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, Upper: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present, }, - pgtype.Tstzrange{Status: pgtype.Null}, + &pgtype.Tstzrange{Status: pgtype.Null}, }, func(aa, bb interface{}) bool { a := aa.(pgtype.Tstzrange) b := bb.(pgtype.Tstzrange) diff --git a/unknown.go b/unknown.go index 2dca0f87..567831d7 100644 --- a/unknown.go +++ b/unknown.go @@ -39,6 +39,6 @@ func (dst *Unknown) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Unknown) Value() (driver.Value, error) { - return (Text)(src).Value() +func (src *Unknown) Value() (driver.Value, error) { + return (*Text)(src).Value() } diff --git a/uuid.go b/uuid.go index 88d2195b..03029ffd 100644 --- a/uuid.go +++ b/uuid.go @@ -126,7 +126,7 @@ func (dst *Uuid) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Uuid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Uuid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -138,7 +138,7 @@ func (src Uuid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Uuid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Uuid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -168,6 +168,6 @@ func (dst *Uuid) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Uuid) Value() (driver.Value, error) { +func (src *Uuid) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/uuid_test.go b/uuid_test.go index b745d542..4c6ad2cd 100644 --- a/uuid_test.go +++ b/uuid_test.go @@ -10,8 +10,8 @@ import ( func TestUuidTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - pgtype.Uuid{Status: pgtype.Null}, + &pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + &pgtype.Uuid{Status: pgtype.Null}, }) } diff --git a/varchar.go b/varchar.go index 6c137b9a..80673fa8 100644 --- a/varchar.go +++ b/varchar.go @@ -32,12 +32,12 @@ func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } -func (src Varchar) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (Text)(src).EncodeText(ci, w) +func (src *Varchar) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Text)(src).EncodeText(ci, w) } -func (src Varchar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (Text)(src).EncodeBinary(ci, w) +func (src *Varchar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Text)(src).EncodeBinary(ci, w) } // Scan implements the database/sql Scanner interface. @@ -46,10 +46,10 @@ func (dst *Varchar) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Varchar) Value() (driver.Value, error) { - return (Text)(src).Value() +func (src *Varchar) Value() (driver.Value, error) { + return (*Text)(src).Value() } -func (src Varchar) MarshalJSON() ([]byte, error) { - return (Text)(src).MarshalJSON() +func (src *Varchar) MarshalJSON() ([]byte, error) { + return (*Text)(src).MarshalJSON() } diff --git a/xid.go b/xid.go index 0a7fc7d9..90a8d691 100644 --- a/xid.go +++ b/xid.go @@ -46,12 +46,12 @@ func (dst *Xid) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src Xid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(ci, w) +func (src *Xid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*pguint32)(src).EncodeText(ci, w) } -func (src Xid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(ci, w) +func (src *Xid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*pguint32)(src).EncodeBinary(ci, w) } // Scan implements the database/sql Scanner interface. @@ -60,6 +60,6 @@ func (dst *Xid) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Xid) Value() (driver.Value, error) { - return (pguint32)(src).Value() +func (src *Xid) Value() (driver.Value, error) { + return (*pguint32)(src).Value() } diff --git a/xid_test.go b/xid_test.go index 868c101e..c4a1bec3 100644 --- a/xid_test.go +++ b/xid_test.go @@ -11,8 +11,8 @@ import ( func TestXidTranscode(t *testing.T) { pgTypeName := "xid" values := []interface{}{ - pgtype.Xid{Uint: 42, Status: pgtype.Present}, - pgtype.Xid{Status: pgtype.Null}, + &pgtype.Xid{Uint: 42, Status: pgtype.Present}, + &pgtype.Xid{Status: pgtype.Null}, } eqFunc := func(a, b interface{}) bool { return reflect.DeepEqual(a, b) From f0e9337d8f7a1d561c7b15934192a93be9cc7443 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Apr 2017 16:46:39 -0500 Subject: [PATCH 0073/1158] Add satori-uuid type Make pgtype.EncodeValueText public --- box.go | 2 +- circle.go | 2 +- database_sql.go | 2 +- daterange.go | 2 +- ext/satori-uuid/uuid.go | 164 +++++++++++++++++++++++++++++++++++ ext/satori-uuid/uuid_test.go | 97 +++++++++++++++++++++ hstore.go | 2 +- inet.go | 2 +- int4range.go | 2 +- int8range.go | 2 +- interval.go | 2 +- line.go | 2 +- lseg.go | 2 +- macaddr.go | 2 +- numrange.go | 2 +- path.go | 2 +- point.go | 2 +- polygon.go | 2 +- tid.go | 2 +- tsrange.go | 2 +- tstzrange.go | 2 +- typed_range.go.erb | 2 +- uuid.go | 2 +- varbit.go | 2 +- 24 files changed, 283 insertions(+), 22 deletions(-) create mode 100644 ext/satori-uuid/uuid.go create mode 100644 ext/satori-uuid/uuid_test.go diff --git a/box.go b/box.go index 138953a5..2e4f39ee 100644 --- a/box.go +++ b/box.go @@ -164,5 +164,5 @@ func (dst *Box) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Box) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/circle.go b/circle.go index 62e2e8b3..8c8f4693 100644 --- a/circle.go +++ b/circle.go @@ -146,5 +146,5 @@ func (dst *Circle) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Circle) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/database_sql.go b/database_sql.go index 2ddd842d..e255b646 100644 --- a/database_sql.go +++ b/database_sql.go @@ -31,7 +31,7 @@ func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { return nil, errors.New("cannot convert to database/sql compatible value") } -func encodeValueText(src TextEncoder) (interface{}, error) { +func EncodeValueText(src TextEncoder) (interface{}, error) { buf := &bytes.Buffer{} null, err := src.EncodeText(nil, buf) if err != nil { diff --git a/daterange.go b/daterange.go index d78c4803..5cecca20 100644 --- a/daterange.go +++ b/daterange.go @@ -264,5 +264,5 @@ func (dst *Daterange) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Daterange) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/ext/satori-uuid/uuid.go b/ext/satori-uuid/uuid.go new file mode 100644 index 00000000..1b65f48a --- /dev/null +++ b/ext/satori-uuid/uuid.go @@ -0,0 +1,164 @@ +package uuid + +import ( + "database/sql/driver" + "errors" + "fmt" + "io" + + "github.com/jackc/pgx/pgtype" + uuid "github.com/satori/go.uuid" +) + +var errUndefined = errors.New("cannot encode status undefined") + +type Uuid struct { + UUID uuid.UUID + Status pgtype.Status +} + +func (dst *Uuid) Set(src interface{}) error { + switch value := src.(type) { + case uuid.UUID: + *dst = Uuid{UUID: value, Status: pgtype.Present} + case [16]byte: + *dst = Uuid{UUID: uuid.UUID(value), Status: pgtype.Present} + case []byte: + if len(value) != 16 { + return fmt.Errorf("[]byte must be 16 bytes to convert to Uuid: %d", len(value)) + } + *dst = Uuid{Status: pgtype.Present} + copy(dst.UUID[:], value) + case string: + uuid, err := uuid.FromString(value) + if err != nil { + return err + } + *dst = Uuid{UUID: uuid, Status: pgtype.Present} + default: + // If all else fails see if pgtype.Uuid can handle it. If so, translate through that. + pgUuid := &pgtype.Uuid{} + if err := pgUuid.Set(value); err != nil { + return fmt.Errorf("cannot convert %v to Uuid", value) + } + + *dst = Uuid{UUID: uuid.UUID(pgUuid.Bytes), Status: pgUuid.Status} + } + + return nil +} + +func (dst *Uuid) Get() interface{} { + switch dst.Status { + case pgtype.Present: + return dst.UUID + case pgtype.Null: + return nil + default: + return dst.Status + } +} + +func (src *Uuid) AssignTo(dst interface{}) error { + switch src.Status { + case pgtype.Present: + switch v := dst.(type) { + case *uuid.UUID: + *v = src.UUID + case *[16]byte: + *v = [16]byte(src.UUID) + return nil + case *[]byte: + *v = make([]byte, 16) + copy(*v, src.UUID[:]) + return nil + case *string: + *v = src.UUID.String() + return nil + default: + if nextDst, retry := pgtype.GetAssignToDstType(v); retry { + return src.AssignTo(nextDst) + } + } + case pgtype.Null: + return pgtype.NullAssignTo(dst) + } + + return fmt.Errorf("cannot assign %v into %T", src, dst) +} + +func (dst *Uuid) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = Uuid{Status: pgtype.Null} + return nil + } + + u, err := uuid.FromString(string(src)) + if err != nil { + return err + } + + *dst = Uuid{UUID: u, Status: pgtype.Present} + return nil +} + +func (dst *Uuid) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = Uuid{Status: pgtype.Null} + return nil + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for Uuid: %v", len(src)) + } + + *dst = Uuid{Status: pgtype.Present} + copy(dst.UUID[:], src) + return nil +} + +func (src *Uuid) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case pgtype.Null: + return true, nil + case pgtype.Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, src.UUID.String()) + return false, err +} + +func (src *Uuid) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case pgtype.Null: + return true, nil + case pgtype.Undefined: + return false, errUndefined + } + + _, err := w.Write(src.UUID[:]) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Uuid) Scan(src interface{}) error { + if src == nil { + *dst = Uuid{Status: pgtype.Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Uuid) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/ext/satori-uuid/uuid_test.go b/ext/satori-uuid/uuid_test.go new file mode 100644 index 00000000..993fb837 --- /dev/null +++ b/ext/satori-uuid/uuid_test.go @@ -0,0 +1,97 @@ +package uuid_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/pgtype" + satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestUuidTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ + &satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + &satori.Uuid{Status: pgtype.Null}, + }) +} + +func TestUuidSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result satori.Uuid + }{ + { + source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: "00010203-0405-0607-0809-0a0b0c0d0e0f", + result: satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r satori.Uuid + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestUuidAssignTo(t *testing.T) { + { + src := satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst [16]byte + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst []byte + expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare(dst, expected) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst string + expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + +} diff --git a/hstore.go b/hstore.go index 3d55f783..04df2acc 100644 --- a/hstore.go +++ b/hstore.go @@ -463,5 +463,5 @@ func (dst *Hstore) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Hstore) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/inet.go b/inet.go index 62734088..e3a7ec88 100644 --- a/inet.go +++ b/inet.go @@ -221,5 +221,5 @@ func (dst *Inet) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Inet) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/int4range.go b/int4range.go index 8b04cf3c..12a48dab 100644 --- a/int4range.go +++ b/int4range.go @@ -264,5 +264,5 @@ func (dst *Int4range) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Int4range) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/int8range.go b/int8range.go index f8e056cb..3541dbe2 100644 --- a/int8range.go +++ b/int8range.go @@ -264,5 +264,5 @@ func (dst *Int8range) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Int8range) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/interval.go b/interval.go index 1cbdffc3..050d5610 100644 --- a/interval.go +++ b/interval.go @@ -267,5 +267,5 @@ func (dst *Interval) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Interval) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/line.go b/line.go index 08a74e84..06f01f21 100644 --- a/line.go +++ b/line.go @@ -144,5 +144,5 @@ func (dst *Line) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Line) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/lseg.go b/lseg.go index b86256e0..986724cc 100644 --- a/lseg.go +++ b/lseg.go @@ -164,5 +164,5 @@ func (dst *Lseg) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Lseg) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/macaddr.go b/macaddr.go index cfbb513d..0fe092e4 100644 --- a/macaddr.go +++ b/macaddr.go @@ -150,5 +150,5 @@ func (dst *Macaddr) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Macaddr) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/numrange.go b/numrange.go index a1b5b184..b0baec9a 100644 --- a/numrange.go +++ b/numrange.go @@ -264,5 +264,5 @@ func (dst *Numrange) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Numrange) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/path.go b/path.go index fb4193d9..2fd6cfc7 100644 --- a/path.go +++ b/path.go @@ -203,5 +203,5 @@ func (dst *Path) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Path) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/point.go b/point.go index 788a76c9..3d51766e 100644 --- a/point.go +++ b/point.go @@ -138,5 +138,5 @@ func (dst *Point) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Point) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/polygon.go b/polygon.go index 1e2df011..af99ee3d 100644 --- a/polygon.go +++ b/polygon.go @@ -182,5 +182,5 @@ func (dst *Polygon) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Polygon) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/tid.go b/tid.go index f24c6244..7976afde 100644 --- a/tid.go +++ b/tid.go @@ -142,5 +142,5 @@ func (dst *Tid) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Tid) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/tsrange.go b/tsrange.go index 3bf5f5ca..78a94af2 100644 --- a/tsrange.go +++ b/tsrange.go @@ -264,5 +264,5 @@ func (dst *Tsrange) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Tsrange) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/tstzrange.go b/tstzrange.go index 8e80a8f9..d1fc7326 100644 --- a/tstzrange.go +++ b/tstzrange.go @@ -264,5 +264,5 @@ func (dst *Tstzrange) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Tstzrange) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/typed_range.go.erb b/typed_range.go.erb index 922b98b4..e46f71c7 100644 --- a/typed_range.go.erb +++ b/typed_range.go.erb @@ -264,5 +264,5 @@ func (dst *<%= range_type %>) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src <%= range_type %>) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/uuid.go b/uuid.go index 03029ffd..c830c086 100644 --- a/uuid.go +++ b/uuid.go @@ -169,5 +169,5 @@ func (dst *Uuid) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Uuid) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/varbit.go b/varbit.go index d28e95cd..00c34e10 100644 --- a/varbit.go +++ b/varbit.go @@ -137,5 +137,5 @@ func (dst *Varbit) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Varbit) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } From 851479b0d3c268d54644560ec626ff393a2cae41 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Apr 2017 17:11:39 -0500 Subject: [PATCH 0074/1158] Replace DATABASE_URL with PGX_TEST_DATABASE PGX_TEST_DATABASE is much less likely to collide with another environment variable. This is especially valuable when using direnv to automatically set environment variables. --- testutil/testutil.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/testutil/testutil.go b/testutil/testutil.go index d9aaa5c4..6bf9f878 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -26,7 +26,7 @@ func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { t.Fatalf("Unknown driver %v", driverName) } - db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL")) + db, err := sql.Open(sqlDriverName, os.Getenv("PGX_TEST_DATABASE")) if err != nil { t.Fatal(err) } @@ -35,7 +35,7 @@ func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { } func MustConnectPgx(t testing.TB) *pgx.Conn { - config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) + config, err := pgx.ParseURI(os.Getenv("PGX_TEST_DATABASE")) if err != nil { t.Fatal(err) } From fa68e44e5ffd91508aa3cad30468ccd810985293 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Apr 2017 17:21:32 -0500 Subject: [PATCH 0075/1158] Use pgx.ParseConnectionString in test helper This allows using URI or DSN for database connection information. DSN allows using unix domain sockets. --- testutil/testutil.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testutil/testutil.go b/testutil/testutil.go index 6bf9f878..5dd2fbe1 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -35,7 +35,7 @@ func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { } func MustConnectPgx(t testing.TB) *pgx.Conn { - config, err := pgx.ParseURI(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgx.ParseConnectionString(os.Getenv("PGX_TEST_DATABASE")) if err != nil { t.Fatal(err) } From 4e2900b774649cbc58686ac047bf84f54aff4e18 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Apr 2017 10:02:38 -0500 Subject: [PATCH 0076/1158] Introduce pgproto3 package pgproto3 will wrap the message encoding and decoding for the PostgreSQL frontend/backend protocol version 3. --- authentication.go | 54 +++++++++++ backend_key_data.go | 47 +++++++++ big_endian.go | 37 +++++++ bind_complete.go | 29 ++++++ close_complete.go | 29 ++++++ command_complete.go | 47 +++++++++ copy_both_response.go | 64 +++++++++++++ copy_data.go | 41 ++++++++ copy_in_response.go | 64 +++++++++++++ copy_out_response.go | 64 +++++++++++++ data_row.go | 103 ++++++++++++++++++++ empty_query_response.go | 29 ++++++ error_response.go | 197 ++++++++++++++++++++++++++++++++++++++ frontend.go | 70 ++++++++++++++ function_call_response.go | 73 ++++++++++++++ no_data.go | 29 ++++++ notice_response.go | 13 +++ notification_response.go | 65 +++++++++++++ parameter_description.go | 60 ++++++++++++ parameter_status.go | 62 ++++++++++++ parse_complete.go | 29 ++++++ pgproto3.go | 88 +++++++++++++++++ query.go | 43 +++++++++ ready_for_query.go | 35 +++++++ row_description.go | 101 +++++++++++++++++++ 25 files changed, 1473 insertions(+) create mode 100644 authentication.go create mode 100644 backend_key_data.go create mode 100644 big_endian.go create mode 100644 bind_complete.go create mode 100644 close_complete.go create mode 100644 command_complete.go create mode 100644 copy_both_response.go create mode 100644 copy_data.go create mode 100644 copy_in_response.go create mode 100644 copy_out_response.go create mode 100644 data_row.go create mode 100644 empty_query_response.go create mode 100644 error_response.go create mode 100644 frontend.go create mode 100644 function_call_response.go create mode 100644 no_data.go create mode 100644 notice_response.go create mode 100644 notification_response.go create mode 100644 parameter_description.go create mode 100644 parameter_status.go create mode 100644 parse_complete.go create mode 100644 pgproto3.go create mode 100644 query.go create mode 100644 ready_for_query.go create mode 100644 row_description.go diff --git a/authentication.go b/authentication.go new file mode 100644 index 00000000..e265a247 --- /dev/null +++ b/authentication.go @@ -0,0 +1,54 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +const ( + AuthTypeOk = 0 + AuthTypeCleartextPassword = 3 + AuthTypeMD5Password = 5 +) + +type Authentication struct { + Type uint32 + + // MD5Password fields + Salt [4]byte +} + +func (*Authentication) Backend() {} + +func (dst *Authentication) UnmarshalBinary(src []byte) error { + *dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])} + + switch dst.Type { + case AuthTypeOk: + case AuthTypeCleartextPassword: + case AuthTypeMD5Password: + copy(dst.Salt[:], src[4:8]) + default: + return fmt.Errorf("unknown authentication type: %d", dst.Type) + } + + return nil +} + +func (src *Authentication) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.WriteByte('R') + buf.Write(bigEndian.Uint32(0)) + buf.Write(bigEndian.Uint32(src.Type)) + + switch src.Type { + case AuthTypeMD5Password: + buf.Write(src.Salt[:]) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} diff --git a/backend_key_data.go b/backend_key_data.go new file mode 100644 index 00000000..5d8eb496 --- /dev/null +++ b/backend_key_data.go @@ -0,0 +1,47 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type BackendKeyData struct { + ProcessID uint32 + SecretKey uint32 +} + +func (*BackendKeyData) Backend() {} + +func (dst *BackendKeyData) UnmarshalBinary(src []byte) error { + if len(src) != 8 { + return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} + } + + dst.ProcessID = binary.BigEndian.Uint32(src[:4]) + dst.SecretKey = binary.BigEndian.Uint32(src[4:]) + + return nil +} + +func (src *BackendKeyData) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.WriteByte('K') + buf.Write(bigEndian.Uint32(12)) + buf.Write(bigEndian.Uint32(src.ProcessID)) + buf.Write(bigEndian.Uint32(src.SecretKey)) + return buf.Bytes(), nil +} + +func (src *BackendKeyData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProcessID uint32 + SecretKey uint32 + }{ + Type: "BackendKeyData", + ProcessID: src.ProcessID, + SecretKey: src.SecretKey, + }) +} diff --git a/big_endian.go b/big_endian.go new file mode 100644 index 00000000..f7bdb97e --- /dev/null +++ b/big_endian.go @@ -0,0 +1,37 @@ +package pgproto3 + +import ( + "encoding/binary" +) + +type BigEndianBuf [8]byte + +func (b BigEndianBuf) Int16(n int16) []byte { + buf := b[0:2] + binary.BigEndian.PutUint16(buf, uint16(n)) + return buf +} + +func (b BigEndianBuf) Uint16(n uint16) []byte { + buf := b[0:2] + binary.BigEndian.PutUint16(buf, n) + return buf +} + +func (b BigEndianBuf) Int32(n int32) []byte { + buf := b[0:4] + binary.BigEndian.PutUint32(buf, uint32(n)) + return buf +} + +func (b BigEndianBuf) Uint32(n uint32) []byte { + buf := b[0:4] + binary.BigEndian.PutUint32(buf, n) + return buf +} + +func (b BigEndianBuf) Int64(n int64) []byte { + buf := b[0:8] + binary.BigEndian.PutUint64(buf, uint64(n)) + return buf +} diff --git a/bind_complete.go b/bind_complete.go new file mode 100644 index 00000000..756a30e6 --- /dev/null +++ b/bind_complete.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type BindComplete struct{} + +func (*BindComplete) Backend() {} + +func (dst *BindComplete) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *BindComplete) MarshalBinary() ([]byte, error) { + return []byte{'2', 0, 0, 0, 4}, nil +} + +func (src *BindComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "BindComplete", + }) +} diff --git a/close_complete.go b/close_complete.go new file mode 100644 index 00000000..fd6ff180 --- /dev/null +++ b/close_complete.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type CloseComplete struct{} + +func (*CloseComplete) Backend() {} + +func (dst *CloseComplete) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *CloseComplete) MarshalBinary() ([]byte, error) { + return []byte{'3', 0, 0, 0, 4}, nil +} + +func (src *CloseComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "CloseComplete", + }) +} diff --git a/command_complete.go b/command_complete.go new file mode 100644 index 00000000..ac60153e --- /dev/null +++ b/command_complete.go @@ -0,0 +1,47 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" +) + +type CommandComplete struct { + CommandTag string +} + +func (*CommandComplete) Backend() {} + +func (dst *CommandComplete) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.CommandTag = string(b[:len(b)-1]) + + return nil +} + +func (src *CommandComplete) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('C') + buf.Write(bigEndian.Uint32(uint32(4 + len(src.CommandTag) + 1))) + + buf.WriteString(src.CommandTag) + buf.WriteByte(0) + + return buf.Bytes(), nil +} + +func (src *CommandComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + CommandTag string + }{ + Type: "CommandComplete", + CommandTag: src.CommandTag, + }) +} diff --git a/copy_both_response.go b/copy_both_response.go new file mode 100644 index 00000000..2a4c58af --- /dev/null +++ b/copy_both_response.go @@ -0,0 +1,64 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type CopyBothResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyBothResponse) Backend() {} + +func (dst *CopyBothResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyBothResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyBothResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +func (src *CopyBothResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('W') + buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) + + buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) + + for _, fc := range src.ColumnFormatCodes { + buf.Write(bigEndian.Uint16(fc)) + } + + return buf.Bytes(), nil +} + +func (src *CopyBothResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyBothResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} diff --git a/copy_data.go b/copy_data.go new file mode 100644 index 00000000..b9ea6272 --- /dev/null +++ b/copy_data.go @@ -0,0 +1,41 @@ +package pgproto3 + +import ( + "bytes" + "encoding/hex" + "encoding/json" +) + +type CopyData struct { + Data []byte +} + +func (*CopyData) Backend() {} +func (*CopyData) Frontend() {} + +func (dst *CopyData) UnmarshalBinary(src []byte) error { + dst.Data = make([]byte, len(src)) + copy(dst.Data, src) + return nil +} + +func (src *CopyData) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('d') + buf.Write(bigEndian.Uint32(uint32(4 + len(src.Data)))) + buf.Write(src.Data) + + return buf.Bytes(), nil +} + +func (src *CopyData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "CopyData", + Data: hex.EncodeToString(src.Data), + }) +} diff --git a/copy_in_response.go b/copy_in_response.go new file mode 100644 index 00000000..63868c7a --- /dev/null +++ b/copy_in_response.go @@ -0,0 +1,64 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type CopyInResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyInResponse) Backend() {} + +func (dst *CopyInResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyInResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyInResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +func (src *CopyInResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('G') + buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) + + buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) + + for _, fc := range src.ColumnFormatCodes { + buf.Write(bigEndian.Uint16(fc)) + } + + return buf.Bytes(), nil +} + +func (src *CopyInResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyInResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} diff --git a/copy_out_response.go b/copy_out_response.go new file mode 100644 index 00000000..e46d9e8f --- /dev/null +++ b/copy_out_response.go @@ -0,0 +1,64 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type CopyOutResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyOutResponse) Backend() {} + +func (dst *CopyOutResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyOutResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyOutResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +func (src *CopyOutResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('H') + buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) + + buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) + + for _, fc := range src.ColumnFormatCodes { + buf.Write(bigEndian.Uint16(fc)) + } + + return buf.Bytes(), nil +} + +func (src *CopyOutResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyOutResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} diff --git a/data_row.go b/data_row.go new file mode 100644 index 00000000..c95861b9 --- /dev/null +++ b/data_row.go @@ -0,0 +1,103 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "encoding/json" +) + +type DataRow struct { + Values [][]byte +} + +func (*DataRow) Backend() {} + +func (dst *DataRow) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + fieldCount := int(binary.BigEndian.Uint16(buf.Next(2))) + + dst.Values = make([][]byte, fieldCount) + + for i := 0; i < fieldCount; i++ { + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + + msgSize := int(int32(binary.BigEndian.Uint32(buf.Next(4)))) + + // null + if msgSize == -1 { + continue + } + + value := make([]byte, msgSize) + _, err := buf.Read(value) + if err != nil { + return err + } + + dst.Values[i] = value + } + + return nil +} + +func (src *DataRow) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('D') + buf.Write(bigEndian.Uint32(0)) + + buf.Write(bigEndian.Uint16(uint16(len(src.Values)))) + + for _, v := range src.Values { + if v == nil { + buf.Write(bigEndian.Int32(-1)) + continue + } + + buf.Write(bigEndian.Int32(int32(len(v)))) + buf.Write(v) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *DataRow) MarshalJSON() ([]byte, error) { + formattedValues := make([]map[string]string, len(src.Values)) + for i, v := range src.Values { + if v == nil { + continue + } + + var hasNonPrintable bool + for _, b := range v { + if b < 32 { + hasNonPrintable = true + break + } + } + + if hasNonPrintable { + formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)} + } else { + formattedValues[i] = map[string]string{"text": string(v)} + } + } + + return json.Marshal(struct { + Type string + Values []map[string]string + }{ + Type: "DataRow", + Values: formattedValues, + }) +} diff --git a/empty_query_response.go b/empty_query_response.go new file mode 100644 index 00000000..de6e6272 --- /dev/null +++ b/empty_query_response.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type EmptyQueryResponse struct{} + +func (*EmptyQueryResponse) Backend() {} + +func (dst *EmptyQueryResponse) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *EmptyQueryResponse) MarshalBinary() ([]byte, error) { + return []byte{'I', 0, 0, 0, 4}, nil +} + +func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "EmptyQueryResponse", + }) +} diff --git a/error_response.go b/error_response.go new file mode 100644 index 00000000..82e408d7 --- /dev/null +++ b/error_response.go @@ -0,0 +1,197 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "strconv" +) + +type ErrorResponse struct { + Severity string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string +} + +func (*ErrorResponse) Backend() {} + +func (dst *ErrorResponse) UnmarshalBinary(src []byte) error { + *dst = ErrorResponse{} + + buf := bytes.NewBuffer(src) + + for { + k, err := buf.ReadByte() + if err != nil { + return err + } + if k == 0 { + break + } + + vb, err := buf.ReadBytes(0) + if err != nil { + return err + } + v := string(vb[:len(vb)-1]) + + switch k { + case 'S': + dst.Severity = v + case 'C': + dst.Code = v + case 'M': + dst.Message = v + case 'D': + dst.Detail = v + case 'H': + dst.Hint = v + case 'P': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.Position = int32(n) + case 'p': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.InternalPosition = int32(n) + case 'q': + dst.InternalQuery = v + case 'W': + dst.Where = v + case 's': + dst.SchemaName = v + case 't': + dst.TableName = v + case 'c': + dst.ColumnName = v + case 'd': + dst.DataTypeName = v + case 'n': + dst.ConstraintName = v + case 'F': + dst.File = v + case 'L': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.Line = int32(n) + case 'R': + dst.Routine = v + + default: + if dst.UnknownFields == nil { + dst.UnknownFields = make(map[byte]string) + } + dst.UnknownFields[k] = v + } + } + + return nil +} + +func (src *ErrorResponse) MarshalBinary() ([]byte, error) { + return src.marshalBinary('E') +} + +func (src *ErrorResponse) marshalBinary(typeByte byte) ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte(typeByte) + buf.Write(bigEndian.Uint32(0)) + + if src.Severity != "" { + buf.WriteString(src.Severity) + buf.WriteByte(0) + } + if src.Code != "" { + buf.WriteString(src.Code) + buf.WriteByte(0) + } + if src.Message != "" { + buf.WriteString(src.Message) + buf.WriteByte(0) + } + if src.Detail != "" { + buf.WriteString(src.Detail) + buf.WriteByte(0) + } + if src.Hint != "" { + buf.WriteString(src.Hint) + buf.WriteByte(0) + } + if src.Position != 0 { + buf.WriteString(strconv.Itoa(int(src.Position))) + buf.WriteByte(0) + } + if src.InternalPosition != 0 { + buf.WriteString(strconv.Itoa(int(src.InternalPosition))) + buf.WriteByte(0) + } + if src.InternalQuery != "" { + buf.WriteString(src.InternalQuery) + buf.WriteByte(0) + } + if src.Where != "" { + buf.WriteString(src.Where) + buf.WriteByte(0) + } + if src.SchemaName != "" { + buf.WriteString(src.SchemaName) + buf.WriteByte(0) + } + if src.TableName != "" { + buf.WriteString(src.TableName) + buf.WriteByte(0) + } + if src.ColumnName != "" { + buf.WriteString(src.ColumnName) + buf.WriteByte(0) + } + if src.DataTypeName != "" { + buf.WriteString(src.DataTypeName) + buf.WriteByte(0) + } + if src.ConstraintName != "" { + buf.WriteString(src.ConstraintName) + buf.WriteByte(0) + } + if src.File != "" { + buf.WriteString(src.File) + buf.WriteByte(0) + } + if src.Line != 0 { + buf.WriteString(strconv.Itoa(int(src.Line))) + buf.WriteByte(0) + } + if src.Routine != "" { + buf.WriteString(src.Routine) + buf.WriteByte(0) + } + + for k, v := range src.UnknownFields { + buf.WriteByte(k) + buf.WriteByte(0) + buf.WriteString(v) + buf.WriteByte(0) + } + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} diff --git a/frontend.go b/frontend.go new file mode 100644 index 00000000..c1dec461 --- /dev/null +++ b/frontend.go @@ -0,0 +1,70 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/jackc/pgx/chunkreader" +) + +type Frontend struct { + cr *chunkreader.ChunkReader + w io.Writer +} + +func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { + cr := chunkreader.NewChunkReader(r) + return &Frontend{cr: cr, w: w}, nil +} + +func (b *Frontend) Send(msg FrontendMessage) error { + return errors.New("not implemented") +} + +func (b *Frontend) Receive() (BackendMessage, error) { + backendMessages := map[byte]BackendMessage{ + '1': &ParseComplete{}, + '2': &BindComplete{}, + '3': &CloseComplete{}, + 'A': &NotificationResponse{}, + 'C': &CommandComplete{}, + 'd': &CopyData{}, + 'D': &DataRow{}, + 'E': &ErrorResponse{}, + 'G': &CopyInResponse{}, + 'H': &CopyOutResponse{}, + 'I': &EmptyQueryResponse{}, + 'K': &BackendKeyData{}, + 'n': &NoData{}, + 'N': &NoticeResponse{}, + 'R': &Authentication{}, + 'S': &ParameterStatus{}, + 't': &ParameterDescription{}, + 'T': &RowDescription{}, + 'V': &FunctionCallResponse{}, + 'W': &CopyBothResponse{}, + 'Z': &ReadyForQuery{}, + } + + header, err := b.cr.Next(5) + if err != nil { + return nil, err + } + + msgType := header[0] + bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 + + msgBody, err := b.cr.Next(bodyLen) + if err != nil { + return nil, err + } + + if msg, ok := backendMessages[msgType]; ok { + err = msg.UnmarshalBinary(msgBody) + return msg, err + } + + return nil, fmt.Errorf("unknown message type: %c", msgType) +} diff --git a/function_call_response.go b/function_call_response.go new file mode 100644 index 00000000..5c692b36 --- /dev/null +++ b/function_call_response.go @@ -0,0 +1,73 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "encoding/json" +) + +type FunctionCallResponse struct { + Result []byte +} + +func (*FunctionCallResponse) Backend() {} + +func (dst *FunctionCallResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} + } + resultSize := int(binary.BigEndian.Uint32(buf.Next(4))) + if buf.Len() != resultSize { + return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} + } + + dst.Result = make([]byte, resultSize) + copy(dst.Result, buf.Bytes()) + + return nil +} + +func (src *FunctionCallResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('V') + buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Result)))) + + if src.Result == nil { + buf.Write(bigEndian.Int32(-1)) + } else { + buf.Write(bigEndian.Int32(int32(len(src.Result)))) + buf.Write(src.Result) + } + + return buf.Bytes(), nil +} + +func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) { + var formattedValue map[string]string + var hasNonPrintable bool + for _, b := range src.Result { + if b < 32 { + hasNonPrintable = true + break + } + } + + if hasNonPrintable { + formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)} + } else { + formattedValue = map[string]string{"text": string(src.Result)} + } + + return json.Marshal(struct { + Type string + Result map[string]string + }{ + Type: "FunctionCallResponse", + Result: formattedValue, + }) +} diff --git a/no_data.go b/no_data.go new file mode 100644 index 00000000..47ebf28e --- /dev/null +++ b/no_data.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type NoData struct{} + +func (*NoData) Backend() {} + +func (dst *NoData) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *NoData) MarshalBinary() ([]byte, error) { + return []byte{'n', 0, 0, 0, 4}, nil +} + +func (src *NoData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "NoData", + }) +} diff --git a/notice_response.go b/notice_response.go new file mode 100644 index 00000000..767c9a67 --- /dev/null +++ b/notice_response.go @@ -0,0 +1,13 @@ +package pgproto3 + +type NoticeResponse ErrorResponse + +func (*NoticeResponse) Backend() {} + +func (dst *NoticeResponse) UnmarshalBinary(src []byte) error { + return (*ErrorResponse)(dst).UnmarshalBinary(src) +} + +func (src *NoticeResponse) MarshalBinary() ([]byte, error) { + return (*ErrorResponse)(src).marshalBinary('N') +} diff --git a/notification_response.go b/notification_response.go new file mode 100644 index 00000000..4ae8bab3 --- /dev/null +++ b/notification_response.go @@ -0,0 +1,65 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type NotificationResponse struct { + PID uint32 + Channel string + Payload string +} + +func (*NotificationResponse) Backend() {} + +func (dst *NotificationResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + pid := binary.BigEndian.Uint32(buf.Next(4)) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + channel := string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + payload := string(b[:len(b)-1]) + + *dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload} + return nil +} + +func (src *NotificationResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('A') + buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Channel) + len(src.Payload)))) + + buf.WriteString(src.Channel) + buf.WriteByte(0) + buf.WriteString(src.Payload) + buf.WriteByte(0) + + return buf.Bytes(), nil +} + +func (src *NotificationResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + PID uint32 + Channel string + Payload string + }{ + Type: "NotificationResponse", + PID: src.PID, + Channel: src.Channel, + Payload: src.Payload, + }) +} diff --git a/parameter_description.go b/parameter_description.go new file mode 100644 index 00000000..40d92c50 --- /dev/null +++ b/parameter_description.go @@ -0,0 +1,60 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type ParameterDescription struct { + ParameterOIDs []uint32 +} + +func (*ParameterDescription) Backend() {} + +func (dst *ParameterDescription) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "ParameterDescription"} + } + + // Reported parameter count will be incorrect when number of args is greater than uint16 + buf.Next(2) + // Instead infer parameter count by remaining size of message + parameterCount := buf.Len() / 4 + + *dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)} + + for i := 0; i < parameterCount; i++ { + dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4)) + } + + return nil +} + +func (src *ParameterDescription) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('t') + buf.Write(bigEndian.Uint32(uint32(4 + 2 + 4*len(src.ParameterOIDs)))) + + buf.Write(bigEndian.Uint16(uint16(len(src.ParameterOIDs)))) + + for _, oid := range src.ParameterOIDs { + buf.Write(bigEndian.Uint32(oid)) + } + + return buf.Bytes(), nil +} + +func (src *ParameterDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ParameterOIDs []uint32 + }{ + Type: "ParameterDescription", + ParameterOIDs: src.ParameterOIDs, + }) +} diff --git a/parameter_status.go b/parameter_status.go new file mode 100644 index 00000000..b8ce7f8d --- /dev/null +++ b/parameter_status.go @@ -0,0 +1,62 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type ParameterStatus struct { + Name string + Value string +} + +func (*ParameterStatus) Backend() {} + +func (dst *ParameterStatus) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + name := string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + value := string(b[:len(b)-1]) + + *dst = ParameterStatus{Name: name, Value: value} + return nil +} + +func (src *ParameterStatus) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('S') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteString(src.Name) + buf.WriteByte(0) + buf.WriteString(src.Value) + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (ps *ParameterStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Name string + Value string + }{ + Type: "ParameterStatus", + Name: ps.Name, + Value: ps.Value, + }) +} diff --git a/parse_complete.go b/parse_complete.go new file mode 100644 index 00000000..24951e3d --- /dev/null +++ b/parse_complete.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type ParseComplete struct{} + +func (*ParseComplete) Backend() {} + +func (dst *ParseComplete) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *ParseComplete) MarshalBinary() ([]byte, error) { + return []byte{'1', 0, 0, 0, 4}, nil +} + +func (src *ParseComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "ParseComplete", + }) +} diff --git a/pgproto3.go b/pgproto3.go new file mode 100644 index 00000000..a9221239 --- /dev/null +++ b/pgproto3.go @@ -0,0 +1,88 @@ +package pgproto3 + +import "fmt" + +type Message interface { + UnmarshalBinary(data []byte) error + MarshalBinary() (data []byte, err error) +} + +type FrontendMessage interface { + Message + Frontend() // no-op method to distinguish frontend from backend methods +} + +type BackendMessage interface { + Message + Backend() // no-op method to distinguish frontend from backend methods +} + +// func ParseBackend(typeByte byte, body []byte) (BackendMessage, error) { +// switch typeByte { +// case '1': +// return ParseParseComplete(body) +// case '2': +// return ParseBindComplete(body) +// case 'C': +// return ParseCommandComplete(body) +// case 'D': +// return ParseDataRow(body) +// case 'E': +// return ParseErrorResponse(body) +// case 'K': +// return ParseBackendKeyData(body) +// case 'R': +// return ParseAuthentication(body) +// case 'S': +// return ParseParameterStatus(body) +// case 'T': +// return ParseRowDescription(body) +// case 't': +// return ParseParameterDescription(body) +// case 'Z': +// return ParseReadyForQuery(body) +// default: +// return ParseUnknownMessage(typeByte, body) +// } +// } + +// func ParseFrontend(typeByte byte, body []byte) (FrontendMessage, error) { +// switch typeByte { +// case 'B': +// return ParseBind(body) +// case 'D': +// return ParseDescribe(body) +// case 'E': +// return ParseExecute(body) +// case 'P': +// return ParseParse(body) +// case 'p': +// return ParsePasswordMessage(body) +// case 'Q': +// return ParseQuery(body) +// case 'S': +// return ParseSync(body) +// case 'X': +// return ParseTerminate(body) +// default: +// return ParseUnknownMessage(typeByte, body) +// } +// } + +type invalidMessageLenErr struct { + messageType string + expectedLen int + actualLen int +} + +func (e *invalidMessageLenErr) Error() string { + return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen) +} + +type invalidMessageFormatErr struct { + messageType string +} + +func (e *invalidMessageFormatErr) Error() string { + return fmt.Sprintf("%s body is invalid", e.messageType) +} diff --git a/query.go b/query.go new file mode 100644 index 00000000..a3fc32eb --- /dev/null +++ b/query.go @@ -0,0 +1,43 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" +) + +type Query struct { + String string +} + +func (*Query) Frontend() {} + +func (dst *Query) UnmarshalBinary(src []byte) error { + i := bytes.IndexByte(src, 0) + if i != len(src)-1 { + return &invalidMessageFormatErr{messageType: "Query"} + } + + dst.String = string(src[:i]) + + return nil +} + +func (src *Query) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.WriteByte('Q') + buf.Write(bigEndian.Uint32(uint32(4 + len(src.String) + 1))) + buf.WriteString(src.String) + buf.WriteByte(0) + return buf.Bytes(), nil +} + +func (src *Query) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + String string + }{ + Type: "Query", + String: src.String, + }) +} diff --git a/ready_for_query.go b/ready_for_query.go new file mode 100644 index 00000000..09005d00 --- /dev/null +++ b/ready_for_query.go @@ -0,0 +1,35 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type ReadyForQuery struct { + TxStatus byte +} + +func (*ReadyForQuery) Backend() {} + +func (dst *ReadyForQuery) UnmarshalBinary(src []byte) error { + if len(src) != 1 { + return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)} + } + + dst.TxStatus = src[0] + + return nil +} + +func (src *ReadyForQuery) MarshalBinary() ([]byte, error) { + return []byte{'Z', 0, 0, 0, 5, src.TxStatus}, nil +} + +func (src *ReadyForQuery) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + TxStatus string + }{ + Type: "ReadyForQuery", + TxStatus: string(src.TxStatus), + }) +} diff --git a/row_description.go b/row_description.go new file mode 100644 index 00000000..294a6aa9 --- /dev/null +++ b/row_description.go @@ -0,0 +1,101 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +const ( + TextFormat = 0 + BinaryFormat = 1 +) + +type FieldDescription struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier uint32 + Format int16 +} + +type RowDescription struct { + Fields []FieldDescription +} + +func (*RowDescription) Backend() {} + +func (dst *RowDescription) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "RowDescription"} + } + fieldCount := int(binary.BigEndian.Uint16(buf.Next(2))) + + *dst = RowDescription{Fields: make([]FieldDescription, fieldCount)} + + for i := 0; i < fieldCount; i++ { + var fd FieldDescription + bName, err := buf.ReadBytes(0) + if err != nil { + return err + } + fd.Name = string(bName[:len(bName)-1]) + + // Since buf.Next() doesn't return an error if we hit the end of the buffer + // check Len ahead of time + if buf.Len() < 18 { + return &invalidMessageFormatErr{messageType: "RowDescription"} + } + + fd.TableOID = binary.BigEndian.Uint32(buf.Next(4)) + fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2)) + fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4)) + fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2))) + fd.TypeModifier = binary.BigEndian.Uint32(buf.Next(4)) + fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2))) + + dst.Fields[i] = fd + } + + return nil +} + +func (src *RowDescription) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('T') + buf.Write(bigEndian.Uint32(0)) + + buf.Write(bigEndian.Uint16(uint16(len(src.Fields)))) + + for _, fd := range src.Fields { + buf.WriteString(fd.Name) + buf.WriteByte(0) + + buf.Write(bigEndian.Uint32(fd.TableOID)) + buf.Write(bigEndian.Uint16(fd.TableAttributeNumber)) + buf.Write(bigEndian.Uint32(fd.DataTypeOID)) + buf.Write(bigEndian.Uint16(uint16(fd.DataTypeSize))) + buf.Write(bigEndian.Uint32(fd.TypeModifier)) + buf.Write(bigEndian.Uint16(uint16(fd.Format))) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *RowDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Fields []FieldDescription + }{ + Type: "RowDescription", + Fields: src.Fields, + }) +} From de9bb7e6d8ef0133ee19ce46985e140a863b163f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Apr 2017 11:01:54 -0500 Subject: [PATCH 0077/1158] Use flyweight pattern for pgproto3 messages --- frontend.go | 103 +++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 73 insertions(+), 30 deletions(-) diff --git a/frontend.go b/frontend.go index c1dec461..df67b718 100644 --- a/frontend.go +++ b/frontend.go @@ -12,6 +12,29 @@ import ( type Frontend struct { cr *chunkreader.ChunkReader w io.Writer + + // Backend message flyweights + authentication Authentication + backendKeyData BackendKeyData + bindComplete BindComplete + closeComplete CloseComplete + commandComplete CommandComplete + copyBothResponse CopyBothResponse + copyData CopyData + copyInResponse CopyInResponse + copyOutResponse CopyOutResponse + dataRow DataRow + emptyQueryResponse EmptyQueryResponse + errorResponse ErrorResponse + functionCallResponse FunctionCallResponse + noData NoData + noticeResponse NoticeResponse + notificationResponse NotificationResponse + parameterDescription ParameterDescription + parameterStatus ParameterStatus + parseComplete ParseComplete + readyForQuery ReadyForQuery + rowDescription RowDescription } func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { @@ -24,30 +47,6 @@ func (b *Frontend) Send(msg FrontendMessage) error { } func (b *Frontend) Receive() (BackendMessage, error) { - backendMessages := map[byte]BackendMessage{ - '1': &ParseComplete{}, - '2': &BindComplete{}, - '3': &CloseComplete{}, - 'A': &NotificationResponse{}, - 'C': &CommandComplete{}, - 'd': &CopyData{}, - 'D': &DataRow{}, - 'E': &ErrorResponse{}, - 'G': &CopyInResponse{}, - 'H': &CopyOutResponse{}, - 'I': &EmptyQueryResponse{}, - 'K': &BackendKeyData{}, - 'n': &NoData{}, - 'N': &NoticeResponse{}, - 'R': &Authentication{}, - 'S': &ParameterStatus{}, - 't': &ParameterDescription{}, - 'T': &RowDescription{}, - 'V': &FunctionCallResponse{}, - 'W': &CopyBothResponse{}, - 'Z': &ReadyForQuery{}, - } - header, err := b.cr.Next(5) if err != nil { return nil, err @@ -56,15 +55,59 @@ func (b *Frontend) Receive() (BackendMessage, error) { msgType := header[0] bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 + var msg BackendMessage + switch msgType { + case '1': + msg = &b.parseComplete + case '2': + msg = &b.bindComplete + case '3': + msg = &b.closeComplete + case 'A': + msg = &b.notificationResponse + case 'C': + msg = &b.commandComplete + case 'd': + msg = &b.copyData + case 'D': + msg = &b.dataRow + case 'E': + msg = &b.errorResponse + case 'G': + msg = &b.copyInResponse + case 'H': + msg = &b.copyOutResponse + case 'I': + msg = &b.emptyQueryResponse + case 'K': + msg = &b.backendKeyData + case 'n': + msg = &b.noData + case 'N': + msg = &b.noticeResponse + case 'R': + msg = &b.authentication + case 'S': + msg = &b.parameterStatus + case 't': + msg = &b.parameterDescription + case 'T': + msg = &b.rowDescription + case 'V': + msg = &b.functionCallResponse + case 'W': + msg = &b.copyBothResponse + case 'Z': + msg = &b.readyForQuery + default: + return nil, fmt.Errorf("unknown message type: %c", msgType) + } + msgBody, err := b.cr.Next(bodyLen) if err != nil { return nil, err } - if msg, ok := backendMessages[msgType]; ok { - err = msg.UnmarshalBinary(msgBody) - return msg, err - } - - return nil, fmt.Errorf("unknown message type: %c", msgType) + err = msg.UnmarshalBinary(msgBody) + return msg, err } From eff55451cfd12e893e19a0120edd519be33a6581 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Apr 2017 11:55:14 -0500 Subject: [PATCH 0078/1158] Reduce allocations and copies in pgproto3 Altered chunkreader to never reuse memory. Altered pgproto3 to to copy memory when decoding. Renamed UnmarshalBinary to Decode because of changed semantics. --- authentication.go | 2 +- backend_key_data.go | 2 +- bind_complete.go | 2 +- close_complete.go | 2 +- command_complete.go | 13 ++++----- copy_both_response.go | 2 +- copy_data.go | 5 ++-- copy_in_response.go | 2 +- copy_out_response.go | 2 +- data_row.go | 39 +++++++++++++++----------- empty_query_response.go | 2 +- error_response.go | 2 +- frontend.go | 2 +- function_call_response.go | 22 +++++++++------ no_data.go | 2 +- notice_response.go | 4 +-- notification_response.go | 2 +- parameter_description.go | 2 +- parameter_status.go | 2 +- parse_complete.go | 2 +- pgproto3.go | 59 ++++----------------------------------- query.go | 2 +- ready_for_query.go | 2 +- row_description.go | 2 +- 24 files changed, 70 insertions(+), 108 deletions(-) diff --git a/authentication.go b/authentication.go index e265a247..54f4978f 100644 --- a/authentication.go +++ b/authentication.go @@ -21,7 +21,7 @@ type Authentication struct { func (*Authentication) Backend() {} -func (dst *Authentication) UnmarshalBinary(src []byte) error { +func (dst *Authentication) Decode(src []byte) error { *dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])} switch dst.Type { diff --git a/backend_key_data.go b/backend_key_data.go index 5d8eb496..04f31aec 100644 --- a/backend_key_data.go +++ b/backend_key_data.go @@ -13,7 +13,7 @@ type BackendKeyData struct { func (*BackendKeyData) Backend() {} -func (dst *BackendKeyData) UnmarshalBinary(src []byte) error { +func (dst *BackendKeyData) Decode(src []byte) error { if len(src) != 8 { return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} } diff --git a/bind_complete.go b/bind_complete.go index 756a30e6..4f1c44b8 100644 --- a/bind_complete.go +++ b/bind_complete.go @@ -8,7 +8,7 @@ type BindComplete struct{} func (*BindComplete) Backend() {} -func (dst *BindComplete) UnmarshalBinary(src []byte) error { +func (dst *BindComplete) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)} } diff --git a/close_complete.go b/close_complete.go index fd6ff180..9bab3e8c 100644 --- a/close_complete.go +++ b/close_complete.go @@ -8,7 +8,7 @@ type CloseComplete struct{} func (*CloseComplete) Backend() {} -func (dst *CloseComplete) UnmarshalBinary(src []byte) error { +func (dst *CloseComplete) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)} } diff --git a/command_complete.go b/command_complete.go index ac60153e..86653804 100644 --- a/command_complete.go +++ b/command_complete.go @@ -11,14 +11,13 @@ type CommandComplete struct { func (*CommandComplete) Backend() {} -func (dst *CommandComplete) UnmarshalBinary(src []byte) error { - buf := bytes.NewBuffer(src) - - b, err := buf.ReadBytes(0) - if err != nil { - return err +func (dst *CommandComplete) Decode(src []byte) error { + idx := bytes.IndexByte(src, 0) + if idx != len(src)-1 { + return &invalidMessageFormatErr{messageType: "CommandComplete"} } - dst.CommandTag = string(b[:len(b)-1]) + + dst.CommandTag = string(src[:idx]) return nil } diff --git a/copy_both_response.go b/copy_both_response.go index 2a4c58af..3857c187 100644 --- a/copy_both_response.go +++ b/copy_both_response.go @@ -13,7 +13,7 @@ type CopyBothResponse struct { func (*CopyBothResponse) Backend() {} -func (dst *CopyBothResponse) UnmarshalBinary(src []byte) error { +func (dst *CopyBothResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) if buf.Len() < 3 { diff --git a/copy_data.go b/copy_data.go index b9ea6272..de7ab4ff 100644 --- a/copy_data.go +++ b/copy_data.go @@ -13,9 +13,8 @@ type CopyData struct { func (*CopyData) Backend() {} func (*CopyData) Frontend() {} -func (dst *CopyData) UnmarshalBinary(src []byte) error { - dst.Data = make([]byte, len(src)) - copy(dst.Data, src) +func (dst *CopyData) Decode(src []byte) error { + dst.Data = src return nil } diff --git a/copy_in_response.go b/copy_in_response.go index 63868c7a..9854d665 100644 --- a/copy_in_response.go +++ b/copy_in_response.go @@ -13,7 +13,7 @@ type CopyInResponse struct { func (*CopyInResponse) Backend() {} -func (dst *CopyInResponse) UnmarshalBinary(src []byte) error { +func (dst *CopyInResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) if buf.Len() < 3 { diff --git a/copy_out_response.go b/copy_out_response.go index e46d9e8f..5ef6e4c1 100644 --- a/copy_out_response.go +++ b/copy_out_response.go @@ -13,7 +13,7 @@ type CopyOutResponse struct { func (*CopyOutResponse) Backend() {} -func (dst *CopyOutResponse) UnmarshalBinary(src []byte) error { +func (dst *CopyOutResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) if buf.Len() < 3 { diff --git a/data_row.go b/data_row.go index c95861b9..6b27f728 100644 --- a/data_row.go +++ b/data_row.go @@ -13,35 +13,42 @@ type DataRow struct { func (*DataRow) Backend() {} -func (dst *DataRow) UnmarshalBinary(src []byte) error { - buf := bytes.NewBuffer(src) - - if buf.Len() < 2 { +func (dst *DataRow) Decode(src []byte) error { + if len(src) < 2 { return &invalidMessageFormatErr{messageType: "DataRow"} } - fieldCount := int(binary.BigEndian.Uint16(buf.Next(2))) + rp := 0 + fieldCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 - dst.Values = make([][]byte, fieldCount) + // If the capacity of the values slice is too small OR substantially too + // large reallocate. This is too avoid one row with many columns from + // permanently allocating memory. + if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 { + dst.Values = make([][]byte, fieldCount, 32) + } else { + dst.Values = dst.Values[:fieldCount] + } for i := 0; i < fieldCount; i++ { - if buf.Len() < 4 { + if len(src[rp:]) < 4 { return &invalidMessageFormatErr{messageType: "DataRow"} } - msgSize := int(int32(binary.BigEndian.Uint32(buf.Next(4)))) + msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 // null if msgSize == -1 { - continue - } + dst.Values[i] = nil + } else { + if len(src[rp:]) < msgSize { + return &invalidMessageFormatErr{messageType: "DataRow"} + } - value := make([]byte, msgSize) - _, err := buf.Read(value) - if err != nil { - return err + dst.Values[i] = src[rp : rp+msgSize] + rp += msgSize } - - dst.Values[i] = value } return nil diff --git a/empty_query_response.go b/empty_query_response.go index de6e6272..13ed1886 100644 --- a/empty_query_response.go +++ b/empty_query_response.go @@ -8,7 +8,7 @@ type EmptyQueryResponse struct{} func (*EmptyQueryResponse) Backend() {} -func (dst *EmptyQueryResponse) UnmarshalBinary(src []byte) error { +func (dst *EmptyQueryResponse) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)} } diff --git a/error_response.go b/error_response.go index 82e408d7..602dd2a1 100644 --- a/error_response.go +++ b/error_response.go @@ -30,7 +30,7 @@ type ErrorResponse struct { func (*ErrorResponse) Backend() {} -func (dst *ErrorResponse) UnmarshalBinary(src []byte) error { +func (dst *ErrorResponse) Decode(src []byte) error { *dst = ErrorResponse{} buf := bytes.NewBuffer(src) diff --git a/frontend.go b/frontend.go index df67b718..50835836 100644 --- a/frontend.go +++ b/frontend.go @@ -108,6 +108,6 @@ func (b *Frontend) Receive() (BackendMessage, error) { return nil, err } - err = msg.UnmarshalBinary(msgBody) + err = msg.Decode(msgBody) return msg, err } diff --git a/function_call_response.go b/function_call_response.go index 5c692b36..1e0f16af 100644 --- a/function_call_response.go +++ b/function_call_response.go @@ -13,20 +13,24 @@ type FunctionCallResponse struct { func (*FunctionCallResponse) Backend() {} -func (dst *FunctionCallResponse) UnmarshalBinary(src []byte) error { - buf := bytes.NewBuffer(src) - - if buf.Len() < 4 { +func (dst *FunctionCallResponse) Decode(src []byte) error { + if len(src) < 4 { return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} } - resultSize := int(binary.BigEndian.Uint32(buf.Next(4))) - if buf.Len() != resultSize { + rp := 0 + resultSize := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + if resultSize == -1 { + dst.Result = nil + return nil + } + + if len(src[rp:]) != resultSize { return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} } - dst.Result = make([]byte, resultSize) - copy(dst.Result, buf.Bytes()) - + dst.Result = src[rp:] return nil } diff --git a/no_data.go b/no_data.go index 47ebf28e..3adec4ad 100644 --- a/no_data.go +++ b/no_data.go @@ -8,7 +8,7 @@ type NoData struct{} func (*NoData) Backend() {} -func (dst *NoData) UnmarshalBinary(src []byte) error { +func (dst *NoData) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)} } diff --git a/notice_response.go b/notice_response.go index 767c9a67..8af55baf 100644 --- a/notice_response.go +++ b/notice_response.go @@ -4,8 +4,8 @@ type NoticeResponse ErrorResponse func (*NoticeResponse) Backend() {} -func (dst *NoticeResponse) UnmarshalBinary(src []byte) error { - return (*ErrorResponse)(dst).UnmarshalBinary(src) +func (dst *NoticeResponse) Decode(src []byte) error { + return (*ErrorResponse)(dst).Decode(src) } func (src *NoticeResponse) MarshalBinary() ([]byte, error) { diff --git a/notification_response.go b/notification_response.go index 4ae8bab3..7262844e 100644 --- a/notification_response.go +++ b/notification_response.go @@ -14,7 +14,7 @@ type NotificationResponse struct { func (*NotificationResponse) Backend() {} -func (dst *NotificationResponse) UnmarshalBinary(src []byte) error { +func (dst *NotificationResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) pid := binary.BigEndian.Uint32(buf.Next(4)) diff --git a/parameter_description.go b/parameter_description.go index 40d92c50..32b6e1c1 100644 --- a/parameter_description.go +++ b/parameter_description.go @@ -12,7 +12,7 @@ type ParameterDescription struct { func (*ParameterDescription) Backend() {} -func (dst *ParameterDescription) UnmarshalBinary(src []byte) error { +func (dst *ParameterDescription) Decode(src []byte) error { buf := bytes.NewBuffer(src) if buf.Len() < 2 { diff --git a/parameter_status.go b/parameter_status.go index b8ce7f8d..9b10824c 100644 --- a/parameter_status.go +++ b/parameter_status.go @@ -13,7 +13,7 @@ type ParameterStatus struct { func (*ParameterStatus) Backend() {} -func (dst *ParameterStatus) UnmarshalBinary(src []byte) error { +func (dst *ParameterStatus) Decode(src []byte) error { buf := bytes.NewBuffer(src) b, err := buf.ReadBytes(0) diff --git a/parse_complete.go b/parse_complete.go index 24951e3d..e949c14c 100644 --- a/parse_complete.go +++ b/parse_complete.go @@ -8,7 +8,7 @@ type ParseComplete struct{} func (*ParseComplete) Backend() {} -func (dst *ParseComplete) UnmarshalBinary(src []byte) error { +func (dst *ParseComplete) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)} } diff --git a/pgproto3.go b/pgproto3.go index a9221239..3fe8fc93 100644 --- a/pgproto3.go +++ b/pgproto3.go @@ -2,8 +2,13 @@ package pgproto3 import "fmt" +// Message is the interface implemented by an object that can decode and encode +// a particular PostgreSQL message. +// +// Decode is allowed and expected to retain a reference to data after +// returning (unlike encoding.BinaryUnmarshaler). type Message interface { - UnmarshalBinary(data []byte) error + Decode(data []byte) error MarshalBinary() (data []byte, err error) } @@ -17,58 +22,6 @@ type BackendMessage interface { Backend() // no-op method to distinguish frontend from backend methods } -// func ParseBackend(typeByte byte, body []byte) (BackendMessage, error) { -// switch typeByte { -// case '1': -// return ParseParseComplete(body) -// case '2': -// return ParseBindComplete(body) -// case 'C': -// return ParseCommandComplete(body) -// case 'D': -// return ParseDataRow(body) -// case 'E': -// return ParseErrorResponse(body) -// case 'K': -// return ParseBackendKeyData(body) -// case 'R': -// return ParseAuthentication(body) -// case 'S': -// return ParseParameterStatus(body) -// case 'T': -// return ParseRowDescription(body) -// case 't': -// return ParseParameterDescription(body) -// case 'Z': -// return ParseReadyForQuery(body) -// default: -// return ParseUnknownMessage(typeByte, body) -// } -// } - -// func ParseFrontend(typeByte byte, body []byte) (FrontendMessage, error) { -// switch typeByte { -// case 'B': -// return ParseBind(body) -// case 'D': -// return ParseDescribe(body) -// case 'E': -// return ParseExecute(body) -// case 'P': -// return ParseParse(body) -// case 'p': -// return ParsePasswordMessage(body) -// case 'Q': -// return ParseQuery(body) -// case 'S': -// return ParseSync(body) -// case 'X': -// return ParseTerminate(body) -// default: -// return ParseUnknownMessage(typeByte, body) -// } -// } - type invalidMessageLenErr struct { messageType string expectedLen int diff --git a/query.go b/query.go index a3fc32eb..b5fc2dbc 100644 --- a/query.go +++ b/query.go @@ -11,7 +11,7 @@ type Query struct { func (*Query) Frontend() {} -func (dst *Query) UnmarshalBinary(src []byte) error { +func (dst *Query) Decode(src []byte) error { i := bytes.IndexByte(src, 0) if i != len(src)-1 { return &invalidMessageFormatErr{messageType: "Query"} diff --git a/ready_for_query.go b/ready_for_query.go index 09005d00..e0e4707a 100644 --- a/ready_for_query.go +++ b/ready_for_query.go @@ -10,7 +10,7 @@ type ReadyForQuery struct { func (*ReadyForQuery) Backend() {} -func (dst *ReadyForQuery) UnmarshalBinary(src []byte) error { +func (dst *ReadyForQuery) Decode(src []byte) error { if len(src) != 1 { return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)} } diff --git a/row_description.go b/row_description.go index 294a6aa9..b1110290 100644 --- a/row_description.go +++ b/row_description.go @@ -27,7 +27,7 @@ type RowDescription struct { func (*RowDescription) Backend() {} -func (dst *RowDescription) UnmarshalBinary(src []byte) error { +func (dst *RowDescription) Decode(src []byte) error { buf := bytes.NewBuffer(src) if buf.Len() < 2 { From 61026b7c21bc88079489858a2d9c7c7f708a0348 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Apr 2017 11:55:14 -0500 Subject: [PATCH 0079/1158] Reduce allocations and copies in pgproto3 Altered chunkreader to never reuse memory. Altered pgproto3 to to copy memory when decoding. Renamed UnmarshalBinary to Decode because of changed semantics. --- chunkreader.go | 25 ++++--------------------- chunkreader_test.go | 44 ++++++-------------------------------------- 2 files changed, 10 insertions(+), 59 deletions(-) diff --git a/chunkreader.go b/chunkreader.go index f9d6555c..f8d437b2 100644 --- a/chunkreader.go +++ b/chunkreader.go @@ -9,14 +9,12 @@ type ChunkReader struct { buf []byte rp, wp int // buf read position and write position - taken bool options Options } type Options struct { MinBufLen int // Minimum buffer length - BlockLen int // Increments to expand buffer (e.g. a 8000 byte request with a BlockLen of 1024 would yield a buffer len of 8192) } func NewChunkReader(r io.Reader) *ChunkReader { @@ -32,9 +30,6 @@ func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) { if options.MinBufLen == 0 { options.MinBufLen = 4096 } - if options.BlockLen == 0 { - options.BlockLen = 512 - } return &ChunkReader{ r: r, @@ -43,8 +38,8 @@ func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) { }, nil } -// Next returns buf filled with the next n bytes. buf is only valid until the -// next call to Next. If an error occurs, buf will be nil. +// Next returns buf filled with the next n bytes. If an error occurs, buf will +// be nil. func (r *ChunkReader) Next(n int) (buf []byte, err error) { // n bytes already in buf if (r.wp - r.rp) >= n { @@ -56,17 +51,12 @@ func (r *ChunkReader) Next(n int) (buf []byte, err error) { // available space in buf is less than n if len(r.buf) < n { r.copyBufContents(r.newBuf(n)) - r.taken = false } // buf is large enough, but need to shift filled area to start to make enough contiguous space minReadCount := n - (r.wp - r.rp) if (len(r.buf) - r.wp) < minReadCount { - newBuf := r.buf - if r.taken { - newBuf = r.newBuf(n) - r.taken = false - } + newBuf := r.newBuf(n) r.copyBufContents(newBuf) } @@ -79,20 +69,13 @@ func (r *ChunkReader) Next(n int) (buf []byte, err error) { return buf, nil } -// KeepLast prevents the last data retrieved by Next from being reused by the -// ChunkReader. -func (r *ChunkReader) KeepLast() { - r.taken = true -} - func (r *ChunkReader) appendAtLeast(fillLen int) error { n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen) r.wp += n return err } -func (r *ChunkReader) newBuf(min int) []byte { - size := ((min / r.options.BlockLen) + 1) * r.options.BlockLen +func (r *ChunkReader) newBuf(size int) []byte { if size < r.options.MinBufLen { size = r.options.MinBufLen } diff --git a/chunkreader_test.go b/chunkreader_test.go index 9c19ff4a..3be07e3c 100644 --- a/chunkreader_test.go +++ b/chunkreader_test.go @@ -7,7 +7,7 @@ import ( func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 2}) + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) if err != nil { t.Fatal(err) } @@ -44,7 +44,7 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 2}) + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) if err != nil { t.Fatal(err) } @@ -59,14 +59,14 @@ func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { if bytes.Compare(n1, src[0:5]) != 0 { t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1) } - if len(r.buf) != 6 { - t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 6, len(r.buf)) + if len(r.buf) != 5 { + t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 5, len(r.buf)) } } -func TestChunkReaderNextReusesBuf(t *testing.T) { +func TestChunkReaderDoesNotReuseBuf(t *testing.T) { server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 1}) + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) if err != nil { t.Fatal(err) } @@ -90,38 +90,6 @@ func TestChunkReaderNextReusesBuf(t *testing.T) { t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2) } - if bytes.Compare(n1, src[4:8]) != 0 { - t.Fatalf("Expected Next to have reused buf, %v found instead of %v", src[4:8], n1) - } -} - -func TestChunkReaderKeepLastPreventsBufReuse(t *testing.T) { - server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 1}) - if err != nil { - t.Fatal(err) - } - - src := []byte{1, 2, 3, 4, 5, 6, 7, 8} - server.Write(src) - - n1, err := r.Next(4) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n1, src[0:4]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1) - } - r.KeepLast() - - n2, err := r.Next(4) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n2, src[4:8]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2) - } - if bytes.Compare(n1, src[0:4]) != 0 { t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1) } From ab21bc4ec76e4a88a8870618087e12b812fc2d61 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Apr 2017 12:23:51 -0500 Subject: [PATCH 0080/1158] pgtype DecodeText and DecodeBinary do not copy They now take ownership of the src argument. Needed to change Scan to make a copy of []byte arguments as lib/pq apparently gives Scan a shared memory buffer. --- aclitem.go | 4 +++- aclitem_array.go | 4 +++- bool.go | 4 +++- bool_array.go | 4 +++- box.go | 4 +++- bytea.go | 5 +---- bytea_array.go | 4 +++- cidr_array.go | 4 +++- circle.go | 4 +++- date.go | 4 +++- date_array.go | 4 +++- daterange.go | 10 ++++++---- float4.go | 4 +++- float4_array.go | 4 +++- float8.go | 4 +++- float8_array.go | 4 +++- hstore.go | 4 +++- hstore_array.go | 4 +++- inet.go | 4 +++- inet_array.go | 4 +++- int2.go | 4 +++- int2_array.go | 4 +++- int4.go | 4 +++- int4_array.go | 4 +++- int4range.go | 10 ++++++---- int8.go | 4 +++- int8_array.go | 4 +++- int8range.go | 10 ++++++---- interval.go | 4 +++- json.go | 9 ++++----- jsonb.go | 6 +----- line.go | 4 +++- lseg.go | 4 +++- macaddr.go | 4 +++- numeric.go | 4 +++- numeric_array.go | 4 +++- numrange.go | 10 ++++++---- oid.go | 4 +++- path.go | 4 +++- pgtype.go | 8 ++++---- pguint32.go | 4 +++- point.go | 4 +++- polygon.go | 4 +++- text.go | 4 +++- text_array.go | 4 +++- tid.go | 4 +++- timestamp.go | 4 +++- timestamp_array.go | 4 +++- timestamptz.go | 4 +++- timestamptz_array.go | 4 +++- tsrange.go | 10 ++++++---- tstzrange.go | 10 ++++++---- typed_array.go.erb | 4 +++- typed_range.go.erb | 4 +++- uuid.go | 4 +++- varbit.go | 9 ++++----- varchar_array.go | 4 +++- 57 files changed, 188 insertions(+), 93 deletions(-) diff --git a/aclitem.go b/aclitem.go index ebfcc3e7..31065764 100644 --- a/aclitem.go +++ b/aclitem.go @@ -106,7 +106,9 @@ func (dst *Aclitem) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/aclitem_array.go b/aclitem_array.go index 7ef76573..480b5bba 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -206,7 +206,9 @@ func (dst *AclitemArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/bool.go b/bool.go index 9d309f0c..ba876c91 100644 --- a/bool.go +++ b/bool.go @@ -142,7 +142,9 @@ func (dst *Bool) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/bool_array.go b/bool_array.go index 468f6816..4e92a616 100644 --- a/bool_array.go +++ b/bool_array.go @@ -308,7 +308,9 @@ func (dst *BoolArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/box.go b/box.go index 2e4f39ee..e25af854 100644 --- a/box.go +++ b/box.go @@ -156,7 +156,9 @@ func (dst *Box) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/bytea.go b/bytea.go index 3e2661db..bf774476 100644 --- a/bytea.go +++ b/bytea.go @@ -95,10 +95,7 @@ func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } - buf := make([]byte, len(src)) - copy(buf, src) - - *dst = Bytea{Bytes: buf, Status: Present} + *dst = Bytea{Bytes: src, Status: Present} return nil } diff --git a/bytea_array.go b/bytea_array.go index 4aa2b862..dd79b991 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -308,7 +308,9 @@ func (dst *ByteaArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/cidr_array.go b/cidr_array.go index 96d912ae..0aa289e7 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -337,7 +337,9 @@ func (dst *CidrArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/circle.go b/circle.go index 8c8f4693..e9268a06 100644 --- a/circle.go +++ b/circle.go @@ -138,7 +138,9 @@ func (dst *Circle) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/date.go b/date.go index 993a04c5..a7e4762a 100644 --- a/date.go +++ b/date.go @@ -185,7 +185,9 @@ func (dst *Date) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) case time.Time: *dst = Date{Time: src, Status: Present} return nil diff --git a/date_array.go b/date_array.go index f24bf6b9..91e2ee62 100644 --- a/date_array.go +++ b/date_array.go @@ -309,7 +309,9 @@ func (dst *DateArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/daterange.go b/daterange.go index 5cecca20..a5cd5d95 100644 --- a/daterange.go +++ b/daterange.go @@ -106,7 +106,7 @@ func (dst *Daterange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Daterange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Daterange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src *Daterange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src *Daterange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Daterange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -256,13 +256,15 @@ func (dst *Daterange) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. -func (src *Daterange) Value() (driver.Value, error) { +func (src Daterange) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/float4.go b/float4.go index 76be4203..77bc4878 100644 --- a/float4.go +++ b/float4.go @@ -177,7 +177,9 @@ func (dst *Float4) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/float4_array.go b/float4_array.go index db1523f0..38508a52 100644 --- a/float4_array.go +++ b/float4_array.go @@ -308,7 +308,9 @@ func (dst *Float4Array) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/float8.go b/float8.go index 8cfc53c5..5322e251 100644 --- a/float8.go +++ b/float8.go @@ -167,7 +167,9 @@ func (dst *Float8) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/float8_array.go b/float8_array.go index 19878bbb..2f310bbd 100644 --- a/float8_array.go +++ b/float8_array.go @@ -308,7 +308,9 @@ func (dst *Float8Array) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/hstore.go b/hstore.go index 04df2acc..69a35b17 100644 --- a/hstore.go +++ b/hstore.go @@ -455,7 +455,9 @@ func (dst *Hstore) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/hstore_array.go b/hstore_array.go index e4263f20..9f773af2 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -308,7 +308,9 @@ func (dst *HstoreArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/inet.go b/inet.go index e3a7ec88..7c09a549 100644 --- a/inet.go +++ b/inet.go @@ -213,7 +213,9 @@ func (dst *Inet) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/inet_array.go b/inet_array.go index 4687b145..ed9f5d1c 100644 --- a/inet_array.go +++ b/inet_array.go @@ -337,7 +337,9 @@ func (dst *InetArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/int2.go b/int2.go index 4a3beb22..028cdfcf 100644 --- a/int2.go +++ b/int2.go @@ -178,7 +178,9 @@ func (dst *Int2) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/int2_array.go b/int2_array.go index 3506370e..cdfcde48 100644 --- a/int2_array.go +++ b/int2_array.go @@ -336,7 +336,9 @@ func (dst *Int2Array) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/int4.go b/int4.go index f429d887..cae0d32a 100644 --- a/int4.go +++ b/int4.go @@ -169,7 +169,9 @@ func (dst *Int4) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/int4_array.go b/int4_array.go index e4ec6455..9ca0b067 100644 --- a/int4_array.go +++ b/int4_array.go @@ -336,7 +336,9 @@ func (dst *Int4Array) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/int4range.go b/int4range.go index 12a48dab..29b8371e 100644 --- a/int4range.go +++ b/int4range.go @@ -106,7 +106,7 @@ func (dst *Int4range) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int4range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Int4range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src *Int4range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src *Int4range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Int4range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -256,13 +256,15 @@ func (dst *Int4range) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. -func (src *Int4range) Value() (driver.Value, error) { +func (src Int4range) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/int8.go b/int8.go index 97db8393..a4ec4e62 100644 --- a/int8.go +++ b/int8.go @@ -155,7 +155,9 @@ func (dst *Int8) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/int8_array.go b/int8_array.go index 6c0dab65..c5026f83 100644 --- a/int8_array.go +++ b/int8_array.go @@ -336,7 +336,9 @@ func (dst *Int8Array) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/int8range.go b/int8range.go index 3541dbe2..e3e0486f 100644 --- a/int8range.go +++ b/int8range.go @@ -106,7 +106,7 @@ func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src *Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src *Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -256,13 +256,15 @@ func (dst *Int8range) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. -func (src *Int8range) Value() (driver.Value, error) { +func (src Int8range) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/interval.go b/interval.go index 050d5610..8ce345a3 100644 --- a/interval.go +++ b/interval.go @@ -259,7 +259,9 @@ func (dst *Interval) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/json.go b/json.go index a027a91c..44880863 100644 --- a/json.go +++ b/json.go @@ -97,10 +97,7 @@ func (dst *Json) DecodeText(ci *ConnInfo, src []byte) error { return nil } - buf := make([]byte, len(src)) - copy(buf, src) - - *dst = Json{Bytes: buf, Status: Present} + *dst = Json{Bytes: src, Status: Present} return nil } @@ -135,7 +132,9 @@ func (dst *Json) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/jsonb.go b/jsonb.go index 82cbb21f..5533b4b4 100644 --- a/jsonb.go +++ b/jsonb.go @@ -37,12 +37,8 @@ func (dst *Jsonb) DecodeBinary(ci *ConnInfo, src []byte) error { if src[0] != 1 { return fmt.Errorf("unknown jsonb version number %d", src[0]) } - src = src[1:] - buf := make([]byte, len(src)) - copy(buf, src) - - *dst = Jsonb{Bytes: buf, Status: Present} + *dst = Jsonb{Bytes: src[1:], Status: Present} return nil } diff --git a/line.go b/line.go index 06f01f21..75fdf207 100644 --- a/line.go +++ b/line.go @@ -136,7 +136,9 @@ func (dst *Line) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/lseg.go b/lseg.go index 986724cc..823c2c09 100644 --- a/lseg.go +++ b/lseg.go @@ -156,7 +156,9 @@ func (dst *Lseg) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/macaddr.go b/macaddr.go index 0fe092e4..785148a2 100644 --- a/macaddr.go +++ b/macaddr.go @@ -142,7 +142,9 @@ func (dst *Macaddr) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/numeric.go b/numeric.go index 63f99c06..8dbc0251 100644 --- a/numeric.go +++ b/numeric.go @@ -594,7 +594,9 @@ func (dst *Numeric) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/numeric_array.go b/numeric_array.go index 3d59a6b0..2fc844eb 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -336,7 +336,9 @@ func (dst *NumericArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/numrange.go b/numrange.go index b0baec9a..bac6fc4b 100644 --- a/numrange.go +++ b/numrange.go @@ -106,7 +106,7 @@ func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src *Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src *Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -256,13 +256,15 @@ func (dst *Numrange) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. -func (src *Numrange) Value() (driver.Value, error) { +func (src Numrange) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/oid.go b/oid.go index 339dee0f..58a7b0f5 100644 --- a/oid.go +++ b/oid.go @@ -70,7 +70,9 @@ func (dst *Oid) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/path.go b/path.go index 2fd6cfc7..c1aa76bc 100644 --- a/path.go +++ b/path.go @@ -195,7 +195,9 @@ func (dst *Path) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype.go b/pgtype.go index 27a1a091..3a6b7471 100644 --- a/pgtype.go +++ b/pgtype.go @@ -96,15 +96,15 @@ type Value interface { type BinaryDecoder interface { // DecodeBinary decodes src into BinaryDecoder. If src is nil then the - // original SQL value is NULL. BinaryDecoder MUST not retain a reference to - // src. It MUST make a copy if it needs to retain the raw bytes. + // original SQL value is NULL. BinaryDecoder takes ownership of src. The + // caller MUST not use it again. DecodeBinary(ci *ConnInfo, src []byte) error } type TextDecoder interface { // DecodeText decodes src into TextDecoder. If src is nil then the original - // SQL value is NULL. TextDecoder MUST not retain a reference to src. It MUST - // make a copy if it needs to retain the raw bytes. + // SQL value is NULL. TextDecoder takes ownership of src. The caller MUST not + // use it again. DecodeText(ci *ConnInfo, src []byte) error } diff --git a/pguint32.go b/pguint32.go index 0caa0cba..a13c1fcd 100644 --- a/pguint32.go +++ b/pguint32.go @@ -144,7 +144,9 @@ func (dst *pguint32) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/point.go b/point.go index 3d51766e..62901340 100644 --- a/point.go +++ b/point.go @@ -130,7 +130,9 @@ func (dst *Point) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/polygon.go b/polygon.go index af99ee3d..c4383765 100644 --- a/polygon.go +++ b/polygon.go @@ -174,7 +174,9 @@ func (dst *Polygon) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/text.go b/text.go index 8e42a756..54e2d774 100644 --- a/text.go +++ b/text.go @@ -118,7 +118,9 @@ func (dst *Text) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/text_array.go b/text_array.go index a6bd4724..8a573d83 100644 --- a/text_array.go +++ b/text_array.go @@ -308,7 +308,9 @@ func (dst *TextArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/tid.go b/tid.go index 7976afde..7456b155 100644 --- a/tid.go +++ b/tid.go @@ -134,7 +134,9 @@ func (dst *Tid) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/timestamp.go b/timestamp.go index 694b63c0..4fb10abc 100644 --- a/timestamp.go +++ b/timestamp.go @@ -201,7 +201,9 @@ func (dst *Timestamp) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) case time.Time: *dst = Timestamp{Time: src, Status: Present} return nil diff --git a/timestamp_array.go b/timestamp_array.go index 2046c387..49815dae 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -309,7 +309,9 @@ func (dst *TimestampArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/timestamptz.go b/timestamptz.go index 3c76ec03..8606b2f2 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -197,7 +197,9 @@ func (dst *Timestamptz) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) case time.Time: *dst = Timestamptz{Time: src, Status: Present} return nil diff --git a/timestamptz_array.go b/timestamptz_array.go index fd58d3be..bf983b6b 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -309,7 +309,9 @@ func (dst *TimestamptzArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/tsrange.go b/tsrange.go index 78a94af2..429a5cbe 100644 --- a/tsrange.go +++ b/tsrange.go @@ -106,7 +106,7 @@ func (dst *Tsrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Tsrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Tsrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src *Tsrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src *Tsrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Tsrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -256,13 +256,15 @@ func (dst *Tsrange) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. -func (src *Tsrange) Value() (driver.Value, error) { +func (src Tsrange) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/tstzrange.go b/tstzrange.go index d1fc7326..f03a9f65 100644 --- a/tstzrange.go +++ b/tstzrange.go @@ -106,7 +106,7 @@ func (dst *Tstzrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Tstzrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Tstzrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src *Tstzrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src *Tstzrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Tstzrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -256,13 +256,15 @@ func (dst *Tstzrange) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. -func (src *Tstzrange) Value() (driver.Value, error) { +func (src Tstzrange) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/typed_array.go.erb b/typed_array.go.erb index 2a38ed82..6752bd5b 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -310,7 +310,9 @@ func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/typed_range.go.erb b/typed_range.go.erb index e46f71c7..49db1b1d 100644 --- a/typed_range.go.erb +++ b/typed_range.go.erb @@ -256,7 +256,9 @@ func (dst *<%= range_type %>) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/uuid.go b/uuid.go index c830c086..a4a93ab3 100644 --- a/uuid.go +++ b/uuid.go @@ -161,7 +161,9 @@ func (dst *Uuid) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/varbit.go b/varbit.go index 00c34e10..b986f02a 100644 --- a/varbit.go +++ b/varbit.go @@ -72,10 +72,7 @@ func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { bitLen := int32(binary.BigEndian.Uint32(src)) rp := 4 - buf := make([]byte, len(src[rp:])) - copy(buf, src[rp:]) - - *dst = Varbit{Bytes: buf, Len: bitLen, Status: Present} + *dst = Varbit{Bytes: src[rp:], Len: bitLen, Status: Present} return nil } @@ -129,7 +126,9 @@ func (dst *Varbit) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/varchar_array.go b/varchar_array.go index 9ca16d7e..d84fac02 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -308,7 +308,9 @@ func (dst *VarcharArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) From d25abf56747f299c190301717e7f022b7e78dbbb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 1 May 2017 18:11:55 -0500 Subject: [PATCH 0081/1158] Add pgproto3.Backend --- backend.go | 74 ++++++++++++++++++++ bind.go | 167 ++++++++++++++++++++++++++++++++++++++++++++ describe.go | 60 ++++++++++++++++ execute.go | 60 ++++++++++++++++ parse.go | 82 ++++++++++++++++++++++ password_message.go | 44 ++++++++++++ sync.go | 29 ++++++++ terminate.go | 29 ++++++++ 8 files changed, 545 insertions(+) create mode 100644 backend.go create mode 100644 bind.go create mode 100644 describe.go create mode 100644 execute.go create mode 100644 parse.go create mode 100644 password_message.go create mode 100644 sync.go create mode 100644 terminate.go diff --git a/backend.go b/backend.go new file mode 100644 index 00000000..c04116a8 --- /dev/null +++ b/backend.go @@ -0,0 +1,74 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/jackc/pgx/chunkreader" +) + +type Backend struct { + cr *chunkreader.ChunkReader + w io.Writer + + // Frontend message flyweights + bind Bind + describe Describe + execute Execute + parse Parse + passwordMessage PasswordMessage + query Query + sync Sync + terminate Terminate +} + +func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { + cr := chunkreader.NewChunkReader(r) + return &Backend{cr: cr, w: w}, nil +} + +func (b *Backend) Send(msg BackendMessage) error { + return errors.New("not implemented") +} + +func (b *Backend) Receive() (FrontendMessage, error) { + header, err := b.cr.Next(5) + if err != nil { + return nil, err + } + + msgType := header[0] + bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 + + var msg FrontendMessage + switch msgType { + case 'B': + msg = &b.bind + case 'D': + msg = &b.describe + case 'E': + msg = &b.execute + case 'P': + msg = &b.parse + case 'p': + msg = &b.passwordMessage + case 'Q': + msg = &b.query + case 'S': + msg = &b.sync + case 'X': + msg = &b.terminate + default: + return nil, fmt.Errorf("unknown message type: %c", msgType) + } + + msgBody, err := b.cr.Next(bodyLen) + if err != nil { + return nil, err + } + + err = msg.Decode(msgBody) + return msg, err +} diff --git a/bind.go b/bind.go new file mode 100644 index 00000000..6661a775 --- /dev/null +++ b/bind.go @@ -0,0 +1,167 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "encoding/json" +) + +type Bind struct { + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters [][]byte + ResultFormatCodes []int16 +} + +func (*Bind) Frontend() {} + +func (dst *Bind) Decode(src []byte) error { + idx := bytes.IndexByte(src, 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + dst.DestinationPortal = string(src[:idx]) + rp := idx + 1 + + idx = bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + dst.PreparedStatement = string(src[rp : rp+idx]) + rp += idx + 1 + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount) + + if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + for i := 0; i < parameterFormatCodeCount; i++ { + dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + parameterCount := int(binary.BigEndian.Uint16(src[rp:])) + + dst.Parameters = make([][]byte, parameterCount) + + for i := 0; i < parameterCount; i++ { + if len(src[rp:]) < 4 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + + msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + // null + if msgSize == -1 { + continue + } + + if len(src[rp:]) < msgSize { + return &invalidMessageFormatErr{messageType: "Bind"} + } + + dst.Parameters[i] = src[rp : rp+msgSize] + rp += msgSize + } + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + dst.ResultFormatCodes = make([]int16, resultFormatCodeCount) + if len(src[rp:]) < len(dst.ResultFormatCodes)*2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + for i := 0; i < resultFormatCodeCount; i++ { + dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + return nil +} + +func (src *Bind) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('B') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteString(src.DestinationPortal) + buf.WriteByte(0) + buf.WriteString(src.PreparedStatement) + buf.WriteByte(0) + + buf.Write(bigEndian.Uint16(uint16(len(src.ParameterFormatCodes)))) + + for _, fc := range src.ParameterFormatCodes { + buf.Write(bigEndian.Int16(fc)) + } + + buf.Write(bigEndian.Uint16(uint16(len(src.Parameters)))) + + for _, p := range src.Parameters { + if p == nil { + buf.Write(bigEndian.Int32(-1)) + continue + } + + buf.Write(bigEndian.Int32(int32(len(p)))) + buf.Write(p) + } + + buf.Write(bigEndian.Uint16(uint16(len(src.ResultFormatCodes)))) + + for _, fc := range src.ResultFormatCodes { + buf.Write(bigEndian.Int16(fc)) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *Bind) MarshalJSON() ([]byte, error) { + formattedParameters := make([]map[string]string, len(src.Parameters)) + for i, p := range src.Parameters { + if p == nil { + continue + } + + if src.ParameterFormatCodes[i] == 0 { + formattedParameters[i] = map[string]string{"text": string(p)} + } else { + formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)} + } + } + + return json.Marshal(struct { + Type string + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters []map[string]string + ResultFormatCodes []int16 + }{ + Type: "Bind", + DestinationPortal: src.DestinationPortal, + PreparedStatement: src.PreparedStatement, + ParameterFormatCodes: src.ParameterFormatCodes, + Parameters: formattedParameters, + ResultFormatCodes: src.ResultFormatCodes, + }) +} diff --git a/describe.go b/describe.go new file mode 100644 index 00000000..ea55ed9d --- /dev/null +++ b/describe.go @@ -0,0 +1,60 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type Describe struct { + ObjectType byte // 'S' = prepared statement, 'P' = portal + Name string +} + +func (*Describe) Frontend() {} + +func (dst *Describe) Decode(src []byte) error { + if len(src) < 2 { + return &invalidMessageFormatErr{messageType: "Describe"} + } + + dst.ObjectType = src[0] + rp := 1 + + idx := bytes.IndexByte(src[rp:], 0) + if idx != len(src[rp:])-1 { + return &invalidMessageFormatErr{messageType: "Describe"} + } + + dst.Name = string(src[rp : len(src)-1]) + + return nil +} + +func (src *Describe) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('D') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteByte(src.ObjectType) + buf.WriteString(src.Name) + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *Describe) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ObjectType string + Name string + }{ + Type: "Describe", + ObjectType: string(src.ObjectType), + Name: src.Name, + }) +} diff --git a/execute.go b/execute.go new file mode 100644 index 00000000..4892e7b3 --- /dev/null +++ b/execute.go @@ -0,0 +1,60 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type Execute struct { + Portal string + MaxRows uint32 +} + +func (*Execute) Frontend() {} + +func (dst *Execute) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Portal = string(b[:len(b)-1]) + + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "Execute"} + } + dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4)) + + return nil +} + +func (src *Execute) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('E') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteString(src.Portal) + buf.WriteByte(0) + + buf.Write(bigEndian.Uint32(src.MaxRows)) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *Execute) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Portal string + MaxRows uint32 + }{ + Type: "Execute", + Portal: src.Portal, + MaxRows: src.MaxRows, + }) +} diff --git a/parse.go b/parse.go new file mode 100644 index 00000000..5d17ed11 --- /dev/null +++ b/parse.go @@ -0,0 +1,82 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type Parse struct { + Name string + Query string + ParameterOIDs []uint32 +} + +func (*Parse) Frontend() {} + +func (dst *Parse) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Name = string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + dst.Query = string(b[:len(b)-1]) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "Parse"} + } + parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2))) + + for i := 0; i < parameterOIDCount; i++ { + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "Parse"} + } + dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4))) + } + + return nil +} + +func (src *Parse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('P') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteString(src.Name) + buf.WriteByte(0) + buf.WriteString(src.Query) + buf.WriteByte(0) + + buf.Write(bigEndian.Uint16(uint16(len(src.ParameterOIDs)))) + + for _, v := range src.ParameterOIDs { + buf.Write(bigEndian.Uint32(v)) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *Parse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Name string + Query string + ParameterOIDs []uint32 + }{ + Type: "Parse", + Name: src.Name, + Query: src.Query, + ParameterOIDs: src.ParameterOIDs, + }) +} diff --git a/password_message.go b/password_message.go new file mode 100644 index 00000000..69df6362 --- /dev/null +++ b/password_message.go @@ -0,0 +1,44 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" +) + +type PasswordMessage struct { + Password string +} + +func (*PasswordMessage) Frontend() {} + +func (dst *PasswordMessage) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Password = string(b[:len(b)-1]) + + return nil +} + +func (src *PasswordMessage) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.WriteByte('p') + buf.Write(bigEndian.Uint32(uint32(4 + len(src.Password) + 1))) + buf.WriteString(src.Password) + buf.WriteByte(0) + return buf.Bytes(), nil +} + +func (src *PasswordMessage) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Password string + }{ + Type: "PasswordMessage", + Password: src.Password, + }) +} diff --git a/sync.go b/sync.go new file mode 100644 index 00000000..da3fa727 --- /dev/null +++ b/sync.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Sync struct{} + +func (*Sync) Frontend() {} + +func (dst *Sync) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *Sync) MarshalBinary() ([]byte, error) { + return []byte{'S', 0, 0, 0, 4}, nil +} + +func (src *Sync) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Sync", + }) +} diff --git a/terminate.go b/terminate.go new file mode 100644 index 00000000..77977f20 --- /dev/null +++ b/terminate.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Terminate struct{} + +func (*Terminate) Frontend() {} + +func (dst *Terminate) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *Terminate) MarshalBinary() ([]byte, error) { + return []byte{'X', 0, 0, 0, 4}, nil +} + +func (src *Terminate) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Terminate", + }) +} From eb9fc6e7a5ddeb45c286ff0e1954610bde8e266d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 1 May 2017 19:46:37 -0500 Subject: [PATCH 0082/1158] Fix queries with more than 32 columns fixes #270 --- data_row.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/data_row.go b/data_row.go index 6b27f728..3e600e84 100644 --- a/data_row.go +++ b/data_row.go @@ -25,7 +25,11 @@ func (dst *DataRow) Decode(src []byte) error { // large reallocate. This is too avoid one row with many columns from // permanently allocating memory. if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 { - dst.Values = make([][]byte, fieldCount, 32) + newCap := 32 + if newCap < fieldCount { + newCap = fieldCount + } + dst.Values = make([][]byte, fieldCount, newCap) } else { dst.Values = dst.Values[:fieldCount] } From d4fe3edf84071a0dee5e60bf356f5c97eb3f87fd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 2 May 2017 20:38:26 -0500 Subject: [PATCH 0083/1158] Refactor pgio and types to append buffers --- read.go | 106 ++++++++++---------------------------------------- read_test.go | 57 +++++++++++++++++++++++++++ write.go | 105 ++++++++++++------------------------------------- write_test.go | 78 +++++++++++++++++++++++++++++++++++++ 4 files changed, 180 insertions(+), 166 deletions(-) create mode 100644 read_test.go create mode 100644 write_test.go diff --git a/read.go b/read.go index 7c39162c..7ddad508 100644 --- a/read.go +++ b/read.go @@ -2,103 +2,39 @@ package pgio import ( "encoding/binary" - "io" ) -type Uint16Reader interface { - ReadUint16() (n uint16, err error) +func NextByte(buf []byte) ([]byte, byte) { + b := buf[0] + return buf[1:], b } -type Uint32Reader interface { - ReadUint32() (n uint32, err error) +func NextUint16(buf []byte) ([]byte, uint16) { + n := binary.BigEndian.Uint16(buf) + return buf[2:], n } -type Uint64Reader interface { - ReadUint64() (n uint64, err error) +func NextUint32(buf []byte) ([]byte, uint32) { + n := binary.BigEndian.Uint32(buf) + return buf[4:], n } -// ReadByte reads a byte from r. -func ReadByte(r io.Reader) (byte, error) { - if r, ok := r.(io.ByteReader); ok { - return r.ReadByte() - } - - buf := make([]byte, 1) - _, err := r.Read(buf) - return buf[0], err +func NextUint64(buf []byte) ([]byte, uint64) { + n := binary.BigEndian.Uint64(buf) + return buf[8:], n } -// ReadUint16 reads an uint16 from r in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Read if r provides a ReadUint16 -// method. -func ReadUint16(r io.Reader) (uint16, error) { - if r, ok := r.(Uint16Reader); ok { - return r.ReadUint16() - } - - buf := make([]byte, 2) - _, err := io.ReadFull(r, buf) - if err != nil { - return 0, err - } - - return binary.BigEndian.Uint16(buf), nil +func NextInt16(buf []byte) ([]byte, int16) { + buf, n := NextUint16(buf) + return buf, int16(n) } -// ReadInt16 reads an int16 r in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Read if r provides a ReadUint16 -// method. -func ReadInt16(r io.Reader) (int16, error) { - n, err := ReadUint16(r) - return int16(n), err +func NextInt32(buf []byte) ([]byte, int32) { + buf, n := NextUint32(buf) + return buf, int32(n) } -// ReadUint32 reads an uint32 r in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Read if r provides a ReadUint32 -// method. -func ReadUint32(r io.Reader) (uint32, error) { - if r, ok := r.(Uint32Reader); ok { - return r.ReadUint32() - } - - buf := make([]byte, 4) - _, err := io.ReadFull(r, buf) - if err != nil { - return 0, err - } - - return binary.BigEndian.Uint32(buf), nil -} - -// ReadInt32 reads an int32 r in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Read if r provides a ReadUint32 -// method. -func ReadInt32(r io.Reader) (int32, error) { - n, err := ReadUint32(r) - return int32(n), err -} - -// ReadUint64 reads an uint64 r in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Read if r provides a ReadUint64 -// method. -func ReadUint64(r io.Reader) (uint64, error) { - if r, ok := r.(Uint64Reader); ok { - return r.ReadUint64() - } - - buf := make([]byte, 8) - _, err := io.ReadFull(r, buf) - if err != nil { - return 0, err - } - - return binary.BigEndian.Uint64(buf), nil -} - -// ReadInt64 reads an int64 r in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Read if r provides a ReadUint64 -// method. -func ReadInt64(r io.Reader) (int64, error) { - n, err := ReadUint64(r) - return int64(n), err +func NextInt64(buf []byte) ([]byte, int64) { + buf, n := NextUint64(buf) + return buf, int64(n) } diff --git a/read_test.go b/read_test.go new file mode 100644 index 00000000..fbe29ae4 --- /dev/null +++ b/read_test.go @@ -0,0 +1,57 @@ +package pgio + +import ( + "testing" +) + +func TestNextByte(t *testing.T) { + buf := []byte{42, 1} + var b byte + buf, b = NextByte(buf) + if b != 42 { + t.Errorf("NextByte(buf) => %v, want %v", b, 42) + } + buf, b = NextByte(buf) + if b != 1 { + t.Errorf("NextByte(buf) => %v, want %v", b, 1) + } +} + +func TestNextUint16(t *testing.T) { + buf := []byte{0, 42, 0, 1} + var n uint16 + buf, n = NextUint16(buf) + if n != 42 { + t.Errorf("NextUint16(buf) => %v, want %v", n, 42) + } + buf, n = NextUint16(buf) + if n != 1 { + t.Errorf("NextUint16(buf) => %v, want %v", n, 1) + } +} + +func TestNextUint32(t *testing.T) { + buf := []byte{0, 0, 0, 42, 0, 0, 0, 1} + var n uint32 + buf, n = NextUint32(buf) + if n != 42 { + t.Errorf("NextUint32(buf) => %v, want %v", n, 42) + } + buf, n = NextUint32(buf) + if n != 1 { + t.Errorf("NextUint32(buf) => %v, want %v", n, 1) + } +} + +func TestNextUint64(t *testing.T) { + buf := []byte{0, 0, 0, 0, 0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 1} + var n uint64 + buf, n = NextUint64(buf) + if n != 42 { + t.Errorf("NextUint64(buf) => %v, want %v", n, 42) + } + buf, n = NextUint64(buf) + if n != 1 { + t.Errorf("NextUint64(buf) => %v, want %v", n, 1) + } +} diff --git a/write.go b/write.go index 823fbd00..96aedf9d 100644 --- a/write.go +++ b/write.go @@ -1,97 +1,40 @@ package pgio -import ( - "encoding/binary" - "io" -) +import "encoding/binary" -type Uint16Writer interface { - WriteUint16(uint16) (n int, err error) +func AppendUint16(buf []byte, n uint16) []byte { + wp := len(buf) + buf = append(buf, 0, 0) + binary.BigEndian.PutUint16(buf[wp:], n) + return buf } -type Uint32Writer interface { - WriteUint32(uint32) (n int, err error) +func AppendUint32(buf []byte, n uint32) []byte { + wp := len(buf) + buf = append(buf, 0, 0, 0, 0) + binary.BigEndian.PutUint32(buf[wp:], n) + return buf } -type Uint64Writer interface { - WriteUint64(uint64) (n int, err error) +func AppendUint64(buf []byte, n uint64) []byte { + wp := len(buf) + buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0) + binary.BigEndian.PutUint64(buf[wp:], n) + return buf } -// WriteByte writes b to w. -func WriteByte(w io.Writer, b byte) error { - if w, ok := w.(io.ByteWriter); ok { - return w.WriteByte(b) - } - _, err := w.Write([]byte{b}) - return err +func AppendInt16(buf []byte, n int16) []byte { + return AppendUint16(buf, uint16(n)) } -// WriteUint16 writes n to w in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Write if w provides a WriteUint16 -// method. -func WriteUint16(w io.Writer, n uint16) (int, error) { - if w, ok := w.(Uint16Writer); ok { - return w.WriteUint16(n) - } - b := make([]byte, 2) - binary.BigEndian.PutUint16(b, n) - return w.Write(b) +func AppendInt32(buf []byte, n int32) []byte { + return AppendUint32(buf, uint32(n)) } -// WriteInt16 writes n to w in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Write if w provides a WriteUint16 -// method. -func WriteInt16(w io.Writer, n int16) (int, error) { - return WriteUint16(w, uint16(n)) +func AppendInt64(buf []byte, n int64) []byte { + return AppendUint64(buf, uint64(n)) } -// WriteUint32 writes n to w in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Write if w provides a WriteUint32 -// method. -func WriteUint32(w io.Writer, n uint32) (int, error) { - if w, ok := w.(Uint32Writer); ok { - return w.WriteUint32(n) - } - b := make([]byte, 4) - binary.BigEndian.PutUint32(b, n) - return w.Write(b) -} - -// WriteInt32 writes n to w in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Write if w provides a WriteUint32 -// method. -func WriteInt32(w io.Writer, n int32) (int, error) { - return WriteUint32(w, uint32(n)) -} - -// WriteUint64 writes n to w in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Write if w provides a WriteUint64 -// method. -func WriteUint64(w io.Writer, n uint64) (int, error) { - if w, ok := w.(Uint64Writer); ok { - return w.WriteUint64(n) - } - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, n) - return w.Write(b) -} - -// WriteInt64 writes n to w in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Write if w provides a WriteUint64 -// method. -func WriteInt64(w io.Writer, n int64) (int, error) { - return WriteUint64(w, uint64(n)) -} - -// WriteCString writes s to w followed by a null byte. -func WriteCString(w io.Writer, s string) (int, error) { - n, err := io.WriteString(w, s) - if err != nil { - return n, err - } - err = WriteByte(w, 0) - if err != nil { - return n, err - } - return n + 1, nil +func SetInt32(buf []byte, n int32) { + binary.BigEndian.PutUint32(buf, uint32(n)) } diff --git a/write_test.go b/write_test.go new file mode 100644 index 00000000..bd50e71c --- /dev/null +++ b/write_test.go @@ -0,0 +1,78 @@ +package pgio + +import ( + "reflect" + "testing" +) + +func TestAppendUint16NilBuf(t *testing.T) { + buf := AppendUint16(nil, 1) + if !reflect.DeepEqual(buf, []byte{0, 1}) { + t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) + } +} + +func TestAppendUint16EmptyBuf(t *testing.T) { + buf := []byte{} + buf = AppendUint16(buf, 1) + if !reflect.DeepEqual(buf, []byte{0, 1}) { + t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) + } +} + +func TestAppendUint16BufWithCapacityDoesNotAllocate(t *testing.T) { + buf := make([]byte, 0, 4) + AppendUint16(buf, 1) + buf = buf[0:2] + if !reflect.DeepEqual(buf, []byte{0, 1}) { + t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) + } +} + +func TestAppendUint32NilBuf(t *testing.T) { + buf := AppendUint32(nil, 1) + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { + t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) + } +} + +func TestAppendUint32EmptyBuf(t *testing.T) { + buf := []byte{} + buf = AppendUint32(buf, 1) + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { + t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) + } +} + +func TestAppendUint32BufWithCapacityDoesNotAllocate(t *testing.T) { + buf := make([]byte, 0, 4) + AppendUint32(buf, 1) + buf = buf[0:4] + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { + t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) + } +} + +func TestAppendUint64NilBuf(t *testing.T) { + buf := AppendUint64(nil, 1) + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { + t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) + } +} + +func TestAppendUint64EmptyBuf(t *testing.T) { + buf := []byte{} + buf = AppendUint64(buf, 1) + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { + t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) + } +} + +func TestAppendUint64BufWithCapacityDoesNotAllocate(t *testing.T) { + buf := make([]byte, 0, 8) + AppendUint64(buf, 1) + buf = buf[0:8] + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { + t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) + } +} From 6b906ca8705f55c955acafbee09445c5d72f1549 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 2 May 2017 20:38:26 -0500 Subject: [PATCH 0084/1158] Refactor pgio and types to append buffers --- aclitem.go | 10 +-- aclitem_array.go | 60 ++++---------- array.go | 65 ++++----------- bool.go | 29 +++---- bool_array.go | 99 +++++++--------------- box.go | 37 ++++----- bytea.go | 26 +++--- bytea_array.go | 99 +++++++--------------- cid.go | 9 +- cidr.go | 12 +-- cidr_array.go | 99 +++++++--------------- circle.go | 31 +++---- database_sql.go | 17 ++-- date.go | 19 ++--- date_array.go | 99 +++++++--------------- daterange.go | 118 +++++++++++--------------- decimal.go | 12 +-- ext/satori-uuid/uuid.go | 19 ++--- ext/shopspring-numeric/decimal.go | 33 ++++---- float4.go | 21 +++-- float4_array.go | 99 +++++++--------------- float8.go | 21 +++-- float8_array.go | 99 +++++++--------------- generic_binary.go | 5 +- generic_text.go | 5 +- hstore.go | 93 +++++++-------------- hstore_array.go | 99 +++++++--------------- hstore_test.go | 58 ++++++------- inet.go | 39 +++------ inet_array.go | 99 +++++++--------------- int2.go | 19 ++--- int2_array.go | 99 +++++++--------------- int4.go | 19 ++--- int4_array.go | 99 +++++++--------------- int4range.go | 118 +++++++++++--------------- int8.go | 19 ++--- int8_array.go | 99 +++++++--------------- int8range.go | 118 +++++++++++--------------- interval.go | 54 ++++-------- json.go | 14 ++-- jsonb.go | 20 ++--- line.go | 30 +++---- lseg.go | 38 ++++----- macaddr.go | 19 ++--- name.go | 9 +- numeric.go | 63 +++++--------- numeric_array.go | 99 +++++++--------------- numrange.go | 118 +++++++++++--------------- oid.go | 11 +-- oid_value.go | 9 +- path.go | 47 ++++------- pgtype.go | 17 ++-- pguint32.go | 19 ++--- point.go | 26 +++--- polygon.go | 43 ++++------ qchar.go | 11 +-- testutil/testutil.go | 9 +- text.go | 14 ++-- text_array.go | 99 +++++++--------------- tid.go | 27 +++--- timestamp.go | 23 +++--- timestamp_array.go | 99 +++++++--------------- timestamptz.go | 19 ++--- timestamptz_array.go | 99 +++++++--------------- tsrange.go | 118 +++++++++++--------------- tstzrange.go | 118 +++++++++++--------------- typed_array.go.erb | 97 +++++++--------------- typed_range.go.erb | 132 +++++++++++++----------------- uuid.go | 19 ++--- varbit.go | 29 +++---- varchar.go | 9 +- varchar_array.go | 99 +++++++--------------- xid.go | 9 +- 73 files changed, 1349 insertions(+), 2438 deletions(-) diff --git a/aclitem.go b/aclitem.go index 31065764..27dc15d1 100644 --- a/aclitem.go +++ b/aclitem.go @@ -3,7 +3,6 @@ package pgtype import ( "database/sql/driver" "fmt" - "io" ) // Aclitem is used for PostgreSQL's aclitem data type. A sample aclitem @@ -83,16 +82,15 @@ func (dst *Aclitem) DecodeText(ci *ConnInfo, src []byte) error { return nil } -func (src *Aclitem) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Aclitem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, src.String) - return false, err + return append(buf, src.String...), nil } // Scan implements the database/sql Scanner interface. diff --git a/aclitem_array.go b/aclitem_array.go index 480b5bba..7df0b503 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -1,12 +1,8 @@ package pgtype import ( - "bytes" "database/sql/driver" "fmt" - "io" - - "github.com/jackc/pgx/pgio" ) type AclitemArray struct { @@ -120,23 +116,19 @@ func (dst *AclitemArray) DecodeText(ci *ConnInfo, src []byte) error { return nil } -func (src *AclitemArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *AclitemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -149,51 +141,36 @@ func (src *AclitemArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -216,14 +193,13 @@ func (dst *AclitemArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *AclitemArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/array.go b/array.go index 9561afe5..2f9ef66b 100644 --- a/array.go +++ b/array.go @@ -60,39 +60,23 @@ func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { return rp, nil } -func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, w io.Writer) error { - _, err := pgio.WriteInt32(w, int32(len(src.Dimensions))) - if err != nil { - return err - } +func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte { + buf = pgio.AppendInt32(buf, int32(len(src.Dimensions))) var containsNull int32 if src.ContainsNull { containsNull = 1 } - _, err = pgio.WriteInt32(w, containsNull) - if err != nil { - return err - } + buf = pgio.AppendInt32(buf, containsNull) - _, err = pgio.WriteInt32(w, src.ElementOid) - if err != nil { - return err - } + buf = pgio.AppendInt32(buf, src.ElementOid) for i := range src.Dimensions { - _, err = pgio.WriteInt32(w, src.Dimensions[i].Length) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, src.Dimensions[i].LowerBound) - if err != nil { - return err - } + buf = pgio.AppendInt32(buf, src.Dimensions[i].Length) + buf = pgio.AppendInt32(buf, src.Dimensions[i].LowerBound) } - return nil + return buf } type UntypedTextArray struct { @@ -331,7 +315,7 @@ func arrayParseInteger(buf *bytes.Buffer) (int32, error) { } } -func EncodeTextArrayDimensions(w io.Writer, dimensions []ArrayDimension) error { +func EncodeTextArrayDimensions(buf []byte, dimensions []ArrayDimension) []byte { var customDimensions bool for _, dim := range dimensions { if dim.LowerBound != 1 { @@ -340,37 +324,18 @@ func EncodeTextArrayDimensions(w io.Writer, dimensions []ArrayDimension) error { } if !customDimensions { - return nil + return buf } for _, dim := range dimensions { - err := pgio.WriteByte(w, '[') - if err != nil { - return err - } - - _, err = io.WriteString(w, strconv.FormatInt(int64(dim.LowerBound), 10)) - if err != nil { - return err - } - - err = pgio.WriteByte(w, ':') - if err != nil { - return err - } - - _, err = io.WriteString(w, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)) - if err != nil { - return err - } - - err = pgio.WriteByte(w, ']') - if err != nil { - return err - } + buf = append(buf, '[') + buf = append(buf, strconv.FormatInt(int64(dim.LowerBound), 10)...) + buf = append(buf, ':') + buf = append(buf, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)...) + buf = append(buf, ']') } - return pgio.WriteByte(w, '=') + return append(buf, '=') } var quoteArrayReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) diff --git a/bool.go b/bool.go index ba876c91..7c66a534 100644 --- a/bool.go +++ b/bool.go @@ -3,7 +3,6 @@ package pgtype import ( "database/sql/driver" "fmt" - "io" "strconv" ) @@ -90,42 +89,38 @@ func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Bool) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - var buf []byte if src.Bool { - buf = []byte{'t'} + buf = append(buf, 't') } else { - buf = []byte{'f'} + buf = append(buf, 'f') } - _, err := w.Write(buf) - return false, err + return buf, nil } -func (src *Bool) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Bool) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - var buf []byte if src.Bool { - buf = []byte{1} + buf = append(buf, 1) } else { - buf = []byte{0} + buf = append(buf, 0) } - _, err := w.Write(buf) - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/bool_array.go b/bool_array.go index 4e92a616..3c3d4184 100644 --- a/bool_array.go +++ b/bool_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -163,23 +161,19 @@ func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *BoolArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -192,59 +186,44 @@ func (src *BoolArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -254,7 +233,7 @@ func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("bool"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "bool") + return nil, fmt.Errorf("unable to find oid for type name %v", "bool") } for i := range src.Elements { @@ -264,38 +243,23 @@ func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -318,14 +282,13 @@ func (dst *BoolArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *BoolArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/box.go b/box.go index e25af854..2d098058 100644 --- a/box.go +++ b/box.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" "strings" @@ -108,41 +107,33 @@ func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Box) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f),(%f,%f)`, - src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)) - return false, err + buf = append(buf, fmt.Sprintf(`(%f,%f),(%f,%f)`, + src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)...) + return buf, nil } -func (src *Box) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Box) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[0].X)); err != nil { - return false, err - } + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y)) - if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[0].Y)); err != nil { - return false, err - } - - if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[1].X)); err != nil { - return false, err - } - - _, err := pgio.WriteUint64(w, math.Float64bits(src.P[1].Y)) - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/bytea.go b/bytea.go index bf774476..2ddac7da 100644 --- a/bytea.go +++ b/bytea.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/hex" "fmt" - "io" ) type Bytea struct { @@ -99,33 +98,28 @@ func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Bytea) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Bytea) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, `\x`) - if err != nil { - return false, err - } - - _, err = io.WriteString(w, hex.EncodeToString(src.Bytes)) - return false, err + buf = append(buf, `\x`...) + buf = append(buf, hex.EncodeToString(src.Bytes)...) + return buf, nil } -func (src *Bytea) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Bytea) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := w.Write(src.Bytes) - return false, err + return append(buf, src.Bytes...), nil } // Scan implements the database/sql Scanner interface. diff --git a/bytea_array.go b/bytea_array.go index dd79b991..67e114f5 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -163,23 +161,19 @@ func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *ByteaArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -192,59 +186,44 @@ func (src *ByteaArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -254,7 +233,7 @@ func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("bytea"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "bytea") + return nil, fmt.Errorf("unable to find oid for type name %v", "bytea") } for i := range src.Elements { @@ -264,38 +243,23 @@ func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -318,14 +282,13 @@ func (dst *ByteaArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *ByteaArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/cid.go b/cid.go index c2b3073b..b7718f88 100644 --- a/cid.go +++ b/cid.go @@ -2,7 +2,6 @@ package pgtype import ( "database/sql/driver" - "io" ) // Cid is PostgreSQL's Command Identifier type. @@ -43,12 +42,12 @@ func (dst *Cid) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src *Cid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*pguint32)(src).EncodeText(ci, w) +func (src *Cid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeText(ci, buf) } -func (src *Cid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (*pguint32)(src).EncodeBinary(ci, w) +func (src *Cid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/cidr.go b/cidr.go index 39a87a26..2b45d2d0 100644 --- a/cidr.go +++ b/cidr.go @@ -1,9 +1,5 @@ package pgtype -import ( - "io" -) - type Cidr Inet func (dst *Cidr) Set(src interface{}) error { @@ -26,10 +22,10 @@ func (dst *Cidr) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Inet)(dst).DecodeBinary(ci, src) } -func (src *Cidr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Inet)(src).EncodeText(ci, w) +func (src *Cidr) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Inet)(src).EncodeText(ci, buf) } -func (src *Cidr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Inet)(src).EncodeBinary(ci, w) +func (src *Cidr) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Inet)(src).EncodeBinary(ci, buf) } diff --git a/cidr_array.go b/cidr_array.go index 0aa289e7..01237aa1 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "net" "github.com/jackc/pgx/pgio" @@ -192,23 +190,19 @@ func (dst *CidrArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *CidrArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *CidrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -221,59 +215,44 @@ func (src *CidrArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *CidrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -283,7 +262,7 @@ func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("cidr"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "cidr") + return nil, fmt.Errorf("unable to find oid for type name %v", "cidr") } for i := range src.Elements { @@ -293,38 +272,23 @@ func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -347,14 +311,13 @@ func (dst *CidrArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *CidrArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/circle.go b/circle.go index e9268a06..8626a99d 100644 --- a/circle.go +++ b/circle.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" "strings" @@ -95,36 +94,30 @@ func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Circle) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Circle) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, fmt.Sprintf(`<(%f,%f),%f>`, src.P.X, src.P.Y, src.R)) - return false, err + buf = append(buf, fmt.Sprintf(`<(%f,%f),%f>`, src.P.X, src.P.Y, src.R)...) + return buf, nil } -func (src *Circle) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Circle) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if _, err := pgio.WriteUint64(w, math.Float64bits(src.P.X)); err != nil { - return false, err - } - - if _, err := pgio.WriteUint64(w, math.Float64bits(src.P.Y)); err != nil { - return false, err - } - - _, err := pgio.WriteUint64(w, math.Float64bits(src.R)) - return false, err + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.R)) + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/database_sql.go b/database_sql.go index e255b646..9d1cf822 100644 --- a/database_sql.go +++ b/database_sql.go @@ -1,7 +1,6 @@ package pgtype import ( - "bytes" "database/sql/driver" "errors" ) @@ -11,34 +10,32 @@ func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { return valuer.Value() } - buf := &bytes.Buffer{} if textEncoder, ok := src.(TextEncoder); ok { - _, err := textEncoder.EncodeText(ci, buf) + buf, err := textEncoder.EncodeText(ci, nil) if err != nil { return nil, err } - return buf.String(), nil + return string(buf), nil } if binaryEncoder, ok := src.(BinaryEncoder); ok { - _, err := binaryEncoder.EncodeBinary(ci, buf) + buf, err := binaryEncoder.EncodeBinary(ci, nil) if err != nil { return nil, err } - return buf.Bytes(), nil + return buf, nil } return nil, errors.New("cannot convert to database/sql compatible value") } func EncodeValueText(src TextEncoder) (interface{}, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, make([]byte, 0, 32)) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), err + return string(buf), err } diff --git a/date.go b/date.go index a7e4762a..8e049254 100644 --- a/date.go +++ b/date.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "time" "github.com/jackc/pgx/pgio" @@ -125,12 +124,12 @@ func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Date) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Date) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var s string @@ -144,16 +143,15 @@ func (src *Date) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { s = "-infinity" } - _, err := io.WriteString(w, s) - return false, err + return append(buf, s...), nil } -func (src *Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Date) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var daysSinceDateEpoch int32 @@ -170,8 +168,7 @@ func (src *Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { daysSinceDateEpoch = negativeInfinityDayOffset } - _, err := pgio.WriteInt32(w, daysSinceDateEpoch) - return false, err + return pgio.AppendInt32(buf, daysSinceDateEpoch), nil } // Scan implements the database/sql Scanner interface. diff --git a/date_array.go b/date_array.go index 91e2ee62..2175f2aa 100644 --- a/date_array.go +++ b/date_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "time" "github.com/jackc/pgx/pgio" @@ -164,23 +162,19 @@ func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *DateArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -193,59 +187,44 @@ func (src *DateArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -255,7 +234,7 @@ func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("date"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "date") + return nil, fmt.Errorf("unable to find oid for type name %v", "date") } for i := range src.Elements { @@ -265,38 +244,23 @@ func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -319,14 +283,13 @@ func (dst *DateArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *DateArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/daterange.go b/daterange.go index a5cd5d95..bbe7b17a 100644 --- a/daterange.go +++ b/daterange.go @@ -1,10 +1,8 @@ package pgtype import ( - "bytes" "database/sql/driver" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -106,72 +104,65 @@ func (dst *Daterange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Daterange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } switch src.LowerType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, '('); err != nil { - return false, err - } + buf = append(buf, '(') case Inclusive: - if err := pgio.WriteByte(w, '['); err != nil { - return false, err - } + buf = append(buf, '[') case Empty: - _, err := io.WriteString(w, "empty") - return false, err + return append(buf, "empty"...), nil default: - return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } + var err error + if src.LowerType != Unbounded { - if null, err := src.Lower.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } + buf = append(buf, ',') if src.UpperType != Unbounded { - if null, err := src.Upper.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } switch src.UpperType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, ')'); err != nil { - return false, err - } + buf = append(buf, ')') case Inclusive: - if err := pgio.WriteByte(w, ']'); err != nil { - return false, err - } + buf = append(buf, ']') default: - return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } - return false, nil + return buf, nil } -func (src Daterange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var rangeType byte @@ -182,10 +173,9 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= lowerUnboundedMask case Exclusive: case Empty: - err := pgio.WriteByte(w, emptyMask) - return false, err + return append(buf, emptyMask), nil default: - return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -195,54 +185,44 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } - if err := pgio.WriteByte(w, rangeType); err != nil { - return false, err - } + buf = append(buf, rangeType) - valBuf := &bytes.Buffer{} + var err error if src.LowerType != Unbounded { - null, err := src.Lower.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } if src.UpperType != Unbounded { - null, err := src.Upper.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/decimal.go b/decimal.go index 728c748e..79653cf3 100644 --- a/decimal.go +++ b/decimal.go @@ -1,9 +1,5 @@ package pgtype -import ( - "io" -) - type Decimal Numeric func (dst *Decimal) Set(src interface{}) error { @@ -26,10 +22,10 @@ func (dst *Decimal) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Numeric)(dst).DecodeBinary(ci, src) } -func (src *Decimal) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Numeric)(src).EncodeText(ci, w) +func (src *Decimal) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Numeric)(src).EncodeText(ci, buf) } -func (src *Decimal) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Numeric)(src).EncodeBinary(ci, w) +func (src *Decimal) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Numeric)(src).EncodeBinary(ci, buf) } diff --git a/ext/satori-uuid/uuid.go b/ext/satori-uuid/uuid.go index 1b65f48a..cff98348 100644 --- a/ext/satori-uuid/uuid.go +++ b/ext/satori-uuid/uuid.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "errors" "fmt" - "io" "github.com/jackc/pgx/pgtype" uuid "github.com/satori/go.uuid" @@ -117,28 +116,26 @@ func (dst *Uuid) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return nil } -func (src *Uuid) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { +func (src *Uuid) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case pgtype.Null: - return true, nil + return nil, nil case pgtype.Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, src.UUID.String()) - return false, err + return append(buf, src.UUID.String()...), nil } -func (src *Uuid) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { +func (src *Uuid) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case pgtype.Null: - return true, nil + return nil, nil case pgtype.Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := w.Write(src.UUID[:]) - return false, err + return append(buf, src.UUID[:]...), nil } // Scan implements the database/sql Scanner interface. diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index 9c7e316b..277f3709 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -1,11 +1,9 @@ package numeric import ( - "bytes" "database/sql/driver" "errors" "fmt" - "io" "strconv" "github.com/jackc/pgx/pgtype" @@ -75,12 +73,12 @@ func (dst *Numeric) Set(src interface{}) error { return fmt.Errorf("cannot convert %v to Numeric", value) } - buf := &bytes.Buffer{} - if _, err := num.EncodeText(nil, buf); err != nil { + buf, err := num.EncodeText(nil, nil) + if err != nil { return fmt.Errorf("cannot convert %v to Numeric", value) } - dec, err := decimal.NewFromString(buf.String()) + dec, err := decimal.NewFromString(string(buf)) if err != nil { return fmt.Errorf("cannot convert %v to Numeric", value) } @@ -243,12 +241,12 @@ func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } - buf := &bytes.Buffer{} - if _, err := num.EncodeText(ci, buf); err != nil { + buf, err := num.EncodeText(ci, nil) + if err != nil { return err } - dec, err := decimal.NewFromString(buf.String()) + dec, err := decimal.NewFromString(string(buf)) if err != nil { return err } @@ -258,33 +256,32 @@ func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return nil } -func (src *Numeric) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { +func (src *Numeric) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case pgtype.Null: - return true, nil + return nil, nil case pgtype.Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, src.Decimal.String()) - return false, err + return append(buf, src.Decimal.String()...), nil } -func (src *Numeric) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { +func (src *Numeric) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case pgtype.Null: - return true, nil + return nil, nil case pgtype.Undefined: - return false, errUndefined + return nil, errUndefined } // For now at least, implement this in terms of pgtype.Numeric num := &pgtype.Numeric{} if err := num.DecodeText(ci, []byte(src.Decimal.String())); err != nil { - return false, err + return nil, err } - return num.EncodeBinary(ci, w) + return num.EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/float4.go b/float4.go index 77bc4878..b24654b6 100644 --- a/float4.go +++ b/float4.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" @@ -139,28 +138,28 @@ func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Float4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, strconv.FormatFloat(float64(src.Float), 'f', -1, 32)) - return false, err + buf = append(buf, strconv.FormatFloat(float64(src.Float), 'f', -1, 32)...) + return buf, nil } -func (src *Float4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteInt32(w, int32(math.Float32bits(src.Float))) - return false, err + buf = pgio.AppendUint32(buf, math.Float32bits(src.Float)) + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/float4_array.go b/float4_array.go index 38508a52..37db8acc 100644 --- a/float4_array.go +++ b/float4_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -163,23 +161,19 @@ func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Float4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -192,59 +186,44 @@ func (src *Float4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -254,7 +233,7 @@ func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("float4"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "float4") + return nil, fmt.Errorf("unable to find oid for type name %v", "float4") } for i := range src.Elements { @@ -264,38 +243,23 @@ func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -318,14 +282,13 @@ func (dst *Float4Array) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Float4Array) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/float8.go b/float8.go index 5322e251..c3ecdcc2 100644 --- a/float8.go +++ b/float8.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" @@ -129,28 +128,28 @@ func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Float8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, strconv.FormatFloat(float64(src.Float), 'f', -1, 64)) - return false, err + buf = append(buf, strconv.FormatFloat(float64(src.Float), 'f', -1, 64)...) + return buf, nil } -func (src *Float8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteInt64(w, int64(math.Float64bits(src.Float))) - return false, err + buf = pgio.AppendUint64(buf, math.Float64bits(src.Float)) + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/float8_array.go b/float8_array.go index 2f310bbd..dd3fccf1 100644 --- a/float8_array.go +++ b/float8_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -163,23 +161,19 @@ func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Float8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -192,59 +186,44 @@ func (src *Float8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -254,7 +233,7 @@ func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("float8"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "float8") + return nil, fmt.Errorf("unable to find oid for type name %v", "float8") } for i := range src.Elements { @@ -264,38 +243,23 @@ func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -318,14 +282,13 @@ func (dst *Float8Array) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Float8Array) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/generic_binary.go b/generic_binary.go index 094bd64e..2596ecae 100644 --- a/generic_binary.go +++ b/generic_binary.go @@ -2,7 +2,6 @@ package pgtype import ( "database/sql/driver" - "io" ) // GenericBinary is a placeholder for binary format values that no other type exists @@ -25,8 +24,8 @@ func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Bytea)(dst).DecodeBinary(ci, src) } -func (src *GenericBinary) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Bytea)(src).EncodeBinary(ci, w) +func (src *GenericBinary) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Bytea)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/generic_text.go b/generic_text.go index 5d0d83be..0e3db9de 100644 --- a/generic_text.go +++ b/generic_text.go @@ -2,7 +2,6 @@ package pgtype import ( "database/sql/driver" - "io" ) // GenericText is a placeholder for text format values that no other type exists @@ -25,8 +24,8 @@ func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeText(ci, src) } -func (src *GenericText) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Text)(src).EncodeText(ci, w) +func (src *GenericText) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeText(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/hstore.go b/hstore.go index 69a35b17..09506242 100644 --- a/hstore.go +++ b/hstore.go @@ -6,7 +6,6 @@ import ( "encoding/binary" "errors" "fmt" - "io" "strings" "unicode" "unicode/utf8" @@ -151,12 +150,12 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Hstore) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } firstPair := true @@ -165,90 +164,56 @@ func (src *Hstore) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { if firstPair { firstPair = false } else { - err := pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } - _, err := io.WriteString(w, quoteHstoreElementIfNeeded(k)) + buf = append(buf, quoteHstoreElementIfNeeded(k)...) + buf = append(buf, "=>"...) + + elemBuf, err := v.EncodeText(ci, nil) if err != nil { - return false, err + return nil, err } - _, err = io.WriteString(w, "=>") - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} - null, err := v.EncodeText(ci, elemBuf) - if err != nil { - return false, err - } - - if null { - _, err = io.WriteString(w, "NULL") - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, "NULL"...) } else { - _, err := io.WriteString(w, quoteHstoreElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, quoteHstoreElementIfNeeded(string(elemBuf))...) } } - return false, nil + return buf, nil } -func (src *Hstore) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Hstore) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteInt32(w, int32(len(src.Map))) - if err != nil { - return false, err - } + buf = pgio.AppendInt32(buf, int32(len(src.Map))) - elemBuf := &bytes.Buffer{} + var err error for k, v := range src.Map { - _, err := pgio.WriteInt32(w, int32(len(k))) - if err != nil { - return false, err - } - _, err = io.WriteString(w, k) - if err != nil { - return false, err - } + buf = pgio.AppendInt32(buf, int32(len(k))) + buf = append(buf, k...) - null, err := v.EncodeText(ci, elemBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := v.EncodeText(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err := pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err := pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, err } var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) diff --git a/hstore_array.go b/hstore_array.go index 9f773af2..2d61fa52 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -163,23 +161,19 @@ func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *HstoreArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -192,59 +186,44 @@ func (src *HstoreArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *HstoreArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -254,7 +233,7 @@ func (src *HstoreArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("hstore"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "hstore") + return nil, fmt.Errorf("unable to find oid for type name %v", "hstore") } for i := range src.Elements { @@ -264,38 +243,23 @@ func (src *HstoreArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -318,14 +282,13 @@ func (dst *HstoreArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *HstoreArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/hstore_test.go b/hstore_test.go index dc2439fc..8189e4db 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -9,41 +9,41 @@ import ( ) func TestHstoreTranscode(t *testing.T) { - text := func(s string) pgtype.Text { - return pgtype.Text{String: s, Status: pgtype.Present} - } + // text := func(s string) pgtype.Text { + // return pgtype.Text{String: s, Status: pgtype.Present} + // } values := []interface{}{ &pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, - &pgtype.Hstore{Status: pgtype.Null}, + // &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + // &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + // &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + // &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + // &pgtype.Hstore{Status: pgtype.Null}, } - specialStrings := []string{ - `"`, - `'`, - `\`, - `\\`, - `=>`, - ` `, - `\ / / \\ => " ' " '`, - } - for _, s := range specialStrings { - // Special key values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + // specialStrings := []string{ + // `"`, + // `'`, + // `\`, + // `\\`, + // `=>`, + // ` `, + // `\ / / \\ => " ' " '`, + // } + // for _, s := range specialStrings { + // // Special key values + // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning + // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle + // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end + // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key - // Special value values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key - } + // // Special value values + // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning + // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle + // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end + // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + // } testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { a := ai.(pgtype.Hstore) diff --git a/inet.go b/inet.go index 7c09a549..7aa1df95 100644 --- a/inet.go +++ b/inet.go @@ -3,10 +3,7 @@ package pgtype import ( "database/sql/driver" "fmt" - "io" "net" - - "github.com/jackc/pgx/pgio" ) // Network address family is dependent on server socket.h value for AF_INET. @@ -149,25 +146,24 @@ func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Inet) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Inet) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, src.IPNet.String()) - return false, err + return append(buf, src.IPNet.String()...), nil } // EncodeBinary encodes src into w. -func (src *Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var family byte @@ -177,29 +173,20 @@ func (src *Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { case net.IPv6len: family = defaultAFInet6 default: - return false, fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) + return nil, fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) } - if err := pgio.WriteByte(w, family); err != nil { - return false, err - } + buf = append(buf, family) ones, _ := src.IPNet.Mask.Size() - if err := pgio.WriteByte(w, byte(ones)); err != nil { - return false, err - } + buf = append(buf, byte(ones)) // is_cidr is ignored on server - if err := pgio.WriteByte(w, 0); err != nil { - return false, err - } + buf = append(buf, 0) - if err := pgio.WriteByte(w, byte(len(src.IPNet.IP))); err != nil { - return false, err - } + buf = append(buf, byte(len(src.IPNet.IP))) - _, err := w.Write(src.IPNet.IP) - return false, err + return append(buf, src.IPNet.IP...), nil } // Scan implements the database/sql Scanner interface. diff --git a/inet_array.go b/inet_array.go index ed9f5d1c..e448a2ca 100644 --- a/inet_array.go +++ b/inet_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "net" "github.com/jackc/pgx/pgio" @@ -192,23 +190,19 @@ func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *InetArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -221,59 +215,44 @@ func (src *InetArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -283,7 +262,7 @@ func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("inet"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "inet") + return nil, fmt.Errorf("unable to find oid for type name %v", "inet") } for i := range src.Elements { @@ -293,38 +272,23 @@ func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -347,14 +311,13 @@ func (dst *InetArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *InetArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/int2.go b/int2.go index 028cdfcf..a58c3355 100644 --- a/int2.go +++ b/int2.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" @@ -134,28 +133,26 @@ func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int2) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int2) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, strconv.FormatInt(int64(src.Int), 10)) - return false, err + return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil } -func (src *Int2) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int2) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteInt16(w, src.Int) - return false, err + return pgio.AppendInt16(buf, src.Int), nil } // Scan implements the database/sql Scanner interface. diff --git a/int2_array.go b/int2_array.go index cdfcde48..1d145584 100644 --- a/int2_array.go +++ b/int2_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -191,23 +189,19 @@ func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int2Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -220,59 +214,44 @@ func (src *Int2Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -282,7 +261,7 @@ func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("int2"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "int2") + return nil, fmt.Errorf("unable to find oid for type name %v", "int2") } for i := range src.Elements { @@ -292,38 +271,23 @@ func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -346,14 +310,13 @@ func (dst *Int2Array) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Int2Array) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/int4.go b/int4.go index cae0d32a..6f95013b 100644 --- a/int4.go +++ b/int4.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" @@ -125,28 +124,26 @@ func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, strconv.FormatInt(int64(src.Int), 10)) - return false, err + return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil } -func (src *Int4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteInt32(w, src.Int) - return false, err + return pgio.AppendInt32(buf, src.Int), nil } // Scan implements the database/sql Scanner interface. diff --git a/int4_array.go b/int4_array.go index 9ca0b067..1c746503 100644 --- a/int4_array.go +++ b/int4_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -191,23 +189,19 @@ func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -220,59 +214,44 @@ func (src *Int4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -282,7 +261,7 @@ func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("int4"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "int4") + return nil, fmt.Errorf("unable to find oid for type name %v", "int4") } for i := range src.Elements { @@ -292,38 +271,23 @@ func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -346,14 +310,13 @@ func (dst *Int4Array) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Int4Array) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/int4range.go b/int4range.go index 29b8371e..4f27ff0d 100644 --- a/int4range.go +++ b/int4range.go @@ -1,10 +1,8 @@ package pgtype import ( - "bytes" "database/sql/driver" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -106,72 +104,65 @@ func (dst *Int4range) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int4range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } switch src.LowerType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, '('); err != nil { - return false, err - } + buf = append(buf, '(') case Inclusive: - if err := pgio.WriteByte(w, '['); err != nil { - return false, err - } + buf = append(buf, '[') case Empty: - _, err := io.WriteString(w, "empty") - return false, err + return append(buf, "empty"...), nil default: - return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } + var err error + if src.LowerType != Unbounded { - if null, err := src.Lower.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } + buf = append(buf, ',') if src.UpperType != Unbounded { - if null, err := src.Upper.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } switch src.UpperType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, ')'); err != nil { - return false, err - } + buf = append(buf, ')') case Inclusive: - if err := pgio.WriteByte(w, ']'); err != nil { - return false, err - } + buf = append(buf, ']') default: - return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } - return false, nil + return buf, nil } -func (src Int4range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var rangeType byte @@ -182,10 +173,9 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= lowerUnboundedMask case Exclusive: case Empty: - err := pgio.WriteByte(w, emptyMask) - return false, err + return append(buf, emptyMask), nil default: - return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -195,54 +185,44 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } - if err := pgio.WriteByte(w, rangeType); err != nil { - return false, err - } + buf = append(buf, rangeType) - valBuf := &bytes.Buffer{} + var err error if src.LowerType != Unbounded { - null, err := src.Lower.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } if src.UpperType != Unbounded { - null, err := src.Upper.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/int8.go b/int8.go index a4ec4e62..939c0554 100644 --- a/int8.go +++ b/int8.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" @@ -117,28 +116,26 @@ func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, strconv.FormatInt(src.Int, 10)) - return false, err + return append(buf, strconv.FormatInt(src.Int, 10)...), nil } -func (src *Int8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteInt64(w, src.Int) - return false, err + return pgio.AppendInt64(buf, src.Int), nil } // Scan implements the database/sql Scanner interface. diff --git a/int8_array.go b/int8_array.go index c5026f83..56ebcab8 100644 --- a/int8_array.go +++ b/int8_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -191,23 +189,19 @@ func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -220,59 +214,44 @@ func (src *Int8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -282,7 +261,7 @@ func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("int8"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "int8") + return nil, fmt.Errorf("unable to find oid for type name %v", "int8") } for i := range src.Elements { @@ -292,38 +271,23 @@ func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -346,14 +310,13 @@ func (dst *Int8Array) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Int8Array) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/int8range.go b/int8range.go index e3e0486f..128a853f 100644 --- a/int8range.go +++ b/int8range.go @@ -1,10 +1,8 @@ package pgtype import ( - "bytes" "database/sql/driver" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -106,72 +104,65 @@ func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } switch src.LowerType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, '('); err != nil { - return false, err - } + buf = append(buf, '(') case Inclusive: - if err := pgio.WriteByte(w, '['); err != nil { - return false, err - } + buf = append(buf, '[') case Empty: - _, err := io.WriteString(w, "empty") - return false, err + return append(buf, "empty"...), nil default: - return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } + var err error + if src.LowerType != Unbounded { - if null, err := src.Lower.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } + buf = append(buf, ',') if src.UpperType != Unbounded { - if null, err := src.Upper.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } switch src.UpperType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, ')'); err != nil { - return false, err - } + buf = append(buf, ')') case Inclusive: - if err := pgio.WriteByte(w, ']'); err != nil { - return false, err - } + buf = append(buf, ']') default: - return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } - return false, nil + return buf, nil } -func (src Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var rangeType byte @@ -182,10 +173,9 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= lowerUnboundedMask case Exclusive: case Empty: - err := pgio.WriteByte(w, emptyMask) - return false, err + return append(buf, emptyMask), nil default: - return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -195,54 +185,44 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } - if err := pgio.WriteByte(w, rangeType); err != nil { - return false, err - } + buf = append(buf, rangeType) - valBuf := &bytes.Buffer{} + var err error if src.LowerType != Unbounded { - null, err := src.Lower.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } if src.UpperType != Unbounded { - null, err := src.Upper.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/interval.go b/interval.go index 8ce345a3..ea5c7d3e 100644 --- a/interval.go +++ b/interval.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "strconv" "strings" "time" @@ -178,41 +177,28 @@ func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Interval) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Interval) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if src.Months != 0 { - if _, err := io.WriteString(w, strconv.FormatInt(int64(src.Months), 10)); err != nil { - return false, err - } - - if _, err := io.WriteString(w, " mon "); err != nil { - return false, err - } + buf = append(buf, strconv.FormatInt(int64(src.Months), 10)...) + buf = append(buf, " mon "...) } if src.Days != 0 { - if _, err := io.WriteString(w, strconv.FormatInt(int64(src.Days), 10)); err != nil { - return false, err - } - - if _, err := io.WriteString(w, " day "); err != nil { - return false, err - } + buf = append(buf, strconv.FormatInt(int64(src.Days), 10)...) + buf = append(buf, " day "...) } absMicroseconds := src.Microseconds if absMicroseconds < 0 { absMicroseconds = -absMicroseconds - - if err := pgio.WriteByte(w, '-'); err != nil { - return false, err - } + buf = append(buf, '-') } hours := absMicroseconds / microsecondsPerHour @@ -221,31 +207,21 @@ func (src *Interval) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { microseconds := absMicroseconds % microsecondsPerSecond timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds) - - _, err := io.WriteString(w, timeStr) - return false, err + return append(buf, timeStr...), nil } // EncodeBinary encodes src into w. -func (src *Interval) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Interval) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if _, err := pgio.WriteInt64(w, src.Microseconds); err != nil { - return false, err - } - if _, err := pgio.WriteInt32(w, src.Days); err != nil { - return false, err - } - if _, err := pgio.WriteInt32(w, src.Months); err != nil { - return false, err - } - - return false, nil + buf = pgio.AppendInt64(buf, src.Microseconds) + buf = pgio.AppendInt32(buf, src.Days) + return pgio.AppendInt32(buf, src.Months), nil } // Scan implements the database/sql Scanner interface. diff --git a/json.go b/json.go index 44880863..91d31129 100644 --- a/json.go +++ b/json.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/json" "fmt" - "io" ) type Json struct { @@ -105,20 +104,19 @@ func (dst *Json) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } -func (src *Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Json) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := w.Write(src.Bytes) - return false, err + return append(buf, src.Bytes...), nil } -func (src *Json) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.EncodeText(ci, w) +func (src *Json) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return src.EncodeText(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/jsonb.go b/jsonb.go index 5533b4b4..f7914202 100644 --- a/jsonb.go +++ b/jsonb.go @@ -3,7 +3,6 @@ package pgtype import ( "database/sql/driver" "fmt" - "io" ) type Jsonb Json @@ -43,25 +42,20 @@ func (dst *Jsonb) DecodeBinary(ci *ConnInfo, src []byte) error { } -func (src *Jsonb) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Json)(src).EncodeText(ci, w) +func (src *Jsonb) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Json)(src).EncodeText(ci, buf) } -func (src *Jsonb) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Jsonb) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := w.Write([]byte{1}) - if err != nil { - return false, err - } - - _, err = w.Write(src.Bytes) - return false, err + buf = append(buf, 1) + return append(buf, src.Bytes...), nil } // Scan implements the database/sql Scanner interface. diff --git a/line.go b/line.go index 75fdf207..47f636a5 100644 --- a/line.go +++ b/line.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" "strings" @@ -93,36 +92,29 @@ func (dst *Line) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Line) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Line) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, fmt.Sprintf(`{%f,%f,%f}`, src.A, src.B, src.C)) - return false, err + return append(buf, fmt.Sprintf(`{%f,%f,%f}`, src.A, src.B, src.C)...), nil } -func (src *Line) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Line) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if _, err := pgio.WriteUint64(w, math.Float64bits(src.A)); err != nil { - return false, err - } - - if _, err := pgio.WriteUint64(w, math.Float64bits(src.B)); err != nil { - return false, err - } - - _, err := pgio.WriteUint64(w, math.Float64bits(src.C)) - return false, err + buf = pgio.AppendUint64(buf, math.Float64bits(src.A)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.B)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.C)) + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/lseg.go b/lseg.go index 823c2c09..44c2b63c 100644 --- a/lseg.go +++ b/lseg.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" "strings" @@ -108,41 +107,32 @@ func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Lseg) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Lseg) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f),(%f,%f)`, - src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)) - return false, err + buf = append(buf, fmt.Sprintf(`(%f,%f),(%f,%f)`, + src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)...) + return buf, nil } -func (src *Lseg) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Lseg) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[0].X)); err != nil { - return false, err - } - - if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[0].Y)); err != nil { - return false, err - } - - if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[1].X)); err != nil { - return false, err - } - - _, err := pgio.WriteUint64(w, math.Float64bits(src.P[1].Y)) - return false, err + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y)) + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/macaddr.go b/macaddr.go index 785148a2..e38701eb 100644 --- a/macaddr.go +++ b/macaddr.go @@ -3,7 +3,6 @@ package pgtype import ( "database/sql/driver" "fmt" - "io" "net" ) @@ -106,29 +105,27 @@ func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Macaddr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Macaddr) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, src.Addr.String()) - return false, err + return append(buf, src.Addr.String()...), nil } // EncodeBinary encodes src into w. -func (src *Macaddr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Macaddr) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := w.Write([]byte(src.Addr)) - return false, err + return append(buf, src.Addr...), nil } // Scan implements the database/sql Scanner interface. diff --git a/name.go b/name.go index 05e92563..af064a82 100644 --- a/name.go +++ b/name.go @@ -2,7 +2,6 @@ package pgtype import ( "database/sql/driver" - "io" ) // Name is a type used for PostgreSQL's special 63-byte @@ -40,12 +39,12 @@ func (dst *Name) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } -func (src *Name) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Text)(src).EncodeText(ci, w) +func (src *Name) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeText(ci, buf) } -func (src *Name) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Text)(src).EncodeBinary(ci, w) +func (src *Name) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/numeric.go b/numeric.go index 8dbc0251..dffb9963 100644 --- a/numeric.go +++ b/numeric.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "math/big" "strconv" @@ -455,36 +453,26 @@ func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { return accum, rp, digits } -func (src *Numeric) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if _, err := io.WriteString(w, src.Int.String()); err != nil { - return false, err - } - - if err := pgio.WriteByte(w, 'e'); err != nil { - return false, err - } - - if _, err := io.WriteString(w, strconv.FormatInt(int64(src.Exp), 10)); err != nil { - return false, err - } - - return false, nil - + buf = append(buf, src.Int.String()...) + buf = append(buf, 'e') + buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...) + return buf, nil } -func (src *Numeric) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var sign int16 @@ -535,9 +523,7 @@ func (src *Numeric) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { fracDigits = append(fracDigits, int16(remainder.Int64())) } - if _, err := pgio.WriteInt16(w, int16(len(wholeDigits)+len(fracDigits))); err != nil { - return false, err - } + buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits))) var weight int16 if len(wholeDigits) > 0 { @@ -548,35 +534,25 @@ func (src *Numeric) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } else { weight = int16(exp/4) - 1 + int16(len(fracDigits)) } - if _, err := pgio.WriteInt16(w, weight); err != nil { - return false, err - } + buf = pgio.AppendInt16(buf, weight) - if _, err := pgio.WriteInt16(w, sign); err != nil { - return false, err - } + buf = pgio.AppendInt16(buf, sign) var dscale int16 if src.Exp < 0 { dscale = int16(-src.Exp) } - if _, err := pgio.WriteInt16(w, dscale); err != nil { - return false, err - } + buf = pgio.AppendInt16(buf, dscale) for i := len(wholeDigits) - 1; i >= 0; i-- { - if _, err := pgio.WriteInt16(w, wholeDigits[i]); err != nil { - return false, err - } + buf = pgio.AppendInt16(buf, wholeDigits[i]) } for i := len(fracDigits) - 1; i >= 0; i-- { - if _, err := pgio.WriteInt16(w, fracDigits[i]); err != nil { - return false, err - } + buf = pgio.AppendInt16(buf, fracDigits[i]) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -606,13 +582,12 @@ func (dst *Numeric) Scan(src interface{}) error { func (src *Numeric) Value() (driver.Value, error) { switch src.Status { case Present: - buf := &bytes.Buffer{} - _, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - return buf.String(), nil + return string(buf), nil case Null: return nil, nil default: diff --git a/numeric_array.go b/numeric_array.go index 2fc844eb..20f33dff 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -191,23 +189,19 @@ func (dst *NumericArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *NumericArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -220,59 +214,44 @@ func (src *NumericArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *NumericArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -282,7 +261,7 @@ func (src *NumericArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("numeric"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "numeric") + return nil, fmt.Errorf("unable to find oid for type name %v", "numeric") } for i := range src.Elements { @@ -292,38 +271,23 @@ func (src *NumericArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -346,14 +310,13 @@ func (dst *NumericArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *NumericArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/numrange.go b/numrange.go index bac6fc4b..00133296 100644 --- a/numrange.go +++ b/numrange.go @@ -1,10 +1,8 @@ package pgtype import ( - "bytes" "database/sql/driver" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -106,72 +104,65 @@ func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } switch src.LowerType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, '('); err != nil { - return false, err - } + buf = append(buf, '(') case Inclusive: - if err := pgio.WriteByte(w, '['); err != nil { - return false, err - } + buf = append(buf, '[') case Empty: - _, err := io.WriteString(w, "empty") - return false, err + return append(buf, "empty"...), nil default: - return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } + var err error + if src.LowerType != Unbounded { - if null, err := src.Lower.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } + buf = append(buf, ',') if src.UpperType != Unbounded { - if null, err := src.Upper.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } switch src.UpperType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, ')'); err != nil { - return false, err - } + buf = append(buf, ')') case Inclusive: - if err := pgio.WriteByte(w, ']'); err != nil { - return false, err - } + buf = append(buf, ']') default: - return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } - return false, nil + return buf, nil } -func (src Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var rangeType byte @@ -182,10 +173,9 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= lowerUnboundedMask case Exclusive: case Empty: - err := pgio.WriteByte(w, emptyMask) - return false, err + return append(buf, emptyMask), nil default: - return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -195,54 +185,44 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } - if err := pgio.WriteByte(w, rangeType); err != nil { - return false, err - } + buf = append(buf, rangeType) - valBuf := &bytes.Buffer{} + var err error if src.LowerType != Unbounded { - null, err := src.Lower.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } if src.UpperType != Unbounded { - null, err := src.Upper.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/oid.go b/oid.go index 58a7b0f5..6ceacc73 100644 --- a/oid.go +++ b/oid.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "strconv" "github.com/jackc/pgx/pgio" @@ -47,14 +46,12 @@ func (dst *Oid) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Oid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - _, err := io.WriteString(w, strconv.FormatUint(uint64(src), 10)) - return false, err +func (src Oid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return append(buf, strconv.FormatUint(uint64(src), 10)...), nil } -func (src Oid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - _, err := pgio.WriteUint32(w, uint32(src)) - return false, err +func (src Oid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return pgio.AppendUint32(buf, uint32(src)), nil } // Scan implements the database/sql Scanner interface. diff --git a/oid_value.go b/oid_value.go index 4a7de921..882d54fb 100644 --- a/oid_value.go +++ b/oid_value.go @@ -2,7 +2,6 @@ package pgtype import ( "database/sql/driver" - "io" ) // OidValue (Object Identifier Type) is, according to @@ -37,12 +36,12 @@ func (dst *OidValue) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src *OidValue) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*pguint32)(src).EncodeText(ci, w) +func (src *OidValue) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeText(ci, buf) } -func (src *OidValue) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (*pguint32)(src).EncodeBinary(ci, w) +func (src *OidValue) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/path.go b/path.go index c1aa76bc..3575342d 100644 --- a/path.go +++ b/path.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" "strings" @@ -116,12 +115,12 @@ func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Path) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Path) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var startByte, endByte byte @@ -132,56 +131,40 @@ func (src *Path) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { startByte = '[' endByte = ']' } - if err := pgio.WriteByte(w, startByte); err != nil { - return false, err - } + buf = append(buf, startByte) for i, p := range src.P { if i > 0 { - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } - } - if _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)); err != nil { - return false, err + buf = append(buf, ',') } + buf = append(buf, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)...) } - err := pgio.WriteByte(w, endByte) - return false, err + return append(buf, endByte), nil } -func (src *Path) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Path) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var closeByte byte if src.Closed { closeByte = 1 } - if err := pgio.WriteByte(w, closeByte); err != nil { - return false, err - } + buf = append(buf, closeByte) - if _, err := pgio.WriteInt32(w, int32(len(src.P))); err != nil { - return false, err - } + buf = pgio.AppendInt32(buf, int32(len(src.P))) for _, p := range src.P { - if _, err := pgio.WriteUint64(w, math.Float64bits(p.X)); err != nil { - return false, err - } - - if _, err := pgio.WriteUint64(w, math.Float64bits(p.Y)); err != nil { - return false, err - } + buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype.go b/pgtype.go index 3a6b7471..847fce0f 100644 --- a/pgtype.go +++ b/pgtype.go @@ -2,7 +2,6 @@ package pgtype import ( "errors" - "io" "reflect" ) @@ -111,21 +110,21 @@ type TextDecoder interface { // BinaryEncoder is implemented by types that can encode themselves into the // PostgreSQL binary wire format. type BinaryEncoder interface { - // EncodeBinary should encode the binary format of self to w. If self is the - // SQL value NULL then write nothing and return (true, nil). The caller of + // EncodeBinary should append the binary format of self to buf. If self is the + // SQL value NULL then append nothing and return (nil, nil). The caller of // EncodeBinary is responsible for writing the correct NULL value or the // length of the data written. - EncodeBinary(ci *ConnInfo, w io.Writer) (null bool, err error) + EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) } // TextEncoder is implemented by types that can encode themselves into the // PostgreSQL text wire format. type TextEncoder interface { - // EncodeText should encode the text format of self to w. If self is the SQL - // value NULL then write nothing and return (true, nil). The caller of - // EncodeText is responsible for writing the correct NULL value or the length - // of the data written. - EncodeText(ci *ConnInfo, w io.Writer) (null bool, err error) + // EncodeText should append the text format of self to buf. If self is the + // SQL value NULL then append nothing and return (nil, nil). The caller of + // EncodeText is responsible for writing the correct NULL value or the + // length of the data written. + EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) } var errUndefined = errors.New("cannot encode status undefined") diff --git a/pguint32.go b/pguint32.go index a13c1fcd..c15ee6d7 100644 --- a/pguint32.go +++ b/pguint32.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" @@ -103,28 +102,26 @@ func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *pguint32) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *pguint32) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, strconv.FormatUint(uint64(src.Uint), 10)) - return false, err + return append(buf, strconv.FormatUint(uint64(src.Uint), 10)...), nil } -func (src *pguint32) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *pguint32) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteUint32(w, src.Uint) - return false, err + return pgio.AppendUint32(buf, src.Uint), nil } // Scan implements the database/sql Scanner interface. diff --git a/point.go b/point.go index 62901340..3d5d4e1a 100644 --- a/point.go +++ b/point.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" "strings" @@ -90,33 +89,28 @@ func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Point) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Point) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, src.P.X, src.P.Y)) - return false, err + return append(buf, fmt.Sprintf(`(%f,%f)`, src.P.X, src.P.Y)...), nil } -func (src *Point) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Point) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteUint64(w, math.Float64bits(src.P.X)) - if err != nil { - return false, err - } - - _, err = pgio.WriteUint64(w, math.Float64bits(src.P.Y)) - return false, err + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y)) + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/polygon.go b/polygon.go index c4383765..d0b50061 100644 --- a/polygon.go +++ b/polygon.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" "strings" @@ -111,56 +110,42 @@ func (dst *Polygon) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Polygon) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Polygon) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if err := pgio.WriteByte(w, '('); err != nil { - return false, err - } + buf = append(buf, '(') for i, p := range src.P { if i > 0 { - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } - } - if _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)); err != nil { - return false, err + buf = append(buf, ',') } + buf = append(buf, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)...) } - err := pgio.WriteByte(w, ')') - return false, err + return append(buf, ')'), nil } -func (src *Polygon) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Polygon) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if _, err := pgio.WriteInt32(w, int32(len(src.P))); err != nil { - return false, err - } + buf = pgio.AppendInt32(buf, int32(len(src.P))) for _, p := range src.P { - if _, err := pgio.WriteUint64(w, math.Float64bits(p.X)); err != nil { - return false, err - } - - if _, err := pgio.WriteUint64(w, math.Float64bits(p.Y)); err != nil { - return false, err - } + buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/qchar.go b/qchar.go index 10b56534..9c40ce18 100644 --- a/qchar.go +++ b/qchar.go @@ -2,11 +2,8 @@ package pgtype import ( "fmt" - "io" "math" "strconv" - - "github.com/jackc/pgx/pgio" ) // QChar is for PostgreSQL's special 8-bit-only "char" type more akin to the C @@ -136,13 +133,13 @@ func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *QChar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *QChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - return false, pgio.WriteByte(w, byte(src.Int)) + return append(buf, byte(src.Int)), nil } diff --git a/testutil/testutil.go b/testutil/testutil.go index 5dd2fbe1..0effb42d 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "fmt" - "io" "os" "reflect" "testing" @@ -61,16 +60,16 @@ type forceTextEncoder struct { e pgtype.TextEncoder } -func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { - return f.e.EncodeText(ci, w) +func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + return f.e.EncodeText(ci, buf) } type forceBinaryEncoder struct { e pgtype.BinaryEncoder } -func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { - return f.e.EncodeBinary(ci, w) +func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + return f.e.EncodeBinary(ci, buf) } func ForceEncoder(e interface{}, formatCode int16) interface{} { diff --git a/text.go b/text.go index 54e2d774..6638c354 100644 --- a/text.go +++ b/text.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/json" "fmt" - "io" ) type Text struct { @@ -91,20 +90,19 @@ func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } -func (src *Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Text) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, src.String) - return false, err + return append(buf, src.String...), nil } -func (src *Text) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.EncodeText(ci, w) +func (src *Text) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return src.EncodeText(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/text_array.go b/text_array.go index 8a573d83..ed240e12 100644 --- a/text_array.go +++ b/text_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -163,23 +161,19 @@ func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *TextArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -192,59 +186,44 @@ func (src *TextArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `"NULL"`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `"NULL"`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -254,7 +233,7 @@ func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("text"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "text") + return nil, fmt.Errorf("unable to find oid for type name %v", "text") } for i := range src.Elements { @@ -264,38 +243,23 @@ func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -318,14 +282,13 @@ func (dst *TextArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *TextArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/tid.go b/tid.go index 7456b155..2f4412cb 100644 --- a/tid.go +++ b/tid.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "strconv" "strings" @@ -94,33 +93,29 @@ func (dst *Tid) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Tid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Tid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber)) - return false, err + buf = append(buf, fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber)...) + return buf, nil } -func (src *Tid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Tid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteUint32(w, src.BlockNumber) - if err != nil { - return false, err - } - - _, err = pgio.WriteUint16(w, src.OffsetNumber) - return false, err + buf = pgio.AppendUint32(buf, src.BlockNumber) + buf = pgio.AppendUint16(buf, src.OffsetNumber) + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/timestamp.go b/timestamp.go index 4fb10abc..75c6cffa 100644 --- a/timestamp.go +++ b/timestamp.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "time" "github.com/jackc/pgx/pgio" @@ -136,15 +135,15 @@ func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { // EncodeText writes the text encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src *Timestamp) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Timestamp) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if src.Time.Location() != time.UTC { - return false, fmt.Errorf("cannot encode non-UTC time into timestamp") + return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") } var s string @@ -158,21 +157,20 @@ func (src *Timestamp) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { s = "-infinity" } - _, err := io.WriteString(w, s) - return false, err + return append(buf, s...), nil } // EncodeBinary writes the binary encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src *Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Timestamp) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if src.Time.Location() != time.UTC { - return false, fmt.Errorf("cannot encode non-UTC time into timestamp") + return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") } var microsecSinceY2K int64 @@ -186,8 +184,7 @@ func (src *Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { microsecSinceY2K = negativeInfinityMicrosecondOffset } - _, err := pgio.WriteInt64(w, microsecSinceY2K) - return false, err + return pgio.AppendInt64(buf, microsecSinceY2K), nil } // Scan implements the database/sql Scanner interface. diff --git a/timestamp_array.go b/timestamp_array.go index 49815dae..a4f1b9dd 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "time" "github.com/jackc/pgx/pgio" @@ -164,23 +162,19 @@ func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *TimestampArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -193,59 +187,44 @@ func (src *TimestampArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -255,7 +234,7 @@ func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) if dt, ok := ci.DataTypeForName("timestamp"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "timestamp") + return nil, fmt.Errorf("unable to find oid for type name %v", "timestamp") } for i := range src.Elements { @@ -265,38 +244,23 @@ func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -319,14 +283,13 @@ func (dst *TimestampArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *TimestampArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/timestamptz.go b/timestamptz.go index 8606b2f2..97b0de2a 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "time" "github.com/jackc/pgx/pgio" @@ -140,12 +139,12 @@ func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Timestamptz) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Timestamptz) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var s string @@ -159,16 +158,15 @@ func (src *Timestamptz) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { s = "-infinity" } - _, err := io.WriteString(w, s) - return false, err + return append(buf, s...), nil } -func (src *Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Timestamptz) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var microsecSinceY2K int64 @@ -182,8 +180,7 @@ func (src *Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { microsecSinceY2K = negativeInfinityMicrosecondOffset } - _, err := pgio.WriteInt64(w, microsecSinceY2K) - return false, err + return pgio.AppendInt64(buf, microsecSinceY2K), nil } // Scan implements the database/sql Scanner interface. diff --git a/timestamptz_array.go b/timestamptz_array.go index bf983b6b..34d4f8a8 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "time" "github.com/jackc/pgx/pgio" @@ -164,23 +162,19 @@ func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *TimestamptzArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -193,59 +187,44 @@ func (src *TimestamptzArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -255,7 +234,7 @@ func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, erro if dt, ok := ci.DataTypeForName("timestamptz"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "timestamptz") + return nil, fmt.Errorf("unable to find oid for type name %v", "timestamptz") } for i := range src.Elements { @@ -265,38 +244,23 @@ func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, erro } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -319,14 +283,13 @@ func (dst *TimestamptzArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *TimestamptzArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/tsrange.go b/tsrange.go index 429a5cbe..783fb086 100644 --- a/tsrange.go +++ b/tsrange.go @@ -1,10 +1,8 @@ package pgtype import ( - "bytes" "database/sql/driver" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -106,72 +104,65 @@ func (dst *Tsrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Tsrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } switch src.LowerType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, '('); err != nil { - return false, err - } + buf = append(buf, '(') case Inclusive: - if err := pgio.WriteByte(w, '['); err != nil { - return false, err - } + buf = append(buf, '[') case Empty: - _, err := io.WriteString(w, "empty") - return false, err + return append(buf, "empty"...), nil default: - return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } + var err error + if src.LowerType != Unbounded { - if null, err := src.Lower.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } + buf = append(buf, ',') if src.UpperType != Unbounded { - if null, err := src.Upper.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } switch src.UpperType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, ')'); err != nil { - return false, err - } + buf = append(buf, ')') case Inclusive: - if err := pgio.WriteByte(w, ']'); err != nil { - return false, err - } + buf = append(buf, ']') default: - return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } - return false, nil + return buf, nil } -func (src Tsrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var rangeType byte @@ -182,10 +173,9 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= lowerUnboundedMask case Exclusive: case Empty: - err := pgio.WriteByte(w, emptyMask) - return false, err + return append(buf, emptyMask), nil default: - return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -195,54 +185,44 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } - if err := pgio.WriteByte(w, rangeType); err != nil { - return false, err - } + buf = append(buf, rangeType) - valBuf := &bytes.Buffer{} + var err error if src.LowerType != Unbounded { - null, err := src.Lower.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } if src.UpperType != Unbounded { - null, err := src.Upper.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/tstzrange.go b/tstzrange.go index f03a9f65..8fd3fd68 100644 --- a/tstzrange.go +++ b/tstzrange.go @@ -1,10 +1,8 @@ package pgtype import ( - "bytes" "database/sql/driver" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -106,72 +104,65 @@ func (dst *Tstzrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Tstzrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } switch src.LowerType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, '('); err != nil { - return false, err - } + buf = append(buf, '(') case Inclusive: - if err := pgio.WriteByte(w, '['); err != nil { - return false, err - } + buf = append(buf, '[') case Empty: - _, err := io.WriteString(w, "empty") - return false, err + return append(buf, "empty"...), nil default: - return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } + var err error + if src.LowerType != Unbounded { - if null, err := src.Lower.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } + buf = append(buf, ',') if src.UpperType != Unbounded { - if null, err := src.Upper.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } switch src.UpperType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, ')'); err != nil { - return false, err - } + buf = append(buf, ')') case Inclusive: - if err := pgio.WriteByte(w, ']'); err != nil { - return false, err - } + buf = append(buf, ']') default: - return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } - return false, nil + return buf, nil } -func (src Tstzrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var rangeType byte @@ -182,10 +173,9 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= lowerUnboundedMask case Exclusive: case Empty: - err := pgio.WriteByte(w, emptyMask) - return false, err + return append(buf, emptyMask), nil default: - return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -195,54 +185,44 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } - if err := pgio.WriteByte(w, rangeType); err != nil { - return false, err - } + buf = append(buf, rangeType) - valBuf := &bytes.Buffer{} + var err error if src.LowerType != Unbounded { - null, err := src.Lower.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } if src.UpperType != Unbounded { - null, err := src.Upper.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/typed_array.go.erb b/typed_array.go.erb index 6752bd5b..0d454ac8 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -163,23 +163,19 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) erro } <% end %> -func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -192,60 +188,45 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `<%= text_null %>`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `<%= text_null %>`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } <% if binary_format == "true" %> - func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -255,7 +236,7 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") + return nil, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") } for i := range src.Elements { @@ -265,38 +246,23 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } <% end %> @@ -320,14 +286,13 @@ func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *<%= pgtype_array_type %>) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/typed_range.go.erb b/typed_range.go.erb index 49db1b1d..90c23991 100644 --- a/typed_range.go.erb +++ b/typed_range.go.erb @@ -106,73 +106,66 @@ func (dst *<%= range_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src <%= range_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - switch src.Status { - case Null: - return true, nil - case Undefined: - return false, errUndefined - } +func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } switch src.LowerType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, '('); err != nil { - return false, err - } + buf = append(buf, '(') case Inclusive: - if err := pgio.WriteByte(w, '['); err != nil { - return false, err - } + buf = append(buf, '[') case Empty: - _, err := io.WriteString(w, "empty") - return false, err + return append(buf, "empty"...), nil default: - return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } + var err error + if src.LowerType != Unbounded { - if null, err := src.Lower.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } + buf = append(buf, ',') if src.UpperType != Unbounded { - if null, err := src.Upper.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } switch src.UpperType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, ')'); err != nil { - return false, err - } + buf = append(buf, ')') case Inclusive: - if err := pgio.WriteByte(w, ']'); err != nil { - return false, err - } + buf = append(buf, ']') default: - return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } - return false, nil + return buf, nil } -func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - switch src.Status { - case Null: - return true, nil - case Undefined: - return false, errUndefined - } +func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } var rangeType byte switch src.LowerType { @@ -182,10 +175,9 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, erro rangeType |= lowerUnboundedMask case Exclusive: case Empty: - err := pgio.WriteByte(w, emptyMask) - return false, err + return append(buf, emptyMask), nil default: - return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -195,54 +187,44 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, erro rangeType |= upperUnboundedMask case Exclusive: default: - return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } - if err := pgio.WriteByte(w, rangeType); err != nil { - return false, err - } + buf = append(buf, rangeType) - valBuf := &bytes.Buffer{} + var err error if src.LowerType != Unbounded { - null, err := src.Lower.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } if src.UpperType != Unbounded { - null, err := src.Upper.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/uuid.go b/uuid.go index a4a93ab3..c73c501e 100644 --- a/uuid.go +++ b/uuid.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/hex" "fmt" - "io" ) type Uuid struct { @@ -126,28 +125,26 @@ func (dst *Uuid) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Uuid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Uuid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, encodeUuid(src.Bytes)) - return false, err + return append(buf, encodeUuid(src.Bytes)...), nil } -func (src *Uuid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Uuid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := w.Write(src.Bytes[:]) - return false, err + return append(buf, src.Bytes[:]...), nil } // Scan implements the database/sql Scanner interface. diff --git a/varbit.go b/varbit.go index b986f02a..9a9fe1e1 100644 --- a/varbit.go +++ b/varbit.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -76,43 +75,37 @@ func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Varbit) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Varbit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - buf := make([]byte, int(src.Len)) - for i, _ := range buf { + for i := int32(0); i < src.Len; i++ { byteIdx := i / 8 bitMask := byte(128 >> byte(i%8)) char := byte('0') if src.Bytes[byteIdx]&bitMask > 0 { char = '1' } - buf[i] = char + buf = append(buf, char) } - _, err := w.Write(buf) - return false, err + return buf, nil } -func (src *Varbit) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Varbit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if _, err := pgio.WriteInt32(w, src.Len); err != nil { - return false, err - } - - _, err := w.Write(src.Bytes) - return false, err + buf = pgio.AppendInt32(buf, src.Len) + return append(buf, src.Bytes...), nil } // Scan implements the database/sql Scanner interface. diff --git a/varchar.go b/varchar.go index 80673fa8..371efd7e 100644 --- a/varchar.go +++ b/varchar.go @@ -2,7 +2,6 @@ package pgtype import ( "database/sql/driver" - "io" ) type Varchar Text @@ -32,12 +31,12 @@ func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } -func (src *Varchar) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Text)(src).EncodeText(ci, w) +func (src *Varchar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeText(ci, buf) } -func (src *Varchar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Text)(src).EncodeBinary(ci, w) +func (src *Varchar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/varchar_array.go b/varchar_array.go index d84fac02..c34ac0b6 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -163,23 +161,19 @@ func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *VarcharArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -192,59 +186,44 @@ func (src *VarcharArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `"NULL"`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `"NULL"`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -254,7 +233,7 @@ func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("varchar"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "varchar") + return nil, fmt.Errorf("unable to find oid for type name %v", "varchar") } for i := range src.Elements { @@ -264,38 +243,23 @@ func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -318,14 +282,13 @@ func (dst *VarcharArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *VarcharArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/xid.go b/xid.go index 90a8d691..84acd1b0 100644 --- a/xid.go +++ b/xid.go @@ -2,7 +2,6 @@ package pgtype import ( "database/sql/driver" - "io" ) // Xid is PostgreSQL's Transaction ID type. @@ -46,12 +45,12 @@ func (dst *Xid) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src *Xid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*pguint32)(src).EncodeText(ci, w) +func (src *Xid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeText(ci, buf) } -func (src *Xid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (*pguint32)(src).EncodeBinary(ci, w) +func (src *Xid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. From 6f398d8bb59696901d4fe3a9c88e158f54b9b395 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 May 2017 08:48:40 -0500 Subject: [PATCH 0085/1158] Update pgproto3 to enable pgmock --- read.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/read.go b/read.go index 7ddad508..033bada4 100644 --- a/read.go +++ b/read.go @@ -1,6 +1,7 @@ package pgio import ( + "bytes" "encoding/binary" ) @@ -38,3 +39,13 @@ func NextInt64(buf []byte) ([]byte, int64) { buf, n := NextUint64(buf) return buf, int64(n) } + +func NextCString(buf []byte) ([]byte, string, bool) { + idx := bytes.IndexByte(buf, 0) + if idx < 0 { + return buf, "", false + } + cstring := string(buf[:idx]) + buf = buf[:idx+1] + return buf, cstring, true +} From 61d43869314ff91cd05d38edff698f18ba5d06f4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 May 2017 08:48:40 -0500 Subject: [PATCH 0086/1158] Update pgproto3 to enable pgmock --- backend.go | 30 ++++++++++++++- frontend.go | 9 ++++- startup_message.go | 95 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 130 insertions(+), 4 deletions(-) create mode 100644 startup_message.go diff --git a/backend.go b/backend.go index c04116a8..bd477315 100644 --- a/backend.go +++ b/backend.go @@ -2,7 +2,6 @@ package pgproto3 import ( "encoding/binary" - "errors" "fmt" "io" @@ -20,6 +19,7 @@ type Backend struct { parse Parse passwordMessage PasswordMessage query Query + startupMessage StartupMessage sync Sync terminate Terminate } @@ -30,7 +30,33 @@ func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { } func (b *Backend) Send(msg BackendMessage) error { - return errors.New("not implemented") + buf, err := msg.MarshalBinary() + if err != nil { + return nil + } + + _, err = b.w.Write(buf) + return err +} + +func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { + buf, err := b.cr.Next(4) + if err != nil { + return nil, err + } + msgSize := int(binary.BigEndian.Uint32(buf) - 4) + + buf, err = b.cr.Next(msgSize) + if err != nil { + return nil, err + } + + err = b.startupMessage.Decode(buf) + if err != nil { + return nil, err + } + + return &b.startupMessage, nil } func (b *Backend) Receive() (FrontendMessage, error) { diff --git a/frontend.go b/frontend.go index 50835836..27a9890a 100644 --- a/frontend.go +++ b/frontend.go @@ -2,7 +2,6 @@ package pgproto3 import ( "encoding/binary" - "errors" "fmt" "io" @@ -43,7 +42,13 @@ func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { } func (b *Frontend) Send(msg FrontendMessage) error { - return errors.New("not implemented") + buf, err := msg.MarshalBinary() + if err != nil { + return nil + } + + _, err = b.w.Write(buf) + return err } func (b *Frontend) Receive() (BackendMessage, error) { diff --git a/startup_message.go b/startup_message.go new file mode 100644 index 00000000..ebb804fe --- /dev/null +++ b/startup_message.go @@ -0,0 +1,95 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "fmt" +) + +const ( + protocolVersionNumber = 196608 // 3.0 + sslRequestNumber = 80877103 +) + +type StartupMessage struct { + ProtocolVersion uint32 + Parameters map[string]string +} + +func (*StartupMessage) Frontend() {} + +func (dst *StartupMessage) Decode(src []byte) error { + if len(src) < 4 { + return fmt.Errorf("startup message too short") + } + + dst.ProtocolVersion = binary.BigEndian.Uint32(src) + rp := 4 + + if dst.ProtocolVersion == sslRequestNumber { + return fmt.Errorf("can't handle ssl connection request") + } + + if dst.ProtocolVersion != protocolVersionNumber { + return fmt.Errorf("Bad startup message version number. Expected %d, got %d", protocolVersionNumber, dst.ProtocolVersion) + } + + dst.Parameters = make(map[string]string) + for { + idx := bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "StartupMesage"} + } + key := string(src[rp : rp+idx]) + rp += idx + 1 + + idx = bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "StartupMesage"} + } + value := string(src[rp : rp+idx]) + rp += idx + 1 + + dst.Parameters[key] = value + + if len(src[rp:]) == 1 { + if src[rp] != 0 { + return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) + } + break + } + } + + return nil +} + +func (src *StartupMessage) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.Write(bigEndian.Uint32(0)) + buf.Write(bigEndian.Uint32(src.ProtocolVersion)) + for k, v := range src.Parameters { + buf.WriteString(k) + buf.WriteByte(0) + buf.WriteString(v) + buf.WriteByte(0) + } + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[0:4], uint32(buf.Len())) + + return buf.Bytes(), nil +} + +func (src *StartupMessage) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProtocolVersion uint32 + Parameters map[string]string + }{ + Type: "StartupMessage", + ProtocolVersion: src.ProtocolVersion, + Parameters: src.Parameters, + }) +} From 2d209bd579c721e74cdc7614b1e643958fbae0e3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 May 2017 08:53:37 -0500 Subject: [PATCH 0087/1158] Remove read functions from pgio and update docs --- doc.go | 8 +++----- read.go | 51 ---------------------------------------------- read_test.go | 57 ---------------------------------------------------- 3 files changed, 3 insertions(+), 113 deletions(-) delete mode 100644 read.go delete mode 100644 read_test.go diff --git a/doc.go b/doc.go index 36233a47..ef2dcc7f 100644 --- a/doc.go +++ b/doc.go @@ -1,8 +1,6 @@ -// Package pgio a extremely low-level IO toolkit for the PostgreSQL wire protocol. +// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol. /* -pgio provides functions for reading and writing integers from io.Reader and -io.Writer while doing byte order conversion. It publishes interfaces which -readers and writers may implement to decode and encode messages with the minimum -of memory allocations. +pgio provides functions for appending integers to a []byte while doing byte +order conversion. */ package pgio diff --git a/read.go b/read.go deleted file mode 100644 index 033bada4..00000000 --- a/read.go +++ /dev/null @@ -1,51 +0,0 @@ -package pgio - -import ( - "bytes" - "encoding/binary" -) - -func NextByte(buf []byte) ([]byte, byte) { - b := buf[0] - return buf[1:], b -} - -func NextUint16(buf []byte) ([]byte, uint16) { - n := binary.BigEndian.Uint16(buf) - return buf[2:], n -} - -func NextUint32(buf []byte) ([]byte, uint32) { - n := binary.BigEndian.Uint32(buf) - return buf[4:], n -} - -func NextUint64(buf []byte) ([]byte, uint64) { - n := binary.BigEndian.Uint64(buf) - return buf[8:], n -} - -func NextInt16(buf []byte) ([]byte, int16) { - buf, n := NextUint16(buf) - return buf, int16(n) -} - -func NextInt32(buf []byte) ([]byte, int32) { - buf, n := NextUint32(buf) - return buf, int32(n) -} - -func NextInt64(buf []byte) ([]byte, int64) { - buf, n := NextUint64(buf) - return buf, int64(n) -} - -func NextCString(buf []byte) ([]byte, string, bool) { - idx := bytes.IndexByte(buf, 0) - if idx < 0 { - return buf, "", false - } - cstring := string(buf[:idx]) - buf = buf[:idx+1] - return buf, cstring, true -} diff --git a/read_test.go b/read_test.go deleted file mode 100644 index fbe29ae4..00000000 --- a/read_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package pgio - -import ( - "testing" -) - -func TestNextByte(t *testing.T) { - buf := []byte{42, 1} - var b byte - buf, b = NextByte(buf) - if b != 42 { - t.Errorf("NextByte(buf) => %v, want %v", b, 42) - } - buf, b = NextByte(buf) - if b != 1 { - t.Errorf("NextByte(buf) => %v, want %v", b, 1) - } -} - -func TestNextUint16(t *testing.T) { - buf := []byte{0, 42, 0, 1} - var n uint16 - buf, n = NextUint16(buf) - if n != 42 { - t.Errorf("NextUint16(buf) => %v, want %v", n, 42) - } - buf, n = NextUint16(buf) - if n != 1 { - t.Errorf("NextUint16(buf) => %v, want %v", n, 1) - } -} - -func TestNextUint32(t *testing.T) { - buf := []byte{0, 0, 0, 42, 0, 0, 0, 1} - var n uint32 - buf, n = NextUint32(buf) - if n != 42 { - t.Errorf("NextUint32(buf) => %v, want %v", n, 42) - } - buf, n = NextUint32(buf) - if n != 1 { - t.Errorf("NextUint32(buf) => %v, want %v", n, 1) - } -} - -func TestNextUint64(t *testing.T) { - buf := []byte{0, 0, 0, 0, 0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 1} - var n uint64 - buf, n = NextUint64(buf) - if n != 42 { - t.Errorf("NextUint64(buf) => %v, want %v", n, 42) - } - buf, n = NextUint64(buf) - if n != 1 { - t.Errorf("NextUint64(buf) => %v, want %v", n, 1) - } -} From 45b67f9b95a4d94bbb81b8ecff5b27ca92a9b1f1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 May 2017 19:48:03 -0500 Subject: [PATCH 0088/1158] Fix issues identified by go vet --- interval.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/interval.go b/interval.go index ea5c7d3e..85d76d99 100644 --- a/interval.go +++ b/interval.go @@ -118,26 +118,26 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { hours, err := strconv.ParseInt(timeParts[0], 10, 64) if err != nil { - return fmt.Errorf("bad interval hour format: %s", hours) + return fmt.Errorf("bad interval hour format: %s", timeParts[0]) } minutes, err := strconv.ParseInt(timeParts[1], 10, 64) if err != nil { - return fmt.Errorf("bad interval minute format: %s", minutes) + return fmt.Errorf("bad interval minute format: %s", timeParts[1]) } secondParts := strings.SplitN(timeParts[2], ".", 2) seconds, err := strconv.ParseInt(secondParts[0], 10, 64) if err != nil { - return fmt.Errorf("bad interval second format: %s", seconds) + return fmt.Errorf("bad interval second format: %s", secondParts[0]) } var uSeconds int64 if len(secondParts) == 2 { uSeconds, err = strconv.ParseInt(secondParts[1], 10, 64) if err != nil { - return fmt.Errorf("bad interval decimal format: %s", seconds) + return fmt.Errorf("bad interval decimal format: %s", secondParts[1]) } for i := 0; i < 6-len(secondParts[1]); i++ { From 80edb27deed6fb956c714b847b0a798baa7a4ed6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 8 May 2017 18:07:11 -0500 Subject: [PATCH 0089/1158] Fix Bind Decode to advance rp --- bind.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bind.go b/bind.go index 6661a775..cbd71e13 100644 --- a/bind.go +++ b/bind.go @@ -52,6 +52,7 @@ func (dst *Bind) Decode(src []byte) error { return &invalidMessageFormatErr{messageType: "Bind"} } parameterCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 dst.Parameters = make([][]byte, parameterCount) From c6aef151817d0b6deb9c09847bb368b93dd3292f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 May 2017 17:56:54 -0500 Subject: [PATCH 0090/1158] Add basic pgmock support Primarily useful for testing pgx itself. Design is still subject to change. --- startup_message.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/startup_message.go b/startup_message.go index ebb804fe..4847d629 100644 --- a/startup_message.go +++ b/startup_message.go @@ -8,7 +8,7 @@ import ( ) const ( - protocolVersionNumber = 196608 // 3.0 + ProtocolVersionNumber = 196608 // 3.0 sslRequestNumber = 80877103 ) @@ -31,8 +31,8 @@ func (dst *StartupMessage) Decode(src []byte) error { return fmt.Errorf("can't handle ssl connection request") } - if dst.ProtocolVersion != protocolVersionNumber { - return fmt.Errorf("Bad startup message version number. Expected %d, got %d", protocolVersionNumber, dst.ProtocolVersion) + if dst.ProtocolVersion != ProtocolVersionNumber { + return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) } dst.Parameters = make(map[string]string) From e45a42c7efeeabc38dc4e4bbab834cbdecd93bec Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 May 2017 15:50:27 -0500 Subject: [PATCH 0091/1158] Do not create empty slices in Bind.Decode --- bind.go | 58 +++++++++++++++++++++++++++++++-------------------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/bind.go b/bind.go index cbd71e13..79fb4503 100644 --- a/bind.go +++ b/bind.go @@ -18,6 +18,8 @@ type Bind struct { func (*Bind) Frontend() {} func (dst *Bind) Decode(src []byte) error { + *dst = Bind{} + idx := bytes.IndexByte(src, 0) if idx < 0 { return &invalidMessageFormatErr{messageType: "Bind"} @@ -38,14 +40,16 @@ func (dst *Bind) Decode(src []byte) error { parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) rp += 2 - dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount) + if parameterFormatCodeCount > 0 { + dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount) - if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 { - return &invalidMessageFormatErr{messageType: "Bind"} - } - for i := 0; i < parameterFormatCodeCount; i++ { - dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) - rp += 2 + if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + for i := 0; i < parameterFormatCodeCount; i++ { + dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } } if len(src[rp:]) < 2 { @@ -54,27 +58,29 @@ func (dst *Bind) Decode(src []byte) error { parameterCount := int(binary.BigEndian.Uint16(src[rp:])) rp += 2 - dst.Parameters = make([][]byte, parameterCount) + if parameterCount > 0 { + dst.Parameters = make([][]byte, parameterCount) - for i := 0; i < parameterCount; i++ { - if len(src[rp:]) < 4 { - return &invalidMessageFormatErr{messageType: "Bind"} + for i := 0; i < parameterCount; i++ { + if len(src[rp:]) < 4 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + + msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + // null + if msgSize == -1 { + continue + } + + if len(src[rp:]) < msgSize { + return &invalidMessageFormatErr{messageType: "Bind"} + } + + dst.Parameters[i] = src[rp : rp+msgSize] + rp += msgSize } - - msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - - // null - if msgSize == -1 { - continue - } - - if len(src[rp:]) < msgSize { - return &invalidMessageFormatErr{messageType: "Bind"} - } - - dst.Parameters[i] = src[rp : rp+msgSize] - rp += msgSize } if len(src[rp:]) < 2 { From b1934ad4c27caf51c362c064469e84b77e6a3018 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 May 2017 17:31:22 -0500 Subject: [PATCH 0092/1158] Add flush and close messages to pgproto3 --- backend.go | 6 ++++++ close.go | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ flush.go | 29 ++++++++++++++++++++++++++ 3 files changed, 95 insertions(+) create mode 100644 close.go create mode 100644 flush.go diff --git a/backend.go b/backend.go index bd477315..df66a799 100644 --- a/backend.go +++ b/backend.go @@ -14,8 +14,10 @@ type Backend struct { // Frontend message flyweights bind Bind + _close Close describe Describe execute Execute + flush Flush parse Parse passwordMessage PasswordMessage query Query @@ -72,10 +74,14 @@ func (b *Backend) Receive() (FrontendMessage, error) { switch msgType { case 'B': msg = &b.bind + case 'C': + msg = &b._close case 'D': msg = &b.describe case 'E': msg = &b.execute + case 'H': + msg = &b.flush case 'P': msg = &b.parse case 'p': diff --git a/close.go b/close.go new file mode 100644 index 00000000..454ef68e --- /dev/null +++ b/close.go @@ -0,0 +1,60 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type Close struct { + ObjectType byte // 'S' = prepared statement, 'P' = portal + Name string +} + +func (*Close) Frontend() {} + +func (dst *Close) Decode(src []byte) error { + if len(src) < 2 { + return &invalidMessageFormatErr{messageType: "Close"} + } + + dst.ObjectType = src[0] + rp := 1 + + idx := bytes.IndexByte(src[rp:], 0) + if idx != len(src[rp:])-1 { + return &invalidMessageFormatErr{messageType: "Close"} + } + + dst.Name = string(src[rp : len(src)-1]) + + return nil +} + +func (src *Close) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('C') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteByte(src.ObjectType) + buf.WriteString(src.Name) + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *Close) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ObjectType string + Name string + }{ + Type: "Close", + ObjectType: string(src.ObjectType), + Name: src.Name, + }) +} diff --git a/flush.go b/flush.go new file mode 100644 index 00000000..d26f5c0c --- /dev/null +++ b/flush.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Flush struct{} + +func (*Flush) Frontend() {} + +func (dst *Flush) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *Flush) MarshalBinary() ([]byte, error) { + return []byte{'H', 0, 0, 0, 4}, nil +} + +func (src *Flush) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Flush", + }) +} From fe36df4fff215e003021fcb1f225d5edffa703e0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 08:34:20 -0500 Subject: [PATCH 0093/1158] Uncomment Hstore tests --- hstore_test.go | 58 +++++++++++++++++++++++++------------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/hstore_test.go b/hstore_test.go index 8189e4db..dc2439fc 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -9,41 +9,41 @@ import ( ) func TestHstoreTranscode(t *testing.T) { - // text := func(s string) pgtype.Text { - // return pgtype.Text{String: s, Status: pgtype.Present} - // } + text := func(s string) pgtype.Text { + return pgtype.Text{String: s, Status: pgtype.Present} + } values := []interface{}{ &pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, - // &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, - // &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, - // &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, - // &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, - // &pgtype.Hstore{Status: pgtype.Null}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + &pgtype.Hstore{Status: pgtype.Null}, } - // specialStrings := []string{ - // `"`, - // `'`, - // `\`, - // `\\`, - // `=>`, - // ` `, - // `\ / / \\ => " ' " '`, - // } - // for _, s := range specialStrings { - // // Special key values - // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning - // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle - // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end - // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key - // // Special value values - // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning - // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle - // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end - // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key - // } + // Special value values + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + } testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { a := ai.(pgtype.Hstore) From 4c51d6af822a3b881f008e9d8ac11604f18a9a40 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 08:36:40 -0500 Subject: [PATCH 0094/1158] Test &pgtype.QChar --- qchar_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/qchar_test.go b/qchar_test.go index b810b89c..057a557f 100644 --- a/qchar_test.go +++ b/qchar_test.go @@ -11,12 +11,12 @@ import ( func TestQCharTranscode(t *testing.T) { testutil.TestPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ - pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, - pgtype.QChar{Int: -1, Status: pgtype.Present}, - pgtype.QChar{Int: 0, Status: pgtype.Present}, - pgtype.QChar{Int: 1, Status: pgtype.Present}, - pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present}, - pgtype.QChar{Int: 0, Status: pgtype.Null}, + &pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, + &pgtype.QChar{Int: -1, Status: pgtype.Present}, + &pgtype.QChar{Int: 0, Status: pgtype.Present}, + &pgtype.QChar{Int: 1, Status: pgtype.Present}, + &pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present}, + &pgtype.QChar{Int: 0, Status: pgtype.Null}, }, func(a, b interface{}) bool { return reflect.DeepEqual(a, b) }) From 6ba93d4e54a3b99bd824a477f6e5cf3efbef06d9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 08:38:27 -0500 Subject: [PATCH 0095/1158] Fix TestNumericNormalize --- numeric_test.go | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/numeric_test.go b/numeric_test.go index d68a9347..5f3a3416 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -49,47 +49,47 @@ func TestNumericNormalize(t *testing.T) { testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { SQL: "select '0'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, }, { SQL: "select '1'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, }, { SQL: "select '10.00'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present}, }, { SQL: "select '1e-3'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present}, }, { SQL: "select '-1'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, }, { SQL: "select '10000'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present}, }, { SQL: "select '3.14'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, }, { SQL: "select '1.1'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present}, }, { SQL: "select '100010001'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present}, }, { SQL: "select '100010001.0001'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present}, }, { SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", - Value: pgtype.Numeric{ + Value: &pgtype.Numeric{ Int: mustParseBigInt(t, "423723478923478928934789237432487213832189417894318904389012483210893443219085471578891547854892438945012347981"), Exp: -41, Status: pgtype.Present, @@ -97,7 +97,7 @@ func TestNumericNormalize(t *testing.T) { }, { SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", - Value: pgtype.Numeric{ + Value: &pgtype.Numeric{ Int: mustParseBigInt(t, "8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), Exp: -196, Status: pgtype.Present, @@ -105,7 +105,7 @@ func TestNumericNormalize(t *testing.T) { }, { SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", - Value: pgtype.Numeric{ + Value: &pgtype.Numeric{ Int: mustParseBigInt(t, "123"), Exp: -186, Status: pgtype.Present, From 97a927bb03fb17a22a2c7aeb856cf28e8e3b6bd9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 08:39:53 -0500 Subject: [PATCH 0096/1158] Fix TestIntervalNormalize --- interval_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/interval_test.go b/interval_test.go index 18e21ddd..76ea3240 100644 --- a/interval_test.go +++ b/interval_test.go @@ -33,31 +33,31 @@ func TestIntervalNormalize(t *testing.T) { testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { SQL: "select '1 second'::interval", - Value: pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, + Value: &pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, }, { SQL: "select '1.000001 second'::interval", - Value: pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, + Value: &pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, }, { SQL: "select '34223 hours'::interval", - Value: pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, + Value: &pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, }, { SQL: "select '1 day'::interval", - Value: pgtype.Interval{Days: 1, Status: pgtype.Present}, + Value: &pgtype.Interval{Days: 1, Status: pgtype.Present}, }, { SQL: "select '1 month'::interval", - Value: pgtype.Interval{Months: 1, Status: pgtype.Present}, + Value: &pgtype.Interval{Months: 1, Status: pgtype.Present}, }, { SQL: "select '1 year'::interval", - Value: pgtype.Interval{Months: 12, Status: pgtype.Present}, + Value: &pgtype.Interval{Months: 12, Status: pgtype.Present}, }, { SQL: "select '-13 mon'::interval", - Value: pgtype.Interval{Months: -13, Status: pgtype.Present}, + Value: &pgtype.Interval{Months: -13, Status: pgtype.Present}, }, }) } From a3e05ea29f41fbd74c37e74a81b5a2783c885e87 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 08:42:39 -0500 Subject: [PATCH 0097/1158] Fix TestHstoreArrayTranscode --- hstore_array_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hstore_array_test.go b/hstore_array_test.go index d26497b1..fcf08c49 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -49,7 +49,7 @@ func TestHstoreArrayTranscode(t *testing.T) { values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key } - src := pgtype.HstoreArray{ + src := &pgtype.HstoreArray{ Elements: values, Dimensions: []pgtype.ArrayDimension{{Length: int32(len(values)), LowerBound: 1}}, Status: pgtype.Present, From 1f1677ba5e005575694a1de28421bc31e5fa48cb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 09:44:15 -0500 Subject: [PATCH 0098/1158] Ensure shopspring-numeric tests run --- ext/shopspring-numeric/decimal_test.go | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go index 50c0fb8b..08483dda 100644 --- a/ext/shopspring-numeric/decimal_test.go +++ b/ext/shopspring-numeric/decimal_test.go @@ -25,61 +25,61 @@ func TestNumericNormalize(t *testing.T) { testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { SQL: "select '0'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, }, { SQL: "select '1'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, }, { SQL: "select '10.00'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Status: pgtype.Present}, }, { SQL: "select '1e-3'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, }, { SQL: "select '-1'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, }, { SQL: "select '10000'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Status: pgtype.Present}, }, { SQL: "select '3.14'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, }, { SQL: "select '1.1'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Status: pgtype.Present}, }, { SQL: "select '100010001'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Status: pgtype.Present}, }, { SQL: "select '100010001.0001'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Status: pgtype.Present}, }, { SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", - Value: shopspring.Numeric{ + Value: &shopspring.Numeric{ Decimal: mustParseDecimal(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), Status: pgtype.Present, }, }, { SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", - Value: shopspring.Numeric{ + Value: &shopspring.Numeric{ Decimal: mustParseDecimal(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), Status: pgtype.Present, }, }, { SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", - Value: shopspring.Numeric{ + Value: &shopspring.Numeric{ Decimal: mustParseDecimal(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), Status: pgtype.Present, }, From 071de0b674ac105e5b459fde0b5a98098049b927 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 09:46:06 -0500 Subject: [PATCH 0099/1158] Fix shopsprint-numeric test --- ext/shopspring-numeric/decimal_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go index 08483dda..79121ef3 100644 --- a/ext/shopspring-numeric/decimal_test.go +++ b/ext/shopspring-numeric/decimal_test.go @@ -22,7 +22,7 @@ func mustParseDecimal(t *testing.T, src string) decimal.Decimal { } func TestNumericNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ { SQL: "select '0'::numeric", Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, @@ -84,6 +84,11 @@ func TestNumericNormalize(t *testing.T) { Status: pgtype.Present, }, }, + }, func(aa, bb interface{}) bool { + a := aa.(shopspring.Numeric) + b := bb.(shopspring.Numeric) + + return a.Status == b.Status && a.Decimal.Equal(b.Decimal) }) } From d6312305ae92ccfb03cc100a343377a1bba84687 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 26 May 2017 17:00:44 -0500 Subject: [PATCH 0100/1158] Replace MarshalBinary with Encode This new approach can avoid allocations. --- authentication.go | 20 +++++++++--------- backend.go | 7 +------ backend_key_data.go | 17 ++++++++-------- bind.go | 43 ++++++++++++++++++--------------------- bind_complete.go | 4 ++-- close.go | 23 ++++++++++----------- close_complete.go | 4 ++-- command_complete.go | 18 ++++++++-------- copy_both_response.go | 21 ++++++++++--------- copy_data.go | 17 +++++++--------- copy_in_response.go | 21 ++++++++++--------- copy_out_response.go | 21 ++++++++++--------- data_row.go | 26 +++++++++++------------ describe.go | 23 ++++++++++----------- empty_query_response.go | 4 ++-- error_response.go | 8 ++++---- execute.go | 22 ++++++++++---------- flush.go | 4 ++-- frontend.go | 7 +------ function_call_response.go | 23 +++++++++++---------- no_data.go | 4 ++-- notice_response.go | 4 ++-- notification_response.go | 22 +++++++++++--------- parameter_description.go | 21 ++++++++++--------- parameter_status.go | 25 +++++++++++------------ parse.go | 31 ++++++++++++++-------------- parse_complete.go | 4 ++-- password_message.go | 18 ++++++++-------- pgproto3.go | 9 ++++---- query.go | 18 ++++++++-------- ready_for_query.go | 4 ++-- row_description.go | 35 ++++++++++++++++--------------- startup_message.go | 26 ++++++++++++----------- sync.go | 4 ++-- terminate.go | 4 ++-- 35 files changed, 277 insertions(+), 285 deletions(-) diff --git a/authentication.go b/authentication.go index 54f4978f..c04ee448 100644 --- a/authentication.go +++ b/authentication.go @@ -1,9 +1,10 @@ package pgproto3 import ( - "bytes" "encoding/binary" "fmt" + + "github.com/jackc/pgx/pgio" ) const ( @@ -36,19 +37,18 @@ func (dst *Authentication) Decode(src []byte) error { return nil } -func (src *Authentication) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - buf.WriteByte('R') - buf.Write(bigEndian.Uint32(0)) - buf.Write(bigEndian.Uint32(src.Type)) +func (src *Authentication) Encode(dst []byte) []byte { + dst = append(dst, 'R') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint32(dst, src.Type) switch src.Type { case AuthTypeMD5Password: - buf.Write(src.Salt[:]) + dst = append(dst, src.Salt[:]...) } - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return buf.Bytes(), nil + return dst } diff --git a/backend.go b/backend.go index df66a799..bf96ba95 100644 --- a/backend.go +++ b/backend.go @@ -32,12 +32,7 @@ func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { } func (b *Backend) Send(msg BackendMessage) error { - buf, err := msg.MarshalBinary() - if err != nil { - return nil - } - - _, err = b.w.Write(buf) + _, err := b.w.Write(msg.Encode(nil)) return err } diff --git a/backend_key_data.go b/backend_key_data.go index 04f31aec..5a478f10 100644 --- a/backend_key_data.go +++ b/backend_key_data.go @@ -1,9 +1,10 @@ package pgproto3 import ( - "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type BackendKeyData struct { @@ -24,14 +25,12 @@ func (dst *BackendKeyData) Decode(src []byte) error { return nil } -func (src *BackendKeyData) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - buf.WriteByte('K') - buf.Write(bigEndian.Uint32(12)) - buf.Write(bigEndian.Uint32(src.ProcessID)) - buf.Write(bigEndian.Uint32(src.SecretKey)) - return buf.Bytes(), nil +func (src *BackendKeyData) Encode(dst []byte) []byte { + dst = append(dst, 'K') + dst = pgio.AppendUint32(dst, 12) + dst = pgio.AppendUint32(dst, src.ProcessID) + dst = pgio.AppendUint32(dst, src.SecretKey) + return dst } func (src *BackendKeyData) MarshalJSON() ([]byte, error) { diff --git a/bind.go b/bind.go index 79fb4503..cceee6ab 100644 --- a/bind.go +++ b/bind.go @@ -5,6 +5,8 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type Bind struct { @@ -101,45 +103,40 @@ func (dst *Bind) Decode(src []byte) error { return nil } -func (src *Bind) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} +func (src *Bind) Encode(dst []byte) []byte { + dst = append(dst, 'B') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) - buf.WriteByte('B') - buf.Write(bigEndian.Uint32(0)) - - buf.WriteString(src.DestinationPortal) - buf.WriteByte(0) - buf.WriteString(src.PreparedStatement) - buf.WriteByte(0) - - buf.Write(bigEndian.Uint16(uint16(len(src.ParameterFormatCodes)))) + dst = append(dst, src.DestinationPortal...) + dst = append(dst, 0) + dst = append(dst, src.PreparedStatement...) + dst = append(dst, 0) + dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes))) for _, fc := range src.ParameterFormatCodes { - buf.Write(bigEndian.Int16(fc)) + dst = pgio.AppendInt16(dst, fc) } - buf.Write(bigEndian.Uint16(uint16(len(src.Parameters)))) - + dst = pgio.AppendUint16(dst, uint16(len(src.Parameters))) for _, p := range src.Parameters { if p == nil { - buf.Write(bigEndian.Int32(-1)) + dst = pgio.AppendInt32(dst, -1) continue } - buf.Write(bigEndian.Int32(int32(len(p)))) - buf.Write(p) + dst = pgio.AppendInt32(dst, int32(len(p))) + dst = append(dst, p...) } - buf.Write(bigEndian.Uint16(uint16(len(src.ResultFormatCodes)))) - + dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes))) for _, fc := range src.ResultFormatCodes { - buf.Write(bigEndian.Int16(fc)) + dst = pgio.AppendInt16(dst, fc) } - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return buf.Bytes(), nil + return dst } func (src *Bind) MarshalJSON() ([]byte, error) { diff --git a/bind_complete.go b/bind_complete.go index 4f1c44b8..60360519 100644 --- a/bind_complete.go +++ b/bind_complete.go @@ -16,8 +16,8 @@ func (dst *BindComplete) Decode(src []byte) error { return nil } -func (src *BindComplete) MarshalBinary() ([]byte, error) { - return []byte{'2', 0, 0, 0, 4}, nil +func (src *BindComplete) Encode(dst []byte) []byte { + return append(dst, '2', 0, 0, 0, 4) } func (src *BindComplete) MarshalJSON() ([]byte, error) { diff --git a/close.go b/close.go index 454ef68e..5ff4c886 100644 --- a/close.go +++ b/close.go @@ -2,8 +2,9 @@ package pgproto3 import ( "bytes" - "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type Close struct { @@ -31,20 +32,18 @@ func (dst *Close) Decode(src []byte) error { return nil } -func (src *Close) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} +func (src *Close) Encode(dst []byte) []byte { + dst = append(dst, 'C') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) - buf.WriteByte('C') - buf.Write(bigEndian.Uint32(0)) + dst = append(dst, src.ObjectType) + dst = append(dst, src.Name...) + dst = append(dst, 0) - buf.WriteByte(src.ObjectType) - buf.WriteString(src.Name) - buf.WriteByte(0) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) - - return buf.Bytes(), nil + return dst } func (src *Close) MarshalJSON() ([]byte, error) { diff --git a/close_complete.go b/close_complete.go index 9bab3e8c..db793c94 100644 --- a/close_complete.go +++ b/close_complete.go @@ -16,8 +16,8 @@ func (dst *CloseComplete) Decode(src []byte) error { return nil } -func (src *CloseComplete) MarshalBinary() ([]byte, error) { - return []byte{'3', 0, 0, 0, 4}, nil +func (src *CloseComplete) Encode(dst []byte) []byte { + return append(dst, '3', 0, 0, 0, 4) } func (src *CloseComplete) MarshalJSON() ([]byte, error) { diff --git a/command_complete.go b/command_complete.go index 86653804..85848532 100644 --- a/command_complete.go +++ b/command_complete.go @@ -3,6 +3,8 @@ package pgproto3 import ( "bytes" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type CommandComplete struct { @@ -22,17 +24,17 @@ func (dst *CommandComplete) Decode(src []byte) error { return nil } -func (src *CommandComplete) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} +func (src *CommandComplete) Encode(dst []byte) []byte { + dst = append(dst, 'C') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) - buf.WriteByte('C') - buf.Write(bigEndian.Uint32(uint32(4 + len(src.CommandTag) + 1))) + dst = append(dst, src.CommandTag...) + dst = append(dst, 0) - buf.WriteString(src.CommandTag) - buf.WriteByte(0) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return buf.Bytes(), nil + return dst } func (src *CommandComplete) MarshalJSON() ([]byte, error) { diff --git a/copy_both_response.go b/copy_both_response.go index 3857c187..2862a34f 100644 --- a/copy_both_response.go +++ b/copy_both_response.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type CopyBothResponse struct { @@ -37,20 +39,19 @@ func (dst *CopyBothResponse) Decode(src []byte) error { return nil } -func (src *CopyBothResponse) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte('W') - buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) - - buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) +func (src *CopyBothResponse) Encode(dst []byte) []byte { + dst = append(dst, 'W') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { - buf.Write(bigEndian.Uint16(fc)) + dst = pgio.AppendUint16(dst, fc) } - return buf.Bytes(), nil + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst } func (src *CopyBothResponse) MarshalJSON() ([]byte, error) { diff --git a/copy_data.go b/copy_data.go index de7ab4ff..fab139e6 100644 --- a/copy_data.go +++ b/copy_data.go @@ -1,9 +1,10 @@ package pgproto3 import ( - "bytes" "encoding/hex" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type CopyData struct { @@ -18,15 +19,11 @@ func (dst *CopyData) Decode(src []byte) error { return nil } -func (src *CopyData) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte('d') - buf.Write(bigEndian.Uint32(uint32(4 + len(src.Data)))) - buf.Write(src.Data) - - return buf.Bytes(), nil +func (src *CopyData) Encode(dst []byte) []byte { + dst = append(dst, 'd') + dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) + dst = append(dst, src.Data...) + return dst } func (src *CopyData) MarshalJSON() ([]byte, error) { diff --git a/copy_in_response.go b/copy_in_response.go index 9854d665..54083cd6 100644 --- a/copy_in_response.go +++ b/copy_in_response.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type CopyInResponse struct { @@ -37,20 +39,19 @@ func (dst *CopyInResponse) Decode(src []byte) error { return nil } -func (src *CopyInResponse) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte('G') - buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) - - buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) +func (src *CopyInResponse) Encode(dst []byte) []byte { + dst = append(dst, 'G') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { - buf.Write(bigEndian.Uint16(fc)) + dst = pgio.AppendUint16(dst, fc) } - return buf.Bytes(), nil + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst } func (src *CopyInResponse) MarshalJSON() ([]byte, error) { diff --git a/copy_out_response.go b/copy_out_response.go index 5ef6e4c1..eaa33b8b 100644 --- a/copy_out_response.go +++ b/copy_out_response.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type CopyOutResponse struct { @@ -37,20 +39,19 @@ func (dst *CopyOutResponse) Decode(src []byte) error { return nil } -func (src *CopyOutResponse) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte('H') - buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) - - buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) +func (src *CopyOutResponse) Encode(dst []byte) []byte { + dst = append(dst, 'H') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { - buf.Write(bigEndian.Uint16(fc)) + dst = pgio.AppendUint16(dst, fc) } - return buf.Bytes(), nil + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst } func (src *CopyOutResponse) MarshalJSON() ([]byte, error) { diff --git a/data_row.go b/data_row.go index 3e600e84..e46d3cc0 100644 --- a/data_row.go +++ b/data_row.go @@ -1,10 +1,11 @@ package pgproto3 import ( - "bytes" "encoding/binary" "encoding/hex" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type DataRow struct { @@ -58,28 +59,25 @@ func (dst *DataRow) Decode(src []byte) error { return nil } -func (src *DataRow) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte('D') - buf.Write(bigEndian.Uint32(0)) - - buf.Write(bigEndian.Uint16(uint16(len(src.Values)))) +func (src *DataRow) Encode(dst []byte) []byte { + dst = append(dst, 'D') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint16(dst, uint16(len(src.Values))) for _, v := range src.Values { if v == nil { - buf.Write(bigEndian.Int32(-1)) + dst = pgio.AppendInt32(dst, -1) continue } - buf.Write(bigEndian.Int32(int32(len(v)))) - buf.Write(v) + dst = pgio.AppendInt32(dst, int32(len(v))) + dst = append(dst, v...) } - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return buf.Bytes(), nil + return dst } func (src *DataRow) MarshalJSON() ([]byte, error) { diff --git a/describe.go b/describe.go index ea55ed9d..bb7bc056 100644 --- a/describe.go +++ b/describe.go @@ -2,8 +2,9 @@ package pgproto3 import ( "bytes" - "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type Describe struct { @@ -31,20 +32,18 @@ func (dst *Describe) Decode(src []byte) error { return nil } -func (src *Describe) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} +func (src *Describe) Encode(dst []byte) []byte { + dst = append(dst, 'D') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) - buf.WriteByte('D') - buf.Write(bigEndian.Uint32(0)) + dst = append(dst, src.ObjectType) + dst = append(dst, src.Name...) + dst = append(dst, 0) - buf.WriteByte(src.ObjectType) - buf.WriteString(src.Name) - buf.WriteByte(0) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) - - return buf.Bytes(), nil + return dst } func (src *Describe) MarshalJSON() ([]byte, error) { diff --git a/empty_query_response.go b/empty_query_response.go index 13ed1886..d283b06d 100644 --- a/empty_query_response.go +++ b/empty_query_response.go @@ -16,8 +16,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error { return nil } -func (src *EmptyQueryResponse) MarshalBinary() ([]byte, error) { - return []byte{'I', 0, 0, 0, 4}, nil +func (src *EmptyQueryResponse) Encode(dst []byte) []byte { + return append(dst, 'I', 0, 0, 0, 4) } func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) { diff --git a/error_response.go b/error_response.go index 602dd2a1..160234f2 100644 --- a/error_response.go +++ b/error_response.go @@ -103,11 +103,11 @@ func (dst *ErrorResponse) Decode(src []byte) error { return nil } -func (src *ErrorResponse) MarshalBinary() ([]byte, error) { - return src.marshalBinary('E') +func (src *ErrorResponse) Encode(dst []byte) []byte { + return append(dst, src.marshalBinary('E')...) } -func (src *ErrorResponse) marshalBinary(typeByte byte) ([]byte, error) { +func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { var bigEndian BigEndianBuf buf := &bytes.Buffer{} @@ -193,5 +193,5 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) ([]byte, error) { binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) - return buf.Bytes(), nil + return buf.Bytes() } diff --git a/execute.go b/execute.go index 4892e7b3..76da9943 100644 --- a/execute.go +++ b/execute.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type Execute struct { @@ -30,21 +32,19 @@ func (dst *Execute) Decode(src []byte) error { return nil } -func (src *Execute) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} +func (src *Execute) Encode(dst []byte) []byte { + dst = append(dst, 'E') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) - buf.WriteByte('E') - buf.Write(bigEndian.Uint32(0)) + dst = append(dst, src.Portal...) + dst = append(dst, 0) - buf.WriteString(src.Portal) - buf.WriteByte(0) + dst = pgio.AppendUint32(dst, src.MaxRows) - buf.Write(bigEndian.Uint32(src.MaxRows)) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) - - return buf.Bytes(), nil + return dst } func (src *Execute) MarshalJSON() ([]byte, error) { diff --git a/flush.go b/flush.go index d26f5c0c..7fd5e987 100644 --- a/flush.go +++ b/flush.go @@ -16,8 +16,8 @@ func (dst *Flush) Decode(src []byte) error { return nil } -func (src *Flush) MarshalBinary() ([]byte, error) { - return []byte{'H', 0, 0, 0, 4}, nil +func (src *Flush) Encode(dst []byte) []byte { + return append(dst, 'H', 0, 0, 0, 4) } func (src *Flush) MarshalJSON() ([]byte, error) { diff --git a/frontend.go b/frontend.go index 27a9890a..630a5cba 100644 --- a/frontend.go +++ b/frontend.go @@ -42,12 +42,7 @@ func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { } func (b *Frontend) Send(msg FrontendMessage) error { - buf, err := msg.MarshalBinary() - if err != nil { - return nil - } - - _, err = b.w.Write(buf) + _, err := b.w.Write(msg.Encode(nil)) return err } diff --git a/function_call_response.go b/function_call_response.go index 1e0f16af..bb325b69 100644 --- a/function_call_response.go +++ b/function_call_response.go @@ -1,10 +1,11 @@ package pgproto3 import ( - "bytes" "encoding/binary" "encoding/hex" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type FunctionCallResponse struct { @@ -34,21 +35,21 @@ func (dst *FunctionCallResponse) Decode(src []byte) error { return nil } -func (src *FunctionCallResponse) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte('V') - buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Result)))) +func (src *FunctionCallResponse) Encode(dst []byte) []byte { + dst = append(dst, 'V') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) if src.Result == nil { - buf.Write(bigEndian.Int32(-1)) + dst = pgio.AppendInt32(dst, -1) } else { - buf.Write(bigEndian.Int32(int32(len(src.Result)))) - buf.Write(src.Result) + dst = pgio.AppendInt32(dst, int32(len(src.Result))) + dst = append(dst, src.Result...) } - return buf.Bytes(), nil + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst } func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) { diff --git a/no_data.go b/no_data.go index 3adec4ad..1fb47c2a 100644 --- a/no_data.go +++ b/no_data.go @@ -16,8 +16,8 @@ func (dst *NoData) Decode(src []byte) error { return nil } -func (src *NoData) MarshalBinary() ([]byte, error) { - return []byte{'n', 0, 0, 0, 4}, nil +func (src *NoData) Encode(dst []byte) []byte { + return append(dst, 'n', 0, 0, 0, 4) } func (src *NoData) MarshalJSON() ([]byte, error) { diff --git a/notice_response.go b/notice_response.go index 8af55baf..e4595aa5 100644 --- a/notice_response.go +++ b/notice_response.go @@ -8,6 +8,6 @@ func (dst *NoticeResponse) Decode(src []byte) error { return (*ErrorResponse)(dst).Decode(src) } -func (src *NoticeResponse) MarshalBinary() ([]byte, error) { - return (*ErrorResponse)(src).marshalBinary('N') +func (src *NoticeResponse) Encode(dst []byte) []byte { + return append(dst, (*ErrorResponse)(src).marshalBinary('N')...) } diff --git a/notification_response.go b/notification_response.go index 7262844e..b14007b4 100644 --- a/notification_response.go +++ b/notification_response.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type NotificationResponse struct { @@ -35,19 +37,19 @@ func (dst *NotificationResponse) Decode(src []byte) error { return nil } -func (src *NotificationResponse) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} +func (src *NotificationResponse) Encode(dst []byte) []byte { + dst = append(dst, 'A') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) - buf.WriteByte('A') - buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Channel) + len(src.Payload)))) + dst = append(dst, src.Channel...) + dst = append(dst, 0) + dst = append(dst, src.Payload...) + dst = append(dst, 0) - buf.WriteString(src.Channel) - buf.WriteByte(0) - buf.WriteString(src.Payload) - buf.WriteByte(0) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return buf.Bytes(), nil + return dst } func (src *NotificationResponse) MarshalJSON() ([]byte, error) { diff --git a/parameter_description.go b/parameter_description.go index 32b6e1c1..1fa3c927 100644 --- a/parameter_description.go +++ b/parameter_description.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type ParameterDescription struct { @@ -33,20 +35,19 @@ func (dst *ParameterDescription) Decode(src []byte) error { return nil } -func (src *ParameterDescription) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte('t') - buf.Write(bigEndian.Uint32(uint32(4 + 2 + 4*len(src.ParameterOIDs)))) - - buf.Write(bigEndian.Uint16(uint16(len(src.ParameterOIDs)))) +func (src *ParameterDescription) Encode(dst []byte) []byte { + dst = append(dst, 't') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) for _, oid := range src.ParameterOIDs { - buf.Write(bigEndian.Uint32(oid)) + dst = pgio.AppendUint32(dst, oid) } - return buf.Bytes(), nil + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst } func (src *ParameterDescription) MarshalJSON() ([]byte, error) { diff --git a/parameter_status.go b/parameter_status.go index 9b10824c..b3bac33f 100644 --- a/parameter_status.go +++ b/parameter_status.go @@ -2,8 +2,9 @@ package pgproto3 import ( "bytes" - "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type ParameterStatus struct { @@ -32,21 +33,19 @@ func (dst *ParameterStatus) Decode(src []byte) error { return nil } -func (src *ParameterStatus) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} +func (src *ParameterStatus) Encode(dst []byte) []byte { + dst = append(dst, 'S') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) - buf.WriteByte('S') - buf.Write(bigEndian.Uint32(0)) + dst = append(dst, src.Name...) + dst = append(dst, 0) + dst = append(dst, src.Value...) + dst = append(dst, 0) - buf.WriteString(src.Name) - buf.WriteByte(0) - buf.WriteString(src.Value) - buf.WriteByte(0) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) - - return buf.Bytes(), nil + return dst } func (ps *ParameterStatus) MarshalJSON() ([]byte, error) { diff --git a/parse.go b/parse.go index 5d17ed11..b8775547 100644 --- a/parse.go +++ b/parse.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type Parse struct { @@ -44,27 +46,24 @@ func (dst *Parse) Decode(src []byte) error { return nil } -func (src *Parse) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} +func (src *Parse) Encode(dst []byte) []byte { + dst = append(dst, 'P') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) - buf.WriteByte('P') - buf.Write(bigEndian.Uint32(0)) + dst = append(dst, src.Name...) + dst = append(dst, 0) + dst = append(dst, src.Query...) + dst = append(dst, 0) - buf.WriteString(src.Name) - buf.WriteByte(0) - buf.WriteString(src.Query) - buf.WriteByte(0) - - buf.Write(bigEndian.Uint16(uint16(len(src.ParameterOIDs)))) - - for _, v := range src.ParameterOIDs { - buf.Write(bigEndian.Uint32(v)) + dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) + for _, oid := range src.ParameterOIDs { + dst = pgio.AppendUint32(dst, oid) } - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return buf.Bytes(), nil + return dst } func (src *Parse) MarshalJSON() ([]byte, error) { diff --git a/parse_complete.go b/parse_complete.go index e949c14c..462a89ba 100644 --- a/parse_complete.go +++ b/parse_complete.go @@ -16,8 +16,8 @@ func (dst *ParseComplete) Decode(src []byte) error { return nil } -func (src *ParseComplete) MarshalBinary() ([]byte, error) { - return []byte{'1', 0, 0, 0, 4}, nil +func (src *ParseComplete) Encode(dst []byte) []byte { + return append(dst, '1', 0, 0, 0, 4) } func (src *ParseComplete) MarshalJSON() ([]byte, error) { diff --git a/password_message.go b/password_message.go index 69df6362..2ad3fe4a 100644 --- a/password_message.go +++ b/password_message.go @@ -3,6 +3,8 @@ package pgproto3 import ( "bytes" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type PasswordMessage struct { @@ -23,14 +25,14 @@ func (dst *PasswordMessage) Decode(src []byte) error { return nil } -func (src *PasswordMessage) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - buf.WriteByte('p') - buf.Write(bigEndian.Uint32(uint32(4 + len(src.Password) + 1))) - buf.WriteString(src.Password) - buf.WriteByte(0) - return buf.Bytes(), nil +func (src *PasswordMessage) Encode(dst []byte) []byte { + dst = append(dst, 'p') + dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1)) + + dst = append(dst, src.Password...) + dst = append(dst, 0) + + return dst } func (src *PasswordMessage) MarshalJSON() ([]byte, error) { diff --git a/pgproto3.go b/pgproto3.go index 3fe8fc93..fe7b085b 100644 --- a/pgproto3.go +++ b/pgproto3.go @@ -4,12 +4,13 @@ import "fmt" // Message is the interface implemented by an object that can decode and encode // a particular PostgreSQL message. -// -// Decode is allowed and expected to retain a reference to data after -// returning (unlike encoding.BinaryUnmarshaler). type Message interface { + // Decode is allowed and expected to retain a reference to data after + // returning (unlike encoding.BinaryUnmarshaler). Decode(data []byte) error - MarshalBinary() (data []byte, err error) + + // Encode appends itself to dst and returns the new buffer. + Encode(dst []byte) []byte } type FrontendMessage interface { diff --git a/query.go b/query.go index b5fc2dbc..d80c0fb4 100644 --- a/query.go +++ b/query.go @@ -3,6 +3,8 @@ package pgproto3 import ( "bytes" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type Query struct { @@ -22,14 +24,14 @@ func (dst *Query) Decode(src []byte) error { return nil } -func (src *Query) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - buf.WriteByte('Q') - buf.Write(bigEndian.Uint32(uint32(4 + len(src.String) + 1))) - buf.WriteString(src.String) - buf.WriteByte(0) - return buf.Bytes(), nil +func (src *Query) Encode(dst []byte) []byte { + dst = append(dst, 'Q') + dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1)) + + dst = append(dst, src.String...) + dst = append(dst, 0) + + return dst } func (src *Query) MarshalJSON() ([]byte, error) { diff --git a/ready_for_query.go b/ready_for_query.go index e0e4707a..63b902bd 100644 --- a/ready_for_query.go +++ b/ready_for_query.go @@ -20,8 +20,8 @@ func (dst *ReadyForQuery) Decode(src []byte) error { return nil } -func (src *ReadyForQuery) MarshalBinary() ([]byte, error) { - return []byte{'Z', 0, 0, 0, 5, src.TxStatus}, nil +func (src *ReadyForQuery) Encode(dst []byte) []byte { + return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus) } func (src *ReadyForQuery) MarshalJSON() ([]byte, error) { diff --git a/row_description.go b/row_description.go index b1110290..d0df11b0 100644 --- a/row_description.go +++ b/row_description.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) const ( @@ -64,30 +66,27 @@ func (dst *RowDescription) Decode(src []byte) error { return nil } -func (src *RowDescription) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte('T') - buf.Write(bigEndian.Uint32(0)) - - buf.Write(bigEndian.Uint16(uint16(len(src.Fields)))) +func (src *RowDescription) Encode(dst []byte) []byte { + dst = append(dst, 'T') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint16(dst, uint16(len(src.Fields))) for _, fd := range src.Fields { - buf.WriteString(fd.Name) - buf.WriteByte(0) + dst = append(dst, fd.Name...) + dst = append(dst, 0) - buf.Write(bigEndian.Uint32(fd.TableOID)) - buf.Write(bigEndian.Uint16(fd.TableAttributeNumber)) - buf.Write(bigEndian.Uint32(fd.DataTypeOID)) - buf.Write(bigEndian.Uint16(uint16(fd.DataTypeSize))) - buf.Write(bigEndian.Uint32(fd.TypeModifier)) - buf.Write(bigEndian.Uint16(uint16(fd.Format))) + dst = pgio.AppendUint32(dst, fd.TableOID) + dst = pgio.AppendUint16(dst, fd.TableAttributeNumber) + dst = pgio.AppendUint32(dst, fd.DataTypeOID) + dst = pgio.AppendInt16(dst, fd.DataTypeSize) + dst = pgio.AppendUint32(dst, fd.TypeModifier) + dst = pgio.AppendInt16(dst, fd.Format) } - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return buf.Bytes(), nil + return dst } func (src *RowDescription) MarshalJSON() ([]byte, error) { diff --git a/startup_message.go b/startup_message.go index 4847d629..4e2df27d 100644 --- a/startup_message.go +++ b/startup_message.go @@ -5,6 +5,8 @@ import ( "encoding/binary" "encoding/json" "fmt" + + "github.com/jackc/pgx/pgio" ) const ( @@ -64,22 +66,22 @@ func (dst *StartupMessage) Decode(src []byte) error { return nil } -func (src *StartupMessage) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - buf.Write(bigEndian.Uint32(0)) - buf.Write(bigEndian.Uint32(src.ProtocolVersion)) +func (src *StartupMessage) Encode(dst []byte) []byte { + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint32(dst, src.ProtocolVersion) for k, v := range src.Parameters { - buf.WriteString(k) - buf.WriteByte(0) - buf.WriteString(v) - buf.WriteByte(0) + dst = append(dst, k...) + dst = append(dst, 0) + dst = append(dst, v...) + dst = append(dst, 0) } - buf.WriteByte(0) + dst = append(dst, 0) - binary.BigEndian.PutUint32(buf.Bytes()[0:4], uint32(buf.Len())) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return buf.Bytes(), nil + return dst } func (src *StartupMessage) MarshalJSON() ([]byte, error) { diff --git a/sync.go b/sync.go index da3fa727..85f4749a 100644 --- a/sync.go +++ b/sync.go @@ -16,8 +16,8 @@ func (dst *Sync) Decode(src []byte) error { return nil } -func (src *Sync) MarshalBinary() ([]byte, error) { - return []byte{'S', 0, 0, 0, 4}, nil +func (src *Sync) Encode(dst []byte) []byte { + return append(dst, 'S', 0, 0, 0, 4) } func (src *Sync) MarshalJSON() ([]byte, error) { diff --git a/terminate.go b/terminate.go index 77977f20..0a3310da 100644 --- a/terminate.go +++ b/terminate.go @@ -16,8 +16,8 @@ func (dst *Terminate) Decode(src []byte) error { return nil } -func (src *Terminate) MarshalBinary() ([]byte, error) { - return []byte{'X', 0, 0, 0, 4}, nil +func (src *Terminate) Encode(dst []byte) []byte { + return append(dst, 'X', 0, 0, 0, 4) } func (src *Terminate) MarshalJSON() ([]byte, error) { From 8e404a02a32e9cc442175688667538ee09ef84ba Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 29 May 2017 11:24:49 -0500 Subject: [PATCH 0101/1158] Ensure pgproto3.Parse.Decode overwrites itself entirely --- parse.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/parse.go b/parse.go index b8775547..ca4834c6 100644 --- a/parse.go +++ b/parse.go @@ -17,6 +17,8 @@ type Parse struct { func (*Parse) Frontend() {} func (dst *Parse) Decode(src []byte) error { + *dst = Parse{} + buf := bytes.NewBuffer(src) b, err := buf.ReadBytes(0) From 2140814606183d59c1ab4e8be57d2f750e85aa02 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Jun 2017 11:53:49 -0500 Subject: [PATCH 0102/1158] Use Go casing convention for OID --- array.go | 6 +-- bool_array.go | 2 +- bytea_array.go | 2 +- cidr_array.go | 2 +- date_array.go | 2 +- float4_array.go | 2 +- float8_array.go | 2 +- hstore_array.go | 2 +- inet_array.go | 2 +- int2_array.go | 2 +- int4_array.go | 2 +- int8_array.go | 2 +- numeric_array.go | 2 +- oid.go | 30 ++++++------- oid_value.go | 28 ++++++------ oid_value_test.go | 30 ++++++------- pgtype.go | 104 +++++++++++++++++++++---------------------- record.go | 6 +-- text_array.go | 2 +- timestamp_array.go | 2 +- timestamptz_array.go | 2 +- typed_array.go.erb | 2 +- varchar_array.go | 2 +- 23 files changed, 119 insertions(+), 119 deletions(-) diff --git a/array.go b/array.go index 2f9ef66b..e5504455 100644 --- a/array.go +++ b/array.go @@ -18,7 +18,7 @@ import ( type ArrayHeader struct { ContainsNull bool - ElementOid int32 + ElementOID int32 Dimensions []ArrayDimension } @@ -40,7 +40,7 @@ func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1 rp += 4 - dst.ElementOid = int32(binary.BigEndian.Uint32(src[rp:])) + dst.ElementOID = int32(binary.BigEndian.Uint32(src[rp:])) rp += 4 if numDims > 0 { @@ -69,7 +69,7 @@ func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte { } buf = pgio.AppendInt32(buf, containsNull) - buf = pgio.AppendInt32(buf, src.ElementOid) + buf = pgio.AppendInt32(buf, src.ElementOID) for i := range src.Dimensions { buf = pgio.AppendInt32(buf, src.Dimensions[i].Length) diff --git a/bool_array.go b/bool_array.go index 3c3d4184..e20a0381 100644 --- a/bool_array.go +++ b/bool_array.go @@ -231,7 +231,7 @@ func (src *BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("bool"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "bool") } diff --git a/bytea_array.go b/bytea_array.go index 67e114f5..0d381693 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -231,7 +231,7 @@ func (src *ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("bytea"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "bytea") } diff --git a/cidr_array.go b/cidr_array.go index 01237aa1..b8a70d63 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -260,7 +260,7 @@ func (src *CidrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("cidr"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "cidr") } diff --git a/date_array.go b/date_array.go index 2175f2aa..ef91cf3e 100644 --- a/date_array.go +++ b/date_array.go @@ -232,7 +232,7 @@ func (src *DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("date"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "date") } diff --git a/float4_array.go b/float4_array.go index 37db8acc..a35657b0 100644 --- a/float4_array.go +++ b/float4_array.go @@ -231,7 +231,7 @@ func (src *Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("float4"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "float4") } diff --git a/float8_array.go b/float8_array.go index dd3fccf1..486e3a4e 100644 --- a/float8_array.go +++ b/float8_array.go @@ -231,7 +231,7 @@ func (src *Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("float8"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "float8") } diff --git a/hstore_array.go b/hstore_array.go index 2d61fa52..3e5a003f 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -231,7 +231,7 @@ func (src *HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("hstore"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "hstore") } diff --git a/inet_array.go b/inet_array.go index e448a2ca..57123c1c 100644 --- a/inet_array.go +++ b/inet_array.go @@ -260,7 +260,7 @@ func (src *InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("inet"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "inet") } diff --git a/int2_array.go b/int2_array.go index 1d145584..e4993104 100644 --- a/int2_array.go +++ b/int2_array.go @@ -259,7 +259,7 @@ func (src *Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("int2"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "int2") } diff --git a/int4_array.go b/int4_array.go index 1c746503..6bc06e86 100644 --- a/int4_array.go +++ b/int4_array.go @@ -259,7 +259,7 @@ func (src *Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("int4"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "int4") } diff --git a/int8_array.go b/int8_array.go index 56ebcab8..4404d22a 100644 --- a/int8_array.go +++ b/int8_array.go @@ -259,7 +259,7 @@ func (src *Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("int8"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "int8") } diff --git a/numeric_array.go b/numeric_array.go index 20f33dff..f193a2a5 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -259,7 +259,7 @@ func (src *NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) } if dt, ok := ci.DataTypeForName("numeric"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "numeric") } diff --git a/oid.go b/oid.go index 6ceacc73..d37f4e57 100644 --- a/oid.go +++ b/oid.go @@ -9,18 +9,18 @@ import ( "github.com/jackc/pgx/pgio" ) -// Oid (Object Identifier Type) is, according to +// OID (Object Identifier Type) is, according to // https://www.postgresql.org/docs/current/static/datatype-oid.html, used // internally by PostgreSQL as a primary key for various system tables. It is // currently implemented as an unsigned four-byte integer. Its definition can be // found in src/include/postgres_ext.h in the PostgreSQL sources. Because it is -// so frequently required to be in a NOT NULL condition Oid cannot be NULL. To -// allow for NULL Oids use OidValue. -type Oid uint32 +// so frequently required to be in a NOT NULL condition OID cannot be NULL. To +// allow for NULL OIDs use OIDValue. +type OID uint32 -func (dst *Oid) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *OID) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - return fmt.Errorf("cannot decode nil into Oid") + return fmt.Errorf("cannot decode nil into OID") } n, err := strconv.ParseUint(string(src), 10, 32) @@ -28,13 +28,13 @@ func (dst *Oid) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Oid(n) + *dst = OID(n) return nil } -func (dst *Oid) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *OID) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - return fmt.Errorf("cannot decode nil into Oid") + return fmt.Errorf("cannot decode nil into OID") } if len(src) != 4 { @@ -42,27 +42,27 @@ func (dst *Oid) DecodeBinary(ci *ConnInfo, src []byte) error { } n := binary.BigEndian.Uint32(src) - *dst = Oid(n) + *dst = OID(n) return nil } -func (src Oid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src OID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, strconv.FormatUint(uint64(src), 10)...), nil } -func (src Oid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src OID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return pgio.AppendUint32(buf, uint32(src)), nil } // Scan implements the database/sql Scanner interface. -func (dst *Oid) Scan(src interface{}) error { +func (dst *OID) Scan(src interface{}) error { if src == nil { return fmt.Errorf("cannot scan NULL into %T", src) } switch src := src.(type) { case int64: - *dst = Oid(src) + *dst = OID(src) return nil case string: return dst.DecodeText(nil, []byte(src)) @@ -76,6 +76,6 @@ func (dst *Oid) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Oid) Value() (driver.Value, error) { +func (src OID) Value() (driver.Value, error) { return int64(src), nil } diff --git a/oid_value.go b/oid_value.go index 882d54fb..7eae4bf1 100644 --- a/oid_value.go +++ b/oid_value.go @@ -4,52 +4,52 @@ import ( "database/sql/driver" ) -// OidValue (Object Identifier Type) is, according to -// https://www.postgresql.org/docs/current/static/datatype-OidValue.html, used +// OIDValue (Object Identifier Type) is, according to +// https://www.postgresql.org/docs/current/static/datatype-OIDValue.html, used // internally by PostgreSQL as a primary key for various system tables. It is // currently implemented as an unsigned four-byte integer. Its definition can be // found in src/include/postgres_ext.h in the PostgreSQL sources. -type OidValue pguint32 +type OIDValue pguint32 -// Set converts from src to dst. Note that as OidValue is not a general +// Set converts from src to dst. Note that as OIDValue is not a general // number type Set does not do automatic type conversion as other number // types do. -func (dst *OidValue) Set(src interface{}) error { +func (dst *OIDValue) Set(src interface{}) error { return (*pguint32)(dst).Set(src) } -func (dst *OidValue) Get() interface{} { +func (dst *OIDValue) Get() interface{} { return (*pguint32)(dst).Get() } -// AssignTo assigns from src to dst. Note that as OidValue is not a general number +// AssignTo assigns from src to dst. Note that as OIDValue is not a general number // type AssignTo does not do automatic type conversion as other number types do. -func (src *OidValue) AssignTo(dst interface{}) error { +func (src *OIDValue) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *OidValue) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *OIDValue) DecodeText(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *OidValue) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *OIDValue) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src *OidValue) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *OIDValue) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return (*pguint32)(src).EncodeText(ci, buf) } -func (src *OidValue) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *OIDValue) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return (*pguint32)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. -func (dst *OidValue) Scan(src interface{}) error { +func (dst *OIDValue) Scan(src interface{}) error { return (*pguint32)(dst).Scan(src) } // Value implements the database/sql/driver Valuer interface. -func (src *OidValue) Value() (driver.Value, error) { +func (src *OIDValue) Value() (driver.Value, error) { return (*pguint32)(src).Value() } diff --git a/oid_value_test.go b/oid_value_test.go index 52ce4064..f5ff16cf 100644 --- a/oid_value_test.go +++ b/oid_value_test.go @@ -8,23 +8,23 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestOidValueTranscode(t *testing.T) { +func TestOIDValueTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "oid", []interface{}{ - &pgtype.OidValue{Uint: 42, Status: pgtype.Present}, - &pgtype.OidValue{Status: pgtype.Null}, + &pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, + &pgtype.OIDValue{Status: pgtype.Null}, }) } -func TestOidValueSet(t *testing.T) { +func TestOIDValueSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.OidValue + result pgtype.OIDValue }{ - {source: uint32(1), result: pgtype.OidValue{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.OIDValue{Uint: 1, Status: pgtype.Present}}, } for i, tt := range successfulTests { - var r pgtype.OidValue + var r pgtype.OIDValue err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -36,17 +36,17 @@ func TestOidValueSet(t *testing.T) { } } -func TestOidValueAssignTo(t *testing.T) { +func TestOIDValueAssignTo(t *testing.T) { var ui32 uint32 var pui32 *uint32 simpleTests := []struct { - src pgtype.OidValue + src pgtype.OIDValue dst interface{} expected interface{} }{ - {src: pgtype.OidValue{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.OidValue{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.OIDValue{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -61,11 +61,11 @@ func TestOidValueAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.OidValue + src pgtype.OIDValue dst interface{} expected interface{} }{ - {src: pgtype.OidValue{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -80,10 +80,10 @@ func TestOidValueAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.OidValue + src pgtype.OIDValue dst interface{} }{ - {src: pgtype.OidValue{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.OIDValue{Status: pgtype.Null}, dst: &ui32}, } for i, tt := range errorTests { diff --git a/pgtype.go b/pgtype.go index 847fce0f..4c1e86f6 100644 --- a/pgtype.go +++ b/pgtype.go @@ -7,47 +7,47 @@ import ( // PostgreSQL oids for common types const ( - BoolOid = 16 - ByteaOid = 17 - CharOid = 18 - NameOid = 19 - Int8Oid = 20 - Int2Oid = 21 - Int4Oid = 23 - TextOid = 25 - OidOid = 26 - TidOid = 27 - XidOid = 28 - CidOid = 29 - JsonOid = 114 - CidrOid = 650 - CidrArrayOid = 651 - Float4Oid = 700 - Float8Oid = 701 - UnknownOid = 705 - InetOid = 869 - BoolArrayOid = 1000 - Int2ArrayOid = 1005 - Int4ArrayOid = 1007 - TextArrayOid = 1009 - ByteaArrayOid = 1001 - VarcharArrayOid = 1015 - Int8ArrayOid = 1016 - Float4ArrayOid = 1021 - Float8ArrayOid = 1022 - AclitemOid = 1033 - AclitemArrayOid = 1034 - InetArrayOid = 1041 - VarcharOid = 1043 - DateOid = 1082 - TimestampOid = 1114 - TimestampArrayOid = 1115 - DateArrayOid = 1182 - TimestamptzOid = 1184 - TimestamptzArrayOid = 1185 - RecordOid = 2249 - UuidOid = 2950 - JsonbOid = 3802 + BoolOID = 16 + ByteaOID = 17 + CharOID = 18 + NameOID = 19 + Int8OID = 20 + Int2OID = 21 + Int4OID = 23 + TextOID = 25 + OIDOID = 26 + TidOID = 27 + XidOID = 28 + CidOID = 29 + JsonOID = 114 + CidrOID = 650 + CidrArrayOID = 651 + Float4OID = 700 + Float8OID = 701 + UnknownOID = 705 + InetOID = 869 + BoolArrayOID = 1000 + Int2ArrayOID = 1005 + Int4ArrayOID = 1007 + TextArrayOID = 1009 + ByteaArrayOID = 1001 + VarcharArrayOID = 1015 + Int8ArrayOID = 1016 + Float4ArrayOID = 1021 + Float8ArrayOID = 1022 + AclitemOID = 1033 + AclitemArrayOID = 1034 + InetArrayOID = 1041 + VarcharOID = 1043 + DateOID = 1082 + TimestampOID = 1114 + TimestampArrayOID = 1115 + DateArrayOID = 1182 + TimestamptzOID = 1184 + TimestamptzArrayOID = 1185 + RecordOID = 2249 + UuidOID = 2950 + JsonbOID = 3802 ) type Status byte @@ -133,42 +133,42 @@ var errBadStatus = errors.New("invalid status") type DataType struct { Value Value Name string - Oid Oid + OID OID } type ConnInfo struct { - oidToDataType map[Oid]*DataType + oidToDataType map[OID]*DataType nameToDataType map[string]*DataType reflectTypeToDataType map[reflect.Type]*DataType } func NewConnInfo() *ConnInfo { return &ConnInfo{ - oidToDataType: make(map[Oid]*DataType, 256), + oidToDataType: make(map[OID]*DataType, 256), nameToDataType: make(map[string]*DataType, 256), reflectTypeToDataType: make(map[reflect.Type]*DataType, 256), } } -func (ci *ConnInfo) InitializeDataTypes(nameOids map[string]Oid) { - for name, oid := range nameOids { +func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) { + for name, oid := range nameOIDs { var value Value if t, ok := nameValues[name]; ok { value = reflect.New(reflect.ValueOf(t).Elem().Type()).Interface().(Value) } else { value = &GenericText{} } - ci.RegisterDataType(DataType{Value: value, Name: name, Oid: oid}) + ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid}) } } func (ci *ConnInfo) RegisterDataType(t DataType) { - ci.oidToDataType[t.Oid] = &t + ci.oidToDataType[t.OID] = &t ci.nameToDataType[t.Name] = &t ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t } -func (ci *ConnInfo) DataTypeForOid(oid Oid) (*DataType, bool) { +func (ci *ConnInfo) DataTypeForOID(oid OID) (*DataType, bool) { dt, ok := ci.oidToDataType[oid] return dt, ok } @@ -186,7 +186,7 @@ func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) { // DeepCopy makes a deep copy of the ConnInfo. func (ci *ConnInfo) DeepCopy() *ConnInfo { ci2 := &ConnInfo{ - oidToDataType: make(map[Oid]*DataType, len(ci.oidToDataType)), + oidToDataType: make(map[OID]*DataType, len(ci.oidToDataType)), nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)), reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)), } @@ -195,7 +195,7 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { ci2.RegisterDataType(DataType{ Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value), Name: dt.Name, - Oid: dt.Oid, + OID: dt.OID, }) } @@ -250,7 +250,7 @@ func init() { "name": &Name{}, "numeric": &Numeric{}, "numrange": &Numrange{}, - "oid": &OidValue{}, + "oid": &OIDValue{}, "path": &Path{}, "point": &Point{}, "polygon": &Polygon{}, diff --git a/record.go b/record.go index 3b315d40..7c8736df 100644 --- a/record.go +++ b/record.go @@ -88,16 +88,16 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { if len(src[rp:]) < 8 { return fmt.Errorf("Record incomplete %v", src) } - fieldOid := Oid(binary.BigEndian.Uint32(src[rp:])) + fieldOID := OID(binary.BigEndian.Uint32(src[rp:])) rp += 4 fieldLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 var binaryDecoder BinaryDecoder - if dt, ok := ci.DataTypeForOid(fieldOid); ok { + if dt, ok := ci.DataTypeForOID(fieldOID); ok { if binaryDecoder, ok = dt.Value.(BinaryDecoder); !ok { - return fmt.Errorf("unknown oid while decoding record: %v", fieldOid) + return fmt.Errorf("unknown oid while decoding record: %v", fieldOID) } } diff --git a/text_array.go b/text_array.go index ed240e12..dab7d36e 100644 --- a/text_array.go +++ b/text_array.go @@ -231,7 +231,7 @@ func (src *TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("text"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "text") } diff --git a/timestamp_array.go b/timestamp_array.go index a4f1b9dd..fca9ad93 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -232,7 +232,7 @@ func (src *TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error } if dt, ok := ci.DataTypeForName("timestamp"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "timestamp") } diff --git a/timestamptz_array.go b/timestamptz_array.go index 34d4f8a8..e0866d69 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -232,7 +232,7 @@ func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err } if dt, ok := ci.DataTypeForName("timestamptz"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "timestamptz") } diff --git a/typed_array.go.erb b/typed_array.go.erb index 0d454ac8..01072549 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -234,7 +234,7 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byt } if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") } diff --git a/varchar_array.go b/varchar_array.go index c34ac0b6..95b5cfc1 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -231,7 +231,7 @@ func (src *VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) } if dt, ok := ci.DataTypeForName("varchar"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "varchar") } From 496c5a4dff03b8d8277605a2198bf14550c8b180 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Jun 2017 11:54:57 -0500 Subject: [PATCH 0103/1158] Use Go casing convention for UUID --- ext/satori-uuid/uuid.go | 52 +++++++++++++++---------------- ext/satori-uuid/uuid_test.go | 26 ++++++++-------- pgtype.go | 4 +-- uuid.go | 60 ++++++++++++++++++------------------ uuid_test.go | 26 ++++++++-------- 5 files changed, 84 insertions(+), 84 deletions(-) diff --git a/ext/satori-uuid/uuid.go b/ext/satori-uuid/uuid.go index cff98348..b7b776f9 100644 --- a/ext/satori-uuid/uuid.go +++ b/ext/satori-uuid/uuid.go @@ -11,43 +11,43 @@ import ( var errUndefined = errors.New("cannot encode status undefined") -type Uuid struct { +type UUID struct { UUID uuid.UUID Status pgtype.Status } -func (dst *Uuid) Set(src interface{}) error { +func (dst *UUID) Set(src interface{}) error { switch value := src.(type) { case uuid.UUID: - *dst = Uuid{UUID: value, Status: pgtype.Present} + *dst = UUID{UUID: value, Status: pgtype.Present} case [16]byte: - *dst = Uuid{UUID: uuid.UUID(value), Status: pgtype.Present} + *dst = UUID{UUID: uuid.UUID(value), Status: pgtype.Present} case []byte: if len(value) != 16 { - return fmt.Errorf("[]byte must be 16 bytes to convert to Uuid: %d", len(value)) + return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) } - *dst = Uuid{Status: pgtype.Present} + *dst = UUID{Status: pgtype.Present} copy(dst.UUID[:], value) case string: uuid, err := uuid.FromString(value) if err != nil { return err } - *dst = Uuid{UUID: uuid, Status: pgtype.Present} + *dst = UUID{UUID: uuid, Status: pgtype.Present} default: - // If all else fails see if pgtype.Uuid can handle it. If so, translate through that. - pgUuid := &pgtype.Uuid{} - if err := pgUuid.Set(value); err != nil { - return fmt.Errorf("cannot convert %v to Uuid", value) + // If all else fails see if pgtype.UUID can handle it. If so, translate through that. + pgUUID := &pgtype.UUID{} + if err := pgUUID.Set(value); err != nil { + return fmt.Errorf("cannot convert %v to UUID", value) } - *dst = Uuid{UUID: uuid.UUID(pgUuid.Bytes), Status: pgUuid.Status} + *dst = UUID{UUID: uuid.UUID(pgUUID.Bytes), Status: pgUUID.Status} } return nil } -func (dst *Uuid) Get() interface{} { +func (dst *UUID) Get() interface{} { switch dst.Status { case pgtype.Present: return dst.UUID @@ -58,7 +58,7 @@ func (dst *Uuid) Get() interface{} { } } -func (src *Uuid) AssignTo(dst interface{}) error { +func (src *UUID) AssignTo(dst interface{}) error { switch src.Status { case pgtype.Present: switch v := dst.(type) { @@ -86,9 +86,9 @@ func (src *Uuid) AssignTo(dst interface{}) error { return fmt.Errorf("cannot assign %v into %T", src, dst) } -func (dst *Uuid) DecodeText(ci *pgtype.ConnInfo, src []byte) error { +func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { if src == nil { - *dst = Uuid{Status: pgtype.Null} + *dst = UUID{Status: pgtype.Null} return nil } @@ -97,26 +97,26 @@ func (dst *Uuid) DecodeText(ci *pgtype.ConnInfo, src []byte) error { return err } - *dst = Uuid{UUID: u, Status: pgtype.Present} + *dst = UUID{UUID: u, Status: pgtype.Present} return nil } -func (dst *Uuid) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { +func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { if src == nil { - *dst = Uuid{Status: pgtype.Null} + *dst = UUID{Status: pgtype.Null} return nil } if len(src) != 16 { - return fmt.Errorf("invalid length for Uuid: %v", len(src)) + return fmt.Errorf("invalid length for UUID: %v", len(src)) } - *dst = Uuid{Status: pgtype.Present} + *dst = UUID{Status: pgtype.Present} copy(dst.UUID[:], src) return nil } -func (src *Uuid) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { +func (src *UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case pgtype.Null: return nil, nil @@ -127,7 +127,7 @@ func (src *Uuid) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { return append(buf, src.UUID.String()...), nil } -func (src *Uuid) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { +func (src *UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case pgtype.Null: return nil, nil @@ -139,9 +139,9 @@ func (src *Uuid) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Uuid) Scan(src interface{}) error { +func (dst *UUID) Scan(src interface{}) error { if src == nil { - *dst = Uuid{Status: pgtype.Null} + *dst = UUID{Status: pgtype.Null} return nil } @@ -156,6 +156,6 @@ func (dst *Uuid) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Uuid) Value() (driver.Value, error) { +func (src *UUID) Value() (driver.Value, error) { return pgtype.EncodeValueText(src) } diff --git a/ext/satori-uuid/uuid_test.go b/ext/satori-uuid/uuid_test.go index 993fb837..02ebb770 100644 --- a/ext/satori-uuid/uuid_test.go +++ b/ext/satori-uuid/uuid_test.go @@ -9,34 +9,34 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestUuidTranscode(t *testing.T) { +func TestUUIDTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - &satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - &satori.Uuid{Status: pgtype.Null}, + &satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + &satori.UUID{Status: pgtype.Null}, }) } -func TestUuidSet(t *testing.T) { +func TestUUIDSet(t *testing.T) { successfulTests := []struct { source interface{} - result satori.Uuid + result satori.UUID }{ { source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, { source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, { source: "00010203-0405-0607-0809-0a0b0c0d0e0f", - result: satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, } for i, tt := range successfulTests { - var r satori.Uuid + var r satori.UUID err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -48,9 +48,9 @@ func TestUuidSet(t *testing.T) { } } -func TestUuidAssignTo(t *testing.T) { +func TestUUIDAssignTo(t *testing.T) { { - src := satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst [16]byte expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} @@ -65,7 +65,7 @@ func TestUuidAssignTo(t *testing.T) { } { - src := satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst []byte expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} @@ -80,7 +80,7 @@ func TestUuidAssignTo(t *testing.T) { } { - src := satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst string expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" diff --git a/pgtype.go b/pgtype.go index 4c1e86f6..60fab59f 100644 --- a/pgtype.go +++ b/pgtype.go @@ -46,7 +46,7 @@ const ( TimestamptzOID = 1184 TimestamptzArrayOID = 1185 RecordOID = 2249 - UuidOID = 2950 + UUIDOID = 2950 JsonbOID = 3802 ) @@ -262,7 +262,7 @@ func init() { "tsrange": &Tsrange{}, "tstzrange": &Tstzrange{}, "unknown": &Unknown{}, - "uuid": &Uuid{}, + "uuid": &UUID{}, "varbit": &Varbit{}, "varchar": &Varchar{}, "xid": &Xid{}, diff --git a/uuid.go b/uuid.go index c73c501e..d1ab1a38 100644 --- a/uuid.go +++ b/uuid.go @@ -6,38 +6,38 @@ import ( "fmt" ) -type Uuid struct { +type UUID struct { Bytes [16]byte Status Status } -func (dst *Uuid) Set(src interface{}) error { +func (dst *UUID) Set(src interface{}) error { switch value := src.(type) { case [16]byte: - *dst = Uuid{Bytes: value, Status: Present} + *dst = UUID{Bytes: value, Status: Present} case []byte: if len(value) != 16 { - return fmt.Errorf("[]byte must be 16 bytes to convert to Uuid: %d", len(value)) + return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) } - *dst = Uuid{Status: Present} + *dst = UUID{Status: Present} copy(dst.Bytes[:], value) case string: - uuid, err := parseUuid(value) + uuid, err := parseUUID(value) if err != nil { return err } - *dst = Uuid{Bytes: uuid, Status: Present} + *dst = UUID{Bytes: uuid, Status: Present} default: if originalSrc, ok := underlyingPtrType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Uuid", value) + return fmt.Errorf("cannot convert %v to UUID", value) } return nil } -func (dst *Uuid) Get() interface{} { +func (dst *UUID) Get() interface{} { switch dst.Status { case Present: return dst.Bytes @@ -48,7 +48,7 @@ func (dst *Uuid) Get() interface{} { } } -func (src *Uuid) AssignTo(dst interface{}) error { +func (src *UUID) AssignTo(dst interface{}) error { switch src.Status { case Present: switch v := dst.(type) { @@ -60,7 +60,7 @@ func (src *Uuid) AssignTo(dst interface{}) error { copy(*v, src.Bytes[:]) return nil case *string: - *v = encodeUuid(src.Bytes) + *v = encodeUUID(src.Bytes) return nil default: if nextDst, retry := GetAssignToDstType(v); retry { @@ -74,8 +74,8 @@ func (src *Uuid) AssignTo(dst interface{}) error { return fmt.Errorf("cannot assign %v into %T", src, dst) } -// parseUuid converts a string UUID in standard form to a byte array. -func parseUuid(src string) (dst [16]byte, err error) { +// parseUUID converts a string UUID in standard form to a byte array. +func parseUUID(src string) (dst [16]byte, err error) { src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:] buf, err := hex.DecodeString(src) if err != nil { @@ -86,46 +86,46 @@ func parseUuid(src string) (dst [16]byte, err error) { return dst, err } -// encodeUuid converts a uuid byte array to UUID standard string form. -func encodeUuid(src [16]byte) string { +// encodeUUID converts a uuid byte array to UUID standard string form. +func encodeUUID(src [16]byte) string { return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16]) } -func (dst *Uuid) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Uuid{Status: Null} + *dst = UUID{Status: Null} return nil } if len(src) != 36 { - return fmt.Errorf("invalid length for Uuid: %v", len(src)) + return fmt.Errorf("invalid length for UUID: %v", len(src)) } - buf, err := parseUuid(string(src)) + buf, err := parseUUID(string(src)) if err != nil { return err } - *dst = Uuid{Bytes: buf, Status: Present} + *dst = UUID{Bytes: buf, Status: Present} return nil } -func (dst *Uuid) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Uuid{Status: Null} + *dst = UUID{Status: Null} return nil } if len(src) != 16 { - return fmt.Errorf("invalid length for Uuid: %v", len(src)) + return fmt.Errorf("invalid length for UUID: %v", len(src)) } - *dst = Uuid{Status: Present} + *dst = UUID{Status: Present} copy(dst.Bytes[:], src) return nil } -func (src *Uuid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *UUID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -133,10 +133,10 @@ func (src *Uuid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } - return append(buf, encodeUuid(src.Bytes)...), nil + return append(buf, encodeUUID(src.Bytes)...), nil } -func (src *Uuid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *UUID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -148,9 +148,9 @@ func (src *Uuid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Uuid) Scan(src interface{}) error { +func (dst *UUID) Scan(src interface{}) error { if src == nil { - *dst = Uuid{Status: Null} + *dst = UUID{Status: Null} return nil } @@ -167,6 +167,6 @@ func (dst *Uuid) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Uuid) Value() (driver.Value, error) { +func (src *UUID) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/uuid_test.go b/uuid_test.go index 4c6ad2cd..5ab52b35 100644 --- a/uuid_test.go +++ b/uuid_test.go @@ -8,34 +8,34 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestUuidTranscode(t *testing.T) { +func TestUUIDTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - &pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - &pgtype.Uuid{Status: pgtype.Null}, + &pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + &pgtype.UUID{Status: pgtype.Null}, }) } -func TestUuidSet(t *testing.T) { +func TestUUIDSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.Uuid + result pgtype.UUID }{ { source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, { source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, { source: "00010203-0405-0607-0809-0a0b0c0d0e0f", - result: pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, } for i, tt := range successfulTests { - var r pgtype.Uuid + var r pgtype.UUID err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -47,9 +47,9 @@ func TestUuidSet(t *testing.T) { } } -func TestUuidAssignTo(t *testing.T) { +func TestUUIDAssignTo(t *testing.T) { { - src := pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst [16]byte expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} @@ -64,7 +64,7 @@ func TestUuidAssignTo(t *testing.T) { } { - src := pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst []byte expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} @@ -79,7 +79,7 @@ func TestUuidAssignTo(t *testing.T) { } { - src := pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst string expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" From aab8b77215e38bb52b163866354f5c582b668db8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Jun 2017 11:57:14 -0500 Subject: [PATCH 0104/1158] Use Go casing convention for JSON(B) --- json.go | 40 +++++++++++++++++++-------------------- json_test.go | 52 +++++++++++++++++++++++++-------------------------- jsonb.go | 38 ++++++++++++++++++------------------- jsonb_test.go | 52 +++++++++++++++++++++++++-------------------------- pgtype.go | 8 ++++---- 5 files changed, 95 insertions(+), 95 deletions(-) diff --git a/json.go b/json.go index 91d31129..ee00e9a4 100644 --- a/json.go +++ b/json.go @@ -6,44 +6,44 @@ import ( "fmt" ) -type Json struct { +type JSON struct { Bytes []byte Status Status } -func (dst *Json) Set(src interface{}) error { +func (dst *JSON) Set(src interface{}) error { if src == nil { - *dst = Json{Status: Null} + *dst = JSON{Status: Null} return nil } switch value := src.(type) { case string: - *dst = Json{Bytes: []byte(value), Status: Present} + *dst = JSON{Bytes: []byte(value), Status: Present} case *string: if value == nil { - *dst = Json{Status: Null} + *dst = JSON{Status: Null} } else { - *dst = Json{Bytes: []byte(*value), Status: Present} + *dst = JSON{Bytes: []byte(*value), Status: Present} } case []byte: if value == nil { - *dst = Json{Status: Null} + *dst = JSON{Status: Null} } else { - *dst = Json{Bytes: value, Status: Present} + *dst = JSON{Bytes: value, Status: Present} } default: buf, err := json.Marshal(value) if err != nil { return err } - *dst = Json{Bytes: buf, Status: Present} + *dst = JSON{Bytes: buf, Status: Present} } return nil } -func (dst *Json) Get() interface{} { +func (dst *JSON) Get() interface{} { switch dst.Status { case Present: var i interface{} @@ -59,7 +59,7 @@ func (dst *Json) Get() interface{} { } } -func (src *Json) AssignTo(dst interface{}) error { +func (src *JSON) AssignTo(dst interface{}) error { switch v := dst.(type) { case *string: if src.Status != Present { @@ -90,21 +90,21 @@ func (src *Json) AssignTo(dst interface{}) error { return nil } -func (dst *Json) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *JSON) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Json{Status: Null} + *dst = JSON{Status: Null} return nil } - *dst = Json{Bytes: src, Status: Present} + *dst = JSON{Bytes: src, Status: Present} return nil } -func (dst *Json) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *JSON) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } -func (src *Json) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *JSON) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -115,14 +115,14 @@ func (src *Json) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, src.Bytes...), nil } -func (src *Json) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *JSON) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return src.EncodeText(ci, buf) } // Scan implements the database/sql Scanner interface. -func (dst *Json) Scan(src interface{}) error { +func (dst *JSON) Scan(src interface{}) error { if src == nil { - *dst = Json{Status: Null} + *dst = JSON{Status: Null} return nil } @@ -139,7 +139,7 @@ func (dst *Json) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Json) Value() (driver.Value, error) { +func (src *JSON) Value() (driver.Value, error) { switch src.Status { case Present: return string(src.Bytes), nil diff --git a/json_test.go b/json_test.go index 3d8d2a68..82c02539 100644 --- a/json_test.go +++ b/json_test.go @@ -9,31 +9,31 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestJsonTranscode(t *testing.T) { +func TestJSONTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "json", []interface{}{ - &pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, - &pgtype.Json{Bytes: []byte("null"), Status: pgtype.Present}, - &pgtype.Json{Bytes: []byte("42"), Status: pgtype.Present}, - &pgtype.Json{Bytes: []byte(`"hello"`), Status: pgtype.Present}, - &pgtype.Json{Status: pgtype.Null}, + &pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, + &pgtype.JSON{Bytes: []byte("null"), Status: pgtype.Present}, + &pgtype.JSON{Bytes: []byte("42"), Status: pgtype.Present}, + &pgtype.JSON{Bytes: []byte(`"hello"`), Status: pgtype.Present}, + &pgtype.JSON{Status: pgtype.Null}, }) } -func TestJsonSet(t *testing.T) { +func TestJSONSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.Json + result pgtype.JSON }{ - {source: "{}", result: pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: []byte("{}"), result: pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: ([]byte)(nil), result: pgtype.Json{Status: pgtype.Null}}, - {source: (*string)(nil), result: pgtype.Json{Status: pgtype.Null}}, - {source: []int{1, 2, 3}, result: pgtype.Json{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, - {source: map[string]interface{}{"foo": "bar"}, result: pgtype.Json{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, + {source: "{}", result: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: []byte("{}"), result: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: ([]byte)(nil), result: pgtype.JSON{Status: pgtype.Null}}, + {source: (*string)(nil), result: pgtype.JSON{Status: pgtype.Null}}, + {source: []int{1, 2, 3}, result: pgtype.JSON{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, + {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, } for i, tt := range successfulTests { - var d pgtype.Json + var d pgtype.JSON err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -45,17 +45,17 @@ func TestJsonSet(t *testing.T) { } } -func TestJsonAssignTo(t *testing.T) { +func TestJSONAssignTo(t *testing.T) { var s string var ps *string var b []byte rawStringTests := []struct { - src pgtype.Json + src pgtype.JSON dst *string expected string }{ - {src: pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, + {src: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, } for i, tt := range rawStringTests { @@ -70,12 +70,12 @@ func TestJsonAssignTo(t *testing.T) { } rawBytesTests := []struct { - src pgtype.Json + src pgtype.JSON dst *[]byte expected []byte }{ - {src: pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, - {src: pgtype.Json{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, + {src: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, + {src: pgtype.JSON{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, } for i, tt := range rawBytesTests { @@ -97,12 +97,12 @@ func TestJsonAssignTo(t *testing.T) { var strDst structDst unmarshalTests := []struct { - src pgtype.Json + src pgtype.JSON dst interface{} expected interface{} }{ - {src: pgtype.Json{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, - {src: pgtype.Json{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + {src: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, + {src: pgtype.JSON{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, } for i, tt := range unmarshalTests { err := tt.src.AssignTo(tt.dst) @@ -116,11 +116,11 @@ func TestJsonAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.Json + src pgtype.JSON dst **string expected *string }{ - {src: pgtype.Json{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + {src: pgtype.JSON{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, } for i, tt := range pointerAllocTests { diff --git a/jsonb.go b/jsonb.go index f7914202..9a06c1b4 100644 --- a/jsonb.go +++ b/jsonb.go @@ -5,27 +5,27 @@ import ( "fmt" ) -type Jsonb Json +type JSONB JSON -func (dst *Jsonb) Set(src interface{}) error { - return (*Json)(dst).Set(src) +func (dst *JSONB) Set(src interface{}) error { + return (*JSON)(dst).Set(src) } -func (dst *Jsonb) Get() interface{} { - return (*Json)(dst).Get() +func (dst *JSONB) Get() interface{} { + return (*JSON)(dst).Get() } -func (src *Jsonb) AssignTo(dst interface{}) error { - return (*Json)(src).AssignTo(dst) +func (src *JSONB) AssignTo(dst interface{}) error { + return (*JSON)(src).AssignTo(dst) } -func (dst *Jsonb) DecodeText(ci *ConnInfo, src []byte) error { - return (*Json)(dst).DecodeText(ci, src) +func (dst *JSONB) DecodeText(ci *ConnInfo, src []byte) error { + return (*JSON)(dst).DecodeText(ci, src) } -func (dst *Jsonb) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Jsonb{Status: Null} + *dst = JSONB{Status: Null} return nil } @@ -37,16 +37,16 @@ func (dst *Jsonb) DecodeBinary(ci *ConnInfo, src []byte) error { return fmt.Errorf("unknown jsonb version number %d", src[0]) } - *dst = Jsonb{Bytes: src[1:], Status: Present} + *dst = JSONB{Bytes: src[1:], Status: Present} return nil } -func (src *Jsonb) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Json)(src).EncodeText(ci, buf) +func (src *JSONB) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*JSON)(src).EncodeText(ci, buf) } -func (src *Jsonb) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *JSONB) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -59,11 +59,11 @@ func (src *Jsonb) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Jsonb) Scan(src interface{}) error { - return (*Json)(dst).Scan(src) +func (dst *JSONB) Scan(src interface{}) error { + return (*JSON)(dst).Scan(src) } // Value implements the database/sql/driver Valuer interface. -func (src *Jsonb) Value() (driver.Value, error) { - return (*Json)(src).Value() +func (src *JSONB) Value() (driver.Value, error) { + return (*JSON)(src).Value() } diff --git a/jsonb_test.go b/jsonb_test.go index 86c8a12c..1a9a3056 100644 --- a/jsonb_test.go +++ b/jsonb_test.go @@ -9,7 +9,7 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestJsonbTranscode(t *testing.T) { +func TestJSONBTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustClose(t, conn) if _, ok := conn.ConnInfo.DataTypeForName("jsonb"); !ok { @@ -17,29 +17,29 @@ func TestJsonbTranscode(t *testing.T) { } testutil.TestSuccessfulTranscode(t, "jsonb", []interface{}{ - &pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, - &pgtype.Jsonb{Bytes: []byte("null"), Status: pgtype.Present}, - &pgtype.Jsonb{Bytes: []byte("42"), Status: pgtype.Present}, - &pgtype.Jsonb{Bytes: []byte(`"hello"`), Status: pgtype.Present}, - &pgtype.Jsonb{Status: pgtype.Null}, + &pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, + &pgtype.JSONB{Bytes: []byte("null"), Status: pgtype.Present}, + &pgtype.JSONB{Bytes: []byte("42"), Status: pgtype.Present}, + &pgtype.JSONB{Bytes: []byte(`"hello"`), Status: pgtype.Present}, + &pgtype.JSONB{Status: pgtype.Null}, }) } -func TestJsonbSet(t *testing.T) { +func TestJSONBSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.Jsonb + result pgtype.JSONB }{ - {source: "{}", result: pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: []byte("{}"), result: pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: ([]byte)(nil), result: pgtype.Jsonb{Status: pgtype.Null}}, - {source: (*string)(nil), result: pgtype.Jsonb{Status: pgtype.Null}}, - {source: []int{1, 2, 3}, result: pgtype.Jsonb{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, - {source: map[string]interface{}{"foo": "bar"}, result: pgtype.Jsonb{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, + {source: "{}", result: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: []byte("{}"), result: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: ([]byte)(nil), result: pgtype.JSONB{Status: pgtype.Null}}, + {source: (*string)(nil), result: pgtype.JSONB{Status: pgtype.Null}}, + {source: []int{1, 2, 3}, result: pgtype.JSONB{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, + {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, } for i, tt := range successfulTests { - var d pgtype.Jsonb + var d pgtype.JSONB err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -51,17 +51,17 @@ func TestJsonbSet(t *testing.T) { } } -func TestJsonbAssignTo(t *testing.T) { +func TestJSONBAssignTo(t *testing.T) { var s string var ps *string var b []byte rawStringTests := []struct { - src pgtype.Jsonb + src pgtype.JSONB dst *string expected string }{ - {src: pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, + {src: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, } for i, tt := range rawStringTests { @@ -76,12 +76,12 @@ func TestJsonbAssignTo(t *testing.T) { } rawBytesTests := []struct { - src pgtype.Jsonb + src pgtype.JSONB dst *[]byte expected []byte }{ - {src: pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, - {src: pgtype.Jsonb{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, + {src: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, + {src: pgtype.JSONB{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, } for i, tt := range rawBytesTests { @@ -103,12 +103,12 @@ func TestJsonbAssignTo(t *testing.T) { var strDst structDst unmarshalTests := []struct { - src pgtype.Jsonb + src pgtype.JSONB dst interface{} expected interface{} }{ - {src: pgtype.Jsonb{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, - {src: pgtype.Jsonb{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + {src: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, + {src: pgtype.JSONB{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, } for i, tt := range unmarshalTests { err := tt.src.AssignTo(tt.dst) @@ -122,11 +122,11 @@ func TestJsonbAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.Jsonb + src pgtype.JSONB dst **string expected *string }{ - {src: pgtype.Jsonb{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + {src: pgtype.JSONB{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, } for i, tt := range pointerAllocTests { diff --git a/pgtype.go b/pgtype.go index 60fab59f..2bfc9527 100644 --- a/pgtype.go +++ b/pgtype.go @@ -19,7 +19,7 @@ const ( TidOID = 27 XidOID = 28 CidOID = 29 - JsonOID = 114 + JSONOID = 114 CidrOID = 650 CidrArrayOID = 651 Float4OID = 700 @@ -47,7 +47,7 @@ const ( TimestamptzArrayOID = 1185 RecordOID = 2249 UUIDOID = 2950 - JsonbOID = 3802 + JSONBOID = 3802 ) type Status byte @@ -242,8 +242,8 @@ func init() { "int4range": &Int4range{}, "int8": &Int8{}, "int8range": &Int8range{}, - "json": &Json{}, - "jsonb": &Jsonb{}, + "json": &JSON{}, + "jsonb": &JSONB{}, "line": &Line{}, "lseg": &Lseg{}, "macaddr": &Macaddr{}, From 01fa5960b2a537b1759ad5d6425fbf4dbf14ea54 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Jun 2017 11:58:40 -0500 Subject: [PATCH 0105/1158] Use Go casing convention for ACLItem --- aclitem.go | 32 +++++++++---------- aclitem_array.go | 38 +++++++++++----------- aclitem_array_test.go | 74 +++++++++++++++++++++---------------------- aclitem_test.go | 34 ++++++++++---------- pgtype.go | 8 ++--- typed_array_gen.sh | 2 +- 6 files changed, 94 insertions(+), 94 deletions(-) diff --git a/aclitem.go b/aclitem.go index 27dc15d1..829eb908 100644 --- a/aclitem.go +++ b/aclitem.go @@ -5,7 +5,7 @@ import ( "fmt" ) -// Aclitem is used for PostgreSQL's aclitem data type. A sample aclitem +// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem // might look like this: // // postgres=arwdDxt/postgres @@ -17,32 +17,32 @@ import ( // // postgres=arwdDxt/"role with spaces" // -type Aclitem struct { +type ACLItem struct { String string Status Status } -func (dst *Aclitem) Set(src interface{}) error { +func (dst *ACLItem) Set(src interface{}) error { switch value := src.(type) { case string: - *dst = Aclitem{String: value, Status: Present} + *dst = ACLItem{String: value, Status: Present} case *string: if value == nil { - *dst = Aclitem{Status: Null} + *dst = ACLItem{Status: Null} } else { - *dst = Aclitem{String: *value, Status: Present} + *dst = ACLItem{String: *value, Status: Present} } default: if originalSrc, ok := underlyingStringType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Aclitem", value) + return fmt.Errorf("cannot convert %v to ACLItem", value) } return nil } -func (dst *Aclitem) Get() interface{} { +func (dst *ACLItem) Get() interface{} { switch dst.Status { case Present: return dst.String @@ -53,7 +53,7 @@ func (dst *Aclitem) Get() interface{} { } } -func (src *Aclitem) AssignTo(dst interface{}) error { +func (src *ACLItem) AssignTo(dst interface{}) error { switch src.Status { case Present: switch v := dst.(type) { @@ -72,17 +72,17 @@ func (src *Aclitem) AssignTo(dst interface{}) error { return fmt.Errorf("cannot decode %v into %T", src, dst) } -func (dst *Aclitem) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *ACLItem) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Aclitem{Status: Null} + *dst = ACLItem{Status: Null} return nil } - *dst = Aclitem{String: string(src), Status: Present} + *dst = ACLItem{String: string(src), Status: Present} return nil } -func (src *Aclitem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *ACLItem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -94,9 +94,9 @@ func (src *Aclitem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Aclitem) Scan(src interface{}) error { +func (dst *ACLItem) Scan(src interface{}) error { if src == nil { - *dst = Aclitem{Status: Null} + *dst = ACLItem{Status: Null} return nil } @@ -113,7 +113,7 @@ func (dst *Aclitem) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Aclitem) Value() (driver.Value, error) { +func (src *ACLItem) Value() (driver.Value, error) { switch src.Status { case Present: return src.String, nil diff --git a/aclitem_array.go b/aclitem_array.go index 7df0b503..f9215a93 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -5,28 +5,28 @@ import ( "fmt" ) -type AclitemArray struct { - Elements []Aclitem +type ACLItemArray struct { + Elements []ACLItem Dimensions []ArrayDimension Status Status } -func (dst *AclitemArray) Set(src interface{}) error { +func (dst *ACLItemArray) Set(src interface{}) error { switch value := src.(type) { case []string: if value == nil { - *dst = AclitemArray{Status: Null} + *dst = ACLItemArray{Status: Null} } else if len(value) == 0 { - *dst = AclitemArray{Status: Present} + *dst = ACLItemArray{Status: Present} } else { - elements := make([]Aclitem, len(value)) + elements := make([]ACLItem, len(value)) for i := range value { if err := elements[i].Set(value[i]); err != nil { return err } } - *dst = AclitemArray{ + *dst = ACLItemArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, Status: Present, @@ -37,13 +37,13 @@ func (dst *AclitemArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Aclitem", value) + return fmt.Errorf("cannot convert %v to ACLItem", value) } return nil } -func (dst *AclitemArray) Get() interface{} { +func (dst *ACLItemArray) Get() interface{} { switch dst.Status { case Present: return dst @@ -54,7 +54,7 @@ func (dst *AclitemArray) Get() interface{} { } } -func (src *AclitemArray) AssignTo(dst interface{}) error { +func (src *ACLItemArray) AssignTo(dst interface{}) error { switch src.Status { case Present: switch v := dst.(type) { @@ -80,9 +80,9 @@ func (src *AclitemArray) AssignTo(dst interface{}) error { return fmt.Errorf("cannot decode %v into %T", src, dst) } -func (dst *AclitemArray) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = AclitemArray{Status: Null} + *dst = ACLItemArray{Status: Null} return nil } @@ -91,13 +91,13 @@ func (dst *AclitemArray) DecodeText(ci *ConnInfo, src []byte) error { return err } - var elements []Aclitem + var elements []ACLItem if len(uta.Elements) > 0 { - elements = make([]Aclitem, len(uta.Elements)) + elements = make([]ACLItem, len(uta.Elements)) for i, s := range uta.Elements { - var elem Aclitem + var elem ACLItem var elemSrc []byte if s != "NULL" { elemSrc = []byte(s) @@ -111,12 +111,12 @@ func (dst *AclitemArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = AclitemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} return nil } -func (src *AclitemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -174,7 +174,7 @@ func (src *AclitemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } // Scan implements the database/sql Scanner interface. -func (dst *AclitemArray) Scan(src interface{}) error { +func (dst *ACLItemArray) Scan(src interface{}) error { if src == nil { return dst.DecodeText(nil, nil) } @@ -192,7 +192,7 @@ func (dst *AclitemArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *AclitemArray) Value() (driver.Value, error) { +func (src *ACLItemArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/aclitem_array_test.go b/aclitem_array_test.go index 951e7847..c01eaa13 100644 --- a/aclitem_array_test.go +++ b/aclitem_array_test.go @@ -8,40 +8,40 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestAclitemArrayTranscode(t *testing.T) { +func TestACLItemArrayTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "aclitem[]", []interface{}{ - &pgtype.AclitemArray{ + &pgtype.ACLItemArray{ Elements: nil, Dimensions: nil, Status: pgtype.Present, }, - &pgtype.AclitemArray{ - Elements: []pgtype.Aclitem{ - pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.Aclitem{Status: pgtype.Null}, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, - &pgtype.AclitemArray{Status: pgtype.Null}, - &pgtype.AclitemArray{ - Elements: []pgtype.Aclitem{ - pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - pgtype.Aclitem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.Aclitem{Status: pgtype.Null}, - pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, + &pgtype.ACLItemArray{Status: pgtype.Null}, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{Status: pgtype.Null}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, - &pgtype.AclitemArray{ - Elements: []pgtype.Aclitem{ - pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, @@ -52,26 +52,26 @@ func TestAclitemArrayTranscode(t *testing.T) { }) } -func TestAclitemArraySet(t *testing.T) { +func TestACLItemArraySet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.AclitemArray + result pgtype.ACLItemArray }{ { source: []string{"=r/postgres"}, - result: pgtype.AclitemArray{ - Elements: []pgtype.Aclitem{{String: "=r/postgres", Status: pgtype.Present}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, { source: (([]string)(nil)), - result: pgtype.AclitemArray{Status: pgtype.Null}, + result: pgtype.ACLItemArray{Status: pgtype.Null}, }, } for i, tt := range successfulTests { - var r pgtype.AclitemArray + var r pgtype.ACLItemArray err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -83,19 +83,19 @@ func TestAclitemArraySet(t *testing.T) { } } -func TestAclitemArrayAssignTo(t *testing.T) { +func TestACLItemArrayAssignTo(t *testing.T) { var stringSlice []string type _stringSlice []string var namedStringSlice _stringSlice simpleTests := []struct { - src pgtype.AclitemArray + src pgtype.ACLItemArray dst interface{} expected interface{} }{ { - src: pgtype.AclitemArray{ - Elements: []pgtype.Aclitem{{String: "=r/postgres", Status: pgtype.Present}}, + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, @@ -103,8 +103,8 @@ func TestAclitemArrayAssignTo(t *testing.T) { expected: []string{"=r/postgres"}, }, { - src: pgtype.AclitemArray{ - Elements: []pgtype.Aclitem{{String: "=r/postgres", Status: pgtype.Present}}, + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, @@ -112,7 +112,7 @@ func TestAclitemArrayAssignTo(t *testing.T) { expected: _stringSlice{"=r/postgres"}, }, { - src: pgtype.AclitemArray{Status: pgtype.Null}, + src: pgtype.ACLItemArray{Status: pgtype.Null}, dst: &stringSlice, expected: (([]string)(nil)), }, @@ -130,12 +130,12 @@ func TestAclitemArrayAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.AclitemArray + src pgtype.ACLItemArray dst interface{} }{ { - src: pgtype.AclitemArray{ - Elements: []pgtype.Aclitem{{Status: pgtype.Null}}, + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{Status: pgtype.Null}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, diff --git a/aclitem_test.go b/aclitem_test.go index 13c63395..65399a30 100644 --- a/aclitem_test.go +++ b/aclitem_test.go @@ -8,25 +8,25 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestAclitemTranscode(t *testing.T) { +func TestACLItemTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "aclitem", []interface{}{ - &pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - &pgtype.Aclitem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - &pgtype.Aclitem{Status: pgtype.Null}, + &pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + &pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + &pgtype.ACLItem{Status: pgtype.Null}, }) } -func TestAclitemSet(t *testing.T) { +func TestACLItemSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.Aclitem + result pgtype.ACLItem }{ - {source: "postgres=arwdDxt/postgres", result: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - {source: (*string)(nil), result: pgtype.Aclitem{Status: pgtype.Null}}, + {source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.ACLItem{Status: pgtype.Null}}, } for i, tt := range successfulTests { - var d pgtype.Aclitem + var d pgtype.ACLItem err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -38,17 +38,17 @@ func TestAclitemSet(t *testing.T) { } } -func TestAclitemAssignTo(t *testing.T) { +func TestACLItemAssignTo(t *testing.T) { var s string var ps *string simpleTests := []struct { - src pgtype.Aclitem + src pgtype.ACLItem dst interface{} expected interface{} }{ - {src: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"}, - {src: pgtype.Aclitem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"}, + {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, } for i, tt := range simpleTests { @@ -63,11 +63,11 @@ func TestAclitemAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.Aclitem + src pgtype.ACLItem dst interface{} expected interface{} }{ - {src: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, } for i, tt := range pointerAllocTests { @@ -82,10 +82,10 @@ func TestAclitemAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.Aclitem + src pgtype.ACLItem dst interface{} }{ - {src: pgtype.Aclitem{Status: pgtype.Null}, dst: &s}, + {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &s}, } for i, tt := range errorTests { diff --git a/pgtype.go b/pgtype.go index 2bfc9527..4fdcf3c2 100644 --- a/pgtype.go +++ b/pgtype.go @@ -35,8 +35,8 @@ const ( Int8ArrayOID = 1016 Float4ArrayOID = 1021 Float8ArrayOID = 1022 - AclitemOID = 1033 - AclitemArrayOID = 1034 + ACLItemOID = 1033 + ACLItemArrayOID = 1034 InetArrayOID = 1041 VarcharOID = 1043 DateOID = 1082 @@ -206,7 +206,7 @@ var nameValues map[string]Value func init() { nameValues = map[string]Value{ - "_aclitem": &AclitemArray{}, + "_aclitem": &ACLItemArray{}, "_bool": &BoolArray{}, "_bytea": &ByteaArray{}, "_cidr": &CidrArray{}, @@ -222,7 +222,7 @@ func init() { "_timestamp": &TimestampArray{}, "_timestamptz": &TimestamptzArray{}, "_varchar": &VarcharArray{}, - "aclitem": &Aclitem{}, + "aclitem": &ACLItem{}, "bool": &Bool{}, "box": &Box{}, "bytea": &Bytea{}, diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 2e36b8b3..d7abcbcf 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -12,7 +12,7 @@ erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.I erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' binary_format=true typed_array.go.erb > text_array.go erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' binary_format=true typed_array.go.erb > varchar_array.go erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go -erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]float64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go goimports -w *_array.go From 654adbdd4a6117789b1c14423964a9386f3dacbc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Jun 2017 12:01:49 -0500 Subject: [PATCH 0106/1158] Use Go casing convention for CID/TID/XID/CIDR --- cid.go | 26 ++++++------- cid_test.go | 30 +++++++-------- cidr.go | 16 ++++---- cidr_array.go | 58 ++++++++++++++--------------- cidr_array_test.go | 92 +++++++++++++++++++++++----------------------- inet_array_test.go | 36 +++++++++--------- inet_test.go | 36 +++++++++--------- pgtype.go | 20 +++++----- pgtype_test.go | 2 +- pguint32.go | 2 +- tid.go | 34 ++++++++--------- tid_test.go | 8 ++-- typed_array_gen.sh | 2 +- xid.go | 26 ++++++------- xid_test.go | 30 +++++++-------- 15 files changed, 209 insertions(+), 209 deletions(-) diff --git a/cid.go b/cid.go index b7718f88..0ed54f44 100644 --- a/cid.go +++ b/cid.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" ) -// Cid is PostgreSQL's Command Identifier type. +// CID is PostgreSQL's Command Identifier type. // // When one does // @@ -15,47 +15,47 @@ import ( // It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/c.h as CommandId // in the PostgreSQL sources. -type Cid pguint32 +type CID pguint32 -// Set converts from src to dst. Note that as Cid is not a general +// Set converts from src to dst. Note that as CID is not a general // number type Set does not do automatic type conversion as other number // types do. -func (dst *Cid) Set(src interface{}) error { +func (dst *CID) Set(src interface{}) error { return (*pguint32)(dst).Set(src) } -func (dst *Cid) Get() interface{} { +func (dst *CID) Get() interface{} { return (*pguint32)(dst).Get() } -// AssignTo assigns from src to dst. Note that as Cid is not a general number +// AssignTo assigns from src to dst. Note that as CID is not a general number // type AssignTo does not do automatic type conversion as other number types do. -func (src *Cid) AssignTo(dst interface{}) error { +func (src *CID) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *Cid) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *CID) DecodeText(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *Cid) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *CID) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src *Cid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *CID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return (*pguint32)(src).EncodeText(ci, buf) } -func (src *Cid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *CID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return (*pguint32)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. -func (dst *Cid) Scan(src interface{}) error { +func (dst *CID) Scan(src interface{}) error { return (*pguint32)(dst).Scan(src) } // Value implements the database/sql/driver Valuer interface. -func (src *Cid) Value() (driver.Value, error) { +func (src *CID) Value() (driver.Value, error) { return (*pguint32)(src).Value() } diff --git a/cid_test.go b/cid_test.go index c3bf3132..0dfc56d4 100644 --- a/cid_test.go +++ b/cid_test.go @@ -8,11 +8,11 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestCidTranscode(t *testing.T) { +func TestCIDTranscode(t *testing.T) { pgTypeName := "cid" values := []interface{}{ - &pgtype.Cid{Uint: 42, Status: pgtype.Present}, - &pgtype.Cid{Status: pgtype.Null}, + &pgtype.CID{Uint: 42, Status: pgtype.Present}, + &pgtype.CID{Status: pgtype.Null}, } eqFunc := func(a, b interface{}) bool { return reflect.DeepEqual(a, b) @@ -28,16 +28,16 @@ func TestCidTranscode(t *testing.T) { } } -func TestCidSet(t *testing.T) { +func TestCIDSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.Cid + result pgtype.CID }{ - {source: uint32(1), result: pgtype.Cid{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.CID{Uint: 1, Status: pgtype.Present}}, } for i, tt := range successfulTests { - var r pgtype.Cid + var r pgtype.CID err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -49,17 +49,17 @@ func TestCidSet(t *testing.T) { } } -func TestCidAssignTo(t *testing.T) { +func TestCIDAssignTo(t *testing.T) { var ui32 uint32 var pui32 *uint32 simpleTests := []struct { - src pgtype.Cid + src pgtype.CID dst interface{} expected interface{} }{ - {src: pgtype.Cid{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Cid{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.CID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -74,11 +74,11 @@ func TestCidAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.Cid + src pgtype.CID dst interface{} expected interface{} }{ - {src: pgtype.Cid{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -93,10 +93,10 @@ func TestCidAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.Cid + src pgtype.CID dst interface{} }{ - {src: pgtype.Cid{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.CID{Status: pgtype.Null}, dst: &ui32}, } for i, tt := range errorTests { diff --git a/cidr.go b/cidr.go index 2b45d2d0..519b9cae 100644 --- a/cidr.go +++ b/cidr.go @@ -1,31 +1,31 @@ package pgtype -type Cidr Inet +type CIDR Inet -func (dst *Cidr) Set(src interface{}) error { +func (dst *CIDR) Set(src interface{}) error { return (*Inet)(dst).Set(src) } -func (dst *Cidr) Get() interface{} { +func (dst *CIDR) Get() interface{} { return (*Inet)(dst).Get() } -func (src *Cidr) AssignTo(dst interface{}) error { +func (src *CIDR) AssignTo(dst interface{}) error { return (*Inet)(src).AssignTo(dst) } -func (dst *Cidr) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *CIDR) DecodeText(ci *ConnInfo, src []byte) error { return (*Inet)(dst).DecodeText(ci, src) } -func (dst *Cidr) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *CIDR) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Inet)(dst).DecodeBinary(ci, src) } -func (src *Cidr) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *CIDR) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return (*Inet)(src).EncodeText(ci, buf) } -func (src *Cidr) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *CIDR) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return (*Inet)(src).EncodeBinary(ci, buf) } diff --git a/cidr_array.go b/cidr_array.go index b8a70d63..9b7b50fa 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -9,28 +9,28 @@ import ( "github.com/jackc/pgx/pgio" ) -type CidrArray struct { - Elements []Cidr +type CIDRArray struct { + Elements []CIDR Dimensions []ArrayDimension Status Status } -func (dst *CidrArray) Set(src interface{}) error { +func (dst *CIDRArray) Set(src interface{}) error { switch value := src.(type) { case []*net.IPNet: if value == nil { - *dst = CidrArray{Status: Null} + *dst = CIDRArray{Status: Null} } else if len(value) == 0 { - *dst = CidrArray{Status: Present} + *dst = CIDRArray{Status: Present} } else { - elements := make([]Cidr, len(value)) + elements := make([]CIDR, len(value)) for i := range value { if err := elements[i].Set(value[i]); err != nil { return err } } - *dst = CidrArray{ + *dst = CIDRArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, Status: Present, @@ -39,17 +39,17 @@ func (dst *CidrArray) Set(src interface{}) error { case []net.IP: if value == nil { - *dst = CidrArray{Status: Null} + *dst = CIDRArray{Status: Null} } else if len(value) == 0 { - *dst = CidrArray{Status: Present} + *dst = CIDRArray{Status: Present} } else { - elements := make([]Cidr, len(value)) + elements := make([]CIDR, len(value)) for i := range value { if err := elements[i].Set(value[i]); err != nil { return err } } - *dst = CidrArray{ + *dst = CIDRArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, Status: Present, @@ -60,13 +60,13 @@ func (dst *CidrArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Cidr", value) + return fmt.Errorf("cannot convert %v to CIDR", value) } return nil } -func (dst *CidrArray) Get() interface{} { +func (dst *CIDRArray) Get() interface{} { switch dst.Status { case Present: return dst @@ -77,7 +77,7 @@ func (dst *CidrArray) Get() interface{} { } } -func (src *CidrArray) AssignTo(dst interface{}) error { +func (src *CIDRArray) AssignTo(dst interface{}) error { switch src.Status { case Present: switch v := dst.(type) { @@ -112,9 +112,9 @@ func (src *CidrArray) AssignTo(dst interface{}) error { return fmt.Errorf("cannot decode %v into %T", src, dst) } -func (dst *CidrArray) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = CidrArray{Status: Null} + *dst = CIDRArray{Status: Null} return nil } @@ -123,13 +123,13 @@ func (dst *CidrArray) DecodeText(ci *ConnInfo, src []byte) error { return err } - var elements []Cidr + var elements []CIDR if len(uta.Elements) > 0 { - elements = make([]Cidr, len(uta.Elements)) + elements = make([]CIDR, len(uta.Elements)) for i, s := range uta.Elements { - var elem Cidr + var elem CIDR var elemSrc []byte if s != "NULL" { elemSrc = []byte(s) @@ -143,14 +143,14 @@ func (dst *CidrArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = CidrArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = CIDRArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} return nil } -func (dst *CidrArray) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = CidrArray{Status: Null} + *dst = CIDRArray{Status: Null} return nil } @@ -161,7 +161,7 @@ func (dst *CidrArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = CidrArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = CIDRArray{Dimensions: arrayHeader.Dimensions, Status: Present} return nil } @@ -170,7 +170,7 @@ func (dst *CidrArray) DecodeBinary(ci *ConnInfo, src []byte) error { elementCount *= d.Length } - elements := make([]Cidr, elementCount) + elements := make([]CIDR, elementCount) for i := range elements { elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) @@ -186,11 +186,11 @@ func (dst *CidrArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = CidrArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = CIDRArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} return nil } -func (src *CidrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -247,7 +247,7 @@ func (src *CidrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *CidrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -292,7 +292,7 @@ func (src *CidrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } // Scan implements the database/sql Scanner interface. -func (dst *CidrArray) Scan(src interface{}) error { +func (dst *CIDRArray) Scan(src interface{}) error { if src == nil { return dst.DecodeText(nil, nil) } @@ -310,7 +310,7 @@ func (dst *CidrArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *CidrArray) Value() (driver.Value, error) { +func (src *CIDRArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/cidr_array_test.go b/cidr_array_test.go index 1ebe5195..70d3f65b 100644 --- a/cidr_array_test.go +++ b/cidr_array_test.go @@ -9,40 +9,40 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestCidrArrayTranscode(t *testing.T) { +func TestCIDRArrayTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "cidr[]", []interface{}{ - &pgtype.CidrArray{ + &pgtype.CIDRArray{ Elements: nil, Dimensions: nil, Status: pgtype.Present, }, - &pgtype.CidrArray{ - Elements: []pgtype.Cidr{ - pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Cidr{Status: pgtype.Null}, + &pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.CIDR{Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, - &pgtype.CidrArray{Status: pgtype.Null}, - &pgtype.CidrArray{ - Elements: []pgtype.Cidr{ - pgtype.Cidr{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Cidr{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, - pgtype.Cidr{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - pgtype.Cidr{Status: pgtype.Null}, - pgtype.Cidr{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, + &pgtype.CIDRArray{Status: pgtype.Null}, + &pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + pgtype.CIDR{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.CIDR{Status: pgtype.Null}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, - &pgtype.CidrArray{ - Elements: []pgtype.Cidr{ - pgtype.Cidr{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Cidr{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, - pgtype.Cidr{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + &pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + pgtype.CIDR{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, @@ -53,37 +53,37 @@ func TestCidrArrayTranscode(t *testing.T) { }) } -func TestCidrArraySet(t *testing.T) { +func TestCIDRArraySet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.CidrArray + result pgtype.CIDRArray }{ { - source: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, - result: pgtype.CidrArray{ - Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, { source: (([]*net.IPNet)(nil)), - result: pgtype.CidrArray{Status: pgtype.Null}, + result: pgtype.CIDRArray{Status: pgtype.Null}, }, { - source: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, - result: pgtype.CidrArray{ - Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, { source: (([]net.IP)(nil)), - result: pgtype.CidrArray{Status: pgtype.Null}, + result: pgtype.CIDRArray{Status: pgtype.Null}, }, } for i, tt := range successfulTests { - var r pgtype.CidrArray + var r pgtype.CIDRArray err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -95,27 +95,27 @@ func TestCidrArraySet(t *testing.T) { } } -func TestCidrArrayAssignTo(t *testing.T) { +func TestCIDRArrayAssignTo(t *testing.T) { var ipnetSlice []*net.IPNet var ipSlice []net.IP simpleTests := []struct { - src pgtype.CidrArray + src pgtype.CIDRArray dst interface{} expected interface{} }{ { - src: pgtype.CidrArray{ - Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, dst: &ipnetSlice, - expected: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, + expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, }, { - src: pgtype.CidrArray{ - Elements: []pgtype.Cidr{{Status: pgtype.Null}}, + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{Status: pgtype.Null}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, @@ -123,17 +123,17 @@ func TestCidrArrayAssignTo(t *testing.T) { expected: []*net.IPNet{nil}, }, { - src: pgtype.CidrArray{ - Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, dst: &ipSlice, - expected: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, + expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, }, { - src: pgtype.CidrArray{ - Elements: []pgtype.Cidr{{Status: pgtype.Null}}, + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{Status: pgtype.Null}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, @@ -141,12 +141,12 @@ func TestCidrArrayAssignTo(t *testing.T) { expected: []net.IP{nil}, }, { - src: pgtype.CidrArray{Status: pgtype.Null}, + src: pgtype.CIDRArray{Status: pgtype.Null}, dst: &ipnetSlice, expected: (([]*net.IPNet)(nil)), }, { - src: pgtype.CidrArray{Status: pgtype.Null}, + src: pgtype.CIDRArray{Status: pgtype.Null}, dst: &ipSlice, expected: (([]net.IP)(nil)), }, diff --git a/inet_array_test.go b/inet_array_test.go index c0465922..3e2b6a3c 100644 --- a/inet_array_test.go +++ b/inet_array_test.go @@ -18,7 +18,7 @@ func TestInetArrayTranscode(t *testing.T) { }, &pgtype.InetArray{ Elements: []pgtype.Inet{ - pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, pgtype.Inet{Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, @@ -27,22 +27,22 @@ func TestInetArrayTranscode(t *testing.T) { &pgtype.InetArray{Status: pgtype.Null}, &pgtype.InetArray{ Elements: []pgtype.Inet{ - pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, pgtype.Inet{Status: pgtype.Null}, - pgtype.Inet{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.InetArray{ Elements: []pgtype.Inet{ - pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, @@ -59,9 +59,9 @@ func TestInetArraySet(t *testing.T) { result pgtype.InetArray }{ { - source: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, + source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, @@ -70,9 +70,9 @@ func TestInetArraySet(t *testing.T) { result: pgtype.InetArray{Status: pgtype.Null}, }, { - source: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, + source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, @@ -106,12 +106,12 @@ func TestInetArrayAssignTo(t *testing.T) { }{ { src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, dst: &ipnetSlice, - expected: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, + expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, }, { src: pgtype.InetArray{ @@ -124,12 +124,12 @@ func TestInetArrayAssignTo(t *testing.T) { }, { src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, dst: &ipSlice, - expected: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, + expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, }, { src: pgtype.InetArray{ diff --git a/inet_test.go b/inet_test.go index b883df8e..32d66999 100644 --- a/inet_test.go +++ b/inet_test.go @@ -12,16 +12,16 @@ import ( func TestInetTranscode(t *testing.T) { for _, pgTypeName := range []string{"inet", "cidr"} { testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ - &pgtype.Inet{IPNet: mustParseCidr(t, "0.0.0.0/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "192.168.1.0/24"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "255.255.255.255/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "::/128"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "::/0"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "::1/128"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, &pgtype.Inet{Status: pgtype.Null}, }) } @@ -32,9 +32,9 @@ func TestInetSet(t *testing.T) { source interface{} result pgtype.Inet }{ - {source: mustParseCidr(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: mustParseCidr(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, } for i, tt := range successfulTests { @@ -61,8 +61,8 @@ func TestInetAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCidr(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCidr(t, "127.0.0.1/32").IP}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, {src: pgtype.Inet{Status: pgtype.Null}, dst: &pipnet, expected: ((*net.IPNet)(nil))}, {src: pgtype.Inet{Status: pgtype.Null}, dst: &pip, expected: ((*net.IP)(nil))}, } @@ -83,8 +83,8 @@ func TestInetAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCidr(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCidr(t, "127.0.0.1/32").IP}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, } for i, tt := range pointerAllocTests { @@ -102,7 +102,7 @@ func TestInetAssignTo(t *testing.T) { src pgtype.Inet dst interface{} }{ - {src: pgtype.Inet{IPNet: mustParseCidr(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip}, {src: pgtype.Inet{Status: pgtype.Null}, dst: &ipnet}, } diff --git a/pgtype.go b/pgtype.go index 4fdcf3c2..4302a5fe 100644 --- a/pgtype.go +++ b/pgtype.go @@ -16,12 +16,12 @@ const ( Int4OID = 23 TextOID = 25 OIDOID = 26 - TidOID = 27 - XidOID = 28 - CidOID = 29 + TIDOID = 27 + XIDOID = 28 + CIDOID = 29 JSONOID = 114 - CidrOID = 650 - CidrArrayOID = 651 + CIDROID = 650 + CIDRArrayOID = 651 Float4OID = 700 Float8OID = 701 UnknownOID = 705 @@ -209,7 +209,7 @@ func init() { "_aclitem": &ACLItemArray{}, "_bool": &BoolArray{}, "_bytea": &ByteaArray{}, - "_cidr": &CidrArray{}, + "_cidr": &CIDRArray{}, "_date": &DateArray{}, "_float4": &Float4Array{}, "_float8": &Float8Array{}, @@ -227,8 +227,8 @@ func init() { "box": &Box{}, "bytea": &Bytea{}, "char": &QChar{}, - "cid": &Cid{}, - "cidr": &Cidr{}, + "cid": &CID{}, + "cidr": &CIDR{}, "circle": &Circle{}, "date": &Date{}, "daterange": &Daterange{}, @@ -256,7 +256,7 @@ func init() { "polygon": &Polygon{}, "record": &Record{}, "text": &Text{}, - "tid": &Tid{}, + "tid": &TID{}, "timestamp": &Timestamp{}, "timestamptz": &Timestamptz{}, "tsrange": &Tsrange{}, @@ -265,6 +265,6 @@ func init() { "uuid": &UUID{}, "varbit": &Varbit{}, "varchar": &Varchar{}, - "xid": &Xid{}, + "xid": &XID{}, } } diff --git a/pgtype_test.go b/pgtype_test.go index 716e063d..f7e743b2 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -20,7 +20,7 @@ type _float32Slice []float32 type _float64Slice []float64 type _byteSlice []byte -func mustParseCidr(t testing.TB, s string) *net.IPNet { +func mustParseCIDR(t testing.TB, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { t.Fatal(err) diff --git a/pguint32.go b/pguint32.go index c15ee6d7..15b0f38d 100644 --- a/pguint32.go +++ b/pguint32.go @@ -11,7 +11,7 @@ import ( ) // pguint32 is the core type that is used to implement PostgreSQL types such as -// Cid and Xid. +// CID and XID. type pguint32 struct { Uint uint32 Status Status diff --git a/tid.go b/tid.go index 2f4412cb..d44ea3a6 100644 --- a/tid.go +++ b/tid.go @@ -10,7 +10,7 @@ import ( "github.com/jackc/pgx/pgio" ) -// Tid is PostgreSQL's Tuple Identifier type. +// TID is PostgreSQL's Tuple Identifier type. // // When one does // @@ -21,17 +21,17 @@ import ( // It is currently implemented as a pair unsigned two byte integers. // Its conversion functions can be found in src/backend/utils/adt/tid.c // in the PostgreSQL sources. -type Tid struct { +type TID struct { BlockNumber uint32 OffsetNumber uint16 Status Status } -func (dst *Tid) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Tid", src) +func (dst *TID) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to TID", src) } -func (dst *Tid) Get() interface{} { +func (dst *TID) Get() interface{} { switch dst.Status { case Present: return dst @@ -42,13 +42,13 @@ func (dst *Tid) Get() interface{} { } } -func (src *Tid) AssignTo(dst interface{}) error { +func (src *TID) AssignTo(dst interface{}) error { return fmt.Errorf("cannot assign %v to %T", src, dst) } -func (dst *Tid) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *TID) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Tid{Status: Null} + *dst = TID{Status: Null} return nil } @@ -71,13 +71,13 @@ func (dst *Tid) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Tid{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Status: Present} + *dst = TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Status: Present} return nil } -func (dst *Tid) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *TID) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Tid{Status: Null} + *dst = TID{Status: Null} return nil } @@ -85,7 +85,7 @@ func (dst *Tid) DecodeBinary(ci *ConnInfo, src []byte) error { return fmt.Errorf("invalid length for tid: %v", len(src)) } - *dst = Tid{ + *dst = TID{ BlockNumber: binary.BigEndian.Uint32(src), OffsetNumber: binary.BigEndian.Uint16(src[4:]), Status: Present, @@ -93,7 +93,7 @@ func (dst *Tid) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Tid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *TID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -105,7 +105,7 @@ func (src *Tid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Tid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *TID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -119,9 +119,9 @@ func (src *Tid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Tid) Scan(src interface{}) error { +func (dst *TID) Scan(src interface{}) error { if src == nil { - *dst = Tid{Status: Null} + *dst = TID{Status: Null} return nil } @@ -138,6 +138,6 @@ func (dst *Tid) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Tid) Value() (driver.Value, error) { +func (src *TID) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/tid_test.go b/tid_test.go index a5430d11..9185cb31 100644 --- a/tid_test.go +++ b/tid_test.go @@ -7,10 +7,10 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestTidTranscode(t *testing.T) { +func TestTIDTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "tid", []interface{}{ - &pgtype.Tid{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, - &pgtype.Tid{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, - &pgtype.Tid{Status: pgtype.Null}, + &pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, + &pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, + &pgtype.TID{Status: pgtype.Null}, }) } diff --git a/typed_array_gen.sh b/typed_array_gen.sh index d7abcbcf..1aa6c354 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -8,7 +8,7 @@ erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_type erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go -erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go +erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' binary_format=true typed_array.go.erb > text_array.go erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' binary_format=true typed_array.go.erb > varchar_array.go erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go diff --git a/xid.go b/xid.go index 84acd1b0..f66f5367 100644 --- a/xid.go +++ b/xid.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" ) -// Xid is PostgreSQL's Transaction ID type. +// XID is PostgreSQL's Transaction ID type. // // In later versions of PostgreSQL, it is the type used for the backend_xid // and backend_xmin columns of the pg_stat_activity system view. @@ -18,47 +18,47 @@ import ( // It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/postgres_ext.h as TransactionId // in the PostgreSQL sources. -type Xid pguint32 +type XID pguint32 -// Set converts from src to dst. Note that as Xid is not a general +// Set converts from src to dst. Note that as XID is not a general // number type Set does not do automatic type conversion as other number // types do. -func (dst *Xid) Set(src interface{}) error { +func (dst *XID) Set(src interface{}) error { return (*pguint32)(dst).Set(src) } -func (dst *Xid) Get() interface{} { +func (dst *XID) Get() interface{} { return (*pguint32)(dst).Get() } -// AssignTo assigns from src to dst. Note that as Xid is not a general number +// AssignTo assigns from src to dst. Note that as XID is not a general number // type AssignTo does not do automatic type conversion as other number types do. -func (src *Xid) AssignTo(dst interface{}) error { +func (src *XID) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *Xid) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *XID) DecodeText(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *Xid) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *XID) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src *Xid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *XID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return (*pguint32)(src).EncodeText(ci, buf) } -func (src *Xid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *XID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return (*pguint32)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. -func (dst *Xid) Scan(src interface{}) error { +func (dst *XID) Scan(src interface{}) error { return (*pguint32)(dst).Scan(src) } // Value implements the database/sql/driver Valuer interface. -func (src *Xid) Value() (driver.Value, error) { +func (src *XID) Value() (driver.Value, error) { return (*pguint32)(src).Value() } diff --git a/xid_test.go b/xid_test.go index c4a1bec3..d0f3f0ab 100644 --- a/xid_test.go +++ b/xid_test.go @@ -8,11 +8,11 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestXidTranscode(t *testing.T) { +func TestXIDTranscode(t *testing.T) { pgTypeName := "xid" values := []interface{}{ - &pgtype.Xid{Uint: 42, Status: pgtype.Present}, - &pgtype.Xid{Status: pgtype.Null}, + &pgtype.XID{Uint: 42, Status: pgtype.Present}, + &pgtype.XID{Status: pgtype.Null}, } eqFunc := func(a, b interface{}) bool { return reflect.DeepEqual(a, b) @@ -28,16 +28,16 @@ func TestXidTranscode(t *testing.T) { } } -func TestXidSet(t *testing.T) { +func TestXIDSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.Xid + result pgtype.XID }{ - {source: uint32(1), result: pgtype.Xid{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.XID{Uint: 1, Status: pgtype.Present}}, } for i, tt := range successfulTests { - var r pgtype.Xid + var r pgtype.XID err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -49,17 +49,17 @@ func TestXidSet(t *testing.T) { } } -func TestXidAssignTo(t *testing.T) { +func TestXIDAssignTo(t *testing.T) { var ui32 uint32 var pui32 *uint32 simpleTests := []struct { - src pgtype.Xid + src pgtype.XID dst interface{} expected interface{} }{ - {src: pgtype.Xid{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Xid{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.XID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -74,11 +74,11 @@ func TestXidAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.Xid + src pgtype.XID dst interface{} expected interface{} }{ - {src: pgtype.Xid{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -93,10 +93,10 @@ func TestXidAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.Xid + src pgtype.XID dst interface{} }{ - {src: pgtype.Xid{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.XID{Status: pgtype.Null}, dst: &ui32}, } for i, tt := range errorTests { From a5f166bd217dcdd8691694185207116d89ccb289 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 4 Jun 2017 21:30:03 -0500 Subject: [PATCH 0107/1158] Use github.com/pkg/errors --- aclitem.go | 9 ++--- aclitem_array.go | 9 ++--- array.go | 36 +++++++++---------- bool.go | 13 +++---- bool_array.go | 10 +++--- box.go | 11 +++--- bytea.go | 11 +++--- bytea_array.go | 10 +++--- cidr_array.go | 10 +++--- circle.go | 11 +++--- convert.go | 59 ++++++++++++++++--------------- database_sql.go | 3 +- date.go | 12 +++---- date_array.go | 10 +++--- daterange.go | 24 ++++++------- ext/satori-uuid/uuid.go | 14 ++++---- ext/shopspring-numeric/decimal.go | 52 +++++++++++++-------------- float4.go | 20 +++++------ float4_array.go | 10 +++--- float8.go | 16 ++++----- float8_array.go | 10 +++--- hstore.go | 32 ++++++++--------- hstore_array.go | 10 +++--- inet.go | 15 ++++---- inet_array.go | 10 +++--- int2.go | 32 ++++++++--------- int2_array.go | 10 +++--- int4.go | 26 +++++++------- int4_array.go | 10 +++--- int4range.go | 24 ++++++------- int8.go | 16 ++++----- int8_array.go | 10 +++--- int8range.go | 24 ++++++------- interval.go | 23 ++++++------ json.go | 5 +-- jsonb.go | 7 ++-- line.go | 13 +++---- lseg.go | 11 +++--- macaddr.go | 11 +++--- numeric.go | 56 ++++++++++++++--------------- numeric_array.go | 10 +++--- numrange.go | 24 ++++++------- oid.go | 12 +++---- path.go | 13 +++---- pgtype.go | 3 +- pguint32.go | 14 ++++---- point.go | 13 +++---- polygon.go | 13 +++---- qchar.go | 33 ++++++++--------- range.go | 39 ++++++++++---------- record.go | 15 ++++---- text.go | 9 ++--- text_array.go | 10 +++--- tid.go | 13 +++---- timestamp.go | 16 ++++----- timestamp_array.go | 10 +++--- timestamptz.go | 12 +++---- timestamptz_array.go | 10 +++--- tsrange.go | 24 ++++++------- tstzrange.go | 24 ++++++------- typed_array.go.erb | 8 ++--- typed_range.go.erb | 22 ++++++------ uuid.go | 14 ++++---- varbit.go | 10 +++--- varchar_array.go | 10 +++--- 65 files changed, 556 insertions(+), 530 deletions(-) diff --git a/aclitem.go b/aclitem.go index 829eb908..35269e91 100644 --- a/aclitem.go +++ b/aclitem.go @@ -2,7 +2,8 @@ package pgtype import ( "database/sql/driver" - "fmt" + + "github.com/pkg/errors" ) // ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem @@ -36,7 +37,7 @@ func (dst *ACLItem) Set(src interface{}) error { if originalSrc, ok := underlyingStringType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to ACLItem", value) + return errors.Errorf("cannot convert %v to ACLItem", value) } return nil @@ -69,7 +70,7 @@ func (src *ACLItem) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *ACLItem) DecodeText(ci *ConnInfo, src []byte) error { @@ -109,7 +110,7 @@ func (dst *ACLItem) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/aclitem_array.go b/aclitem_array.go index f9215a93..fe0af434 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -2,7 +2,8 @@ package pgtype import ( "database/sql/driver" - "fmt" + + "github.com/pkg/errors" ) type ACLItemArray struct { @@ -37,7 +38,7 @@ func (dst *ACLItemArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to ACLItem", value) + return errors.Errorf("cannot convert %v to ACLItem", value) } return nil @@ -77,7 +78,7 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -188,7 +189,7 @@ func (dst *ACLItemArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/array.go b/array.go index e5504455..5b852ed5 100644 --- a/array.go +++ b/array.go @@ -3,13 +3,13 @@ package pgtype import ( "bytes" "encoding/binary" - "fmt" "io" "strconv" "strings" "unicode" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) // Information on the internals of PostgreSQL arrays can be found in @@ -29,7 +29,7 @@ type ArrayDimension struct { func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { if len(src) < 12 { - return 0, fmt.Errorf("array header too short: %d", len(src)) + return 0, errors.Errorf("array header too short: %d", len(src)) } rp := 0 @@ -47,7 +47,7 @@ func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { dst.Dimensions = make([]ArrayDimension, numDims) } if len(src) < 12+numDims*8 { - return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) + return 0, errors.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) } for i := range dst.Dimensions { dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:])) @@ -93,7 +93,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { r, _, err := buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } var explicitDimensions []ArrayDimension @@ -105,41 +105,41 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { for { r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } if r == '=' { break } else if r != '[' { - return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r) + return nil, errors.Errorf("invalid array, expected '[' or '=' got %v", r) } lower, err := arrayParseInteger(buf) if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } if r != ':' { - return nil, fmt.Errorf("invalid array, expected ':' got %v", r) + return nil, errors.Errorf("invalid array, expected ':' got %v", r) } upper, err := arrayParseInteger(buf) if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } if r != ']' { - return nil, fmt.Errorf("invalid array, expected ']' got %v", r) + return nil, errors.Errorf("invalid array, expected ']' got %v", r) } explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1}) @@ -147,12 +147,12 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } } if r != '{' { - return nil, fmt.Errorf("invalid array, expected '{': %v", err) + return nil, errors.Errorf("invalid array, expected '{': %v", err) } implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}} @@ -161,7 +161,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { for { r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } if r == '{' { @@ -178,7 +178,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { for { r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } switch r { @@ -197,7 +197,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { buf.UnreadRune() value, err := arrayParseValue(buf) if err != nil { - return nil, fmt.Errorf("invalid array value: %v", err) + return nil, errors.Errorf("invalid array value: %v", err) } if currentDim == counterDim { implicitDimensions[currentDim].Length++ @@ -213,7 +213,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { skipWhitespace(buf) if buf.Len() > 0 { - return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + return nil, errors.Errorf("unexpected trailing data: %v", buf.String()) } if len(dst.Elements) == 0 { diff --git a/bool.go b/bool.go index 7c66a534..3a3eef48 100644 --- a/bool.go +++ b/bool.go @@ -2,8 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "strconv" + + "github.com/pkg/errors" ) type Bool struct { @@ -25,7 +26,7 @@ func (dst *Bool) Set(src interface{}) error { if originalSrc, ok := underlyingBoolType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Bool", value) + return errors.Errorf("cannot convert %v to Bool", value) } return nil @@ -58,7 +59,7 @@ func (src *Bool) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { @@ -68,7 +69,7 @@ func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) != 1 { - return fmt.Errorf("invalid length for bool: %v", len(src)) + return errors.Errorf("invalid length for bool: %v", len(src)) } *dst = Bool{Bool: src[0] == 't', Status: Present} @@ -82,7 +83,7 @@ func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 1 { - return fmt.Errorf("invalid length for bool: %v", len(src)) + return errors.Errorf("invalid length for bool: %v", len(src)) } *dst = Bool{Bool: src[0] == 1, Status: Present} @@ -142,7 +143,7 @@ func (dst *Bool) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/bool_array.go b/bool_array.go index e20a0381..e23c27e5 100644 --- a/bool_array.go +++ b/bool_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type BoolArray struct { @@ -40,7 +40,7 @@ func (dst *BoolArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Bool", value) + return errors.Errorf("cannot convert %v to Bool", value) } return nil @@ -80,7 +80,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -233,7 +233,7 @@ func (src *BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("bool"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "bool") + return nil, errors.Errorf("unable to find oid for type name %v", "bool") } for i := range src.Elements { @@ -277,7 +277,7 @@ func (dst *BoolArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/box.go b/box.go index 2d098058..83df0499 100644 --- a/box.go +++ b/box.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Box struct { @@ -17,7 +18,7 @@ type Box struct { } func (dst *Box) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Box", src) + return errors.Errorf("cannot convert %v to Box", src) } func (dst *Box) Get() interface{} { @@ -32,7 +33,7 @@ func (dst *Box) Get() interface{} { } func (src *Box) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { @@ -42,7 +43,7 @@ func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 11 { - return fmt.Errorf("invalid length for Box: %v", len(src)) + return errors.Errorf("invalid length for Box: %v", len(src)) } str := string(src[1:]) @@ -89,7 +90,7 @@ func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 32 { - return fmt.Errorf("invalid length for Box: %v", len(src)) + return errors.Errorf("invalid length for Box: %v", len(src)) } x1 := binary.BigEndian.Uint64(src) @@ -152,7 +153,7 @@ func (dst *Box) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/bytea.go b/bytea.go index 2ddac7da..c7117f48 100644 --- a/bytea.go +++ b/bytea.go @@ -3,7 +3,8 @@ package pgtype import ( "database/sql/driver" "encoding/hex" - "fmt" + + "github.com/pkg/errors" ) type Bytea struct { @@ -28,7 +29,7 @@ func (dst *Bytea) Set(src interface{}) error { if originalSrc, ok := underlyingBytesType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Bytea", value) + return errors.Errorf("cannot convert %v to Bytea", value) } return nil @@ -63,7 +64,7 @@ func (src *Bytea) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } // DecodeText only supports the hex format. This has been the default since @@ -75,7 +76,7 @@ func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 2 || src[0] != '\\' || src[1] != 'x' { - return fmt.Errorf("invalid hex format") + return errors.Errorf("invalid hex format") } buf := make([]byte, (len(src)-2)/2) @@ -139,7 +140,7 @@ func (dst *Bytea) Scan(src interface{}) error { return nil } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/bytea_array.go b/bytea_array.go index 0d381693..f2842179 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type ByteaArray struct { @@ -40,7 +40,7 @@ func (dst *ByteaArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Bytea", value) + return errors.Errorf("cannot convert %v to Bytea", value) } return nil @@ -80,7 +80,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -233,7 +233,7 @@ func (src *ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("bytea"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "bytea") + return nil, errors.Errorf("unable to find oid for type name %v", "bytea") } for i := range src.Elements { @@ -277,7 +277,7 @@ func (dst *ByteaArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/cidr_array.go b/cidr_array.go index 9b7b50fa..2373da46 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "net" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type CIDRArray struct { @@ -60,7 +60,7 @@ func (dst *CIDRArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to CIDR", value) + return errors.Errorf("cannot convert %v to CIDR", value) } return nil @@ -109,7 +109,7 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -262,7 +262,7 @@ func (src *CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("cidr"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "cidr") + return nil, errors.Errorf("unable to find oid for type name %v", "cidr") } for i := range src.Elements { @@ -306,7 +306,7 @@ func (dst *CIDRArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/circle.go b/circle.go index 8626a99d..97ecbf31 100644 --- a/circle.go +++ b/circle.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Circle struct { @@ -18,7 +19,7 @@ type Circle struct { } func (dst *Circle) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Circle", src) + return errors.Errorf("cannot convert %v to Circle", src) } func (dst *Circle) Get() interface{} { @@ -33,7 +34,7 @@ func (dst *Circle) Get() interface{} { } func (src *Circle) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { @@ -43,7 +44,7 @@ func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 9 { - return fmt.Errorf("invalid length for Circle: %v", len(src)) + return errors.Errorf("invalid length for Circle: %v", len(src)) } str := string(src[2:]) @@ -79,7 +80,7 @@ func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 24 { - return fmt.Errorf("invalid length for Circle: %v", len(src)) + return errors.Errorf("invalid length for Circle: %v", len(src)) } x := binary.BigEndian.Uint64(src) @@ -136,7 +137,7 @@ func (dst *Circle) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/convert.go b/convert.go index 2b406426..5dfb738e 100644 --- a/convert.go +++ b/convert.go @@ -1,10 +1,11 @@ package pgtype import ( - "fmt" "math" "reflect" "time" + + "github.com/pkg/errors" ) const maxUint = ^uint(0) @@ -189,70 +190,70 @@ func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { switch v := dst.(type) { case *int: if srcVal < int64(minInt) { - return fmt.Errorf("%d is less than minimum value for int", srcVal) + return errors.Errorf("%d is less than minimum value for int", srcVal) } else if srcVal > int64(maxInt) { - return fmt.Errorf("%d is greater than maximum value for int", srcVal) + return errors.Errorf("%d is greater than maximum value for int", srcVal) } *v = int(srcVal) case *int8: if srcVal < math.MinInt8 { - return fmt.Errorf("%d is less than minimum value for int8", srcVal) + return errors.Errorf("%d is less than minimum value for int8", srcVal) } else if srcVal > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for int8", srcVal) + return errors.Errorf("%d is greater than maximum value for int8", srcVal) } *v = int8(srcVal) case *int16: if srcVal < math.MinInt16 { - return fmt.Errorf("%d is less than minimum value for int16", srcVal) + return errors.Errorf("%d is less than minimum value for int16", srcVal) } else if srcVal > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for int16", srcVal) + return errors.Errorf("%d is greater than maximum value for int16", srcVal) } *v = int16(srcVal) case *int32: if srcVal < math.MinInt32 { - return fmt.Errorf("%d is less than minimum value for int32", srcVal) + return errors.Errorf("%d is less than minimum value for int32", srcVal) } else if srcVal > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for int32", srcVal) + return errors.Errorf("%d is greater than maximum value for int32", srcVal) } *v = int32(srcVal) case *int64: if srcVal < math.MinInt64 { - return fmt.Errorf("%d is less than minimum value for int64", srcVal) + return errors.Errorf("%d is less than minimum value for int64", srcVal) } else if srcVal > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for int64", srcVal) + return errors.Errorf("%d is greater than maximum value for int64", srcVal) } *v = int64(srcVal) case *uint: if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint", srcVal) + return errors.Errorf("%d is less than zero for uint", srcVal) } else if uint64(srcVal) > uint64(maxUint) { - return fmt.Errorf("%d is greater than maximum value for uint", srcVal) + return errors.Errorf("%d is greater than maximum value for uint", srcVal) } *v = uint(srcVal) case *uint8: if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint8", srcVal) + return errors.Errorf("%d is less than zero for uint8", srcVal) } else if srcVal > math.MaxUint8 { - return fmt.Errorf("%d is greater than maximum value for uint8", srcVal) + return errors.Errorf("%d is greater than maximum value for uint8", srcVal) } *v = uint8(srcVal) case *uint16: if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint32", srcVal) + return errors.Errorf("%d is less than zero for uint32", srcVal) } else if srcVal > math.MaxUint16 { - return fmt.Errorf("%d is greater than maximum value for uint16", srcVal) + return errors.Errorf("%d is greater than maximum value for uint16", srcVal) } *v = uint16(srcVal) case *uint32: if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint32", srcVal) + return errors.Errorf("%d is less than zero for uint32", srcVal) } else if srcVal > math.MaxUint32 { - return fmt.Errorf("%d is greater than maximum value for uint32", srcVal) + return errors.Errorf("%d is greater than maximum value for uint32", srcVal) } *v = uint32(srcVal) case *uint64: if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint64", srcVal) + return errors.Errorf("%d is less than zero for uint64", srcVal) } *v = uint64(srcVal) default: @@ -268,22 +269,22 @@ func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { return int64AssignTo(srcVal, srcStatus, el.Interface()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if el.OverflowInt(int64(srcVal)) { - return fmt.Errorf("cannot put %d into %T", srcVal, dst) + return errors.Errorf("cannot put %d into %T", srcVal, dst) } el.SetInt(int64(srcVal)) return nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: if srcVal < 0 { - return fmt.Errorf("%d is less than zero for %T", srcVal, dst) + return errors.Errorf("%d is less than zero for %T", srcVal, dst) } if el.OverflowUint(uint64(srcVal)) { - return fmt.Errorf("cannot put %d into %T", srcVal, dst) + return errors.Errorf("cannot put %d into %T", srcVal, dst) } el.SetUint(uint64(srcVal)) return nil } } - return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + return errors.Errorf("cannot assign %v into %T", srcVal, dst) } return nil } @@ -297,7 +298,7 @@ func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { } } - return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) + return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) } func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { @@ -325,7 +326,7 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { } } } - return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + return errors.Errorf("cannot assign %v into %T", srcVal, dst) } return nil } @@ -339,7 +340,7 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { } } - return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) + return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) } func NullAssignTo(dst interface{}) error { @@ -347,7 +348,7 @@ func NullAssignTo(dst interface{}) error { // AssignTo dst must always be a pointer if dstPtr.Kind() != reflect.Ptr { - return fmt.Errorf("cannot assign NULL to %T", dst) + return errors.Errorf("cannot assign NULL to %T", dst) } dstVal := dstPtr.Elem() @@ -358,7 +359,7 @@ func NullAssignTo(dst interface{}) error { return nil } - return fmt.Errorf("cannot assign NULL to %T", dst) + return errors.Errorf("cannot assign NULL to %T", dst) } var kindTypes map[reflect.Kind]reflect.Type diff --git a/database_sql.go b/database_sql.go index 9d1cf822..969536dd 100644 --- a/database_sql.go +++ b/database_sql.go @@ -2,7 +2,8 @@ package pgtype import ( "database/sql/driver" - "errors" + + "github.com/pkg/errors" ) func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { diff --git a/date.go b/date.go index 8e049254..f1c0d8bd 100644 --- a/date.go +++ b/date.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "time" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Date struct { @@ -33,7 +33,7 @@ func (dst *Date) Set(src interface{}) error { if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Date", value) + return errors.Errorf("cannot convert %v to Date", value) } return nil @@ -59,7 +59,7 @@ func (src *Date) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: if src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } *v = src.Time return nil @@ -72,7 +72,7 @@ func (src *Date) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { @@ -106,7 +106,7 @@ func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 4 { - return fmt.Errorf("invalid length for date: %v", len(src)) + return errors.Errorf("invalid length for date: %v", len(src)) } dayOffset := int32(binary.BigEndian.Uint32(src)) @@ -190,7 +190,7 @@ func (dst *Date) Scan(src interface{}) error { return nil } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/date_array.go b/date_array.go index ef91cf3e..383945e7 100644 --- a/date_array.go +++ b/date_array.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "time" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type DateArray struct { @@ -41,7 +41,7 @@ func (dst *DateArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Date", value) + return errors.Errorf("cannot convert %v to Date", value) } return nil @@ -81,7 +81,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -234,7 +234,7 @@ func (src *DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("date"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "date") + return nil, errors.Errorf("unable to find oid for type name %v", "date") } for i := range src.Elements { @@ -278,7 +278,7 @@ func (dst *DateArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/daterange.go b/daterange.go index bbe7b17a..47cd7e46 100644 --- a/daterange.go +++ b/daterange.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Daterange struct { @@ -16,7 +16,7 @@ type Daterange struct { } func (dst *Daterange) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Daterange", src) + return errors.Errorf("cannot convert %v to Daterange", src) } func (dst *Daterange) Get() interface{} { @@ -31,7 +31,7 @@ func (dst *Daterange) Get() interface{} { } func (src *Daterange) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Daterange) DecodeText(ci *ConnInfo, src []byte) error { @@ -120,7 +120,7 @@ func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -130,7 +130,7 @@ func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -141,7 +141,7 @@ func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -151,7 +151,7 @@ func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -175,7 +175,7 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -185,7 +185,7 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -201,7 +201,7 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -216,7 +216,7 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -241,7 +241,7 @@ func (dst *Daterange) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/ext/satori-uuid/uuid.go b/ext/satori-uuid/uuid.go index b7b776f9..78a90035 100644 --- a/ext/satori-uuid/uuid.go +++ b/ext/satori-uuid/uuid.go @@ -2,8 +2,8 @@ package uuid import ( "database/sql/driver" - "errors" - "fmt" + + "github.com/pkg/errors" "github.com/jackc/pgx/pgtype" uuid "github.com/satori/go.uuid" @@ -24,7 +24,7 @@ func (dst *UUID) Set(src interface{}) error { *dst = UUID{UUID: uuid.UUID(value), Status: pgtype.Present} case []byte: if len(value) != 16 { - return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) } *dst = UUID{Status: pgtype.Present} copy(dst.UUID[:], value) @@ -38,7 +38,7 @@ func (dst *UUID) Set(src interface{}) error { // If all else fails see if pgtype.UUID can handle it. If so, translate through that. pgUUID := &pgtype.UUID{} if err := pgUUID.Set(value); err != nil { - return fmt.Errorf("cannot convert %v to UUID", value) + return errors.Errorf("cannot convert %v to UUID", value) } *dst = UUID{UUID: uuid.UUID(pgUUID.Bytes), Status: pgUUID.Status} @@ -83,7 +83,7 @@ func (src *UUID) AssignTo(dst interface{}) error { return pgtype.NullAssignTo(dst) } - return fmt.Errorf("cannot assign %v into %T", src, dst) + return errors.Errorf("cannot assign %v into %T", src, dst) } func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { @@ -108,7 +108,7 @@ func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { } if len(src) != 16 { - return fmt.Errorf("invalid length for UUID: %v", len(src)) + return errors.Errorf("invalid length for UUID: %v", len(src)) } *dst = UUID{Status: pgtype.Present} @@ -152,7 +152,7 @@ func (dst *UUID) Scan(src interface{}) error { return dst.DecodeText(nil, src) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index 277f3709..507a93dc 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -2,10 +2,10 @@ package numeric import ( "database/sql/driver" - "errors" - "fmt" "strconv" + "github.com/pkg/errors" + "github.com/jackc/pgx/pgtype" "github.com/shopspring/decimal" ) @@ -70,17 +70,17 @@ func (dst *Numeric) Set(src interface{}) error { // If all else fails see if pgtype.Numeric can handle it. If so, translate through that. num := &pgtype.Numeric{} if err := num.Set(value); err != nil { - return fmt.Errorf("cannot convert %v to Numeric", value) + return errors.Errorf("cannot convert %v to Numeric", value) } buf, err := num.EncodeText(nil, nil) if err != nil { - return fmt.Errorf("cannot convert %v to Numeric", value) + return errors.Errorf("cannot convert %v to Numeric", value) } dec, err := decimal.NewFromString(string(buf)) if err != nil { - return fmt.Errorf("cannot convert %v to Numeric", value) + return errors.Errorf("cannot convert %v to Numeric", value) } *dst = Numeric{Decimal: dec, Status: pgtype.Present} } @@ -113,92 +113,92 @@ func (src *Numeric) AssignTo(dst interface{}) error { *v = f case *int: if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseInt(src.Decimal.String(), 10, strconv.IntSize) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = int(n) case *int8: if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseInt(src.Decimal.String(), 10, 8) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = int8(n) case *int16: if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseInt(src.Decimal.String(), 10, 16) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = int16(n) case *int32: if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseInt(src.Decimal.String(), 10, 32) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = int32(n) case *int64: if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseInt(src.Decimal.String(), 10, 64) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = int64(n) case *uint: if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseUint(src.Decimal.String(), 10, strconv.IntSize) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = uint(n) case *uint8: if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseUint(src.Decimal.String(), 10, 8) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = uint8(n) case *uint16: if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseUint(src.Decimal.String(), 10, 16) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = uint16(n) case *uint32: if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseUint(src.Decimal.String(), 10, 32) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = uint32(n) case *uint64: if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseUint(src.Decimal.String(), 10, 64) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = uint64(n) default: @@ -301,7 +301,7 @@ func (dst *Numeric) Scan(src interface{}) error { return dst.DecodeText(nil, src) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/float4.go b/float4.go index b24654b6..2207594a 100644 --- a/float4.go +++ b/float4.go @@ -3,11 +3,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "math" "strconv" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Float4 struct { @@ -39,42 +39,42 @@ func (dst *Float4) Set(src interface{}) error { if int32(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) + return errors.Errorf("%v cannot be exactly represented as float32", value) } case uint32: f32 := float32(value) if uint32(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) + return errors.Errorf("%v cannot be exactly represented as float32", value) } case int64: f32 := float32(value) if int64(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) + return errors.Errorf("%v cannot be exactly represented as float32", value) } case uint64: f32 := float32(value) if uint64(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) + return errors.Errorf("%v cannot be exactly represented as float32", value) } case int: f32 := float32(value) if int(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) + return errors.Errorf("%v cannot be exactly represented as float32", value) } case uint: f32 := float32(value) if uint(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) + return errors.Errorf("%v cannot be exactly represented as float32", value) } case string: num, err := strconv.ParseFloat(value, 32) @@ -86,7 +86,7 @@ func (dst *Float4) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Float8", value) + return errors.Errorf("cannot convert %v to Float8", value) } return nil @@ -129,7 +129,7 @@ func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 4 { - return fmt.Errorf("invalid length for float4: %v", len(src)) + return errors.Errorf("invalid length for float4: %v", len(src)) } n := int32(binary.BigEndian.Uint32(src)) @@ -181,7 +181,7 @@ func (dst *Float4) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/float4_array.go b/float4_array.go index a35657b0..6499064b 100644 --- a/float4_array.go +++ b/float4_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Float4Array struct { @@ -40,7 +40,7 @@ func (dst *Float4Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Float4", value) + return errors.Errorf("cannot convert %v to Float4", value) } return nil @@ -80,7 +80,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { @@ -233,7 +233,7 @@ func (src *Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("float4"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "float4") + return nil, errors.Errorf("unable to find oid for type name %v", "float4") } for i := range src.Elements { @@ -277,7 +277,7 @@ func (dst *Float4Array) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/float8.go b/float8.go index c3ecdcc2..dd34f541 100644 --- a/float8.go +++ b/float8.go @@ -3,11 +3,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "math" "strconv" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Float8 struct { @@ -43,28 +43,28 @@ func (dst *Float8) Set(src interface{}) error { if int64(f64) == value { *dst = Float8{Float: f64, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float64", value) + return errors.Errorf("%v cannot be exactly represented as float64", value) } case uint64: f64 := float64(value) if uint64(f64) == value { *dst = Float8{Float: f64, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float64", value) + return errors.Errorf("%v cannot be exactly represented as float64", value) } case int: f64 := float64(value) if int(f64) == value { *dst = Float8{Float: f64, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float64", value) + return errors.Errorf("%v cannot be exactly represented as float64", value) } case uint: f64 := float64(value) if uint(f64) == value { *dst = Float8{Float: f64, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float64", value) + return errors.Errorf("%v cannot be exactly represented as float64", value) } case string: num, err := strconv.ParseFloat(value, 64) @@ -76,7 +76,7 @@ func (dst *Float8) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Float8", value) + return errors.Errorf("cannot convert %v to Float8", value) } return nil @@ -119,7 +119,7 @@ func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 { - return fmt.Errorf("invalid length for float4: %v", len(src)) + return errors.Errorf("invalid length for float4: %v", len(src)) } n := int64(binary.BigEndian.Uint64(src)) @@ -171,7 +171,7 @@ func (dst *Float8) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/float8_array.go b/float8_array.go index 486e3a4e..27b24836 100644 --- a/float8_array.go +++ b/float8_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Float8Array struct { @@ -40,7 +40,7 @@ func (dst *Float8Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Float8", value) + return errors.Errorf("cannot convert %v to Float8", value) } return nil @@ -80,7 +80,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { @@ -233,7 +233,7 @@ func (src *Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("float8"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "float8") + return nil, errors.Errorf("unable to find oid for type name %v", "float8") } for i := range src.Elements { @@ -277,7 +277,7 @@ func (dst *Float8Array) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/hstore.go b/hstore.go index 09506242..347446ae 100644 --- a/hstore.go +++ b/hstore.go @@ -4,12 +4,12 @@ import ( "bytes" "database/sql/driver" "encoding/binary" - "errors" - "fmt" "strings" "unicode" "unicode/utf8" + "github.com/pkg/errors" + "github.com/jackc/pgx/pgio" ) @@ -34,7 +34,7 @@ func (dst *Hstore) Set(src interface{}) error { } *dst = Hstore{Map: m, Status: Present} default: - return fmt.Errorf("cannot convert %v to Hstore", src) + return errors.Errorf("cannot convert %v to Hstore", src) } return nil @@ -59,7 +59,7 @@ func (src *Hstore) AssignTo(dst interface{}) error { *v = make(map[string]string, len(src.Map)) for k, val := range src.Map { if val.Status != Present { - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } (*v)[k] = val.String } @@ -73,7 +73,7 @@ func (src *Hstore) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { @@ -105,7 +105,7 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { rp := 0 if len(src[rp:]) < 4 { - return fmt.Errorf("hstore incomplete %v", src) + return errors.Errorf("hstore incomplete %v", src) } pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 @@ -114,19 +114,19 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { for i := 0; i < pairCount; i++ { if len(src[rp:]) < 4 { - return fmt.Errorf("hstore incomplete %v", src) + return errors.Errorf("hstore incomplete %v", src) } keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 if len(src[rp:]) < keyLen { - return fmt.Errorf("hstore incomplete %v", src) + return errors.Errorf("hstore incomplete %v", src) } key := string(src[rp : rp+keyLen]) rp += keyLen if len(src[rp:]) < 4 { - return fmt.Errorf("hstore incomplete %v", src) + return errors.Errorf("hstore incomplete %v", src) } valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 @@ -333,13 +333,13 @@ func parseHstore(s string) (k []string, v []Text, err error) { case r == 'N': state = hsNul default: - err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) + err = errors.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) } default: - err = fmt.Errorf("Invalid character after '=', expecting '>'") + err = errors.Errorf("Invalid character after '=', expecting '>'") } } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) + err = errors.Errorf("Invalid character '%c' after value, expecting '='", r) } case hsVal: switch r { @@ -376,7 +376,7 @@ func parseHstore(s string) (k []string, v []Text, err error) { values = append(values, Text{Status: Null}) state = hsNext } else { - err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) + err = errors.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) } case hsNext: if r == ',' { @@ -388,10 +388,10 @@ func parseHstore(s string) (k []string, v []Text, err error) { r, end = p.Consume() state = hsKey default: - err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) + err = errors.Errorf("Invalid character '%c' after ', ', expecting \"", r) } } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) + err = errors.Errorf("Invalid character '%c' after value, expecting ','", r) } } @@ -425,7 +425,7 @@ func (dst *Hstore) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/hstore_array.go b/hstore_array.go index 3e5a003f..38ce457b 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type HstoreArray struct { @@ -40,7 +40,7 @@ func (dst *HstoreArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Hstore", value) + return errors.Errorf("cannot convert %v to Hstore", value) } return nil @@ -80,7 +80,7 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -233,7 +233,7 @@ func (src *HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("hstore"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "hstore") + return nil, errors.Errorf("unable to find oid for type name %v", "hstore") } for i := range src.Elements { @@ -277,7 +277,7 @@ func (dst *HstoreArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/inet.go b/inet.go index 7aa1df95..01fc0e5b 100644 --- a/inet.go +++ b/inet.go @@ -2,8 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "net" + + "github.com/pkg/errors" ) // Network address family is dependent on server socket.h value for AF_INET. @@ -45,7 +46,7 @@ func (dst *Inet) Set(src interface{}) error { if originalSrc, ok := underlyingPtrType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Inet", value) + return errors.Errorf("cannot convert %v to Inet", value) } return nil @@ -76,7 +77,7 @@ func (src *Inet) AssignTo(dst interface{}) error { return nil case *net.IP: if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } *v = make(net.IP, len(src.IPNet.IP)) copy(*v, src.IPNet.IP) @@ -90,7 +91,7 @@ func (src *Inet) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { @@ -128,7 +129,7 @@ func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 && len(src) != 20 { - return fmt.Errorf("Received an invalid size for a inet: %d", len(src)) + return errors.Errorf("Received an invalid size for a inet: %d", len(src)) } // ignore family @@ -173,7 +174,7 @@ func (src *Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case net.IPv6len: family = defaultAFInet6 default: - return nil, fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) + return nil, errors.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) } buf = append(buf, family) @@ -205,7 +206,7 @@ func (dst *Inet) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/inet_array.go b/inet_array.go index 57123c1c..3ece23eb 100644 --- a/inet_array.go +++ b/inet_array.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "net" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type InetArray struct { @@ -60,7 +60,7 @@ func (dst *InetArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Inet", value) + return errors.Errorf("cannot convert %v to Inet", value) } return nil @@ -109,7 +109,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -262,7 +262,7 @@ func (src *InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("inet"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "inet") + return nil, errors.Errorf("unable to find oid for type name %v", "inet") } for i := range src.Elements { @@ -306,7 +306,7 @@ func (dst *InetArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/int2.go b/int2.go index a58c3355..45bce93c 100644 --- a/int2.go +++ b/int2.go @@ -3,11 +3,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "math" "strconv" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Int2 struct { @@ -30,46 +30,46 @@ func (dst *Int2) Set(src interface{}) error { *dst = Int2{Int: int16(value), Status: Present} case uint16: if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case int32: if value < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case uint32: if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case int64: if value < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case uint64: if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case int: if value < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case uint: if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case string: @@ -82,7 +82,7 @@ func (dst *Int2) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Int2", value) + return errors.Errorf("cannot convert %v to Int2", value) } return nil @@ -125,7 +125,7 @@ func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 2 { - return fmt.Errorf("invalid length for int2: %v", len(src)) + return errors.Errorf("invalid length for int2: %v", len(src)) } n := int16(binary.BigEndian.Uint16(src)) @@ -165,10 +165,10 @@ func (dst *Int2) Scan(src interface{}) error { switch src := src.(type) { case int64: if src < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", src) + return errors.Errorf("%d is greater than maximum value for Int2", src) } if src > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", src) + return errors.Errorf("%d is greater than maximum value for Int2", src) } *dst = Int2{Int: int16(src), Status: Present} return nil @@ -180,7 +180,7 @@ func (dst *Int2) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/int2_array.go b/int2_array.go index e4993104..e939411b 100644 --- a/int2_array.go +++ b/int2_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Int2Array struct { @@ -59,7 +59,7 @@ func (dst *Int2Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Int2", value) + return errors.Errorf("cannot convert %v to Int2", value) } return nil @@ -108,7 +108,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { @@ -261,7 +261,7 @@ func (src *Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("int2"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "int2") + return nil, errors.Errorf("unable to find oid for type name %v", "int2") } for i := range src.Elements { @@ -305,7 +305,7 @@ func (dst *Int2Array) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/int4.go b/int4.go index 6f95013b..a3499fef 100644 --- a/int4.go +++ b/int4.go @@ -3,11 +3,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "math" "strconv" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Int4 struct { @@ -34,33 +34,33 @@ func (dst *Int4) Set(src interface{}) error { *dst = Int4{Int: int32(value), Status: Present} case uint32: if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) + return errors.Errorf("%d is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case int64: if value < math.MinInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) + return errors.Errorf("%d is greater than maximum value for Int4", value) } if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) + return errors.Errorf("%d is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case uint64: if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) + return errors.Errorf("%d is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case int: if value < math.MinInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) + return errors.Errorf("%d is greater than maximum value for Int4", value) } if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) + return errors.Errorf("%d is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case uint: if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) + return errors.Errorf("%d is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case string: @@ -73,7 +73,7 @@ func (dst *Int4) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Int4", value) + return errors.Errorf("cannot convert %v to Int4", value) } return nil @@ -116,7 +116,7 @@ func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 4 { - return fmt.Errorf("invalid length for int4: %v", len(src)) + return errors.Errorf("invalid length for int4: %v", len(src)) } n := int32(binary.BigEndian.Uint32(src)) @@ -156,10 +156,10 @@ func (dst *Int4) Scan(src interface{}) error { switch src := src.(type) { case int64: if src < math.MinInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", src) + return errors.Errorf("%d is greater than maximum value for Int4", src) } if src > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", src) + return errors.Errorf("%d is greater than maximum value for Int4", src) } *dst = Int4{Int: int32(src), Status: Present} return nil @@ -171,7 +171,7 @@ func (dst *Int4) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/int4_array.go b/int4_array.go index 6bc06e86..1a907d2e 100644 --- a/int4_array.go +++ b/int4_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Int4Array struct { @@ -59,7 +59,7 @@ func (dst *Int4Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Int4", value) + return errors.Errorf("cannot convert %v to Int4", value) } return nil @@ -108,7 +108,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { @@ -261,7 +261,7 @@ func (src *Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("int4"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "int4") + return nil, errors.Errorf("unable to find oid for type name %v", "int4") } for i := range src.Elements { @@ -305,7 +305,7 @@ func (dst *Int4Array) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/int4range.go b/int4range.go index 4f27ff0d..95ad1521 100644 --- a/int4range.go +++ b/int4range.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Int4range struct { @@ -16,7 +16,7 @@ type Int4range struct { } func (dst *Int4range) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Int4range", src) + return errors.Errorf("cannot convert %v to Int4range", src) } func (dst *Int4range) Get() interface{} { @@ -31,7 +31,7 @@ func (dst *Int4range) Get() interface{} { } func (src *Int4range) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Int4range) DecodeText(ci *ConnInfo, src []byte) error { @@ -120,7 +120,7 @@ func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -130,7 +130,7 @@ func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -141,7 +141,7 @@ func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -151,7 +151,7 @@ func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -175,7 +175,7 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -185,7 +185,7 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -201,7 +201,7 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -216,7 +216,7 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -241,7 +241,7 @@ func (dst *Int4range) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/int8.go b/int8.go index 939c0554..d671eda7 100644 --- a/int8.go +++ b/int8.go @@ -3,11 +3,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "math" "strconv" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Int8 struct { @@ -38,20 +38,20 @@ func (dst *Int8) Set(src interface{}) error { *dst = Int8{Int: int64(value), Status: Present} case uint64: if value > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", value) + return errors.Errorf("%d is greater than maximum value for Int8", value) } *dst = Int8{Int: int64(value), Status: Present} case int: if int64(value) < math.MinInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", value) + return errors.Errorf("%d is greater than maximum value for Int8", value) } if int64(value) > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", value) + return errors.Errorf("%d is greater than maximum value for Int8", value) } *dst = Int8{Int: int64(value), Status: Present} case uint: if uint64(value) > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", value) + return errors.Errorf("%d is greater than maximum value for Int8", value) } *dst = Int8{Int: int64(value), Status: Present} case string: @@ -64,7 +64,7 @@ func (dst *Int8) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Int8", value) + return errors.Errorf("cannot convert %v to Int8", value) } return nil @@ -107,7 +107,7 @@ func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 { - return fmt.Errorf("invalid length for int8: %v", len(src)) + return errors.Errorf("invalid length for int8: %v", len(src)) } n := int64(binary.BigEndian.Uint64(src)) @@ -157,7 +157,7 @@ func (dst *Int8) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/int8_array.go b/int8_array.go index 4404d22a..4f3ab4dc 100644 --- a/int8_array.go +++ b/int8_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Int8Array struct { @@ -59,7 +59,7 @@ func (dst *Int8Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Int8", value) + return errors.Errorf("cannot convert %v to Int8", value) } return nil @@ -108,7 +108,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { @@ -261,7 +261,7 @@ func (src *Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("int8"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "int8") + return nil, errors.Errorf("unable to find oid for type name %v", "int8") } for i := range src.Elements { @@ -305,7 +305,7 @@ func (dst *Int8Array) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/int8range.go b/int8range.go index 128a853f..61d860d3 100644 --- a/int8range.go +++ b/int8range.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Int8range struct { @@ -16,7 +16,7 @@ type Int8range struct { } func (dst *Int8range) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Int8range", src) + return errors.Errorf("cannot convert %v to Int8range", src) } func (dst *Int8range) Get() interface{} { @@ -31,7 +31,7 @@ func (dst *Int8range) Get() interface{} { } func (src *Int8range) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Int8range) DecodeText(ci *ConnInfo, src []byte) error { @@ -120,7 +120,7 @@ func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -130,7 +130,7 @@ func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -141,7 +141,7 @@ func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -151,7 +151,7 @@ func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -175,7 +175,7 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -185,7 +185,7 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -201,7 +201,7 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -216,7 +216,7 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -241,7 +241,7 @@ func (dst *Int8range) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/interval.go b/interval.go index 85d76d99..799ce53a 100644 --- a/interval.go +++ b/interval.go @@ -9,6 +9,7 @@ import ( "time" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) const ( @@ -37,7 +38,7 @@ func (dst *Interval) Set(src interface{}) error { if originalSrc, ok := underlyingPtrType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Interval", value) + return errors.Errorf("cannot convert %v to Interval", value) } return nil @@ -60,7 +61,7 @@ func (src *Interval) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Duration: if src.Days > 0 || src.Months > 0 { - return fmt.Errorf("interval with months or days cannot be decoded into %T", dst) + return errors.Errorf("interval with months or days cannot be decoded into %T", dst) } *v = time.Duration(src.Microseconds) * time.Microsecond return nil @@ -73,7 +74,7 @@ func (src *Interval) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { @@ -91,7 +92,7 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { for i := 0; i < len(parts)-1; i += 2 { scalar, err := strconv.ParseInt(parts[i], 10, 64) if err != nil { - return fmt.Errorf("bad interval format") + return errors.Errorf("bad interval format") } switch parts[i+1] { @@ -107,7 +108,7 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { if len(parts)%2 == 1 { timeParts := strings.SplitN(parts[len(parts)-1], ":", 3) if len(timeParts) != 3 { - return fmt.Errorf("bad interval format") + return errors.Errorf("bad interval format") } var negative bool @@ -118,26 +119,26 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { hours, err := strconv.ParseInt(timeParts[0], 10, 64) if err != nil { - return fmt.Errorf("bad interval hour format: %s", timeParts[0]) + return errors.Errorf("bad interval hour format: %s", timeParts[0]) } minutes, err := strconv.ParseInt(timeParts[1], 10, 64) if err != nil { - return fmt.Errorf("bad interval minute format: %s", timeParts[1]) + return errors.Errorf("bad interval minute format: %s", timeParts[1]) } secondParts := strings.SplitN(timeParts[2], ".", 2) seconds, err := strconv.ParseInt(secondParts[0], 10, 64) if err != nil { - return fmt.Errorf("bad interval second format: %s", secondParts[0]) + return errors.Errorf("bad interval second format: %s", secondParts[0]) } var uSeconds int64 if len(secondParts) == 2 { uSeconds, err = strconv.ParseInt(secondParts[1], 10, 64) if err != nil { - return fmt.Errorf("bad interval decimal format: %s", secondParts[1]) + return errors.Errorf("bad interval decimal format: %s", secondParts[1]) } for i := 0; i < 6-len(secondParts[1]); i++ { @@ -166,7 +167,7 @@ func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 16 { - return fmt.Errorf("Received an invalid size for a interval: %d", len(src)) + return errors.Errorf("Received an invalid size for a interval: %d", len(src)) } microseconds := int64(binary.BigEndian.Uint64(src)) @@ -240,7 +241,7 @@ func (dst *Interval) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/json.go b/json.go index ee00e9a4..562722aa 100644 --- a/json.go +++ b/json.go @@ -3,7 +3,8 @@ package pgtype import ( "database/sql/driver" "encoding/json" - "fmt" + + "github.com/pkg/errors" ) type JSON struct { @@ -135,7 +136,7 @@ func (dst *JSON) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/jsonb.go b/jsonb.go index 9a06c1b4..c315c588 100644 --- a/jsonb.go +++ b/jsonb.go @@ -2,7 +2,8 @@ package pgtype import ( "database/sql/driver" - "fmt" + + "github.com/pkg/errors" ) type JSONB JSON @@ -30,11 +31,11 @@ func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) == 0 { - return fmt.Errorf("jsonb too short") + return errors.Errorf("jsonb too short") } if src[0] != 1 { - return fmt.Errorf("unknown jsonb version number %d", src[0]) + return errors.Errorf("unknown jsonb version number %d", src[0]) } *dst = JSONB{Bytes: src[1:], Status: Present} diff --git a/line.go b/line.go index 47f636a5..f6eadf0e 100644 --- a/line.go +++ b/line.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Line struct { @@ -17,7 +18,7 @@ type Line struct { } func (dst *Line) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Line", src) + return errors.Errorf("cannot convert %v to Line", src) } func (dst *Line) Get() interface{} { @@ -32,7 +33,7 @@ func (dst *Line) Get() interface{} { } func (src *Line) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { @@ -42,12 +43,12 @@ func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 7 { - return fmt.Errorf("invalid length for Line: %v", len(src)) + return errors.Errorf("invalid length for Line: %v", len(src)) } parts := strings.SplitN(string(src[1:len(src)-1]), ",", 3) if len(parts) < 3 { - return fmt.Errorf("invalid format for line") + return errors.Errorf("invalid format for line") } a, err := strconv.ParseFloat(parts[0], 64) @@ -76,7 +77,7 @@ func (dst *Line) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 24 { - return fmt.Errorf("invalid length for Line: %v", len(src)) + return errors.Errorf("invalid length for Line: %v", len(src)) } a := binary.BigEndian.Uint64(src) @@ -133,7 +134,7 @@ func (dst *Line) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/lseg.go b/lseg.go index 44c2b63c..a9d740cf 100644 --- a/lseg.go +++ b/lseg.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Lseg struct { @@ -17,7 +18,7 @@ type Lseg struct { } func (dst *Lseg) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Lseg", src) + return errors.Errorf("cannot convert %v to Lseg", src) } func (dst *Lseg) Get() interface{} { @@ -32,7 +33,7 @@ func (dst *Lseg) Get() interface{} { } func (src *Lseg) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { @@ -42,7 +43,7 @@ func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 11 { - return fmt.Errorf("invalid length for Lseg: %v", len(src)) + return errors.Errorf("invalid length for Lseg: %v", len(src)) } str := string(src[2:]) @@ -89,7 +90,7 @@ func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 32 { - return fmt.Errorf("invalid length for Lseg: %v", len(src)) + return errors.Errorf("invalid length for Lseg: %v", len(src)) } x1 := binary.BigEndian.Uint64(src) @@ -151,7 +152,7 @@ func (dst *Lseg) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/macaddr.go b/macaddr.go index e38701eb..4c6e2212 100644 --- a/macaddr.go +++ b/macaddr.go @@ -2,8 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "net" + + "github.com/pkg/errors" ) type Macaddr struct { @@ -32,7 +33,7 @@ func (dst *Macaddr) Set(src interface{}) error { if originalSrc, ok := underlyingPtrType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Macaddr", value) + return errors.Errorf("cannot convert %v to Macaddr", value) } return nil @@ -69,7 +70,7 @@ func (src *Macaddr) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Macaddr) DecodeText(ci *ConnInfo, src []byte) error { @@ -94,7 +95,7 @@ func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 6 { - return fmt.Errorf("Received an invalid size for a macaddr: %d", len(src)) + return errors.Errorf("Received an invalid size for a macaddr: %d", len(src)) } addr := make(net.HardwareAddr, 6) @@ -144,7 +145,7 @@ func (dst *Macaddr) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/numeric.go b/numeric.go index dffb9963..fded6359 100644 --- a/numeric.go +++ b/numeric.go @@ -3,13 +3,13 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "math" "math/big" "strconv" "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) // PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 @@ -97,7 +97,7 @@ func (dst *Numeric) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Numeric", value) + return errors.Errorf("cannot convert %v to Numeric", value) } return nil @@ -136,10 +136,10 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(bigMaxInt) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) } if normalizedInt.Cmp(bigMinInt) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) } *v = int(normalizedInt.Int64()) case *int8: @@ -148,10 +148,10 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(bigMaxInt8) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) } if normalizedInt.Cmp(bigMinInt8) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) } *v = int8(normalizedInt.Int64()) case *int16: @@ -160,10 +160,10 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(bigMaxInt16) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) } if normalizedInt.Cmp(bigMinInt16) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) } *v = int16(normalizedInt.Int64()) case *int32: @@ -172,10 +172,10 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(bigMaxInt32) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) } if normalizedInt.Cmp(bigMinInt32) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) } *v = int32(normalizedInt.Int64()) case *int64: @@ -184,10 +184,10 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(bigMaxInt64) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) } if normalizedInt.Cmp(bigMinInt64) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) } *v = normalizedInt.Int64() case *uint: @@ -196,9 +196,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) } else if normalizedInt.Cmp(bigMaxUint) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) } *v = uint(normalizedInt.Uint64()) case *uint8: @@ -207,9 +207,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) } else if normalizedInt.Cmp(bigMaxUint8) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) } *v = uint8(normalizedInt.Uint64()) case *uint16: @@ -218,9 +218,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) } else if normalizedInt.Cmp(bigMaxUint16) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) } *v = uint16(normalizedInt.Uint64()) case *uint32: @@ -229,9 +229,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) } else if normalizedInt.Cmp(bigMaxUint32) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) } *v = uint32(normalizedInt.Uint64()) case *uint64: @@ -240,9 +240,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) } else if normalizedInt.Cmp(bigMaxUint64) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) } *v = normalizedInt.Uint64() default: @@ -276,7 +276,7 @@ func (dst *Numeric) toBigInt() (*big.Int, error) { remainder := &big.Int{} num.DivMod(num, div, remainder) if remainder.Cmp(big0) != 0 { - return nil, fmt.Errorf("cannot convert %v to integer", dst) + return nil, errors.Errorf("cannot convert %v to integer", dst) } return num, nil } @@ -328,7 +328,7 @@ func parseNumericString(str string) (n *big.Int, exp int32, err error) { accum := &big.Int{} if _, ok := accum.SetString(digits, 10); !ok { - return nil, 0, fmt.Errorf("%s is not a number", str) + return nil, 0, errors.Errorf("%s is not a number", str) } return accum, exp, nil @@ -341,7 +341,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) < 8 { - return fmt.Errorf("numeric incomplete %v", src) + return errors.Errorf("numeric incomplete %v", src) } rp := 0 @@ -361,7 +361,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { rp += 2 if len(src[rp:]) < int(ndigits)*2 { - return fmt.Errorf("numeric incomplete %v", src) + return errors.Errorf("numeric incomplete %v", src) } accum := &big.Int{} @@ -382,7 +382,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { case 4: mul = bigNBaseX4 default: - return fmt.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) + return errors.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) } accum.Mul(accum, mul) } @@ -575,7 +575,7 @@ func (dst *Numeric) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/numeric_array.go b/numeric_array.go index f193a2a5..6dfbe5e3 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type NumericArray struct { @@ -59,7 +59,7 @@ func (dst *NumericArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Numeric", value) + return errors.Errorf("cannot convert %v to Numeric", value) } return nil @@ -108,7 +108,7 @@ func (src *NumericArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -261,7 +261,7 @@ func (src *NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) if dt, ok := ci.DataTypeForName("numeric"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "numeric") + return nil, errors.Errorf("unable to find oid for type name %v", "numeric") } for i := range src.Elements { @@ -305,7 +305,7 @@ func (dst *NumericArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/numrange.go b/numrange.go index 00133296..aaed62ce 100644 --- a/numrange.go +++ b/numrange.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Numrange struct { @@ -16,7 +16,7 @@ type Numrange struct { } func (dst *Numrange) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Numrange", src) + return errors.Errorf("cannot convert %v to Numrange", src) } func (dst *Numrange) Get() interface{} { @@ -31,7 +31,7 @@ func (dst *Numrange) Get() interface{} { } func (src *Numrange) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Numrange) DecodeText(ci *ConnInfo, src []byte) error { @@ -120,7 +120,7 @@ func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -130,7 +130,7 @@ func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -141,7 +141,7 @@ func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -151,7 +151,7 @@ func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -175,7 +175,7 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -185,7 +185,7 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -201,7 +201,7 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -216,7 +216,7 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -241,7 +241,7 @@ func (dst *Numrange) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/oid.go b/oid.go index d37f4e57..59370d66 100644 --- a/oid.go +++ b/oid.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "strconv" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) // OID (Object Identifier Type) is, according to @@ -20,7 +20,7 @@ type OID uint32 func (dst *OID) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - return fmt.Errorf("cannot decode nil into OID") + return errors.Errorf("cannot decode nil into OID") } n, err := strconv.ParseUint(string(src), 10, 32) @@ -34,11 +34,11 @@ func (dst *OID) DecodeText(ci *ConnInfo, src []byte) error { func (dst *OID) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - return fmt.Errorf("cannot decode nil into OID") + return errors.Errorf("cannot decode nil into OID") } if len(src) != 4 { - return fmt.Errorf("invalid length: %v", len(src)) + return errors.Errorf("invalid length: %v", len(src)) } n := binary.BigEndian.Uint32(src) @@ -57,7 +57,7 @@ func (src OID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *OID) Scan(src interface{}) error { if src == nil { - return fmt.Errorf("cannot scan NULL into %T", src) + return errors.Errorf("cannot scan NULL into %T", src) } switch src := src.(type) { @@ -72,7 +72,7 @@ func (dst *OID) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/path.go b/path.go index 3575342d..aa0cee8e 100644 --- a/path.go +++ b/path.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Path struct { @@ -18,7 +19,7 @@ type Path struct { } func (dst *Path) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Path", src) + return errors.Errorf("cannot convert %v to Path", src) } func (dst *Path) Get() interface{} { @@ -33,7 +34,7 @@ func (dst *Path) Get() interface{} { } func (src *Path) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Path) DecodeText(ci *ConnInfo, src []byte) error { @@ -43,7 +44,7 @@ func (dst *Path) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 7 { - return fmt.Errorf("invalid length for Path: %v", len(src)) + return errors.Errorf("invalid length for Path: %v", len(src)) } closed := src[0] == '(' @@ -86,7 +87,7 @@ func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) < 5 { - return fmt.Errorf("invalid length for Path: %v", len(src)) + return errors.Errorf("invalid length for Path: %v", len(src)) } closed := src[0] == 1 @@ -95,7 +96,7 @@ func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { rp := 5 if 5+pointCount*16 != len(src) { - return fmt.Errorf("invalid length for Path with %d points: %v", pointCount, len(src)) + return errors.Errorf("invalid length for Path with %d points: %v", pointCount, len(src)) } points := make([]Vec2, pointCount) @@ -183,7 +184,7 @@ func (dst *Path) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype.go b/pgtype.go index 4302a5fe..6f8e7986 100644 --- a/pgtype.go +++ b/pgtype.go @@ -1,8 +1,9 @@ package pgtype import ( - "errors" "reflect" + + "github.com/pkg/errors" ) // PostgreSQL oids for common types diff --git a/pguint32.go b/pguint32.go index 15b0f38d..e441a690 100644 --- a/pguint32.go +++ b/pguint32.go @@ -3,11 +3,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "math" "strconv" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) // pguint32 is the core type that is used to implement PostgreSQL types such as @@ -24,16 +24,16 @@ func (dst *pguint32) Set(src interface{}) error { switch value := src.(type) { case int64: if value < 0 { - return fmt.Errorf("%d is less than minimum value for pguint32", value) + return errors.Errorf("%d is less than minimum value for pguint32", value) } if value > math.MaxUint32 { - return fmt.Errorf("%d is greater than maximum value for pguint32", value) + return errors.Errorf("%d is greater than maximum value for pguint32", value) } *dst = pguint32{Uint: uint32(value), Status: Present} case uint32: *dst = pguint32{Uint: value, Status: Present} default: - return fmt.Errorf("cannot convert %v to pguint32", value) + return errors.Errorf("cannot convert %v to pguint32", value) } return nil @@ -58,7 +58,7 @@ func (src *pguint32) AssignTo(dst interface{}) error { if src.Status == Present { *v = src.Uint } else { - return fmt.Errorf("cannot assign %v into %T", src, dst) + return errors.Errorf("cannot assign %v into %T", src, dst) } case **uint32: if src.Status == Present { @@ -94,7 +94,7 @@ func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 4 { - return fmt.Errorf("invalid length: %v", len(src)) + return errors.Errorf("invalid length: %v", len(src)) } n := binary.BigEndian.Uint32(src) @@ -146,7 +146,7 @@ func (dst *pguint32) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/point.go b/point.go index 3d5d4e1a..3132a939 100644 --- a/point.go +++ b/point.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Vec2 struct { @@ -22,7 +23,7 @@ type Point struct { } func (dst *Point) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Point", src) + return errors.Errorf("cannot convert %v to Point", src) } func (dst *Point) Get() interface{} { @@ -37,7 +38,7 @@ func (dst *Point) Get() interface{} { } func (src *Point) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { @@ -47,12 +48,12 @@ func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 5 { - return fmt.Errorf("invalid length for point: %v", len(src)) + return errors.Errorf("invalid length for point: %v", len(src)) } parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) if len(parts) < 2 { - return fmt.Errorf("invalid format for point") + return errors.Errorf("invalid format for point") } x, err := strconv.ParseFloat(parts[0], 64) @@ -76,7 +77,7 @@ func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 16 { - return fmt.Errorf("invalid length for point: %v", len(src)) + return errors.Errorf("invalid length for point: %v", len(src)) } x := binary.BigEndian.Uint64(src) @@ -129,7 +130,7 @@ func (dst *Point) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/polygon.go b/polygon.go index d0b50061..3f3d9f53 100644 --- a/polygon.go +++ b/polygon.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Polygon struct { @@ -17,7 +18,7 @@ type Polygon struct { } func (dst *Polygon) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Polygon", src) + return errors.Errorf("cannot convert %v to Polygon", src) } func (dst *Polygon) Get() interface{} { @@ -32,7 +33,7 @@ func (dst *Polygon) Get() interface{} { } func (src *Polygon) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Polygon) DecodeText(ci *ConnInfo, src []byte) error { @@ -42,7 +43,7 @@ func (dst *Polygon) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 7 { - return fmt.Errorf("invalid length for Polygon: %v", len(src)) + return errors.Errorf("invalid length for Polygon: %v", len(src)) } points := make([]Vec2, 0) @@ -84,14 +85,14 @@ func (dst *Polygon) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) < 5 { - return fmt.Errorf("invalid length for Polygon: %v", len(src)) + return errors.Errorf("invalid length for Polygon: %v", len(src)) } pointCount := int(binary.BigEndian.Uint32(src)) rp := 4 if 4+pointCount*16 != len(src) { - return fmt.Errorf("invalid length for Polygon with %d points: %v", pointCount, len(src)) + return errors.Errorf("invalid length for Polygon with %d points: %v", pointCount, len(src)) } points := make([]Vec2, pointCount) @@ -164,7 +165,7 @@ func (dst *Polygon) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/qchar.go b/qchar.go index 9c40ce18..064dab1e 100644 --- a/qchar.go +++ b/qchar.go @@ -1,9 +1,10 @@ package pgtype import ( - "fmt" "math" "strconv" + + "github.com/pkg/errors" ) // QChar is for PostgreSQL's special 8-bit-only "char" type more akin to the C @@ -33,59 +34,59 @@ func (dst *QChar) Set(src interface{}) error { *dst = QChar{Int: value, Status: Present} case uint8: if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case int16: if value < math.MinInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case uint16: if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case int32: if value < math.MinInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case uint32: if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case int64: if value < math.MinInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case uint64: if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case int: if value < math.MinInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case uint: if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case string: @@ -98,7 +99,7 @@ func (dst *QChar) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to QChar", value) + return errors.Errorf("cannot convert %v to QChar", value) } return nil @@ -126,7 +127,7 @@ func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 1 { - return fmt.Errorf(`invalid length for "char": %v`, len(src)) + return errors.Errorf(`invalid length for "char": %v`, len(src)) } *dst = QChar{Int: int8(src[0]), Status: Present} diff --git a/range.go b/range.go index 76daf8cc..d870834f 100644 --- a/range.go +++ b/range.go @@ -3,7 +3,8 @@ package pgtype import ( "bytes" "encoding/binary" - "fmt" + + "github.com/pkg/errors" ) type BoundType byte @@ -36,7 +37,7 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { r, _, err := buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid lower bound: %v", err) + return nil, errors.Errorf("invalid lower bound: %v", err) } switch r { case '(': @@ -44,12 +45,12 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { case '[': utr.LowerType = Inclusive default: - return nil, fmt.Errorf("missing lower bound, instead got: %v", string(r)) + return nil, errors.Errorf("missing lower bound, instead got: %v", string(r)) } r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid lower value: %v", err) + return nil, errors.Errorf("invalid lower value: %v", err) } buf.UnreadRune() @@ -58,21 +59,21 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { } else { utr.Lower, err = rangeParseValue(buf) if err != nil { - return nil, fmt.Errorf("invalid lower value: %v", err) + return nil, errors.Errorf("invalid lower value: %v", err) } } r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("missing range separator: %v", err) + return nil, errors.Errorf("missing range separator: %v", err) } if r != ',' { - return nil, fmt.Errorf("missing range separator: %v", r) + return nil, errors.Errorf("missing range separator: %v", r) } r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid upper value: %v", err) + return nil, errors.Errorf("invalid upper value: %v", err) } buf.UnreadRune() @@ -81,13 +82,13 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { } else { utr.Upper, err = rangeParseValue(buf) if err != nil { - return nil, fmt.Errorf("invalid upper value: %v", err) + return nil, errors.Errorf("invalid upper value: %v", err) } } r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("missing upper bound: %v", err) + return nil, errors.Errorf("missing upper bound: %v", err) } switch r { case ')': @@ -95,13 +96,13 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { case ']': utr.UpperType = Inclusive default: - return nil, fmt.Errorf("missing upper bound, instead got: %v", string(r)) + return nil, errors.Errorf("missing upper bound, instead got: %v", string(r)) } skipWhitespace(buf) if buf.Len() > 0 { - return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + return nil, errors.Errorf("unexpected trailing data: %v", buf.String()) } return utr, nil @@ -197,7 +198,7 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { ubr := &UntypedBinaryRange{} if len(src) == 0 { - return nil, fmt.Errorf("range too short: %v", len(src)) + return nil, errors.Errorf("range too short: %v", len(src)) } rangeType := src[0] @@ -205,7 +206,7 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { if rangeType&emptyMask > 0 { if len(src[rp:]) > 0 { - return nil, fmt.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) + return nil, errors.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) } ubr.LowerType = Empty ubr.UpperType = Empty @@ -230,13 +231,13 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { if ubr.LowerType == Unbounded && ubr.UpperType == Unbounded { if len(src[rp:]) > 0 { - return nil, fmt.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) + return nil, errors.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) } return ubr, nil } if len(src[rp:]) < 4 { - return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + return nil, errors.Errorf("too few bytes for size: %v", src[rp:]) } valueLen := int(binary.BigEndian.Uint32(src[rp:])) rp += 4 @@ -249,14 +250,14 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { } else { ubr.Upper = val if len(src[rp:]) > 0 { - return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + return nil, errors.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) } return ubr, nil } if ubr.UpperType != Unbounded { if len(src[rp:]) < 4 { - return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + return nil, errors.Errorf("too few bytes for size: %v", src[rp:]) } valueLen := int(binary.BigEndian.Uint32(src[rp:])) rp += 4 @@ -265,7 +266,7 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { } if len(src[rp:]) > 0 { - return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + return nil, errors.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) } return ubr, nil diff --git a/record.go b/record.go index 7c8736df..14b415c3 100644 --- a/record.go +++ b/record.go @@ -2,7 +2,8 @@ package pgtype import ( "encoding/binary" - "fmt" + + "github.com/pkg/errors" ) // Record is the generic PostgreSQL record type such as is created with the @@ -25,7 +26,7 @@ func (dst *Record) Set(src interface{}) error { case []Value: *dst = Record{Fields: value, Status: Present} default: - return fmt.Errorf("cannot convert %v to Record", src) + return errors.Errorf("cannot convert %v to Record", src) } return nil @@ -65,7 +66,7 @@ func (src *Record) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { @@ -77,7 +78,7 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { rp := 0 if len(src[rp:]) < 4 { - return fmt.Errorf("Record incomplete %v", src) + return errors.Errorf("Record incomplete %v", src) } fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 @@ -86,7 +87,7 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { for i := 0; i < fieldCount; i++ { if len(src[rp:]) < 8 { - return fmt.Errorf("Record incomplete %v", src) + return errors.Errorf("Record incomplete %v", src) } fieldOID := OID(binary.BigEndian.Uint32(src[rp:])) rp += 4 @@ -97,14 +98,14 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { var binaryDecoder BinaryDecoder if dt, ok := ci.DataTypeForOID(fieldOID); ok { if binaryDecoder, ok = dt.Value.(BinaryDecoder); !ok { - return fmt.Errorf("unknown oid while decoding record: %v", fieldOID) + return errors.Errorf("unknown oid while decoding record: %v", fieldOID) } } var fieldBytes []byte if fieldLen >= 0 { if len(src[rp:]) < fieldLen { - return fmt.Errorf("Record incomplete %v", src) + return errors.Errorf("Record incomplete %v", src) } fieldBytes = src[rp : rp+fieldLen] rp += fieldLen diff --git a/text.go b/text.go index 6638c354..f05e1e89 100644 --- a/text.go +++ b/text.go @@ -3,7 +3,8 @@ package pgtype import ( "database/sql/driver" "encoding/json" - "fmt" + + "github.com/pkg/errors" ) type Text struct { @@ -36,7 +37,7 @@ func (dst *Text) Set(src interface{}) error { if originalSrc, ok := underlyingStringType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Text", value) + return errors.Errorf("cannot convert %v to Text", value) } return nil @@ -73,7 +74,7 @@ func (src *Text) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { @@ -121,7 +122,7 @@ func (dst *Text) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/text_array.go b/text_array.go index dab7d36e..2609a2cc 100644 --- a/text_array.go +++ b/text_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type TextArray struct { @@ -40,7 +40,7 @@ func (dst *TextArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Text", value) + return errors.Errorf("cannot convert %v to Text", value) } return nil @@ -80,7 +80,7 @@ func (src *TextArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -233,7 +233,7 @@ func (src *TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("text"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "text") + return nil, errors.Errorf("unable to find oid for type name %v", "text") } for i := range src.Elements { @@ -277,7 +277,7 @@ func (dst *TextArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/tid.go b/tid.go index d44ea3a6..21852a14 100644 --- a/tid.go +++ b/tid.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) // TID is PostgreSQL's Tuple Identifier type. @@ -28,7 +29,7 @@ type TID struct { } func (dst *TID) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to TID", src) + return errors.Errorf("cannot convert %v to TID", src) } func (dst *TID) Get() interface{} { @@ -43,7 +44,7 @@ func (dst *TID) Get() interface{} { } func (src *TID) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *TID) DecodeText(ci *ConnInfo, src []byte) error { @@ -53,12 +54,12 @@ func (dst *TID) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 5 { - return fmt.Errorf("invalid length for tid: %v", len(src)) + return errors.Errorf("invalid length for tid: %v", len(src)) } parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) if len(parts) < 2 { - return fmt.Errorf("invalid format for tid") + return errors.Errorf("invalid format for tid") } blockNumber, err := strconv.ParseUint(parts[0], 10, 32) @@ -82,7 +83,7 @@ func (dst *TID) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 6 { - return fmt.Errorf("invalid length for tid: %v", len(src)) + return errors.Errorf("invalid length for tid: %v", len(src)) } *dst = TID{ @@ -134,7 +135,7 @@ func (dst *TID) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/timestamp.go b/timestamp.go index 75c6cffa..d906f467 100644 --- a/timestamp.go +++ b/timestamp.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "time" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) const pgTimestampFormat = "2006-01-02 15:04:05.999999999" @@ -37,7 +37,7 @@ func (dst *Timestamp) Set(src interface{}) error { if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Timestamp", value) + return errors.Errorf("cannot convert %v to Timestamp", value) } return nil @@ -63,7 +63,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: if src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } *v = src.Time return nil @@ -76,7 +76,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } // DecodeText decodes from src into dst. The decoded time is considered to @@ -114,7 +114,7 @@ func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 { - return fmt.Errorf("invalid length for timestamp: %v", len(src)) + return errors.Errorf("invalid length for timestamp: %v", len(src)) } microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) @@ -143,7 +143,7 @@ func (src *Timestamp) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } if src.Time.Location() != time.UTC { - return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") + return nil, errors.Errorf("cannot encode non-UTC time into timestamp") } var s string @@ -170,7 +170,7 @@ func (src *Timestamp) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } if src.Time.Location() != time.UTC { - return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") + return nil, errors.Errorf("cannot encode non-UTC time into timestamp") } var microsecSinceY2K int64 @@ -206,7 +206,7 @@ func (dst *Timestamp) Scan(src interface{}) error { return nil } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/timestamp_array.go b/timestamp_array.go index fca9ad93..be281f2e 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "time" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type TimestampArray struct { @@ -41,7 +41,7 @@ func (dst *TimestampArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Timestamp", value) + return errors.Errorf("cannot convert %v to Timestamp", value) } return nil @@ -81,7 +81,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -234,7 +234,7 @@ func (src *TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error if dt, ok := ci.DataTypeForName("timestamp"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "timestamp") + return nil, errors.Errorf("unable to find oid for type name %v", "timestamp") } for i := range src.Elements { @@ -278,7 +278,7 @@ func (dst *TimestampArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/timestamptz.go b/timestamptz.go index 97b0de2a..74fe4954 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "time" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" @@ -38,7 +38,7 @@ func (dst *Timestamptz) Set(src interface{}) error { if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Timestamptz", value) + return errors.Errorf("cannot convert %v to Timestamptz", value) } return nil @@ -64,7 +64,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: if src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } *v = src.Time return nil @@ -77,7 +77,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { @@ -120,7 +120,7 @@ func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 { - return fmt.Errorf("invalid length for timestamptz: %v", len(src)) + return errors.Errorf("invalid length for timestamptz: %v", len(src)) } microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) @@ -202,7 +202,7 @@ func (dst *Timestamptz) Scan(src interface{}) error { return nil } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/timestamptz_array.go b/timestamptz_array.go index e0866d69..086a4ef0 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "time" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type TimestamptzArray struct { @@ -41,7 +41,7 @@ func (dst *TimestamptzArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Timestamptz", value) + return errors.Errorf("cannot convert %v to Timestamptz", value) } return nil @@ -81,7 +81,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -234,7 +234,7 @@ func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err if dt, ok := ci.DataTypeForName("timestamptz"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "timestamptz") + return nil, errors.Errorf("unable to find oid for type name %v", "timestamptz") } for i := range src.Elements { @@ -278,7 +278,7 @@ func (dst *TimestamptzArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/tsrange.go b/tsrange.go index 783fb086..8a67d65e 100644 --- a/tsrange.go +++ b/tsrange.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Tsrange struct { @@ -16,7 +16,7 @@ type Tsrange struct { } func (dst *Tsrange) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Tsrange", src) + return errors.Errorf("cannot convert %v to Tsrange", src) } func (dst *Tsrange) Get() interface{} { @@ -31,7 +31,7 @@ func (dst *Tsrange) Get() interface{} { } func (src *Tsrange) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Tsrange) DecodeText(ci *ConnInfo, src []byte) error { @@ -120,7 +120,7 @@ func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -130,7 +130,7 @@ func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -141,7 +141,7 @@ func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -151,7 +151,7 @@ func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -175,7 +175,7 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -185,7 +185,7 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -201,7 +201,7 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -216,7 +216,7 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -241,7 +241,7 @@ func (dst *Tsrange) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/tstzrange.go b/tstzrange.go index 8fd3fd68..b5129093 100644 --- a/tstzrange.go +++ b/tstzrange.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Tstzrange struct { @@ -16,7 +16,7 @@ type Tstzrange struct { } func (dst *Tstzrange) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Tstzrange", src) + return errors.Errorf("cannot convert %v to Tstzrange", src) } func (dst *Tstzrange) Get() interface{} { @@ -31,7 +31,7 @@ func (dst *Tstzrange) Get() interface{} { } func (src *Tstzrange) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Tstzrange) DecodeText(ci *ConnInfo, src []byte) error { @@ -120,7 +120,7 @@ func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -130,7 +130,7 @@ func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -141,7 +141,7 @@ func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -151,7 +151,7 @@ func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -175,7 +175,7 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -185,7 +185,7 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -201,7 +201,7 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -216,7 +216,7 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -241,7 +241,7 @@ func (dst *Tstzrange) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/typed_array.go.erb b/typed_array.go.erb index 01072549..7a69d0ab 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -40,7 +40,7 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to <%= pgtype_element_type %>", value) + return errors.Errorf("cannot convert %v to <%= pgtype_element_type %>", value) } return nil @@ -80,7 +80,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { @@ -236,7 +236,7 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byt if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") + return nil, errors.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") } for i := range src.Elements { @@ -281,7 +281,7 @@ func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/typed_range.go.erb b/typed_range.go.erb index 90c23991..91a5cb97 100644 --- a/typed_range.go.erb +++ b/typed_range.go.erb @@ -18,7 +18,7 @@ type <%= range_type %> struct { } func (dst *<%= range_type %>) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to <%= range_type %>", src) + return errors.Errorf("cannot convert %v to <%= range_type %>", src) } func (dst *<%= range_type %>) Get() interface{} { @@ -33,7 +33,7 @@ func (dst *<%= range_type %>) Get() interface{} { } func (src *<%= range_type %>) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *<%= range_type %>) DecodeText(ci *ConnInfo, src []byte) error { @@ -122,7 +122,7 @@ func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error case Empty: return append(buf, "empty"...), nil default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -132,7 +132,7 @@ func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -143,7 +143,7 @@ func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -153,7 +153,7 @@ func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error case Inclusive: buf = append(buf, ']') default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -177,7 +177,7 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err case Empty: return append(buf, emptyMask), nil default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -187,7 +187,7 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err rangeType |= upperUnboundedMask case Exclusive: default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -203,7 +203,7 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err return nil, err } if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -218,7 +218,7 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err return nil, err } if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -243,7 +243,7 @@ func (dst *<%= range_type %>) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/uuid.go b/uuid.go index d1ab1a38..33e79536 100644 --- a/uuid.go +++ b/uuid.go @@ -4,6 +4,8 @@ import ( "database/sql/driver" "encoding/hex" "fmt" + + "github.com/pkg/errors" ) type UUID struct { @@ -17,7 +19,7 @@ func (dst *UUID) Set(src interface{}) error { *dst = UUID{Bytes: value, Status: Present} case []byte: if len(value) != 16 { - return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) } *dst = UUID{Status: Present} copy(dst.Bytes[:], value) @@ -31,7 +33,7 @@ func (dst *UUID) Set(src interface{}) error { if originalSrc, ok := underlyingPtrType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to UUID", value) + return errors.Errorf("cannot convert %v to UUID", value) } return nil @@ -71,7 +73,7 @@ func (src *UUID) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot assign %v into %T", src, dst) + return errors.Errorf("cannot assign %v into %T", src, dst) } // parseUUID converts a string UUID in standard form to a byte array. @@ -98,7 +100,7 @@ func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) != 36 { - return fmt.Errorf("invalid length for UUID: %v", len(src)) + return errors.Errorf("invalid length for UUID: %v", len(src)) } buf, err := parseUUID(string(src)) @@ -117,7 +119,7 @@ func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 16 { - return fmt.Errorf("invalid length for UUID: %v", len(src)) + return errors.Errorf("invalid length for UUID: %v", len(src)) } *dst = UUID{Status: Present} @@ -163,7 +165,7 @@ func (dst *UUID) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/varbit.go b/varbit.go index 9a9fe1e1..dfa194d2 100644 --- a/varbit.go +++ b/varbit.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Varbit struct { @@ -15,7 +15,7 @@ type Varbit struct { } func (dst *Varbit) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Varbit", src) + return errors.Errorf("cannot convert %v to Varbit", src) } func (dst *Varbit) Get() interface{} { @@ -30,7 +30,7 @@ func (dst *Varbit) Get() interface{} { } func (src *Varbit) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Varbit) DecodeText(ci *ConnInfo, src []byte) error { @@ -65,7 +65,7 @@ func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) < 4 { - return fmt.Errorf("invalid length for varbit: %v", len(src)) + return errors.Errorf("invalid length for varbit: %v", len(src)) } bitLen := int32(binary.BigEndian.Uint32(src)) @@ -124,7 +124,7 @@ func (dst *Varbit) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/varchar_array.go b/varchar_array.go index 95b5cfc1..fecbb2e5 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type VarcharArray struct { @@ -40,7 +40,7 @@ func (dst *VarcharArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Varchar", value) + return errors.Errorf("cannot convert %v to Varchar", value) } return nil @@ -80,7 +80,7 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -233,7 +233,7 @@ func (src *VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) if dt, ok := ci.DataTypeForName("varchar"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "varchar") + return nil, errors.Errorf("unable to find oid for type name %v", "varchar") } for i := range src.Elements { @@ -277,7 +277,7 @@ func (dst *VarcharArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. From ffa9ff221392368427bf369ced0a3c104bc37c02 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 4 Jun 2017 21:30:03 -0500 Subject: [PATCH 0108/1158] Use github.com/pkg/errors --- authentication.go | 4 ++-- backend.go | 4 ++-- frontend.go | 4 ++-- startup_message.go | 10 +++++----- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/authentication.go b/authentication.go index c04ee448..77750b86 100644 --- a/authentication.go +++ b/authentication.go @@ -2,9 +2,9 @@ package pgproto3 import ( "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) const ( @@ -31,7 +31,7 @@ func (dst *Authentication) Decode(src []byte) error { case AuthTypeMD5Password: copy(dst.Salt[:], src[4:8]) default: - return fmt.Errorf("unknown authentication type: %d", dst.Type) + return errors.Errorf("unknown authentication type: %d", dst.Type) } return nil diff --git a/backend.go b/backend.go index bf96ba95..9a7ef342 100644 --- a/backend.go +++ b/backend.go @@ -2,10 +2,10 @@ package pgproto3 import ( "encoding/binary" - "fmt" "io" "github.com/jackc/pgx/chunkreader" + "github.com/pkg/errors" ) type Backend struct { @@ -88,7 +88,7 @@ func (b *Backend) Receive() (FrontendMessage, error) { case 'X': msg = &b.terminate default: - return nil, fmt.Errorf("unknown message type: %c", msgType) + return nil, errors.Errorf("unknown message type: %c", msgType) } msgBody, err := b.cr.Next(bodyLen) diff --git a/frontend.go b/frontend.go index 630a5cba..c8ab5f15 100644 --- a/frontend.go +++ b/frontend.go @@ -2,10 +2,10 @@ package pgproto3 import ( "encoding/binary" - "fmt" "io" "github.com/jackc/pgx/chunkreader" + "github.com/pkg/errors" ) type Frontend struct { @@ -100,7 +100,7 @@ func (b *Frontend) Receive() (BackendMessage, error) { case 'Z': msg = &b.readyForQuery default: - return nil, fmt.Errorf("unknown message type: %c", msgType) + return nil, errors.Errorf("unknown message type: %c", msgType) } msgBody, err := b.cr.Next(bodyLen) diff --git a/startup_message.go b/startup_message.go index 4e2df27d..6c5d4f99 100644 --- a/startup_message.go +++ b/startup_message.go @@ -4,9 +4,9 @@ import ( "bytes" "encoding/binary" "encoding/json" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) const ( @@ -23,18 +23,18 @@ func (*StartupMessage) Frontend() {} func (dst *StartupMessage) Decode(src []byte) error { if len(src) < 4 { - return fmt.Errorf("startup message too short") + return errors.Errorf("startup message too short") } dst.ProtocolVersion = binary.BigEndian.Uint32(src) rp := 4 if dst.ProtocolVersion == sslRequestNumber { - return fmt.Errorf("can't handle ssl connection request") + return errors.Errorf("can't handle ssl connection request") } if dst.ProtocolVersion != ProtocolVersionNumber { - return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) + return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) } dst.Parameters = make(map[string]string) @@ -57,7 +57,7 @@ func (dst *StartupMessage) Decode(src []byte) error { if len(src[rp:]) == 1 { if src[rp] != 0 { - return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) + return errors.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) } break } From 10fa3a64977dec30b25fba7ce6b35cbc275d0c2e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Aug 2017 16:40:18 -0500 Subject: [PATCH 0109/1158] Return error on MarshalJSON of status Undefined Previously "undefined" was returned as a value. While this is a valid JavaScript value, it is not valid JSON. --- int2.go | 2 +- int4.go | 2 +- int8.go | 2 +- text.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/int2.go b/int2.go index 45bce93c..6156ea77 100644 --- a/int2.go +++ b/int2.go @@ -202,7 +202,7 @@ func (src *Int2) MarshalJSON() ([]byte, error) { case Null: return []byte("null"), nil case Undefined: - return []byte("undefined"), nil + return nil, errUndefined } return nil, errBadStatus diff --git a/int4.go b/int4.go index a3499fef..37d00511 100644 --- a/int4.go +++ b/int4.go @@ -193,7 +193,7 @@ func (src *Int4) MarshalJSON() ([]byte, error) { case Null: return []byte("null"), nil case Undefined: - return []byte("undefined"), nil + return nil, errUndefined } return nil, errBadStatus diff --git a/int8.go b/int8.go index d671eda7..17a676eb 100644 --- a/int8.go +++ b/int8.go @@ -179,7 +179,7 @@ func (src *Int8) MarshalJSON() ([]byte, error) { case Null: return []byte("null"), nil case Undefined: - return []byte("undefined"), nil + return nil, errUndefined } return nil, errBadStatus diff --git a/text.go b/text.go index f05e1e89..e7fba682 100644 --- a/text.go +++ b/text.go @@ -144,7 +144,7 @@ func (src *Text) MarshalJSON() ([]byte, error) { case Null: return []byte("null"), nil case Undefined: - return []byte("undefined"), nil + return nil, errUndefined } return nil, errBadStatus From f18a22e066785802cd97cc58805fa4a95084a83a Mon Sep 17 00:00:00 2001 From: Wei Congrui Date: Fri, 18 Aug 2017 15:20:39 +0800 Subject: [PATCH 0110/1158] Fix numeric EncodeBinary bug --- numeric.go | 10 +++++++--- numeric_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/numeric.go b/numeric.go index fded6359..fb63df75 100644 --- a/numeric.go +++ b/numeric.go @@ -16,6 +16,7 @@ import ( const nbase = 10000 var big0 *big.Int = big.NewInt(0) +var big1 *big.Int = big.NewInt(1) var big10 *big.Int = big.NewInt(10) var big100 *big.Int = big.NewInt(100) var big1000 *big.Int = big.NewInt(1000) @@ -507,6 +508,7 @@ func (src *Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { divisor := &big.Int{} divisor.Exp(big10, big.NewInt(int64(-exp)), nil) wholePart.DivMod(absInt, divisor, fracPart) + fracPart.Add(fracPart, divisor) } else { wholePart = absInt } @@ -518,9 +520,11 @@ func (src *Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { wholeDigits = append(wholeDigits, int16(remainder.Int64())) } - for fracPart.Cmp(big0) != 0 { - fracPart.DivMod(fracPart, bigNBase, remainder) - fracDigits = append(fracDigits, int16(remainder.Int64())) + if fracPart.Cmp(big0) != 0 { + for fracPart.Cmp(big1) != 0 { + fracPart.DivMod(fracPart, bigNBase, remainder) + fracDigits = append(fracDigits, int16(remainder.Int64())) + } } buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits))) diff --git a/numeric_test.go b/numeric_test.go index 5f3a3416..9d7d83d6 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -317,3 +317,39 @@ func TestNumericAssignTo(t *testing.T) { } } } + +func TestNumericEncodeDecodeBinary(t *testing.T) { + ci := pgtype.NewConnInfo() + tests := []interface{}{ + 123, + 0.000012345, + 1.00002345, + } + + for i, tt := range tests { + toString := func(n *pgtype.Numeric) string { + ci := pgtype.NewConnInfo() + text, err := n.EncodeText(ci, nil) + if err != nil { + t.Errorf("%d: %v", i, err) + } + return string(text) + } + numeric := &pgtype.Numeric{} + numeric.Set(tt) + + encoded, err := numeric.EncodeBinary(ci, nil) + if err != nil { + t.Errorf("%d: %v", i, err) + } + decoded := &pgtype.Numeric{} + decoded.DecodeBinary(ci, encoded) + + text0 := toString(numeric) + text1 := toString(decoded) + + if text0 != text1 { + t.Errorf("%d: expected %v to equal to %v, but doesn't", i, text0, text1) + } + } +} From 43c2b979d0ac55a31677f92f77a77d185c3853b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Peignier?= Date: Fri, 18 Aug 2017 18:22:08 -0700 Subject: [PATCH 0111/1158] Add more ColumnType support --- pgtype.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pgtype.go b/pgtype.go index 6f8e7986..00175a30 100644 --- a/pgtype.go +++ b/pgtype.go @@ -46,6 +46,7 @@ const ( DateArrayOID = 1182 TimestamptzOID = 1184 TimestamptzArrayOID = 1185 + NumericOID = 1700 RecordOID = 2249 UUIDOID = 2950 JSONBOID = 3802 From 2dfcf74f6241620c90fb34d24a2a6ea4e1580001 Mon Sep 17 00:00:00 2001 From: Kelsey Francis Date: Sun, 27 Aug 2017 19:31:22 -0700 Subject: [PATCH 0112/1158] Add UUIDArray type Also change UUID.Set() to convert nil to NULL in order for UUIDArray.Set() to support converting [][]byte slices that contain nil. --- pgtype.go | 2 + uuid.go | 17 ++- uuid_array.go | 355 +++++++++++++++++++++++++++++++++++++++++++++ uuid_array_test.go | 205 ++++++++++++++++++++++++++ uuid_test.go | 8 + 5 files changed, 583 insertions(+), 4 deletions(-) create mode 100644 uuid_array.go create mode 100644 uuid_array_test.go diff --git a/pgtype.go b/pgtype.go index 00175a30..be13ec77 100644 --- a/pgtype.go +++ b/pgtype.go @@ -49,6 +49,7 @@ const ( NumericOID = 1700 RecordOID = 2249 UUIDOID = 2950 + UUIDArrayOID = 2951 JSONBOID = 3802 ) @@ -223,6 +224,7 @@ func init() { "_text": &TextArray{}, "_timestamp": &TimestampArray{}, "_timestamptz": &TimestamptzArray{}, + "_uuid": &UUIDArray{}, "_varchar": &VarcharArray{}, "aclitem": &ACLItem{}, "bool": &Bool{}, diff --git a/uuid.go b/uuid.go index 33e79536..f8297b39 100644 --- a/uuid.go +++ b/uuid.go @@ -14,15 +14,24 @@ type UUID struct { } func (dst *UUID) Set(src interface{}) error { + if src == nil { + *dst = UUID{Status: Null} + return nil + } + switch value := src.(type) { case [16]byte: *dst = UUID{Bytes: value, Status: Present} case []byte: - if len(value) != 16 { - return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + if value != nil { + if len(value) != 16 { + return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + } + *dst = UUID{Status: Present} + copy(dst.Bytes[:], value) + } else { + *dst = UUID{Status: Null} } - *dst = UUID{Status: Present} - copy(dst.Bytes[:], value) case string: uuid, err := parseUUID(value) if err != nil { diff --git a/uuid_array.go b/uuid_array.go new file mode 100644 index 00000000..c18aec4f --- /dev/null +++ b/uuid_array.go @@ -0,0 +1,355 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type UUIDArray struct { + Elements []UUID + Dimensions []ArrayDimension + Status Status +} + +func (dst *UUIDArray) Set(src interface{}) error { + if src == nil { + *dst = UUIDArray{Status: Null} + return nil + } + + switch value := src.(type) { + + case [][16]byte: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case [][]byte: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []string: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to UUIDArray", value) + } + + return nil +} + +func (dst *UUIDArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *UUIDArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[][16]byte: + *v = make([][16]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[][]byte: + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *UUIDArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = UUIDArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []UUID + + if len(uta.Elements) > 0 { + elements = make([]UUID, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem UUID + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = UUIDArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *UUIDArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = UUIDArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = UUIDArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]UUID, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = UUIDArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *UUIDArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("uuid"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "uuid") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *UUIDArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *UUIDArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/uuid_array_test.go b/uuid_array_test.go new file mode 100644 index 00000000..ee9d3dfa --- /dev/null +++ b/uuid_array_test.go @@ -0,0 +1,205 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestUUIDArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid[]", []interface{}{ + &pgtype.UUIDArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.UUIDArray{Status: pgtype.Null}, + &pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestUUIDArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.UUIDArray + }{ + { + source: nil, + result: pgtype.UUIDArray{Status: pgtype.Null}, + }, + { + source: [][16]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][16]byte{}, + result: pgtype.UUIDArray{Status: pgtype.Present}, + }, + { + source: ([][16]byte)(nil), + result: pgtype.UUIDArray{Status: pgtype.Null}, + }, + { + source: [][]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][]byte{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, + nil, + {32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, + }, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 4}}, + Status: pgtype.Present}, + }, + { + source: [][]byte{}, + result: pgtype.UUIDArray{Status: pgtype.Present}, + }, + { + source: ([][]byte)(nil), + result: pgtype.UUIDArray{Status: pgtype.Null}, + }, + { + source: []string{"00010203-0405-0607-0809-0a0b0c0d0e0f"}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []string{}, + result: pgtype.UUIDArray{Status: pgtype.Present}, + }, + { + source: ([]string)(nil), + result: pgtype.UUIDArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.UUIDArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestUUIDArrayAssignTo(t *testing.T) { + var byteArraySlice [][16]byte + var byteSliceSlice [][]byte + var stringSlice []string + + simpleTests := []struct { + src pgtype.UUIDArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &byteArraySlice, + expected: [][16]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + }, + { + src: pgtype.UUIDArray{Status: pgtype.Null}, + dst: &byteArraySlice, + expected: ([][16]byte)(nil), + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &byteSliceSlice, + expected: [][]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + }, + { + src: pgtype.UUIDArray{Status: pgtype.Null}, + dst: &byteSliceSlice, + expected: ([][]byte)(nil), + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"00010203-0405-0607-0809-0a0b0c0d0e0f"}, + }, + { + src: pgtype.UUIDArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: ([]string)(nil), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/uuid_test.go b/uuid_test.go index 5ab52b35..162d999f 100644 --- a/uuid_test.go +++ b/uuid_test.go @@ -20,6 +20,10 @@ func TestUUIDSet(t *testing.T) { source interface{} result pgtype.UUID }{ + { + source: nil, + result: pgtype.UUID{Status: pgtype.Null}, + }, { source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, @@ -28,6 +32,10 @@ func TestUUIDSet(t *testing.T) { source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, + { + source: ([]byte)(nil), + result: pgtype.UUID{Status: pgtype.Null}, + }, { source: "00010203-0405-0607-0809-0a0b0c0d0e0f", result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, From 703ce85513a224247c1b4b1dc92dfb799bda46f8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 29 Aug 2017 14:33:25 -0500 Subject: [PATCH 0113/1158] Generate UUIDArray from template - Fix error in Set - Specifically handle untyped nil --- aclitem_array.go | 8 +++++++- bool_array.go | 8 +++++++- bytea_array.go | 8 +++++++- cidr_array.go | 8 +++++++- date_array.go | 8 +++++++- float4_array.go | 8 +++++++- float8_array.go | 8 +++++++- hstore_array.go | 8 +++++++- inet_array.go | 8 +++++++- int2_array.go | 8 +++++++- int4_array.go | 8 +++++++- int8_array.go | 8 +++++++- numeric_array.go | 8 +++++++- text_array.go | 8 +++++++- timestamp_array.go | 8 +++++++- timestamptz_array.go | 8 +++++++- typed_array.go.erb | 8 +++++++- typed_array_gen.sh | 1 + uuid_array.go | 3 ++- varchar_array.go | 8 +++++++- 20 files changed, 129 insertions(+), 19 deletions(-) diff --git a/aclitem_array.go b/aclitem_array.go index fe0af434..0a829295 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -13,6 +13,12 @@ type ACLItemArray struct { } func (dst *ACLItemArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = ACLItemArray{Status: Null} + return nil + } + switch value := src.(type) { case []string: @@ -38,7 +44,7 @@ func (dst *ACLItemArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to ACLItem", value) + return errors.Errorf("cannot convert %v to ACLItemArray", value) } return nil diff --git a/bool_array.go b/bool_array.go index e23c27e5..67dd92a7 100644 --- a/bool_array.go +++ b/bool_array.go @@ -15,6 +15,12 @@ type BoolArray struct { } func (dst *BoolArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = BoolArray{Status: Null} + return nil + } + switch value := src.(type) { case []bool: @@ -40,7 +46,7 @@ func (dst *BoolArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Bool", value) + return errors.Errorf("cannot convert %v to BoolArray", value) } return nil diff --git a/bytea_array.go b/bytea_array.go index f2842179..c8eb5669 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -15,6 +15,12 @@ type ByteaArray struct { } func (dst *ByteaArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = ByteaArray{Status: Null} + return nil + } + switch value := src.(type) { case [][]byte: @@ -40,7 +46,7 @@ func (dst *ByteaArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Bytea", value) + return errors.Errorf("cannot convert %v to ByteaArray", value) } return nil diff --git a/cidr_array.go b/cidr_array.go index 2373da46..e4bb7614 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -16,6 +16,12 @@ type CIDRArray struct { } func (dst *CIDRArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = CIDRArray{Status: Null} + return nil + } + switch value := src.(type) { case []*net.IPNet: @@ -60,7 +66,7 @@ func (dst *CIDRArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to CIDR", value) + return errors.Errorf("cannot convert %v to CIDRArray", value) } return nil diff --git a/date_array.go b/date_array.go index 383945e7..0cb64581 100644 --- a/date_array.go +++ b/date_array.go @@ -16,6 +16,12 @@ type DateArray struct { } func (dst *DateArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = DateArray{Status: Null} + return nil + } + switch value := src.(type) { case []time.Time: @@ -41,7 +47,7 @@ func (dst *DateArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Date", value) + return errors.Errorf("cannot convert %v to DateArray", value) } return nil diff --git a/float4_array.go b/float4_array.go index 6499064b..02c28caa 100644 --- a/float4_array.go +++ b/float4_array.go @@ -15,6 +15,12 @@ type Float4Array struct { } func (dst *Float4Array) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Float4Array{Status: Null} + return nil + } + switch value := src.(type) { case []float32: @@ -40,7 +46,7 @@ func (dst *Float4Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Float4", value) + return errors.Errorf("cannot convert %v to Float4Array", value) } return nil diff --git a/float8_array.go b/float8_array.go index 27b24836..b92a8205 100644 --- a/float8_array.go +++ b/float8_array.go @@ -15,6 +15,12 @@ type Float8Array struct { } func (dst *Float8Array) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Float8Array{Status: Null} + return nil + } + switch value := src.(type) { case []float64: @@ -40,7 +46,7 @@ func (dst *Float8Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Float8", value) + return errors.Errorf("cannot convert %v to Float8Array", value) } return nil diff --git a/hstore_array.go b/hstore_array.go index 38ce457b..80530c26 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -15,6 +15,12 @@ type HstoreArray struct { } func (dst *HstoreArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = HstoreArray{Status: Null} + return nil + } + switch value := src.(type) { case []map[string]string: @@ -40,7 +46,7 @@ func (dst *HstoreArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Hstore", value) + return errors.Errorf("cannot convert %v to HstoreArray", value) } return nil diff --git a/inet_array.go b/inet_array.go index 3ece23eb..f3e4efbf 100644 --- a/inet_array.go +++ b/inet_array.go @@ -16,6 +16,12 @@ type InetArray struct { } func (dst *InetArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = InetArray{Status: Null} + return nil + } + switch value := src.(type) { case []*net.IPNet: @@ -60,7 +66,7 @@ func (dst *InetArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Inet", value) + return errors.Errorf("cannot convert %v to InetArray", value) } return nil diff --git a/int2_array.go b/int2_array.go index e939411b..f50d9275 100644 --- a/int2_array.go +++ b/int2_array.go @@ -15,6 +15,12 @@ type Int2Array struct { } func (dst *Int2Array) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int2Array{Status: Null} + return nil + } + switch value := src.(type) { case []int16: @@ -59,7 +65,7 @@ func (dst *Int2Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int2", value) + return errors.Errorf("cannot convert %v to Int2Array", value) } return nil diff --git a/int4_array.go b/int4_array.go index 1a907d2e..6c9418ba 100644 --- a/int4_array.go +++ b/int4_array.go @@ -15,6 +15,12 @@ type Int4Array struct { } func (dst *Int4Array) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int4Array{Status: Null} + return nil + } + switch value := src.(type) { case []int32: @@ -59,7 +65,7 @@ func (dst *Int4Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int4", value) + return errors.Errorf("cannot convert %v to Int4Array", value) } return nil diff --git a/int8_array.go b/int8_array.go index 4f3ab4dc..bb6ce004 100644 --- a/int8_array.go +++ b/int8_array.go @@ -15,6 +15,12 @@ type Int8Array struct { } func (dst *Int8Array) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int8Array{Status: Null} + return nil + } + switch value := src.(type) { case []int64: @@ -59,7 +65,7 @@ func (dst *Int8Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int8", value) + return errors.Errorf("cannot convert %v to Int8Array", value) } return nil diff --git a/numeric_array.go b/numeric_array.go index 6dfbe5e3..d991234a 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -15,6 +15,12 @@ type NumericArray struct { } func (dst *NumericArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = NumericArray{Status: Null} + return nil + } + switch value := src.(type) { case []float32: @@ -59,7 +65,7 @@ func (dst *NumericArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Numeric", value) + return errors.Errorf("cannot convert %v to NumericArray", value) } return nil diff --git a/text_array.go b/text_array.go index 2609a2cc..e40f4b86 100644 --- a/text_array.go +++ b/text_array.go @@ -15,6 +15,12 @@ type TextArray struct { } func (dst *TextArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = TextArray{Status: Null} + return nil + } + switch value := src.(type) { case []string: @@ -40,7 +46,7 @@ func (dst *TextArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Text", value) + return errors.Errorf("cannot convert %v to TextArray", value) } return nil diff --git a/timestamp_array.go b/timestamp_array.go index be281f2e..546a3810 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -16,6 +16,12 @@ type TimestampArray struct { } func (dst *TimestampArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = TimestampArray{Status: Null} + return nil + } + switch value := src.(type) { case []time.Time: @@ -41,7 +47,7 @@ func (dst *TimestampArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Timestamp", value) + return errors.Errorf("cannot convert %v to TimestampArray", value) } return nil diff --git a/timestamptz_array.go b/timestamptz_array.go index 086a4ef0..88b6cc5f 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -16,6 +16,12 @@ type TimestamptzArray struct { } func (dst *TimestamptzArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = TimestamptzArray{Status: Null} + return nil + } + switch value := src.(type) { case []time.Time: @@ -41,7 +47,7 @@ func (dst *TimestamptzArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Timestamptz", value) + return errors.Errorf("cannot convert %v to TimestamptzArray", value) } return nil diff --git a/typed_array.go.erb b/typed_array.go.erb index 7a69d0ab..6fafc2df 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -15,6 +15,12 @@ type <%= pgtype_array_type %> struct { } func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = <%= pgtype_array_type %>{Status: Null} + return nil + } + switch value := src.(type) { <% go_array_types.split(",").each do |t| %> case <%= t %>: @@ -40,7 +46,7 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to <%= pgtype_element_type %>", value) + return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>", value) } return nil diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 1aa6c354..80ece93c 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -15,4 +15,5 @@ erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]by erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]float64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go +erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go goimports -w *_array.go diff --git a/uuid_array.go b/uuid_array.go index c18aec4f..9c7843a7 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -15,6 +15,7 @@ type UUIDArray struct { } func (dst *UUIDArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different if src == nil { *dst = UUIDArray{Status: Null} return nil @@ -80,7 +81,7 @@ func (dst *UUIDArray) Set(src interface{}) error { } default: - if originalSrc, ok := underlyingPtrType(src); ok { + if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } return errors.Errorf("cannot convert %v to UUIDArray", value) diff --git a/varchar_array.go b/varchar_array.go index fecbb2e5..09eba3ea 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -15,6 +15,12 @@ type VarcharArray struct { } func (dst *VarcharArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = VarcharArray{Status: Null} + return nil + } + switch value := src.(type) { case []string: @@ -40,7 +46,7 @@ func (dst *VarcharArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Varchar", value) + return errors.Errorf("cannot convert %v to VarcharArray", value) } return nil From 2e630dddf9b2ebf2301b92c934f90c3b51e1b439 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 29 Aug 2017 15:37:22 -0500 Subject: [PATCH 0114/1158] Fix decoding row with same type values Row decoding was reusing and returning connection owned values for decoding. Instead allocate new value each time. fixes #313 --- record.go | 4 ++++ record_test.go | 12 +++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/record.go b/record.go index 14b415c3..26411af2 100644 --- a/record.go +++ b/record.go @@ -2,6 +2,7 @@ package pgtype import ( "encoding/binary" + "reflect" "github.com/pkg/errors" ) @@ -111,6 +112,9 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { rp += fieldLen } + // Duplicate struct to scan into + binaryDecoder = reflect.New(reflect.ValueOf(binaryDecoder).Elem().Type()).Interface().(BinaryDecoder) + if err := binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { return err } diff --git a/record_test.go b/record_test.go index df17501f..dc01cbbf 100644 --- a/record_test.go +++ b/record_test.go @@ -35,6 +35,16 @@ func TestRecordTranscode(t *testing.T) { Status: pgtype.Present, }, }, + { + sql: `select row(100.0::float4, 1.09::float4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Float4{Float: 100, Status: pgtype.Present}, + &pgtype.Float4{Float: 1.09, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, { sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, expected: pgtype.Record{ @@ -87,7 +97,7 @@ func TestRecordTranscode(t *testing.T) { } if !reflect.DeepEqual(tt.expected, result) { - t.Errorf("%d: expected %v, got %v", i, tt.expected, result) + t.Errorf("%d: expected %#v, got %#v", i, tt.expected, result) } } } From 3453586e891009352749b44c3e0a3b50fed6c36a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 29 Sep 2017 15:25:53 -0500 Subject: [PATCH 0115/1158] Add UnmarshalJSON to a few types --- int4.go | 13 +++++++++++++ text.go | 12 ++++++++++++ varchar.go | 4 ++++ 3 files changed, 29 insertions(+) diff --git a/int4.go b/int4.go index 37d00511..261c5118 100644 --- a/int4.go +++ b/int4.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "math" "strconv" @@ -198,3 +199,15 @@ func (src *Int4) MarshalJSON() ([]byte, error) { return nil, errBadStatus } + +func (dst *Int4) UnmarshalJSON(b []byte) error { + var n int32 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + *dst = Int4{Int: n, Status: Present} + + return nil +} diff --git a/text.go b/text.go index e7fba682..bceeffd4 100644 --- a/text.go +++ b/text.go @@ -149,3 +149,15 @@ func (src *Text) MarshalJSON() ([]byte, error) { return nil, errBadStatus } + +func (dst *Text) UnmarshalJSON(b []byte) error { + var s string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + *dst = Text{String: s, Status: Present} + + return nil +} diff --git a/varchar.go b/varchar.go index 371efd7e..6be1a035 100644 --- a/varchar.go +++ b/varchar.go @@ -52,3 +52,7 @@ func (src *Varchar) Value() (driver.Value, error) { func (src *Varchar) MarshalJSON() ([]byte, error) { return (*Text)(src).MarshalJSON() } + +func (dst *Varchar) UnmarshalJSON(b []byte) error { + return (*Text)(dst).UnmarshalJSON(b) +} From 5ba28cf2c5b58b46efc0c4d3f0bffcb46cf35adf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 17 Oct 2017 20:24:55 -0500 Subject: [PATCH 0116/1158] Add support for array of enum fixes #338 --- enum_array.go | 212 +++++++++++++++++++++++++++++++++++++++++++++ enum_array_test.go | 150 ++++++++++++++++++++++++++++++++ typed_array_gen.sh | 4 + 3 files changed, 366 insertions(+) create mode 100644 enum_array.go create mode 100644 enum_array_test.go diff --git a/enum_array.go b/enum_array.go new file mode 100644 index 00000000..3a948015 --- /dev/null +++ b/enum_array.go @@ -0,0 +1,212 @@ +package pgtype + +import ( + "database/sql/driver" + + "github.com/pkg/errors" +) + +type EnumArray struct { + Elements []GenericText + Dimensions []ArrayDimension + Status Status +} + +func (dst *EnumArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = EnumArray{Status: Null} + return nil + } + + switch value := src.(type) { + + case []string: + if value == nil { + *dst = EnumArray{Status: Null} + } else if len(value) == 0 { + *dst = EnumArray{Status: Present} + } else { + elements := make([]GenericText, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = EnumArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to EnumArray", value) + } + + return nil +} + +func (dst *EnumArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *EnumArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = EnumArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []GenericText + + if len(uta.Elements) > 0 { + elements = make([]GenericText, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem GenericText + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = EnumArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (src *EnumArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *EnumArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *EnumArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/enum_array_test.go b/enum_array_test.go new file mode 100644 index 00000000..94774e1e --- /dev/null +++ b/enum_array_test.go @@ -0,0 +1,150 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestEnumArrayTranscode(t *testing.T) { + setupConn := testutil.MustConnectPgx(t) + defer testutil.MustClose(t, setupConn) + + if _, err := setupConn.Exec("drop type if exists color"); err != nil { + t.Fatal(err) + } + if _, err := setupConn.Exec("create type color as enum ('red', 'green', 'blue')"); err != nil { + t.Fatal(err) + } + + testutil.TestSuccessfulTranscode(t, "color[]", []interface{}{ + &pgtype.EnumArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + pgtype.GenericText{String: "red", Status: pgtype.Present}, + pgtype.GenericText{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.EnumArray{Status: pgtype.Null}, + &pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + pgtype.GenericText{String: "red", Status: pgtype.Present}, + pgtype.GenericText{String: "green", Status: pgtype.Present}, + pgtype.GenericText{String: "blue", Status: pgtype.Present}, + pgtype.GenericText{String: "red", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestEnumArrayArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.EnumArray + }{ + { + source: []string{"foo"}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.EnumArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.EnumArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestEnumArrayArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + + simpleTests := []struct { + src pgtype.EnumArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.EnumArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.EnumArray + dst interface{} + }{ + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 80ece93c..2a1eab99 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -16,4 +16,8 @@ erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[] erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]float64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go + +# While the binary format is theoretically possible it is only practical to use the text format. In addition, the text format for NULL enums is unquoted so TextArray or a possible GenericTextArray cannot be used. +erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string text_null='NULL' binary_format=false typed_array.go.erb > enum_array.go + goimports -w *_array.go From 6618ea669e756ad82fadf3a3b14855f3c5d15643 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Nov 2017 13:37:47 -0500 Subject: [PATCH 0117/1158] Use named value instead of literal --- range.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/range.go b/range.go index d870834f..49e766ff 100644 --- a/range.go +++ b/range.go @@ -26,8 +26,8 @@ type UntypedTextRange struct { func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { utr := &UntypedTextRange{} if src == "empty" { - utr.LowerType = 'E' - utr.UpperType = 'E' + utr.LowerType = Empty + utr.UpperType = Empty return utr, nil } From 5ab54cb24f0eda8099e100c4090f40d1545cf785 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Nov 2017 13:47:03 -0500 Subject: [PATCH 0118/1158] Add String method to pgtype.BoundType Character representation is much easier to read than numeric. --- range.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/range.go b/range.go index 49e766ff..a200825e 100644 --- a/range.go +++ b/range.go @@ -16,6 +16,10 @@ const ( Empty = BoundType('E') ) +func (bt BoundType) String() string { + return string(bt) +} + type UntypedTextRange struct { Lower string Upper string From 4e334054dd226f121ecf202cbafd2d7939471585 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Nov 2017 14:03:46 -0500 Subject: [PATCH 0119/1158] Fix ranges with text format where end is unbounded fixes #342 --- int4range_test.go | 2 ++ int8range_test.go | 2 ++ numrange_test.go | 12 ++++++++++++ range.go | 26 +++++++++++++------------- 4 files changed, 29 insertions(+), 13 deletions(-) diff --git a/int4range_test.go b/int4range_test.go index 088097d8..961678bb 100644 --- a/int4range_test.go +++ b/int4range_test.go @@ -12,6 +12,8 @@ func TestInt4rangeTranscode(t *testing.T) { &pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, &pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int4{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, Status: pgtype.Present}, + &pgtype.Int4range{Upper: pgtype.Int4{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, Status: pgtype.Present}, &pgtype.Int4range{Status: pgtype.Null}, }) } diff --git a/int8range_test.go b/int8range_test.go index c039ec65..f33ae4d8 100644 --- a/int8range_test.go +++ b/int8range_test.go @@ -12,6 +12,8 @@ func TestInt8rangeTranscode(t *testing.T) { &pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, &pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int8{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, Status: pgtype.Present}, + &pgtype.Int8range{Upper: pgtype.Int8{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, Status: pgtype.Present}, &pgtype.Int8range{Status: pgtype.Null}, }) } diff --git a/numrange_test.go b/numrange_test.go index 32267c86..ccc794d5 100644 --- a/numrange_test.go +++ b/numrange_test.go @@ -29,6 +29,18 @@ func TestNumrangeTranscode(t *testing.T) { UpperType: pgtype.Exclusive, Status: pgtype.Present, }, + &pgtype.Numrange{ + Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Unbounded, + Status: pgtype.Present, + }, + &pgtype.Numrange{ + Upper: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, + LowerType: pgtype.Unbounded, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, &pgtype.Numrange{Status: pgtype.Null}, }) } diff --git a/range.go b/range.go index a200825e..54fc6ca0 100644 --- a/range.go +++ b/range.go @@ -79,28 +79,28 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { if err != nil { return nil, errors.Errorf("invalid upper value: %v", err) } - buf.UnreadRune() if r == ')' || r == ']' { utr.UpperType = Unbounded } else { + buf.UnreadRune() utr.Upper, err = rangeParseValue(buf) if err != nil { return nil, errors.Errorf("invalid upper value: %v", err) } - } - r, _, err = buf.ReadRune() - if err != nil { - return nil, errors.Errorf("missing upper bound: %v", err) - } - switch r { - case ')': - utr.UpperType = Exclusive - case ']': - utr.UpperType = Inclusive - default: - return nil, errors.Errorf("missing upper bound, instead got: %v", string(r)) + r, _, err = buf.ReadRune() + if err != nil { + return nil, errors.Errorf("missing upper bound: %v", err) + } + switch r { + case ')': + utr.UpperType = Exclusive + case ']': + utr.UpperType = Inclusive + default: + return nil, errors.Errorf("missing upper bound, instead got: %v", string(r)) + } } skipWhitespace(buf) From 3f02d66ae0bc47cc7611ccd3d24c80e8cc3dabc9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Nov 2017 19:09:24 -0500 Subject: [PATCH 0120/1158] Detect erroneous JSON(B) encoding JSON(B) automatically marshals any value. Avoid marshalling values of pgtype.JSON and pgtype.JSONB. The caller certainly meant to call on a pointer. See https://github.com/jackc/pgx/issues/350 for discussion. refs #350 --- json.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/json.go b/json.go index 562722aa..ef8231b1 100644 --- a/json.go +++ b/json.go @@ -33,6 +33,15 @@ func (dst *JSON) Set(src interface{}) error { } else { *dst = JSON{Bytes: value, Status: Present} } + // Encode* methods are defined on *JSON. If JSON is passed directly then the + // struct itself would be encoded instead of Bytes. This is clearly a footgun + // so detect and return an error. See https://github.com/jackc/pgx/issues/350. + case JSON: + return errors.New("use pointer to pgtype.JSON instead of value") + // Same as above but for JSONB (because they share implementation) + case JSONB: + return errors.New("use pointer to pgtype.JSONB instead of value") + default: buf, err := json.Marshal(value) if err != nil { From 4e6de12a62e1cd6f645d6e7f50e61d15f8b45023 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 17 Nov 2017 09:37:22 -0600 Subject: [PATCH 0121/1158] Fix missing interval mapping --- pgtype.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pgtype.go b/pgtype.go index be13ec77..83311cf4 100644 --- a/pgtype.go +++ b/pgtype.go @@ -246,6 +246,7 @@ func init() { "int4range": &Int4range{}, "int8": &Int8{}, "int8range": &Int8range{}, + "interval": &Interval{}, "json": &JSON{}, "jsonb": &JSONB{}, "line": &Line{}, From a01653c3df06bd60090b6d50a9226f8fbd7e61f9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Nov 2017 21:13:34 -0600 Subject: [PATCH 0122/1158] Add support for bit type --- bit.go | 37 +++++++++++++++++++++++++++++++++++++ bit_test.go | 25 +++++++++++++++++++++++++ pgtype.go | 1 + 3 files changed, 63 insertions(+) create mode 100644 bit.go create mode 100644 bit_test.go diff --git a/bit.go b/bit.go new file mode 100644 index 00000000..f892cee5 --- /dev/null +++ b/bit.go @@ -0,0 +1,37 @@ +package pgtype + +import ( + "database/sql/driver" +) + +type Bit Varbit + +func (dst *Bit) Set(src interface{}) error { + return (*Varbit)(dst).Set(src) +} + +func (dst *Bit) Get() interface{} { + return (*Varbit)(dst).Get() +} + +func (src *Bit) AssignTo(dst interface{}) error { + return (*Varbit)(src).AssignTo(dst) +} + +func (dst *Bit) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Varbit)(dst).DecodeBinary(ci, src) +} + +func (src *Bit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Varbit)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Bit) Scan(src interface{}) error { + return (*Varbit)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Bit) Value() (driver.Value, error) { + return (*Varbit)(src).Value() +} diff --git a/bit_test.go b/bit_test.go new file mode 100644 index 00000000..19492bc9 --- /dev/null +++ b/bit_test.go @@ -0,0 +1,25 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestBitTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "bit(40)", []interface{}{ + &pgtype.Varbit{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Status: pgtype.Present}, + &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Status: pgtype.Present}, + &pgtype.Varbit{Status: pgtype.Null}, + }) +} + +func TestBitNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select B'111111111'", + Value: &pgtype.Bit{Bytes: []byte{255, 128}, Len: 9, Status: pgtype.Present}, + }, + }) +} diff --git a/pgtype.go b/pgtype.go index 83311cf4..f7a1a300 100644 --- a/pgtype.go +++ b/pgtype.go @@ -227,6 +227,7 @@ func init() { "_uuid": &UUIDArray{}, "_varchar": &VarcharArray{}, "aclitem": &ACLItem{}, + "bit": &Bit{}, "bool": &Bool{}, "box": &Box{}, "bytea": &Bytea{}, From b3d0cbd0e6ca13ccf6f787d381e1b008ba736d21 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Dec 2017 13:45:22 -0600 Subject: [PATCH 0123/1158] Fix reading interrupted messages When an message is received and a timeout occurs after reading the header but before reading the entire body the connection state could be corrupted due to the header being consumed. The next read would consider the body of the previous message as the header for the next. fixes #348 --- backend.go | 27 ++++++++++++++------- backend_test.go | 37 +++++++++++++++++++++++++++++ frontend.go | 27 ++++++++++++++------- frontend_test.go | 62 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 135 insertions(+), 18 deletions(-) create mode 100644 backend_test.go create mode 100644 frontend_test.go diff --git a/backend.go b/backend.go index 9a7ef342..8f3c3478 100644 --- a/backend.go +++ b/backend.go @@ -24,6 +24,10 @@ type Backend struct { startupMessage StartupMessage sync Sync terminate Terminate + + bodyLen int + msgType byte + partialMsg bool } func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { @@ -57,16 +61,19 @@ func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { } func (b *Backend) Receive() (FrontendMessage, error) { - header, err := b.cr.Next(5) - if err != nil { - return nil, err + if !b.partialMsg { + header, err := b.cr.Next(5) + if err != nil { + return nil, err + } + + b.msgType = header[0] + b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + b.partialMsg = true } - msgType := header[0] - bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 - var msg FrontendMessage - switch msgType { + switch b.msgType { case 'B': msg = &b.bind case 'C': @@ -88,14 +95,16 @@ func (b *Backend) Receive() (FrontendMessage, error) { case 'X': msg = &b.terminate default: - return nil, errors.Errorf("unknown message type: %c", msgType) + return nil, errors.Errorf("unknown message type: %c", b.msgType) } - msgBody, err := b.cr.Next(bodyLen) + msgBody, err := b.cr.Next(b.bodyLen) if err != nil { return nil, err } + b.partialMsg = false + err = msg.Decode(msgBody) return msg, err } diff --git a/backend_test.go b/backend_test.go new file mode 100644 index 00000000..02a5e9ca --- /dev/null +++ b/backend_test.go @@ -0,0 +1,37 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/jackc/pgx/pgproto3" +) + +func TestBackendReceiveInterrupted(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Q', 0, 0, 0, 6}) + + backend, err := pgproto3.NewBackend(server, nil) + if err != nil { + t.Fatal(err) + } + + msg, err := backend.Receive() + if err == nil { + t.Fatal("expected err") + } + if msg != nil { + t.Fatalf("did not expect msg, but %v", msg) + } + + server.push([]byte{'I', 0}) + + msg, err = backend.Receive() + if err != nil { + t.Fatal(err) + } + if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" { + t.Fatalf("unexpected msg: %v", msg) + } +} diff --git a/frontend.go b/frontend.go index c8ab5f15..d803d362 100644 --- a/frontend.go +++ b/frontend.go @@ -34,6 +34,10 @@ type Frontend struct { parseComplete ParseComplete readyForQuery ReadyForQuery rowDescription RowDescription + + bodyLen int + msgType byte + partialMsg bool } func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { @@ -47,16 +51,19 @@ func (b *Frontend) Send(msg FrontendMessage) error { } func (b *Frontend) Receive() (BackendMessage, error) { - header, err := b.cr.Next(5) - if err != nil { - return nil, err + if !b.partialMsg { + header, err := b.cr.Next(5) + if err != nil { + return nil, err + } + + b.msgType = header[0] + b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + b.partialMsg = true } - msgType := header[0] - bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 - var msg BackendMessage - switch msgType { + switch b.msgType { case '1': msg = &b.parseComplete case '2': @@ -100,14 +107,16 @@ func (b *Frontend) Receive() (BackendMessage, error) { case 'Z': msg = &b.readyForQuery default: - return nil, errors.Errorf("unknown message type: %c", msgType) + return nil, errors.Errorf("unknown message type: %c", b.msgType) } - msgBody, err := b.cr.Next(bodyLen) + msgBody, err := b.cr.Next(b.bodyLen) if err != nil { return nil, err } + b.partialMsg = false + err = msg.Decode(msgBody) return msg, err } diff --git a/frontend_test.go b/frontend_test.go new file mode 100644 index 00000000..7d6652c1 --- /dev/null +++ b/frontend_test.go @@ -0,0 +1,62 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/pkg/errors" + + "github.com/jackc/pgx/pgproto3" +) + +type interruptReader struct { + chunks [][]byte +} + +func (ir *interruptReader) Read(p []byte) (n int, err error) { + if len(ir.chunks) == 0 { + return 0, errors.New("no data") + } + + n = copy(p, ir.chunks[0]) + if n != len(ir.chunks[0]) { + panic("this test reader doesn't support partial reads of chunks") + } + + ir.chunks = ir.chunks[1:] + + return n, nil +} + +func (ir *interruptReader) push(p []byte) { + ir.chunks = append(ir.chunks, p) +} + +func TestFrontendReceiveInterrupted(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Z', 0, 0, 0, 5}) + + frontend, err := pgproto3.NewFrontend(server, nil) + if err != nil { + t.Fatal(err) + } + + msg, err := frontend.Receive() + if err == nil { + t.Fatal("expected err") + } + if msg != nil { + t.Fatalf("did not expect msg, but %v", msg) + } + + server.push([]byte{'I'}) + + msg, err = frontend.Receive() + if err != nil { + t.Fatal(err) + } + if msg, ok := msg.(*pgproto3.ReadyForQuery); !ok || msg.TxStatus != 'I' { + t.Fatalf("unexpected msg: %v", msg) + } +} From e22e7e67ecddacbef13de2b0d7b4213e2dce3023 Mon Sep 17 00:00:00 2001 From: Iurii Krasnoshchok Date: Wed, 20 Dec 2017 14:47:52 +0100 Subject: [PATCH 0124/1158] Return error on unknown oid while decoding record instead of panic --- record.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/record.go b/record.go index 26411af2..aeca1c54 100644 --- a/record.go +++ b/record.go @@ -98,9 +98,10 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { var binaryDecoder BinaryDecoder if dt, ok := ci.DataTypeForOID(fieldOID); ok { - if binaryDecoder, ok = dt.Value.(BinaryDecoder); !ok { - return errors.Errorf("unknown oid while decoding record: %v", fieldOID) - } + binaryDecoder, _ = dt.Value.(BinaryDecoder) + } + if binaryDecoder == nil { + return errors.Errorf("unknown oid while decoding record: %v", fieldOID) } var fieldBytes []byte From 645e646183d78b7ba15871f8615d4786bc60c713 Mon Sep 17 00:00:00 2001 From: ferhat elmas Date: Thu, 21 Dec 2017 23:45:26 +0100 Subject: [PATCH 0125/1158] Run gofmt with simplify flag --- aclitem_array_test.go | 24 ++++++++++++------------ bool_array_test.go | 24 ++++++++++++------------ bytea_array_test.go | 24 ++++++++++++------------ cidr_array_test.go | 24 ++++++++++++------------ date_array_test.go | 24 ++++++++++++------------ enum_array_test.go | 12 ++++++------ float4_array_test.go | 24 ++++++++++++------------ float8_array_test.go | 24 ++++++++++++------------ hstore_array_test.go | 18 +++++++++--------- hstore_test.go | 4 ++-- inet_array_test.go | 24 ++++++++++++------------ int2_array_test.go | 24 ++++++++++++------------ int4_array_test.go | 24 ++++++++++++------------ int8_array_test.go | 24 ++++++++++++------------ numeric_array_test.go | 24 ++++++++++++------------ record_test.go | 8 ++++---- text_array_test.go | 24 ++++++++++++------------ timestamp_array_test.go | 24 ++++++++++++------------ timestamptz_array_test.go | 24 ++++++++++++------------ varchar_array_test.go | 24 ++++++++++++------------ 20 files changed, 213 insertions(+), 213 deletions(-) diff --git a/aclitem_array_test.go b/aclitem_array_test.go index c01eaa13..4e60afca 100644 --- a/aclitem_array_test.go +++ b/aclitem_array_test.go @@ -17,8 +17,8 @@ func TestACLItemArrayTranscode(t *testing.T) { }, &pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.ACLItem{Status: pgtype.Null}, + {String: "=r/postgres", Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -26,22 +26,22 @@ func TestACLItemArrayTranscode(t *testing.T) { &pgtype.ACLItemArray{Status: pgtype.Null}, &pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.ACLItem{Status: pgtype.Null}, - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {Status: pgtype.Null}, + {String: "=r/postgres", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, diff --git a/bool_array_test.go b/bool_array_test.go index 87886da6..b529555e 100644 --- a/bool_array_test.go +++ b/bool_array_test.go @@ -17,8 +17,8 @@ func TestBoolArrayTranscode(t *testing.T) { }, &pgtype.BoolArray{ Elements: []pgtype.Bool{ - pgtype.Bool{Bool: true, Status: pgtype.Present}, - pgtype.Bool{Status: pgtype.Null}, + {Bool: true, Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -26,22 +26,22 @@ func TestBoolArrayTranscode(t *testing.T) { &pgtype.BoolArray{Status: pgtype.Null}, &pgtype.BoolArray{ Elements: []pgtype.Bool{ - pgtype.Bool{Bool: true, Status: pgtype.Present}, - pgtype.Bool{Bool: true, Status: pgtype.Present}, - pgtype.Bool{Bool: false, Status: pgtype.Present}, - pgtype.Bool{Bool: true, Status: pgtype.Present}, - pgtype.Bool{Status: pgtype.Null}, - pgtype.Bool{Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Bool: false, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.BoolArray{ Elements: []pgtype.Bool{ - pgtype.Bool{Bool: true, Status: pgtype.Present}, - pgtype.Bool{Bool: false, Status: pgtype.Present}, - pgtype.Bool{Bool: true, Status: pgtype.Present}, - pgtype.Bool{Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, diff --git a/bytea_array_test.go b/bytea_array_test.go index 451c2461..8450b71b 100644 --- a/bytea_array_test.go +++ b/bytea_array_test.go @@ -17,8 +17,8 @@ func TestByteaArrayTranscode(t *testing.T) { }, &pgtype.ByteaArray{ Elements: []pgtype.Bytea{ - pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - pgtype.Bytea{Status: pgtype.Null}, + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -26,22 +26,22 @@ func TestByteaArrayTranscode(t *testing.T) { &pgtype.ByteaArray{Status: pgtype.Null}, &pgtype.ByteaArray{ Elements: []pgtype.Bytea{ - pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, - pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - pgtype.Bytea{Status: pgtype.Null}, - pgtype.Bytea{Bytes: []byte{1}, Status: pgtype.Present}, + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{}, Status: pgtype.Present}, + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Bytes: []byte{1}, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.ByteaArray{ Elements: []pgtype.Bytea{ - pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, - pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - pgtype.Bytea{Bytes: []byte{1}, Status: pgtype.Present}, + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{}, Status: pgtype.Present}, + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{1}, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, diff --git a/cidr_array_test.go b/cidr_array_test.go index 70d3f65b..206a590f 100644 --- a/cidr_array_test.go +++ b/cidr_array_test.go @@ -18,8 +18,8 @@ func TestCIDRArrayTranscode(t *testing.T) { }, &pgtype.CIDRArray{ Elements: []pgtype.CIDR{ - pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.CIDR{Status: pgtype.Null}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -27,22 +27,22 @@ func TestCIDRArrayTranscode(t *testing.T) { &pgtype.CIDRArray{Status: pgtype.Null}, &pgtype.CIDRArray{ Elements: []pgtype.CIDR{ - pgtype.CIDR{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.CIDR{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - pgtype.CIDR{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - pgtype.CIDR{Status: pgtype.Null}, - pgtype.CIDR{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + {Status: pgtype.Null}, + {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.CIDRArray{ Elements: []pgtype.CIDR{ - pgtype.CIDR{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.CIDR{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - pgtype.CIDR{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, diff --git a/date_array_test.go b/date_array_test.go index 74ebfbbe..2ba19d1a 100644 --- a/date_array_test.go +++ b/date_array_test.go @@ -18,8 +18,8 @@ func TestDateArrayTranscode(t *testing.T) { }, &pgtype.DateArray{ Elements: []pgtype.Date{ - pgtype.Date{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Status: pgtype.Null}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -27,22 +27,22 @@ func TestDateArrayTranscode(t *testing.T) { &pgtype.DateArray{Status: pgtype.Null}, &pgtype.DateArray{ Elements: []pgtype.Date{ - pgtype.Date{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Status: pgtype.Null}, - pgtype.Date{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Status: pgtype.Null}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.DateArray{ Elements: []pgtype.Date{ - pgtype.Date{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, diff --git a/enum_array_test.go b/enum_array_test.go index 94774e1e..9cc950af 100644 --- a/enum_array_test.go +++ b/enum_array_test.go @@ -27,8 +27,8 @@ func TestEnumArrayTranscode(t *testing.T) { }, &pgtype.EnumArray{ Elements: []pgtype.GenericText{ - pgtype.GenericText{String: "red", Status: pgtype.Present}, - pgtype.GenericText{Status: pgtype.Null}, + {String: "red", Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -36,10 +36,10 @@ func TestEnumArrayTranscode(t *testing.T) { &pgtype.EnumArray{Status: pgtype.Null}, &pgtype.EnumArray{ Elements: []pgtype.GenericText{ - pgtype.GenericText{String: "red", Status: pgtype.Present}, - pgtype.GenericText{String: "green", Status: pgtype.Present}, - pgtype.GenericText{String: "blue", Status: pgtype.Present}, - pgtype.GenericText{String: "red", Status: pgtype.Present}, + {String: "red", Status: pgtype.Present}, + {String: "green", Status: pgtype.Present}, + {String: "blue", Status: pgtype.Present}, + {String: "red", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, diff --git a/float4_array_test.go b/float4_array_test.go index 6d6a4f30..4d6511b4 100644 --- a/float4_array_test.go +++ b/float4_array_test.go @@ -17,8 +17,8 @@ func TestFloat4ArrayTranscode(t *testing.T) { }, &pgtype.Float4Array{ Elements: []pgtype.Float4{ - pgtype.Float4{Float: 1, Status: pgtype.Present}, - pgtype.Float4{Status: pgtype.Null}, + {Float: 1, Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -26,22 +26,22 @@ func TestFloat4ArrayTranscode(t *testing.T) { &pgtype.Float4Array{Status: pgtype.Null}, &pgtype.Float4Array{ Elements: []pgtype.Float4{ - pgtype.Float4{Float: 1, Status: pgtype.Present}, - pgtype.Float4{Float: 2, Status: pgtype.Present}, - pgtype.Float4{Float: 3, Status: pgtype.Present}, - pgtype.Float4{Float: 4, Status: pgtype.Present}, - pgtype.Float4{Status: pgtype.Null}, - pgtype.Float4{Float: 6, Status: pgtype.Present}, + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Float: 6, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.Float4Array{ Elements: []pgtype.Float4{ - pgtype.Float4{Float: 1, Status: pgtype.Present}, - pgtype.Float4{Float: 2, Status: pgtype.Present}, - pgtype.Float4{Float: 3, Status: pgtype.Present}, - pgtype.Float4{Float: 4, Status: pgtype.Present}, + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, diff --git a/float8_array_test.go b/float8_array_test.go index 56801e80..ff8e3b26 100644 --- a/float8_array_test.go +++ b/float8_array_test.go @@ -17,8 +17,8 @@ func TestFloat8ArrayTranscode(t *testing.T) { }, &pgtype.Float8Array{ Elements: []pgtype.Float8{ - pgtype.Float8{Float: 1, Status: pgtype.Present}, - pgtype.Float8{Status: pgtype.Null}, + {Float: 1, Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -26,22 +26,22 @@ func TestFloat8ArrayTranscode(t *testing.T) { &pgtype.Float8Array{Status: pgtype.Null}, &pgtype.Float8Array{ Elements: []pgtype.Float8{ - pgtype.Float8{Float: 1, Status: pgtype.Present}, - pgtype.Float8{Float: 2, Status: pgtype.Present}, - pgtype.Float8{Float: 3, Status: pgtype.Present}, - pgtype.Float8{Float: 4, Status: pgtype.Present}, - pgtype.Float8{Status: pgtype.Null}, - pgtype.Float8{Float: 6, Status: pgtype.Present}, + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Float: 6, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.Float8Array{ Elements: []pgtype.Float8{ - pgtype.Float8{Float: 1, Status: pgtype.Present}, - pgtype.Float8{Float: 2, Status: pgtype.Present}, - pgtype.Float8{Float: 3, Status: pgtype.Present}, - pgtype.Float8{Float: 4, Status: pgtype.Present}, + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, diff --git a/hstore_array_test.go b/hstore_array_test.go index fcf08c49..d629a04b 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -18,12 +18,12 @@ func TestHstoreArrayTranscode(t *testing.T) { } values := []pgtype.Hstore{ - pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, - pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, - pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, - pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, - pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, - pgtype.Hstore{Status: pgtype.Null}, + {Map: map[string]pgtype.Text{}, Status: pgtype.Present}, + {Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + {Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + {Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + {Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + {Status: pgtype.Null}, } specialStrings := []string{ @@ -120,11 +120,11 @@ func TestHstoreArraySet(t *testing.T) { result pgtype.HstoreArray }{ { - src: []map[string]string{map[string]string{"foo": "bar"}}, + src: []map[string]string{{"foo": "bar"}}, result: pgtype.HstoreArray{ Elements: []pgtype.Hstore{ { - Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, Status: pgtype.Present, }, }, @@ -159,7 +159,7 @@ func TestHstoreArrayAssignTo(t *testing.T) { src: pgtype.HstoreArray{ Elements: []pgtype.Hstore{ { - Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, Status: pgtype.Present, }, }, diff --git a/hstore_test.go b/hstore_test.go index dc2439fc..d76c9942 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -68,7 +68,7 @@ func TestHstoreSet(t *testing.T) { src map[string]string result pgtype.Hstore }{ - {src: map[string]string{"foo": "bar"}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}}, + {src: map[string]string{"foo": "bar"}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}}, } for i, tt := range successfulTests { @@ -92,7 +92,7 @@ func TestHstoreAssignTo(t *testing.T) { dst *map[string]string expected map[string]string }{ - {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}, dst: &m, expected: map[string]string{"foo": "bar"}}, + {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}, dst: &m, expected: map[string]string{"foo": "bar"}}, {src: pgtype.Hstore{Status: pgtype.Null}, dst: &m, expected: ((map[string]string)(nil))}, } diff --git a/inet_array_test.go b/inet_array_test.go index 3e2b6a3c..ca528ed3 100644 --- a/inet_array_test.go +++ b/inet_array_test.go @@ -18,8 +18,8 @@ func TestInetArrayTranscode(t *testing.T) { }, &pgtype.InetArray{ Elements: []pgtype.Inet{ - pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Inet{Status: pgtype.Null}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -27,22 +27,22 @@ func TestInetArrayTranscode(t *testing.T) { &pgtype.InetArray{Status: pgtype.Null}, &pgtype.InetArray{ Elements: []pgtype.Inet{ - pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - pgtype.Inet{Status: pgtype.Null}, - pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + {Status: pgtype.Null}, + {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.InetArray{ Elements: []pgtype.Inet{ - pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, diff --git a/int2_array_test.go b/int2_array_test.go index 0adc1aef..0fe763c1 100644 --- a/int2_array_test.go +++ b/int2_array_test.go @@ -17,8 +17,8 @@ func TestInt2ArrayTranscode(t *testing.T) { }, &pgtype.Int2Array{ Elements: []pgtype.Int2{ - pgtype.Int2{Int: 1, Status: pgtype.Present}, - pgtype.Int2{Status: pgtype.Null}, + {Int: 1, Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -26,22 +26,22 @@ func TestInt2ArrayTranscode(t *testing.T) { &pgtype.Int2Array{Status: pgtype.Null}, &pgtype.Int2Array{ Elements: []pgtype.Int2{ - pgtype.Int2{Int: 1, Status: pgtype.Present}, - pgtype.Int2{Int: 2, Status: pgtype.Present}, - pgtype.Int2{Int: 3, Status: pgtype.Present}, - pgtype.Int2{Int: 4, Status: pgtype.Present}, - pgtype.Int2{Status: pgtype.Null}, - pgtype.Int2{Int: 6, Status: pgtype.Present}, + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Int: 6, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.Int2Array{ Elements: []pgtype.Int2{ - pgtype.Int2{Int: 1, Status: pgtype.Present}, - pgtype.Int2{Int: 2, Status: pgtype.Present}, - pgtype.Int2{Int: 3, Status: pgtype.Present}, - pgtype.Int2{Int: 4, Status: pgtype.Present}, + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, diff --git a/int4_array_test.go b/int4_array_test.go index 6fad18bb..602a3657 100644 --- a/int4_array_test.go +++ b/int4_array_test.go @@ -17,8 +17,8 @@ func TestInt4ArrayTranscode(t *testing.T) { }, &pgtype.Int4Array{ Elements: []pgtype.Int4{ - pgtype.Int4{Int: 1, Status: pgtype.Present}, - pgtype.Int4{Status: pgtype.Null}, + {Int: 1, Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -26,22 +26,22 @@ func TestInt4ArrayTranscode(t *testing.T) { &pgtype.Int4Array{Status: pgtype.Null}, &pgtype.Int4Array{ Elements: []pgtype.Int4{ - pgtype.Int4{Int: 1, Status: pgtype.Present}, - pgtype.Int4{Int: 2, Status: pgtype.Present}, - pgtype.Int4{Int: 3, Status: pgtype.Present}, - pgtype.Int4{Int: 4, Status: pgtype.Present}, - pgtype.Int4{Status: pgtype.Null}, - pgtype.Int4{Int: 6, Status: pgtype.Present}, + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Int: 6, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.Int4Array{ Elements: []pgtype.Int4{ - pgtype.Int4{Int: 1, Status: pgtype.Present}, - pgtype.Int4{Int: 2, Status: pgtype.Present}, - pgtype.Int4{Int: 3, Status: pgtype.Present}, - pgtype.Int4{Int: 4, Status: pgtype.Present}, + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, diff --git a/int8_array_test.go b/int8_array_test.go index 4f5c4f9a..2ca65173 100644 --- a/int8_array_test.go +++ b/int8_array_test.go @@ -17,8 +17,8 @@ func TestInt8ArrayTranscode(t *testing.T) { }, &pgtype.Int8Array{ Elements: []pgtype.Int8{ - pgtype.Int8{Int: 1, Status: pgtype.Present}, - pgtype.Int8{Status: pgtype.Null}, + {Int: 1, Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -26,22 +26,22 @@ func TestInt8ArrayTranscode(t *testing.T) { &pgtype.Int8Array{Status: pgtype.Null}, &pgtype.Int8Array{ Elements: []pgtype.Int8{ - pgtype.Int8{Int: 1, Status: pgtype.Present}, - pgtype.Int8{Int: 2, Status: pgtype.Present}, - pgtype.Int8{Int: 3, Status: pgtype.Present}, - pgtype.Int8{Int: 4, Status: pgtype.Present}, - pgtype.Int8{Status: pgtype.Null}, - pgtype.Int8{Int: 6, Status: pgtype.Present}, + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Int: 6, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.Int8Array{ Elements: []pgtype.Int8{ - pgtype.Int8{Int: 1, Status: pgtype.Present}, - pgtype.Int8{Int: 2, Status: pgtype.Present}, - pgtype.Int8{Int: 3, Status: pgtype.Present}, - pgtype.Int8{Int: 4, Status: pgtype.Present}, + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, diff --git a/numeric_array_test.go b/numeric_array_test.go index 25531840..22ee1bc4 100644 --- a/numeric_array_test.go +++ b/numeric_array_test.go @@ -18,8 +18,8 @@ func TestNumericArrayTranscode(t *testing.T) { }, &pgtype.NumericArray{ Elements: []pgtype.Numeric{ - pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}, - pgtype.Numeric{Status: pgtype.Null}, + {Int: big.NewInt(1), Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -27,22 +27,22 @@ func TestNumericArrayTranscode(t *testing.T) { &pgtype.NumericArray{Status: pgtype.Null}, &pgtype.NumericArray{ Elements: []pgtype.Numeric{ - pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}, - pgtype.Numeric{Int: big.NewInt(2), Status: pgtype.Present}, - pgtype.Numeric{Int: big.NewInt(3), Status: pgtype.Present}, - pgtype.Numeric{Int: big.NewInt(4), Status: pgtype.Present}, - pgtype.Numeric{Status: pgtype.Null}, - pgtype.Numeric{Int: big.NewInt(6), Status: pgtype.Present}, + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Status: pgtype.Null}, + {Int: big.NewInt(6), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.NumericArray{ Elements: []pgtype.Numeric{ - pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}, - pgtype.Numeric{Int: big.NewInt(2), Status: pgtype.Present}, - pgtype.Numeric{Int: big.NewInt(3), Status: pgtype.Present}, - pgtype.Numeric{Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, diff --git a/record_test.go b/record_test.go index dc01cbbf..a3730a3e 100644 --- a/record_test.go +++ b/record_test.go @@ -52,10 +52,10 @@ func TestRecordTranscode(t *testing.T) { &pgtype.Text{String: "foo", Status: pgtype.Present}, &pgtype.Int4Array{ Elements: []pgtype.Int4{ - pgtype.Int4{Int: 1, Status: pgtype.Present}, - pgtype.Int4{Int: 2, Status: pgtype.Present}, - pgtype.Int4{Status: pgtype.Null}, - pgtype.Int4{Int: 4, Status: pgtype.Present}, + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Int: 4, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, Status: pgtype.Present, diff --git a/text_array_test.go b/text_array_test.go index 35ebef96..105d9353 100644 --- a/text_array_test.go +++ b/text_array_test.go @@ -17,8 +17,8 @@ func TestTextArrayTranscode(t *testing.T) { }, &pgtype.TextArray{ Elements: []pgtype.Text{ - pgtype.Text{String: "foo", Status: pgtype.Present}, - pgtype.Text{Status: pgtype.Null}, + {String: "foo", Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -26,22 +26,22 @@ func TestTextArrayTranscode(t *testing.T) { &pgtype.TextArray{Status: pgtype.Null}, &pgtype.TextArray{ Elements: []pgtype.Text{ - pgtype.Text{String: "bar ", Status: pgtype.Present}, - pgtype.Text{String: "NuLL", Status: pgtype.Present}, - pgtype.Text{String: `wow"quz\`, Status: pgtype.Present}, - pgtype.Text{String: "", Status: pgtype.Present}, - pgtype.Text{Status: pgtype.Null}, - pgtype.Text{String: "null", Status: pgtype.Present}, + {String: "bar ", Status: pgtype.Present}, + {String: "NuLL", Status: pgtype.Present}, + {String: `wow"quz\`, Status: pgtype.Present}, + {String: "", Status: pgtype.Present}, + {Status: pgtype.Null}, + {String: "null", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.TextArray{ Elements: []pgtype.Text{ - pgtype.Text{String: "bar", Status: pgtype.Present}, - pgtype.Text{String: "baz", Status: pgtype.Present}, - pgtype.Text{String: "quz", Status: pgtype.Present}, - pgtype.Text{String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "quz", Status: pgtype.Present}, + {String: "foo", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, diff --git a/timestamp_array_test.go b/timestamp_array_test.go index c75d101f..5821f43a 100644 --- a/timestamp_array_test.go +++ b/timestamp_array_test.go @@ -18,8 +18,8 @@ func TestTimestampArrayTranscode(t *testing.T) { }, &pgtype.TimestampArray{ Elements: []pgtype.Timestamp{ - pgtype.Timestamp{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Status: pgtype.Null}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -27,22 +27,22 @@ func TestTimestampArrayTranscode(t *testing.T) { &pgtype.TimestampArray{Status: pgtype.Null}, &pgtype.TimestampArray{ Elements: []pgtype.Timestamp{ - pgtype.Timestamp{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Status: pgtype.Null}, - pgtype.Timestamp{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Status: pgtype.Null}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.TimestampArray{ Elements: []pgtype.Timestamp{ - pgtype.Timestamp{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, diff --git a/timestamptz_array_test.go b/timestamptz_array_test.go index 50ee65d0..8d7ea4c9 100644 --- a/timestamptz_array_test.go +++ b/timestamptz_array_test.go @@ -18,8 +18,8 @@ func TestTimestamptzArrayTranscode(t *testing.T) { }, &pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - pgtype.Timestamptz{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamptz{Status: pgtype.Null}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -27,22 +27,22 @@ func TestTimestamptzArrayTranscode(t *testing.T) { &pgtype.TimestamptzArray{Status: pgtype.Null}, &pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - pgtype.Timestamptz{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamptz{Status: pgtype.Null}, - pgtype.Timestamptz{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Status: pgtype.Null}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - pgtype.Timestamptz{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, diff --git a/varchar_array_test.go b/varchar_array_test.go index 7d6fb39b..9fb0960f 100644 --- a/varchar_array_test.go +++ b/varchar_array_test.go @@ -17,8 +17,8 @@ func TestVarcharArrayTranscode(t *testing.T) { }, &pgtype.VarcharArray{ Elements: []pgtype.Varchar{ - pgtype.Varchar{String: "foo", Status: pgtype.Present}, - pgtype.Varchar{Status: pgtype.Null}, + {String: "foo", Status: pgtype.Present}, + {Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, @@ -26,22 +26,22 @@ func TestVarcharArrayTranscode(t *testing.T) { &pgtype.VarcharArray{Status: pgtype.Null}, &pgtype.VarcharArray{ Elements: []pgtype.Varchar{ - pgtype.Varchar{String: "bar ", Status: pgtype.Present}, - pgtype.Varchar{String: "NuLL", Status: pgtype.Present}, - pgtype.Varchar{String: `wow"quz\`, Status: pgtype.Present}, - pgtype.Varchar{String: "", Status: pgtype.Present}, - pgtype.Varchar{Status: pgtype.Null}, - pgtype.Varchar{String: "null", Status: pgtype.Present}, + {String: "bar ", Status: pgtype.Present}, + {String: "NuLL", Status: pgtype.Present}, + {String: `wow"quz\`, Status: pgtype.Present}, + {String: "", Status: pgtype.Present}, + {Status: pgtype.Null}, + {String: "null", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.VarcharArray{ Elements: []pgtype.Varchar{ - pgtype.Varchar{String: "bar", Status: pgtype.Present}, - pgtype.Varchar{String: "baz", Status: pgtype.Present}, - pgtype.Varchar{String: "quz", Status: pgtype.Present}, - pgtype.Varchar{String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "quz", Status: pgtype.Present}, + {String: "foo", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, From 5bd04dc568a655e04995e004c48f7263cd95aee2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Dec 2017 10:24:09 -0600 Subject: [PATCH 0126/1158] Add test for record with unknown OID --- record_test.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/record_test.go b/record_test.go index dc01cbbf..7cc8a59f 100644 --- a/record_test.go +++ b/record_test.go @@ -102,6 +102,28 @@ func TestRecordTranscode(t *testing.T) { } } +func TestRecordWithUnknownOID(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustClose(t, conn) + + _, err := conn.Exec(`drop type if exists floatrange; + +create type floatrange as range ( + subtype = float8, + subtype_diff = float8mi +);`) + if err != nil { + t.Fatal(err) + } + defer conn.Exec("drop type floatrange") + + var result pgtype.Record + err = conn.QueryRow("select row('foo'::text, floatrange(1, 10), 'bar'::text)").Scan(&result) + if err == nil { + t.Errorf("expected error but none") + } +} + func TestRecordAssignTo(t *testing.T) { var valueSlice []pgtype.Value var interfaceSlice []interface{} From fbc0fc7e3ef3760b78754ec980789bcd3615646f Mon Sep 17 00:00:00 2001 From: eruca Date: Fri, 29 Dec 2017 21:09:22 +0800 Subject: [PATCH 0127/1158] UnmarshalJSON for Int8 missing --- int8.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/int8.go b/int8.go index 17a676eb..bbdda6b8 100644 --- a/int8.go +++ b/int8.go @@ -184,3 +184,15 @@ func (src *Int8) MarshalJSON() ([]byte, error) { return nil, errBadStatus } + +func (dst *Int8) UnmarshalJSON(b []byte) error { + var n int64 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + *dst = Int8{Int: n, Status: Present} + + return nil +} From 91bb74b5263658526511d258810613330ca036b5 Mon Sep 17 00:00:00 2001 From: Iurii Krasnoshchok Date: Sat, 16 Dec 2017 02:20:34 +0100 Subject: [PATCH 0128/1158] Add support for bpchar type --- bpchar.go | 68 ++++++++++ bpchar_array.go | 300 +++++++++++++++++++++++++++++++++++++++++++ bpchar_array_test.go | 55 ++++++++ bpchar_test.go | 51 ++++++++ pgtype.go | 4 + typed_array_gen.sh | 1 + 6 files changed, 479 insertions(+) create mode 100644 bpchar.go create mode 100644 bpchar_array.go create mode 100644 bpchar_array_test.go create mode 100644 bpchar_test.go diff --git a/bpchar.go b/bpchar.go new file mode 100644 index 00000000..21263184 --- /dev/null +++ b/bpchar.go @@ -0,0 +1,68 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// BPChar is fixed-length, blank padded char type +// character(n), char(n) +type BPChar Text + +// Set converts from src to dst. +func (dst *BPChar) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +// Get returns underlying value +func (dst *BPChar) Get() interface{} { + return (*Text)(dst).Get() +} + +// AssignTo assigns from src to dst. +func (src *BPChar) AssignTo(dst interface{}) error { + if src.Status == Present { + switch v := dst.(type) { + case *rune: + runes := []rune(src.String) + if len(runes) == 1 { + *v = runes[0] + return nil + } + } + } + return (*Text)(src).AssignTo(dst) +} + +func (dst *BPChar) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *BPChar) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +func (src *BPChar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeText(ci, buf) +} + +func (src *BPChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *BPChar) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *BPChar) Value() (driver.Value, error) { + return (*Text)(src).Value() +} + +func (src *BPChar) MarshalJSON() ([]byte, error) { + return (*Text)(src).MarshalJSON() +} + +func (dst *BPChar) UnmarshalJSON(b []byte) error { + return (*Text)(dst).UnmarshalJSON(b) +} diff --git a/bpchar_array.go b/bpchar_array.go new file mode 100644 index 00000000..1e6220f7 --- /dev/null +++ b/bpchar_array.go @@ -0,0 +1,300 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type BPCharArray struct { + Elements []BPChar + Dimensions []ArrayDimension + Status Status +} + +func (dst *BPCharArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = BPCharArray{Status: Null} + return nil + } + + switch value := src.(type) { + + case []string: + if value == nil { + *dst = BPCharArray{Status: Null} + } else if len(value) == 0 { + *dst = BPCharArray{Status: Present} + } else { + elements := make([]BPChar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = BPCharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to BPCharArray", value) + } + + return nil +} + +func (dst *BPCharArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *BPCharArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *BPCharArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = BPCharArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []BPChar + + if len(uta.Elements) > 0 { + elements = make([]BPChar, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem BPChar + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = BPCharArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *BPCharArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = BPCharArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = BPCharArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]BPChar, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = BPCharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *BPCharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("bpchar"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "bpchar") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *BPCharArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *BPCharArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/bpchar_array_test.go b/bpchar_array_test.go new file mode 100644 index 00000000..e4f2e7eb --- /dev/null +++ b/bpchar_array_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestBPCharArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "char(8)[]", []interface{}{ + &pgtype.BPCharArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.BPCharArray{ + Elements: []pgtype.BPChar{ + pgtype.BPChar{String: "foo ", Status: pgtype.Present}, + pgtype.BPChar{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.BPCharArray{Status: pgtype.Null}, + &pgtype.BPCharArray{ + Elements: []pgtype.BPChar{ + pgtype.BPChar{String: "bar ", Status: pgtype.Present}, + pgtype.BPChar{String: "NuLL ", Status: pgtype.Present}, + pgtype.BPChar{String: `wow"quz\`, Status: pgtype.Present}, + pgtype.BPChar{String: "1 ", Status: pgtype.Present}, + pgtype.BPChar{String: "1 ", Status: pgtype.Present}, + pgtype.BPChar{String: "null ", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 3, LowerBound: 1}, + {Length: 2, LowerBound: 1}, + }, + Status: pgtype.Present, + }, + &pgtype.BPCharArray{ + Elements: []pgtype.BPChar{ + pgtype.BPChar{String: " bar ", Status: pgtype.Present}, + pgtype.BPChar{String: " baz ", Status: pgtype.Present}, + pgtype.BPChar{String: " quz ", Status: pgtype.Present}, + pgtype.BPChar{String: "foo ", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} diff --git a/bpchar_test.go b/bpchar_test.go new file mode 100644 index 00000000..c076ca1b --- /dev/null +++ b/bpchar_test.go @@ -0,0 +1,51 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestChar3Transcode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "char(3)", []interface{}{ + &pgtype.BPChar{String: "a ", Status: pgtype.Present}, + &pgtype.BPChar{String: " a ", Status: pgtype.Present}, + &pgtype.BPChar{String: "å—¨ ", Status: pgtype.Present}, + &pgtype.BPChar{String: " ", Status: pgtype.Present}, + &pgtype.BPChar{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.BPChar) + b := bb.(pgtype.BPChar) + + return a.Status == b.Status && a.String == b.String + }) +} + +func TestBPCharAssignTo(t *testing.T) { + var ( + str string + run rune + ) + simpleTests := []struct { + src pgtype.BPChar + dst interface{} + expected interface{} + }{ + {src: pgtype.BPChar{String: "simple", Status: pgtype.Present}, dst: &str, expected: "simple"}, + {src: pgtype.BPChar{String: "å—¨", Status: pgtype.Present}, dst: &run, expected: 'å—¨'}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + +} diff --git a/pgtype.go b/pgtype.go index f7a1a300..2643314e 100644 --- a/pgtype.go +++ b/pgtype.go @@ -32,6 +32,7 @@ const ( Int4ArrayOID = 1007 TextArrayOID = 1009 ByteaArrayOID = 1001 + BPCharArrayOID = 1014 VarcharArrayOID = 1015 Int8ArrayOID = 1016 Float4ArrayOID = 1021 @@ -39,6 +40,7 @@ const ( ACLItemOID = 1033 ACLItemArrayOID = 1034 InetArrayOID = 1041 + BPCharOID = 1042 VarcharOID = 1043 DateOID = 1082 TimestampOID = 1114 @@ -211,6 +213,7 @@ func init() { nameValues = map[string]Value{ "_aclitem": &ACLItemArray{}, "_bool": &BoolArray{}, + "_bpchar": &BPCharArray{}, "_bytea": &ByteaArray{}, "_cidr": &CIDRArray{}, "_date": &DateArray{}, @@ -230,6 +233,7 @@ func init() { "bit": &Bit{}, "bool": &Bool{}, "box": &Box{}, + "bpchar": &BPChar{}, "bytea": &Bytea{}, "char": &QChar{}, "cid": &CID{}, diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 2a1eab99..4a8211bc 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -11,6 +11,7 @@ erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.I erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' binary_format=true typed_array.go.erb > text_array.go erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' binary_format=true typed_array.go.erb > varchar_array.go +erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string element_type_name=bpchar text_null='NULL' binary_format=true typed_array.go.erb > bpchar_array.go erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go From 44bb11de828339712efc7c33b5295af8a095f736 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Jan 2018 18:14:42 -0600 Subject: [PATCH 0129/1158] Import encoding/json package --- int8.go | 1 + 1 file changed, 1 insertion(+) diff --git a/int8.go b/int8.go index bbdda6b8..00a8cd00 100644 --- a/int8.go +++ b/int8.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "math" "strconv" From f078754e05c37a8281976ead88bc28f428e8c5eb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 16 Feb 2018 21:37:41 -0600 Subject: [PATCH 0130/1158] Skip test based on missing line type Instead of explicit server version checking. Ubuntu installed version string is not parsable by go-version. e.g. 10.2 (Ubuntu 10.2-1.pgdg16.04+1) --- line_test.go | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/line_test.go b/line_test.go index 09e48019..45242ad1 100644 --- a/line_test.go +++ b/line_test.go @@ -3,23 +3,14 @@ package pgtype_test import ( "testing" - version "github.com/hashicorp/go-version" "github.com/jackc/pgx/pgtype" "github.com/jackc/pgx/pgtype/testutil" ) func TestLineTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) - serverVersion, err := version.NewVersion(conn.RuntimeParams["server_version"]) - if err != nil { - t.Fatalf("cannot get server version: %v", err) - } - testutil.MustClose(t, conn) - - minVersion := version.Must(version.NewVersion("9.4")) - - if serverVersion.LessThan(minVersion) { - t.Skipf("Skipping line test for server version %v", serverVersion) + if _, ok := conn.ConnInfo.DataTypeForName("line"); !ok { + t.Skip("Skipping due to no line type") } testutil.TestSuccessfulTranscode(t, "line", []interface{}{ From 7ed0a8732c5a16b5bf7db7dbd18211d02109cc1b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 8 Mar 2018 07:40:25 -0500 Subject: [PATCH 0131/1158] Update shopspring decimal integration test New version of shopspring/decimal improves precision. This broke a test. --- ext/shopspring-numeric/decimal_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go index 79121ef3..b237478d 100644 --- a/ext/shopspring-numeric/decimal_test.go +++ b/ext/shopspring-numeric/decimal_test.go @@ -171,7 +171,7 @@ func TestNumericSet(t *testing.T) { {source: float64(1000), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1000"), Status: pgtype.Present}}, {source: float64(1234), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1234"), Status: pgtype.Present}}, {source: float64(12345678900), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345678900"), Status: pgtype.Present}}, - {source: float64(12345.678901), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345.678901"), Status: pgtype.Present}}, + {source: float64(1.25), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.25"), Status: pgtype.Present}}, } for i, tt := range successfulTests { From 898fc86e25e291e291abc0a81a5237c137534c37 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 8 Mar 2018 08:05:54 -0500 Subject: [PATCH 0132/1158] Skip line test of PG 9.3 --- line_test.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/line_test.go b/line_test.go index 45242ad1..019cbf0c 100644 --- a/line_test.go +++ b/line_test.go @@ -13,6 +13,16 @@ func TestLineTranscode(t *testing.T) { t.Skip("Skipping due to no line type") } + // line may exist but not be usable on 9.3 :( + var isPG93 bool + err := conn.QueryRow("select version() ~ '9.3'").Scan(&isPG93) + if err != nil { + t.Fatal(err) + } + if isPG93 { + t.Skip("Skipping due to unimplemented line type in PG 9.3") + } + testutil.TestSuccessfulTranscode(t, "line", []interface{}{ &pgtype.Line{ A: 1.23, B: 4.56, C: 7.89, From 46d0f7e1c828a1f3ee5e175ff1944676d92641be Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 17 Mar 2018 10:26:03 -0500 Subject: [PATCH 0133/1158] Fix precision loss for test format geometric types fixes #399 --- box.go | 8 ++++++-- box_test.go | 2 +- circle.go | 7 ++++++- circle_test.go | 2 +- line.go | 8 +++++++- line_test.go | 2 +- lseg.go | 9 +++++++-- lseg_test.go | 2 +- path.go | 5 ++++- path_test.go | 2 +- point.go | 5 ++++- point_test.go | 2 +- polygon.go | 5 ++++- polygon_test.go | 2 +- 14 files changed, 45 insertions(+), 16 deletions(-) diff --git a/box.go b/box.go index 83df0499..4c5a4406 100644 --- a/box.go +++ b/box.go @@ -116,8 +116,12 @@ func (src *Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } - buf = append(buf, fmt.Sprintf(`(%f,%f),(%f,%f)`, - src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)...) + buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, + strconv.FormatFloat(src.P[0].X, 'f', -1, 64), + strconv.FormatFloat(src.P[0].Y, 'f', -1, 64), + strconv.FormatFloat(src.P[1].X, 'f', -1, 64), + strconv.FormatFloat(src.P[1].Y, 'f', -1, 64), + )...) return buf, nil } diff --git a/box_test.go b/box_test.go index f26cda68..197401f3 100644 --- a/box_test.go +++ b/box_test.go @@ -10,7 +10,7 @@ import ( func TestBoxTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "box", []interface{}{ &pgtype.Box{ - P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, Status: pgtype.Present, }, &pgtype.Box{ diff --git a/circle.go b/circle.go index 97ecbf31..15ea447b 100644 --- a/circle.go +++ b/circle.go @@ -103,7 +103,12 @@ func (src *Circle) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } - buf = append(buf, fmt.Sprintf(`<(%f,%f),%f>`, src.P.X, src.P.Y, src.R)...) + buf = append(buf, fmt.Sprintf(`<(%s,%s),%s>`, + strconv.FormatFloat(src.P.X, 'f', -1, 64), + strconv.FormatFloat(src.P.Y, 'f', -1, 64), + strconv.FormatFloat(src.R, 'f', -1, 64), + )...) + return buf, nil } diff --git a/circle_test.go b/circle_test.go index 2747d4f5..634c5832 100644 --- a/circle_test.go +++ b/circle_test.go @@ -9,7 +9,7 @@ import ( func TestCircleTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "circle", []interface{}{ - &pgtype.Circle{P: pgtype.Vec2{1.234, 5.6789}, R: 3.5, Status: pgtype.Present}, + &pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Status: pgtype.Present}, &pgtype.Circle{P: pgtype.Vec2{-1.234, -5.6789}, R: 12.9, Status: pgtype.Present}, &pgtype.Circle{Status: pgtype.Null}, }) diff --git a/line.go b/line.go index f6eadf0e..5fdc5604 100644 --- a/line.go +++ b/line.go @@ -101,7 +101,13 @@ func (src *Line) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } - return append(buf, fmt.Sprintf(`{%f,%f,%f}`, src.A, src.B, src.C)...), nil + buf = append(buf, fmt.Sprintf(`{%s,%s,%s}`, + strconv.FormatFloat(src.A, 'f', -1, 64), + strconv.FormatFloat(src.B, 'f', -1, 64), + strconv.FormatFloat(src.C, 'f', -1, 64), + )...) + + return buf, nil } func (src *Line) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { diff --git a/line_test.go b/line_test.go index 019cbf0c..200d1d4c 100644 --- a/line_test.go +++ b/line_test.go @@ -25,7 +25,7 @@ func TestLineTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "line", []interface{}{ &pgtype.Line{ - A: 1.23, B: 4.56, C: 7.89, + A: 1.23, B: 4.56, C: 7.89012345, Status: pgtype.Present, }, &pgtype.Line{ diff --git a/lseg.go b/lseg.go index a9d740cf..4445ea51 100644 --- a/lseg.go +++ b/lseg.go @@ -116,8 +116,13 @@ func (src *Lseg) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } - buf = append(buf, fmt.Sprintf(`(%f,%f),(%f,%f)`, - src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)...) + buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, + strconv.FormatFloat(src.P[0].X, 'f', -1, 64), + strconv.FormatFloat(src.P[0].Y, 'f', -1, 64), + strconv.FormatFloat(src.P[1].X, 'f', -1, 64), + strconv.FormatFloat(src.P[1].Y, 'f', -1, 64), + )...) + return buf, nil } diff --git a/lseg_test.go b/lseg_test.go index bd394e3c..0a25090a 100644 --- a/lseg_test.go +++ b/lseg_test.go @@ -10,7 +10,7 @@ import ( func TestLsegTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "lseg", []interface{}{ &pgtype.Lseg{ - P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}}, + P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, Status: pgtype.Present, }, &pgtype.Lseg{ diff --git a/path.go b/path.go index aa0cee8e..69083712 100644 --- a/path.go +++ b/path.go @@ -138,7 +138,10 @@ func (src *Path) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if i > 0 { buf = append(buf, ',') } - buf = append(buf, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)...) + buf = append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(p.X, 'f', -1, 64), + strconv.FormatFloat(p.Y, 'f', -1, 64), + )...) } return append(buf, endByte), nil diff --git a/path_test.go b/path_test.go index d213a1b4..bc2d7435 100644 --- a/path_test.go +++ b/path_test.go @@ -10,7 +10,7 @@ import ( func TestPathTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "path", []interface{}{ &pgtype.Path{ - P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}}, + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}}, Closed: false, Status: pgtype.Present, }, diff --git a/point.go b/point.go index 3132a939..98a32d34 100644 --- a/point.go +++ b/point.go @@ -98,7 +98,10 @@ func (src *Point) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } - return append(buf, fmt.Sprintf(`(%f,%f)`, src.P.X, src.P.Y)...), nil + return append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(src.P.X, 'f', -1, 64), + strconv.FormatFloat(src.P.Y, 'f', -1, 64), + )...), nil } func (src *Point) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { diff --git a/point_test.go b/point_test.go index f46b342d..af70b38b 100644 --- a/point_test.go +++ b/point_test.go @@ -9,7 +9,7 @@ import ( func TestPointTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "point", []interface{}{ - &pgtype.Point{P: pgtype.Vec2{1.234, 5.6789}, Status: pgtype.Present}, + &pgtype.Point{P: pgtype.Vec2{1.234, 5.6789012345}, Status: pgtype.Present}, &pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Status: pgtype.Present}, &pgtype.Point{Status: pgtype.Null}, }) diff --git a/polygon.go b/polygon.go index 3f3d9f53..d84a0abd 100644 --- a/polygon.go +++ b/polygon.go @@ -125,7 +125,10 @@ func (src *Polygon) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if i > 0 { buf = append(buf, ',') } - buf = append(buf, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)...) + buf = append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(p.X, 'f', -1, 64), + strconv.FormatFloat(p.Y, 'f', -1, 64), + )...) } return append(buf, ')'), nil diff --git a/polygon_test.go b/polygon_test.go index 48481dc5..5ff3bbb3 100644 --- a/polygon_test.go +++ b/polygon_test.go @@ -10,7 +10,7 @@ import ( func TestPolygonTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "polygon", []interface{}{ &pgtype.Polygon{ - P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {5.0, 3.234}}, + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, Status: pgtype.Present, }, &pgtype.Polygon{ From 9bb19fd8e7120f92356a0db47dc72f03b46a11eb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 14 Apr 2018 09:17:56 -0500 Subject: [PATCH 0134/1158] pgtype.JSON(B).Value now returns []byte Allows scanning jsonb column into *json.RawMessage. fixes #409 --- json.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/json.go b/json.go index ef8231b1..b05aba6b 100644 --- a/json.go +++ b/json.go @@ -152,7 +152,7 @@ func (dst *JSON) Scan(src interface{}) error { func (src *JSON) Value() (driver.Value, error) { switch src.Status { case Present: - return string(src.Bytes), nil + return src.Bytes, nil case Null: return nil, nil default: From 5524d654d3a54652b81443615bff0afe5e76cd75 Mon Sep 17 00:00:00 2001 From: Anthony Regeda Date: Tue, 24 Apr 2018 16:31:31 +0300 Subject: [PATCH 0135/1158] numeric_with_uint64 numeric array supports both types int64 and uint64 --- numeric_array.go | 56 ++++++++++++++++++++++++++++++++++++++++++++++ typed_array_gen.sh | 2 +- 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/numeric_array.go b/numeric_array.go index d991234a..9c8f8eb3 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -61,6 +61,44 @@ func (dst *NumericArray) Set(src interface{}) error { } } + case []int64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -105,6 +143,24 @@ func (src *NumericArray) AssignTo(dst interface{}) error { } return nil + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 4a8211bc..38b9e1d0 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -15,7 +15,7 @@ erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]st erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go -erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]float64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go +erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]float64,[]int64,[]uint64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go # While the binary format is theoretically possible it is only practical to use the text format. In addition, the text format for NULL enums is unquoted so TextArray or a possible GenericTextArray cannot be used. From 3ec4c6ca23f761fd10042cf3d3a7c212a4862cd5 Mon Sep 17 00:00:00 2001 From: Tarik Demirci Date: Thu, 17 May 2018 12:20:11 +0200 Subject: [PATCH 0136/1158] Allow setting nil to pgtype.Bool --- bool.go | 5 +++++ bool_test.go | 1 + 2 files changed, 6 insertions(+) diff --git a/bool.go b/bool.go index 3a3eef48..308ba304 100644 --- a/bool.go +++ b/bool.go @@ -13,6 +13,11 @@ type Bool struct { } func (dst *Bool) Set(src interface{}) error { + if src == nil { + *dst = Bool{Status: Null} + return nil + } + switch value := src.(type) { case bool: *dst = Bool{Bool: value, Status: Present} diff --git a/bool_test.go b/bool_test.go index 2712e3b0..04d9337d 100644 --- a/bool_test.go +++ b/bool_test.go @@ -29,6 +29,7 @@ func TestBoolSet(t *testing.T) { {source: "f", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, {source: _bool(true), result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, {source: _bool(false), result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: nil, result: pgtype.Bool{Status: pgtype.Null}}, } for i, tt := range successfulTests { From 79ba0275de20522fae625286a319e634225933e5 Mon Sep 17 00:00:00 2001 From: Damir Vandic Date: Mon, 4 Jun 2018 21:02:20 +0200 Subject: [PATCH 0137/1158] Add the type of the value in all decode error messages --- aclitem.go | 2 +- aclitem_array.go | 2 +- bool.go | 2 +- bool_array.go | 2 +- bpchar_array.go | 2 +- bytea.go | 2 +- bytea_array.go | 2 +- cidr_array.go | 2 +- date.go | 2 +- date_array.go | 2 +- enum_array.go | 2 +- float4_array.go | 2 +- float8_array.go | 2 +- hstore.go | 4 ++-- hstore_array.go | 2 +- inet.go | 2 +- inet_array.go | 2 +- int2_array.go | 2 +- int4_array.go | 2 +- int8_array.go | 2 +- interval.go | 2 +- macaddr.go | 2 +- numeric_array.go | 2 +- record.go | 2 +- text.go | 2 +- text_array.go | 2 +- timestamp.go | 2 +- timestamp_array.go | 2 +- timestamptz.go | 2 +- timestamptz_array.go | 2 +- typed_array.go.erb | 2 +- uuid_array.go | 2 +- varchar_array.go | 2 +- 33 files changed, 34 insertions(+), 34 deletions(-) diff --git a/aclitem.go b/aclitem.go index 35269e91..4da962dd 100644 --- a/aclitem.go +++ b/aclitem.go @@ -70,7 +70,7 @@ func (src *ACLItem) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *ACLItem) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/aclitem_array.go b/aclitem_array.go index 0a829295..d8bf3303 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -84,7 +84,7 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/bool.go b/bool.go index 308ba304..0574588d 100644 --- a/bool.go +++ b/bool.go @@ -64,7 +64,7 @@ func (src *Bool) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/bool_array.go b/bool_array.go index 67dd92a7..4231e29d 100644 --- a/bool_array.go +++ b/bool_array.go @@ -86,7 +86,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/bpchar_array.go b/bpchar_array.go index 1e6220f7..b3f36cb6 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -86,7 +86,7 @@ func (src *BPCharArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *BPCharArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/bytea.go b/bytea.go index c7117f48..4506dc31 100644 --- a/bytea.go +++ b/bytea.go @@ -64,7 +64,7 @@ func (src *Bytea) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } // DecodeText only supports the hex format. This has been the default since diff --git a/bytea_array.go b/bytea_array.go index c8eb5669..9c094b28 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -86,7 +86,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/cidr_array.go b/cidr_array.go index e4bb7614..c254c834 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -115,7 +115,7 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/date.go b/date.go index f1c0d8bd..b1d4c11d 100644 --- a/date.go +++ b/date.go @@ -72,7 +72,7 @@ func (src *Date) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/date_array.go b/date_array.go index 0cb64581..c0f5c21c 100644 --- a/date_array.go +++ b/date_array.go @@ -87,7 +87,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/enum_array.go b/enum_array.go index 3a948015..7168cb8a 100644 --- a/enum_array.go +++ b/enum_array.go @@ -84,7 +84,7 @@ func (src *EnumArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/float4_array.go b/float4_array.go index 02c28caa..fba181d3 100644 --- a/float4_array.go +++ b/float4_array.go @@ -86,7 +86,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/float8_array.go b/float8_array.go index b92a8205..13dbf27f 100644 --- a/float8_array.go +++ b/float8_array.go @@ -86,7 +86,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/hstore.go b/hstore.go index 347446ae..71b030f9 100644 --- a/hstore.go +++ b/hstore.go @@ -59,7 +59,7 @@ func (src *Hstore) AssignTo(dst interface{}) error { *v = make(map[string]string, len(src.Map)) for k, val := range src.Map { if val.Status != Present { - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } (*v)[k] = val.String } @@ -73,7 +73,7 @@ func (src *Hstore) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/hstore_array.go b/hstore_array.go index 80530c26..2b8cf37e 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -86,7 +86,7 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/inet.go b/inet.go index 01fc0e5b..d93e6347 100644 --- a/inet.go +++ b/inet.go @@ -91,7 +91,7 @@ func (src *Inet) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/inet_array.go b/inet_array.go index f3e4efbf..dba369d2 100644 --- a/inet_array.go +++ b/inet_array.go @@ -115,7 +115,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/int2_array.go b/int2_array.go index f50d9275..7fefbd95 100644 --- a/int2_array.go +++ b/int2_array.go @@ -114,7 +114,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/int4_array.go b/int4_array.go index 6c9418ba..4e78ce71 100644 --- a/int4_array.go +++ b/int4_array.go @@ -114,7 +114,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/int8_array.go b/int8_array.go index bb6ce004..15a8398a 100644 --- a/int8_array.go +++ b/int8_array.go @@ -114,7 +114,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/interval.go b/interval.go index 799ce53a..dc696319 100644 --- a/interval.go +++ b/interval.go @@ -74,7 +74,7 @@ func (src *Interval) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/macaddr.go b/macaddr.go index 4c6e2212..79004be4 100644 --- a/macaddr.go +++ b/macaddr.go @@ -70,7 +70,7 @@ func (src *Macaddr) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Macaddr) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/numeric_array.go b/numeric_array.go index 9c8f8eb3..b5e38539 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -170,7 +170,7 @@ func (src *NumericArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/record.go b/record.go index aeca1c54..64c6f13a 100644 --- a/record.go +++ b/record.go @@ -67,7 +67,7 @@ func (src *Record) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { diff --git a/text.go b/text.go index bceeffd4..919743fe 100644 --- a/text.go +++ b/text.go @@ -74,7 +74,7 @@ func (src *Text) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/text_array.go b/text_array.go index e40f4b86..d53f0b7b 100644 --- a/text_array.go +++ b/text_array.go @@ -86,7 +86,7 @@ func (src *TextArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/timestamp.go b/timestamp.go index d906f467..6292521a 100644 --- a/timestamp.go +++ b/timestamp.go @@ -76,7 +76,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } // DecodeText decodes from src into dst. The decoded time is considered to diff --git a/timestamp_array.go b/timestamp_array.go index 546a3810..11b32a11 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -87,7 +87,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/timestamptz.go b/timestamptz.go index 74fe4954..2b9d2a64 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -77,7 +77,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/timestamptz_array.go b/timestamptz_array.go index 88b6cc5f..31c11f94 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -87,7 +87,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/typed_array.go.erb b/typed_array.go.erb index 6fafc2df..6b46a23e 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -86,7 +86,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/uuid_array.go b/uuid_array.go index 9c7843a7..13efdb23 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -142,7 +142,7 @@ func (src *UUIDArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *UUIDArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/varchar_array.go b/varchar_array.go index 09eba3ea..a7f23fba 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -86,7 +86,7 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { From 5f39bbaf35772562b62214493b9d56b9f59439f4 Mon Sep 17 00:00:00 2001 From: Murat Kabilov Date: Mon, 30 Jul 2018 17:29:26 +0200 Subject: [PATCH 0138/1158] Add *Conn. CopyFromTextual, CopyToTextual, which use textual format for copying data --- copy_done.go | 30 ++++++++++++++++++++++++++++++ frontend.go | 3 +++ 2 files changed, 33 insertions(+) create mode 100644 copy_done.go diff --git a/copy_done.go b/copy_done.go new file mode 100644 index 00000000..92481908 --- /dev/null +++ b/copy_done.go @@ -0,0 +1,30 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type CopyDone struct { +} + +func (*CopyDone) Backend() {} + +func (dst *CopyDone) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "CopyDone", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *CopyDone) Encode(dst []byte) []byte { + return append(dst, 'c', 0, 0, 0, 4) +} + +func (src *CopyDone) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "CopyDone", + }) +} diff --git a/frontend.go b/frontend.go index d803d362..d1541c74 100644 --- a/frontend.go +++ b/frontend.go @@ -22,6 +22,7 @@ type Frontend struct { copyData CopyData copyInResponse CopyInResponse copyOutResponse CopyOutResponse + copyDone CopyDone dataRow DataRow emptyQueryResponse EmptyQueryResponse errorResponse ErrorResponse @@ -72,6 +73,8 @@ func (b *Frontend) Receive() (BackendMessage, error) { msg = &b.closeComplete case 'A': msg = &b.notificationResponse + case 'c': + msg = &b.copyDone case 'C': msg = &b.commandComplete case 'd': From 88d317af97479da36e7f7f69ac663f24b8df79b7 Mon Sep 17 00:00:00 2001 From: Anthony Regeda Date: Sat, 1 Sep 2018 16:06:20 +0300 Subject: [PATCH 0139/1158] macaddr-array macaddr array is introduced --- macaddr_array.go | 301 ++++++++++++++++++++++++++++++++++++++++++ macaddr_array_test.go | 105 +++++++++++++++ typed_array_gen.sh | 1 + 3 files changed, 407 insertions(+) create mode 100644 macaddr_array.go create mode 100644 macaddr_array_test.go diff --git a/macaddr_array.go b/macaddr_array.go new file mode 100644 index 00000000..bd8b4c5a --- /dev/null +++ b/macaddr_array.go @@ -0,0 +1,301 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "net" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type MacaddrArray struct { + Elements []Macaddr + Dimensions []ArrayDimension + Status Status +} + +func (dst *MacaddrArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = MacaddrArray{Status: Null} + return nil + } + + switch value := src.(type) { + + case []net.HardwareAddr: + if value == nil { + *dst = MacaddrArray{Status: Null} + } else if len(value) == 0 { + *dst = MacaddrArray{Status: Present} + } else { + elements := make([]Macaddr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = MacaddrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to MacaddrArray", value) + } + + return nil +} + +func (dst *MacaddrArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *MacaddrArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]net.HardwareAddr: + *v = make([]net.HardwareAddr, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %#v into %T", src, dst) +} + +func (dst *MacaddrArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = MacaddrArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Macaddr + + if len(uta.Elements) > 0 { + elements = make([]Macaddr, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Macaddr + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = MacaddrArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *MacaddrArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = MacaddrArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = MacaddrArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Macaddr, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = MacaddrArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *MacaddrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *MacaddrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("macaddr"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "macaddr") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *MacaddrArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *MacaddrArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/macaddr_array_test.go b/macaddr_array_test.go new file mode 100644 index 00000000..d4bb2f01 --- /dev/null +++ b/macaddr_array_test.go @@ -0,0 +1,105 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestMacaddrArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "macaddr[]", []interface{}{ + &pgtype.MacaddrArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.MacaddrArray{Status: pgtype.Null}, + }) +} + +func TestMacaddrArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.MacaddrArray + }{ + { + source: []net.HardwareAddr{mustParseMacaddr(t, "01:23:45:67:89:ab")}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]net.HardwareAddr)(nil)), + result: pgtype.MacaddrArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.MacaddrArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestMacaddrArrayAssignTo(t *testing.T) { + var macaddrSlice []net.HardwareAddr + + simpleTests := []struct { + src pgtype.MacaddrArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &macaddrSlice, + expected: []net.HardwareAddr{mustParseMacaddr(t, "01:23:45:67:89:ab")}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &macaddrSlice, + expected: []net.HardwareAddr{nil}, + }, + { + src: pgtype.MacaddrArray{Status: pgtype.Null}, + dst: &macaddrSlice, + expected: (([]net.HardwareAddr)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 38b9e1d0..bd70faa4 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -8,6 +8,7 @@ erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_type erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go +erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr go_array_types=[]net.HardwareAddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' binary_format=true typed_array.go.erb > text_array.go erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' binary_format=true typed_array.go.erb > varchar_array.go From 8f7c03a47f201367ae6ac1faef798eefeae7f498 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Sep 2018 18:40:42 -0500 Subject: [PATCH 0140/1158] Fix: do not silently ignore assign NULL to *string AssignTo can only assign NULL to a **string. Previous code tried to assign nil to a *string, which did nothing. Correct behavior is to detect this as an error. --- json.go | 16 +++++++++++----- json_test.go | 2 +- jsonb_test.go | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/json.go b/json.go index b05aba6b..377a1546 100644 --- a/json.go +++ b/json.go @@ -72,14 +72,20 @@ func (dst *JSON) Get() interface{} { func (src *JSON) AssignTo(dst interface{}) error { switch v := dst.(type) { case *string: - if src.Status != Present { - v = nil - } else { + if src.Status == Present { *v = string(src.Bytes) + } else { + return errors.Errorf("cannot assign non-present status to %T", dst) } case **string: - *v = new(string) - return src.AssignTo(*v) + if src.Status == Present { + s := string(src.Bytes) + *v = &s + return nil + } else { + *v = nil + return nil + } case *[]byte: if src.Status != Present { *v = nil diff --git a/json_test.go b/json_test.go index 82c02539..38494841 100644 --- a/json_test.go +++ b/json_test.go @@ -129,7 +129,7 @@ func TestJSONAssignTo(t *testing.T) { t.Errorf("%d: %v", i, err) } - if *tt.dst == tt.expected { + if *tt.dst != tt.expected { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) } } diff --git a/jsonb_test.go b/jsonb_test.go index 1a9a3056..ab743151 100644 --- a/jsonb_test.go +++ b/jsonb_test.go @@ -135,7 +135,7 @@ func TestJSONBAssignTo(t *testing.T) { t.Errorf("%d: %v", i, err) } - if *tt.dst == tt.expected { + if *tt.dst != tt.expected { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) } } From 64b1ecf96fbb8a9824e96a77be94bfeaaf476a6e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Sep 2018 07:43:18 -0500 Subject: [PATCH 0141/1158] Type modifier should be int32 not uint32 --- row_description.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/row_description.go b/row_description.go index d0df11b0..3c5a6faa 100644 --- a/row_description.go +++ b/row_description.go @@ -19,7 +19,7 @@ type FieldDescription struct { TableAttributeNumber uint16 DataTypeOID uint32 DataTypeSize int16 - TypeModifier uint32 + TypeModifier int32 Format int16 } @@ -57,7 +57,7 @@ func (dst *RowDescription) Decode(src []byte) error { fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2)) fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4)) fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2))) - fd.TypeModifier = binary.BigEndian.Uint32(buf.Next(4)) + fd.TypeModifier = int32(binary.BigEndian.Uint32(buf.Next(4))) fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2))) dst.Fields[i] = fd @@ -80,7 +80,7 @@ func (src *RowDescription) Encode(dst []byte) []byte { dst = pgio.AppendUint16(dst, fd.TableAttributeNumber) dst = pgio.AppendUint32(dst, fd.DataTypeOID) dst = pgio.AppendInt16(dst, fd.DataTypeSize) - dst = pgio.AppendUint32(dst, fd.TypeModifier) + dst = pgio.AppendInt32(dst, fd.TypeModifier) dst = pgio.AppendInt16(dst, fd.Format) } From f9440700e563fd5be090cc093eff2f2c36832ec0 Mon Sep 17 00:00:00 2001 From: maxarchx Date: Fri, 30 Nov 2018 15:13:43 +0500 Subject: [PATCH 0142/1158] Apply UUID string length check before parsing --- uuid.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/uuid.go b/uuid.go index f8297b39..5e1eead5 100644 --- a/uuid.go +++ b/uuid.go @@ -87,6 +87,9 @@ func (src *UUID) AssignTo(dst interface{}) error { // parseUUID converts a string UUID in standard form to a byte array. func parseUUID(src string) (dst [16]byte, err error) { + if len(src) < 36 { + return dst, errors.Errorf("cannot parse UUID %v", src) + } src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:] buf, err := hex.DecodeString(src) if err != nil { From 5d17ec41567f376fb2a78d995a3e8e262dce5c9b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 28 Dec 2018 17:09:56 -0600 Subject: [PATCH 0143/1158] Rename base package to pgconn --- pgconn.go | 478 +++++++++++++++++++++++++++++++++++++++++++++++++ pgconn_test.go | 34 ++++ 2 files changed, 512 insertions(+) create mode 100644 pgconn.go create mode 100644 pgconn_test.go diff --git a/pgconn.go b/pgconn.go new file mode 100644 index 00000000..c9caef42 --- /dev/null +++ b/pgconn.go @@ -0,0 +1,478 @@ +package pgconn + +import ( + "crypto/md5" + "crypto/tls" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "net" + "os" + "os/user" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/pgproto3" +) + +const batchBufferSize = 4096 + +// PgError represents an error reported by the PostgreSQL server. See +// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for +// detailed field description. +type PgError struct { + Severity string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string +} + +func (pe PgError) Error() string { + return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" +} + +// DialFunc is a function that can be used to connect to a PostgreSQL server +type DialFunc func(network, addr string) (net.Conn, error) + +// ErrTLSRefused occurs when the connection attempt requires TLS and the +// PostgreSQL server refuses to use TLS +var ErrTLSRefused = errors.New("server refused TLS connection") + +type ConnConfig struct { + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 // default: 5432 + Database string + User string // default: OS user name + Password string + TLSConfig *tls.Config // config for TLS connection -- nil disables TLS + Dial DialFunc + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) +} + +func (cc *ConnConfig) NetworkAddress() (network, address string) { + // If host is a valid path, then address is unix socket + if _, err := os.Stat(cc.Host); err == nil { + network = "unix" + address = cc.Host + if !strings.Contains(address, "/.s.PGSQL.") { + address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) + } + } else { + network = "tcp" + address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) + } + + return network, address +} + +func (cc *ConnConfig) assignDefaults() error { + if cc.User == "" { + user, err := user.Current() + if err != nil { + return err + } + cc.User = user.Username + } + + if cc.Port == 0 { + cc.Port = 5432 + } + + if cc.Dial == nil { + defaultDialer := &net.Dialer{KeepAlive: 5 * time.Minute} + cc.Dial = defaultDialer.Dial + } + + return nil +} + +// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. +type PgConn struct { + NetConn net.Conn // the underlying TCP or unix domain socket connection + PID uint32 // backend pid + SecretKey uint32 // key to use to send a cancel query message to the server + parameterStatuses map[string]string // parameters that have been reported by the server + TxStatus byte + Frontend *pgproto3.Frontend + + Config ConnConfig + + batchBuf []byte + batchCount int32 + + pendingReadyForQueryCount int32 + + closed bool +} + +func Connect(cc ConnConfig) (*PgConn, error) { + err := cc.assignDefaults() + if err != nil { + return nil, err + } + + pgConn := new(PgConn) + pgConn.Config = cc + + pgConn.NetConn, err = cc.Dial(cc.NetworkAddress()) + if err != nil { + return nil, err + } + + pgConn.parameterStatuses = make(map[string]string) + + if cc.TLSConfig != nil { + if err := pgConn.startTLS(cc.TLSConfig); err != nil { + return nil, err + } + } + + pgConn.Frontend, err = pgproto3.NewFrontend(pgConn.NetConn, pgConn.NetConn) + if err != nil { + return nil, err + } + + startupMsg := pgproto3.StartupMessage{ + ProtocolVersion: pgproto3.ProtocolVersionNumber, + Parameters: make(map[string]string), + } + + // Copy default run-time params + for k, v := range cc.RuntimeParams { + startupMsg.Parameters[k] = v + } + + startupMsg.Parameters["user"] = cc.User + if cc.Database != "" { + startupMsg.Parameters["database"] = cc.Database + } + + if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil { + return nil, err + } + + for { + msg, err := pgConn.ReceiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.BackendKeyData: + pgConn.PID = msg.ProcessID + pgConn.SecretKey = msg.SecretKey + case *pgproto3.Authentication: + if err = pgConn.rxAuthenticationX(msg); err != nil { + return nil, err + } + case *pgproto3.ReadyForQuery: + return pgConn, nil + case *pgproto3.ParameterStatus: + // handled by ReceiveMessage + case *pgproto3.ErrorResponse: + return nil, PgError{ + Severity: msg.Severity, + Code: msg.Code, + Message: msg.Message, + Detail: msg.Detail, + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: msg.InternalQuery, + Where: msg.Where, + SchemaName: msg.SchemaName, + TableName: msg.TableName, + ColumnName: msg.ColumnName, + DataTypeName: msg.DataTypeName, + ConstraintName: msg.ConstraintName, + File: msg.File, + Line: msg.Line, + Routine: msg.Routine, + } + default: + return nil, errors.New("unexpected message") + } + } +} + +func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { + err = binary.Write(pgConn.NetConn, binary.BigEndian, []int32{8, 80877103}) + if err != nil { + return + } + + response := make([]byte, 1) + if _, err = io.ReadFull(pgConn.NetConn, response); err != nil { + return + } + + if response[0] != 'S' { + return ErrTLSRefused + } + + pgConn.NetConn = tls.Client(pgConn.NetConn, tlsConfig) + + return nil +} + +func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { + switch msg.Type { + case pgproto3.AuthTypeOk: + case pgproto3.AuthTypeCleartextPassword: + err = c.txPasswordMessage(c.Config.Password) + case pgproto3.AuthTypeMD5Password: + digestedPassword := "md5" + hexMD5(hexMD5(c.Config.Password+c.Config.User)+string(msg.Salt[:])) + err = c.txPasswordMessage(digestedPassword) + default: + err = errors.New("Received unknown authentication message") + } + + return +} + +func (pgConn *PgConn) txPasswordMessage(password string) (err error) { + msg := &pgproto3.PasswordMessage{Password: password} + _, err = pgConn.NetConn.Write(msg.Encode(nil)) + return err +} + +func hexMD5(s string) string { + hash := md5.New() + io.WriteString(hash, s) + return hex.EncodeToString(hash.Sum(nil)) +} + +func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { + msg, err := pgConn.Frontend.Receive() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + // Under normal circumstances pendingReadyForQueryCount will be > 0 when a + // ReadyForQuery is received. However, this is not the case on initial + // connection. + if pgConn.pendingReadyForQueryCount > 0 { + pgConn.pendingReadyForQueryCount -= 1 + } + pgConn.TxStatus = msg.TxStatus + case *pgproto3.ParameterStatus: + pgConn.parameterStatuses[msg.Name] = msg.Value + case *pgproto3.ErrorResponse: + if msg.Severity == "FATAL" { + // TODO - close pgConn + return nil, errorResponseToPgError(msg) + } + } + + return msg, nil +} + +// Close closes a connection. It is safe to call Close on a already closed +// connection. +func (pgConn *PgConn) Close() error { + if pgConn.closed { + return nil + } + pgConn.closed = true + + _, err := pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4}) + if err != nil { + pgConn.NetConn.Close() + return err + } + + _, err = pgConn.NetConn.Read(make([]byte, 1)) + if err != io.EOF { + pgConn.NetConn.Close() + return err + } + + return pgConn.NetConn.Close() +} + +// ParameterStatus returns the value of a parameter reported by the server (e.g. +// server_version). Returns an empty string for unknown parameters. +func (pgConn *PgConn) ParameterStatus(key string) string { + return pgConn.parameterStatuses[key] +} + +// CommandTag is the result of an Exec function +type CommandTag string + +// RowsAffected returns the number of rows affected. If the CommandTag was not +// for a row affecting command (e.g. "CREATE TABLE") then it returns 0. +func (ct CommandTag) RowsAffected() int64 { + s := string(ct) + index := strings.LastIndex(s, " ") + if index == -1 { + return 0 + } + n, _ := strconv.ParseInt(s[index+1:], 10, 64) + return n +} + +// SendExec enqueues the execution of sql via the PostgreSQL simple query +// protocol. sql may contain multipe queries. Multiple queries will be processed +// within a single transation. It is only sent to the PostgreSQL server when +// Flush is called. +func (pgConn *PgConn) SendExec(sql string) { + pgConn.batchBuf = appendQuery(pgConn.batchBuf, sql) + pgConn.batchCount += 1 +} + +// appendQuery appends a PostgreSQL wire protocol query message to buf and returns it. +func appendQuery(buf []byte, query string) []byte { + buf = append(buf, 'Q') + buf = pgio.AppendInt32(buf, int32(len(query)+5)) + buf = append(buf, query...) + buf = append(buf, 0) + return buf +} + +type PgResultReader struct { + pgConn *PgConn + fieldDescriptions []pgproto3.FieldDescription + rowValues [][]byte + commandTag CommandTag + err error + complete bool +} + +// GetResult returns a PgResultReader for the next result. If all results are +// consumed it returns nil. If an error occurs it will be reported on the +// returned PgResultReader. +func (pgConn *PgConn) GetResult() *PgResultReader { + if pgConn.pendingReadyForQueryCount == 0 { + return nil + } + + return &PgResultReader{pgConn: pgConn} +} + +func (rr *PgResultReader) NextRow() (present bool) { + if rr.complete { + return false + } + + for { + msg, err := rr.pgConn.ReceiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + rr.fieldDescriptions = msg.Fields + case *pgproto3.DataRow: + rr.rowValues = msg.Values + return true + case *pgproto3.CommandComplete: + rr.commandTag = CommandTag(msg.CommandTag) + rr.complete = true + return false + case *pgproto3.ErrorResponse: + rr.err = errorResponseToPgError(msg) + rr.complete = true + return false + } + } +} + +func (rr *PgResultReader) Value(c int) []byte { + return rr.rowValues[c] +} + +// Close consumes any remaining result data and returns the command tag or +// error. +func (rr *PgResultReader) Close() (CommandTag, error) { + if rr.complete { + return rr.commandTag, rr.err + } + + for { + msg, err := rr.pgConn.ReceiveMessage() + if err != nil { + rr.err = err + rr.complete = true + return rr.commandTag, rr.err + } + + switch msg := msg.(type) { + case *pgproto3.CommandComplete: + rr.commandTag = CommandTag(msg.CommandTag) + rr.complete = true + return rr.commandTag, rr.err + case *pgproto3.ErrorResponse: + rr.err = errorResponseToPgError(msg) + rr.complete = true + return rr.commandTag, rr.err + } + } +} + +// Flush sends the enqueued execs to the server. +func (pgConn *PgConn) Flush() error { + defer pgConn.resetBatch() + + n, err := pgConn.NetConn.Write(pgConn.batchBuf) + if err != nil { + if n > 0 { + // TODO - kill connection - we sent a partial message + } + return err + } + + pgConn.pendingReadyForQueryCount += pgConn.batchCount + return nil +} + +func (pgConn *PgConn) resetBatch() { + pgConn.batchCount = 0 + if len(pgConn.batchBuf) > batchBufferSize { + pgConn.batchBuf = make([]byte, 0, batchBufferSize) + } else { + pgConn.batchBuf = pgConn.batchBuf[0:0] + } +} + +func errorResponseToPgError(msg *pgproto3.ErrorResponse) PgError { + return PgError{ + Severity: msg.Severity, + Code: msg.Code, + Message: msg.Message, + Detail: msg.Detail, + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: msg.InternalQuery, + Where: msg.Where, + SchemaName: msg.SchemaName, + TableName: msg.TableName, + ColumnName: msg.ColumnName, + DataTypeName: msg.DataTypeName, + ConstraintName: msg.ConstraintName, + File: msg.File, + Line: msg.Line, + Routine: msg.Routine, + } +} diff --git a/pgconn_test.go b/pgconn_test.go new file mode 100644 index 00000000..dbcf2704 --- /dev/null +++ b/pgconn_test.go @@ -0,0 +1,34 @@ +package pgconn_test + +import ( + "github.com/jackc/pgx/pgconn" + + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSimple(t *testing.T) { + pgConn, err := pgconn.Connect(pgconn.ConnConfig{Host: "/var/run/postgresql", User: "jack", Database: "pgx_test"}) + require.Nil(t, err) + + pgConn.SendExec("select current_database()") + err = pgConn.Flush() + require.Nil(t, err) + + result := pgConn.GetResult() + require.NotNil(t, result) + + rowFound := result.NextRow() + assert.True(t, rowFound) + if rowFound { + assert.Equal(t, "pgx_test", string(result.Value(0))) + } + + _, err = result.Close() + assert.Nil(t, err) + + err = pgConn.Close() + assert.Nil(t, err) +} From beeb69ff0bed06647f93f4eafae419ff43fd4da1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 30 Dec 2018 16:53:57 -0600 Subject: [PATCH 0144/1158] Restructure connect process - Moved lots of connection logic to pgconn from pgx - Extracted pgpassfile package --- config.go | 421 +++++++++++++++++++++++++++++++++++++++++++++++++ config_test.go | 392 +++++++++++++++++++++++++++++++++++++++++++++ pgconn.go | 130 ++++++++------- pgconn_test.go | 8 +- 4 files changed, 881 insertions(+), 70 deletions(-) create mode 100644 config.go create mode 100644 config_test.go diff --git a/config.go b/config.go new file mode 100644 index 00000000..515d6356 --- /dev/null +++ b/config.go @@ -0,0 +1,421 @@ +package pgconn + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "math" + "net" + "net/url" + "os" + "os/user" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "github.com/jackc/pgx/pgpassfile" + "github.com/pkg/errors" +) + +// Config is the settings used to establish a connection to a PostgreSQL server. +type Config struct { + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + Database string + User string + Password string + TLSConfig *tls.Config // nil disables TLS + DialFunc DialFunc // e.g. net.Dialer.DialContext + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + + Fallbacks []*FallbackConfig +} + +// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a +// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections. +type FallbackConfig struct { + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + TLSConfig *tls.Config // nil disables TLS +} + +// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with +// net.Dial. +func NetworkAddress(host string, port uint16) (network, address string) { + if strings.HasPrefix(host, "/") { + network = "unix" + address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) + } else { + network = "tcp" + address = fmt.Sprintf("%s:%d", host, port) + } + return network, address +} + +// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq. +// It uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment +// variables. connString may be a URL or a DSN. It also may be empty to only read from the +// environment. If a password is not supplied it will attempt to read the .pgpass file. +// +// Example DSN: "user=jack password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-ca" +// +// Example URL: "postgres://jack:secret@1.2.3.4:5432/mydb?sslmode=verify-ca" +// +// Multiple configs may be returned due to sslmode settings with fallback options (e.g. +// sslmode=prefer). Future implementations may also support multiple hosts +// (https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS). +// +// ParseConfig currently recognizes the following environment variable and their parameter key word +// equivalents passed via database URL or DSN: +// +// PGHOST +// PGPORT +// PGDATABASE +// PGUSER +// PGPASSWORD +// PGPASSFILE +// PGSSLMODE +// PGSSLCERT +// PGSSLKEY +// PGSSLROOTCERT +// PGAPPNAME +// PGCONNECT_TIMEOUT +// +// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of +// environment variables. +// +// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key +// word names. They are usually but not always the environment variable name downcased and without +// the "PG" prefix. +// +// Important TLS Security Notes: +// +// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to +// "prefer" behavior if not set. +// +// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on +// what level of security each sslmode provides. +// +// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger +// security guarantees than it would with libpq. Do not rely on this behavior as it +// may be possible to match libpq in the future. If you need full security use +// "verify-full". +func ParseConfig(connString string) (*Config, error) { + settings := defaultSettings() + addEnvSettings(settings) + + if connString != "" { + // connString may be a database URL or a DSN + if strings.HasPrefix(connString, "postgres://") { + url, err := url.Parse(connString) + if err != nil { + return nil, err + } + + err = addURLSettings(settings, url) + if err != nil { + return nil, err + } + } else { + err := addDSNSettings(settings, connString) + if err != nil { + return nil, err + } + } + } + + config := &Config{ + Host: settings["host"], + Database: settings["database"], + User: settings["user"], + Password: settings["password"], + RuntimeParams: make(map[string]string), + } + + if port, err := parsePort(settings["port"]); err == nil { + config.Port = port + } else { + return nil, fmt.Errorf("invalid port: %v", settings["port"]) + } + + if connectTimeout, present := settings["connect_timeout"]; present { + dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) + if err != nil { + return nil, err + } + config.DialFunc = dialFunc + } else { + defaultDialer := makeDefaultDialer() + config.DialFunc = defaultDialer.DialContext + } + + notRuntimeParams := map[string]struct{}{ + "host": struct{}{}, + "port": struct{}{}, + "database": struct{}{}, + "user": struct{}{}, + "password": struct{}{}, + "passfile": struct{}{}, + "connect_timeout": struct{}{}, + "sslmode": struct{}{}, + "sslkey": struct{}{}, + "sslcert": struct{}{}, + "sslrootcert": struct{}{}, + } + + for k, v := range settings { + if _, present := notRuntimeParams[k]; present { + continue + } + config.RuntimeParams[k] = v + } + + var tlsConfigs []*tls.Config + + // Ignore TLS settings if Unix domain socket like libpq + if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { + tlsConfigs = append(tlsConfigs, nil) + } else { + var err error + tlsConfigs, err = configTLS(settings) + if err != nil { + return nil, err + } + } + + config.TLSConfig = tlsConfigs[0] + + for _, tlsConfig := range tlsConfigs[1:] { + config.Fallbacks = append(config.Fallbacks, &FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: tlsConfig, + }) + } + + passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) + if err == nil { + if config.Password == "" { + host := config.Host + if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { + host = "localhost" + } + + config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User) + } + } + + return config, nil +} + +func defaultSettings() map[string]string { + settings := make(map[string]string) + + settings["host"] = defaultHost() + settings["port"] = "5432" + + // Default to the OS user name. Purposely ignoring err getting user name from + // OS. The client application will simply have to specify the user in that + // case (which they typically will be doing anyway). + user, err := user.Current() + if err == nil { + settings["user"] = user.Username + settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") + } + + return settings +} + +// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost +// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it +// checks the existence of common locations. +func defaultHost() string { + candidatePaths := []string{ + "/var/run/postgresql", // Debian + "/private/tmp", // OSX - homebrew + "/tmp", // standard PostgreSQL + } + + for _, path := range candidatePaths { + if _, err := os.Stat(path); err == nil { + return path + } + } + + return "localhost" +} + +func addEnvSettings(settings map[string]string) { + nameMap := map[string]string{ + "PGHOST": "host", + "PGPORT": "port", + "PGDATABASE": "database", + "PGUSER": "user", + "PGPASSWORD": "password", + "PGPASSFILE": "passfile", + "PGAPPNAME": "application_name", + "PGCONNECT_TIMEOUT": "connect_timeout", + "PGSSLMODE": "sslmode", + "PGSSLKEY": "sslkey", + "PGSSLCERT": "sslcert", + "PGSSLROOTCERT": "sslrootcert", + } + + for envname, realname := range nameMap { + value := os.Getenv(envname) + if value != "" { + settings[realname] = value + } + } +} + +func addURLSettings(settings map[string]string, url *url.URL) error { + if url.User != nil { + settings["user"] = url.User.Username() + if password, present := url.User.Password(); present { + settings["password"] = password + } + } + + parts := strings.SplitN(url.Host, ":", 2) + if parts[0] != "" { + settings["host"] = parts[0] + } + if len(parts) == 2 { + settings["port"] = parts[1] + } + + database := strings.TrimLeft(url.Path, "/") + if database != "" { + settings["database"] = database + } + + for k, v := range url.Query() { + settings[k] = v[0] + } + + return nil +} + +var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) + +func addDSNSettings(settings map[string]string, s string) error { + m := dsnRegexp.FindAllStringSubmatch(s, -1) + + for _, b := range m { + settings[b[1]] = b[2] + } + + return nil +} + +type pgTLSArgs struct { + sslMode string + sslRootCert string + sslCert string + sslKey string +} + +// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is +// necessary to allow returning multiple TLS configs as sslmode "allow" and +// "prefer" allow fallback. +func configTLS(settings map[string]string) ([]*tls.Config, error) { + host := settings["host"] + sslmode := settings["sslmode"] + sslrootcert := settings["sslrootcert"] + sslcert := settings["sslcert"] + sslkey := settings["sslkey"] + + // Match libpq default behavior + if sslmode == "" { + sslmode = "prefer" + } + + tlsConfig := &tls.Config{} + + switch sslmode { + case "disable": + return []*tls.Config{nil}, nil + case "allow", "prefer": + tlsConfig.InsecureSkipVerify = true + case "require": + tlsConfig.InsecureSkipVerify = sslrootcert == "" + case "verify-ca", "verify-full": + tlsConfig.ServerName = host + default: + return nil, errors.New("sslmode is invalid") + } + + if sslrootcert != "" { + caCertPool := x509.NewCertPool() + + caPath := sslrootcert + caCert, err := ioutil.ReadFile(caPath) + if err != nil { + return nil, errors.Wrapf(err, "unable to read CA file %q", caPath) + } + + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, errors.Wrap(err, "unable to add CA to cert pool") + } + + tlsConfig.RootCAs = caCertPool + tlsConfig.ClientCAs = caCertPool + } + + if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { + return nil, fmt.Errorf(`both "sslcert" and "sslkey" are required`) + } + + if sslcert != "" && sslkey != "" { + cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + if err != nil { + return nil, errors.Wrap(err, "unable to read cert") + } + + tlsConfig.Certificates = []tls.Certificate{cert} + } + + switch sslmode { + case "allow": + return []*tls.Config{nil, tlsConfig}, nil + case "prefer": + return []*tls.Config{tlsConfig, nil}, nil + case "require", "verify-ca", "verify-full": + return []*tls.Config{tlsConfig}, nil + default: + panic("BUG: bad sslmode should already have been caught") + } +} + +func parsePort(s string) (uint16, error) { + port, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + if port < 1 || port > math.MaxUint16 { + return 0, errors.New("outside range") + } + return uint16(port), nil +} + +func makeDefaultDialer() *net.Dialer { + return &net.Dialer{KeepAlive: 5 * time.Minute} +} + +func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { + timeout, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return nil, err + } + if timeout < 0 { + return nil, errors.New("negative timeout") + } + + d := makeDefaultDialer() + d.Timeout = time.Duration(timeout) * time.Second + return d.DialContext, nil +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 00000000..796876f2 --- /dev/null +++ b/config_test.go @@ -0,0 +1,392 @@ +package pgconn_test + +import ( + "crypto/tls" + "fmt" + "io/ioutil" + "os" + "os/user" + "testing" + + "github.com/jackc/pgx/pgconn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseConfig(t *testing.T) { + t.Parallel() + + var osUserName string + osUser, err := user.Current() + if err == nil { + osUserName = osUser.Username + } + + tests := []struct { + name string + connString string + config *pgconn.Config + }{ + // Test all sslmodes + { + name: "sslmode not set (prefer)", + connString: "postgres://jack:secret@localhost:5432/mydb", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "sslmode disable", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode allow", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=allow", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + }, + }, + }, + { + name: "sslmode prefer", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", + config: &pgconn.Config{ + + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "sslmode require", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=require", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode verify-ca", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ServerName: "localhost"}, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode verify-full", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-full", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ServerName: "localhost"}, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url everything", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + { + name: "database url missing password", + connString: "postgres://jack@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url missing user and password", + connString: "postgres://localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: osUserName, + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url missing port", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url unix domain socket host", + connString: "postgres:///foo?host=/tmp", + config: &pgconn.Config{ + User: osUserName, + Host: "/tmp", + Port: 5432, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN everything", + connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + } + + for i, tt := range tests { + config, err := pgconn.ParseConfig(tt.connString) + if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + +func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { + assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) + assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) + assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) + assert.Equalf(t, expected.User, actual.User, "%s - User", testName) + assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) + assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) + + if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { + if expected.TLSConfig != nil { + assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) + } + } + + if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks %v", testName) { + for i := range expected.Fallbacks { + assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) + assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) + + if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName) { + if expected.Fallbacks[i].TLSConfig != nil { + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) + } + } + } + } +} + +func TestParseConfigEnvLibpq(t *testing.T) { + var osUserName string + osUser, err := user.Current() + if err == nil { + osUserName = osUser.Username + } + + pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"} + + savedEnv := make(map[string]string) + for _, n := range pgEnvvars { + savedEnv[n] = os.Getenv(n) + } + defer func() { + for k, v := range savedEnv { + err := os.Setenv(k, v) + if err != nil { + t.Fatalf("Unable to restore environment: %v", err) + } + } + }() + + tests := []struct { + name string + envvars map[string]string + config *pgconn.Config + }{ + { + // not testing no environment at all as that would use default host and that can vary. + name: "PGHOST only", + envvars: map[string]string{"PGHOST": "123.123.123.123"}, + config: &pgconn.Config{ + User: osUserName, + Host: "123.123.123.123", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "123.123.123.123", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "All non-TLS environment", + envvars: map[string]string{ + "PGHOST": "123.123.123.123", + "PGPORT": "7777", + "PGDATABASE": "foo", + "PGUSER": "bar", + "PGPASSWORD": "baz", + "PGCONNECT_TIMEOUT": "10", + "PGSSLMODE": "disable", + "PGAPPNAME": "pgxtest", + }, + config: &pgconn.Config{ + Host: "123.123.123.123", + Port: 7777, + Database: "foo", + User: "bar", + Password: "baz", + TLSConfig: nil, + RuntimeParams: map[string]string{"application_name": "pgxtest"}, + }, + }, + } + + for i, tt := range tests { + for _, n := range pgEnvvars { + err := os.Unsetenv(n) + require.Nil(t, err) + } + + for k, v := range tt.envvars { + err := os.Setenv(k, v) + require.Nil(t, err) + } + + config, err := pgconn.ParseConfig("") + if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + +func TestParseConfigReadsPgPassfile(t *testing.T) { + tf, err := ioutil.TempFile("", "") + require.Nil(t, err) + + defer tf.Close() + defer os.Remove(tf.Name()) + + _, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk")) + require.Nil(t, err) + + connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name()) + expected := &pgconn.Config{ + User: "curly", + Password: "nyuknyuknyuk", + Host: "test1", + Port: 5432, + Database: "curlydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + } + + actual, err := pgconn.ParseConfig(connString) + assert.Nil(t, err) + + assertConfigsEqual(t, expected, actual, "passfile") +} diff --git a/pgconn.go b/pgconn.go index c9caef42..37a205dc 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1,20 +1,16 @@ package pgconn import ( + "context" "crypto/md5" "crypto/tls" "encoding/binary" "encoding/hex" "errors" - "fmt" "io" "net" - "os" - "os/user" - "path/filepath" "strconv" "strings" - "time" "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" @@ -23,7 +19,7 @@ import ( const batchBufferSize = 4096 // PgError represents an error reported by the PostgreSQL server. See -// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for +// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. type PgError struct { Severity string @@ -50,60 +46,12 @@ func (pe PgError) Error() string { } // DialFunc is a function that can be used to connect to a PostgreSQL server -type DialFunc func(network, addr string) (net.Conn, error) +type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // ErrTLSRefused occurs when the connection attempt requires TLS and the // PostgreSQL server refuses to use TLS var ErrTLSRefused = errors.New("server refused TLS connection") -type ConnConfig struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 // default: 5432 - Database string - User string // default: OS user name - Password string - TLSConfig *tls.Config // config for TLS connection -- nil disables TLS - Dial DialFunc - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) -} - -func (cc *ConnConfig) NetworkAddress() (network, address string) { - // If host is a valid path, then address is unix socket - if _, err := os.Stat(cc.Host); err == nil { - network = "unix" - address = cc.Host - if !strings.Contains(address, "/.s.PGSQL.") { - address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) - } - } else { - network = "tcp" - address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) - } - - return network, address -} - -func (cc *ConnConfig) assignDefaults() error { - if cc.User == "" { - user, err := user.Current() - if err != nil { - return err - } - cc.User = user.Username - } - - if cc.Port == 0 { - cc.Port = 5432 - } - - if cc.Dial == nil { - defaultDialer := &net.Dialer{KeepAlive: 5 * time.Minute} - cc.Dial = defaultDialer.Dial - } - - return nil -} - // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { NetConn net.Conn // the underlying TCP or unix domain socket connection @@ -113,7 +61,7 @@ type PgConn struct { TxStatus byte Frontend *pgproto3.Frontend - Config ConnConfig + Config *Config batchBuf []byte batchCount int32 @@ -123,24 +71,72 @@ type PgConn struct { closed bool } -func Connect(cc ConnConfig) (*PgConn, error) { - err := cc.assignDefaults() +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) +// to provide configuration. See documention for ParseConfig for details. ctx can be used to cancel a connect attempt. +func Connect(ctx context.Context, connString string) (*PgConn, error) { + config, err := ParseConfig(connString) if err != nil { return nil, err } - pgConn := new(PgConn) - pgConn.Config = cc + return ConnectConfig(ctx, config) +} - pgConn.NetConn, err = cc.Dial(cc.NetworkAddress()) +// Connect establishes a connection to a PostgreSQL server using config. ctx can be used to cancel a connect attempt. +// +// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An +// authentication error will terminate the chain of attempts (like libpq: +// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, +// if all attempts fail the last error is returned. +func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) { + // For convenience set a few defaults if not already set. This makes it simpler to directly construct a config. + if config.Port == 0 { + config.Port = 5432 + } + if config.DialFunc == nil { + config.DialFunc = makeDefaultDialer().DialContext + } + if config.RuntimeParams == nil { + config.RuntimeParams = make(map[string]string) + } + + // Simplify usage by treating primary config and fallbacks the same. + fallbackConfigs := []*FallbackConfig{ + { + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + } + fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) + + for _, fc := range fallbackConfigs { + pgConn, err = connect(ctx, config, fc) + if err == nil { + return pgConn, nil + } else if err, ok := err.(PgError); ok { + return nil, err + } + } + + return nil, err +} + +func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { + pgConn := new(PgConn) + pgConn.Config = config + + var err error + network, address := NetworkAddress(config.Host, config.Port) + pgConn.NetConn, err = config.DialFunc(ctx, network, address) if err != nil { return nil, err } pgConn.parameterStatuses = make(map[string]string) - if cc.TLSConfig != nil { - if err := pgConn.startTLS(cc.TLSConfig); err != nil { + if config.TLSConfig != nil { + if err := pgConn.startTLS(config.TLSConfig); err != nil { return nil, err } } @@ -156,13 +152,13 @@ func Connect(cc ConnConfig) (*PgConn, error) { } // Copy default run-time params - for k, v := range cc.RuntimeParams { + for k, v := range config.RuntimeParams { startupMsg.Parameters[k] = v } - startupMsg.Parameters["user"] = cc.User - if cc.Database != "" { - startupMsg.Parameters["database"] = cc.Database + startupMsg.Parameters["user"] = config.User + if config.Database != "" { + startupMsg.Parameters["database"] = config.Database } if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil { diff --git a/pgconn_test.go b/pgconn_test.go index dbcf2704..f165786e 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1,16 +1,18 @@ package pgconn_test import ( - "github.com/jackc/pgx/pgconn" - + "context" + "os" "testing" + "github.com/jackc/pgx/pgconn" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestSimple(t *testing.T) { - pgConn, err := pgconn.Connect(pgconn.ConnConfig{Host: "/var/run/postgresql", User: "jack", Database: "pgx_test"}) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) pgConn.SendExec("select current_database()") From c4080cce35dcf4f76c7807f4d0e5fd98593a9521 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 30 Dec 2018 21:10:06 -0600 Subject: [PATCH 0145/1158] Move connection tests to pgconn --- helper_test.go | 13 +++++ pgconn_test.go | 146 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 159 insertions(+) create mode 100644 helper_test.go diff --git a/helper_test.go b/helper_test.go new file mode 100644 index 00000000..e6a7c73b --- /dev/null +++ b/helper_test.go @@ -0,0 +1,13 @@ +package pgconn_test + +import ( + "testing" + + "github.com/jackc/pgx/pgconn" + + "github.com/stretchr/testify/require" +) + +func closeConn(t testing.TB, conn *pgconn.PgConn) { + require.Nil(t, conn.Close()) +} diff --git a/pgconn_test.go b/pgconn_test.go index f165786e..9e16e925 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -2,15 +2,161 @@ package pgconn_test import ( "context" + "crypto/tls" + "net" "os" "testing" + "github.com/jackc/pgx" "github.com/jackc/pgx/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestConnect(t *testing.T) { + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } + + conn, err := pgconn.Connect(context.Background(), connString) + require.Nil(t, err) + + err = conn.Close() + require.Nil(t, err) + }) + } +} + +// TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure +// connection. +func TestConnectTLS(t *testing.T) { + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } + + conn, err := pgconn.Connect(context.Background(), connString) + require.Nil(t, err) + + if _, ok := conn.NetConn.(*tls.Conn); !ok { + t.Error("not a TLS connection") + } + + err = conn.Close() + require.Nil(t, err) +} + +func TestConnectInvalidUser(t *testing.T) { + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + config, err := pgconn.ParseConfig(connString) + require.Nil(t, err) + + config.User = "pgxinvalidusertest" + + conn, err := pgconn.ConnectConfig(context.Background(), config) + if err == nil { + conn.Close() + t.Fatal("expected err but got none") + } + pgErr, ok := err.(pgx.PgError) + if !ok { + t.Fatalf("Expected to receive a PgError, instead received: %v", err) + } + if pgErr.Code != "28000" && pgErr.Code != "28P01" { + t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) + } +} + +func TestConnectWithConnectionRefused(t *testing.T) { + t.Parallel() + + // Presumably nothing is listening on 127.0.0.1:1 + conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") + if err == nil { + conn.Close() + t.Fatal("Expected error establishing connection to bad port") + } +} + +func TestConnectCustomDialer(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + dialed := false + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialed = true + return net.Dial(network, address) + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + require.True(t, dialed) + conn.Close() +} + +func TestConnectWithRuntimeParams(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + config.RuntimeParams = map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, conn) + + // TODO - refactor these selects once there are higher level query functions + + conn.SendExec("show application_name") + conn.SendExec("show search_path") + err = conn.Flush() + require.Nil(t, err) + + result := conn.GetResult() + require.NotNil(t, result) + + rowFound := result.NextRow() + assert.True(t, rowFound) + if rowFound { + assert.Equal(t, "pgxtest", string(result.Value(0))) + } + + _, err = result.Close() + assert.Nil(t, err) + + result = conn.GetResult() + require.NotNil(t, result) + + rowFound = result.NextRow() + assert.True(t, rowFound) + if rowFound { + assert.Equal(t, "myschema", string(result.Value(0))) + } + + _, err = result.Close() + assert.Nil(t, err) +} + func TestSimple(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) From 1836f7be464fb7ce1b69c7cec17dba86d2437634 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 11:14:13 -0600 Subject: [PATCH 0146/1158] Support comma separated hosts and ports like libpq Also add test and fix the fallback config implementation. --- config.go | 138 +++++++++++++++++++++++++------------------ config_test.go | 155 ++++++++++++++++++++++++++++++++++++++++++++++++- pgconn.go | 2 +- pgconn_test.go | 31 ++++++++++ 4 files changed, 267 insertions(+), 59 deletions(-) diff --git a/config.go b/config.go index 515d6356..d2001dc5 100644 --- a/config.go +++ b/config.go @@ -55,21 +55,23 @@ func NetworkAddress(host string, port uint16) (network, address string) { return network, address } -// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq. -// It uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment -// variables. connString may be a URL or a DSN. It also may be empty to only read from the -// environment. If a password is not supplied it will attempt to read the .pgpass file. +// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same +// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. connString may be a URL or a DSN. +// It also may be empty to only read from the environment. If a password is not supplied it will attempt to read the +// .pgpass file. // -// Example DSN: "user=jack password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-ca" +// Example DSN: "user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca" // -// Example URL: "postgres://jack:secret@1.2.3.4:5432/mydb?sslmode=verify-ca" +// Example URL: "postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca" // -// Multiple configs may be returned due to sslmode settings with fallback options (e.g. -// sslmode=prefer). Future implementations may also support multiple hosts -// (https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS). +// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated +// values that will be tried in order. This can be used as part of a high availability system. See +// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. // -// ParseConfig currently recognizes the following environment variable and their parameter key word -// equivalents passed via database URL or DSN: +// Example URL: "postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb" +// +// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed +// via database URL or DSN: // // PGHOST // PGPORT @@ -84,20 +86,18 @@ func NetworkAddress(host string, port uint16) (network, address string) { // PGAPPNAME // PGCONNECT_TIMEOUT // -// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of -// environment variables. +// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. // -// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key -// word names. They are usually but not always the environment variable name downcased and without -// the "PG" prefix. +// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are +// usually but not always the environment variable name downcased and without the "PG" prefix. // // Important TLS Security Notes: // -// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to -// "prefer" behavior if not set. +// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if +// not set. // -// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on -// what level of security each sslmode provides. +// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of +// security each sslmode provides. // // "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger // security guarantees than it would with libpq. Do not rely on this behavior as it @@ -110,12 +110,7 @@ func ParseConfig(connString string) (*Config, error) { if connString != "" { // connString may be a database URL or a DSN if strings.HasPrefix(connString, "postgres://") { - url, err := url.Parse(connString) - if err != nil { - return nil, err - } - - err = addURLSettings(settings, url) + err := addURLSettings(settings, connString) if err != nil { return nil, err } @@ -128,19 +123,12 @@ func ParseConfig(connString string) (*Config, error) { } config := &Config{ - Host: settings["host"], Database: settings["database"], User: settings["user"], Password: settings["password"], RuntimeParams: make(map[string]string), } - if port, err := parsePort(settings["port"]); err == nil { - config.Port = port - } else { - return nil, fmt.Errorf("invalid port: %v", settings["port"]) - } - if connectTimeout, present := settings["connect_timeout"]; present { dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) if err != nil { @@ -173,28 +161,50 @@ func ParseConfig(connString string) (*Config, error) { config.RuntimeParams[k] = v } - var tlsConfigs []*tls.Config + fallbacks := []*FallbackConfig{} - // Ignore TLS settings if Unix domain socket like libpq - if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { - tlsConfigs = append(tlsConfigs, nil) - } else { - var err error - tlsConfigs, err = configTLS(settings) + hosts := strings.Split(settings["host"], ",") + ports := strings.Split(settings["port"], ",") + + for i, host := range hosts { + var portStr string + if i < len(ports) { + portStr = ports[i] + } else { + portStr = ports[0] + } + + port, err := parsePort(portStr) if err != nil { - return nil, err + return nil, fmt.Errorf("invalid port: %v", settings["port"]) + } + + var tlsConfigs []*tls.Config + + // Ignore TLS settings if Unix domain socket like libpq + if network, _ := NetworkAddress(host, port); network == "unix" { + tlsConfigs = append(tlsConfigs, nil) + } else { + var err error + tlsConfigs, err = configTLS(settings) + if err != nil { + return nil, err + } + } + + for _, tlsConfig := range tlsConfigs { + fallbacks = append(fallbacks, &FallbackConfig{ + Host: host, + Port: port, + TLSConfig: tlsConfig, + }) } } - config.TLSConfig = tlsConfigs[0] - - for _, tlsConfig := range tlsConfigs[1:] { - config.Fallbacks = append(config.Fallbacks, &FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: tlsConfig, - }) - } + config.Host = fallbacks[0].Host + config.Port = fallbacks[0].Port + config.TLSConfig = fallbacks[0].TLSConfig + config.Fallbacks = fallbacks[1:] passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) if err == nil { @@ -272,7 +282,12 @@ func addEnvSettings(settings map[string]string) { } } -func addURLSettings(settings map[string]string, url *url.URL) error { +func addURLSettings(settings map[string]string, connString string) error { + url, err := url.Parse(connString) + if err != nil { + return err + } + if url.User != nil { settings["user"] = url.User.Username() if password, present := url.User.Password(); present { @@ -280,12 +295,23 @@ func addURLSettings(settings map[string]string, url *url.URL) error { } } - parts := strings.SplitN(url.Host, ":", 2) - if parts[0] != "" { - settings["host"] = parts[0] + // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. + var hosts []string + var ports []string + for _, host := range strings.Split(url.Host, ",") { + parts := strings.SplitN(host, ":", 2) + if parts[0] != "" { + hosts = append(hosts, parts[0]) + } + if len(parts) == 2 { + ports = append(ports, parts[1]) + } } - if len(parts) == 2 { - settings["port"] = parts[1] + if len(hosts) > 0 { + settings["host"] = strings.Join(hosts, ",") + } + if len(ports) > 0 { + settings["port"] = strings.Join(ports, ",") } database := strings.TrimLeft(url.Path, "/") diff --git a/config_test.go b/config_test.go index 796876f2..566a44f0 100644 --- a/config_test.go +++ b/config_test.go @@ -230,6 +230,150 @@ func TestParseConfig(t *testing.T) { }, }, }, + { + name: "URL multiple hosts", + connString: "postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "URL multiple hosts and ports", + connString: "postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 2, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 3, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "DSN multiple hosts one port", + connString: "user=jack password=secret host=foo,bar,baz port=5432 database=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "DSN multiple hosts multiple ports", + connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 database=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 2, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 3, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "multiple hosts and fallback tsl", + connString: "user=jack password=secret host=foo,bar,baz database=mydb sslmode=prefer", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "foo", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }}, + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }}, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, } for i, tt := range tests { @@ -243,6 +387,13 @@ func TestParseConfig(t *testing.T) { } func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { + if !assert.NotNil(t, expected) { + return + } + if !assert.NotNil(t, actual) { + return + } + assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) @@ -257,12 +408,12 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName } } - if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks %v", testName) { + if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { for i := range expected.Fallbacks { assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) - if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName) { + if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { if expected.Fallbacks[i].TLSConfig != nil { assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) diff --git a/pgconn.go b/pgconn.go index 37a205dc..09860eb2 100644 --- a/pgconn.go +++ b/pgconn.go @@ -127,7 +127,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.Config = config var err error - network, address := NetworkAddress(config.Host, config.Port) + network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) pgConn.NetConn, err = config.DialFunc(ctx, network, address) if err != nil { return nil, err diff --git a/pgconn_test.go b/pgconn_test.go index 9e16e925..d53bbc09 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -157,6 +157,37 @@ func TestConnectWithRuntimeParams(t *testing.T) { assert.Nil(t, err) } +func TestConnectWithFallback(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + // Prepend current primary config to fallbacks + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) + + // Make primary config bad + config.Host = "localhost" + config.Port = 1 // presumably nothing listening here + + // Prepend bad first fallback + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 1, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + closeConn(t, conn) +} + func TestSimple(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) From 5ae6310b058d73bc8fe19e6e71a857a4d3796eff Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 11:39:22 -0600 Subject: [PATCH 0147/1158] Add AcceptConnFunc for filtering HA connections --- config.go | 7 +++++++ pgconn.go | 11 ++++++++++- pgconn_test.go | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index d2001dc5..a07fa533 100644 --- a/config.go +++ b/config.go @@ -20,6 +20,8 @@ import ( "github.com/pkg/errors" ) +type AcceptConnFunc func(pgconn *PgConn) bool + // Config is the settings used to establish a connection to a PostgreSQL server. type Config struct { Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) @@ -32,6 +34,11 @@ type Config struct { RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) Fallbacks []*FallbackConfig + + // AcceptConnFunc is called after successful connection allow custom logic for determining if the connection is + // acceptable. If AcceptConnFunc returns false the connection is closed and the next fallback config is tried. This + // allows implementing high availability behavior such as libpq does with target_session_attrs. + AcceptConnFunc AcceptConnFunc } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a diff --git a/pgconn.go b/pgconn.go index 09860eb2..ac48f870 100644 --- a/pgconn.go +++ b/pgconn.go @@ -137,6 +137,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if config.TLSConfig != nil { if err := pgConn.startTLS(config.TLSConfig); err != nil { + pgConn.NetConn.Close() return nil, err } } @@ -162,6 +163,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil { + pgConn.NetConn.Close() return nil, err } @@ -177,13 +179,19 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.SecretKey = msg.SecretKey case *pgproto3.Authentication: if err = pgConn.rxAuthenticationX(msg); err != nil { + pgConn.NetConn.Close() return nil, err } case *pgproto3.ReadyForQuery: - return pgConn, nil + if config.AcceptConnFunc == nil || config.AcceptConnFunc(pgConn) { + return pgConn, nil + } + pgConn.NetConn.Close() + return nil, errors.New("AcceptConnFunc rejected connection") case *pgproto3.ParameterStatus: // handled by ReceiveMessage case *pgproto3.ErrorResponse: + pgConn.NetConn.Close() return nil, PgError{ Severity: msg.Severity, Code: msg.Code, @@ -204,6 +212,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig Routine: msg.Routine, } default: + pgConn.NetConn.Close() return nil, errors.New("unexpected message") } } diff --git a/pgconn_test.go b/pgconn_test.go index d53bbc09..ad06ae7b 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -188,6 +188,40 @@ func TestConnectWithFallback(t *testing.T) { closeConn(t, conn) } +func TestConnectWithAcceptConnFunc(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + dialCount := 0 + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialCount += 1 + return net.Dial(network, address) + } + + acceptConnCount := 0 + config.AcceptConnFunc = func(conn *pgconn.PgConn) bool { + acceptConnCount += 1 + return acceptConnCount > 1 + } + + // Append current primary config to fallbacks + config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }) + + // Repeat fallbacks + config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + closeConn(t, conn) + + assert.True(t, dialCount > 1) + assert.True(t, acceptConnCount > 1) +} + func TestSimple(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) From 8c574c39f830d10c0b5a1c4ad46cc1e010646071 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 12:14:41 -0600 Subject: [PATCH 0148/1158] Add support for libpq target_session_attrs Generalize AcceptConnFunc into AfterConnectFunc. --- config.go | 93 +++++++++++++++++++++++++++++++++++--------------- config_test.go | 17 +++++++++ pgconn.go | 12 ++++--- pgconn_test.go | 23 +++++++++++-- 4 files changed, 111 insertions(+), 34 deletions(-) diff --git a/config.go b/config.go index a07fa533..38144be7 100644 --- a/config.go +++ b/config.go @@ -20,7 +20,7 @@ import ( "github.com/pkg/errors" ) -type AcceptConnFunc func(pgconn *PgConn) bool +type AfterConnectFunc func(pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. type Config struct { @@ -35,10 +35,10 @@ type Config struct { Fallbacks []*FallbackConfig - // AcceptConnFunc is called after successful connection allow custom logic for determining if the connection is - // acceptable. If AcceptConnFunc returns false the connection is closed and the next fallback config is tried. This + // AfterConnectFunc is called after successful connection. It can be used to set up the connection or to validate that + // server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This // allows implementing high availability behavior such as libpq does with target_session_attrs. - AcceptConnFunc AcceptConnFunc + AfterConnectFunc AfterConnectFunc } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a @@ -92,6 +92,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // PGSSLROOTCERT // PGAPPNAME // PGCONNECT_TIMEOUT +// PGTARGETSESSIONATTRS // // See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. // @@ -148,17 +149,18 @@ func ParseConfig(connString string) (*Config, error) { } notRuntimeParams := map[string]struct{}{ - "host": struct{}{}, - "port": struct{}{}, - "database": struct{}{}, - "user": struct{}{}, - "password": struct{}{}, - "passfile": struct{}{}, - "connect_timeout": struct{}{}, - "sslmode": struct{}{}, - "sslkey": struct{}{}, - "sslcert": struct{}{}, - "sslrootcert": struct{}{}, + "host": struct{}{}, + "port": struct{}{}, + "database": struct{}{}, + "user": struct{}{}, + "password": struct{}{}, + "passfile": struct{}{}, + "connect_timeout": struct{}{}, + "sslmode": struct{}{}, + "sslkey": struct{}{}, + "sslcert": struct{}{}, + "sslrootcert": struct{}{}, + "target_session_attrs": struct{}{}, } for k, v := range settings { @@ -225,6 +227,12 @@ func ParseConfig(connString string) (*Config, error) { } } + if settings["target_session_attrs"] == "read-write" { + config.AfterConnectFunc = AfterConnectTargetSessionAttrsReadWrite + } else if settings["target_session_attrs"] != "any" { + return nil, fmt.Errorf("unknown target_session_attrs value %v", settings["target_session_attrs"]) + } + return config, nil } @@ -243,6 +251,8 @@ func defaultSettings() map[string]string { settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") } + settings["target_session_attrs"] = "any" + return settings } @@ -267,18 +277,19 @@ func defaultHost() string { func addEnvSettings(settings map[string]string) { nameMap := map[string]string{ - "PGHOST": "host", - "PGPORT": "port", - "PGDATABASE": "database", - "PGUSER": "user", - "PGPASSWORD": "password", - "PGPASSFILE": "passfile", - "PGAPPNAME": "application_name", - "PGCONNECT_TIMEOUT": "connect_timeout", - "PGSSLMODE": "sslmode", - "PGSSLKEY": "sslkey", - "PGSSLCERT": "sslcert", - "PGSSLROOTCERT": "sslrootcert", + "PGHOST": "host", + "PGPORT": "port", + "PGDATABASE": "database", + "PGUSER": "user", + "PGPASSWORD": "password", + "PGPASSFILE": "passfile", + "PGAPPNAME": "application_name", + "PGCONNECT_TIMEOUT": "connect_timeout", + "PGSSLMODE": "sslmode", + "PGSSLKEY": "sslkey", + "PGSSLCERT": "sslcert", + "PGSSLROOTCERT": "sslrootcert", + "PGTARGETSESSIONATTRS": "target_session_attrs", } for envname, realname := range nameMap { @@ -452,3 +463,31 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { d.Timeout = time.Duration(timeout) * time.Second return d.DialContext, nil } + +// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible +// target_session_attrs=read-write. +func AfterConnectTargetSessionAttrsReadWrite(pgConn *PgConn) error { + pgConn.SendExec("show transaction_read_only") + err := pgConn.Flush() + if err != nil { + return err + } + + result := pgConn.GetResult() + if err != nil { + return err + } + + rowFound := result.NextRow() + if !rowFound { + return errors.New("show transaction_read_only failed") + } + + if string(result.Value(0)) == "on" { + return errors.New("read only connection") + } + + _, err = result.Close() + + return err +} diff --git a/config_test.go b/config_test.go index 566a44f0..36f3fee2 100644 --- a/config_test.go +++ b/config_test.go @@ -374,6 +374,20 @@ func TestParseConfig(t *testing.T) { }, }, }, + { + name: "target_session_attrs", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + AfterConnectFunc: pgconn.AfterConnectTargetSessionAttrsReadWrite, + }, + }, } for i, tt := range tests { @@ -401,6 +415,9 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) + // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.AfterConnectFunc == nil, actual.AfterConnectFunc == nil, "%s - AfterConnectFunc", testName) + if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { if expected.TLSConfig != nil { assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) diff --git a/pgconn.go b/pgconn.go index ac48f870..94397759 100644 --- a/pgconn.go +++ b/pgconn.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "encoding/hex" "errors" + "fmt" "io" "net" "strconv" @@ -183,11 +184,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, err } case *pgproto3.ReadyForQuery: - if config.AcceptConnFunc == nil || config.AcceptConnFunc(pgConn) { - return pgConn, nil + if config.AfterConnectFunc != nil { + err := config.AfterConnectFunc(pgConn) + if err != nil { + pgConn.NetConn.Close() + return nil, fmt.Errorf("AfterConnectFunc: %v", err) + } } - pgConn.NetConn.Close() - return nil, errors.New("AcceptConnFunc rejected connection") + return pgConn, nil case *pgproto3.ParameterStatus: // handled by ReceiveMessage case *pgproto3.ErrorResponse: diff --git a/pgconn_test.go b/pgconn_test.go index ad06ae7b..0dccc99f 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -9,6 +9,7 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgconn" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -188,7 +189,7 @@ func TestConnectWithFallback(t *testing.T) { closeConn(t, conn) } -func TestConnectWithAcceptConnFunc(t *testing.T) { +func TestConnectWithAfterConnectFunc(t *testing.T) { config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) @@ -199,9 +200,12 @@ func TestConnectWithAcceptConnFunc(t *testing.T) { } acceptConnCount := 0 - config.AcceptConnFunc = func(conn *pgconn.PgConn) bool { + config.AfterConnectFunc = func(conn *pgconn.PgConn) error { acceptConnCount += 1 - return acceptConnCount > 1 + if acceptConnCount < 2 { + return errors.New("reject first conn") + } + return nil } // Append current primary config to fallbacks @@ -222,6 +226,19 @@ func TestConnectWithAcceptConnFunc(t *testing.T) { assert.True(t, acceptConnCount > 1) } +func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + config.AfterConnectFunc = pgconn.AfterConnectTargetSessionAttrsReadWrite + config.RuntimeParams["default_transaction_read_only"] = "on" + + conn, err := pgconn.ConnectConfig(context.Background(), config) + if !assert.NotNil(t, err) { + conn.Close() + } +} + func TestSimple(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) From b419493e5ca130ab9c4eae74eb64c23467d73843 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 13:32:26 -0600 Subject: [PATCH 0149/1158] Add pgconn.Exec --- config.go | 2 +- pgconn.go | 109 +++++++++++++++++++++++++++++++++++++++++-------- pgconn_test.go | 76 +++++++++++++++------------------- 3 files changed, 128 insertions(+), 59 deletions(-) diff --git a/config.go b/config.go index 38144be7..a446a67e 100644 --- a/config.go +++ b/config.go @@ -483,7 +483,7 @@ func AfterConnectTargetSessionAttrsReadWrite(pgConn *PgConn) error { return errors.New("show transaction_read_only failed") } - if string(result.Value(0)) == "on" { + if string(result.Values()[0]) == "on" { return errors.New("read only connection") } diff --git a/pgconn.go b/pgconn.go index 94397759..c243d2f6 100644 --- a/pgconn.go +++ b/pgconn.go @@ -340,10 +340,9 @@ func (ct CommandTag) RowsAffected() int64 { return n } -// SendExec enqueues the execution of sql via the PostgreSQL simple query -// protocol. sql may contain multipe queries. Multiple queries will be processed -// within a single transation. It is only sent to the PostgreSQL server when -// Flush is called. +// SendExec enqueues the execution of sql via the PostgreSQL simple query protocol. sql may contain multiple queries. +// Execution is implicitly wrapped in a transactions unless a transaction is already in progress or sql contains +// transaction control statements. It is only sent to the PostgreSQL server when Flush is called. func (pgConn *PgConn) SendExec(sql string) { pgConn.batchBuf = appendQuery(pgConn.batchBuf, sql) pgConn.batchCount += 1 @@ -359,30 +358,51 @@ func appendQuery(buf []byte, query string) []byte { } type PgResultReader struct { - pgConn *PgConn - fieldDescriptions []pgproto3.FieldDescription - rowValues [][]byte - commandTag CommandTag - err error - complete bool + pgConn *PgConn + fieldDescriptions []pgproto3.FieldDescription + rowValues [][]byte + commandTag CommandTag + err error + complete bool + preloadedRowValues bool } // GetResult returns a PgResultReader for the next result. If all results are // consumed it returns nil. If an error occurs it will be reported on the // returned PgResultReader. func (pgConn *PgConn) GetResult() *PgResultReader { - if pgConn.pendingReadyForQueryCount == 0 { - return nil + for pgConn.pendingReadyForQueryCount > 0 { + msg, err := pgConn.ReceiveMessage() + if err != nil { + return &PgResultReader{pgConn: pgConn, err: err, complete: true} + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + return &PgResultReader{pgConn: pgConn, fieldDescriptions: msg.Fields} + case *pgproto3.DataRow: + return &PgResultReader{pgConn: pgConn, rowValues: msg.Values, preloadedRowValues: true} + case *pgproto3.CommandComplete: + return &PgResultReader{pgConn: pgConn, commandTag: CommandTag(msg.CommandTag), complete: true} + case *pgproto3.ErrorResponse: + return &PgResultReader{pgConn: pgConn, err: errorResponseToPgError(msg), complete: true} + } } - return &PgResultReader{pgConn: pgConn} + return nil } -func (rr *PgResultReader) NextRow() (present bool) { +// NextRow returns advances the PgResultReader to the next row and returns true if a row is available. +func (rr *PgResultReader) NextRow() bool { if rr.complete { return false } + if rr.preloadedRowValues { + rr.preloadedRowValues = false + return true + } + for { msg, err := rr.pgConn.ReceiveMessage() if err != nil { @@ -396,6 +416,7 @@ func (rr *PgResultReader) NextRow() (present bool) { rr.rowValues = msg.Values return true case *pgproto3.CommandComplete: + rr.rowValues = nil rr.commandTag = CommandTag(msg.CommandTag) rr.complete = true return false @@ -407,8 +428,11 @@ func (rr *PgResultReader) NextRow() (present bool) { } } -func (rr *PgResultReader) Value(c int) []byte { - return rr.rowValues[c] +// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only +// valid until the next NextRow call or the PgResultReader is closed. However, the underlying byte data is safe to +// retain a reference to and mutate. +func (rr *PgResultReader) Values() [][]byte { + return rr.rowValues } // Close consumes any remaining result data and returns the command tag or @@ -418,6 +442,8 @@ func (rr *PgResultReader) Close() (CommandTag, error) { return rr.commandTag, rr.err } + rr.rowValues = nil + for { msg, err := rr.pgConn.ReceiveMessage() if err != nil { @@ -464,6 +490,57 @@ func (pgConn *PgConn) resetBatch() { } } +type PgResult struct { + Rows [][][]byte + CommandTag CommandTag +} + +// Exec executes sql via the PostgreSQL simple query protocol, buffers the entire result, and returns it. sql may +// contain multiple queries, but only the last results will be returned. Execution is implicitly wrapped in a +// transactions unless a transaction is already in progress or sql contains transaction control statements. +// +// Exec must not be called when there are pending results from previous Send* methods (e.g. SendExec). +func (pgConn *PgConn) Exec(sql string) (*PgResult, error) { + if pgConn.batchCount != 0 { + return nil, errors.New("unflushed previous sends") + } + if pgConn.pendingReadyForQueryCount != 0 { + return nil, errors.New("unread previous results") + } + + pgConn.SendExec(sql) + err := pgConn.Flush() + if err != nil { + return nil, err + } + + var result *PgResult + + for resultReader := pgConn.GetResult(); resultReader != nil; resultReader = pgConn.GetResult() { + rows := [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + + commandTag, err := resultReader.Close() + if err != nil { + return nil, err + } + + result = &PgResult{ + Rows: rows, + CommandTag: commandTag, + } + } + if result == nil { + return nil, errors.New("unexpected missing result") + } + + return result, nil +} + func errorResponseToPgError(msg *pgproto3.ErrorResponse) PgError { return PgError{ Severity: msg.Severity, diff --git a/pgconn_test.go b/pgconn_test.go index 0dccc99f..f3f22d42 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -126,36 +126,15 @@ func TestConnectWithRuntimeParams(t *testing.T) { require.Nil(t, err) defer closeConn(t, conn) - // TODO - refactor these selects once there are higher level query functions - - conn.SendExec("show application_name") - conn.SendExec("show search_path") - err = conn.Flush() + result, err := conn.Exec("show application_name") require.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result := conn.GetResult() - require.NotNil(t, result) - - rowFound := result.NextRow() - assert.True(t, rowFound) - if rowFound { - assert.Equal(t, "pgxtest", string(result.Value(0))) - } - - _, err = result.Close() - assert.Nil(t, err) - - result = conn.GetResult() - require.NotNil(t, result) - - rowFound = result.NextRow() - assert.True(t, rowFound) - if rowFound { - assert.Equal(t, "myschema", string(result.Value(0))) - } - - _, err = result.Close() - assert.Nil(t, err) + result, err = conn.Exec("show search_path") + require.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "myschema", string(result.Rows[0][0])) } func TestConnectWithFallback(t *testing.T) { @@ -239,26 +218,39 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { } } -func TestSimple(t *testing.T) { +func TestExec(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) + defer closeConn(t, pgConn) - pgConn.SendExec("select current_database()") - err = pgConn.Flush() + result, err := pgConn.Exec("select current_database()") require.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0])) +} - result := pgConn.GetResult() - require.NotNil(t, result) +func TestExecMultipleQueries(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) - rowFound := result.NextRow() - assert.True(t, rowFound) - if rowFound { - assert.Equal(t, "pgx_test", string(result.Value(0))) - } + result, err := pgConn.Exec("select current_database(); select 1") + require.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "1", string(result.Rows[0][0])) +} - _, err = result.Close() - assert.Nil(t, err) +func TestExecMultipleQueriesError(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) - err = pgConn.Close() - assert.Nil(t, err) + result, err := pgConn.Exec("select 1; select 1/0; select 1") + require.NotNil(t, err) + require.Nil(t, result) + if pgErr, ok := err.(pgconn.PgError); ok { + assert.Equal(t, "22012", pgErr.Code) + } else { + t.Errorf("unexpected error: %v", err) + } } From 4e12c08b04a441cfbedf8c0f7dcaac9414ca9f26 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 14:14:40 -0600 Subject: [PATCH 0150/1158] Use buffered exec --- config.go | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/config.go b/config.go index a446a67e..4d8bee4c 100644 --- a/config.go +++ b/config.go @@ -467,27 +467,14 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { // AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible // target_session_attrs=read-write. func AfterConnectTargetSessionAttrsReadWrite(pgConn *PgConn) error { - pgConn.SendExec("show transaction_read_only") - err := pgConn.Flush() + result, err := pgConn.Exec("show transaction_read_only") if err != nil { return err } - result := pgConn.GetResult() - if err != nil { - return err - } - - rowFound := result.NextRow() - if !rowFound { - return errors.New("show transaction_read_only failed") - } - - if string(result.Values()[0]) == "on" { + if string(result.Rows[0][0]) == "on" { return errors.New("read only connection") } - _, err = result.Close() - - return err + return nil } From 4ee6fef45286e5a0056b0d07a1a388be151b92cd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 17:17:11 -0600 Subject: [PATCH 0151/1158] Add context to potentially blocking methods --- config.go | 7 ++- helper_test.go | 6 +- pgconn.go | 165 ++++++++++++++++++++++++++++++++++++++++--------- pgconn_test.go | 67 +++++++++++++++----- 4 files changed, 195 insertions(+), 50 deletions(-) diff --git a/config.go b/config.go index 4d8bee4c..d8872f66 100644 --- a/config.go +++ b/config.go @@ -1,6 +1,7 @@ package pgconn import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -20,7 +21,7 @@ import ( "github.com/pkg/errors" ) -type AfterConnectFunc func(pgconn *PgConn) error +type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. type Config struct { @@ -466,8 +467,8 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { // AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible // target_session_attrs=read-write. -func AfterConnectTargetSessionAttrsReadWrite(pgConn *PgConn) error { - result, err := pgConn.Exec("show transaction_read_only") +func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { + result, err := pgConn.Exec(ctx, "show transaction_read_only") if err != nil { return err } diff --git a/helper_test.go b/helper_test.go index e6a7c73b..8e7ca92f 100644 --- a/helper_test.go +++ b/helper_test.go @@ -1,7 +1,9 @@ package pgconn_test import ( + "context" "testing" + "time" "github.com/jackc/pgx/pgconn" @@ -9,5 +11,7 @@ import ( ) func closeConn(t testing.TB, conn *pgconn.PgConn) { - require.Nil(t, conn.Close()) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.Nil(t, conn.Close(ctx)) } diff --git a/pgconn.go b/pgconn.go index c243d2f6..311b06a3 100644 --- a/pgconn.go +++ b/pgconn.go @@ -12,6 +12,7 @@ import ( "net" "strconv" "strings" + "time" "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" @@ -19,6 +20,8 @@ import ( const batchBufferSize = 4096 +var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) + // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. @@ -185,7 +188,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } case *pgproto3.ReadyForQuery: if config.AfterConnectFunc != nil { - err := config.AfterConnectFunc(pgConn) + err := config.AfterConnectFunc(ctx, pgConn) if err != nil { pgConn.NetConn.Close() return nil, fmt.Errorf("AfterConnectFunc: %v", err) @@ -296,24 +299,28 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { return msg, nil } -// Close closes a connection. It is safe to call Close on a already closed -// connection. -func (pgConn *PgConn) Close() error { +// Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by +// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The +// underlying net.Conn.Close() will always be called regardless of any other errors. +func (pgConn *PgConn) Close(ctx context.Context) error { if pgConn.closed { return nil } pgConn.closed = true + defer pgConn.NetConn.Close() + + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + defer cleanupContext() + _, err := pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { - pgConn.NetConn.Close() - return err + return preferContextOverNetTimeoutError(ctx, err) } _, err = pgConn.NetConn.Read(make([]byte, 1)) if err != io.EOF { - pgConn.NetConn.Close() - return err + return preferContextOverNetTimeoutError(ctx, err) } return pgConn.NetConn.Close() @@ -365,30 +372,38 @@ type PgResultReader struct { err error complete bool preloadedRowValues bool + ctx context.Context + cleanupContext func() } // GetResult returns a PgResultReader for the next result. If all results are // consumed it returns nil. If an error occurs it will be reported on the // returned PgResultReader. -func (pgConn *PgConn) GetResult() *PgResultReader { +func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + for pgConn.pendingReadyForQueryCount > 0 { msg, err := pgConn.ReceiveMessage() if err != nil { - return &PgResultReader{pgConn: pgConn, err: err, complete: true} + cleanupContext() + return &PgResultReader{pgConn: pgConn, ctx: ctx, err: preferContextOverNetTimeoutError(ctx, err), complete: true} } switch msg := msg.(type) { case *pgproto3.RowDescription: - return &PgResultReader{pgConn: pgConn, fieldDescriptions: msg.Fields} + return &PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, fieldDescriptions: msg.Fields} case *pgproto3.DataRow: - return &PgResultReader{pgConn: pgConn, rowValues: msg.Values, preloadedRowValues: true} + return &PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, rowValues: msg.Values, preloadedRowValues: true} case *pgproto3.CommandComplete: - return &PgResultReader{pgConn: pgConn, commandTag: CommandTag(msg.CommandTag), complete: true} + cleanupContext() + return &PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true} case *pgproto3.ErrorResponse: - return &PgResultReader{pgConn: pgConn, err: errorResponseToPgError(msg), complete: true} + cleanupContext() + return &PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true} } } + cleanupContext() return nil } @@ -406,6 +421,8 @@ func (rr *PgResultReader) NextRow() bool { for { msg, err := rr.pgConn.ReceiveMessage() if err != nil { + rr.err = preferContextOverNetTimeoutError(rr.ctx, err) + rr.close() return false } @@ -416,13 +433,12 @@ func (rr *PgResultReader) NextRow() bool { rr.rowValues = msg.Values return true case *pgproto3.CommandComplete: - rr.rowValues = nil rr.commandTag = CommandTag(msg.CommandTag) - rr.complete = true + rr.close() return false case *pgproto3.ErrorResponse: rr.err = errorResponseToPgError(msg) - rr.complete = true + rr.close() return false } } @@ -441,46 +457,137 @@ func (rr *PgResultReader) Close() (CommandTag, error) { if rr.complete { return rr.commandTag, rr.err } - - rr.rowValues = nil + defer rr.close() for { msg, err := rr.pgConn.ReceiveMessage() if err != nil { - rr.err = err - rr.complete = true + rr.err = preferContextOverNetTimeoutError(rr.ctx, err) return rr.commandTag, rr.err } switch msg := msg.(type) { case *pgproto3.CommandComplete: rr.commandTag = CommandTag(msg.CommandTag) - rr.complete = true return rr.commandTag, rr.err case *pgproto3.ErrorResponse: rr.err = errorResponseToPgError(msg) - rr.complete = true return rr.commandTag, rr.err } } } +func (rr *PgResultReader) close() { + if rr.complete { + return + } + + rr.cleanupContext() + rr.rowValues = nil + rr.complete = true +} + // Flush sends the enqueued execs to the server. -func (pgConn *PgConn) Flush() error { +func (pgConn *PgConn) Flush(ctx context.Context) error { defer pgConn.resetBatch() + cleanup := contextDoneToConnDeadline(ctx, pgConn.NetConn) + defer cleanup() + n, err := pgConn.NetConn.Write(pgConn.batchBuf) if err != nil { if n > 0 { - // TODO - kill connection - we sent a partial message + // Close connection because cannot recover from partially sent message. + pgConn.NetConn.Close() + pgConn.closed = true } - return err + return preferContextOverNetTimeoutError(ctx, err) } pgConn.pendingReadyForQueryCount += pgConn.batchCount return nil } +// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from +// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to +// call multiple times. +func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) { + if ctx.Done() != nil { + deadlineWasSet := false + doneChan := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + conn.SetDeadline(deadlineTime) + deadlineWasSet = true + <-doneChan + // TODO + case <-doneChan: + } + }() + + finished := false + return func() { + if !finished { + doneChan <- struct{}{} + if deadlineWasSet { + conn.SetDeadline(time.Time{}) + } + finished = true + } + } + } + + return func() {} +} + +// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == +// true. Otherwise returns err. +func preferContextOverNetTimeoutError(ctx context.Context, err error) error { + if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { + return ctx.Err() + } + return err +} + +// RecoverFromTimeout attempts to recover from a timeout error such as is caused by a canceled context. If recovery is +// successful true is returned. If recovery is not successful the connection is closed and false it returned. Recovery +// should usually be possible except in the case of a partial write. This must be called after any context cancellation. +// +// As RecoverFromTimeout may need to read and ignored data already sent from the server, it potentially can block +// indefinitely. Use ctx to guard against this. +func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { + if pgConn.closed { + return false + } + pgConn.resetBatch() + + pgConn.NetConn.SetDeadline(time.Time{}) + + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + defer cleanupContext() + + for pgConn.pendingReadyForQueryCount > 0 { + _, err := pgConn.ReceiveMessage() + if err != nil { + preferContextOverNetTimeoutError(ctx, err) + pgConn.Close(context.Background()) + return false + } + } + + result, err := pgConn.Exec( + context.Background(), // do not use ctx again because deadline goroutine already started above + "select 'RecoverFromTimeout'", + ) + if err != nil || len(result.Rows) != 1 || len(result.Rows[0]) != 1 || string(result.Rows[0][0]) != "RecoverFromTimeout" { + pgConn.Close(context.Background()) + return false + } + + return true +} + func (pgConn *PgConn) resetBatch() { pgConn.batchCount = 0 if len(pgConn.batchBuf) > batchBufferSize { @@ -500,7 +607,7 @@ type PgResult struct { // transactions unless a transaction is already in progress or sql contains transaction control statements. // // Exec must not be called when there are pending results from previous Send* methods (e.g. SendExec). -func (pgConn *PgConn) Exec(sql string) (*PgResult, error) { +func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { if pgConn.batchCount != 0 { return nil, errors.New("unflushed previous sends") } @@ -509,14 +616,14 @@ func (pgConn *PgConn) Exec(sql string) (*PgResult, error) { } pgConn.SendExec(sql) - err := pgConn.Flush() + err := pgConn.Flush(ctx) if err != nil { return nil, err } var result *PgResult - for resultReader := pgConn.GetResult(); resultReader != nil; resultReader = pgConn.GetResult() { + for resultReader := pgConn.GetResult(ctx); resultReader != nil; resultReader = pgConn.GetResult(ctx) { rows := [][][]byte{} for resultReader.NextRow() { row := make([][]byte, len(resultReader.Values())) diff --git a/pgconn_test.go b/pgconn_test.go index f3f22d42..98fd198e 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -6,6 +6,7 @@ import ( "net" "os" "testing" + "time" "github.com/jackc/pgx" "github.com/jackc/pgx/pgconn" @@ -36,8 +37,7 @@ func TestConnect(t *testing.T) { conn, err := pgconn.Connect(context.Background(), connString) require.Nil(t, err) - err = conn.Close() - require.Nil(t, err) + closeConn(t, conn) }) } } @@ -57,8 +57,7 @@ func TestConnectTLS(t *testing.T) { t.Error("not a TLS connection") } - err = conn.Close() - require.Nil(t, err) + closeConn(t, conn) } func TestConnectInvalidUser(t *testing.T) { @@ -74,7 +73,7 @@ func TestConnectInvalidUser(t *testing.T) { conn, err := pgconn.ConnectConfig(context.Background(), config) if err == nil { - conn.Close() + conn.Close(context.Background()) t.Fatal("expected err but got none") } pgErr, ok := err.(pgx.PgError) @@ -92,7 +91,7 @@ func TestConnectWithConnectionRefused(t *testing.T) { // Presumably nothing is listening on 127.0.0.1:1 conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") if err == nil { - conn.Close() + conn.Close(context.Background()) t.Fatal("Expected error establishing connection to bad port") } } @@ -110,7 +109,7 @@ func TestConnectCustomDialer(t *testing.T) { conn, err := pgconn.ConnectConfig(context.Background(), config) require.Nil(t, err) require.True(t, dialed) - conn.Close() + closeConn(t, conn) } func TestConnectWithRuntimeParams(t *testing.T) { @@ -126,12 +125,12 @@ func TestConnectWithRuntimeParams(t *testing.T) { require.Nil(t, err) defer closeConn(t, conn) - result, err := conn.Exec("show application_name") + result, err := conn.Exec(context.Background(), "show application_name") require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result, err = conn.Exec("show search_path") + result, err = conn.Exec(context.Background(), "show search_path") require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "myschema", string(result.Rows[0][0])) @@ -179,7 +178,7 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { } acceptConnCount := 0 - config.AfterConnectFunc = func(conn *pgconn.PgConn) error { + config.AfterConnectFunc = func(ctx context.Context, conn *pgconn.PgConn) error { acceptConnCount += 1 if acceptConnCount < 2 { return errors.New("reject first conn") @@ -214,38 +213,38 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { conn, err := pgconn.ConnectConfig(context.Background(), config) if !assert.NotNil(t, err) { - conn.Close() + conn.Close(context.Background()) } } -func TestExec(t *testing.T) { +func TestConnExec(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec("select current_database()") + result, err := pgConn.Exec(context.Background(), "select current_database()") require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0])) } -func TestExecMultipleQueries(t *testing.T) { +func TestConnExecMultipleQueries(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec("select current_database(); select 1") + result, err := pgConn.Exec(context.Background(), "select current_database(); select 1") require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "1", string(result.Rows[0][0])) } -func TestExecMultipleQueriesError(t *testing.T) { +func TestConnExecMultipleQueriesError(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec("select 1; select 1/0; select 1") + result, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1") require.NotNil(t, err) require.Nil(t, result) if pgErr, ok := err.(pgconn.PgError); ok { @@ -254,3 +253,37 @@ func TestExecMultipleQueriesError(t *testing.T) { t.Errorf("unexpected error: %v", err) } } + +func TestConnExecContextCanceled(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") + require.Nil(t, result) + assert.Equal(t, context.DeadlineExceeded, err) +} + +func TestConnRecoverFromTimeout(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") + cancel() + require.Nil(t, result) + assert.Equal(t, context.DeadlineExceeded, err) + + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + if assert.True(t, pgConn.RecoverFromTimeout(ctx)) { + result, err := pgConn.Exec(ctx, "select 1") + require.Nil(t, err) + assert.Len(t, result.Rows, 1) + assert.Len(t, result.Rows[0], 1) + assert.Equal(t, "1", string(result.Rows[0][0])) + } + cancel() +} From 53175a7badc5a1a035517900f0d42a323d03f04b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 17:32:04 -0600 Subject: [PATCH 0152/1158] Add cancel request to PgConn RecoverFromTimeout automatically tries to cancel in progress requests. --- pgconn.go | 41 +++++++++++++++++++++++++++++++++++++++++ pgconn_test.go | 22 ++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/pgconn.go b/pgconn.go index 311b06a3..a7c4eea3 100644 --- a/pgconn.go +++ b/pgconn.go @@ -562,8 +562,14 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { } pgConn.resetBatch() + // Clear any existing timeout pgConn.NetConn.SetDeadline(time.Time{}) + // Try to cancel any in-progress requests + for i := 0; i < int(pgConn.pendingReadyForQueryCount); i++ { + pgConn.CancelRequest(ctx) + } + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) defer cleanupContext() @@ -669,3 +675,38 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) PgError { Routine: msg.Routine, } } + +// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel +// request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there +// is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 +func (pgConn *PgConn) CancelRequest(ctx context.Context) error { + // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing + // the connection config. This is important in high availability configurations where fallback connections may be + // specified or DNS may be used to load balance. + serverAddr := pgConn.NetConn.RemoteAddr() + cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + if err != nil { + return err + } + defer cancelConn.Close() + + cleanupContext := contextDoneToConnDeadline(ctx, cancelConn) + defer cleanupContext() + + buf := make([]byte, 16) + binary.BigEndian.PutUint32(buf[0:4], 16) + binary.BigEndian.PutUint32(buf[4:8], 80877102) + binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.PID)) + binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.SecretKey)) + _, err = cancelConn.Write(buf) + if err != nil { + return preferContextOverNetTimeoutError(ctx, err) + } + + _, err = cancelConn.Read(buf) + if err != io.EOF { + return fmt.Errorf("Server failed to close connection after cancel query request: %v", preferContextOverNetTimeoutError(ctx, err)) + } + + return nil +} diff --git a/pgconn_test.go b/pgconn_test.go index 98fd198e..9873013c 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -264,6 +264,8 @@ func TestConnExecContextCanceled(t *testing.T) { result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") require.Nil(t, result) assert.Equal(t, context.DeadlineExceeded, err) + + assert.True(t, pgConn.RecoverFromTimeout(context.Background())) } func TestConnRecoverFromTimeout(t *testing.T) { @@ -287,3 +289,23 @@ func TestConnRecoverFromTimeout(t *testing.T) { } cancel() } + +func TestConnCancelQuery(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + pgConn.SendExec("select current_database(), pg_sleep(5)") + err = pgConn.Flush(context.Background()) + require.Nil(t, err) + + err = pgConn.CancelRequest(context.Background()) + require.Nil(t, err) + + _, err = pgConn.GetResult(context.Background()).Close() + if err, ok := err.(pgconn.PgError); ok { + assert.Equal(t, "57014", err.Code) + } else { + t.Errorf("expected pgconn.PgError got %v", err) + } +} From bcc3da490cd2c06889c03f18c0a0e41eea51e45d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 17:34:44 -0600 Subject: [PATCH 0153/1158] Run tests in parallel --- config_test.go | 2 ++ pgconn_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/config_test.go b/config_test.go index 36f3fee2..e7a5bb44 100644 --- a/config_test.go +++ b/config_test.go @@ -533,6 +533,8 @@ func TestParseConfigEnvLibpq(t *testing.T) { } func TestParseConfigReadsPgPassfile(t *testing.T) { + t.Parallel() + tf, err := ioutil.TempFile("", "") require.Nil(t, err) diff --git a/pgconn_test.go b/pgconn_test.go index 9873013c..741c1b4b 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -29,6 +29,8 @@ func TestConnect(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() + connString := os.Getenv(tt.env) if connString == "" { t.Skipf("Skipping due to missing environment variable %v", tt.env) @@ -45,6 +47,8 @@ func TestConnect(t *testing.T) { // TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure // connection. func TestConnectTLS(t *testing.T) { + t.Parallel() + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") if connString == "" { t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") @@ -61,6 +65,8 @@ func TestConnectTLS(t *testing.T) { } func TestConnectInvalidUser(t *testing.T) { + t.Parallel() + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") if connString == "" { t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") @@ -97,6 +103,8 @@ func TestConnectWithConnectionRefused(t *testing.T) { } func TestConnectCustomDialer(t *testing.T) { + t.Parallel() + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) @@ -113,6 +121,8 @@ func TestConnectCustomDialer(t *testing.T) { } func TestConnectWithRuntimeParams(t *testing.T) { + t.Parallel() + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) @@ -137,6 +147,8 @@ func TestConnectWithRuntimeParams(t *testing.T) { } func TestConnectWithFallback(t *testing.T) { + t.Parallel() + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) @@ -168,6 +180,8 @@ func TestConnectWithFallback(t *testing.T) { } func TestConnectWithAfterConnectFunc(t *testing.T) { + t.Parallel() + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) @@ -205,6 +219,8 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { } func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { + t.Parallel() + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) @@ -218,6 +234,8 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { } func TestConnExec(t *testing.T) { + t.Parallel() + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) @@ -229,6 +247,8 @@ func TestConnExec(t *testing.T) { } func TestConnExecMultipleQueries(t *testing.T) { + t.Parallel() + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) @@ -240,6 +260,8 @@ func TestConnExecMultipleQueries(t *testing.T) { } func TestConnExecMultipleQueriesError(t *testing.T) { + t.Parallel() + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) @@ -255,6 +277,8 @@ func TestConnExecMultipleQueriesError(t *testing.T) { } func TestConnExecContextCanceled(t *testing.T) { + t.Parallel() + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) @@ -269,6 +293,8 @@ func TestConnExecContextCanceled(t *testing.T) { } func TestConnRecoverFromTimeout(t *testing.T) { + t.Parallel() + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) @@ -291,6 +317,8 @@ func TestConnRecoverFromTimeout(t *testing.T) { } func TestConnCancelQuery(t *testing.T) { + t.Parallel() + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) From 49c9674102c3c151a004f2ef1d54072c9cb8244d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 17:46:56 -0600 Subject: [PATCH 0154/1158] PG error type is *pgconn.PgError --- pgconn.go | 10 +++++----- pgconn_test.go | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pgconn.go b/pgconn.go index a7c4eea3..fef113e0 100644 --- a/pgconn.go +++ b/pgconn.go @@ -45,7 +45,7 @@ type PgError struct { Routine string } -func (pe PgError) Error() string { +func (pe *PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } @@ -118,7 +118,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err pgConn, err = connect(ctx, config, fc) if err == nil { return pgConn, nil - } else if err, ok := err.(PgError); ok { + } else if err, ok := err.(*PgError); ok { return nil, err } } @@ -199,7 +199,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig // handled by ReceiveMessage case *pgproto3.ErrorResponse: pgConn.NetConn.Close() - return nil, PgError{ + return nil, &PgError{ Severity: msg.Severity, Code: msg.Code, Message: msg.Message, @@ -654,8 +654,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { return result, nil } -func errorResponseToPgError(msg *pgproto3.ErrorResponse) PgError { - return PgError{ +func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { + return &PgError{ Severity: msg.Severity, Code: msg.Code, Message: msg.Message, diff --git a/pgconn_test.go b/pgconn_test.go index 741c1b4b..e46093b0 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -269,7 +269,7 @@ func TestConnExecMultipleQueriesError(t *testing.T) { result, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1") require.NotNil(t, err) require.Nil(t, result) - if pgErr, ok := err.(pgconn.PgError); ok { + if pgErr, ok := err.(*pgconn.PgError); ok { assert.Equal(t, "22012", pgErr.Code) } else { t.Errorf("unexpected error: %v", err) @@ -331,7 +331,7 @@ func TestConnCancelQuery(t *testing.T) { require.Nil(t, err) _, err = pgConn.GetResult(context.Background()).Close() - if err, ok := err.(pgconn.PgError); ok { + if err, ok := err.(*pgconn.PgError); ok { assert.Equal(t, "57014", err.Code) } else { t.Errorf("expected pgconn.PgError got %v", err) From f5faed65688c703f48a5712b2a41fc7db928fea9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 18:00:08 -0600 Subject: [PATCH 0155/1158] Access underlying net.Conn via method Also remove some dead code. --- pgconn.go | 57 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/pgconn.go b/pgconn.go index fef113e0..776141f9 100644 --- a/pgconn.go +++ b/pgconn.go @@ -58,7 +58,7 @@ var ErrTLSRefused = errors.New("server refused TLS connection") // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { - NetConn net.Conn // the underlying TCP or unix domain socket connection + conn net.Conn // the underlying TCP or unix domain socket connection PID uint32 // backend pid SecretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server @@ -132,7 +132,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) - pgConn.NetConn, err = config.DialFunc(ctx, network, address) + pgConn.conn, err = config.DialFunc(ctx, network, address) if err != nil { return nil, err } @@ -141,12 +141,12 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if config.TLSConfig != nil { if err := pgConn.startTLS(config.TLSConfig); err != nil { - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, err } } - pgConn.Frontend, err = pgproto3.NewFrontend(pgConn.NetConn, pgConn.NetConn) + pgConn.Frontend, err = pgproto3.NewFrontend(pgConn.conn, pgConn.conn) if err != nil { return nil, err } @@ -166,8 +166,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig startupMsg.Parameters["database"] = config.Database } - if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil { - pgConn.NetConn.Close() + if _, err := pgConn.conn.Write(startupMsg.Encode(nil)); err != nil { + pgConn.conn.Close() return nil, err } @@ -183,14 +183,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.SecretKey = msg.SecretKey case *pgproto3.Authentication: if err = pgConn.rxAuthenticationX(msg); err != nil { - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, err } case *pgproto3.ReadyForQuery: if config.AfterConnectFunc != nil { err := config.AfterConnectFunc(ctx, pgConn) if err != nil { - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, fmt.Errorf("AfterConnectFunc: %v", err) } } @@ -198,7 +198,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig case *pgproto3.ParameterStatus: // handled by ReceiveMessage case *pgproto3.ErrorResponse: - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, &PgError{ Severity: msg.Severity, Code: msg.Code, @@ -219,20 +219,20 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig Routine: msg.Routine, } default: - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, errors.New("unexpected message") } } } func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { - err = binary.Write(pgConn.NetConn, binary.BigEndian, []int32{8, 80877103}) + err = binary.Write(pgConn.conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return } response := make([]byte, 1) - if _, err = io.ReadFull(pgConn.NetConn, response); err != nil { + if _, err = io.ReadFull(pgConn.conn, response); err != nil { return } @@ -240,7 +240,7 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { return ErrTLSRefused } - pgConn.NetConn = tls.Client(pgConn.NetConn, tlsConfig) + pgConn.conn = tls.Client(pgConn.conn, tlsConfig) return nil } @@ -262,7 +262,7 @@ func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { func (pgConn *PgConn) txPasswordMessage(password string) (err error) { msg := &pgproto3.PasswordMessage{Password: password} - _, err = pgConn.NetConn.Write(msg.Encode(nil)) + _, err = pgConn.conn.Write(msg.Encode(nil)) return err } @@ -299,6 +299,11 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { return msg, nil } +// Conn returns the underlying net.Conn. +func (pgConn *PgConn) Conn() net.Conn { + return pgConn.conn +} + // Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The // underlying net.Conn.Close() will always be called regardless of any other errors. @@ -308,22 +313,22 @@ func (pgConn *PgConn) Close(ctx context.Context) error { } pgConn.closed = true - defer pgConn.NetConn.Close() + defer pgConn.conn.Close() - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContext() - _, err := pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4}) + _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { return preferContextOverNetTimeoutError(ctx, err) } - _, err = pgConn.NetConn.Read(make([]byte, 1)) + _, err = pgConn.conn.Read(make([]byte, 1)) if err != io.EOF { return preferContextOverNetTimeoutError(ctx, err) } - return pgConn.NetConn.Close() + return pgConn.conn.Close() } // ParameterStatus returns the value of a parameter reported by the server (e.g. @@ -380,7 +385,7 @@ type PgResultReader struct { // consumed it returns nil. If an error occurs it will be reported on the // returned PgResultReader. func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) for pgConn.pendingReadyForQueryCount > 0 { msg, err := pgConn.ReceiveMessage() @@ -491,14 +496,14 @@ func (rr *PgResultReader) close() { func (pgConn *PgConn) Flush(ctx context.Context) error { defer pgConn.resetBatch() - cleanup := contextDoneToConnDeadline(ctx, pgConn.NetConn) + cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanup() - n, err := pgConn.NetConn.Write(pgConn.batchBuf) + n, err := pgConn.conn.Write(pgConn.batchBuf) if err != nil { if n > 0 { // Close connection because cannot recover from partially sent message. - pgConn.NetConn.Close() + pgConn.conn.Close() pgConn.closed = true } return preferContextOverNetTimeoutError(ctx, err) @@ -563,14 +568,14 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { pgConn.resetBatch() // Clear any existing timeout - pgConn.NetConn.SetDeadline(time.Time{}) + pgConn.conn.SetDeadline(time.Time{}) // Try to cancel any in-progress requests for i := 0; i < int(pgConn.pendingReadyForQueryCount); i++ { pgConn.CancelRequest(ctx) } - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContext() for pgConn.pendingReadyForQueryCount > 0 { @@ -683,7 +688,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. - serverAddr := pgConn.NetConn.RemoteAddr() + serverAddr := pgConn.conn.RemoteAddr() cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err From 2f156c7add3a4026af92c7a9626ec2ea85e17f61 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 18:03:55 -0600 Subject: [PATCH 0156/1158] Access PID and SecretKey via method --- pgconn.go | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/pgconn.go b/pgconn.go index 776141f9..87ba0096 100644 --- a/pgconn.go +++ b/pgconn.go @@ -59,8 +59,8 @@ var ErrTLSRefused = errors.New("server refused TLS connection") // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn // the underlying TCP or unix domain socket connection - PID uint32 // backend pid - SecretKey uint32 // key to use to send a cancel query message to the server + pid uint32 // backend pid + secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server TxStatus byte Frontend *pgproto3.Frontend @@ -179,8 +179,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig switch msg := msg.(type) { case *pgproto3.BackendKeyData: - pgConn.PID = msg.ProcessID - pgConn.SecretKey = msg.SecretKey + pgConn.pid = msg.ProcessID + pgConn.secretKey = msg.SecretKey case *pgproto3.Authentication: if err = pgConn.rxAuthenticationX(msg); err != nil { pgConn.conn.Close() @@ -304,6 +304,16 @@ func (pgConn *PgConn) Conn() net.Conn { return pgConn.conn } +// PID returns the backend PID. +func (pgConn *PgConn) PID() uint32 { + return pgConn.pid +} + +// SecretKey returns the backend secret key used to send a cancel query message to the server. +func (pgConn *PgConn) SecretKey() uint32 { + return pgConn.secretKey +} + // Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The // underlying net.Conn.Close() will always be called regardless of any other errors. @@ -701,8 +711,8 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) binary.BigEndian.PutUint32(buf[4:8], 80877102) - binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.PID)) - binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.SecretKey)) + binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) + binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) _, err = cancelConn.Write(buf) if err != nil { return preferContextOverNetTimeoutError(ctx, err) From 650aa7059a3ac45b9e4215ecc8cdb21e35846e8a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 18:45:51 -0600 Subject: [PATCH 0157/1158] Fix broken tests --- pgconn_test.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pgconn_test.go b/pgconn_test.go index e46093b0..05318dac 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/jackc/pgx" "github.com/jackc/pgx/pgconn" "github.com/pkg/errors" @@ -57,7 +56,7 @@ func TestConnectTLS(t *testing.T) { conn, err := pgconn.Connect(context.Background(), connString) require.Nil(t, err) - if _, ok := conn.NetConn.(*tls.Conn); !ok { + if _, ok := conn.Conn().(*tls.Conn); !ok { t.Error("not a TLS connection") } @@ -82,7 +81,7 @@ func TestConnectInvalidUser(t *testing.T) { conn.Close(context.Background()) t.Fatal("expected err but got none") } - pgErr, ok := err.(pgx.PgError) + pgErr, ok := err.(*pgconn.PgError) if !ok { t.Fatalf("Expected to receive a PgError, instead received: %v", err) } From 5f69253174a8f3712c90f5b791a0467a89d347a2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 19:59:32 -0600 Subject: [PATCH 0158/1158] Added ExecParams --- pgconn.go | 171 +++++++++++++++++++++++++++++++++++++++++++++++++ pgconn_test.go | 31 ++++++++- 2 files changed, 201 insertions(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 87ba0096..db9c758d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -20,6 +20,12 @@ import ( const batchBufferSize = 4096 +// PostgreSQL extended protocol format codes +const ( + TextFormatCode = 0 + BinaryFormatCode = 1 +) + var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) // PgError represents an error reported by the PostgreSQL server. See @@ -379,6 +385,127 @@ func appendQuery(buf []byte, query string) []byte { return buf } +// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it. +func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []byte { + buf = append(buf, 'P') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, name...) + buf = append(buf, 0) + buf = append(buf, query...) + buf = append(buf, 0) + + buf = pgio.AppendInt16(buf, int16(len(paramOIDs))) + for _, oid := range paramOIDs { + buf = pgio.AppendUint32(buf, oid) + } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + +// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it. +func appendSync(buf []byte) []byte { + buf = append(buf, 'S') + buf = pgio.AppendInt32(buf, 4) + + return buf +} + +// appendBind appends a PostgreSQL wire protocol bind message to buf and returns it. +func appendBind( + buf []byte, + destinationPortal, + preparedStatement string, + paramFormats []int16, + paramValues [][]byte, + resultFormatCodes []int16, +) []byte { + if len(paramFormats) != 0 && len(paramFormats) != len(paramValues) && len(paramFormats) != len(paramValues) { + panic(fmt.Sprintf("len(paramFormats) must be 0, 1, or len(paramValues), received %d", len(paramFormats))) + } + + buf = append(buf, 'B') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, destinationPortal...) + buf = append(buf, 0) + buf = append(buf, preparedStatement...) + buf = append(buf, 0) + + buf = pgio.AppendInt16(buf, int16(len(paramFormats))) + for _, f := range paramFormats { + buf = pgio.AppendInt16(buf, f) + } + + buf = pgio.AppendInt16(buf, int16(len(paramValues))) + for _, p := range paramValues { + if p == nil { + buf = pgio.AppendInt32(buf, -1) + continue + } + + buf = pgio.AppendInt32(buf, int32(len(p))) + buf = append(buf, p...) + } + + buf = pgio.AppendInt16(buf, int16(len(resultFormatCodes))) + for _, fc := range resultFormatCodes { + buf = pgio.AppendInt16(buf, fc) + } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + +// appendExecute appends a PostgreSQL wire protocol execute message to buf and returns it. +func appendExecute(buf []byte, portal string, maxRows uint32) []byte { + buf = append(buf, 'E') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf = append(buf, portal...) + buf = append(buf, 0) + buf = pgio.AppendUint32(buf, maxRows) + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + +// SendExecParams enqueues the execution of sql via the PostgreSQL extended query protocol. +// +// sql is a SQL command string. It may only contain one query. Parameter substitution is position using $1, $2, $3, etc. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for +// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. +// SendExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all results will be in text protocol. SendExecParams will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text protocol. +// +// Query is only sent to the PostgreSQL server when Flush is called. +func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { + if len(paramValues) > 65535 { + panic(fmt.Sprintf("Number of params 0 and 65535, received %d", len(paramValues))) + } + if len(paramOIDs) != 0 && len(paramOIDs) != len(paramValues) && len(paramOIDs) != len(paramValues) { + panic(fmt.Sprintf("len(paramOIDs) must be 0, 1, or len(paramValues), received %d", len(paramOIDs))) + } + + pgConn.batchBuf = appendParse(pgConn.batchBuf, "", sql, paramOIDs) + pgConn.batchBuf = appendBind(pgConn.batchBuf, "", "", paramFormats, paramValues, resultFormats) + pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0) + pgConn.batchBuf = appendSync(pgConn.batchBuf) + pgConn.batchCount += 1 +} + type PgResultReader struct { pgConn *PgConn fieldDescriptions []pgproto3.FieldDescription @@ -669,6 +796,50 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { return result, nil } +// ExecParams executes sql via the PostgreSQL extended query protocol, buffers the entire result, and returns it. See +// SendExecParams for parameter descriptions. +// +// ExecParams must not be called when there are pending results from previous Send* methods (e.g. SendExec). +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) (*PgResult, error) { + if pgConn.batchCount != 0 { + return nil, errors.New("unflushed previous sends") + } + if pgConn.pendingReadyForQueryCount != 0 { + return nil, errors.New("unread previous results") + } + + pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats) + err := pgConn.Flush(ctx) + if err != nil { + return nil, err + } + + resultReader := pgConn.GetResult(ctx) + if resultReader == nil { + return nil, errors.New("unexpected missing result") + } + + var result *PgResult + rows := [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + + commandTag, err := resultReader.Close() + if err != nil { + return nil, err + } + + result = &PgResult{ + Rows: rows, + CommandTag: commandTag, + } + + return result, nil +} + func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ Severity: msg.Severity, diff --git a/pgconn_test.go b/pgconn_test.go index 05318dac..fa1ec5fc 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -285,7 +285,36 @@ func TestConnExecContextCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") - require.Nil(t, result) + assert.Nil(t, result) + assert.Equal(t, context.DeadlineExceeded, err) + + assert.True(t, pgConn.RecoverFromTimeout(context.Background())) +} + +func TestConnExecParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + result, err := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + require.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "Hello, world", string(result.Rows[0][0])) +} + +func TestConnExecParamsCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result, err := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil) + assert.Nil(t, result) assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgConn.RecoverFromTimeout(context.Background())) From 13323df0dd20310714151bada713d8f168a672df Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 20:08:11 -0600 Subject: [PATCH 0159/1158] Add batched query test --- pgconn_test.go | 90 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/pgconn_test.go b/pgconn_test.go index fa1ec5fc..a765dc4c 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -320,6 +320,96 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.True(t, pgConn.RecoverFromTimeout(context.Background())) } +func TestConnBatchedQueries(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + pgConn.SendExec("select 'SendExec 1'") + pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 1")}, nil, nil, nil) + pgConn.SendExec("select 'SendExec 2'") + pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 2")}, nil, nil, nil) + err = pgConn.Flush(context.Background()) + + // "select 'SendExec 1'" + resultReader := pgConn.GetResult(context.Background()) + require.NotNil(t, resultReader) + + rows := [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + require.Len(t, rows, 1) + require.Len(t, rows[0], 1) + assert.Equal(t, "SendExec 1", string(rows[0][0])) + + commandTag, err := resultReader.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) + + // "SendExecParams 1" + resultReader = pgConn.GetResult(context.Background()) + require.NotNil(t, resultReader) + + rows = [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + require.Len(t, rows, 1) + require.Len(t, rows[0], 1) + assert.Equal(t, "SendExecParams 1", string(rows[0][0])) + + commandTag, err = resultReader.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) + + // "SendExec 2" + resultReader = pgConn.GetResult(context.Background()) + require.NotNil(t, resultReader) + + rows = [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + require.Len(t, rows, 1) + require.Len(t, rows[0], 1) + assert.Equal(t, "SendExec 2", string(rows[0][0])) + + commandTag, err = resultReader.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) + + // "SendExecParams 2" + resultReader = pgConn.GetResult(context.Background()) + require.NotNil(t, resultReader) + + rows = [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + require.Len(t, rows, 1) + require.Len(t, rows[0], 1) + assert.Equal(t, "SendExecParams 2", string(rows[0][0])) + + commandTag, err = resultReader.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) + + // Done + resultReader = pgConn.GetResult(context.Background()) + assert.Nil(t, resultReader) +} + func TestConnRecoverFromTimeout(t *testing.T) { t.Parallel() From 54df8c691874b91855671462f081a4dd3ce9df42 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 11:32:56 -0600 Subject: [PATCH 0160/1158] Add ExecPrepared --- pgconn.go | 126 +++++++++++++++++++++++++++++++++++++++++++++++-- pgconn_test.go | 57 ++++++++++++++++++++++ 2 files changed, 180 insertions(+), 3 deletions(-) diff --git a/pgconn.go b/pgconn.go index db9c758d..f2e46539 100644 --- a/pgconn.go +++ b/pgconn.go @@ -387,6 +387,10 @@ func appendQuery(buf []byte, query string) []byte { // appendParse appends a PostgreSQL wire protocol parse message to buf and returns it. func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []byte { + if len(paramOIDs) > 65535 { + panic(fmt.Sprintf("len(paramOIDs) must be between 0 and 65535, received %d", len(paramOIDs))) + } + buf = append(buf, 'P') sp := len(buf) buf = pgio.AppendInt32(buf, -1) @@ -404,6 +408,19 @@ func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []by return buf } +// appendDescribe appends a PostgreSQL wire protocol describe message to buf and returns it. +func appendDescribe(buf []byte, objectType byte, name string) []byte { + buf = append(buf, 'D') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, objectType) + buf = append(buf, name...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + // appendSync appends a PostgreSQL wire protocol sync message to buf and returns it. func appendSync(buf []byte) []byte { buf = append(buf, 'S') @@ -424,6 +441,9 @@ func appendBind( if len(paramFormats) != 0 && len(paramFormats) != len(paramValues) && len(paramFormats) != len(paramValues) { panic(fmt.Sprintf("len(paramFormats) must be 0, 1, or len(paramValues), received %d", len(paramFormats))) } + if len(paramValues) > 65535 { + panic(fmt.Sprintf("len(paramValues) must be between 0 and 65535, received %d", len(paramValues))) + } buf = append(buf, 'B') sp := len(buf) @@ -492,9 +512,6 @@ func appendExecute(buf []byte, portal string, maxRows uint32) []byte { // // Query is only sent to the PostgreSQL server when Flush is called. func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { - if len(paramValues) > 65535 { - panic(fmt.Sprintf("Number of params 0 and 65535, received %d", len(paramValues))) - } if len(paramOIDs) != 0 && len(paramOIDs) != len(paramValues) && len(paramOIDs) != len(paramValues) { panic(fmt.Sprintf("len(paramOIDs) must be 0, 1, or len(paramValues), received %d", len(paramOIDs))) } @@ -506,6 +523,25 @@ func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs pgConn.batchCount += 1 } +// SendExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all results will be in text protocol. SendExecParams will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text protocol. +// +// Query is only sent to the PostgreSQL server when Flush is called. +func (pgConn *PgConn) SendExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { + pgConn.batchBuf = appendBind(pgConn.batchBuf, "", stmtName, paramFormats, paramValues, resultFormats) + pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0) + pgConn.batchBuf = appendSync(pgConn.batchBuf) + pgConn.batchCount += 1 +} + type PgResultReader struct { pgConn *PgConn fieldDescriptions []pgproto3.FieldDescription @@ -840,6 +876,90 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return result, nil } +// ExecPrepared executes a prepared statement via the PostgreSQL extended query protocol, buffers the entire result, and +// returns it. See SendExecPrepared for parameter descriptions. +// +// ExecPrepared must not be called when there are pending results from previous Send* methods (e.g. SendExec). +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) (*PgResult, error) { + if pgConn.batchCount != 0 { + return nil, errors.New("unflushed previous sends") + } + if pgConn.pendingReadyForQueryCount != 0 { + return nil, errors.New("unread previous results") + } + + pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats) + err := pgConn.Flush(ctx) + if err != nil { + return nil, err + } + + resultReader := pgConn.GetResult(ctx) + if resultReader == nil { + return nil, errors.New("unexpected missing result") + } + + var result *PgResult + rows := [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + + commandTag, err := resultReader.Close() + if err != nil { + return nil, err + } + + result = &PgResult{ + Rows: rows, + CommandTag: commandTag, + } + + return result, nil +} + +// Prepare creates a prepared statement. +func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) error { + if pgConn.batchCount != 0 { + return errors.New("unflushed previous sends") + } + if pgConn.pendingReadyForQueryCount != 0 { + return errors.New("unread previous results") + } + + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanupContext() + + pgConn.batchBuf = appendParse(pgConn.batchBuf, name, sql, paramOIDs) + pgConn.batchBuf = appendDescribe(pgConn.batchBuf, 'S', name) + pgConn.batchBuf = appendSync(pgConn.batchBuf) + pgConn.batchCount += 1 + err := pgConn.Flush(context.Background()) + if err != nil { + return preferContextOverNetTimeoutError(ctx, err) + } + + for pgConn.pendingReadyForQueryCount > 0 { + msg, err := pgConn.ReceiveMessage() + if err != nil { + return preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + // TODO + case *pgproto3.RowDescription: + // TODO + case *pgproto3.ErrorResponse: + return errorResponseToPgError(msg) + } + } + + return nil +} + func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ Severity: msg.Severity, diff --git a/pgconn_test.go b/pgconn_test.go index a765dc4c..35f5b536 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -320,6 +320,41 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.True(t, pgConn.RecoverFromTimeout(context.Background())) } +func TestConnExecPrepared(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.Nil(t, err) + + result, err := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + require.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "Hello, world", string(result.Rows[0][0])) +} + +func TestConnExecPreparedCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) + require.Nil(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result, err := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil) + assert.Nil(t, result) + assert.Equal(t, context.DeadlineExceeded, err) + + assert.True(t, pgConn.RecoverFromTimeout(context.Background())) +} + func TestConnBatchedQueries(t *testing.T) { t.Parallel() @@ -327,8 +362,12 @@ func TestConnBatchedQueries(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) + err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.Nil(t, err) + pgConn.SendExec("select 'SendExec 1'") pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 1")}, nil, nil, nil) + pgConn.SendExecPrepared("ps1", [][]byte{[]byte("SendExecPrepared 1")}, nil, nil) pgConn.SendExec("select 'SendExec 2'") pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 2")}, nil, nil, nil) err = pgConn.Flush(context.Background()) @@ -369,6 +408,24 @@ func TestConnBatchedQueries(t *testing.T) { assert.Equal(t, "SELECT 1", string(commandTag)) assert.Nil(t, err) + // "SendExecPrepared 1" + resultReader = pgConn.GetResult(context.Background()) + require.NotNil(t, resultReader) + + rows = [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + require.Len(t, rows, 1) + require.Len(t, rows[0], 1) + assert.Equal(t, "SendExecPrepared 1", string(rows[0][0])) + + commandTag, err = resultReader.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) + // "SendExec 2" resultReader = pgConn.GetResult(context.Background()) require.NotNil(t, resultReader) From 51d654d32a9c4975ff27f499fa5cb1149d750cf0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 11:35:39 -0600 Subject: [PATCH 0161/1158] Format code constants already in pgproto3 --- pgconn.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pgconn.go b/pgconn.go index f2e46539..9aeba757 100644 --- a/pgconn.go +++ b/pgconn.go @@ -20,12 +20,6 @@ import ( const batchBufferSize = 4096 -// PostgreSQL extended protocol format codes -const ( - TextFormatCode = 0 - BinaryFormatCode = 1 -) - var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) // PgError represents an error reported by the PostgreSQL server. See From b793875c1ffbdf077c20d2eb36fe3346ae6d77a4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 13:16:50 -0600 Subject: [PATCH 0162/1158] Extract bufferLastResult Buffered exec methods need to read until pending ready for queries is 0. Factor this common logic out. Add stress test for PgConn. --- pgconn.go | 54 ++++++------------------------------------------------ 1 file changed, 6 insertions(+), 48 deletions(-) diff --git a/pgconn.go b/pgconn.go index 9aeba757..ec5413de 100644 --- a/pgconn.go +++ b/pgconn.go @@ -799,6 +799,10 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { return nil, err } + return pgConn.bufferLastResult(ctx) +} + +func (pgConn *PgConn) bufferLastResult(ctx context.Context) (*PgResult, error) { var result *PgResult for resultReader := pgConn.GetResult(ctx); resultReader != nil; resultReader = pgConn.GetResult(ctx) { @@ -844,30 +848,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return nil, err } - resultReader := pgConn.GetResult(ctx) - if resultReader == nil { - return nil, errors.New("unexpected missing result") - } - - var result *PgResult - rows := [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - - commandTag, err := resultReader.Close() - if err != nil { - return nil, err - } - - result = &PgResult{ - Rows: rows, - CommandTag: commandTag, - } - - return result, nil + return pgConn.bufferLastResult(ctx) } // ExecPrepared executes a prepared statement via the PostgreSQL extended query protocol, buffers the entire result, and @@ -888,30 +869,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return nil, err } - resultReader := pgConn.GetResult(ctx) - if resultReader == nil { - return nil, errors.New("unexpected missing result") - } - - var result *PgResult - rows := [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - - commandTag, err := resultReader.Close() - if err != nil { - return nil, err - } - - result = &PgResult{ - Rows: rows, - CommandTag: commandTag, - } - - return result, nil + return pgConn.bufferLastResult(ctx) } // Prepare creates a prepared statement. From 8df3f2010f3b448bcbf5499e889df94223c8d7fd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 13:47:37 -0600 Subject: [PATCH 0163/1158] Avoid allocating strings in common message types --- pgconn.go | 62 +++++++++++++++++++++++++------------------------------ 1 file changed, 28 insertions(+), 34 deletions(-) diff --git a/pgconn.go b/pgconn.go index ec5413de..df823042 100644 --- a/pgconn.go +++ b/pgconn.go @@ -199,25 +199,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig // handled by ReceiveMessage case *pgproto3.ErrorResponse: pgConn.conn.Close() - return nil, &PgError{ - Severity: msg.Severity, - Code: msg.Code, - Message: msg.Message, - Detail: msg.Detail, - Hint: msg.Hint, - Position: msg.Position, - InternalPosition: msg.InternalPosition, - InternalQuery: msg.InternalQuery, - Where: msg.Where, - SchemaName: msg.SchemaName, - TableName: msg.TableName, - ColumnName: msg.ColumnName, - DataTypeName: msg.DataTypeName, - ConstraintName: msg.ConstraintName, - File: msg.File, - Line: msg.Line, - Routine: msg.Routine, - } + return nil, errorResponseToPgError(msg) default: pgConn.conn.Close() return nil, errors.New("unexpected message") @@ -348,7 +330,7 @@ func (pgConn *PgConn) ParameterStatus(key string) string { } // CommandTag is the result of an Exec function -type CommandTag string +type CommandTag []byte // RowsAffected returns the number of rows affected. If the CommandTag was not // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. @@ -362,6 +344,10 @@ func (ct CommandTag) RowsAffected() int64 { return n } +func (ct CommandTag) String() string { + return string(ct) +} + // SendExec enqueues the execution of sql via the PostgreSQL simple query protocol. sql may contain multiple queries. // Execution is implicitly wrapped in a transactions unless a transaction is already in progress or sql contains // transaction control statements. It is only sent to the PostgreSQL server when Flush is called. @@ -511,6 +497,7 @@ func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs } pgConn.batchBuf = appendParse(pgConn.batchBuf, "", sql, paramOIDs) + pgConn.batchBuf = appendDescribe(pgConn.batchBuf, 'S', "") pgConn.batchBuf = appendBind(pgConn.batchBuf, "", "", paramFormats, paramValues, resultFormats) pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0) pgConn.batchBuf = appendSync(pgConn.batchBuf) @@ -530,6 +517,7 @@ func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs // // Query is only sent to the PostgreSQL server when Flush is called. func (pgConn *PgConn) SendExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { + pgConn.batchBuf = appendDescribe(pgConn.batchBuf, 'S', stmtName) pgConn.batchBuf = appendBind(pgConn.batchBuf, "", stmtName, paramFormats, paramValues, resultFormats) pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0) pgConn.batchBuf = appendSync(pgConn.batchBuf) @@ -616,6 +604,12 @@ func (rr *PgResultReader) NextRow() bool { } } +// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until +// the PgResultReader is closed. +func (rr *PgResultReader) FieldDescriptions() []pgproto3.FieldDescription { + return rr.fieldDescriptions +} + // Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only // valid until the next NextRow call or the PgResultReader is closed. However, the underlying byte data is safe to // retain a reference to and mutate. @@ -914,23 +908,23 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ - Severity: msg.Severity, - Code: msg.Code, - Message: msg.Message, - Detail: msg.Detail, - Hint: msg.Hint, + Severity: string(msg.Severity), + Code: string(msg.Code), + Message: string(msg.Message), + Detail: string(msg.Detail), + Hint: string(msg.Hint), Position: msg.Position, InternalPosition: msg.InternalPosition, - InternalQuery: msg.InternalQuery, - Where: msg.Where, - SchemaName: msg.SchemaName, - TableName: msg.TableName, - ColumnName: msg.ColumnName, - DataTypeName: msg.DataTypeName, - ConstraintName: msg.ConstraintName, - File: msg.File, + InternalQuery: string(msg.InternalQuery), + Where: string(msg.Where), + SchemaName: string(msg.SchemaName), + TableName: string(msg.TableName), + ColumnName: string(msg.ColumnName), + DataTypeName: string(msg.DataTypeName), + ConstraintName: string(msg.ConstraintName), + File: string(msg.File), Line: msg.Line, - Routine: msg.Routine, + Routine: string(msg.Routine), } } From f225b3d4a1d815e81ad1c37cc95e92be4dfad253 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 13:47:37 -0600 Subject: [PATCH 0164/1158] Avoid allocating strings in common message types --- command_complete.go | 6 +++--- row_description.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/command_complete.go b/command_complete.go index 85848532..ba5a6a63 100644 --- a/command_complete.go +++ b/command_complete.go @@ -8,7 +8,7 @@ import ( ) type CommandComplete struct { - CommandTag string + CommandTag []byte } func (*CommandComplete) Backend() {} @@ -19,7 +19,7 @@ func (dst *CommandComplete) Decode(src []byte) error { return &invalidMessageFormatErr{messageType: "CommandComplete"} } - dst.CommandTag = string(src[:idx]) + dst.CommandTag = src[:idx] return nil } @@ -43,6 +43,6 @@ func (src *CommandComplete) MarshalJSON() ([]byte, error) { CommandTag string }{ Type: "CommandComplete", - CommandTag: src.CommandTag, + CommandTag: string(src.CommandTag), }) } diff --git a/row_description.go b/row_description.go index 3c5a6faa..eb504c60 100644 --- a/row_description.go +++ b/row_description.go @@ -14,7 +14,7 @@ const ( ) type FieldDescription struct { - Name string + Name []byte TableOID uint32 TableAttributeNumber uint16 DataTypeOID uint32 @@ -45,7 +45,7 @@ func (dst *RowDescription) Decode(src []byte) error { if err != nil { return err } - fd.Name = string(bName[:len(bName)-1]) + fd.Name = bName[:len(bName)-1] // Since buf.Next() doesn't return an error if we hit the end of the buffer // check Len ahead of time From 4f00c6aebdee5322c29a5672eef1fef53d058bdb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 13:49:12 -0600 Subject: [PATCH 0165/1158] Add pgconn stress test --- pgconn_stress_test.go | 199 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 199 insertions(+) create mode 100644 pgconn_stress_test.go diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go new file mode 100644 index 00000000..cc6acab8 --- /dev/null +++ b/pgconn_stress_test.go @@ -0,0 +1,199 @@ +package pgconn_test + +import ( + "context" + "math/rand" + "os" + "strconv" + "testing" + "time" + + "github.com/jackc/pgx/pgconn" + "github.com/pkg/errors" + + "github.com/stretchr/testify/require" +) + +func TestConnStress(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + actionCount := 100 + if s := os.Getenv("PTX_TEST_STRESS_FACTOR"); s != "" { + stressFactor, err := strconv.ParseInt(s, 10, 64) + require.Nil(t, err, "Failed to parse PTX_TEST_STRESS_FACTOR") + actionCount *= int(stressFactor) + } + + setupStressDB(t, pgConn) + + actions := []struct { + name string + fn func(*pgconn.PgConn) error + }{ + {"Exec Select", stressExecSelect}, + {"ExecParams Select", stressExecParamsSelect}, + {"Batch", stressBatch}, + {"ExecCanceled", stressExecSelectCanceled}, + {"ExecParamsCanceled", stressExecParamsSelectCanceled}, + {"BatchCanceled", stressBatchCanceled}, + } + + for i := 0; i < actionCount; i++ { + action := actions[rand.Intn(len(actions))] + err := action.fn(pgConn) + require.Nilf(t, err, "%d: %s", i, action.name) + } +} + +func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { + _, err := pgConn.Exec(context.Background(), ` + create temporary table widgets( + id serial primary key, + name varchar not null, + description text, + creation_time timestamptz default now() + ); + + insert into widgets(name, description) values + ('Foo', 'bar'), + ('baz', 'Something really long Something really long Something really long Something really long Something really long'), + ('a', 'b')`) + require.Nil(t, err) +} + +func stressExecSelect(pgConn *pgconn.PgConn) error { + _, err := pgConn.Exec(context.Background(), "select * from widgets") + return err +} + +func stressExecParamsSelect(pgConn *pgconn.PgConn) error { + _, err := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + return err +} + +func stressBatch(pgConn *pgconn.PgConn) error { + pgConn.SendExec("select * from widgets") + pgConn.SendExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + err := pgConn.Flush(context.Background()) + if err != nil { + return err + } + + // Query 1 + resultReader := pgConn.GetResult(context.Background()) + if resultReader == nil { + return errors.New("missing resultReader") + } + + for resultReader.NextRow() { + } + _, err = resultReader.Close() + if err != nil { + return err + } + + // Query 2 + resultReader = pgConn.GetResult(context.Background()) + if resultReader == nil { + return errors.New("missing resultReader") + } + + for resultReader.NextRow() { + } + _, err = resultReader.Close() + if err != nil { + return err + } + + // No more + resultReader = pgConn.GetResult(context.Background()) + if resultReader != nil { + return errors.New("unexpected result reader") + } + + return nil +} + +func stressExecSelectCanceled(pgConn *pgconn.PgConn) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + _, err := pgConn.Exec(ctx, "select *, pg_sleep(1) from widgets") + cancel() + if err != context.DeadlineExceeded { + return err + } + + ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) + recovered := pgConn.RecoverFromTimeout(ctx) + cancel() + if !recovered { + return errors.New("unable to recover from timeout") + } + return nil +} + +func stressExecParamsSelectCanceled(pgConn *pgconn.PgConn) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + _, err := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + cancel() + if err != context.DeadlineExceeded { + return err + } + + ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) + recovered := pgConn.RecoverFromTimeout(ctx) + cancel() + if !recovered { + return errors.New("unable to recover from timeout") + } + return nil +} + +func stressBatchCanceled(pgConn *pgconn.PgConn) error { + + pgConn.SendExec("select * from widgets") + pgConn.SendExecParams("select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + err := pgConn.Flush(context.Background()) + if err != nil { + return err + } + + // Query 1 + resultReader := pgConn.GetResult(context.Background()) + if resultReader == nil { + return errors.New("missing resultReader") + } + + for resultReader.NextRow() { + } + _, err = resultReader.Close() + if err != nil { + return err + } + + // Query 2 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + resultReader = pgConn.GetResult(ctx) + cancel() + if resultReader == nil { + return errors.New("missing resultReader") + } + + for resultReader.NextRow() { + } + _, err = resultReader.Close() + if err != context.DeadlineExceeded { + return err + } + + ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) + recovered := pgConn.RecoverFromTimeout(ctx) + cancel() + if !recovered { + return errors.New("unable to recover from timeout") + } + return nil +} From 7bd9b776cd753ee5b13adc4872858dd2be4fe650 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 13:52:04 -0600 Subject: [PATCH 0166/1158] Remove another allocation --- row_description.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/row_description.go b/row_description.go index eb504c60..c7f3477f 100644 --- a/row_description.go +++ b/row_description.go @@ -37,7 +37,7 @@ func (dst *RowDescription) Decode(src []byte) error { } fieldCount := int(binary.BigEndian.Uint16(buf.Next(2))) - *dst = RowDescription{Fields: make([]FieldDescription, fieldCount)} + dst.Fields = dst.Fields[0:0] for i := 0; i < fieldCount; i++ { var fd FieldDescription @@ -60,7 +60,7 @@ func (dst *RowDescription) Decode(src []byte) error { fd.TypeModifier = int32(binary.BigEndian.Uint32(buf.Next(4))) fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2))) - dst.Fields[i] = fd + dst.Fields = append(dst.Fields, fd) } return nil From 9af9f57f1575f0ae0c35a473f9125056acb90cba Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 13:56:09 -0600 Subject: [PATCH 0167/1158] Remove another allocation --- pgconn.go | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/pgconn.go b/pgconn.go index df823042..d9755f6c 100644 --- a/pgconn.go +++ b/pgconn.go @@ -73,6 +73,8 @@ type PgConn struct { pendingReadyForQueryCount int32 closed bool + + resultReader PgResultReader } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -536,9 +538,9 @@ type PgResultReader struct { cleanupContext func() } -// GetResult returns a PgResultReader for the next result. If all results are -// consumed it returns nil. If an error occurs it will be reported on the -// returned PgResultReader. +// GetResult returns a PgResultReader for the next result. If all results are consumed it returns nil. If an error +// occurs it will be reported on the returned PgResultReader. Returned PgResultReader is only valid until next call of +// GetResult. func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) @@ -546,20 +548,25 @@ func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { msg, err := pgConn.ReceiveMessage() if err != nil { cleanupContext() - return &PgResultReader{pgConn: pgConn, ctx: ctx, err: preferContextOverNetTimeoutError(ctx, err), complete: true} + pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: preferContextOverNetTimeoutError(ctx, err), complete: true} + return &pgConn.resultReader } switch msg := msg.(type) { case *pgproto3.RowDescription: - return &PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, fieldDescriptions: msg.Fields} + pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, fieldDescriptions: msg.Fields} + return &pgConn.resultReader case *pgproto3.DataRow: - return &PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, rowValues: msg.Values, preloadedRowValues: true} + pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, rowValues: msg.Values, preloadedRowValues: true} + return &pgConn.resultReader case *pgproto3.CommandComplete: cleanupContext() - return &PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true} + pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true} + return &pgConn.resultReader case *pgproto3.ErrorResponse: cleanupContext() - return &PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true} + pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true} + return &pgConn.resultReader } } From 914766af9b867dd946da258bb19269a632ba5296 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 14:10:16 -0600 Subject: [PATCH 0168/1158] Use result readers in next/get fashion --- pgconn.go | 27 ++++++++++++++++----------- pgconn_stress_test.go | 29 ++++++++++++++--------------- pgconn_test.go | 26 +++++++++++++------------- 3 files changed, 43 insertions(+), 39 deletions(-) diff --git a/pgconn.go b/pgconn.go index d9755f6c..8511d5b9 100644 --- a/pgconn.go +++ b/pgconn.go @@ -538,10 +538,9 @@ type PgResultReader struct { cleanupContext func() } -// GetResult returns a PgResultReader for the next result. If all results are consumed it returns nil. If an error -// occurs it will be reported on the returned PgResultReader. Returned PgResultReader is only valid until next call of -// GetResult. -func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { +// NextResult reads until a result is ready to be read or no results are pending. Returns true if a result is available. +// Use ResultReader() to acquire a reader for the result. +func (pgConn *PgConn) NextResult(ctx context.Context) bool { cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) for pgConn.pendingReadyForQueryCount > 0 { @@ -549,29 +548,34 @@ func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { if err != nil { cleanupContext() pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: preferContextOverNetTimeoutError(ctx, err), complete: true} - return &pgConn.resultReader + return true } switch msg := msg.(type) { case *pgproto3.RowDescription: pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, fieldDescriptions: msg.Fields} - return &pgConn.resultReader + return true case *pgproto3.DataRow: pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, rowValues: msg.Values, preloadedRowValues: true} - return &pgConn.resultReader + return true case *pgproto3.CommandComplete: cleanupContext() pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true} - return &pgConn.resultReader + return true case *pgproto3.ErrorResponse: cleanupContext() pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true} - return &pgConn.resultReader + return true } } cleanupContext() - return nil + return false +} + +// ResultReader returns the result reader prepared by next result. It is only valid until the result is completed. +func (pgConn *PgConn) ResultReader() *PgResultReader { + return &pgConn.resultReader } // NextRow returns advances the PgResultReader to the next row and returns true if a row is available. @@ -806,7 +810,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { func (pgConn *PgConn) bufferLastResult(ctx context.Context) (*PgResult, error) { var result *PgResult - for resultReader := pgConn.GetResult(ctx); resultReader != nil; resultReader = pgConn.GetResult(ctx) { + for pgConn.NextResult(ctx) { + resultReader := pgConn.ResultReader() rows := [][][]byte{} for resultReader.NextRow() { row := make([][]byte, len(resultReader.Values())) diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go index cc6acab8..9aa94539 100644 --- a/pgconn_stress_test.go +++ b/pgconn_stress_test.go @@ -84,10 +84,10 @@ func stressBatch(pgConn *pgconn.PgConn) error { } // Query 1 - resultReader := pgConn.GetResult(context.Background()) - if resultReader == nil { - return errors.New("missing resultReader") + if !pgConn.NextResult(context.Background()) { + return errors.New("missing result") } + resultReader := pgConn.ResultReader() for resultReader.NextRow() { } @@ -97,10 +97,10 @@ func stressBatch(pgConn *pgconn.PgConn) error { } // Query 2 - resultReader = pgConn.GetResult(context.Background()) - if resultReader == nil { - return errors.New("missing resultReader") + if !pgConn.NextResult(context.Background()) { + return errors.New("missing result") } + resultReader = pgConn.ResultReader() for resultReader.NextRow() { } @@ -110,8 +110,7 @@ func stressBatch(pgConn *pgconn.PgConn) error { } // No more - resultReader = pgConn.GetResult(context.Background()) - if resultReader != nil { + if pgConn.NextResult(context.Background()) { return errors.New("unexpected result reader") } @@ -162,10 +161,10 @@ func stressBatchCanceled(pgConn *pgconn.PgConn) error { } // Query 1 - resultReader := pgConn.GetResult(context.Background()) - if resultReader == nil { - return errors.New("missing resultReader") + if !pgConn.NextResult(context.Background()) { + return errors.New("missing result") } + resultReader := pgConn.ResultReader() for resultReader.NextRow() { } @@ -176,11 +175,11 @@ func stressBatchCanceled(pgConn *pgconn.PgConn) error { // Query 2 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - resultReader = pgConn.GetResult(ctx) - cancel() - if resultReader == nil { - return errors.New("missing resultReader") + if !pgConn.NextResult(ctx) { + return errors.New("missing result") } + cancel() + resultReader = pgConn.ResultReader() for resultReader.NextRow() { } diff --git a/pgconn_test.go b/pgconn_test.go index 35f5b536..8b578d42 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -373,8 +373,8 @@ func TestConnBatchedQueries(t *testing.T) { err = pgConn.Flush(context.Background()) // "select 'SendExec 1'" - resultReader := pgConn.GetResult(context.Background()) - require.NotNil(t, resultReader) + require.True(t, pgConn.NextResult(context.Background())) + resultReader := pgConn.ResultReader() rows := [][][]byte{} for resultReader.NextRow() { @@ -391,8 +391,8 @@ func TestConnBatchedQueries(t *testing.T) { assert.Nil(t, err) // "SendExecParams 1" - resultReader = pgConn.GetResult(context.Background()) - require.NotNil(t, resultReader) + require.True(t, pgConn.NextResult(context.Background())) + resultReader = pgConn.ResultReader() rows = [][][]byte{} for resultReader.NextRow() { @@ -409,8 +409,8 @@ func TestConnBatchedQueries(t *testing.T) { assert.Nil(t, err) // "SendExecPrepared 1" - resultReader = pgConn.GetResult(context.Background()) - require.NotNil(t, resultReader) + require.True(t, pgConn.NextResult(context.Background())) + resultReader = pgConn.ResultReader() rows = [][][]byte{} for resultReader.NextRow() { @@ -427,8 +427,8 @@ func TestConnBatchedQueries(t *testing.T) { assert.Nil(t, err) // "SendExec 2" - resultReader = pgConn.GetResult(context.Background()) - require.NotNil(t, resultReader) + require.True(t, pgConn.NextResult(context.Background())) + resultReader = pgConn.ResultReader() rows = [][][]byte{} for resultReader.NextRow() { @@ -445,8 +445,8 @@ func TestConnBatchedQueries(t *testing.T) { assert.Nil(t, err) // "SendExecParams 2" - resultReader = pgConn.GetResult(context.Background()) - require.NotNil(t, resultReader) + require.True(t, pgConn.NextResult(context.Background())) + resultReader = pgConn.ResultReader() rows = [][][]byte{} for resultReader.NextRow() { @@ -463,8 +463,7 @@ func TestConnBatchedQueries(t *testing.T) { assert.Nil(t, err) // Done - resultReader = pgConn.GetResult(context.Background()) - assert.Nil(t, resultReader) + require.False(t, pgConn.NextResult(context.Background())) } func TestConnRecoverFromTimeout(t *testing.T) { @@ -505,7 +504,8 @@ func TestConnCancelQuery(t *testing.T) { err = pgConn.CancelRequest(context.Background()) require.Nil(t, err) - _, err = pgConn.GetResult(context.Background()).Close() + require.True(t, pgConn.NextResult(context.Background())) + _, err = pgConn.ResultReader().Close() if err, ok := err.(*pgconn.PgError); ok { assert.Equal(t, "57014", err.Code) } else { From bd2a5d97d0f850520f40ef02a5d328f39fd94d7f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 14:10:24 -0600 Subject: [PATCH 0169/1158] Add benchmark to pgconn --- benchmark_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 benchmark_test.go diff --git a/benchmark_test.go b/benchmark_test.go new file mode 100644 index 00000000..da5bd4fc --- /dev/null +++ b/benchmark_test.go @@ -0,0 +1,52 @@ +package pgconn_test + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/pgconn" + "github.com/stretchr/testify/require" +) + +func BenchmarkConnect(b *testing.B) { + benchmarks := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + } + + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + connString := os.Getenv(bm.env) + if connString == "" { + b.Skipf("Skipping due to missing environment variable %v", bm.env) + } + + for i := 0; i < b.N; i++ { + conn, err := pgconn.Connect(context.Background(), connString) + require.Nil(b, err) + + err = conn.Close(context.Background()) + require.Nil(b, err) + } + }) + } +} + +func BenchmarkExecPrepared(b *testing.B) { + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(b, err) + defer closeConn(b, conn) + + err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil) + require.Nil(b, err) + } +} From 11964a6ec38e6b899e826ffe721817e634665c81 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 14:17:17 -0600 Subject: [PATCH 0170/1158] Add non-buffered benchmark --- benchmark_test.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/benchmark_test.go b/benchmark_test.go index da5bd4fc..aff21216 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -50,3 +50,24 @@ func BenchmarkExecPrepared(b *testing.B) { require.Nil(b, err) } } + +func BenchmarkSendExecPrepared(b *testing.B) { + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(b, err) + defer closeConn(b, conn) + + err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + conn.SendExecPrepared("ps1", nil, nil, nil) + err := conn.Flush(context.Background()) + require.Nil(b, err) + + for conn.NextResult(context.Background()) { + _, err := conn.ResultReader().Close() + require.Nil(b, err) + } + } +} From fdbf2ba728332987aa341ecef2a2b9e0232b5654 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 14:32:42 -0600 Subject: [PATCH 0171/1158] Use pgproto3 instead of custom message encoders --- benchmark_test.go | 13 +++++ pgconn.go | 141 ++++------------------------------------------ 2 files changed, 23 insertions(+), 131 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index aff21216..bdc550cb 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -36,6 +36,19 @@ func BenchmarkConnect(b *testing.B) { } } +func BenchmarkExec(b *testing.B) { + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(b, err) + defer closeConn(b, conn) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + require.Nil(b, err) + } +} + func BenchmarkExecPrepared(b *testing.B) { conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(b, err) diff --git a/pgconn.go b/pgconn.go index 8511d5b9..d7a99676 100644 --- a/pgconn.go +++ b/pgconn.go @@ -14,7 +14,6 @@ import ( "strings" "time" - "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" ) @@ -354,127 +353,10 @@ func (ct CommandTag) String() string { // Execution is implicitly wrapped in a transactions unless a transaction is already in progress or sql contains // transaction control statements. It is only sent to the PostgreSQL server when Flush is called. func (pgConn *PgConn) SendExec(sql string) { - pgConn.batchBuf = appendQuery(pgConn.batchBuf, sql) + pgConn.batchBuf = (&pgproto3.Query{String: sql}).Encode(pgConn.batchBuf) pgConn.batchCount += 1 } -// appendQuery appends a PostgreSQL wire protocol query message to buf and returns it. -func appendQuery(buf []byte, query string) []byte { - buf = append(buf, 'Q') - buf = pgio.AppendInt32(buf, int32(len(query)+5)) - buf = append(buf, query...) - buf = append(buf, 0) - return buf -} - -// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it. -func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []byte { - if len(paramOIDs) > 65535 { - panic(fmt.Sprintf("len(paramOIDs) must be between 0 and 65535, received %d", len(paramOIDs))) - } - - buf = append(buf, 'P') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, name...) - buf = append(buf, 0) - buf = append(buf, query...) - buf = append(buf, 0) - - buf = pgio.AppendInt16(buf, int16(len(paramOIDs))) - for _, oid := range paramOIDs { - buf = pgio.AppendUint32(buf, oid) - } - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - return buf -} - -// appendDescribe appends a PostgreSQL wire protocol describe message to buf and returns it. -func appendDescribe(buf []byte, objectType byte, name string) []byte { - buf = append(buf, 'D') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, objectType) - buf = append(buf, name...) - buf = append(buf, 0) - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - return buf -} - -// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it. -func appendSync(buf []byte) []byte { - buf = append(buf, 'S') - buf = pgio.AppendInt32(buf, 4) - - return buf -} - -// appendBind appends a PostgreSQL wire protocol bind message to buf and returns it. -func appendBind( - buf []byte, - destinationPortal, - preparedStatement string, - paramFormats []int16, - paramValues [][]byte, - resultFormatCodes []int16, -) []byte { - if len(paramFormats) != 0 && len(paramFormats) != len(paramValues) && len(paramFormats) != len(paramValues) { - panic(fmt.Sprintf("len(paramFormats) must be 0, 1, or len(paramValues), received %d", len(paramFormats))) - } - if len(paramValues) > 65535 { - panic(fmt.Sprintf("len(paramValues) must be between 0 and 65535, received %d", len(paramValues))) - } - - buf = append(buf, 'B') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, destinationPortal...) - buf = append(buf, 0) - buf = append(buf, preparedStatement...) - buf = append(buf, 0) - - buf = pgio.AppendInt16(buf, int16(len(paramFormats))) - for _, f := range paramFormats { - buf = pgio.AppendInt16(buf, f) - } - - buf = pgio.AppendInt16(buf, int16(len(paramValues))) - for _, p := range paramValues { - if p == nil { - buf = pgio.AppendInt32(buf, -1) - continue - } - - buf = pgio.AppendInt32(buf, int32(len(p))) - buf = append(buf, p...) - } - - buf = pgio.AppendInt16(buf, int16(len(resultFormatCodes))) - for _, fc := range resultFormatCodes { - buf = pgio.AppendInt16(buf, fc) - } - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - return buf -} - -// appendExecute appends a PostgreSQL wire protocol execute message to buf and returns it. -func appendExecute(buf []byte, portal string, maxRows uint32) []byte { - buf = append(buf, 'E') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf = append(buf, portal...) - buf = append(buf, 0) - buf = pgio.AppendUint32(buf, maxRows) - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - return buf -} - // SendExecParams enqueues the execution of sql via the PostgreSQL extended query protocol. // // sql is a SQL command string. It may only contain one query. Parameter substitution is position using $1, $2, $3, etc. @@ -498,11 +380,8 @@ func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs panic(fmt.Sprintf("len(paramOIDs) must be 0, 1, or len(paramValues), received %d", len(paramOIDs))) } - pgConn.batchBuf = appendParse(pgConn.batchBuf, "", sql, paramOIDs) - pgConn.batchBuf = appendDescribe(pgConn.batchBuf, 'S', "") - pgConn.batchBuf = appendBind(pgConn.batchBuf, "", "", paramFormats, paramValues, resultFormats) - pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0) - pgConn.batchBuf = appendSync(pgConn.batchBuf) + pgConn.batchBuf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) + pgConn.SendExecPrepared("", paramValues, paramFormats, resultFormats) pgConn.batchCount += 1 } @@ -519,10 +398,10 @@ func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs // // Query is only sent to the PostgreSQL server when Flush is called. func (pgConn *PgConn) SendExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { - pgConn.batchBuf = appendDescribe(pgConn.batchBuf, 'S', stmtName) - pgConn.batchBuf = appendBind(pgConn.batchBuf, "", stmtName, paramFormats, paramValues, resultFormats) - pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0) - pgConn.batchBuf = appendSync(pgConn.batchBuf) + pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: stmtName}).Encode(pgConn.batchBuf) + pgConn.batchBuf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(pgConn.batchBuf) + pgConn.batchBuf = (&pgproto3.Execute{}).Encode(pgConn.batchBuf) + pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf) pgConn.batchCount += 1 } @@ -890,9 +769,9 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContext() - pgConn.batchBuf = appendParse(pgConn.batchBuf, name, sql, paramOIDs) - pgConn.batchBuf = appendDescribe(pgConn.batchBuf, 'S', name) - pgConn.batchBuf = appendSync(pgConn.batchBuf) + pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) + pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf) + pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf) pgConn.batchCount += 1 err := pgConn.Flush(context.Background()) if err != nil { From 7986e2726d1679e78d3cce4c3df19e3f7bd3a866 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 16:55:48 -0600 Subject: [PATCH 0172/1158] pgx uses pgconn.CommandTag instead of own definition --- pgconn_test.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pgconn_test.go b/pgconn_test.go index 8b578d42..8f976d87 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -512,3 +512,26 @@ func TestConnCancelQuery(t *testing.T) { t.Errorf("expected pgconn.PgError got %v", err) } } + +func TestCommandTag(t *testing.T) { + t.Parallel() + + var tests = []struct { + commandTag pgconn.CommandTag + rowsAffected int64 + }{ + {commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5}, + {commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0}, + {commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1}, + {commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0}, + {commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1}, + {commandTag: pgconn.CommandTag("CREATE TABLE"), rowsAffected: 0}, + {commandTag: pgconn.CommandTag("ALTER TABLE"), rowsAffected: 0}, + {commandTag: pgconn.CommandTag("DROP TABLE"), rowsAffected: 0}, + } + + for i, tt := range tests { + actual := tt.commandTag.RowsAffected() + assert.Equalf(t, tt.rowsAffected, actual, "%d. %v", i, tt.commandTag) + } +} From 547741ae6aac6cec4b7c7a724d746a30c2ade465 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 17:08:56 -0600 Subject: [PATCH 0173/1158] Fix bug with ready for query counter --- pgconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index d7a99676..1e70a82b 100644 --- a/pgconn.go +++ b/pgconn.go @@ -382,7 +382,6 @@ func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs pgConn.batchBuf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) pgConn.SendExecPrepared("", paramValues, paramFormats, resultFormats) - pgConn.batchCount += 1 } // SendExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. @@ -708,6 +707,7 @@ func (pgConn *PgConn) bufferLastResult(ctx context.Context) (*PgResult, error) { CommandTag: commandTag, } } + if result == nil { return nil, errors.New("unexpected missing result") } From d545e0704e4a57498c357254d3cfb8b528d19697 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 18:03:20 -0600 Subject: [PATCH 0174/1158] Prepare returns description --- benchmark_test.go | 4 ++-- pgconn.go | 53 +++++++++++++++++++++++++++++++++++++++-------- pgconn_test.go | 9 +++++--- 3 files changed, 52 insertions(+), 14 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index bdc550cb..269ac59b 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -54,7 +54,7 @@ func BenchmarkExecPrepared(b *testing.B) { require.Nil(b, err) defer closeConn(b, conn) - err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) b.ResetTimer() @@ -69,7 +69,7 @@ func BenchmarkSendExecPrepared(b *testing.B) { require.Nil(b, err) defer closeConn(b, conn) - err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) b.ResetTimer() diff --git a/pgconn.go b/pgconn.go index 1e70a82b..de7020b2 100644 --- a/pgconn.go +++ b/pgconn.go @@ -757,13 +757,42 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return pgConn.bufferLastResult(ctx) } +type FieldDescription struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + FormatCode int16 +} + +// pgproto3FieldDescriptionToPgconnFieldDescription copies and converts the data from a pgproto3.FieldDescription to a +// FieldDescription. +func pgproto3FieldDescriptionToPgconnFieldDescription(src *pgproto3.FieldDescription, dst *FieldDescription) { + dst.Name = string(src.Name) + dst.TableOID = src.TableOID + dst.TableAttributeNumber = src.TableAttributeNumber + dst.DataTypeOID = src.DataTypeOID + dst.DataTypeSize = src.DataTypeSize + dst.TypeModifier = src.TypeModifier + dst.FormatCode = src.Format +} + +type PreparedStatementDescription struct { + Name string + SQL string + ParamOIDs []uint32 + Fields []FieldDescription +} + // Prepare creates a prepared statement. -func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) error { +func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { if pgConn.batchCount != 0 { - return errors.New("unflushed previous sends") + return nil, errors.New("unflushed previous sends") } if pgConn.pendingReadyForQueryCount != 0 { - return errors.New("unread previous results") + return nil, errors.New("unread previous results") } cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) @@ -775,26 +804,32 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ pgConn.batchCount += 1 err := pgConn.Flush(context.Background()) if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } + psd := &PreparedStatementDescription{Name: name, SQL: sql} + for pgConn.pendingReadyForQueryCount > 0 { msg, err := pgConn.ReceiveMessage() if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { case *pgproto3.ParameterDescription: - // TODO + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) case *pgproto3.RowDescription: - // TODO + psd.Fields = make([]FieldDescription, len(msg.Fields)) + for i := range msg.Fields { + pgproto3FieldDescriptionToPgconnFieldDescription(&msg.Fields[i], &psd.Fields[i]) + } case *pgproto3.ErrorResponse: - return errorResponseToPgError(msg) + return nil, errorResponseToPgError(msg) } } - return nil + return psd, nil } func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { diff --git a/pgconn_test.go b/pgconn_test.go index 8f976d87..ee573d42 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -327,8 +327,11 @@ func TestConnExecPrepared(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) require.Nil(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 1) result, err := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) require.Nil(t, err) @@ -343,7 +346,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) + _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) require.Nil(t, err) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -362,7 +365,7 @@ func TestConnBatchedQueries(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) require.Nil(t, err) pgConn.SendExec("select 'SendExec 1'") From 6d2fa9c5cf5f09b696faf2597341c132139258fb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 12:28:11 -0600 Subject: [PATCH 0175/1158] Handle empty query response --- pgconn.go | 4 ++++ pgconn_test.go | 13 +++++++++++++ 2 files changed, 17 insertions(+) diff --git a/pgconn.go b/pgconn.go index de7020b2..b3abe8e0 100644 --- a/pgconn.go +++ b/pgconn.go @@ -440,6 +440,10 @@ func (pgConn *PgConn) NextResult(ctx context.Context) bool { cleanupContext() pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true} return true + case *pgproto3.EmptyQueryResponse: + cleanupContext() + pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, complete: true} + return true case *pgproto3.ErrorResponse: cleanupContext() pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true} diff --git a/pgconn_test.go b/pgconn_test.go index ee573d42..8d6b606a 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -245,6 +245,19 @@ func TestConnExec(t *testing.T) { assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0])) } +func TestConnExecEmpty(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + result, err := pgConn.Exec(context.Background(), ";") + require.Nil(t, err) + assert.Nil(t, result.CommandTag) + assert.Equal(t, 0, len(result.Rows)) +} + func TestConnExecMultipleQueries(t *testing.T) { t.Parallel() From 460946d66256f11b10bbd3b3bac99def937143f2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 13:14:34 -0600 Subject: [PATCH 0176/1158] Move notice handling to pgconn --- config.go | 2 ++ pgconn.go | 19 +++++++++++++++++++ pgconn_test.go | 23 +++++++++++++++++++++++ 3 files changed, 44 insertions(+) diff --git a/config.go b/config.go index d8872f66..bd1fec9b 100644 --- a/config.go +++ b/config.go @@ -40,6 +40,8 @@ type Config struct { // server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This // allows implementing high availability behavior such as libpq does with target_session_attrs. AfterConnectFunc AfterConnectFunc + + OnNotice NoticeHandler // Callback function called when a notice response is received. } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a diff --git a/pgconn.go b/pgconn.go index b3abe8e0..6b6330dc 100644 --- a/pgconn.go +++ b/pgconn.go @@ -48,9 +48,19 @@ func (pe *PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } +// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from +// LISTEN/NOTIFY notification. +type Notice PgError + // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) +// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at +// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin +// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY +// notification. +type NoticeHandler func(*PgConn, *Notice) + // ErrTLSRefused occurs when the connection attempt requires TLS and the // PostgreSQL server refuses to use TLS var ErrTLSRefused = errors.New("server refused TLS connection") @@ -277,6 +287,10 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { // TODO - close pgConn return nil, errorResponseToPgError(msg) } + case *pgproto3.NoticeResponse: + if pgConn.Config.OnNotice != nil { + pgConn.Config.OnNotice(pgConn, noticeResponseToNotice(msg)) + } } return msg, nil @@ -858,6 +872,11 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { } } +func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { + pgerr := errorResponseToPgError((*pgproto3.ErrorResponse)(msg)) + return (*Notice)(pgerr) +} + // CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel // request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there // is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 diff --git a/pgconn_test.go b/pgconn_test.go index 8d6b606a..98ec9664 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -551,3 +551,26 @@ func TestCommandTag(t *testing.T) { assert.Equalf(t, tt.rowsAffected, actual, "%d. %v", i, tt.commandTag) } } + +func TestConnOnNotice(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + var msg string + config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { + msg = notice.Message + } + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `do $$ +begin + raise notice 'hello, world'; +end$$;`) + require.Nil(t, err) + assert.Equal(t, "hello, world", msg) +} From b213299a9261bb9845b0785a9adcc3d7aebe12f6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 13:59:00 -0600 Subject: [PATCH 0177/1158] Add ensureReadyForQuery to pgconn --- helper_test.go | 14 +++++++ pgconn.go | 102 ++++++++++++++++++++++++++++++++----------------- pgconn_test.go | 28 ++++++++++++++ 3 files changed, 109 insertions(+), 35 deletions(-) diff --git a/helper_test.go b/helper_test.go index 8e7ca92f..1053310b 100644 --- a/helper_test.go +++ b/helper_test.go @@ -7,6 +7,7 @@ import ( "github.com/jackc/pgx/pgconn" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -15,3 +16,16 @@ func closeConn(t testing.TB, conn *pgconn.PgConn) { defer cancel() require.Nil(t, conn.Close(ctx)) } + +// Do a simple query to ensure the connection is still usable +func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + result, err := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil) + cancel() + + require.Nil(t, err) + assert.Equal(t, 3, len(result.Rows)) + assert.Equal(t, "1", string(result.Rows[0][0])) + assert.Equal(t, "2", string(result.Rows[1][0])) + assert.Equal(t, "3", string(result.Rows[2][0])) +} diff --git a/pgconn.go b/pgconn.go index 6b6330dc..76836b9c 100644 --- a/pgconn.go +++ b/pgconn.go @@ -562,23 +562,28 @@ func (rr *PgResultReader) close() { // Flush sends the enqueued execs to the server. func (pgConn *PgConn) Flush(ctx context.Context) error { - defer pgConn.resetBatch() - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanup() + err := pgConn.flush() + cleanup() + return preferContextOverNetTimeoutError(ctx, err) +} +// flush sends the enqueued execs to the server without handling a context. +func (pgConn *PgConn) flush() error { n, err := pgConn.conn.Write(pgConn.batchBuf) - if err != nil { - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - return preferContextOverNetTimeoutError(ctx, err) + if err != nil && n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true } - pgConn.pendingReadyForQueryCount += pgConn.batchCount - return nil + if err == nil { + pgConn.pendingReadyForQueryCount += pgConn.batchCount + } + + pgConn.resetBatch() + + return err } // contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from @@ -646,13 +651,11 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContext() - for pgConn.pendingReadyForQueryCount > 0 { - _, err := pgConn.ReceiveMessage() - if err != nil { - preferContextOverNetTimeoutError(ctx, err) - pgConn.Close(context.Background()) - return false - } + err := pgConn.ensureReadyForQuery() + if err != nil { + preferContextOverNetTimeoutError(ctx, err) + pgConn.Close(context.Background()) + return false } result, err := pgConn.Exec( @@ -667,6 +670,18 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { return true } +// ensureReadyForQuery reads until pendingReadyForQueryCount == 0. +func (pgConn *PgConn) ensureReadyForQuery() error { + for pgConn.pendingReadyForQueryCount > 0 { + _, err := pgConn.ReceiveMessage() + if err != nil { + return err + } + } + + return nil +} + func (pgConn *PgConn) resetBatch() { pgConn.batchCount = 0 if len(pgConn.batchBuf) > batchBufferSize { @@ -690,14 +705,19 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { if pgConn.batchCount != 0 { return nil, errors.New("unflushed previous sends") } - if pgConn.pendingReadyForQueryCount != 0 { - return nil, errors.New("unread previous results") + + cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanup() + + err := pgConn.ensureReadyForQuery() + if err != nil { + return nil, preferContextOverNetTimeoutError(ctx, err) } pgConn.SendExec(sql) - err := pgConn.Flush(ctx) + err = pgConn.flush() if err != nil { - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } return pgConn.bufferLastResult(ctx) @@ -741,12 +761,17 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] if pgConn.batchCount != 0 { return nil, errors.New("unflushed previous sends") } - if pgConn.pendingReadyForQueryCount != 0 { - return nil, errors.New("unread previous results") + + cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanup() + + err := pgConn.ensureReadyForQuery() + if err != nil { + return nil, preferContextOverNetTimeoutError(ctx, err) } pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats) - err := pgConn.Flush(ctx) + err = pgConn.flush() if err != nil { return nil, err } @@ -762,12 +787,17 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa if pgConn.batchCount != 0 { return nil, errors.New("unflushed previous sends") } - if pgConn.pendingReadyForQueryCount != 0 { - return nil, errors.New("unread previous results") + + cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanup() + + err := pgConn.ensureReadyForQuery() + if err != nil { + return nil, preferContextOverNetTimeoutError(ctx, err) } pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats) - err := pgConn.Flush(ctx) + err = pgConn.flush() if err != nil { return nil, err } @@ -809,18 +839,20 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ if pgConn.batchCount != 0 { return nil, errors.New("unflushed previous sends") } - if pgConn.pendingReadyForQueryCount != 0 { - return nil, errors.New("unread previous results") - } - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContext() + cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanup() + + err := pgConn.ensureReadyForQuery() + if err != nil { + return nil, preferContextOverNetTimeoutError(ctx, err) + } pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf) pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf) pgConn.batchCount += 1 - err := pgConn.Flush(context.Background()) + err = pgConn.flush() if err != nil { return nil, preferContextOverNetTimeoutError(ctx, err) } diff --git a/pgconn_test.go b/pgconn_test.go index 98ec9664..e436d739 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -243,6 +243,8 @@ func TestConnExec(t *testing.T) { require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0])) + + ensureConnValid(t, pgConn) } func TestConnExecEmpty(t *testing.T) { @@ -256,6 +258,8 @@ func TestConnExecEmpty(t *testing.T) { require.Nil(t, err) assert.Nil(t, result.CommandTag) assert.Equal(t, 0, len(result.Rows)) + + ensureConnValid(t, pgConn) } func TestConnExecMultipleQueries(t *testing.T) { @@ -269,6 +273,8 @@ func TestConnExecMultipleQueries(t *testing.T) { require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "1", string(result.Rows[0][0])) + + ensureConnValid(t, pgConn) } func TestConnExecMultipleQueriesError(t *testing.T) { @@ -286,6 +292,8 @@ func TestConnExecMultipleQueriesError(t *testing.T) { } else { t.Errorf("unexpected error: %v", err) } + + ensureConnValid(t, pgConn) } func TestConnExecContextCanceled(t *testing.T) { @@ -302,6 +310,8 @@ func TestConnExecContextCanceled(t *testing.T) { assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgConn.RecoverFromTimeout(context.Background())) + + ensureConnValid(t, pgConn) } func TestConnExecParams(t *testing.T) { @@ -315,6 +325,8 @@ func TestConnExecParams(t *testing.T) { require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "Hello, world", string(result.Rows[0][0])) + + ensureConnValid(t, pgConn) } func TestConnExecParamsCanceled(t *testing.T) { @@ -331,6 +343,8 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgConn.RecoverFromTimeout(context.Background())) + + ensureConnValid(t, pgConn) } func TestConnExecPrepared(t *testing.T) { @@ -350,6 +364,8 @@ func TestConnExecPrepared(t *testing.T) { require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "Hello, world", string(result.Rows[0][0])) + + ensureConnValid(t, pgConn) } func TestConnExecPreparedCanceled(t *testing.T) { @@ -369,6 +385,8 @@ func TestConnExecPreparedCanceled(t *testing.T) { assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgConn.RecoverFromTimeout(context.Background())) + + ensureConnValid(t, pgConn) } func TestConnBatchedQueries(t *testing.T) { @@ -480,6 +498,8 @@ func TestConnBatchedQueries(t *testing.T) { // Done require.False(t, pgConn.NextResult(context.Background())) + + ensureConnValid(t, pgConn) } func TestConnRecoverFromTimeout(t *testing.T) { @@ -504,6 +524,8 @@ func TestConnRecoverFromTimeout(t *testing.T) { assert.Equal(t, "1", string(result.Rows[0][0])) } cancel() + + ensureConnValid(t, pgConn) } func TestConnCancelQuery(t *testing.T) { @@ -527,6 +549,10 @@ func TestConnCancelQuery(t *testing.T) { } else { t.Errorf("expected pgconn.PgError got %v", err) } + + require.False(t, pgConn.NextResult(context.Background())) + + ensureConnValid(t, pgConn) } func TestCommandTag(t *testing.T) { @@ -573,4 +599,6 @@ begin end$$;`) require.Nil(t, err) assert.Equal(t, "hello, world", msg) + + ensureConnValid(t, pgConn) } From 475720d172af1aaff5711146b60bb8839e2952f8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 14:10:57 -0600 Subject: [PATCH 0178/1158] Fix typo --- pgconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 76836b9c..ff43f8a8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -629,7 +629,7 @@ func preferContextOverNetTimeoutError(ctx context.Context, err error) error { } // RecoverFromTimeout attempts to recover from a timeout error such as is caused by a canceled context. If recovery is -// successful true is returned. If recovery is not successful the connection is closed and false it returned. Recovery +// successful true is returned. If recovery is not successful the connection is closed and false is returned. Recovery // should usually be possible except in the case of a partial write. This must be called after any context cancellation. // // As RecoverFromTimeout may need to read and ignored data already sent from the server, it potentially can block From de2b9bb301c52f92abdd4f3caf13520e6e4855a9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 14:20:10 -0600 Subject: [PATCH 0179/1158] Tweak RecoverFromTimeout docs --- pgconn.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pgconn.go b/pgconn.go index ff43f8a8..9661f99e 100644 --- a/pgconn.go +++ b/pgconn.go @@ -628,12 +628,12 @@ func preferContextOverNetTimeoutError(ctx context.Context, err error) error { return err } -// RecoverFromTimeout attempts to recover from a timeout error such as is caused by a canceled context. If recovery is -// successful true is returned. If recovery is not successful the connection is closed and false is returned. Recovery -// should usually be possible except in the case of a partial write. This must be called after any context cancellation. -// -// As RecoverFromTimeout may need to read and ignored data already sent from the server, it potentially can block -// indefinitely. Use ctx to guard against this. +// RecoverFromTimeout attempts to recover from a timeout error such as is caused by a canceled context. This must be +// called after any context cancellation. This is not done automatically as RecoverFromTimeout may need to signal the +// server to abort the in-progress query and read and ignore data already sent from the server. This potentially can +// block indefinitely. Use ctx to guard against this. If recovery is successful true is returned. If recovery is not +// successful the connection is closed and false is returned. Recovery should usually be possible except in the case of +// a partial write. func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { if pgConn.closed { return false From ec622237e97fe4258b9f33e60af1aae33f622c27 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 14:56:24 -0600 Subject: [PATCH 0180/1158] Extract startOperation --- pgconn.go | 113 +++++++++++++++++++++++++++++------------------------- 1 file changed, 60 insertions(+), 53 deletions(-) diff --git a/pgconn.go b/pgconn.go index 9661f99e..e22a0de8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -586,39 +586,6 @@ func (pgConn *PgConn) flush() error { return err } -// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from -// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to -// call multiple times. -func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) { - if ctx.Done() != nil { - deadlineWasSet := false - doneChan := make(chan struct{}) - go func() { - select { - case <-ctx.Done(): - conn.SetDeadline(deadlineTime) - deadlineWasSet = true - <-doneChan - // TODO - case <-doneChan: - } - }() - - finished := false - return func() { - if !finished { - doneChan <- struct{}{} - if deadlineWasSet { - conn.SetDeadline(time.Time{}) - } - finished = true - } - } - } - - return func() {} -} - // preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == // true. Otherwise returns err. func preferContextOverNetTimeoutError(ctx context.Context, err error) error { @@ -670,6 +637,54 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { return true } +// startOperation gets the connection ready for a new operation. It should be called at the beginning of every public +// method that communicates with the server. The returned cleanup function must be called if err == nil or a goroutine may +// be leaked. The cleanup function is safe to call multiple times. +func (pgConn *PgConn) startOperation(ctx context.Context) (cleanup func(), err error) { + cleanup = contextDoneToConnDeadline(ctx, pgConn.conn) + + err = pgConn.ensureReadyForQuery() + if err != nil { + cleanup() + return cleanup, preferContextOverNetTimeoutError(ctx, err) + } + + return cleanup, nil +} + +// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from +// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to +// call multiple times. +func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) { + if ctx.Done() != nil { + deadlineWasSet := false + doneChan := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + conn.SetDeadline(deadlineTime) + deadlineWasSet = true + <-doneChan + // TODO + case <-doneChan: + } + }() + + finished := false + return func() { + if !finished { + doneChan <- struct{}{} + if deadlineWasSet { + conn.SetDeadline(time.Time{}) + } + finished = true + } + } + } + + return func() {} +} + // ensureReadyForQuery reads until pendingReadyForQueryCount == 0. func (pgConn *PgConn) ensureReadyForQuery() error { for pgConn.pendingReadyForQueryCount > 0 { @@ -706,13 +721,11 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { return nil, errors.New("unflushed previous sends") } - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanup() - - err := pgConn.ensureReadyForQuery() + cleanup, err := pgConn.startOperation(ctx) if err != nil { - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, err } + defer cleanup() pgConn.SendExec(sql) err = pgConn.flush() @@ -762,13 +775,11 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return nil, errors.New("unflushed previous sends") } - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanup() - - err := pgConn.ensureReadyForQuery() + cleanup, err := pgConn.startOperation(ctx) if err != nil { - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, err } + defer cleanup() pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats) err = pgConn.flush() @@ -788,13 +799,11 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return nil, errors.New("unflushed previous sends") } - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanup() - - err := pgConn.ensureReadyForQuery() + cleanup, err := pgConn.startOperation(ctx) if err != nil { - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, err } + defer cleanup() pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats) err = pgConn.flush() @@ -840,13 +849,11 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ return nil, errors.New("unflushed previous sends") } - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanup() - - err := pgConn.ensureReadyForQuery() + cleanup, err := pgConn.startOperation(ctx) if err != nil { - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, err } + defer cleanup() pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf) From fa5e1d3ec4ad6453e36af71c042ffc261b379cfa Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 18:16:08 -0600 Subject: [PATCH 0181/1158] Back out of some over optimization --- pgconn.go | 30 +++--------------------------- 1 file changed, 3 insertions(+), 27 deletions(-) diff --git a/pgconn.go b/pgconn.go index e22a0de8..ee8127bf 100644 --- a/pgconn.go +++ b/pgconn.go @@ -814,33 +814,11 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return pgConn.bufferLastResult(ctx) } -type FieldDescription struct { - Name string - TableOID uint32 - TableAttributeNumber uint16 - DataTypeOID uint32 - DataTypeSize int16 - TypeModifier int32 - FormatCode int16 -} - -// pgproto3FieldDescriptionToPgconnFieldDescription copies and converts the data from a pgproto3.FieldDescription to a -// FieldDescription. -func pgproto3FieldDescriptionToPgconnFieldDescription(src *pgproto3.FieldDescription, dst *FieldDescription) { - dst.Name = string(src.Name) - dst.TableOID = src.TableOID - dst.TableAttributeNumber = src.TableAttributeNumber - dst.DataTypeOID = src.DataTypeOID - dst.DataTypeSize = src.DataTypeSize - dst.TypeModifier = src.TypeModifier - dst.FormatCode = src.Format -} - type PreparedStatementDescription struct { Name string SQL string ParamOIDs []uint32 - Fields []FieldDescription + Fields []pgproto3.FieldDescription } // Prepare creates a prepared statement. @@ -877,10 +855,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) copy(psd.ParamOIDs, msg.ParameterOIDs) case *pgproto3.RowDescription: - psd.Fields = make([]FieldDescription, len(msg.Fields)) - for i := range msg.Fields { - pgproto3FieldDescriptionToPgconnFieldDescription(&msg.Fields[i], &psd.Fields[i]) - } + psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) + copy(psd.Fields, msg.Fields) case *pgproto3.ErrorResponse: return nil, errorResponseToPgError(msg) } From a24d764440ff7448d32302091ae873884c93dd88 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 18:16:08 -0600 Subject: [PATCH 0182/1158] Back out of some over optimization --- command_complete.go | 6 +++--- row_description.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/command_complete.go b/command_complete.go index ba5a6a63..85848532 100644 --- a/command_complete.go +++ b/command_complete.go @@ -8,7 +8,7 @@ import ( ) type CommandComplete struct { - CommandTag []byte + CommandTag string } func (*CommandComplete) Backend() {} @@ -19,7 +19,7 @@ func (dst *CommandComplete) Decode(src []byte) error { return &invalidMessageFormatErr{messageType: "CommandComplete"} } - dst.CommandTag = src[:idx] + dst.CommandTag = string(src[:idx]) return nil } @@ -43,6 +43,6 @@ func (src *CommandComplete) MarshalJSON() ([]byte, error) { CommandTag string }{ Type: "CommandComplete", - CommandTag: string(src.CommandTag), + CommandTag: src.CommandTag, }) } diff --git a/row_description.go b/row_description.go index c7f3477f..7deba379 100644 --- a/row_description.go +++ b/row_description.go @@ -14,7 +14,7 @@ const ( ) type FieldDescription struct { - Name []byte + Name string TableOID uint32 TableAttributeNumber uint16 DataTypeOID uint32 @@ -45,7 +45,7 @@ func (dst *RowDescription) Decode(src []byte) error { if err != nil { return err } - fd.Name = bName[:len(bName)-1] + fd.Name = string(bName[:len(bName)-1]) // Since buf.Next() doesn't return an error if we hit the end of the buffer // check Len ahead of time From 64e80f1f723cc2edc3495db56b23755b164abf62 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 18:16:20 -0600 Subject: [PATCH 0183/1158] Add benchmarks when cancellable --- benchmark_test.go | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/benchmark_test.go b/benchmark_test.go index 269ac59b..fc4b6057 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -49,6 +49,22 @@ func BenchmarkExec(b *testing.B) { } } +func BenchmarkExecPossibleToCancel(b *testing.B) { + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(b, err) + defer closeConn(b, conn) + + b.ResetTimer() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for i := 0; i < b.N; i++ { + _, err := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + require.Nil(b, err) + } +} + func BenchmarkExecPrepared(b *testing.B) { conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(b, err) @@ -64,6 +80,24 @@ func BenchmarkExecPrepared(b *testing.B) { } } +func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(b, err) + defer closeConn(b, conn) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := conn.ExecPrepared(ctx, "ps1", nil, nil, nil) + require.Nil(b, err) + } +} + func BenchmarkSendExecPrepared(b *testing.B) { conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(b, err) From cddf01180659a163df1e7ca9cd03dd648ea8c153 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Jan 2019 17:37:28 -0600 Subject: [PATCH 0184/1158] Big restructure to better handle context cancel --- benchmark_test.go | 33 +- config.go | 6 +- helper_test.go | 4 +- pgconn.go | 995 +++++++++++++++++++++++------------------- pgconn_stress_test.go | 116 +---- pgconn_test.go | 289 +++++------- 6 files changed, 686 insertions(+), 757 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index fc4b6057..ffb1455c 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -44,7 +44,7 @@ func BenchmarkExec(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + _, err := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date").ReadAll() require.Nil(b, err) } } @@ -60,7 +60,7 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { defer cancel() for i := 0; i < b.N; i++ { - _, err := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + _, err := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date").ReadAll() require.Nil(b, err) } } @@ -71,12 +71,13 @@ func BenchmarkExecPrepared(b *testing.B) { defer closeConn(b, conn) _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + require.Nil(b, err) b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil) - require.Nil(b, err) + result := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil).ReadAll() + require.Nil(b, result.Err) } } @@ -89,32 +90,12 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { defer cancel() _, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _, err := conn.ExecPrepared(ctx, "ps1", nil, nil, nil) - require.Nil(b, err) - } -} - -func BenchmarkSendExecPrepared(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(b, err) - defer closeConn(b, conn) - - _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) b.ResetTimer() for i := 0; i < b.N; i++ { - conn.SendExecPrepared("ps1", nil, nil, nil) - err := conn.Flush(context.Background()) - require.Nil(b, err) - - for conn.NextResult(context.Background()) { - _, err := conn.ResultReader().Close() - require.Nil(b, err) - } + result := conn.ExecPrepared(ctx, "ps1", nil, nil, nil).ReadAll() + require.Nil(b, result.Err) } } diff --git a/config.go b/config.go index bd1fec9b..fb0719cd 100644 --- a/config.go +++ b/config.go @@ -470,9 +470,9 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { // AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible // target_session_attrs=read-write. func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { - result, err := pgConn.Exec(ctx, "show transaction_read_only") - if err != nil { - return err + result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).ReadAll() + if result.Err != nil { + return result.Err } if string(result.Rows[0][0]) == "on" { diff --git a/helper_test.go b/helper_test.go index 1053310b..a50f7cb1 100644 --- a/helper_test.go +++ b/helper_test.go @@ -20,10 +20,10 @@ func closeConn(t testing.TB, conn *pgconn.PgConn) { // Do a simple query to ensure the connection is still usable func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) - result, err := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil) + result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).ReadAll() cancel() - require.Nil(t, err) + require.Nil(t, result.Err) assert.Equal(t, 3, len(result.Rows)) assert.Equal(t, "1", string(result.Rows[0][0])) assert.Equal(t, "2", string(result.Rows[1][0])) diff --git a/pgconn.go b/pgconn.go index ee8127bf..cfacc7bb 100644 --- a/pgconn.go +++ b/pgconn.go @@ -17,8 +17,6 @@ import ( "github.com/jackc/pgx/pgproto3" ) -const batchBufferSize = 4096 - var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) // PgError represents an error reported by the PostgreSQL server. See @@ -76,14 +74,9 @@ type PgConn struct { Config *Config - batchBuf []byte - batchCount int32 - - pendingReadyForQueryCount int32 + controller chan interface{} closed bool - - resultReader PgResultReader } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -140,6 +133,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.Config = config + pgConn.controller = make(chan interface{}, 1) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) @@ -268,23 +262,22 @@ func hexMD5(s string) string { func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { msg, err := pgConn.Frontend.Receive() if err != nil { + // Close on anything other than timeout error - everything else is fatal + if err, ok := err.(net.Error); !ok && err.Timeout() { + pgConn.hardClose() + } + return nil, err } switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - // Under normal circumstances pendingReadyForQueryCount will be > 0 when a - // ReadyForQuery is received. However, this is not the case on initial - // connection. - if pgConn.pendingReadyForQueryCount > 0 { - pgConn.pendingReadyForQueryCount -= 1 - } pgConn.TxStatus = msg.TxStatus case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: if msg.Severity == "FATAL" { - // TODO - close pgConn + pgConn.hardClose() return nil, errorResponseToPgError(msg) } case *pgproto3.NoticeResponse: @@ -338,6 +331,15 @@ func (pgConn *PgConn) Close(ctx context.Context) error { return pgConn.conn.Close() } +// hardClose closes the underlying connection without sending the exit message. +func (pgConn *PgConn) hardClose() error { + if pgConn.closed { + return nil + } + pgConn.closed = true + return pgConn.conn.Close() +} + // ParameterStatus returns the value of a parameter reported by the server (e.g. // server_version). Returns an empty string for unknown parameters. func (pgConn *PgConn) ParameterStatus(key string) string { @@ -363,229 +365,6 @@ func (ct CommandTag) String() string { return string(ct) } -// SendExec enqueues the execution of sql via the PostgreSQL simple query protocol. sql may contain multiple queries. -// Execution is implicitly wrapped in a transactions unless a transaction is already in progress or sql contains -// transaction control statements. It is only sent to the PostgreSQL server when Flush is called. -func (pgConn *PgConn) SendExec(sql string) { - pgConn.batchBuf = (&pgproto3.Query{String: sql}).Encode(pgConn.batchBuf) - pgConn.batchCount += 1 -} - -// SendExecParams enqueues the execution of sql via the PostgreSQL extended query protocol. -// -// sql is a SQL command string. It may only contain one query. Parameter substitution is position using $1, $2, $3, etc. -// -// paramValues are the parameter values. It must be encoded in the format given by paramFormats. -// -// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for -// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. -// SendExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). -// -// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or -// binary format. If paramFormats is nil all results will be in text protocol. SendExecParams will panic if -// len(paramFormats) is not 0, 1, or len(paramValues). -// -// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or -// binary format. If resultFormats is nil all results will be in text protocol. -// -// Query is only sent to the PostgreSQL server when Flush is called. -func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { - if len(paramOIDs) != 0 && len(paramOIDs) != len(paramValues) && len(paramOIDs) != len(paramValues) { - panic(fmt.Sprintf("len(paramOIDs) must be 0, 1, or len(paramValues), received %d", len(paramOIDs))) - } - - pgConn.batchBuf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) - pgConn.SendExecPrepared("", paramValues, paramFormats, resultFormats) -} - -// SendExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. -// -// paramValues are the parameter values. It must be encoded in the format given by paramFormats. -// -// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or -// binary format. If paramFormats is nil all results will be in text protocol. SendExecParams will panic if -// len(paramFormats) is not 0, 1, or len(paramValues). -// -// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or -// binary format. If resultFormats is nil all results will be in text protocol. -// -// Query is only sent to the PostgreSQL server when Flush is called. -func (pgConn *PgConn) SendExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { - pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: stmtName}).Encode(pgConn.batchBuf) - pgConn.batchBuf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(pgConn.batchBuf) - pgConn.batchBuf = (&pgproto3.Execute{}).Encode(pgConn.batchBuf) - pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf) - pgConn.batchCount += 1 -} - -type PgResultReader struct { - pgConn *PgConn - fieldDescriptions []pgproto3.FieldDescription - rowValues [][]byte - commandTag CommandTag - err error - complete bool - preloadedRowValues bool - ctx context.Context - cleanupContext func() -} - -// NextResult reads until a result is ready to be read or no results are pending. Returns true if a result is available. -// Use ResultReader() to acquire a reader for the result. -func (pgConn *PgConn) NextResult(ctx context.Context) bool { - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) - - for pgConn.pendingReadyForQueryCount > 0 { - msg, err := pgConn.ReceiveMessage() - if err != nil { - cleanupContext() - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: preferContextOverNetTimeoutError(ctx, err), complete: true} - return true - } - - switch msg := msg.(type) { - case *pgproto3.RowDescription: - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, fieldDescriptions: msg.Fields} - return true - case *pgproto3.DataRow: - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, rowValues: msg.Values, preloadedRowValues: true} - return true - case *pgproto3.CommandComplete: - cleanupContext() - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true} - return true - case *pgproto3.EmptyQueryResponse: - cleanupContext() - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, complete: true} - return true - case *pgproto3.ErrorResponse: - cleanupContext() - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true} - return true - } - } - - cleanupContext() - return false -} - -// ResultReader returns the result reader prepared by next result. It is only valid until the result is completed. -func (pgConn *PgConn) ResultReader() *PgResultReader { - return &pgConn.resultReader -} - -// NextRow returns advances the PgResultReader to the next row and returns true if a row is available. -func (rr *PgResultReader) NextRow() bool { - if rr.complete { - return false - } - - if rr.preloadedRowValues { - rr.preloadedRowValues = false - return true - } - - for { - msg, err := rr.pgConn.ReceiveMessage() - if err != nil { - rr.err = preferContextOverNetTimeoutError(rr.ctx, err) - rr.close() - return false - } - - switch msg := msg.(type) { - case *pgproto3.RowDescription: - rr.fieldDescriptions = msg.Fields - case *pgproto3.DataRow: - rr.rowValues = msg.Values - return true - case *pgproto3.CommandComplete: - rr.commandTag = CommandTag(msg.CommandTag) - rr.close() - return false - case *pgproto3.ErrorResponse: - rr.err = errorResponseToPgError(msg) - rr.close() - return false - } - } -} - -// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until -// the PgResultReader is closed. -func (rr *PgResultReader) FieldDescriptions() []pgproto3.FieldDescription { - return rr.fieldDescriptions -} - -// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only -// valid until the next NextRow call or the PgResultReader is closed. However, the underlying byte data is safe to -// retain a reference to and mutate. -func (rr *PgResultReader) Values() [][]byte { - return rr.rowValues -} - -// Close consumes any remaining result data and returns the command tag or -// error. -func (rr *PgResultReader) Close() (CommandTag, error) { - if rr.complete { - return rr.commandTag, rr.err - } - defer rr.close() - - for { - msg, err := rr.pgConn.ReceiveMessage() - if err != nil { - rr.err = preferContextOverNetTimeoutError(rr.ctx, err) - return rr.commandTag, rr.err - } - - switch msg := msg.(type) { - case *pgproto3.CommandComplete: - rr.commandTag = CommandTag(msg.CommandTag) - return rr.commandTag, rr.err - case *pgproto3.ErrorResponse: - rr.err = errorResponseToPgError(msg) - return rr.commandTag, rr.err - } - } -} - -func (rr *PgResultReader) close() { - if rr.complete { - return - } - - rr.cleanupContext() - rr.rowValues = nil - rr.complete = true -} - -// Flush sends the enqueued execs to the server. -func (pgConn *PgConn) Flush(ctx context.Context) error { - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - err := pgConn.flush() - cleanup() - return preferContextOverNetTimeoutError(ctx, err) -} - -// flush sends the enqueued execs to the server without handling a context. -func (pgConn *PgConn) flush() error { - n, err := pgConn.conn.Write(pgConn.batchBuf) - if err != nil && n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - - if err == nil { - pgConn.pendingReadyForQueryCount += pgConn.batchCount - } - - pgConn.resetBatch() - - return err -} - // preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == // true. Otherwise returns err. func preferContextOverNetTimeoutError(ctx context.Context, err error) error { @@ -595,63 +374,6 @@ func preferContextOverNetTimeoutError(ctx context.Context, err error) error { return err } -// RecoverFromTimeout attempts to recover from a timeout error such as is caused by a canceled context. This must be -// called after any context cancellation. This is not done automatically as RecoverFromTimeout may need to signal the -// server to abort the in-progress query and read and ignore data already sent from the server. This potentially can -// block indefinitely. Use ctx to guard against this. If recovery is successful true is returned. If recovery is not -// successful the connection is closed and false is returned. Recovery should usually be possible except in the case of -// a partial write. -func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { - if pgConn.closed { - return false - } - pgConn.resetBatch() - - // Clear any existing timeout - pgConn.conn.SetDeadline(time.Time{}) - - // Try to cancel any in-progress requests - for i := 0; i < int(pgConn.pendingReadyForQueryCount); i++ { - pgConn.CancelRequest(ctx) - } - - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContext() - - err := pgConn.ensureReadyForQuery() - if err != nil { - preferContextOverNetTimeoutError(ctx, err) - pgConn.Close(context.Background()) - return false - } - - result, err := pgConn.Exec( - context.Background(), // do not use ctx again because deadline goroutine already started above - "select 'RecoverFromTimeout'", - ) - if err != nil || len(result.Rows) != 1 || len(result.Rows[0]) != 1 || string(result.Rows[0][0]) != "RecoverFromTimeout" { - pgConn.Close(context.Background()) - return false - } - - return true -} - -// startOperation gets the connection ready for a new operation. It should be called at the beginning of every public -// method that communicates with the server. The returned cleanup function must be called if err == nil or a goroutine may -// be leaked. The cleanup function is safe to call multiple times. -func (pgConn *PgConn) startOperation(ctx context.Context) (cleanup func(), err error) { - cleanup = contextDoneToConnDeadline(ctx, pgConn.conn) - - err = pgConn.ensureReadyForQuery() - if err != nil { - cleanup() - return cleanup, preferContextOverNetTimeoutError(ctx, err) - } - - return cleanup, nil -} - // contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from // ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to // call multiple times. @@ -665,7 +387,6 @@ func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func conn.SetDeadline(deadlineTime) deadlineWasSet = true <-doneChan - // TODO case <-doneChan: } }() @@ -685,135 +406,6 @@ func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func return func() {} } -// ensureReadyForQuery reads until pendingReadyForQueryCount == 0. -func (pgConn *PgConn) ensureReadyForQuery() error { - for pgConn.pendingReadyForQueryCount > 0 { - _, err := pgConn.ReceiveMessage() - if err != nil { - return err - } - } - - return nil -} - -func (pgConn *PgConn) resetBatch() { - pgConn.batchCount = 0 - if len(pgConn.batchBuf) > batchBufferSize { - pgConn.batchBuf = make([]byte, 0, batchBufferSize) - } else { - pgConn.batchBuf = pgConn.batchBuf[0:0] - } -} - -type PgResult struct { - Rows [][][]byte - CommandTag CommandTag -} - -// Exec executes sql via the PostgreSQL simple query protocol, buffers the entire result, and returns it. sql may -// contain multiple queries, but only the last results will be returned. Execution is implicitly wrapped in a -// transactions unless a transaction is already in progress or sql contains transaction control statements. -// -// Exec must not be called when there are pending results from previous Send* methods (e.g. SendExec). -func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { - if pgConn.batchCount != 0 { - return nil, errors.New("unflushed previous sends") - } - - cleanup, err := pgConn.startOperation(ctx) - if err != nil { - return nil, err - } - defer cleanup() - - pgConn.SendExec(sql) - err = pgConn.flush() - if err != nil { - return nil, preferContextOverNetTimeoutError(ctx, err) - } - - return pgConn.bufferLastResult(ctx) -} - -func (pgConn *PgConn) bufferLastResult(ctx context.Context) (*PgResult, error) { - var result *PgResult - - for pgConn.NextResult(ctx) { - resultReader := pgConn.ResultReader() - rows := [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - - commandTag, err := resultReader.Close() - if err != nil { - return nil, err - } - - result = &PgResult{ - Rows: rows, - CommandTag: commandTag, - } - } - - if result == nil { - return nil, errors.New("unexpected missing result") - } - - return result, nil -} - -// ExecParams executes sql via the PostgreSQL extended query protocol, buffers the entire result, and returns it. See -// SendExecParams for parameter descriptions. -// -// ExecParams must not be called when there are pending results from previous Send* methods (e.g. SendExec). -func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) (*PgResult, error) { - if pgConn.batchCount != 0 { - return nil, errors.New("unflushed previous sends") - } - - cleanup, err := pgConn.startOperation(ctx) - if err != nil { - return nil, err - } - defer cleanup() - - pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats) - err = pgConn.flush() - if err != nil { - return nil, err - } - - return pgConn.bufferLastResult(ctx) -} - -// ExecPrepared executes a prepared statement via the PostgreSQL extended query protocol, buffers the entire result, and -// returns it. See SendExecPrepared for parameter descriptions. -// -// ExecPrepared must not be called when there are pending results from previous Send* methods (e.g. SendExec). -func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) (*PgResult, error) { - if pgConn.batchCount != 0 { - return nil, errors.New("unflushed previous sends") - } - - cleanup, err := pgConn.startOperation(ctx) - if err != nil { - return nil, err - } - defer cleanup() - - pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats) - err = pgConn.flush() - if err != nil { - return nil, err - } - - return pgConn.bufferLastResult(ctx) -} - type PreparedStatementDescription struct { Name string SQL string @@ -823,30 +415,38 @@ type PreparedStatementDescription struct { // Prepare creates a prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { - if pgConn.batchCount != 0 { - return nil, errors.New("unflushed previous sends") + select { + case <-ctx.Done(): + return nil, ctx.Err() + case pgConn.controller <- pgConn: } + cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanupContextDeadline() - cleanup, err := pgConn.startOperation(ctx) - if err != nil { - return nil, err - } - defer cleanup() + var buf []byte + buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) + buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) + buf = (&pgproto3.Sync{}).Encode(buf) - pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) - pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf) - pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf) - pgConn.batchCount += 1 - err = pgConn.flush() + n, err := pgConn.conn.Write(buf) if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + return nil, preferContextOverNetTimeoutError(ctx, err) } psd := &PreparedStatementDescription{Name: name, SQL: sql} - for pgConn.pendingReadyForQueryCount > 0 { +readloop: + for { msg, err := pgConn.ReceiveMessage() if err != nil { + go pgConn.recoverFromTimeout() return nil, preferContextOverNetTimeoutError(ctx, err) } @@ -858,10 +458,14 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) copy(psd.Fields, msg.Fields) case *pgproto3.ErrorResponse: + go pgConn.recoverFromTimeout() return nil, errorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + break readloop } } + <-pgConn.controller return psd, nil } @@ -892,10 +496,10 @@ func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { return (*Notice)(pgerr) } -// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel +// cancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel // request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there // is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 -func (pgConn *PgConn) CancelRequest(ctx context.Context) error { +func (pgConn *PgConn) cancelRequest(ctx context.Context) error { // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. @@ -926,3 +530,514 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { return nil } + +// Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is +// implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control +// statements. +// +// Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. +func (pgConn *PgConn) Exec(ctx context.Context, sql string) *PgMultiResult { + multiResult := &PgMultiResult{ + pgConn: pgConn, + ctx: ctx, + cleanupContextDeadline: func() {}, + } + + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = ctx.Err() + return multiResult + case pgConn.controller <- multiResult: + } + multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + + var buf []byte + buf = (&pgproto3.Query{String: sql}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + multiResult.cleanupContextDeadline() + multiResult.closed = true + multiResult.err = preferContextOverNetTimeoutError(ctx, err) + <-pgConn.controller + return multiResult + } + + return multiResult +} + +// ExecParams executes a command via the PostgreSQL extended query protocol. +// +// sql is a SQL command string. It may only contain one query. Parameter substitution is positional using $1, $2, $3, +// etc. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for +// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. +// ExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all results will be in text protocol. ExecParams will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text protocol. +// +// Result must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *PgResult { + result := &PgResult{ + pgConn: pgConn, + ctx: ctx, + cleanupContextDeadline: func() {}, + } + + select { + case <-ctx.Done(): + result.concludeCommand(nil, ctx.Err()) + result.closed = true + return result + case pgConn.controller <- result: + } + result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + + var buf []byte + + // TODO - refactor ExecParams and ExecPrepared - these lines only difference + buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) + buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + + buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) + buf = (&pgproto3.Execute{}).Encode(buf) + buf = (&pgproto3.Sync{}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + result.concludeCommand(nil, err) + result.cleanupContextDeadline() + result.closed = true + <-pgConn.controller + } + + return result +} + +// ExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all results will be in text protocol. ExecPrepared will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text protocol. +// +// Result must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *PgResult { + result := &PgResult{ + pgConn: pgConn, + ctx: ctx, + cleanupContextDeadline: func() {}, + } + + select { + case <-ctx.Done(): + result.concludeCommand(nil, ctx.Err()) + result.closed = true + return result + case pgConn.controller <- result: + } + result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + + var buf []byte + buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) + buf = (&pgproto3.Execute{}).Encode(buf) + buf = (&pgproto3.Sync{}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + result.concludeCommand(nil, err) + result.cleanupContextDeadline() + result.closed = true + <-pgConn.controller + } + + return result +} + +type PgMultiResult struct { + pgConn *PgConn + ctx context.Context + cleanupContextDeadline func() + + pgResult *PgResult + + closed bool + err error +} + +func (mr *PgMultiResult) ReadAll() ([]*BufferedResult, error) { + var results []*BufferedResult + + for mr.NextResult() { + results = append(results, mr.Result().ReadAll()) + } + err := mr.Close() + + return results, err +} + +func (mr *PgMultiResult) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := mr.pgConn.ReceiveMessage() + + if err != nil { + mr.cleanupContextDeadline() + mr.err = preferContextOverNetTimeoutError(mr.ctx, err) + mr.closed = true + + if err, ok := err.(net.Error); ok && err.Timeout() { + go mr.pgConn.recoverFromTimeout() + } else { + <-mr.pgConn.controller + } + + return nil, mr.err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + mr.cleanupContextDeadline() + mr.closed = true + <-mr.pgConn.controller + case *pgproto3.ErrorResponse: + mr.err = errorResponseToPgError(msg) + } + + return msg, nil +} + +// NextResult returns advances the PgMultiResult to the next result and returns true if a result is available. +func (mr *PgMultiResult) NextResult() bool { + for !mr.closed && mr.err == nil { + msg, err := mr.receiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + mr.pgResult = &PgResult{ + pgConn: mr.pgConn, + pgMultiResult: mr, + ctx: mr.ctx, + cleanupContextDeadline: func() {}, + fieldDescriptions: msg.Fields, + } + return true + case *pgproto3.CommandComplete: + mr.pgResult = &PgResult{ + commandTag: CommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + } + return true + case *pgproto3.EmptyQueryResponse: + return false + } + } + + return false +} + +func (mr *PgMultiResult) Result() *PgResult { + return mr.pgResult +} + +func (mr *PgMultiResult) Close() error { + for !mr.closed { + _, err := mr.receiveMessage() + if err != nil { + return mr.err + } + } + + return mr.err +} + +type PgResult struct { + pgConn *PgConn + pgMultiResult *PgMultiResult + ctx context.Context + cleanupContextDeadline func() + + fieldDescriptions []pgproto3.FieldDescription + rowValues [][]byte + commandTag CommandTag + commandConcluded bool + closed bool + err error +} + +type BufferedResult struct { + FieldDescriptions []pgproto3.FieldDescription + Rows [][][]byte + CommandTag CommandTag + Err error +} + +func (rr *PgResult) ReadAll() *BufferedResult { + br := &BufferedResult{} + + for rr.NextRow() { + if br.FieldDescriptions == nil { + br.FieldDescriptions = make([]pgproto3.FieldDescription, len(rr.FieldDescriptions())) + copy(br.FieldDescriptions, rr.FieldDescriptions()) + } + + row := make([][]byte, len(rr.Values())) + copy(row, rr.Values()) + br.Rows = append(br.Rows, row) + } + + br.CommandTag, br.Err = rr.Close() + + return br +} + +// NextRow advances the PgResult to the next row and returns true if a row is available. +func (rr *PgResult) NextRow() bool { + for !rr.commandConcluded { + msg, err := rr.receiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.DataRow: + rr.rowValues = msg.Values + return true + } + } + + return false +} + +// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until +// the PgResult is closed. +func (rr *PgResult) FieldDescriptions() []pgproto3.FieldDescription { + return rr.fieldDescriptions +} + +// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only +// valid until the next NextRow call or the PgResult is closed. However, the underlying byte data is safe to +// retain a reference to and mutate. +func (rr *PgResult) Values() [][]byte { + return rr.rowValues +} + +// Close consumes any remaining result data and returns the command tag or +// error. +func (rr *PgResult) Close() (CommandTag, error) { + if rr.closed { + return rr.commandTag, rr.err + } + rr.closed = true + + for !rr.commandConcluded { + _, err := rr.receiveMessage() + if err != nil { + return nil, rr.err + } + } + + if rr.pgMultiResult == nil { + for { + msg, err := rr.receiveMessage() + if err != nil { + return nil, rr.err + } + + switch msg.(type) { + case *pgproto3.ReadyForQuery: + rr.cleanupContextDeadline() + <-rr.pgConn.controller + return rr.commandTag, rr.err + } + } + } + + return rr.commandTag, rr.err +} + +func (rr *PgResult) receiveMessage() (msg pgproto3.BackendMessage, err error) { + if rr.pgMultiResult == nil { + msg, err = rr.pgConn.ReceiveMessage() + } else { + msg, err = rr.pgMultiResult.receiveMessage() + } + + if err != nil { + rr.concludeCommand(nil, err) + rr.cleanupContextDeadline() + rr.closed = true + if rr.pgMultiResult == nil { + if err, ok := err.(net.Error); ok && err.Timeout() { + go rr.pgConn.recoverFromTimeout() + } else { + <-rr.pgConn.controller + } + } + + return nil, rr.err + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + rr.fieldDescriptions = msg.Fields + case *pgproto3.CommandComplete: + rr.concludeCommand(CommandTag(msg.CommandTag), nil) + case *pgproto3.ErrorResponse: + rr.concludeCommand(nil, errorResponseToPgError(msg)) + } + + return msg, nil +} + +func (rr *PgResult) concludeCommand(commandTag CommandTag, err error) { + if rr.commandConcluded { + return + } + + rr.commandTag = commandTag + rr.err = preferContextOverNetTimeoutError(rr.ctx, err) + rr.fieldDescriptions = nil + rr.rowValues = nil + rr.commandConcluded = true +} + +func (pgConn *PgConn) recoverFromTimeout() { + // Regardless of recovery outcome the lock on the pgConn must be released. + defer func() { <-pgConn.controller }() + + // Send a cancellation request to the PostgreSQL server. If it is not successful in a reasonable amount of time do not + // try further to recover the connection. + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + err := pgConn.cancelRequest(ctx) + cancel() + if err != nil { + pgConn.hardClose() + return + } + + // Limit time to wait for ReadyForQuery message. + err = pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second)) + if err != nil { + pgConn.hardClose() + return + } + + // A cancel query request will always return a "57014" error response, even if no query was in progress. This error + // may be returned before or after the ReadyForQuery message. Must ensure both messages are read. + needError57014 := true + needReadyForQuery := true + + for needError57014 || needReadyForQuery { + msg, err := pgConn.ReceiveMessage() + if err != nil { + pgConn.hardClose() + return + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + if msg.Code == "57014" { + needError57014 = false + } + case *pgproto3.ReadyForQuery: + needReadyForQuery = false + } + } + + err = pgConn.conn.SetDeadline(time.Time{}) + if err != nil { + pgConn.hardClose() + } +} + +type Batch struct { + buf []byte +} + +// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. +func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { + // TODO - refactor ExecParams and ExecPrepared - these lines only difference + batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) + batch.ExecPrepared("", paramValues, paramFormats, resultFormats) +} + +// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. +func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { + batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) + batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) + batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) +} + +func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *PgMultiResult { + multiResult := &PgMultiResult{ + pgConn: pgConn, + ctx: ctx, + cleanupContextDeadline: func() {}, + } + + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = ctx.Err() + return multiResult + case pgConn.controller <- multiResult: + } + multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + + batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) + n, err := pgConn.conn.Write(batch.buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + multiResult.cleanupContextDeadline() + multiResult.closed = true + multiResult.err = preferContextOverNetTimeoutError(ctx, err) + <-pgConn.controller + return multiResult + } + + return multiResult +} diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go index 9aa94539..17d344b7 100644 --- a/pgconn_stress_test.go +++ b/pgconn_stress_test.go @@ -9,7 +9,6 @@ import ( "time" "github.com/jackc/pgx/pgconn" - "github.com/pkg/errors" "github.com/stretchr/testify/require" ) @@ -22,9 +21,9 @@ func TestConnStress(t *testing.T) { defer closeConn(t, pgConn) actionCount := 100 - if s := os.Getenv("PTX_TEST_STRESS_FACTOR"); s != "" { + if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" { stressFactor, err := strconv.ParseInt(s, 10, 64) - require.Nil(t, err, "Failed to parse PTX_TEST_STRESS_FACTOR") + require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR") actionCount *= int(stressFactor) } @@ -61,138 +60,61 @@ func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { insert into widgets(name, description) values ('Foo', 'bar'), ('baz', 'Something really long Something really long Something really long Something really long Something really long'), - ('a', 'b')`) + ('a', 'b')`).ReadAll() require.Nil(t, err) } func stressExecSelect(pgConn *pgconn.PgConn) error { - _, err := pgConn.Exec(context.Background(), "select * from widgets") + _, err := pgConn.Exec(context.Background(), "select * from widgets").ReadAll() return err } func stressExecParamsSelect(pgConn *pgconn.PgConn) error { - _, err := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - return err + result := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).ReadAll() + return result.Err } func stressBatch(pgConn *pgconn.PgConn) error { - pgConn.SendExec("select * from widgets") - pgConn.SendExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - err := pgConn.Flush(context.Background()) - if err != nil { - return err - } + batch := &pgconn.Batch{} - // Query 1 - if !pgConn.NextResult(context.Background()) { - return errors.New("missing result") - } - resultReader := pgConn.ResultReader() - - for resultReader.NextRow() { - } - _, err = resultReader.Close() - if err != nil { - return err - } - - // Query 2 - if !pgConn.NextResult(context.Background()) { - return errors.New("missing result") - } - resultReader = pgConn.ResultReader() - - for resultReader.NextRow() { - } - _, err = resultReader.Close() - if err != nil { - return err - } - - // No more - if pgConn.NextResult(context.Background()) { - return errors.New("unexpected result reader") - } - - return nil + batch.ExecParams("select * from widgets", nil, nil, nil, nil) + batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + _, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + return err } func stressExecSelectCanceled(pgConn *pgconn.PgConn) error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - _, err := pgConn.Exec(ctx, "select *, pg_sleep(1) from widgets") + _, err := pgConn.Exec(ctx, "select *, pg_sleep(1) from widgets").ReadAll() cancel() if err != context.DeadlineExceeded { return err } - ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) - recovered := pgConn.RecoverFromTimeout(ctx) - cancel() - if !recovered { - return errors.New("unable to recover from timeout") - } return nil } func stressExecParamsSelectCanceled(pgConn *pgconn.PgConn) error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - _, err := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + result := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).ReadAll() cancel() - if err != context.DeadlineExceeded { - return err + if result.Err != context.DeadlineExceeded { + return result.Err } - ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) - recovered := pgConn.RecoverFromTimeout(ctx) - cancel() - if !recovered { - return errors.New("unable to recover from timeout") - } return nil } func stressBatchCanceled(pgConn *pgconn.PgConn) error { - - pgConn.SendExec("select * from widgets") - pgConn.SendExecParams("select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - err := pgConn.Flush(context.Background()) - if err != nil { - return err - } - - // Query 1 - if !pgConn.NextResult(context.Background()) { - return errors.New("missing result") - } - resultReader := pgConn.ResultReader() - - for resultReader.NextRow() { - } - _, err = resultReader.Close() - if err != nil { - return err - } - - // Query 2 + batch := &pgconn.Batch{} + batch.ExecParams("select * from widgets", nil, nil, nil, nil) + batch.ExecParams("select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - if !pgConn.NextResult(ctx) { - return errors.New("missing result") - } + _, err := pgConn.ExecBatch(ctx, batch).ReadAll() cancel() - resultReader = pgConn.ResultReader() - - for resultReader.NextRow() { - } - _, err = resultReader.Close() if err != context.DeadlineExceeded { return err } - ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) - recovered := pgConn.RecoverFromTimeout(ctx) - cancel() - if !recovered { - return errors.New("unable to recover from timeout") - } return nil } diff --git a/pgconn_test.go b/pgconn_test.go index e436d739..a2eb7838 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -134,13 +134,13 @@ func TestConnectWithRuntimeParams(t *testing.T) { require.Nil(t, err) defer closeConn(t, conn) - result, err := conn.Exec(context.Background(), "show application_name") - require.Nil(t, err) + result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).ReadAll() + require.Nil(t, result.Err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result, err = conn.Exec(context.Background(), "show search_path") - require.Nil(t, err) + result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).ReadAll() + require.Nil(t, result.Err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "myschema", string(result.Rows[0][0])) } @@ -239,10 +239,14 @@ func TestConnExec(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec(context.Background(), "select current_database()") - require.Nil(t, err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0])) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.Nil(t, err) + + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) ensureConnValid(t, pgConn) } @@ -254,10 +258,16 @@ func TestConnExecEmpty(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec(context.Background(), ";") - require.Nil(t, err) - assert.Nil(t, result.CommandTag) - assert.Equal(t, 0, len(result.Rows)) + multiResult := pgConn.Exec(context.Background(), ";") + + resultCount := 0 + for multiResult.NextResult() { + resultCount += 1 + multiResult.Result().Close() + } + assert.Equal(t, 0, resultCount) + err = multiResult.Close() + assert.Nil(t, err) ensureConnValid(t, pgConn) } @@ -269,10 +279,20 @@ func TestConnExecMultipleQueries(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec(context.Background(), "select current_database(); select 1") - require.Nil(t, err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "1", string(result.Rows[0][0])) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() + assert.Nil(t, err) + + assert.Len(t, results, 2) + + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + assert.Nil(t, results[1].Err) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + assert.Len(t, results[1].Rows, 1) + assert.Equal(t, "1", string(results[1].Rows[0][0])) ensureConnValid(t, pgConn) } @@ -284,15 +304,18 @@ func TestConnExecMultipleQueriesError(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1") + results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() require.NotNil(t, err) - require.Nil(t, result) if pgErr, ok := err.(*pgconn.PgError); ok { assert.Equal(t, "22012", pgErr.Code) } else { t.Errorf("unexpected error: %v", err) } + assert.Len(t, results, 1) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) + ensureConnValid(t, pgConn) } @@ -305,11 +328,12 @@ func TestConnExecContextCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") - assert.Nil(t, result) - assert.Equal(t, context.DeadlineExceeded, err) + multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(1)") - assert.True(t, pgConn.RecoverFromTimeout(context.Background())) + for multiResult.NextResult() { + } + err = multiResult.Close() + assert.Equal(t, context.DeadlineExceeded, err) ensureConnValid(t, pgConn) } @@ -321,10 +345,16 @@ func TestConnExecParams(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) - require.Nil(t, err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "Hello, world", string(result.Rows[0][0])) + result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) ensureConnValid(t, pgConn) } @@ -338,12 +368,16 @@ func TestConnExecParamsCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - result, err := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil) - assert.Nil(t, result) + result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + } + assert.Equal(t, 0, rowCount) + commandTag, err := result.Close() + assert.Nil(t, commandTag) assert.Equal(t, context.DeadlineExceeded, err) - assert.True(t, pgConn.RecoverFromTimeout(context.Background())) - ensureConnValid(t, pgConn) } @@ -360,10 +394,16 @@ func TestConnExecPrepared(t *testing.T) { assert.Len(t, psd.ParamOIDs, 1) assert.Len(t, psd.Fields, 1) - result, err := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) - require.Nil(t, err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "Hello, world", string(result.Rows[0][0])) + result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) ensureConnValid(t, pgConn) } @@ -380,16 +420,20 @@ func TestConnExecPreparedCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - result, err := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil) - assert.Nil(t, result) + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + } + assert.Equal(t, 0, rowCount) + commandTag, err := result.Close() + assert.Nil(t, commandTag) assert.Equal(t, context.DeadlineExceeded, err) - assert.True(t, pgConn.RecoverFromTimeout(context.Background())) - ensureConnValid(t, pgConn) } -func TestConnBatchedQueries(t *testing.T) { +func TestConnExecBatch(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) @@ -399,160 +443,26 @@ func TestConnBatchedQueries(t *testing.T) { _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) require.Nil(t, err) - pgConn.SendExec("select 'SendExec 1'") - pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 1")}, nil, nil, nil) - pgConn.SendExecPrepared("ps1", [][]byte{[]byte("SendExecPrepared 1")}, nil, nil) - pgConn.SendExec("select 'SendExec 2'") - pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 2")}, nil, nil, nil) - err = pgConn.Flush(context.Background()) + batch := &pgconn.Batch{} - // "select 'SendExec 1'" - require.True(t, pgConn.NextResult(context.Background())) - resultReader := pgConn.ResultReader() - - rows := [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - require.Len(t, rows, 1) - require.Len(t, rows[0], 1) - assert.Equal(t, "SendExec 1", string(rows[0][0])) - - commandTag, err := resultReader.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) - - // "SendExecParams 1" - require.True(t, pgConn.NextResult(context.Background())) - resultReader = pgConn.ResultReader() - - rows = [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - require.Len(t, rows, 1) - require.Len(t, rows[0], 1) - assert.Equal(t, "SendExecParams 1", string(rows[0][0])) - - commandTag, err = resultReader.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) - - // "SendExecPrepared 1" - require.True(t, pgConn.NextResult(context.Background())) - resultReader = pgConn.ResultReader() - - rows = [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - require.Len(t, rows, 1) - require.Len(t, rows[0], 1) - assert.Equal(t, "SendExecPrepared 1", string(rows[0][0])) - - commandTag, err = resultReader.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) - - // "SendExec 2" - require.True(t, pgConn.NextResult(context.Background())) - resultReader = pgConn.ResultReader() - - rows = [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - require.Len(t, rows, 1) - require.Len(t, rows[0], 1) - assert.Equal(t, "SendExec 2", string(rows[0][0])) - - commandTag, err = resultReader.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) - - // "SendExecParams 2" - require.True(t, pgConn.NextResult(context.Background())) - resultReader = pgConn.ResultReader() - - rows = [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - require.Len(t, rows, 1) - require.Len(t, rows[0], 1) - assert.Equal(t, "SendExecParams 2", string(rows[0][0])) - - commandTag, err = resultReader.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) - - // Done - require.False(t, pgConn.NextResult(context.Background())) - - ensureConnValid(t, pgConn) -} - -func TestConnRecoverFromTimeout(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() require.Nil(t, err) - defer closeConn(t, pgConn) + require.Len(t, results, 3) - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") - cancel() - require.Nil(t, result) - assert.Equal(t, context.DeadlineExceeded, err) + require.Len(t, results[0].Rows, 1) + require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - if assert.True(t, pgConn.RecoverFromTimeout(ctx)) { - result, err := pgConn.Exec(ctx, "select 1") - require.Nil(t, err) - assert.Len(t, result.Rows, 1) - assert.Len(t, result.Rows[0], 1) - assert.Equal(t, "1", string(result.Rows[0][0])) - } - cancel() + require.Len(t, results[1].Rows, 1) + require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - ensureConnValid(t, pgConn) -} - -func TestConnCancelQuery(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) - defer closeConn(t, pgConn) - - pgConn.SendExec("select current_database(), pg_sleep(5)") - err = pgConn.Flush(context.Background()) - require.Nil(t, err) - - err = pgConn.CancelRequest(context.Background()) - require.Nil(t, err) - - require.True(t, pgConn.NextResult(context.Background())) - _, err = pgConn.ResultReader().Close() - if err, ok := err.(*pgconn.PgError); ok { - assert.Equal(t, "57014", err.Code) - } else { - t.Errorf("expected pgconn.PgError got %v", err) - } - - require.False(t, pgConn.NextResult(context.Background())) - - ensureConnValid(t, pgConn) + require.Len(t, results[2].Rows, 1) + require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } func TestCommandTag(t *testing.T) { @@ -593,10 +503,11 @@ func TestConnOnNotice(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), `do $$ + multiResult := pgConn.Exec(context.Background(), `do $$ begin raise notice 'hello, world'; end$$;`) + err = multiResult.Close() require.Nil(t, err) assert.Equal(t, "hello, world", msg) From 04ee3b8cbd64e2acbab4ceff2b1369677f3cb2d5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Jan 2019 17:41:43 -0600 Subject: [PATCH 0185/1158] Remove Pg prefix for a couple types --- pgconn.go | 60 +++++++++++++++++++++++++++---------------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/pgconn.go b/pgconn.go index cfacc7bb..d4563086 100644 --- a/pgconn.go +++ b/pgconn.go @@ -536,8 +536,8 @@ func (pgConn *PgConn) cancelRequest(ctx context.Context) error { // statements. // // Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. -func (pgConn *PgConn) Exec(ctx context.Context, sql string) *PgMultiResult { - multiResult := &PgMultiResult{ +func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResult { + multiResult := &MultiResult{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, @@ -593,8 +593,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *PgMultiResult { // binary format. If resultFormats is nil all results will be in text protocol. // // Result must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *PgResult { - result := &PgResult{ +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *Result { + result := &Result{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, @@ -649,8 +649,8 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] // binary format. If resultFormats is nil all results will be in text protocol. // // Result must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *PgResult { - result := &PgResult{ +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *Result { + result := &Result{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, @@ -689,18 +689,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } -type PgMultiResult struct { +type MultiResult struct { pgConn *PgConn ctx context.Context cleanupContextDeadline func() - pgResult *PgResult + pgResult *Result closed bool err error } -func (mr *PgMultiResult) ReadAll() ([]*BufferedResult, error) { +func (mr *MultiResult) ReadAll() ([]*BufferedResult, error) { var results []*BufferedResult for mr.NextResult() { @@ -711,7 +711,7 @@ func (mr *PgMultiResult) ReadAll() ([]*BufferedResult, error) { return results, err } -func (mr *PgMultiResult) receiveMessage() (pgproto3.BackendMessage, error) { +func (mr *MultiResult) receiveMessage() (pgproto3.BackendMessage, error) { msg, err := mr.pgConn.ReceiveMessage() if err != nil { @@ -740,8 +740,8 @@ func (mr *PgMultiResult) receiveMessage() (pgproto3.BackendMessage, error) { return msg, nil } -// NextResult returns advances the PgMultiResult to the next result and returns true if a result is available. -func (mr *PgMultiResult) NextResult() bool { +// NextResult returns advances the MultiResult to the next result and returns true if a result is available. +func (mr *MultiResult) NextResult() bool { for !mr.closed && mr.err == nil { msg, err := mr.receiveMessage() if err != nil { @@ -750,7 +750,7 @@ func (mr *PgMultiResult) NextResult() bool { switch msg := msg.(type) { case *pgproto3.RowDescription: - mr.pgResult = &PgResult{ + mr.pgResult = &Result{ pgConn: mr.pgConn, pgMultiResult: mr, ctx: mr.ctx, @@ -759,7 +759,7 @@ func (mr *PgMultiResult) NextResult() bool { } return true case *pgproto3.CommandComplete: - mr.pgResult = &PgResult{ + mr.pgResult = &Result{ commandTag: CommandTag(msg.CommandTag), commandConcluded: true, closed: true, @@ -773,11 +773,11 @@ func (mr *PgMultiResult) NextResult() bool { return false } -func (mr *PgMultiResult) Result() *PgResult { +func (mr *MultiResult) Result() *Result { return mr.pgResult } -func (mr *PgMultiResult) Close() error { +func (mr *MultiResult) Close() error { for !mr.closed { _, err := mr.receiveMessage() if err != nil { @@ -788,9 +788,9 @@ func (mr *PgMultiResult) Close() error { return mr.err } -type PgResult struct { +type Result struct { pgConn *PgConn - pgMultiResult *PgMultiResult + pgMultiResult *MultiResult ctx context.Context cleanupContextDeadline func() @@ -809,7 +809,7 @@ type BufferedResult struct { Err error } -func (rr *PgResult) ReadAll() *BufferedResult { +func (rr *Result) ReadAll() *BufferedResult { br := &BufferedResult{} for rr.NextRow() { @@ -828,8 +828,8 @@ func (rr *PgResult) ReadAll() *BufferedResult { return br } -// NextRow advances the PgResult to the next row and returns true if a row is available. -func (rr *PgResult) NextRow() bool { +// NextRow advances the Result to the next row and returns true if a row is available. +func (rr *Result) NextRow() bool { for !rr.commandConcluded { msg, err := rr.receiveMessage() if err != nil { @@ -847,21 +847,21 @@ func (rr *PgResult) NextRow() bool { } // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until -// the PgResult is closed. -func (rr *PgResult) FieldDescriptions() []pgproto3.FieldDescription { +// the Result is closed. +func (rr *Result) FieldDescriptions() []pgproto3.FieldDescription { return rr.fieldDescriptions } // Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only -// valid until the next NextRow call or the PgResult is closed. However, the underlying byte data is safe to +// valid until the next NextRow call or the Result is closed. However, the underlying byte data is safe to // retain a reference to and mutate. -func (rr *PgResult) Values() [][]byte { +func (rr *Result) Values() [][]byte { return rr.rowValues } // Close consumes any remaining result data and returns the command tag or // error. -func (rr *PgResult) Close() (CommandTag, error) { +func (rr *Result) Close() (CommandTag, error) { if rr.closed { return rr.commandTag, rr.err } @@ -893,7 +893,7 @@ func (rr *PgResult) Close() (CommandTag, error) { return rr.commandTag, rr.err } -func (rr *PgResult) receiveMessage() (msg pgproto3.BackendMessage, err error) { +func (rr *Result) receiveMessage() (msg pgproto3.BackendMessage, err error) { if rr.pgMultiResult == nil { msg, err = rr.pgConn.ReceiveMessage() } else { @@ -927,7 +927,7 @@ func (rr *PgResult) receiveMessage() (msg pgproto3.BackendMessage, err error) { return msg, nil } -func (rr *PgResult) concludeCommand(commandTag CommandTag, err error) { +func (rr *Result) concludeCommand(commandTag CommandTag, err error) { if rr.commandConcluded { return } @@ -1006,8 +1006,8 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) } -func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *PgMultiResult { - multiResult := &PgMultiResult{ +func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResult { + multiResult := &MultiResult{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, From 379be3508b5e79eba7dcd7ac4a47f80cfeba8058 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Jan 2019 17:46:47 -0600 Subject: [PATCH 0186/1158] Add some docs for batch --- pgconn.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgconn.go b/pgconn.go index d4563086..09d87b31 100644 --- a/pgconn.go +++ b/pgconn.go @@ -988,6 +988,7 @@ func (pgConn *PgConn) recoverFromTimeout() { } } +// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { buf []byte } @@ -1006,6 +1007,7 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) } +// ExecBatch executes all the queries in batch in a single round-trip. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResult { multiResult := &MultiResult{ pgConn: pgConn, From 2c8971b38263182f5644c7c6f65ec19a89e6f428 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Jan 2019 18:01:57 -0600 Subject: [PATCH 0187/1158] Rename some types and methods --- benchmark_test.go | 4 +- config.go | 2 +- helper_test.go | 2 +- pgconn.go | 126 +++++++++++++++++++++--------------------- pgconn_stress_test.go | 4 +- pgconn_test.go | 6 +- 6 files changed, 72 insertions(+), 72 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index ffb1455c..d2576324 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -76,7 +76,7 @@ func BenchmarkExecPrepared(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - result := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil).ReadAll() + result := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil).Read() require.Nil(b, result.Err) } } @@ -95,7 +95,7 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - result := conn.ExecPrepared(ctx, "ps1", nil, nil, nil).ReadAll() + result := conn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() require.Nil(b, result.Err) } } diff --git a/config.go b/config.go index fb0719cd..b85bcaec 100644 --- a/config.go +++ b/config.go @@ -470,7 +470,7 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { // AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible // target_session_attrs=read-write. func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).ReadAll() + result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() if result.Err != nil { return result.Err } diff --git a/helper_test.go b/helper_test.go index a50f7cb1..c5ac6e01 100644 --- a/helper_test.go +++ b/helper_test.go @@ -20,7 +20,7 @@ func closeConn(t testing.TB, conn *pgconn.PgConn) { // Do a simple query to ensure the connection is still usable func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) - result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).ReadAll() + result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read() cancel() require.Nil(t, result.Err) diff --git a/pgconn.go b/pgconn.go index 09d87b31..be7d37ae 100644 --- a/pgconn.go +++ b/pgconn.go @@ -536,8 +536,8 @@ func (pgConn *PgConn) cancelRequest(ctx context.Context) error { // statements. // // Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. -func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResult { - multiResult := &MultiResult{ +func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { + multiResult := &MultiResultReader{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, @@ -592,9 +592,9 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResult { // resultFormats is a slice of format codes determining for each result column whether it is encoded in text or // binary format. If resultFormats is nil all results will be in text protocol. // -// Result must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *Result { - result := &Result{ +// ResultReader must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { + result := &ResultReader{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, @@ -648,9 +648,9 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] // resultFormats is a slice of format codes determining for each result column whether it is encoded in text or // binary format. If resultFormats is nil all results will be in text protocol. // -// Result must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *Result { - result := &Result{ +// ResultReader must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { + result := &ResultReader{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, @@ -689,77 +689,77 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } -type MultiResult struct { +type MultiResultReader struct { pgConn *PgConn ctx context.Context cleanupContextDeadline func() - pgResult *Result + rr *ResultReader closed bool err error } -func (mr *MultiResult) ReadAll() ([]*BufferedResult, error) { - var results []*BufferedResult +func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { + var results []*Result - for mr.NextResult() { - results = append(results, mr.Result().ReadAll()) + for mrr.NextResult() { + results = append(results, mrr.ResultReader().Read()) } - err := mr.Close() + err := mrr.Close() return results, err } -func (mr *MultiResult) receiveMessage() (pgproto3.BackendMessage, error) { - msg, err := mr.pgConn.ReceiveMessage() +func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := mrr.pgConn.ReceiveMessage() if err != nil { - mr.cleanupContextDeadline() - mr.err = preferContextOverNetTimeoutError(mr.ctx, err) - mr.closed = true + mrr.cleanupContextDeadline() + mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) + mrr.closed = true if err, ok := err.(net.Error); ok && err.Timeout() { - go mr.pgConn.recoverFromTimeout() + go mrr.pgConn.recoverFromTimeout() } else { - <-mr.pgConn.controller + <-mrr.pgConn.controller } - return nil, mr.err + return nil, mrr.err } switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - mr.cleanupContextDeadline() - mr.closed = true - <-mr.pgConn.controller + mrr.cleanupContextDeadline() + mrr.closed = true + <-mrr.pgConn.controller case *pgproto3.ErrorResponse: - mr.err = errorResponseToPgError(msg) + mrr.err = errorResponseToPgError(msg) } return msg, nil } -// NextResult returns advances the MultiResult to the next result and returns true if a result is available. -func (mr *MultiResult) NextResult() bool { - for !mr.closed && mr.err == nil { - msg, err := mr.receiveMessage() +// NextResult returns advances the MultiResultReader to the next result and returns true if a result is available. +func (mrr *MultiResultReader) NextResult() bool { + for !mrr.closed && mrr.err == nil { + msg, err := mrr.receiveMessage() if err != nil { return false } switch msg := msg.(type) { case *pgproto3.RowDescription: - mr.pgResult = &Result{ - pgConn: mr.pgConn, - pgMultiResult: mr, - ctx: mr.ctx, + mrr.rr = &ResultReader{ + pgConn: mrr.pgConn, + multiResultReader: mrr, + ctx: mrr.ctx, cleanupContextDeadline: func() {}, fieldDescriptions: msg.Fields, } return true case *pgproto3.CommandComplete: - mr.pgResult = &Result{ + mrr.rr = &ResultReader{ commandTag: CommandTag(msg.CommandTag), commandConcluded: true, closed: true, @@ -773,24 +773,24 @@ func (mr *MultiResult) NextResult() bool { return false } -func (mr *MultiResult) Result() *Result { - return mr.pgResult +func (mrr *MultiResultReader) ResultReader() *ResultReader { + return mrr.rr } -func (mr *MultiResult) Close() error { - for !mr.closed { - _, err := mr.receiveMessage() +func (mrr *MultiResultReader) Close() error { + for !mrr.closed { + _, err := mrr.receiveMessage() if err != nil { - return mr.err + return mrr.err } } - return mr.err + return mrr.err } -type Result struct { +type ResultReader struct { pgConn *PgConn - pgMultiResult *MultiResult + multiResultReader *MultiResultReader ctx context.Context cleanupContextDeadline func() @@ -802,15 +802,15 @@ type Result struct { err error } -type BufferedResult struct { +type Result struct { FieldDescriptions []pgproto3.FieldDescription Rows [][][]byte CommandTag CommandTag Err error } -func (rr *Result) ReadAll() *BufferedResult { - br := &BufferedResult{} +func (rr *ResultReader) Read() *Result { + br := &Result{} for rr.NextRow() { if br.FieldDescriptions == nil { @@ -828,8 +828,8 @@ func (rr *Result) ReadAll() *BufferedResult { return br } -// NextRow advances the Result to the next row and returns true if a row is available. -func (rr *Result) NextRow() bool { +// NextRow advances the ResultReader to the next row and returns true if a row is available. +func (rr *ResultReader) NextRow() bool { for !rr.commandConcluded { msg, err := rr.receiveMessage() if err != nil { @@ -847,21 +847,21 @@ func (rr *Result) NextRow() bool { } // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until -// the Result is closed. -func (rr *Result) FieldDescriptions() []pgproto3.FieldDescription { +// the ResultReader is closed. +func (rr *ResultReader) FieldDescriptions() []pgproto3.FieldDescription { return rr.fieldDescriptions } // Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only -// valid until the next NextRow call or the Result is closed. However, the underlying byte data is safe to +// valid until the next NextRow call or the ResultReader is closed. However, the underlying byte data is safe to // retain a reference to and mutate. -func (rr *Result) Values() [][]byte { +func (rr *ResultReader) Values() [][]byte { return rr.rowValues } // Close consumes any remaining result data and returns the command tag or // error. -func (rr *Result) Close() (CommandTag, error) { +func (rr *ResultReader) Close() (CommandTag, error) { if rr.closed { return rr.commandTag, rr.err } @@ -874,7 +874,7 @@ func (rr *Result) Close() (CommandTag, error) { } } - if rr.pgMultiResult == nil { + if rr.multiResultReader == nil { for { msg, err := rr.receiveMessage() if err != nil { @@ -893,18 +893,18 @@ func (rr *Result) Close() (CommandTag, error) { return rr.commandTag, rr.err } -func (rr *Result) receiveMessage() (msg pgproto3.BackendMessage, err error) { - if rr.pgMultiResult == nil { +func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { + if rr.multiResultReader == nil { msg, err = rr.pgConn.ReceiveMessage() } else { - msg, err = rr.pgMultiResult.receiveMessage() + msg, err = rr.multiResultReader.receiveMessage() } if err != nil { rr.concludeCommand(nil, err) rr.cleanupContextDeadline() rr.closed = true - if rr.pgMultiResult == nil { + if rr.multiResultReader == nil { if err, ok := err.(net.Error); ok && err.Timeout() { go rr.pgConn.recoverFromTimeout() } else { @@ -927,7 +927,7 @@ func (rr *Result) receiveMessage() (msg pgproto3.BackendMessage, err error) { return msg, nil } -func (rr *Result) concludeCommand(commandTag CommandTag, err error) { +func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { if rr.commandConcluded { return } @@ -1008,8 +1008,8 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor } // ExecBatch executes all the queries in batch in a single round-trip. -func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResult { - multiResult := &MultiResult{ +func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { + multiResult := &MultiResultReader{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go index 17d344b7..6b5efd9f 100644 --- a/pgconn_stress_test.go +++ b/pgconn_stress_test.go @@ -70,7 +70,7 @@ func stressExecSelect(pgConn *pgconn.PgConn) error { } func stressExecParamsSelect(pgConn *pgconn.PgConn) error { - result := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).ReadAll() + result := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() return result.Err } @@ -96,7 +96,7 @@ func stressExecSelectCanceled(pgConn *pgconn.PgConn) error { func stressExecParamsSelectCanceled(pgConn *pgconn.PgConn) error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - result := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).ReadAll() + result := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() cancel() if result.Err != context.DeadlineExceeded { return result.Err diff --git a/pgconn_test.go b/pgconn_test.go index a2eb7838..a524d18f 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -134,12 +134,12 @@ func TestConnectWithRuntimeParams(t *testing.T) { require.Nil(t, err) defer closeConn(t, conn) - result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).ReadAll() + result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read() require.Nil(t, result.Err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).ReadAll() + result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read() require.Nil(t, result.Err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "myschema", string(result.Rows[0][0])) @@ -263,7 +263,7 @@ func TestConnExecEmpty(t *testing.T) { resultCount := 0 for multiResult.NextResult() { resultCount += 1 - multiResult.Result().Close() + multiResult.ResultReader().Close() } assert.Equal(t, 0, resultCount) err = multiResult.Close() From 2959411c419c147d5eef0d1d8ae14b611b7850ac Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Jan 2019 18:06:00 -0600 Subject: [PATCH 0188/1158] CommandTag is string --- pgconn.go | 22 +++++++++------------- pgconn_test.go | 4 ++-- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/pgconn.go b/pgconn.go index be7d37ae..7cf3c91d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -347,7 +347,7 @@ func (pgConn *PgConn) ParameterStatus(key string) string { } // CommandTag is the result of an Exec function -type CommandTag []byte +type CommandTag string // RowsAffected returns the number of rows affected. If the CommandTag was not // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. @@ -361,10 +361,6 @@ func (ct CommandTag) RowsAffected() int64 { return n } -func (ct CommandTag) String() string { - return string(ct) -} - // preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == // true. Otherwise returns err. func preferContextOverNetTimeoutError(ctx context.Context, err error) error { @@ -602,7 +598,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] select { case <-ctx.Done(): - result.concludeCommand(nil, ctx.Err()) + result.concludeCommand("", ctx.Err()) result.closed = true return result case pgConn.controller <- result: @@ -628,7 +624,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] pgConn.closed = true } - result.concludeCommand(nil, err) + result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true <-pgConn.controller @@ -658,7 +654,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa select { case <-ctx.Done(): - result.concludeCommand(nil, ctx.Err()) + result.concludeCommand("", ctx.Err()) result.closed = true return result case pgConn.controller <- result: @@ -680,7 +676,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa pgConn.closed = true } - result.concludeCommand(nil, err) + result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true <-pgConn.controller @@ -870,7 +866,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { for !rr.commandConcluded { _, err := rr.receiveMessage() if err != nil { - return nil, rr.err + return "", rr.err } } @@ -878,7 +874,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { for { msg, err := rr.receiveMessage() if err != nil { - return nil, rr.err + return "", rr.err } switch msg.(type) { @@ -901,7 +897,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error } if err != nil { - rr.concludeCommand(nil, err) + rr.concludeCommand("", err) rr.cleanupContextDeadline() rr.closed = true if rr.multiResultReader == nil { @@ -921,7 +917,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.CommandComplete: rr.concludeCommand(CommandTag(msg.CommandTag), nil) case *pgproto3.ErrorResponse: - rr.concludeCommand(nil, errorResponseToPgError(msg)) + rr.concludeCommand("", errorResponseToPgError(msg)) } return msg, nil diff --git a/pgconn_test.go b/pgconn_test.go index a524d18f..a63aee38 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -375,7 +375,7 @@ func TestConnExecParamsCanceled(t *testing.T) { } assert.Equal(t, 0, rowCount) commandTag, err := result.Close() - assert.Nil(t, commandTag) + assert.Equal(t, pgconn.CommandTag(""), commandTag) assert.Equal(t, context.DeadlineExceeded, err) ensureConnValid(t, pgConn) @@ -427,7 +427,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { } assert.Equal(t, 0, rowCount) commandTag, err := result.Close() - assert.Nil(t, commandTag) + assert.Equal(t, pgconn.CommandTag(""), commandTag) assert.Equal(t, context.DeadlineExceeded, err) ensureConnValid(t, pgConn) From 406e95650a8823d98074dd3e08bdcf097e4b50cc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Jan 2019 18:40:33 -0600 Subject: [PATCH 0189/1158] Add more docs --- config.go | 35 +++++++++++++++++++---------------- doc.go | 29 +++++++++++++++++++++++++++++ pgconn.go | 10 +++++++++- 3 files changed, 57 insertions(+), 17 deletions(-) create mode 100644 doc.go diff --git a/config.go b/config.go index b85bcaec..13167729 100644 --- a/config.go +++ b/config.go @@ -70,32 +70,35 @@ func NetworkAddress(host string, port uint16) (network, address string) { // It also may be empty to only read from the environment. If a password is not supplied it will attempt to read the // .pgpass file. // -// Example DSN: "user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca" +// # Example DSN +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca // -// Example URL: "postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca" +// # Example URL +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca // // ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated // values that will be tried in order. This can be used as part of a high availability system. See // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. // -// Example URL: "postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb" +// # Example URL +// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb // // ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed // via database URL or DSN: // -// PGHOST -// PGPORT -// PGDATABASE -// PGUSER -// PGPASSWORD -// PGPASSFILE -// PGSSLMODE -// PGSSLCERT -// PGSSLKEY -// PGSSLROOTCERT -// PGAPPNAME -// PGCONNECT_TIMEOUT -// PGTARGETSESSIONATTRS +// PGHOST +// PGPORT +// PGDATABASE +// PGUSER +// PGPASSWORD +// PGPASSFILE +// PGSSLMODE +// PGSSLCERT +// PGSSLKEY +// PGSSLROOTCERT +// PGAPPNAME +// PGCONNECT_TIMEOUT +// PGTARGETSESSIONATTRS // // See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. // diff --git a/doc.go b/doc.go new file mode 100644 index 00000000..89e47536 --- /dev/null +++ b/doc.go @@ -0,0 +1,29 @@ +// Package pgconn is a low-level PostgreSQL database driver. +/* +pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at +nearly the same level is the C library libpq. + +Establishing a Connection + +Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for +libpq style environment variables. + +Executing a Query + +ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method +reads all rows into memory. + +Executing Multiple Queries in a Single Round Trip + +Exec and ExecBatch can execute multiple queries in a single round trip. The return readers that iterate over each query +result. The ReadAll method reads all query results into memory. + +Context Support + +All potentially blocking operations take a context.Context. If a context is canceled while a query is in progress the +method immediately returns. In the background a cancel request will be sent to the PostgreSQL server. If the +cancellation fails or hangs for more than a short time (approximately 15 seconds) the connection will be closed. It is +safe to use the connection while this background cancellation is in progress. Any calls will block until the +cancellation and resynchronization is complete (and those calls can be aborted by a context cancellation). +*/ +package pgconn diff --git a/pgconn.go b/pgconn.go index 7cf3c91d..bab4370a 100644 --- a/pgconn.go +++ b/pgconn.go @@ -685,6 +685,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } +// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { pgConn *PgConn ctx context.Context @@ -696,6 +697,7 @@ type MultiResultReader struct { err error } +// ReadAll reads all available results. Calling ReadAll is mutually exclusive with all other MultiResultReader methods. func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { var results []*Result @@ -769,10 +771,12 @@ func (mrr *MultiResultReader) NextResult() bool { return false } +// ResultReader returns the current ResultReader. func (mrr *MultiResultReader) ResultReader() *ResultReader { return mrr.rr } +// Close closes the MultiResultReader and returns the first error that occurred during the MultiResultReader's use. func (mrr *MultiResultReader) Close() error { for !mrr.closed { _, err := mrr.receiveMessage() @@ -784,6 +788,7 @@ func (mrr *MultiResultReader) Close() error { return mrr.err } +// ResultReader is a reader for the result of a single query. type ResultReader struct { pgConn *PgConn multiResultReader *MultiResultReader @@ -798,6 +803,7 @@ type ResultReader struct { err error } +// Result is the saved query response that is returned by calling Read on a ResultReader. type Result struct { FieldDescriptions []pgproto3.FieldDescription Rows [][][]byte @@ -805,6 +811,7 @@ type Result struct { Err error } +// Read saves the query response to a Result. func (rr *ResultReader) Read() *Result { br := &Result{} @@ -1003,7 +1010,8 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) } -// ExecBatch executes all the queries in batch in a single round-trip. +// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a +// transaction is already in progress or SQL contains transaction control statements. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { multiResult := &MultiResultReader{ pgConn: pgConn, From c6a73a469a84661171e31116a8228e54f3f52aa6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Jan 2019 18:47:50 -0600 Subject: [PATCH 0190/1158] Add example --- pgconn_test.go | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/pgconn_test.go b/pgconn_test.go index a63aee38..2d8cc784 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -3,6 +3,8 @@ package pgconn_test import ( "context" "crypto/tls" + "fmt" + "log" "net" "os" "testing" @@ -513,3 +515,27 @@ end$$;`) ensureConnValid(t, pgConn) } + +func Example() { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + log.Fatalln(err) + } + defer pgConn.Close(context.Background()) + + result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read() + if result.Err != nil { + log.Fatalln(result.Err) + } + + for _, row := range result.Rows { + fmt.Println(string(row[0])) + } + + fmt.Println(result.CommandTag) + // Output: + // 1 + // 2 + // 3 + // SELECT 3 +} From bd777fe20c73cf2eea37d2ada1a62164f0074bd1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Jan 2019 11:37:13 -0600 Subject: [PATCH 0191/1158] Add custom context cancellation hook --- config.go | 14 ++++++++++- pgconn.go | 38 +++++++++++++++++++++++++++- pgconn_test.go | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 13167729..40cbd0bb 100644 --- a/config.go +++ b/config.go @@ -41,7 +41,19 @@ type Config struct { // allows implementing high availability behavior such as libpq does with target_session_attrs. AfterConnectFunc AfterConnectFunc - OnNotice NoticeHandler // Callback function called when a notice response is received. + // OnContextCancel is a callback function used to override cancellation behavior. It is called when a context.Context + // is canceled. Default cancellation behavior is to establish another connection to the PostgreSQL server and send a + // query cancel request. Some non-PostgreSQL servers (e.g. CockroachDB) that speak a subset of the PostgreSQL wire + // protocol do not support this cancellation method. + // + // It is called from a background goroutine. When the cancellation process has finished ContextCancel.Finish must be + // called whether it was successful or not. If an error occurs the connection should be closed. The connection must be + // in a ready for query state or be closed when ContextCancel.Finish is called. Use PgConn.ReceiveMessage() to read + // the connection until a ready for query message is received. + OnContextCancel func(*ContextCancel) + + // OnNotice is a callback function called when a notice response is received. + OnNotice NoticeHandler } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a diff --git a/pgconn.go b/pgconn.go index bab4370a..08fce16e 100644 --- a/pgconn.go +++ b/pgconn.go @@ -527,6 +527,22 @@ func (pgConn *PgConn) cancelRequest(ctx context.Context) error { return nil } +// WaitUntilReady waits until a previous context cancellation has been competed processed and the connection is ready +// for use. This is done automatically by all methods that need the connection to be ready for use. The only expected +// use for this method is for a connection pool to wait for a returned connection to be usable again before making it +// available. +func (pgConn *PgConn) WaitUntilReady(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case pgConn.controller <- pgConn: + // The connection must be ready since it was locked. Immediately unlock it. + <-pgConn.controller + } + + return nil +} + // Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is // implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control // statements. @@ -942,7 +958,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { rr.commandConcluded = true } -func (pgConn *PgConn) recoverFromTimeout() { +func (pgConn *PgConn) defaultCancel() { // Regardless of recovery outcome the lock on the pgConn must be released. defer func() { <-pgConn.controller }() @@ -991,6 +1007,26 @@ func (pgConn *PgConn) recoverFromTimeout() { } } +type ContextCancel struct { + PgConn *PgConn +} + +// Finish must be called when the cancellation request has finished processing. The connection must be in a ready for +// query state or the connection must be closed. This must be called regardless of the success of the cancellation and +// whether the connection is still valid or not. It releases an internal busy lock on the connection. +func (cc *ContextCancel) Finish() { + <-cc.PgConn.controller +} + +func (pgConn *PgConn) recoverFromTimeout() { + if pgConn.Config.OnContextCancel == nil { + pgConn.defaultCancel() + } else { + cc := &ContextCancel{PgConn: pgConn} + pgConn.Config.OnContextCancel(cc) + } +} + // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { buf []byte diff --git a/pgconn_test.go b/pgconn_test.go index 2d8cc784..9452ffc0 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/jackc/pgx/pgconn" + "github.com/jackc/pgx/pgproto3" "github.com/pkg/errors" "github.com/stretchr/testify/assert" @@ -490,6 +491,72 @@ func TestCommandTag(t *testing.T) { } } +func TestConnContextCancelWithOnContextCancel(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + calledChan := make(chan struct{}) + + config.OnContextCancel = func(cc *pgconn.ContextCancel) { + defer cc.Finish() + close(calledChan) + + for { + msg, err := cc.PgConn.ReceiveMessage() + if err != nil { + cc.PgConn.Close(context.Background()) + return + } + + switch msg.(type) { + case *pgproto3.ReadyForQuery: + return + } + } + } + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result := pgConn.ExecParams(ctx, "select 'Hello, world', pg_sleep(0.25)", nil, nil, nil, nil) + _, err = result.Close() + assert.Equal(t, context.DeadlineExceeded, err) + + called := false + select { + case <-calledChan: + called = true + case <-time.NewTimer(time.Second).C: + } + + assert.True(t, called) + + ensureConnValid(t, pgConn) +} + +func TestConnWaitUntilReady(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil).Read() + assert.Equal(t, context.DeadlineExceeded, result.Err) + + err = pgConn.WaitUntilReady(context.Background()) + require.Nil(t, err) + + ensureConnValid(t, pgConn) +} + func TestConnOnNotice(t *testing.T) { t.Parallel() From 9c36fa1e5038662693788b43752bceae00f00417 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Jan 2019 15:38:20 -0600 Subject: [PATCH 0192/1158] Fix prepare failure --- pgconn.go | 9 +++++++-- pgconn_test.go | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/pgconn.go b/pgconn.go index 08fce16e..2a3c5936 100644 --- a/pgconn.go +++ b/pgconn.go @@ -438,6 +438,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ psd := &PreparedStatementDescription{Name: name, SQL: sql} + var parseErr error + readloop: for { msg, err := pgConn.ReceiveMessage() @@ -454,14 +456,17 @@ readloop: psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) copy(psd.Fields, msg.Fields) case *pgproto3.ErrorResponse: - go pgConn.recoverFromTimeout() - return nil, errorResponseToPgError(msg) + parseErr = errorResponseToPgError(msg) case *pgproto3.ReadyForQuery: break readloop } } <-pgConn.controller + + if parseErr != nil { + return nil, parseErr + } return psd, nil } diff --git a/pgconn_test.go b/pgconn_test.go index 9452ffc0..90f99325 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -235,6 +235,20 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { } } +func TestConnPrepareFailure(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) + require.Nil(t, psd) + require.NotNil(t, err) + + ensureConnValid(t, pgConn) +} + func TestConnExec(t *testing.T) { t.Parallel() From b3cde6830f0ae451d08f0854a421cc538e2d2e6e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Jan 2019 16:17:03 -0600 Subject: [PATCH 0193/1158] Fix die on receive message error --- pgconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 2a3c5936..9277d4a8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -263,7 +263,7 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { msg, err := pgConn.Frontend.Receive() if err != nil { // Close on anything other than timeout error - everything else is fatal - if err, ok := err.(net.Error); !ok && err.Timeout() { + if err, ok := err.(net.Error); !(ok && err.Timeout()) { pgConn.hardClose() } From cd4b0025c3d322ac21fe481a8e55a0743e37f27b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 14 Jan 2019 20:39:10 -0600 Subject: [PATCH 0194/1158] Add listen/notify to pgconn --- config.go | 3 +++ pgconn.go | 17 +++++++++++++++++ pgconn_test.go | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/config.go b/config.go index 40cbd0bb..fec1fedf 100644 --- a/config.go +++ b/config.go @@ -54,6 +54,9 @@ type Config struct { // OnNotice is a callback function called when a notice response is received. OnNotice NoticeHandler + + // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. + OnNotification NotificationHandler } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a diff --git a/pgconn.go b/pgconn.go index 9277d4a8..b2ffe7ca 100644 --- a/pgconn.go +++ b/pgconn.go @@ -50,6 +50,13 @@ func (pe *PgError) Error() string { // LISTEN/NOTIFY notification. type Notice PgError +// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system +type Notification struct { + PID uint32 // backend pid that sent the notification + Channel string // channel from which notification was received + Payload string +} + // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) @@ -59,6 +66,12 @@ type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // notification. type NoticeHandler func(*PgConn, *Notice) +// NotificationHandler is a function that can handle notifications received from the PostgreSQL server. Notifications +// can be received at any time, usually during handling of a query response. The *PgConn is provided so the handler is +// aware of the origin of the notice, but it must not invoke any query method. Be aware that this is distinct from a +// notice event. +type NotificationHandler func(*PgConn, *Notification) + // ErrTLSRefused occurs when the connection attempt requires TLS and the // PostgreSQL server refuses to use TLS var ErrTLSRefused = errors.New("server refused TLS connection") @@ -284,6 +297,10 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { if pgConn.Config.OnNotice != nil { pgConn.Config.OnNotice(pgConn, noticeResponseToNotice(msg)) } + case *pgproto3.NotificationResponse: + if pgConn.Config.OnNotification != nil { + pgConn.Config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) + } } return msg, nil diff --git a/pgconn_test.go b/pgconn_test.go index 90f99325..ad538257 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -597,6 +597,38 @@ end$$;`) ensureConnValid(t, pgConn) } +func TestConnOnNotification(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.Nil(t, err) + + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.Nil(t, err) + + _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() + require.Nil(t, err) + + assert.Equal(t, "bar", msg) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { From edfd837ba4192c55770f6a18bd1fcfb49ed07f4f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 14 Jan 2019 20:51:53 -0600 Subject: [PATCH 0195/1158] Add PgConn.WaitForNotification --- pgconn.go | 25 +++++++++++++++++++++++++ pgconn_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/pgconn.go b/pgconn.go index b2ffe7ca..efd7686f 100644 --- a/pgconn.go +++ b/pgconn.go @@ -565,6 +565,31 @@ func (pgConn *PgConn) WaitUntilReady(ctx context.Context) error { return nil } +// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not +// received. +func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case pgConn.controller <- pgConn: + } + cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanupContextDeadline() + defer func() { <-pgConn.controller }() + + for { + msg, err := pgConn.ReceiveMessage() + if err != nil { + return preferContextOverNetTimeoutError(ctx, err) + } + + switch msg.(type) { + case *pgproto3.NotificationResponse: + return nil + } + } +} + // Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is // implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control // statements. diff --git a/pgconn_test.go b/pgconn_test.go index ad538257..07e54c75 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -629,6 +629,56 @@ func TestConnOnNotification(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnWaitForNotification(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.Nil(t, err) + + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.Nil(t, err) + + err = pgConn.WaitForNotification(context.Background()) + require.Nil(t, err) + + assert.Equal(t, "bar", msg) + + ensureConnValid(t, pgConn) +} + +func TestConnWaitForNotificationTimeout(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + err = pgConn.WaitForNotification(ctx) + cancel() + require.Equal(t, context.DeadlineExceeded, err) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { From 66af2227c0d5c1d7af2412ddaebb04e0f974ee48 Mon Sep 17 00:00:00 2001 From: Josh Leverette Date: Thu, 17 Jan 2019 22:19:08 -0800 Subject: [PATCH 0196/1158] Fix encoding of ErrorResponse --- error_response.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/error_response.go b/error_response.go index 160234f2..987fe38a 100644 --- a/error_response.go +++ b/error_response.go @@ -115,70 +115,87 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { buf.Write(bigEndian.Uint32(0)) if src.Severity != "" { + buf.WriteByte('S') buf.WriteString(src.Severity) buf.WriteByte(0) } if src.Code != "" { + buf.WriteByte('C') buf.WriteString(src.Code) buf.WriteByte(0) } if src.Message != "" { + buf.WriteByte('M') buf.WriteString(src.Message) buf.WriteByte(0) } if src.Detail != "" { + buf.WriteByte('D') buf.WriteString(src.Detail) buf.WriteByte(0) } if src.Hint != "" { + buf.WriteByte('H') buf.WriteString(src.Hint) buf.WriteByte(0) } if src.Position != 0 { + buf.WriteByte('P') buf.WriteString(strconv.Itoa(int(src.Position))) buf.WriteByte(0) } if src.InternalPosition != 0 { + buf.WriteByte('p') buf.WriteString(strconv.Itoa(int(src.InternalPosition))) buf.WriteByte(0) } if src.InternalQuery != "" { + buf.WriteByte('q') buf.WriteString(src.InternalQuery) buf.WriteByte(0) } if src.Where != "" { + buf.WriteByte('W') buf.WriteString(src.Where) buf.WriteByte(0) } if src.SchemaName != "" { + buf.WriteByte('s') buf.WriteString(src.SchemaName) buf.WriteByte(0) } if src.TableName != "" { + buf.WriteByte('t') buf.WriteString(src.TableName) buf.WriteByte(0) } if src.ColumnName != "" { + buf.WriteByte('c') buf.WriteString(src.ColumnName) buf.WriteByte(0) } if src.DataTypeName != "" { + buf.WriteByte('d') buf.WriteString(src.DataTypeName) buf.WriteByte(0) } if src.ConstraintName != "" { + buf.WriteByte('n') buf.WriteString(src.ConstraintName) buf.WriteByte(0) } if src.File != "" { + buf.WriteByte('F') buf.WriteString(src.File) buf.WriteByte(0) } if src.Line != 0 { + buf.WriteByte('L') buf.WriteString(strconv.Itoa(int(src.Line))) buf.WriteByte(0) } if src.Routine != "" { + buf.WriteByte('R') buf.WriteString(src.Routine) buf.WriteByte(0) } From 738f3a1027b04d0b018d4adf0b9e1acc842a51b3 Mon Sep 17 00:00:00 2001 From: David Bariod Date: Tue, 15 Jan 2019 11:01:18 +0100 Subject: [PATCH 0197/1158] support binding of []int type to array integer --- int4_array.go | 19 +++++++++++++++++++ int4_array_test.go | 25 +++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/int4_array.go b/int4_array.go index 4e78ce71..86656524 100644 --- a/int4_array.go +++ b/int4_array.go @@ -23,6 +23,25 @@ func (dst *Int4Array) Set(src interface{}) error { switch value := src.(type) { + case []int: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []int32: if value == nil { *dst = Int4Array{Status: Null} diff --git a/int4_array_test.go b/int4_array_test.go index 602a3657..f0418600 100644 --- a/int4_array_test.go +++ b/int4_array_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "math" "reflect" "testing" @@ -54,8 +55,9 @@ func TestInt4ArrayTranscode(t *testing.T) { func TestInt4ArraySet(t *testing.T) { successfulTests := []struct { - source interface{} - result pgtype.Int4Array + source interface{} + result pgtype.Int4Array + expectedError bool }{ { source: []int32{1}, @@ -64,6 +66,17 @@ func TestInt4ArraySet(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, + { + source: []int{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []int{1, math.MaxInt32 + 1, 2}, + expectedError: true, + }, { source: []uint32{1}, result: pgtype.Int4Array{ @@ -81,9 +94,17 @@ func TestInt4ArraySet(t *testing.T) { var r pgtype.Int4Array err := r.Set(tt.source) if err != nil { + if tt.expectedError { + continue + } t.Errorf("%d: %v", i, err) } + if tt.expectedError { + t.Errorf("%d: an error was expected, %v", i, tt) + continue + } + if !reflect.DeepEqual(r, tt.result) { t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) } From e441d4828c13a21de6c8f96aa814ab0d119e639e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 19 Jan 2019 14:49:26 -0600 Subject: [PATCH 0198/1158] Fix doc typo --- pgconn.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pgconn.go b/pgconn.go index efd7686f..13301364 100644 --- a/pgconn.go +++ b/pgconn.go @@ -549,10 +549,9 @@ func (pgConn *PgConn) cancelRequest(ctx context.Context) error { return nil } -// WaitUntilReady waits until a previous context cancellation has been competed processed and the connection is ready -// for use. This is done automatically by all methods that need the connection to be ready for use. The only expected -// use for this method is for a connection pool to wait for a returned connection to be usable again before making it -// available. +// WaitUntilReady waits until a previous context cancellation has been completed and the connection is ready for use. +// This is done automatically by all methods that need the connection to be ready for use. The only expected use for +// this method is for a connection pool to wait for a returned connection to be usable again before making it available. func (pgConn *PgConn) WaitUntilReady(ctx context.Context) error { select { case <-ctx.Done(): From 19ef57ad9a7392ef6cd4ae96665ec2e32d1caa0c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 19 Jan 2019 14:49:39 -0600 Subject: [PATCH 0199/1158] Add PgConn.CopyTo --- pgconn.go | 65 ++++++++++++++++++++++++++++ pgconn_test.go | 112 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+) diff --git a/pgconn.go b/pgconn.go index 13301364..476cd046 100644 --- a/pgconn.go +++ b/pgconn.go @@ -747,6 +747,71 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } +// CopyTo executes the copy command sql and copies the results to w. +func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { + select { + case <-ctx.Done(): + return "", ctx.Err() + case pgConn.controller <- pgConn: + } + cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + + // Send copy to command + var buf []byte + buf = (&pgproto3.Query{String: sql}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + cleanupContextDeadline() + <-pgConn.controller + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + // Read results + var commandTag CommandTag + var pgErr error + for { + msg, err := pgConn.ReceiveMessage() + if err != nil { + cleanupContextDeadline() + if err, ok := err.(net.Error); ok && err.Timeout() { + go pgConn.recoverFromTimeout() + } else { + <-pgConn.controller + } + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.CopyDone: + case *pgproto3.CopyData: + _, err := w.Write(msg.Data) + if err != nil { + // This isn't actually a timeout, but we want the same behavior. Abort the request and cleanup. + cleanupContextDeadline() + go pgConn.recoverFromTimeout() + return "", err + } + case *pgproto3.ReadyForQuery: + <-pgConn.controller + return commandTag, pgErr + case *pgproto3.CommandComplete: + commandTag = CommandTag(msg.CommandTag) + case *pgproto3.ErrorResponse: + pgErr = errorResponseToPgError(msg) + } + } +} + // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { pgConn *PgConn diff --git a/pgconn_test.go b/pgconn_test.go index 07e54c75..ab7cfa72 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1,6 +1,7 @@ package pgconn_test import ( + "bytes" "context" "crypto/tls" "fmt" @@ -679,6 +680,117 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnCopyToSmall(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json + )`).ReadAll() + require.Nil(t, err) + + _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() + require.Nil(t, err) + + _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() + require.Nil(t, err) + + inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") + + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + + res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") + require.Nil(t, err) + + assert.Equal(t, int64(2), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToLarge(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json, + h bytea + )`).ReadAll() + require.Nil(t, err) + + inputBytes := make([]byte, 0) + + for i := 0; i < 1000; i++ { + _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() + require.Nil(t, err) + inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) + } + + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + + res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") + require.Nil(t, err) + + assert.Equal(t, int64(1000), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToQueryError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + outputWriter := bytes.NewBuffer(make([]byte, 0)) + + res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + outputWriter := &bytes.Buffer{} + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") + assert.Equal(t, context.DeadlineExceeded, err) + assert.Equal(t, pgconn.CommandTag(""), res) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { From c447ff4e797dc10be183fed254cbed82c61cc4f6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 19 Jan 2019 14:51:07 -0600 Subject: [PATCH 0200/1158] Use NoError instead of Nil for assertions --- config_test.go | 10 ++-- pgconn_stress_test.go | 4 +- pgconn_test.go | 128 +++++++++++++++++++++--------------------- 3 files changed, 71 insertions(+), 71 deletions(-) diff --git a/config_test.go b/config_test.go index e7a5bb44..c7b65861 100644 --- a/config_test.go +++ b/config_test.go @@ -515,12 +515,12 @@ func TestParseConfigEnvLibpq(t *testing.T) { for i, tt := range tests { for _, n := range pgEnvvars { err := os.Unsetenv(n) - require.Nil(t, err) + require.NoError(t, err) } for k, v := range tt.envvars { err := os.Setenv(k, v) - require.Nil(t, err) + require.NoError(t, err) } config, err := pgconn.ParseConfig("") @@ -536,13 +536,13 @@ func TestParseConfigReadsPgPassfile(t *testing.T) { t.Parallel() tf, err := ioutil.TempFile("", "") - require.Nil(t, err) + require.NoError(t, err) defer tf.Close() defer os.Remove(tf.Name()) _, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk")) - require.Nil(t, err) + require.NoError(t, err) connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name()) expected := &pgconn.Config{ @@ -556,7 +556,7 @@ func TestParseConfigReadsPgPassfile(t *testing.T) { } actual, err := pgconn.ParseConfig(connString) - assert.Nil(t, err) + assert.NoError(t, err) assertConfigsEqual(t, expected, actual, "passfile") } diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go index 6b5efd9f..7a95fa98 100644 --- a/pgconn_stress_test.go +++ b/pgconn_stress_test.go @@ -17,7 +17,7 @@ func TestConnStress(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) actionCount := 100 @@ -61,7 +61,7 @@ func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { ('Foo', 'bar'), ('baz', 'Something really long Something really long Something really long Something really long Something really long'), ('a', 'b')`).ReadAll() - require.Nil(t, err) + require.NoError(t, err) } func stressExecSelect(pgConn *pgconn.PgConn) error { diff --git a/pgconn_test.go b/pgconn_test.go index ab7cfa72..f3ed04df 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -40,7 +40,7 @@ func TestConnect(t *testing.T) { } conn, err := pgconn.Connect(context.Background(), connString) - require.Nil(t, err) + require.NoError(t, err) closeConn(t, conn) }) @@ -58,7 +58,7 @@ func TestConnectTLS(t *testing.T) { } conn, err := pgconn.Connect(context.Background(), connString) - require.Nil(t, err) + require.NoError(t, err) if _, ok := conn.Conn().(*tls.Conn); !ok { t.Error("not a TLS connection") @@ -76,7 +76,7 @@ func TestConnectInvalidUser(t *testing.T) { } config, err := pgconn.ParseConfig(connString) - require.Nil(t, err) + require.NoError(t, err) config.User = "pgxinvalidusertest" @@ -109,7 +109,7 @@ func TestConnectCustomDialer(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) dialed := false config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { @@ -118,7 +118,7 @@ func TestConnectCustomDialer(t *testing.T) { } conn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) require.True(t, dialed) closeConn(t, conn) } @@ -127,7 +127,7 @@ func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) config.RuntimeParams = map[string]string{ "application_name": "pgxtest", @@ -135,7 +135,7 @@ func TestConnectWithRuntimeParams(t *testing.T) { } conn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, conn) result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read() @@ -153,7 +153,7 @@ func TestConnectWithFallback(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) // Prepend current primary config to fallbacks config.Fallbacks = append([]*pgconn.FallbackConfig{ @@ -178,7 +178,7 @@ func TestConnectWithFallback(t *testing.T) { }, config.Fallbacks...) conn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) closeConn(t, conn) } @@ -186,7 +186,7 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) dialCount := 0 config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { @@ -214,7 +214,7 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) conn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) closeConn(t, conn) assert.True(t, dialCount > 1) @@ -225,7 +225,7 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) config.AfterConnectFunc = pgconn.AfterConnectTargetSessionAttrsReadWrite config.RuntimeParams["default_transaction_read_only"] = "on" @@ -240,7 +240,7 @@ func TestConnPrepareFailure(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) @@ -254,11 +254,11 @@ func TestConnExec(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err) @@ -273,7 +273,7 @@ func TestConnExecEmpty(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) multiResult := pgConn.Exec(context.Background(), ";") @@ -285,7 +285,7 @@ func TestConnExecEmpty(t *testing.T) { } assert.Equal(t, 0, resultCount) err = multiResult.Close() - assert.Nil(t, err) + assert.NoError(t, err) ensureConnValid(t, pgConn) } @@ -294,11 +294,11 @@ func TestConnExecMultipleQueries(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, results, 2) @@ -319,7 +319,7 @@ func TestConnExecMultipleQueriesError(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() @@ -341,7 +341,7 @@ func TestConnExecContextCanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -360,7 +360,7 @@ func TestConnExecParams(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) @@ -372,7 +372,7 @@ func TestConnExecParams(t *testing.T) { assert.Equal(t, 1, rowCount) commandTag, err := result.Close() assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) + assert.NoError(t, err) ensureConnValid(t, pgConn) } @@ -381,7 +381,7 @@ func TestConnExecParamsCanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -403,11 +403,11 @@ func TestConnExecPrepared(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) - require.Nil(t, err) + require.NoError(t, err) require.NotNil(t, psd) assert.Len(t, psd.ParamOIDs, 1) assert.Len(t, psd.Fields, 1) @@ -421,7 +421,7 @@ func TestConnExecPrepared(t *testing.T) { assert.Equal(t, 1, rowCount) commandTag, err := result.Close() assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) + assert.NoError(t, err) ensureConnValid(t, pgConn) } @@ -430,11 +430,11 @@ func TestConnExecPreparedCanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) - require.Nil(t, err) + require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() @@ -455,11 +455,11 @@ func TestConnExecBatch(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) - require.Nil(t, err) + require.NoError(t, err) batch := &pgconn.Batch{} @@ -467,7 +467,7 @@ func TestConnExecBatch(t *testing.T) { batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.Nil(t, err) + require.NoError(t, err) require.Len(t, results, 3) require.Len(t, results[0].Rows, 1) @@ -510,7 +510,7 @@ func TestConnContextCancelWithOnContextCancel(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) calledChan := make(chan struct{}) @@ -533,7 +533,7 @@ func TestConnContextCancelWithOnContextCancel(t *testing.T) { } pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -558,7 +558,7 @@ func TestConnWaitUntilReady(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -567,7 +567,7 @@ func TestConnWaitUntilReady(t *testing.T) { assert.Equal(t, context.DeadlineExceeded, result.Err) err = pgConn.WaitUntilReady(context.Background()) - require.Nil(t, err) + require.NoError(t, err) ensureConnValid(t, pgConn) } @@ -576,7 +576,7 @@ func TestConnOnNotice(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) var msg string config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { @@ -584,7 +584,7 @@ func TestConnOnNotice(t *testing.T) { } pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) multiResult := pgConn.Exec(context.Background(), `do $$ @@ -592,7 +592,7 @@ begin raise notice 'hello, world'; end$$;`) err = multiResult.Close() - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "hello, world", msg) ensureConnValid(t, pgConn) @@ -602,7 +602,7 @@ func TestConnOnNotification(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) var msg string config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { @@ -610,20 +610,20 @@ func TestConnOnNotification(t *testing.T) { } pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() - require.Nil(t, err) + require.NoError(t, err) notifier, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, notifier) _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() - require.Nil(t, err) + require.NoError(t, err) _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "bar", msg) @@ -634,7 +634,7 @@ func TestConnWaitForNotification(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) var msg string config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { @@ -642,20 +642,20 @@ func TestConnWaitForNotification(t *testing.T) { } pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() - require.Nil(t, err) + require.NoError(t, err) notifier, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, notifier) _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() - require.Nil(t, err) + require.NoError(t, err) err = pgConn.WaitForNotification(context.Background()) - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "bar", msg) @@ -666,10 +666,10 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) @@ -684,7 +684,7 @@ func TestConnCopyToSmall(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Exec(context.Background(), `create temporary table foo( @@ -696,13 +696,13 @@ func TestConnCopyToSmall(t *testing.T) { f date, g json )`).ReadAll() - require.Nil(t, err) + require.NoError(t, err) _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() - require.Nil(t, err) + require.NoError(t, err) _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() - require.Nil(t, err) + require.NoError(t, err) inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") @@ -710,7 +710,7 @@ func TestConnCopyToSmall(t *testing.T) { outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, int64(2), res.RowsAffected()) assert.Equal(t, inputBytes, outputWriter.Bytes()) @@ -722,7 +722,7 @@ func TestConnCopyToLarge(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Exec(context.Background(), `create temporary table foo( @@ -735,20 +735,20 @@ func TestConnCopyToLarge(t *testing.T) { g json, h bytea )`).ReadAll() - require.Nil(t, err) + require.NoError(t, err) inputBytes := make([]byte, 0) for i := 0; i < 1000; i++ { _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() - require.Nil(t, err) + require.NoError(t, err) inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) } outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, int64(1000), res.RowsAffected()) assert.Equal(t, inputBytes, outputWriter.Bytes()) @@ -760,7 +760,7 @@ func TestConnCopyToQueryError(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) outputWriter := bytes.NewBuffer(make([]byte, 0)) @@ -777,7 +777,7 @@ func TestConnCopyToCanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) outputWriter := &bytes.Buffer{} From e15528c4195b2e3b8cb2e9f8b0eacf80d5a5fba3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 19 Jan 2019 15:41:42 -0600 Subject: [PATCH 0201/1158] Remove obsolete comment --- pgconn.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 476cd046..aa246614 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1145,7 +1145,6 @@ type Batch struct { // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { - // TODO - refactor ExecParams and ExecPrepared - these lines only difference batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) batch.ExecPrepared("", paramValues, paramFormats, resultFormats) } From c9f985c1e40fea85c7acc1a404063d3f4d94b001 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 19 Jan 2019 15:44:03 -0600 Subject: [PATCH 0202/1158] Add PgConn.EscapeString --- pgconn.go | 17 +++++++++++++++++ pgconn_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/pgconn.go b/pgconn.go index aa246614..49062f23 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1193,3 +1193,20 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR return multiResult } + +// EscapeString escapes a string such that it can safely be interpolated into a SQL command string. It does not include +// the surrounding single quotes. +// +// The current implementation requires that standard_conforming_strings=on and client_encoding="UTF8". If these +// conditions are not met an error will be returned. It is possible these restrictions will be lifted in the future. +func (pgConn *PgConn) EscapeString(s string) (string, error) { + if pgConn.ParameterStatus("standard_conforming_strings") != "on" { + return "", errors.New("EscapeString must be run with standard_conforming_strings=on") + } + + if pgConn.ParameterStatus("client_encoding") != "UTF8" { + return "", errors.New("EscapeString must be run with client_encoding=UTF8") + } + + return strings.Replace(s, "'", "''", -1), nil +} diff --git a/pgconn_test.go b/pgconn_test.go index f3ed04df..587acc57 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -791,6 +791,34 @@ func TestConnCopyToCanceled(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnEscapeString(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + tests := []struct { + in string + out string + }{ + {in: "", out: ""}, + {in: "42", out: "42"}, + {in: "'", out: "''"}, + {in: "hi'there", out: "hi''there"}, + {in: "'hi there'", out: "''hi there''"}, + } + + for i, tt := range tests { + value, err := pgConn.EscapeString(tt.in) + if assert.NoErrorf(t, err, "%d.", i) { + assert.Equalf(t, tt.out, value, "%d.", i) + } + } + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { From 3683e4a0a16d4508f14ac54a8986a8e5dc658a59 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 19 Jan 2019 17:24:48 -0600 Subject: [PATCH 0203/1158] Move CopyFrom to pgconn --- pgconn.go | 129 ++++++++++++++++++++++++++++++++++++++++++++++ pgconn_test.go | 136 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 265 insertions(+) diff --git a/pgconn.go b/pgconn.go index 49062f23..e8baffa2 100644 --- a/pgconn.go +++ b/pgconn.go @@ -14,6 +14,7 @@ import ( "strings" "time" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" ) @@ -812,6 +813,134 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm } } +// CopyFrom executes the copy command sql and copies all of r to the PostgreSQL server. +// +// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r +// could still block. +func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { + select { + case <-ctx.Done(): + return "", ctx.Err() + case pgConn.controller <- pgConn: + } + cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + + // Send copy to command + var buf []byte + buf = (&pgproto3.Query{String: sql}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + cleanupContextDeadline() + <-pgConn.controller + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + // Read until copy in response or error. + var commandTag CommandTag + var pgErr error + pendingCopyInResponse := true + for pendingCopyInResponse { + msg, err := pgConn.ReceiveMessage() + if err != nil { + cleanupContextDeadline() + if err, ok := err.(net.Error); ok && err.Timeout() { + go pgConn.recoverFromTimeout() + } else { + <-pgConn.controller + } + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.CopyInResponse: + pendingCopyInResponse = false + case *pgproto3.ErrorResponse: + pgErr = errorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + <-pgConn.controller + return commandTag, pgErr + } + } + + // Send copy data + buf = make([]byte, 0, 65536) + buf = append(buf, 'd') + sp := len(buf) + for { + n, err := r.Read(buf[5:cap(buf)]) + if err == io.EOF && n == 0 { + break + } + buf = buf[0 : n+5] + pgio.SetInt32(buf[sp:], int32(n+4)) + + _, err = pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to + // recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to + // close the connection. + pgConn.conn.Close() + pgConn.closed = true + + cleanupContextDeadline() + <-pgConn.controller + + return "", preferContextOverNetTimeoutError(ctx, err) + } + } + + // Send copy done + buf = buf[:0] + copyDone := &pgproto3.CopyDone{} + buf = copyDone.Encode(buf) + + _, err = pgConn.conn.Write(buf) + if err != nil { + pgConn.conn.Close() + pgConn.closed = true + + cleanupContextDeadline() + <-pgConn.controller + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + // Read results + for { + msg, err := pgConn.ReceiveMessage() + if err != nil { + cleanupContextDeadline() + if err, ok := err.(net.Error); ok && err.Timeout() { + go pgConn.recoverFromTimeout() + } else { + <-pgConn.controller + } + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + <-pgConn.controller + return commandTag, pgErr + case *pgproto3.CommandComplete: + commandTag = CommandTag(msg.CommandTag) + case *pgproto3.ErrorResponse: + pgErr = errorResponseToPgError(msg) + } + } +} + // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { pgConn *PgConn diff --git a/pgconn_test.go b/pgconn_test.go index 587acc57..47b3b3fb 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -2,12 +2,15 @@ package pgconn_test import ( "bytes" + "compress/gzip" "context" "crypto/tls" "fmt" + "io/ioutil" "log" "net" "os" + "strconv" "testing" "time" @@ -791,6 +794,139 @@ func TestConnCopyToCanceled(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnCopyFrom(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + srcBuf := &bytes.Buffer{} + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } + + ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromGzipReader(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + f, err := ioutil.TempFile("", "*") + require.NoError(t, err) + + gw := gzip.NewWriter(f) + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } + + err = gw.Close() + require.NoError(t, err) + + _, err = f.Seek(0, 0) + require.NoError(t, err) + + gr, err := gzip.NewReader(f) + require.NoError(t, err) + + ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + err = gr.Close() + require.NoError(t, err) + + err = f.Close() + require.NoError(t, err) + + err = os.Remove(f.Name()) + require.NoError(t, err) + + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromQuerySyntaxError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + srcBuf := &bytes.Buffer{} + + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromQueryNoTableError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + srcBuf := &bytes.Buffer{} + + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + func TestConnEscapeString(t *testing.T) { t.Parallel() From 01b54c7cb6f204983e3ece13262e4560b798eab9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jan 2019 10:21:16 -0600 Subject: [PATCH 0204/1158] Properly abort CopyFrom on reader error --- pgconn.go | 45 ++++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/pgconn.go b/pgconn.go index e8baffa2..d8ec6b07 100644 --- a/pgconn.go +++ b/pgconn.go @@ -876,34 +876,37 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co buf = make([]byte, 0, 65536) buf = append(buf, 'd') sp := len(buf) - for { - n, err := r.Read(buf[5:cap(buf)]) - if err == io.EOF && n == 0 { - break - } - buf = buf[0 : n+5] - pgio.SetInt32(buf[sp:], int32(n+4)) + var readErr error + for readErr == nil { + n, readErr = r.Read(buf[5:cap(buf)]) + if n > 0 { + buf = buf[0 : n+5] + pgio.SetInt32(buf[sp:], int32(n+4)) - _, err = pgConn.conn.Write(buf) - if err != nil { - // Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to - // recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to - // close the connection. - pgConn.conn.Close() - pgConn.closed = true + _, err = pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to + // recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to + // close the connection. + pgConn.conn.Close() + pgConn.closed = true - cleanupContextDeadline() - <-pgConn.controller + cleanupContextDeadline() + <-pgConn.controller - return "", preferContextOverNetTimeoutError(ctx, err) + return "", preferContextOverNetTimeoutError(ctx, err) + } } } - // Send copy done buf = buf[:0] - copyDone := &pgproto3.CopyDone{} - buf = copyDone.Encode(buf) - + if readErr == io.EOF { + copyDone := &pgproto3.CopyDone{} + buf = copyDone.Encode(buf) + } else { + copyFail := &pgproto3.CopyFail{Error: readErr.Error()} + buf = copyFail.Encode(buf) + } _, err = pgConn.conn.Write(buf) if err != nil { pgConn.conn.Close() From 38671ea10678734cddfabb210318a6e84f49032a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jan 2019 10:21:16 -0600 Subject: [PATCH 0205/1158] Properly abort CopyFrom on reader error --- copy_fail.go | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 copy_fail.go diff --git a/copy_fail.go b/copy_fail.go new file mode 100644 index 00000000..432a311b --- /dev/null +++ b/copy_fail.go @@ -0,0 +1,49 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type CopyFail struct { + Error string +} + +func (*CopyFail) Frontend() {} +func (*CopyFail) Backend() {} + +func (dst *CopyFail) Decode(src []byte) error { + idx := bytes.IndexByte(src, 0) + if idx != len(src)-1 { + return &invalidMessageFormatErr{messageType: "CopyFail"} + } + + dst.Error = string(src[:idx]) + + return nil +} + +func (src *CopyFail) Encode(dst []byte) []byte { + dst = append(dst, 'C') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.Error...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *CopyFail) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Error string + }{ + Type: "CopyFail", + Error: src.Error, + }) +} From 96c85cf0c3981d8e35cd3c5fd34a9d1c1ddad313 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jan 2019 12:20:36 -0600 Subject: [PATCH 0206/1158] Recover from context cancellation during CopyFrom --- pgconn.go | 131 ++++++++++++++++++++++++++++++++++++++++++++----- pgconn_test.go | 36 ++++++++++++++ 2 files changed, 155 insertions(+), 12 deletions(-) diff --git a/pgconn.go b/pgconn.go index d8ec6b07..e34853a0 100644 --- a/pgconn.go +++ b/pgconn.go @@ -12,6 +12,7 @@ import ( "net" "strconv" "strings" + "sync" "time" "github.com/jackc/pgx/pgio" @@ -91,6 +92,11 @@ type PgConn struct { controller chan interface{} closed bool + + bufferingReceive bool + bufferingReceiveMux sync.Mutex + bufferingReceiveMsg pgproto3.BackendMessage + bufferingReceiveErr error } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -273,8 +279,42 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } +func (pgConn *PgConn) signalMessage() chan struct{} { + if pgConn.bufferingReceive { + panic("BUG: signalMessage when already in progress") + } + + pgConn.bufferingReceive = true + pgConn.bufferingReceiveMux.Lock() + + ch := make(chan struct{}) + go func() { + pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.Frontend.Receive() + pgConn.bufferingReceiveMux.Unlock() + close(ch) + }() + + return ch +} + func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { - msg, err := pgConn.Frontend.Receive() + var msg pgproto3.BackendMessage + var err error + if pgConn.bufferingReceive { + pgConn.bufferingReceiveMux.Lock() + msg = pgConn.bufferingReceiveMsg + err = pgConn.bufferingReceiveErr + pgConn.bufferingReceiveMux.Unlock() + pgConn.bufferingReceive = false + + // If a timeout error happened in the background try the read again. + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + msg, err = pgConn.Frontend.Receive() + } + } else { + msg, err = pgConn.Frontend.Receive() + } + if err != nil { // Close on anything other than timeout error - everything else is fatal if err, ok := err.(net.Error); !(ok && err.Timeout()) { @@ -853,7 +893,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co if err != nil { cleanupContextDeadline() if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeout() + go pgConn.recoverFromTimeoutDuringCopyFrom() } else { <-pgConn.controller } @@ -877,30 +917,56 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co buf = append(buf, 'd') sp := len(buf) var readErr error - for readErr == nil { + signalMessageChan := pgConn.signalMessage() + for readErr == nil && pgErr == nil { n, readErr = r.Read(buf[5:cap(buf)]) if n > 0 { buf = buf[0 : n+5] pgio.SetInt32(buf[sp:], int32(n+4)) - _, err = pgConn.conn.Write(buf) + n, err = pgConn.conn.Write(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to - // recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to - // close the connection. - pgConn.conn.Close() - pgConn.closed = true - + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } cleanupContextDeadline() - <-pgConn.controller + if err, ok := err.(net.Error); ok && err.Timeout() { + go pgConn.recoverFromTimeoutDuringCopyFrom() + } else { + <-pgConn.controller + } return "", preferContextOverNetTimeoutError(ctx, err) } } + + select { + case <-signalMessageChan: + msg, err := pgConn.ReceiveMessage() + if err != nil { + cleanupContextDeadline() + if err, ok := err.(net.Error); ok && err.Timeout() { + go pgConn.recoverFromTimeoutDuringCopyFrom() + } else { + <-pgConn.controller + } + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + pgErr = errorResponseToPgError(msg) + } + default: + } } buf = buf[:0] - if readErr == io.EOF { + if readErr == io.EOF || pgErr != nil { copyDone := &pgproto3.CopyDone{} buf = copyDone.Encode(buf) } else { @@ -944,6 +1010,47 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } } +func (pgConn *PgConn) recoverFromTimeoutDuringCopyFrom() { + // Regardless of recovery outcome the lock on the pgConn must be released. + defer func() { <-pgConn.controller }() + + // Limit time to wait for entire cancellation process. + err := pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second)) + if err != nil { + pgConn.hardClose() + return + } + + copyFail := &pgproto3.CopyFail{Error: "client cancel"} + buf := copyFail.Encode(nil) + + _, err = pgConn.conn.Write(buf) + if err != nil { + pgConn.hardClose() + return + } + + pendingReadyForQuery := true + + for pendingReadyForQuery { + msg, err := pgConn.ReceiveMessage() + if err != nil { + pgConn.hardClose() + return + } + + switch msg.(type) { + case *pgproto3.ReadyForQuery: + pendingReadyForQuery = false + } + } + + err = pgConn.conn.SetDeadline(time.Time{}) + if err != nil { + pgConn.hardClose() + } +} + // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { pgConn *PgConn diff --git a/pgconn_test.go b/pgconn_test.go index 47b3b3fb..7fb01e2c 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -6,6 +6,7 @@ import ( "context" "crypto/tls" "fmt" + "io" "io/ioutil" "log" "net" @@ -830,6 +831,41 @@ func TestConnCopyFrom(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnCopyFromCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + r, w := io.Pipe() + go func() { + for i := 0; i < 1000000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + if err != nil { + return + } + time.Sleep(time.Microsecond) + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") + cancel() + assert.Equal(t, int64(0), ct.RowsAffected()) + require.Equal(t, context.DeadlineExceeded, err) + + ensureConnValid(t, pgConn) +} + func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel() From 440fbf158199b9de56000f1fee77cf8c0ee59b47 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jan 2019 12:21:54 -0600 Subject: [PATCH 0207/1158] Include missed changes --- backend.go | 3 +++ frontend.go | 3 +++ 2 files changed, 6 insertions(+) diff --git a/backend.go b/backend.go index 8f3c3478..ea44d1d1 100644 --- a/backend.go +++ b/backend.go @@ -15,6 +15,7 @@ type Backend struct { // Frontend message flyweights bind Bind _close Close + copyFail CopyFail describe Describe execute Execute flush Flush @@ -82,6 +83,8 @@ func (b *Backend) Receive() (FrontendMessage, error) { msg = &b.describe case 'E': msg = &b.execute + case 'f': + msg = &b.copyFail case 'H': msg = &b.flush case 'P': diff --git a/frontend.go b/frontend.go index d1541c74..31a955bc 100644 --- a/frontend.go +++ b/frontend.go @@ -23,6 +23,7 @@ type Frontend struct { copyInResponse CopyInResponse copyOutResponse CopyOutResponse copyDone CopyDone + copyFail CopyFail dataRow DataRow emptyQueryResponse EmptyQueryResponse errorResponse ErrorResponse @@ -83,6 +84,8 @@ func (b *Frontend) Receive() (BackendMessage, error) { msg = &b.dataRow case 'E': msg = &b.errorResponse + case 'f': + msg = &b.copyFail case 'G': msg = &b.copyInResponse case 'H': From f5aecdd4992504d8344ea0730800e38d48b32f28 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jan 2019 12:33:51 -0600 Subject: [PATCH 0208/1158] Extract writeAll --- pgconn.go | 87 ++++++++++++++----------------------------------------- 1 file changed, 21 insertions(+), 66 deletions(-) diff --git a/pgconn.go b/pgconn.go index e34853a0..461ff1c0 100644 --- a/pgconn.go +++ b/pgconn.go @@ -398,6 +398,15 @@ func (pgConn *PgConn) hardClose() error { return pgConn.conn.Close() } +// writeAll writes the entire buffer successfully or it hard closes the connection. +func (pgConn *PgConn) writeAll(buf []byte) error { + n, err := pgConn.conn.Write(buf) + if err != nil && n > 0 { + pgConn.hardClose() + } + return err +} + // ParameterStatus returns the value of a parameter reported by the server (e.g. // server_version). Returns an empty string for unknown parameters. func (pgConn *PgConn) ParameterStatus(key string) string { @@ -482,15 +491,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - return nil, preferContextOverNetTimeoutError(ctx, err) } @@ -654,15 +656,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) @@ -718,15 +713,8 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true @@ -770,15 +758,8 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true @@ -801,15 +782,8 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - cleanupContextDeadline() <-pgConn.controller @@ -869,15 +843,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - cleanupContextDeadline() <-pgConn.controller @@ -913,25 +880,21 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } // Send copy data - buf = make([]byte, 0, 65536) + buf = make([]byte, 0, 20000) + // buf = make([]byte, 0, 65536) buf = append(buf, 'd') sp := len(buf) var readErr error signalMessageChan := pgConn.signalMessage() for readErr == nil && pgErr == nil { + var n int n, readErr = r.Read(buf[5:cap(buf)]) if n > 0 { buf = buf[0 : n+5] pgio.SetInt32(buf[sp:], int32(n+4)) - n, err = pgConn.conn.Write(buf) + err = pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } cleanupContextDeadline() if err, ok := err.(net.Error); ok && err.Timeout() { go pgConn.recoverFromTimeoutDuringCopyFrom() @@ -975,8 +938,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } _, err = pgConn.conn.Write(buf) if err != nil { - pgConn.conn.Close() - pgConn.closed = true + pgConn.hardClose() cleanupContextDeadline() <-pgConn.controller @@ -1414,15 +1376,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) - n, err := pgConn.conn.Write(batch.buf) + err := pgConn.writeAll(batch.buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) From b59437f6ecfec5b604e0ae2063134078578e1d7e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jan 2019 16:45:06 -0600 Subject: [PATCH 0209/1158] writeAll dies on permanent net errors --- pgconn.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pgconn.go b/pgconn.go index 461ff1c0..06f9e833 100644 --- a/pgconn.go +++ b/pgconn.go @@ -398,11 +398,15 @@ func (pgConn *PgConn) hardClose() error { return pgConn.conn.Close() } -// writeAll writes the entire buffer successfully or it hard closes the connection. +// writeAll writes the entire buffer. The connection is hard closed on a partial write or a non-temporary error. func (pgConn *PgConn) writeAll(buf []byte) error { n, err := pgConn.conn.Write(buf) - if err != nil && n > 0 { - pgConn.hardClose() + if err != nil { + if n > 0 { + pgConn.hardClose() + } else if ne, ok := err.(net.Error); ok && !ne.Temporary() { + pgConn.hardClose() + } } return err } From 9229e03d06a317a765275e7bd82f301a623b760d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jan 2019 16:46:30 -0600 Subject: [PATCH 0210/1158] Partial conversion of pgx to use pgconn --- pgconn.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pgconn.go b/pgconn.go index 06f9e833..512c9a88 100644 --- a/pgconn.go +++ b/pgconn.go @@ -398,6 +398,12 @@ func (pgConn *PgConn) hardClose() error { return pgConn.conn.Close() } +// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect deatch of +// underlying connection. +func (pgConn *PgConn) IsAlive() bool { + return !pgConn.closed +} + // writeAll writes the entire buffer. The connection is hard closed on a partial write or a non-temporary error. func (pgConn *PgConn) writeAll(buf []byte) error { n, err := pgConn.conn.Write(buf) From 79ffab98367fef041c34bfc3307b82f833661694 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 28 Jan 2019 23:13:03 -0600 Subject: [PATCH 0211/1158] All writes errors are fatal --- pgconn.go | 43 +++++++++++++++++-------------------------- pgconn_test.go | 2 +- 2 files changed, 18 insertions(+), 27 deletions(-) diff --git a/pgconn.go b/pgconn.go index 512c9a88..c785f367 100644 --- a/pgconn.go +++ b/pgconn.go @@ -404,19 +404,6 @@ func (pgConn *PgConn) IsAlive() bool { return !pgConn.closed } -// writeAll writes the entire buffer. The connection is hard closed on a partial write or a non-temporary error. -func (pgConn *PgConn) writeAll(buf []byte) error { - n, err := pgConn.conn.Write(buf) - if err != nil { - if n > 0 { - pgConn.hardClose() - } else if ne, ok := err.(net.Error); ok && !ne.Temporary() { - pgConn.hardClose() - } - } - return err -} - // ParameterStatus returns the value of a parameter reported by the server (e.g. // server_version). Returns an empty string for unknown parameters. func (pgConn *PgConn) ParameterStatus(key string) string { @@ -501,8 +488,9 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() return nil, preferContextOverNetTimeoutError(ctx, err) } @@ -666,8 +654,9 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) @@ -723,8 +712,9 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true @@ -768,8 +758,9 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true @@ -792,8 +783,9 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() cleanupContextDeadline() <-pgConn.controller @@ -853,8 +845,9 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() cleanupContextDeadline() <-pgConn.controller @@ -903,14 +896,11 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co buf = buf[0 : n+5] pgio.SetInt32(buf[sp:], int32(n+4)) - err = pgConn.writeAll(buf) + _, err = pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeoutDuringCopyFrom() - } else { - <-pgConn.controller - } + <-pgConn.controller return "", preferContextOverNetTimeoutError(ctx, err) } @@ -1386,8 +1376,9 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) - err := pgConn.writeAll(batch.buf) + _, err := pgConn.conn.Write(batch.buf) if err != nil { + pgConn.hardClose() multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) diff --git a/pgconn_test.go b/pgconn_test.go index 7fb01e2c..dbf9b840 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -863,7 +863,7 @@ func TestConnCopyFromCanceled(t *testing.T) { assert.Equal(t, int64(0), ct.RowsAffected()) require.Equal(t, context.DeadlineExceeded, err) - ensureConnValid(t, pgConn) + assert.False(t, pgConn.IsAlive()) } func TestConnCopyFromGzipReader(t *testing.T) { From 2a112595551c4551e0c6064224b95e6035a16092 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:03:04 -0500 Subject: [PATCH 0212/1158] Add readme and license --- LICENSE | 22 ++++++++++++++++++++++ README.md | 11 +++++++++++ 2 files changed, 33 insertions(+) create mode 100644 LICENSE create mode 100644 README.md diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..c1c4f50f --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2019 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 00000000..81139c3e --- /dev/null +++ b/README.md @@ -0,0 +1,11 @@ +[![](https://godoc.org/github.com/jackc/pgpassfile?status.svg)](https://godoc.org/github.com/jackc/pgpassfile) +[![Build Status](https://travis-ci.org/jackc/pgpassfile.svg)](https://travis-ci.org/jackc/pgpassfile) + +# pgio + +Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol. + +pgio provides functions for appending integers to a []byte while doing byte +order conversion. + +Extracted from original implementation in https://github.com/jackc/pgx. From 715eaaf2ed8a28a35a85d331f3b1ef08a177a5af Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:03:34 -0500 Subject: [PATCH 0213/1158] Add go module support --- go.mod | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 go.mod diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..c1efdddb --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/jackc/pgio + +go 1.12 From 8abf4a9eaab90a7470c06fdd6c048f748da0d168 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:04:23 -0500 Subject: [PATCH 0214/1158] Fix links in readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 81139c3e..1952ed86 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ -[![](https://godoc.org/github.com/jackc/pgpassfile?status.svg)](https://godoc.org/github.com/jackc/pgpassfile) -[![Build Status](https://travis-ci.org/jackc/pgpassfile.svg)](https://travis-ci.org/jackc/pgpassfile) +[![](https://godoc.org/github.com/jackc/pgio?status.svg)](https://godoc.org/github.com/jackc/pgio) +[![Build Status](https://travis-ci.org/jackc/pgio.svg)](https://travis-ci.org/jackc/pgio) # pgio From 8d9c2a3dafd92d070bd758a165022fd1059e3195 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:04:38 -0500 Subject: [PATCH 0215/1158] Add travis ci --- .travis.yml | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..e176228e --- /dev/null +++ b/.travis.yml @@ -0,0 +1,9 @@ +language: go + +go: + - 1.x + - tip + +matrix: + allow_failures: + - go: tip From e2207bfbaf2d7771de52f1d0951a1e0b8cc7882e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:18:27 -0500 Subject: [PATCH 0216/1158] Add some documentation --- chunkreader.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chunkreader.go b/chunkreader.go index f8d437b2..fedd41e5 100644 --- a/chunkreader.go +++ b/chunkreader.go @@ -1,9 +1,11 @@ +// Package chunkreader provides an opinionated, efficient buffered reader. package chunkreader import ( "io" ) +// ChunkReader is a io.Reader wrapper that minimizes reads and memory allocations. type ChunkReader struct { r io.Reader From 65a3248f5c03df8019ec64664dc600e635208af7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:20:18 -0500 Subject: [PATCH 0217/1158] Add license and readme --- LICENSE | 22 ++++++++++++++++++++++ README.md | 8 ++++++++ 2 files changed, 30 insertions(+) create mode 100644 LICENSE create mode 100644 README.md diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..c1c4f50f --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2019 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 00000000..bcc9ac6b --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +[![](https://godoc.org/github.com/jackc/chunkreader?status.svg)](https://godoc.org/github.com/jackc/chunkreader) +[![Build Status](https://travis-ci.org/jackc/chunkreader.svg)](https://travis-ci.org/jackc/chunkreader) + +# chunkreader + +Package chunkreader provides an opinionated, efficient buffered reader. + +Extracted from original implementation in https://github.com/jackc/pgx. From 811a7d92d62c682d8229f63f19c0b425ef9c12aa Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:21:06 -0500 Subject: [PATCH 0218/1158] Add Go module support --- go.mod | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 go.mod diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..b1ed8c92 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/jackc/chunkreader + +go 1.12 From 517cfde605cd1f91edc0a62affca418114388fc3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:21:36 -0500 Subject: [PATCH 0219/1158] Add Travis CI --- .travis.yml | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..e176228e --- /dev/null +++ b/.travis.yml @@ -0,0 +1,9 @@ +language: go + +go: + - 1.x + - tip + +matrix: + allow_failures: + - go: tip From 16176b5151770608f202298776016b427a1ddcfa Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:26:24 -0500 Subject: [PATCH 0220/1158] Add go module support --- authentication.go | 2 +- backend.go | 2 +- backend_key_data.go | 2 +- backend_test.go | 2 +- bind.go | 2 +- close.go | 2 +- command_complete.go | 2 +- copy_both_response.go | 2 +- copy_data.go | 2 +- copy_fail.go | 2 +- copy_in_response.go | 2 +- copy_out_response.go | 2 +- data_row.go | 2 +- describe.go | 2 +- execute.go | 2 +- frontend.go | 2 +- frontend_test.go | 2 +- function_call_response.go | 2 +- go.mod | 9 +++++++++ go.sum | 6 ++++++ notification_response.go | 2 +- parameter_description.go | 2 +- parameter_status.go | 2 +- parse.go | 2 +- password_message.go | 2 +- query.go | 2 +- row_description.go | 2 +- startup_message.go | 2 +- 28 files changed, 41 insertions(+), 26 deletions(-) create mode 100644 go.mod create mode 100644 go.sum diff --git a/authentication.go b/authentication.go index 77750b86..14275a86 100644 --- a/authentication.go +++ b/authentication.go @@ -3,7 +3,7 @@ package pgproto3 import ( "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/backend.go b/backend.go index ea44d1d1..b64f006c 100644 --- a/backend.go +++ b/backend.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "io" - "github.com/jackc/pgx/chunkreader" + "github.com/jackc/chunkreader" "github.com/pkg/errors" ) diff --git a/backend_key_data.go b/backend_key_data.go index 5a478f10..0396379b 100644 --- a/backend_key_data.go +++ b/backend_key_data.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type BackendKeyData struct { diff --git a/backend_test.go b/backend_test.go index 02a5e9ca..a26f2c40 100644 --- a/backend_test.go +++ b/backend_test.go @@ -3,7 +3,7 @@ package pgproto3_test import ( "testing" - "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgproto3" ) func TestBackendReceiveInterrupted(t *testing.T) { diff --git a/bind.go b/bind.go index cceee6ab..459e5ff2 100644 --- a/bind.go +++ b/bind.go @@ -6,7 +6,7 @@ import ( "encoding/hex" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type Bind struct { diff --git a/close.go b/close.go index 5ff4c886..4e497549 100644 --- a/close.go +++ b/close.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type Close struct { diff --git a/command_complete.go b/command_complete.go index 85848532..0012f6f0 100644 --- a/command_complete.go +++ b/command_complete.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type CommandComplete struct { diff --git a/copy_both_response.go b/copy_both_response.go index 2862a34f..aa59d52a 100644 --- a/copy_both_response.go +++ b/copy_both_response.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type CopyBothResponse struct { diff --git a/copy_data.go b/copy_data.go index fab139e6..490d3d80 100644 --- a/copy_data.go +++ b/copy_data.go @@ -4,7 +4,7 @@ import ( "encoding/hex" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type CopyData struct { diff --git a/copy_fail.go b/copy_fail.go index 432a311b..e086207a 100644 --- a/copy_fail.go +++ b/copy_fail.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type CopyFail struct { diff --git a/copy_in_response.go b/copy_in_response.go index 54083cd6..3ddeeb40 100644 --- a/copy_in_response.go +++ b/copy_in_response.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type CopyInResponse struct { diff --git a/copy_out_response.go b/copy_out_response.go index eaa33b8b..01a64228 100644 --- a/copy_out_response.go +++ b/copy_out_response.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type CopyOutResponse struct { diff --git a/data_row.go b/data_row.go index e46d3cc0..0da18b06 100644 --- a/data_row.go +++ b/data_row.go @@ -5,7 +5,7 @@ import ( "encoding/hex" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type DataRow struct { diff --git a/describe.go b/describe.go index bb7bc056..86016ebc 100644 --- a/describe.go +++ b/describe.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type Describe struct { diff --git a/execute.go b/execute.go index 76da9943..71713f49 100644 --- a/execute.go +++ b/execute.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type Execute struct { diff --git a/frontend.go b/frontend.go index 31a955bc..00cb68b4 100644 --- a/frontend.go +++ b/frontend.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "io" - "github.com/jackc/pgx/chunkreader" + "github.com/jackc/chunkreader" "github.com/pkg/errors" ) diff --git a/frontend_test.go b/frontend_test.go index 7d6652c1..49484e01 100644 --- a/frontend_test.go +++ b/frontend_test.go @@ -5,7 +5,7 @@ import ( "github.com/pkg/errors" - "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgproto3" ) type interruptReader struct { diff --git a/function_call_response.go b/function_call_response.go index bb325b69..f14f8452 100644 --- a/function_call_response.go +++ b/function_call_response.go @@ -5,7 +5,7 @@ import ( "encoding/hex" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type FunctionCallResponse struct { diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..2c2b401e --- /dev/null +++ b/go.mod @@ -0,0 +1,9 @@ +module github.com/jackc/pgproto3 + +go 1.12 + +require ( + github.com/jackc/chunkreader v1.0.0 + github.com/jackc/pgio v1.0.0 + github.com/pkg/errors v0.8.1 +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..887dd869 --- /dev/null +++ b/go.sum @@ -0,0 +1,6 @@ +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/notification_response.go b/notification_response.go index b14007b4..2b32b10c 100644 --- a/notification_response.go +++ b/notification_response.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type NotificationResponse struct { diff --git a/parameter_description.go b/parameter_description.go index 1fa3c927..9d964129 100644 --- a/parameter_description.go +++ b/parameter_description.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type ParameterDescription struct { diff --git a/parameter_status.go b/parameter_status.go index b3bac33f..d370a4c1 100644 --- a/parameter_status.go +++ b/parameter_status.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type ParameterStatus struct { diff --git a/parse.go b/parse.go index ca4834c6..6f17175b 100644 --- a/parse.go +++ b/parse.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type Parse struct { diff --git a/password_message.go b/password_message.go index 2ad3fe4a..30377cbe 100644 --- a/password_message.go +++ b/password_message.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type PasswordMessage struct { diff --git a/query.go b/query.go index d80c0fb4..16228cb4 100644 --- a/query.go +++ b/query.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type Query struct { diff --git a/row_description.go b/row_description.go index 7deba379..7f46ede3 100644 --- a/row_description.go +++ b/row_description.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) const ( diff --git a/startup_message.go b/startup_message.go index 6c5d4f99..93a3d992 100644 --- a/startup_message.go +++ b/startup_message.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) From b9d0da5558b1daaf359e39117f7904fc80d1ab64 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:32:39 -0500 Subject: [PATCH 0221/1158] Add readme, license, and docs --- LICENSE | 22 ++++++++++++++++++++++ README.md | 8 ++++++++ doc.go | 2 ++ 3 files changed, 32 insertions(+) create mode 100644 LICENSE create mode 100644 README.md create mode 100644 doc.go diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..c1c4f50f --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2019 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 00000000..3491780e --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +[![](https://godoc.org/github.com/jackc/pgproto3?status.svg)](https://godoc.org/github.com/jackc/pgproto3) +[![Build Status](https://travis-ci.org/jackc/pgproto3.svg)](https://travis-ci.org/jackc/pgproto3) + +# pgproto3 + +Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. + +Extracted from original implementation in https://github.com/jackc/pgx. diff --git a/doc.go b/doc.go new file mode 100644 index 00000000..75340210 --- /dev/null +++ b/doc.go @@ -0,0 +1,2 @@ +// Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. +package pgproto3 From 127e9976962da24c9da3ec6cf97b5191aa88bebf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:33:04 -0500 Subject: [PATCH 0222/1158] Add travis CI --- .travis.yml | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..e176228e --- /dev/null +++ b/.travis.yml @@ -0,0 +1,9 @@ +language: go + +go: + - 1.x + - tip + +matrix: + allow_failures: + - go: tip From bb06e6b3ff87b1433242ae482d677bf3f6fcd176 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:46:56 -0500 Subject: [PATCH 0223/1158] Decouple github.com/jackc/chunkreader --- backend.go | 6 ++---- backend_test.go | 2 +- chunkreader.go | 18 ++++++++++++++++++ frontend.go | 6 ++---- frontend_test.go | 2 +- 5 files changed, 24 insertions(+), 10 deletions(-) create mode 100644 chunkreader.go diff --git a/backend.go b/backend.go index b64f006c..7f11bc7f 100644 --- a/backend.go +++ b/backend.go @@ -4,12 +4,11 @@ import ( "encoding/binary" "io" - "github.com/jackc/chunkreader" "github.com/pkg/errors" ) type Backend struct { - cr *chunkreader.ChunkReader + cr ChunkReader w io.Writer // Frontend message flyweights @@ -31,8 +30,7 @@ type Backend struct { partialMsg bool } -func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { - cr := chunkreader.NewChunkReader(r) +func NewBackend(cr ChunkReader, w io.Writer) (*Backend, error) { return &Backend{cr: cr, w: w}, nil } diff --git a/backend_test.go b/backend_test.go index a26f2c40..6cba81b6 100644 --- a/backend_test.go +++ b/backend_test.go @@ -12,7 +12,7 @@ func TestBackendReceiveInterrupted(t *testing.T) { server := &interruptReader{} server.push([]byte{'Q', 0, 0, 0, 6}) - backend, err := pgproto3.NewBackend(server, nil) + backend, err := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) if err != nil { t.Fatal(err) } diff --git a/chunkreader.go b/chunkreader.go new file mode 100644 index 00000000..230335c4 --- /dev/null +++ b/chunkreader.go @@ -0,0 +1,18 @@ +package pgproto3 + +import ( + "io" + + "github.com/jackc/chunkreader" +) + +// ChunkReader is an interface to decouple github.com/jackc/chunkreader from this package. +type ChunkReader interface { + // Next returns buf filled with the next n bytes. If an error occurs, buf will be nil. Next must + // not reuse buf. In case of error, Next must preserve partially read data. + Next(n int) (buf []byte, err error) +} + +func NewChunkReader(r io.Reader) ChunkReader { + return chunkreader.NewChunkReader(r) +} diff --git a/frontend.go b/frontend.go index 00cb68b4..ce94f49f 100644 --- a/frontend.go +++ b/frontend.go @@ -4,12 +4,11 @@ import ( "encoding/binary" "io" - "github.com/jackc/chunkreader" "github.com/pkg/errors" ) type Frontend struct { - cr *chunkreader.ChunkReader + cr ChunkReader w io.Writer // Backend message flyweights @@ -42,8 +41,7 @@ type Frontend struct { partialMsg bool } -func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { - cr := chunkreader.NewChunkReader(r) +func NewFrontend(cr ChunkReader, w io.Writer) (*Frontend, error) { return &Frontend{cr: cr, w: w}, nil } diff --git a/frontend_test.go b/frontend_test.go index 49484e01..d3e57f81 100644 --- a/frontend_test.go +++ b/frontend_test.go @@ -37,7 +37,7 @@ func TestFrontendReceiveInterrupted(t *testing.T) { server := &interruptReader{} server.push([]byte{'Z', 0, 0, 0, 5}) - frontend, err := pgproto3.NewFrontend(server, nil) + frontend, err := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil) if err != nil { t.Fatal(err) } From 97a0ac4ddc3f3071c9ade572666ca8fee2dd171e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:52:55 -0500 Subject: [PATCH 0224/1158] Clarify ChunkReader.Next contract --- chunkreader.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chunkreader.go b/chunkreader.go index 230335c4..0acfd4bf 100644 --- a/chunkreader.go +++ b/chunkreader.go @@ -8,8 +8,8 @@ import ( // ChunkReader is an interface to decouple github.com/jackc/chunkreader from this package. type ChunkReader interface { - // Next returns buf filled with the next n bytes. If an error occurs, buf will be nil. Next must - // not reuse buf. In case of error, Next must preserve partially read data. + // Next returns buf filled with the next n bytes. If an error (including a partial read) occurs, + // buf must be nil. Next must preserve any partially read data. Next must not reuse buf. Next(n int) (buf []byte, err error) } From fbdfccf1f91a4c0bc042cb37f3c7c2c9e27a4877 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:55:56 -0500 Subject: [PATCH 0225/1158] Use Go modules --- .gitignore | 1 + benchmark_test.go | 2 +- config.go | 2 +- config_test.go | 2 +- go.mod | 11 +++++++++++ go.sum | 19 +++++++++++++++++++ helper_test.go | 2 +- pgconn.go | 6 +++--- pgconn_stress_test.go | 2 +- pgconn_test.go | 4 ++-- 10 files changed, 41 insertions(+), 10 deletions(-) create mode 100644 .gitignore create mode 100644 go.mod create mode 100644 go.sum diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..7a6353d6 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.envrc diff --git a/benchmark_test.go b/benchmark_test.go index d2576324..959e86be 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -5,7 +5,7 @@ import ( "os" "testing" - "github.com/jackc/pgx/pgconn" + "github.com/jackc/pgconn" "github.com/stretchr/testify/require" ) diff --git a/config.go b/config.go index fec1fedf..1cde9c57 100644 --- a/config.go +++ b/config.go @@ -17,7 +17,7 @@ import ( "strings" "time" - "github.com/jackc/pgx/pgpassfile" + "github.com/jackc/pgpassfile" "github.com/pkg/errors" ) diff --git a/config_test.go b/config_test.go index c7b65861..ce6f3957 100644 --- a/config_test.go +++ b/config_test.go @@ -8,7 +8,7 @@ import ( "os/user" "testing" - "github.com/jackc/pgx/pgconn" + "github.com/jackc/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..3dc806a4 --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/jackc/pgconn + +go 1.12 + +require ( + github.com/jackc/pgio v1.0.0 + github.com/jackc/pgpassfile v1.0.0 + github.com/jackc/pgproto3 v1.0.0 + github.com/pkg/errors v0.8.1 + github.com/stretchr/testify v1.3.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..5b6f835b --- /dev/null +++ b/go.sum @@ -0,0 +1,19 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v0.0.0-20190330174656-bb06e6b3ff87 h1:xueDi0R+HxuFmuOA1xyFbbF+2LSXqWQJZSPWmmMFB0A= +github.com/jackc/pgproto3 v0.0.0-20190330174656-bb06e6b3ff87/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3 v1.0.0 h1:25tUmlES7eyD96oYaUHc1dLOFbgcJtFzCdnOOoqmA1I= +github.com/jackc/pgproto3 v1.0.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= diff --git a/helper_test.go b/helper_test.go index c5ac6e01..5d44f3b8 100644 --- a/helper_test.go +++ b/helper_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/jackc/pgx/pgconn" + "github.com/jackc/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/pgconn.go b/pgconn.go index c785f367..6490617a 100644 --- a/pgconn.go +++ b/pgconn.go @@ -15,8 +15,8 @@ import ( "sync" "time" - "github.com/jackc/pgx/pgio" - "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgio" + "github.com/jackc/pgproto3" ) var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) @@ -171,7 +171,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } } - pgConn.Frontend, err = pgproto3.NewFrontend(pgConn.conn, pgConn.conn) + pgConn.Frontend, err = pgproto3.NewFrontend(pgproto3.NewChunkReader(pgConn.conn), pgConn.conn) if err != nil { return nil, err } diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go index 7a95fa98..1ebbe04a 100644 --- a/pgconn_stress_test.go +++ b/pgconn_stress_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/jackc/pgx/pgconn" + "github.com/jackc/pgconn" "github.com/stretchr/testify/require" ) diff --git a/pgconn_test.go b/pgconn_test.go index dbf9b840..716761ad 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -15,8 +15,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/pgconn" - "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgconn" + "github.com/jackc/pgproto3" "github.com/pkg/errors" "github.com/stretchr/testify/assert" From 08fcc7f2736a16192388fc08a0dc7863951e2c69 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:59:04 -0500 Subject: [PATCH 0226/1158] Add license and readme --- LICENSE | 22 ++++++++++++++++++++++ README.md | 8 ++++++++ 2 files changed, 30 insertions(+) create mode 100644 LICENSE create mode 100644 README.md diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..c1c4f50f --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2019 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 00000000..8a881009 --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +[![](https://godoc.org/github.com/jackc/pgconn?status.svg)](https://godoc.org/github.com/jackc/pgconn) +[![Build Status](https://travis-ci.org/jackc/pgconn.svg)](https://travis-ci.org/jackc/pgconn) + +# pgconn + +Package pgconn is a low-level PostgreSQL database driver. + +It is intended to serve as the foundation for the next generation of https://github.com/jackc/pgx. From b2fc69d32f5cdf79a4119888a36c997dcecdc073 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 13:03:28 -0500 Subject: [PATCH 0227/1158] Import pgx travis config --- .travis.yml | 38 +++++++++++++++++++++++++++++++++++++ travis/before_install.bash | 39 ++++++++++++++++++++++++++++++++++++++ travis/before_script.bash | 16 ++++++++++++++++ travis/install.bash | 14 ++++++++++++++ travis/script.bash | 10 ++++++++++ 5 files changed, 117 insertions(+) create mode 100644 .travis.yml create mode 100755 travis/before_install.bash create mode 100755 travis/before_script.bash create mode 100755 travis/install.bash create mode 100755 travis/script.bash diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..950792d1 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,38 @@ +language: go + +go: + - 1.x + - tip + +# Derived from https://github.com/lib/pq/blob/master/.travis.yml +before_install: + - ./travis/before_install.bash + +env: + global: + - PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test + - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" + - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test + - PGX_TEST_TLS_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + - PGX_TEST_MD5_PASSWORD_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test + - PGX_TEST_PLAIN_PASSWORD_CONN_STRING=postgres://pgx_pw:secret@127.0.0.1/pgx_test + matrix: + - CRATEVERSION=2.1 PGX_TEST_CRATEDB_CONN_STRING="host=127.0.0.1 port=6543 user=pgx database=pgx_test" + - PGVERSION=10 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret database=pgx_test" + - PGVERSION=9.6 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret database=pgx_test" + - PGVERSION=9.5 + - PGVERSION=9.4 + - PGVERSION=9.3 + +before_script: + - ./travis/before_script.bash + +install: + - ./travis/install.bash + +script: + - ./travis/script.bash + +matrix: + allow_failures: + - go: tip diff --git a/travis/before_install.bash b/travis/before_install.bash new file mode 100755 index 00000000..23c7d9cf --- /dev/null +++ b/travis/before_install.bash @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +set -eux + +if [ "${PGVERSION-}" != "" ] +then + sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common + sudo rm -rf /var/lib/postgresql + wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - + sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list" + sudo apt-get update -qq + sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION + sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf + if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then + echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf + echo "max_wal_senders=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf + echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf + fi + sudo /etc/init.d/postgresql restart +fi + +if [ "${CRATEVERSION-}" != "" ] +then + docker run \ + -p "6543:5432" \ + -d \ + crate:"$CRATEVERSION" \ + crate \ + -Cnetwork.host=0.0.0.0 \ + -Ctransport.host=localhost \ + -Clicense.enterprise=false +fi diff --git a/travis/before_script.bash b/travis/before_script.bash new file mode 100755 index 00000000..bcf748a1 --- /dev/null +++ b/travis/before_script.bash @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -eux + +if [ "${PGVERSION-}" != "" ] +then + # The tricky test user, below, has to actually exist so that it can be used in a test + # of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. + psql -U postgres -c 'create database pgx_test' + psql -U postgres pgx_test -c 'create extension hstore' + psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)' + psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user pgx_replication with replication password 'secret'" + psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" +fi diff --git a/travis/install.bash b/travis/install.bash new file mode 100755 index 00000000..63ba875d --- /dev/null +++ b/travis/install.bash @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +set -eux + +go get -u github.com/cockroachdb/apd +go get -u github.com/shopspring/decimal +go get -u gopkg.in/inconshreveable/log15.v2 +go get -u github.com/jackc/fake +go get -u github.com/lib/pq +go get -u github.com/hashicorp/go-version +go get -u github.com/satori/go.uuid +go get -u github.com/sirupsen/logrus +go get -u github.com/pkg/errors +go get -u go.uber.org/zap +go get -u github.com/rs/zerolog diff --git a/travis/script.bash b/travis/script.bash new file mode 100755 index 00000000..5bf1b77e --- /dev/null +++ b/travis/script.bash @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +set -eux + +if [ "${PGVERSION-}" != "" ] +then + go test -v -race ./... +elif [ "${CRATEVERSION-}" != "" ] +then + go test -v -race -run 'TestCrateDBConnect' +fi From 444bd6deaf2065c0d108dadfa042df36af88ea57 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 16:44:20 -0500 Subject: [PATCH 0228/1158] Context cancellation is fatal during query --- config.go | 11 --- doc.go | 10 +-- pgconn.go | 189 +++--------------------------------------- pgconn_stress_test.go | 62 ++++---------- pgconn_test.go | 99 ++++++---------------- 5 files changed, 60 insertions(+), 311 deletions(-) diff --git a/config.go b/config.go index 1cde9c57..d392924c 100644 --- a/config.go +++ b/config.go @@ -41,17 +41,6 @@ type Config struct { // allows implementing high availability behavior such as libpq does with target_session_attrs. AfterConnectFunc AfterConnectFunc - // OnContextCancel is a callback function used to override cancellation behavior. It is called when a context.Context - // is canceled. Default cancellation behavior is to establish another connection to the PostgreSQL server and send a - // query cancel request. Some non-PostgreSQL servers (e.g. CockroachDB) that speak a subset of the PostgreSQL wire - // protocol do not support this cancellation method. - // - // It is called from a background goroutine. When the cancellation process has finished ContextCancel.Finish must be - // called whether it was successful or not. If an error occurs the connection should be closed. The connection must be - // in a ready for query state or be closed when ContextCancel.Finish is called. Use PgConn.ReceiveMessage() to read - // the connection until a ready for query message is received. - OnContextCancel func(*ContextCancel) - // OnNotice is a callback function called when a notice response is received. OnNotice NoticeHandler diff --git a/doc.go b/doc.go index 89e47536..d36eb0fd 100644 --- a/doc.go +++ b/doc.go @@ -20,10 +20,10 @@ result. The ReadAll method reads all query results into memory. Context Support -All potentially blocking operations take a context.Context. If a context is canceled while a query is in progress the -method immediately returns. In the background a cancel request will be sent to the PostgreSQL server. If the -cancellation fails or hangs for more than a short time (approximately 15 seconds) the connection will be closed. It is -safe to use the connection while this background cancellation is in progress. Any calls will block until the -cancellation and resynchronization is complete (and those calls can be aborted by a context cancellation). +All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the +method immediately returns. In most circumstances, this will close the underlying connection. + +The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the +client to abort. */ package pgconn diff --git a/pgconn.go b/pgconn.go index 6490617a..8b0ddcb4 100644 --- a/pgconn.go +++ b/pgconn.go @@ -199,6 +199,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig for { msg, err := pgConn.ReceiveMessage() if err != nil { + pgConn.conn.Close() return nil, err } @@ -502,7 +503,7 @@ readloop: for { msg, err := pgConn.ReceiveMessage() if err != nil { - go pgConn.recoverFromTimeout() + pgConn.hardClose() return nil, preferContextOverNetTimeoutError(ctx, err) } @@ -555,10 +556,10 @@ func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { return (*Notice)(pgerr) } -// cancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel +// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel // request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there // is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 -func (pgConn *PgConn) cancelRequest(ctx context.Context) error { +func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. @@ -590,21 +591,6 @@ func (pgConn *PgConn) cancelRequest(ctx context.Context) error { return nil } -// WaitUntilReady waits until a previous context cancellation has been completed and the connection is ready for use. -// This is done automatically by all methods that need the connection to be ready for use. The only expected use for -// this method is for a connection pool to wait for a returned connection to be usable again before making it available. -func (pgConn *PgConn) WaitUntilReady(ctx context.Context) error { - select { - case <-ctx.Done(): - return ctx.Err() - case pgConn.controller <- pgConn: - // The connection must be ready since it was locked. Immediately unlock it. - <-pgConn.controller - } - - return nil -} - // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { @@ -778,6 +764,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm case pgConn.controller <- pgConn: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanupContextDeadline() // Send copy to command var buf []byte @@ -786,7 +773,6 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - cleanupContextDeadline() <-pgConn.controller return "", preferContextOverNetTimeoutError(ctx, err) @@ -798,13 +784,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm for { msg, err := pgConn.ReceiveMessage() if err != nil { - cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeout() - } else { - <-pgConn.controller - } - + pgConn.hardClose() return "", preferContextOverNetTimeoutError(ctx, err) } @@ -813,9 +793,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm case *pgproto3.CopyData: _, err := w.Write(msg.Data) if err != nil { - // This isn't actually a timeout, but we want the same behavior. Abort the request and cleanup. - cleanupContextDeadline() - go pgConn.recoverFromTimeout() + pgConn.hardClose() return "", err } case *pgproto3.ReadyForQuery: @@ -840,6 +818,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case pgConn.controller <- pgConn: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanupContextDeadline() // Send copy to command var buf []byte @@ -848,7 +827,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - cleanupContextDeadline() <-pgConn.controller return "", preferContextOverNetTimeoutError(ctx, err) @@ -861,13 +839,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co for pendingCopyInResponse { msg, err := pgConn.ReceiveMessage() if err != nil { - cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeoutDuringCopyFrom() - } else { - <-pgConn.controller - } - + pgConn.hardClose() return "", preferContextOverNetTimeoutError(ctx, err) } @@ -899,7 +871,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - cleanupContextDeadline() <-pgConn.controller return "", preferContextOverNetTimeoutError(ctx, err) @@ -910,13 +881,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case <-signalMessageChan: msg, err := pgConn.ReceiveMessage() if err != nil { - cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeoutDuringCopyFrom() - } else { - <-pgConn.controller - } - + pgConn.hardClose() return "", preferContextOverNetTimeoutError(ctx, err) } @@ -939,8 +904,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - - cleanupContextDeadline() <-pgConn.controller return "", preferContextOverNetTimeoutError(ctx, err) @@ -950,13 +913,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co for { msg, err := pgConn.ReceiveMessage() if err != nil { - cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeout() - } else { - <-pgConn.controller - } - + pgConn.hardClose() return "", preferContextOverNetTimeoutError(ctx, err) } @@ -972,47 +929,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } } -func (pgConn *PgConn) recoverFromTimeoutDuringCopyFrom() { - // Regardless of recovery outcome the lock on the pgConn must be released. - defer func() { <-pgConn.controller }() - - // Limit time to wait for entire cancellation process. - err := pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second)) - if err != nil { - pgConn.hardClose() - return - } - - copyFail := &pgproto3.CopyFail{Error: "client cancel"} - buf := copyFail.Encode(nil) - - _, err = pgConn.conn.Write(buf) - if err != nil { - pgConn.hardClose() - return - } - - pendingReadyForQuery := true - - for pendingReadyForQuery { - msg, err := pgConn.ReceiveMessage() - if err != nil { - pgConn.hardClose() - return - } - - switch msg.(type) { - case *pgproto3.ReadyForQuery: - pendingReadyForQuery = false - } - } - - err = pgConn.conn.SetDeadline(time.Time{}) - if err != nil { - pgConn.hardClose() - } -} - // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { pgConn *PgConn @@ -1044,13 +960,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) mrr.cleanupContextDeadline() mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.closed = true - - if err, ok := err.(net.Error); ok && err.Timeout() { - go mrr.pgConn.recoverFromTimeout() - } else { - <-mrr.pgConn.controller - } - + mrr.pgConn.hardClose() return nil, mrr.err } @@ -1236,11 +1146,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error rr.cleanupContextDeadline() rr.closed = true if rr.multiResultReader == nil { - if err, ok := err.(net.Error); ok && err.Timeout() { - go rr.pgConn.recoverFromTimeout() - } else { - <-rr.pgConn.controller - } + rr.pgConn.hardClose() } return nil, rr.err @@ -1270,75 +1176,6 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { rr.commandConcluded = true } -func (pgConn *PgConn) defaultCancel() { - // Regardless of recovery outcome the lock on the pgConn must be released. - defer func() { <-pgConn.controller }() - - // Send a cancellation request to the PostgreSQL server. If it is not successful in a reasonable amount of time do not - // try further to recover the connection. - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - err := pgConn.cancelRequest(ctx) - cancel() - if err != nil { - pgConn.hardClose() - return - } - - // Limit time to wait for ReadyForQuery message. - err = pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second)) - if err != nil { - pgConn.hardClose() - return - } - - // A cancel query request will always return a "57014" error response, even if no query was in progress. This error - // may be returned before or after the ReadyForQuery message. Must ensure both messages are read. - needError57014 := true - needReadyForQuery := true - - for needError57014 || needReadyForQuery { - msg, err := pgConn.ReceiveMessage() - if err != nil { - pgConn.hardClose() - return - } - - switch msg := msg.(type) { - case *pgproto3.ErrorResponse: - if msg.Code == "57014" { - needError57014 = false - } - case *pgproto3.ReadyForQuery: - needReadyForQuery = false - } - } - - err = pgConn.conn.SetDeadline(time.Time{}) - if err != nil { - pgConn.hardClose() - } -} - -type ContextCancel struct { - PgConn *PgConn -} - -// Finish must be called when the cancellation request has finished processing. The connection must be in a ready for -// query state or the connection must be closed. This must be called regardless of the success of the cancellation and -// whether the connection is still valid or not. It releases an internal busy lock on the connection. -func (cc *ContextCancel) Finish() { - <-cc.PgConn.controller -} - -func (pgConn *PgConn) recoverFromTimeout() { - if pgConn.Config.OnContextCancel == nil { - pgConn.defaultCancel() - } else { - cc := &ContextCancel{PgConn: pgConn} - pgConn.Config.OnContextCancel(cc) - } -} - // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { buf []byte diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go index 1ebbe04a..7288c9b4 100644 --- a/pgconn_stress_test.go +++ b/pgconn_stress_test.go @@ -4,9 +4,9 @@ import ( "context" "math/rand" "os" + "runtime" "strconv" "testing" - "time" "github.com/jackc/pgconn" @@ -14,13 +14,11 @@ import ( ) func TestConnStress(t *testing.T) { - t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer closeConn(t, pgConn) - actionCount := 100 + actionCount := 10000 if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" { stressFactor, err := strconv.ParseInt(s, 10, 64) require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR") @@ -36,9 +34,6 @@ func TestConnStress(t *testing.T) { {"Exec Select", stressExecSelect}, {"ExecParams Select", stressExecParamsSelect}, {"Batch", stressBatch}, - {"ExecCanceled", stressExecSelectCanceled}, - {"ExecParamsCanceled", stressExecParamsSelectCanceled}, - {"BatchCanceled", stressBatchCanceled}, } for i := 0; i < actionCount; i++ { @@ -46,6 +41,10 @@ func TestConnStress(t *testing.T) { err := action.fn(pgConn) require.Nilf(t, err, "%d: %s", i, action.name) } + + // Each call with a context starts a goroutine. Ensure they are cleaned up when context is not canceled. + numGoroutine := runtime.NumGoroutine() + require.Truef(t, numGoroutine < 1000, "goroutines appear to be orphaned: %d in process", numGoroutine) } func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { @@ -65,56 +64,27 @@ func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { } func stressExecSelect(pgConn *pgconn.PgConn) error { - _, err := pgConn.Exec(context.Background(), "select * from widgets").ReadAll() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := pgConn.Exec(ctx, "select * from widgets").ReadAll() return err } func stressExecParamsSelect(pgConn *pgconn.PgConn) error { - result := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + result := pgConn.ExecParams(ctx, "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() return result.Err } func stressBatch(pgConn *pgconn.PgConn) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + batch := &pgconn.Batch{} batch.ExecParams("select * from widgets", nil, nil, nil, nil) batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - _, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + _, err := pgConn.ExecBatch(ctx, batch).ReadAll() return err } - -func stressExecSelectCanceled(pgConn *pgconn.PgConn) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - _, err := pgConn.Exec(ctx, "select *, pg_sleep(1) from widgets").ReadAll() - cancel() - if err != context.DeadlineExceeded { - return err - } - - return nil -} - -func stressExecParamsSelectCanceled(pgConn *pgconn.PgConn) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - result := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() - cancel() - if result.Err != context.DeadlineExceeded { - return result.Err - } - - return nil -} - -func stressBatchCanceled(pgConn *pgconn.PgConn) error { - batch := &pgconn.Batch{} - batch.ExecParams("select * from widgets", nil, nil, nil, nil) - batch.ExecParams("select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - _, err := pgConn.ExecBatch(ctx, batch).ReadAll() - cancel() - if err != context.DeadlineExceeded { - return err - } - - return nil -} diff --git a/pgconn_test.go b/pgconn_test.go index 716761ad..88c6f7c4 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -16,7 +16,6 @@ import ( "time" "github.com/jackc/pgconn" - "github.com/jackc/pgproto3" "github.com/pkg/errors" "github.com/stretchr/testify/assert" @@ -356,8 +355,7 @@ func TestConnExecContextCanceled(t *testing.T) { } err = multiResult.Close() assert.Equal(t, context.DeadlineExceeded, err) - - ensureConnValid(t, pgConn) + assert.False(t, pgConn.IsAlive()) } func TestConnExecParams(t *testing.T) { @@ -400,7 +398,7 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.Equal(t, pgconn.CommandTag(""), commandTag) assert.Equal(t, context.DeadlineExceeded, err) - ensureConnValid(t, pgConn) + assert.False(t, pgConn.IsAlive()) } func TestConnExecPrepared(t *testing.T) { @@ -451,8 +449,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(""), commandTag) assert.Equal(t, context.DeadlineExceeded, err) - - ensureConnValid(t, pgConn) + assert.False(t, pgConn.IsAlive()) } func TestConnExecBatch(t *testing.T) { @@ -510,72 +507,6 @@ func TestCommandTag(t *testing.T) { } } -func TestConnContextCancelWithOnContextCancel(t *testing.T) { - t.Parallel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - calledChan := make(chan struct{}) - - config.OnContextCancel = func(cc *pgconn.ContextCancel) { - defer cc.Finish() - close(calledChan) - - for { - msg, err := cc.PgConn.ReceiveMessage() - if err != nil { - cc.PgConn.Close(context.Background()) - return - } - - switch msg.(type) { - case *pgproto3.ReadyForQuery: - return - } - } - } - - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - result := pgConn.ExecParams(ctx, "select 'Hello, world', pg_sleep(0.25)", nil, nil, nil, nil) - _, err = result.Close() - assert.Equal(t, context.DeadlineExceeded, err) - - called := false - select { - case <-calledChan: - called = true - case <-time.NewTimer(time.Second).C: - } - - assert.True(t, called) - - ensureConnValid(t, pgConn) -} - -func TestConnWaitUntilReady(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil).Read() - assert.Equal(t, context.DeadlineExceeded, result.Err) - - err = pgConn.WaitUntilReady(context.Background()) - require.NoError(t, err) - - ensureConnValid(t, pgConn) -} - func TestConnOnNotice(t *testing.T) { t.Parallel() @@ -792,7 +723,7 @@ func TestConnCopyToCanceled(t *testing.T) { assert.Equal(t, context.DeadlineExceeded, err) assert.Equal(t, pgconn.CommandTag(""), res) - ensureConnValid(t, pgConn) + assert.False(t, pgConn.IsAlive()) } func TestConnCopyFrom(t *testing.T) { @@ -991,6 +922,28 @@ func TestConnEscapeString(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnCancelRequest(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(5)") + + err = pgConn.CancelRequest(context.Background()) + require.NoError(t, err) + + for multiResult.NextResult() { + } + err = multiResult.Close() + + require.IsType(t, &pgconn.PgError{}, err) + require.Equal(t, "57014", err.(*pgconn.PgError).Code) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { From 3d9e42d74c14ed6f091449fc7602727c3dc49d07 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 17:09:39 -0500 Subject: [PATCH 0229/1158] Replace chan based conn locking with bool This is conceptually simpler and will lead to error messages instead of deadlocks. --- pgconn.go | 112 ++++++++++++++++++++++++++++++++++++------------- pgconn_test.go | 23 ++++++++++ 2 files changed, 106 insertions(+), 29 deletions(-) diff --git a/pgconn.go b/pgconn.go index 8b0ddcb4..e246bcdd 100644 --- a/pgconn.go +++ b/pgconn.go @@ -89,8 +89,7 @@ type PgConn struct { Config *Config - controller chan interface{} - + locked bool closed bool bufferingReceive bool @@ -153,7 +152,6 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.Config = config - pgConn.controller = make(chan interface{}, 1) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) @@ -405,6 +403,29 @@ func (pgConn *PgConn) IsAlive() bool { return !pgConn.closed } +// lock locks the connection. It returns an error if the connection is already locked or is closed. +func (pgConn *PgConn) lock() error { + if pgConn.locked { + return errors.New("connection busy") + } + + if pgConn.closed { + return errors.New("connection closed") + } + + pgConn.locked = true + + return nil +} + +func (pgConn *PgConn) unlock() { + if !pgConn.locked { + panic("BUG: cannot unlock unlocked connection") + } + + pgConn.locked = false +} + // ParameterStatus returns the value of a parameter reported by the server (e.g. // server_version). Returns an empty string for unknown parameters. func (pgConn *PgConn) ParameterStatus(key string) string { @@ -476,10 +497,14 @@ type PreparedStatementDescription struct { // Prepare creates a prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + select { case <-ctx.Done(): return nil, ctx.Err() - case pgConn.controller <- pgConn: + default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContextDeadline() @@ -521,7 +546,7 @@ readloop: } } - <-pgConn.controller + pgConn.unlock() if parseErr != nil { return nil, parseErr @@ -594,14 +619,18 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { + if err := pgConn.lock(); err != nil { + return err + } + select { case <-ctx.Done(): return ctx.Err() - case pgConn.controller <- pgConn: + default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContextDeadline() - defer func() { <-pgConn.controller }() + defer pgConn.unlock() for { msg, err := pgConn.ReceiveMessage() @@ -628,12 +657,18 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { cleanupContextDeadline: func() {}, } + if err := pgConn.lock(); err != nil { + multiResult.closed = true + multiResult.err = err + return multiResult + } + select { case <-ctx.Done(): multiResult.closed = true multiResult.err = ctx.Err() return multiResult - case pgConn.controller <- multiResult: + default: } multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) @@ -646,7 +681,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) - <-pgConn.controller + pgConn.unlock() return multiResult } @@ -679,12 +714,18 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] cleanupContextDeadline: func() {}, } + if err := pgConn.lock(); err != nil { + result.concludeCommand("", err) + result.closed = true + return result + } + select { case <-ctx.Done(): result.concludeCommand("", ctx.Err()) result.closed = true return result - case pgConn.controller <- result: + default: } result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) @@ -704,7 +745,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true - <-pgConn.controller + pgConn.unlock() } return result @@ -729,12 +770,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa cleanupContextDeadline: func() {}, } + if err := pgConn.lock(); err != nil { + result.concludeCommand("", err) + result.closed = true + return result + } + select { case <-ctx.Done(): result.concludeCommand("", ctx.Err()) result.closed = true return result - case pgConn.controller <- result: + default: } result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) @@ -750,7 +797,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true - <-pgConn.controller + pgConn.unlock() } return result @@ -758,10 +805,14 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { + if err := pgConn.lock(); err != nil { + return "", err + } + select { case <-ctx.Done(): return "", ctx.Err() - case pgConn.controller <- pgConn: + default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContextDeadline() @@ -773,7 +824,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - <-pgConn.controller + pgConn.unlock() return "", preferContextOverNetTimeoutError(ctx, err) } @@ -797,7 +848,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return "", err } case *pgproto3.ReadyForQuery: - <-pgConn.controller + pgConn.unlock() return commandTag, pgErr case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) @@ -812,10 +863,15 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { + if err := pgConn.lock(); err != nil { + return "", err + } + defer pgConn.unlock() + select { case <-ctx.Done(): return "", ctx.Err() - case pgConn.controller <- pgConn: + default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContextDeadline() @@ -827,8 +883,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - <-pgConn.controller - return "", preferContextOverNetTimeoutError(ctx, err) } @@ -849,7 +903,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case *pgproto3.ErrorResponse: pgErr = errorResponseToPgError(msg) case *pgproto3.ReadyForQuery: - <-pgConn.controller return commandTag, pgErr } } @@ -871,8 +924,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - <-pgConn.controller - return "", preferContextOverNetTimeoutError(ctx, err) } } @@ -904,8 +955,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - <-pgConn.controller - return "", preferContextOverNetTimeoutError(ctx, err) } @@ -919,7 +968,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - <-pgConn.controller return commandTag, pgErr case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) @@ -968,7 +1016,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) case *pgproto3.ReadyForQuery: mrr.cleanupContextDeadline() mrr.closed = true - <-mrr.pgConn.controller + mrr.pgConn.unlock() case *pgproto3.ErrorResponse: mrr.err = errorResponseToPgError(msg) } @@ -1125,7 +1173,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { switch msg.(type) { case *pgproto3.ReadyForQuery: rr.cleanupContextDeadline() - <-rr.pgConn.controller + rr.pgConn.unlock() return rr.commandTag, rr.err } } @@ -1203,12 +1251,18 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR cleanupContextDeadline: func() {}, } + if err := pgConn.lock(); err != nil { + multiResult.closed = true + multiResult.err = ctx.Err() + return multiResult + } + select { case <-ctx.Done(): multiResult.closed = true multiResult.err = ctx.Err() return multiResult - case pgConn.controller <- multiResult: + default: } multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) @@ -1219,7 +1273,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) - <-pgConn.controller + pgConn.unlock() return multiResult } diff --git a/pgconn_test.go b/pgconn_test.go index 88c6f7c4..53e3b9d8 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -484,6 +484,29 @@ func TestConnExecBatch(t *testing.T) { assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } +func TestConnLocking(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "connection busy", err.Error()) + + results, err = mrr.ReadAll() + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + ensureConnValid(t, pgConn) +} + func TestCommandTag(t *testing.T) { t.Parallel() From ed7d91dc987364b13d9039fe83a76aa993e9cdf9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 17:13:23 -0500 Subject: [PATCH 0230/1158] Force Go modules for Travis --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 950792d1..50e81eb5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,6 +10,7 @@ before_install: env: global: + - GO111MODULE=on - PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test From 0ac82007fba8770a7018f5757648b8f8a50d4af8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 5 Apr 2019 10:52:23 -0500 Subject: [PATCH 0231/1158] Use extracted packages with Go modules --- array.go | 2 +- bool_array.go | 2 +- box.go | 2 +- bpchar_array.go | 2 +- bytea_array.go | 2 +- cidr_array.go | 2 +- circle.go | 2 +- date.go | 2 +- date_array.go | 2 +- daterange.go | 2 +- float4.go | 2 +- float4_array.go | 2 +- float8.go | 2 +- float8_array.go | 2 +- hstore.go | 2 +- hstore_array.go | 2 +- inet_array.go | 2 +- int2.go | 2 +- int2_array.go | 2 +- int4.go | 2 +- int4_array.go | 2 +- int4range.go | 2 +- int8.go | 2 +- int8_array.go | 2 +- int8range.go | 2 +- interval.go | 2 +- line.go | 2 +- lseg.go | 2 +- macaddr_array.go | 2 +- numeric.go | 2 +- numeric_array.go | 2 +- numrange.go | 2 +- oid.go | 2 +- path.go | 2 +- pguint32.go | 2 +- point.go | 2 +- polygon.go | 2 +- text_array.go | 2 +- tid.go | 2 +- timestamp.go | 2 +- timestamp_array.go | 2 +- timestamptz.go | 2 +- timestamptz_array.go | 2 +- tsrange.go | 2 +- tstzrange.go | 2 +- typed_array.go.erb | 2 +- typed_range.go.erb | 2 +- uuid_array.go | 2 +- varbit.go | 2 +- varchar_array.go | 2 +- 50 files changed, 50 insertions(+), 50 deletions(-) diff --git a/array.go b/array.go index 5b852ed5..9ce0f003 100644 --- a/array.go +++ b/array.go @@ -8,7 +8,7 @@ import ( "strings" "unicode" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/bool_array.go b/bool_array.go index 4231e29d..623937dc 100644 --- a/bool_array.go +++ b/bool_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/box.go b/box.go index 4c5a4406..4c825c56 100644 --- a/box.go +++ b/box.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/bpchar_array.go b/bpchar_array.go index b3f36cb6..d1ee2419 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/bytea_array.go b/bytea_array.go index 9c094b28..68122961 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/cidr_array.go b/cidr_array.go index c254c834..338d4904 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "net" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/circle.go b/circle.go index 15ea447b..a3bb56f1 100644 --- a/circle.go +++ b/circle.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/date.go b/date.go index b1d4c11d..85c698aa 100644 --- a/date.go +++ b/date.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "time" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/date_array.go b/date_array.go index c0f5c21c..d04666f1 100644 --- a/date_array.go +++ b/date_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "time" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/daterange.go b/daterange.go index 47cd7e46..d10d34c0 100644 --- a/daterange.go +++ b/daterange.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/float4.go b/float4.go index 2207594a..c4feb0a7 100644 --- a/float4.go +++ b/float4.go @@ -6,7 +6,7 @@ import ( "math" "strconv" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/float4_array.go b/float4_array.go index fba181d3..4e07ba43 100644 --- a/float4_array.go +++ b/float4_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/float8.go b/float8.go index dd34f541..63944d45 100644 --- a/float8.go +++ b/float8.go @@ -6,7 +6,7 @@ import ( "math" "strconv" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/float8_array.go b/float8_array.go index 13dbf27f..e4c340b2 100644 --- a/float8_array.go +++ b/float8_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/hstore.go b/hstore.go index 71b030f9..754c5a3f 100644 --- a/hstore.go +++ b/hstore.go @@ -10,7 +10,7 @@ import ( "github.com/pkg/errors" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) // Hstore represents an hstore column that can be null or have null values diff --git a/hstore_array.go b/hstore_array.go index 2b8cf37e..239c5d9c 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/inet_array.go b/inet_array.go index dba369d2..7b4cf457 100644 --- a/inet_array.go +++ b/inet_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "net" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/int2.go b/int2.go index 6156ea77..72110684 100644 --- a/int2.go +++ b/int2.go @@ -6,7 +6,7 @@ import ( "math" "strconv" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/int2_array.go b/int2_array.go index 7fefbd95..5b4c2e1a 100644 --- a/int2_array.go +++ b/int2_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/int4.go b/int4.go index 261c5118..9ad878c4 100644 --- a/int4.go +++ b/int4.go @@ -7,7 +7,7 @@ import ( "math" "strconv" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/int4_array.go b/int4_array.go index 86656524..77ad8654 100644 --- a/int4_array.go +++ b/int4_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/int4range.go b/int4range.go index 95ad1521..67bbfcd2 100644 --- a/int4range.go +++ b/int4range.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/int8.go b/int8.go index 00a8cd00..39b8a0a8 100644 --- a/int8.go +++ b/int8.go @@ -7,7 +7,7 @@ import ( "math" "strconv" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/int8_array.go b/int8_array.go index 15a8398a..03b169d2 100644 --- a/int8_array.go +++ b/int8_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/int8range.go b/int8range.go index 61d860d3..25839a7b 100644 --- a/int8range.go +++ b/int8range.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/interval.go b/interval.go index dc696319..75969904 100644 --- a/interval.go +++ b/interval.go @@ -8,7 +8,7 @@ import ( "strings" "time" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/line.go b/line.go index 5fdc5604..6ac4ac2a 100644 --- a/line.go +++ b/line.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/lseg.go b/lseg.go index 4445ea51..c0e77799 100644 --- a/lseg.go +++ b/lseg.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/macaddr_array.go b/macaddr_array.go index bd8b4c5a..c6bc2450 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "net" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/numeric.go b/numeric.go index fb63df75..fb6e1a00 100644 --- a/numeric.go +++ b/numeric.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/numeric_array.go b/numeric_array.go index b5e38539..0d26f3b5 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/numrange.go b/numrange.go index aaed62ce..ff9d5372 100644 --- a/numrange.go +++ b/numrange.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/oid.go b/oid.go index 59370d66..2afc60f8 100644 --- a/oid.go +++ b/oid.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "strconv" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/path.go b/path.go index 69083712..c1b72322 100644 --- a/path.go +++ b/path.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pguint32.go b/pguint32.go index e441a690..37178b5c 100644 --- a/pguint32.go +++ b/pguint32.go @@ -6,7 +6,7 @@ import ( "math" "strconv" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/point.go b/point.go index 98a32d34..fefe5d1f 100644 --- a/point.go +++ b/point.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/polygon.go b/polygon.go index d84a0abd..904e86e1 100644 --- a/polygon.go +++ b/polygon.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/text_array.go b/text_array.go index d53f0b7b..ec487a23 100644 --- a/text_array.go +++ b/text_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/tid.go b/tid.go index 21852a14..e859865b 100644 --- a/tid.go +++ b/tid.go @@ -7,7 +7,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/timestamp.go b/timestamp.go index 6292521a..f8a4070d 100644 --- a/timestamp.go +++ b/timestamp.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "time" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/timestamp_array.go b/timestamp_array.go index 11b32a11..493088a2 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "time" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/timestamptz.go b/timestamptz.go index 2b9d2a64..ca9b538d 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "time" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/timestamptz_array.go b/timestamptz_array.go index 31c11f94..612e9904 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "time" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/tsrange.go b/tsrange.go index 8a67d65e..d771a761 100644 --- a/tsrange.go +++ b/tsrange.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/tstzrange.go b/tstzrange.go index b5129093..9a8c782e 100644 --- a/tstzrange.go +++ b/tstzrange.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/typed_array.go.erb b/typed_array.go.erb index 6b46a23e..b33e7d99 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -5,7 +5,7 @@ import ( "fmt" "io" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type <%= pgtype_array_type %> struct { diff --git a/typed_range.go.erb b/typed_range.go.erb index 91a5cb97..035a71af 100644 --- a/typed_range.go.erb +++ b/typed_range.go.erb @@ -6,7 +6,7 @@ import ( "fmt" "io" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type <%= range_type %> struct { diff --git a/uuid_array.go b/uuid_array.go index 13efdb23..cddd62f1 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/varbit.go b/varbit.go index dfa194d2..2c25b1fb 100644 --- a/varbit.go +++ b/varbit.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/varchar_array.go b/varchar_array.go index a7f23fba..0a929920 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) From c745509c595970c1776bee941d3fd969f313b845 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 5 Apr 2019 11:27:04 -0500 Subject: [PATCH 0232/1158] Rename test --- pgconn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn_test.go b/pgconn_test.go index 53e3b9d8..ab8ae173 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -239,7 +239,7 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { } } -func TestConnPrepareFailure(t *testing.T) { +func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) From 408837dcb1e5fb4535ab313178a64a6ad79d9bbb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 5 Apr 2019 11:47:31 -0500 Subject: [PATCH 0233/1158] Handle extended protocol with too many arguments --- pgconn.go | 17 ++++++++ pgconn_test.go | 106 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+) diff --git a/pgconn.go b/pgconn.go index e246bcdd..223b8e3d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "math" "net" "strconv" "strings" @@ -720,10 +721,18 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return result } + if len(paramValues) > math.MaxUint16 { + result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.closed = true + pgConn.unlock() + return result + } + select { case <-ctx.Done(): result.concludeCommand("", ctx.Err()) result.closed = true + pgConn.unlock() return result default: } @@ -776,10 +785,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } + if len(paramValues) > math.MaxUint16 { + result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.closed = true + pgConn.unlock() + return result + } + select { case <-ctx.Done(): result.concludeCommand("", ctx.Err()) result.closed = true + pgConn.unlock() return result default: } diff --git a/pgconn_test.go b/pgconn_test.go index ab8ae173..b2514e48 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -9,9 +9,11 @@ import ( "io" "io/ioutil" "log" + "math" "net" "os" "strconv" + "strings" "testing" "time" @@ -379,6 +381,52 @@ func TestConnExecParams(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecParamsMaxNumberOfParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsTooManyParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + + ensureConnValid(t, pgConn) +} + func TestConnExecParamsCanceled(t *testing.T) { t.Parallel() @@ -428,6 +476,64 @@ func TestConnExecPrepared(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) + + ensureConnValid(t, pgConn) +} + +func TestConnExecPreparedTooManyParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + + ensureConnValid(t, pgConn) +} + func TestConnExecPreparedCanceled(t *testing.T) { t.Parallel() From 7ad3625edd3b36e00d73c0c09009d8841074daed Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 5 Apr 2019 12:06:59 -0500 Subject: [PATCH 0234/1158] unlock connection when context is pre-canceled --- pgconn.go | 5 ++ pgconn_test.go | 166 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 171 insertions(+) diff --git a/pgconn.go b/pgconn.go index 223b8e3d..db741d47 100644 --- a/pgconn.go +++ b/pgconn.go @@ -504,6 +504,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ select { case <-ctx.Done(): + pgConn.unlock() return nil, ctx.Err() default: } @@ -626,6 +627,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { select { case <-ctx.Done(): + pgConn.unlock() return ctx.Err() default: } @@ -668,6 +670,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { case <-ctx.Done(): multiResult.closed = true multiResult.err = ctx.Err() + pgConn.unlock() return multiResult default: } @@ -828,6 +831,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm select { case <-ctx.Done(): + pgConn.unlock() return "", ctx.Err() default: } @@ -1278,6 +1282,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR case <-ctx.Done(): multiResult.closed = true multiResult.err = ctx.Err() + pgConn.unlock() return multiResult default: } diff --git a/pgconn_test.go b/pgconn_test.go index b2514e48..66a4337b 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -255,6 +255,23 @@ func TestConnPrepareSyntaxError(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnPrepareContextPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil) + require.Nil(t, psd) + require.Error(t, err) + require.Equal(t, context.Canceled, err) + + ensureConnValid(t, pgConn) +} + func TestConnExec(t *testing.T) { t.Parallel() @@ -360,6 +377,22 @@ func TestConnExecContextCanceled(t *testing.T) { assert.False(t, pgConn.IsAlive()) } +func TestConnExecContextPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + require.Error(t, err) + require.Equal(t, context.Canceled, err) + + ensureConnValid(t, pgConn) +} + func TestConnExecParams(t *testing.T) { t.Parallel() @@ -449,6 +482,22 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.False(t, pgConn.IsAlive()) } +func TestConnExecParamsPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, context.Canceled, result.Err) + + ensureConnValid(t, pgConn) +} + func TestConnExecPrepared(t *testing.T) { t.Parallel() @@ -558,6 +607,25 @@ func TestConnExecPreparedCanceled(t *testing.T) { assert.False(t, pgConn.IsAlive()) } +func TestConnExecPreparedPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, context.Canceled, result.Err) + + ensureConnValid(t, pgConn) +} + func TestConnExecBatch(t *testing.T) { t.Parallel() @@ -590,6 +658,31 @@ func TestConnExecBatch(t *testing.T) { assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } +func TestConnExecBatchPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = pgConn.ExecBatch(ctx, batch).ReadAll() + require.Error(t, err) + require.Equal(t, context.Canceled, err) + + ensureConnValid(t, pgConn) +} + func TestConnLocking(t *testing.T) { t.Parallel() @@ -726,6 +819,24 @@ func TestConnWaitForNotification(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnWaitForNotificationPrecanceled(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err = pgConn.WaitForNotification(ctx) + require.Equal(t, context.Canceled, err) + + ensureConnValid(t, pgConn) +} + func TestConnWaitForNotificationTimeout(t *testing.T) { t.Parallel() @@ -855,6 +966,25 @@ func TestConnCopyToCanceled(t *testing.T) { assert.False(t, pgConn.IsAlive()) } +func TestConnCopyToPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + outputWriter := &bytes.Buffer{} + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") + require.Error(t, err) + require.Equal(t, context.Canceled, err) + assert.Equal(t, pgconn.CommandTag(""), res) + + ensureConnValid(t, pgConn) +} + func TestConnCopyFrom(t *testing.T) { t.Parallel() @@ -926,6 +1056,42 @@ func TestConnCopyFromCanceled(t *testing.T) { assert.False(t, pgConn.IsAlive()) } +func TestConnCopyFromPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + r, w := io.Pipe() + go func() { + for i := 0; i < 1000000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + if err != nil { + return + } + time.Sleep(time.Microsecond) + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.Error(t, err) + require.Equal(t, context.Canceled, err) + assert.Equal(t, pgconn.CommandTag(""), ct) + + ensureConnValid(t, pgConn) +} + func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel() From 0ebe322ac3600c12c9d1989f0ad6b3322c224c4a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 5 Apr 2019 16:10:11 -0500 Subject: [PATCH 0235/1158] Extract common code from ExecParams and ExecPrepared --- pgconn.go | 65 ++++++++++++++++++------------------------------------- 1 file changed, 21 insertions(+), 44 deletions(-) diff --git a/pgconn.go b/pgconn.go index db741d47..7ddc50e6 100644 --- a/pgconn.go +++ b/pgconn.go @@ -712,53 +712,16 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { // // ResultReader must be closed before PgConn can be used again. func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { - result := &ResultReader{ - pgConn: pgConn, - ctx: ctx, - cleanupContextDeadline: func() {}, - } - - if err := pgConn.lock(); err != nil { - result.concludeCommand("", err) - result.closed = true + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { return result } - if len(paramValues) > math.MaxUint16 { - result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) - result.closed = true - pgConn.unlock() - return result - } - - select { - case <-ctx.Done(): - result.concludeCommand("", ctx.Err()) - result.closed = true - pgConn.unlock() - return result - default: - } - result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) - var buf []byte - - // TODO - refactor ExecParams and ExecPrepared - these lines only difference buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) - buf = (&pgproto3.Execute{}).Encode(buf) - buf = (&pgproto3.Sync{}).Encode(buf) - - _, err := pgConn.conn.Write(buf) - if err != nil { - pgConn.hardClose() - result.concludeCommand("", err) - result.cleanupContextDeadline() - result.closed = true - pgConn.unlock() - } + pgConn.execExtendedSuffix(buf, result) return result } @@ -776,6 +739,20 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] // // ResultReader must be closed before PgConn can be used again. func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { + return result + } + + var buf []byte + buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + + pgConn.execExtendedSuffix(buf, result) + + return result +} + +func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { result := &ResultReader{ pgConn: pgConn, ctx: ctx, @@ -805,8 +782,10 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa } result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) - var buf []byte - buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + return result +} + +func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) @@ -819,8 +798,6 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa result.closed = true pgConn.unlock() } - - return result } // CopyTo executes the copy command sql and copies the results to w. From fcbd9e93fa4818c73f81fa519ab18bf101b6623d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 12 Apr 2019 16:58:42 -0500 Subject: [PATCH 0236/1158] Initial pass at fixing pgtype tests Many still failing, but at least it compiles now. --- enum_array_test.go | 7 ++++--- hstore_array_test.go | 5 +++-- jsonb_test.go | 2 +- line_test.go | 3 ++- record_test.go | 13 +++++++------ testutil/testutil.go | 28 ++++++++++++++++------------ 6 files changed, 33 insertions(+), 25 deletions(-) diff --git a/enum_array_test.go b/enum_array_test.go index 9cc950af..052a813c 100644 --- a/enum_array_test.go +++ b/enum_array_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "context" "reflect" "testing" @@ -10,12 +11,12 @@ import ( func TestEnumArrayTranscode(t *testing.T) { setupConn := testutil.MustConnectPgx(t) - defer testutil.MustClose(t, setupConn) + defer testutil.MustCloseContext(t, setupConn) - if _, err := setupConn.Exec("drop type if exists color"); err != nil { + if _, err := setupConn.Exec(context.Background(), "drop type if exists color"); err != nil { t.Fatal(err) } - if _, err := setupConn.Exec("create type color as enum ('red', 'green', 'blue')"); err != nil { + if _, err := setupConn.Exec(context.Background(), "create type color as enum ('red', 'green', 'blue')"); err != nil { t.Fatal(err) } diff --git a/hstore_array_test.go b/hstore_array_test.go index d629a04b..c8104d28 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "context" "reflect" "testing" @@ -11,7 +12,7 @@ import ( func TestHstoreArrayTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) - defer testutil.MustClose(t, conn) + defer testutil.MustCloseContext(t, conn) text := func(s string) pgtype.Text { return pgtype.Text{String: s, Status: pgtype.Present} @@ -77,7 +78,7 @@ func TestHstoreArrayTranscode(t *testing.T) { } var result pgtype.HstoreArray - err := conn.QueryRow("test", vEncoder).Scan(&result) + err := conn.QueryRow(context.Background(), "test", vEncoder).Scan(&result) if err != nil { t.Errorf("%v: %v", fc.name, err) continue diff --git a/jsonb_test.go b/jsonb_test.go index ab743151..afc51019 100644 --- a/jsonb_test.go +++ b/jsonb_test.go @@ -11,7 +11,7 @@ import ( func TestJSONBTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) - defer testutil.MustClose(t, conn) + defer testutil.MustCloseContext(t, conn) if _, ok := conn.ConnInfo.DataTypeForName("jsonb"); !ok { t.Skip("Skipping due to no jsonb type") } diff --git a/line_test.go b/line_test.go index 200d1d4c..077afe6b 100644 --- a/line_test.go +++ b/line_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "context" "testing" "github.com/jackc/pgx/pgtype" @@ -15,7 +16,7 @@ func TestLineTranscode(t *testing.T) { // line may exist but not be usable on 9.3 :( var isPG93 bool - err := conn.QueryRow("select version() ~ '9.3'").Scan(&isPG93) + err := conn.QueryRow(context.Background(), "select version() ~ '9.3'").Scan(&isPG93) if err != nil { t.Fatal(err) } diff --git a/record_test.go b/record_test.go index 23ec2cd3..44b0e9d8 100644 --- a/record_test.go +++ b/record_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "context" "fmt" "reflect" "testing" @@ -12,7 +13,7 @@ import ( func TestRecordTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) - defer testutil.MustClose(t, conn) + defer testutil.MustCloseContext(t, conn) tests := []struct { sql string @@ -91,7 +92,7 @@ func TestRecordTranscode(t *testing.T) { ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode var result pgtype.Record - if err := conn.QueryRow(psName).Scan(&result); err != nil { + if err := conn.QueryRow(context.Background(), psName).Scan(&result); err != nil { t.Errorf("%d: %v", i, err) continue } @@ -104,9 +105,9 @@ func TestRecordTranscode(t *testing.T) { func TestRecordWithUnknownOID(t *testing.T) { conn := testutil.MustConnectPgx(t) - defer testutil.MustClose(t, conn) + defer testutil.MustCloseContext(t, conn) - _, err := conn.Exec(`drop type if exists floatrange; + _, err := conn.Exec(context.Background(), `drop type if exists floatrange; create type floatrange as range ( subtype = float8, @@ -115,10 +116,10 @@ create type floatrange as range ( if err != nil { t.Fatal(err) } - defer conn.Exec("drop type floatrange") + defer conn.Exec(context.Background(), "drop type floatrange") var result pgtype.Record - err = conn.QueryRow("select row('foo'::text, floatrange(1, 10), 'bar'::text)").Scan(&result) + err = conn.QueryRow(context.Background(), "select row('foo'::text, floatrange(1, 10), 'bar'::text)").Scan(&result) if err == nil { t.Errorf("expected error but none") } diff --git a/testutil/testutil.go b/testutil/testutil.go index 0effb42d..2cde9961 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -34,12 +34,7 @@ func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { } func MustConnectPgx(t testing.TB) *pgx.Conn { - config, err := pgx.ParseConnectionString(os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - t.Fatal(err) - } - - conn, err := pgx.Connect(config) + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { t.Fatal(err) } @@ -56,6 +51,15 @@ func MustClose(t testing.TB, conn interface { } } +func MustCloseContext(t testing.TB, conn interface { + Close(context.Context) error +}) { + err := conn.Close(context.Background()) + if err != nil { + t.Fatal(err) + } +} + type forceTextEncoder struct { e pgtype.TextEncoder } @@ -102,7 +106,7 @@ func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { conn := MustConnectPgx(t) - defer MustClose(t, conn) + defer MustCloseContext(t, conn) ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) if err != nil { @@ -133,7 +137,7 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] } result := reflect.New(reflect.TypeOf(derefV)) - err := conn.QueryRow("test", ForceEncoder(v, fc.formatCode)).Scan(result.Interface()) + err := conn.QueryRow(context.Background(), "test", ForceEncoder(v, fc.formatCode)).Scan(result.Interface()) if err != nil { t.Errorf("%v %d: %v", fc.name, i, err) } @@ -147,7 +151,7 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { conn := MustConnectPgx(t) - defer MustClose(t, conn) + defer MustCloseContext(t, conn) for i, v := range values { // Derefence value if it is a pointer @@ -158,7 +162,7 @@ func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName str } result := reflect.New(reflect.TypeOf(derefV)) - err := conn.QueryRowEx( + err := conn.QueryRow( context.Background(), fmt.Sprintf("select ($1)::%s", pgTypeName), &pgx.QueryExOptions{SimpleProtocol: true}, @@ -223,7 +227,7 @@ func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc f func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { conn := MustConnectPgx(t) - defer MustClose(t, conn) + defer MustCloseContext(t, conn) formats := []struct { name string @@ -254,7 +258,7 @@ func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFun } result := reflect.New(reflect.TypeOf(derefV)) - err = conn.QueryRow(psName).Scan(result.Interface()) + err = conn.QueryRow(context.Background(), psName).Scan(result.Interface()) if err != nil { t.Errorf("%v %d: %v", fc.name, i, err) } From 59003afe8c48e85b85d66453a737982cd1fc55a1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 12 Apr 2019 21:23:57 -0500 Subject: [PATCH 0237/1158] Fix encode empty value --- testutil/testutil.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/testutil/testutil.go b/testutil/testutil.go index 2cde9961..6ea3a69e 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -137,7 +137,8 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] } result := reflect.New(reflect.TypeOf(derefV)) - err := conn.QueryRow(context.Background(), "test", ForceEncoder(v, fc.formatCode)).Scan(result.Interface()) + + err := conn.QueryRow(context.Background(), "test", vEncoder).Scan(result.Interface()) if err != nil { t.Errorf("%v %d: %v", fc.name, i, err) } From f779b05f367b9d34cfe82b6adb25ff9b5bb8d36c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 12 Apr 2019 21:31:59 -0500 Subject: [PATCH 0238/1158] Extract scan value to pgtype --- pgtype.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/pgtype.go b/pgtype.go index 2643314e..8f41d068 100644 --- a/pgtype.go +++ b/pgtype.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql" "reflect" "github.com/pkg/errors" @@ -84,6 +85,12 @@ func (im InfinityModifier) String() string { } } +// PostgreSQL format codes +const ( + TextFormatCode = 0 + BinaryFormatCode = 1 +) + type Value interface { // Set converts and assigns src to itself. Set(src interface{}) error @@ -207,6 +214,53 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { return ci2 } +func (ci *ConnInfo) Scan(oid OID, formatCode int16, buf []byte, dest interface{}) error { + if dest, ok := dest.(BinaryDecoder); ok && formatCode == BinaryFormatCode { + return dest.DecodeBinary(ci, buf) + } + + if dest, ok := dest.(TextDecoder); ok && formatCode == TextFormatCode { + return dest.DecodeText(ci, buf) + } + + if dt, ok := ci.DataTypeForOID(oid); ok { + value := dt.Value + switch formatCode { + case TextFormatCode: + if textDecoder, ok := value.(TextDecoder); ok { + err := textDecoder.DecodeText(ci, buf) + if err != nil { + return err + } + } else { + return errors.Errorf("%T is not a pgtype.TextDecoder", value) + } + case BinaryFormatCode: + if binaryDecoder, ok := value.(BinaryDecoder); ok { + err := binaryDecoder.DecodeBinary(ci, buf) + if err != nil { + return err + } + } else { + return errors.Errorf("%T is not a pgtype.BinaryDecoder", value) + } + default: + return errors.Errorf("unknown format code: %v", formatCode) + } + + if scanner, ok := dest.(sql.Scanner); ok { + sqlSrc, err := DatabaseSQLValue(ci, value) + if err != nil { + return err + } + return scanner.Scan(sqlSrc) + } else { + return value.AssignTo(dest) + } + } + return errors.Errorf("unknown oid: %v", oid) +} + var nameValues map[string]Value func init() { From 698bd4bf5a75e4c6386e38225e701d7a08da4c86 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Apr 2019 10:30:49 -0500 Subject: [PATCH 0239/1158] Use defer to unlock pgConn in Prepare --- pgconn.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pgconn.go b/pgconn.go index 7ddc50e6..c9891dbf 100644 --- a/pgconn.go +++ b/pgconn.go @@ -501,10 +501,10 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ if err := pgConn.lock(); err != nil { return nil, err } + defer pgConn.unlock() select { case <-ctx.Done(): - pgConn.unlock() return nil, ctx.Err() default: } @@ -548,8 +548,6 @@ readloop: } } - pgConn.unlock() - if parseErr != nil { return nil, parseErr } From 7fbae064bba4ed312fd90ad937f6a4172dad5b22 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Apr 2019 11:39:01 -0500 Subject: [PATCH 0240/1158] Remove simple protocol and one round trip query options It is impossible to guarantee that the a query executed with the simple protocol will behave the same as with the extended protocol. This is because the normal pgx path relies on knowing the OID of query parameters. Without this encoding a value can only be determined by the value instead of the combination of value and PostgreSQL type. For example, how should a []int32 be encoded? It might be encoded into a PostgreSQL int4[] or json. Removal also simplifies the core query path. The primary reason for the simple protocol is for servers like PgBouncer that may not be able to support normal prepared statements. After further research it appears that issuing a "flush" instead "sync" after preparing the unnamed statement would allow PgBouncer to work. The one round trip mode can be better handled with prepared statements. As a last resort, all original server functionality can still be accessed by dropping down to PgConn. --- cid_test.go | 3 --- testutil/testutil.go | 30 ------------------------------ xid_test.go | 3 --- 3 files changed, 36 deletions(-) diff --git a/cid_test.go b/cid_test.go index 0dfc56d4..924e4cf3 100644 --- a/cid_test.go +++ b/cid_test.go @@ -20,9 +20,6 @@ func TestCIDTranscode(t *testing.T) { testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - // No direct conversion from int to cid, convert through text - testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) } diff --git a/testutil/testutil.go b/testutil/testutil.go index 6ea3a69e..462549a7 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -98,7 +98,6 @@ func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) } @@ -150,35 +149,6 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] } } -func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := MustConnectPgx(t) - defer MustCloseContext(t, conn) - - for i, v := range values { - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err := conn.QueryRow( - context.Background(), - fmt.Sprintf("select ($1)::%s", pgTypeName), - &pgx.QueryExOptions{SimpleProtocol: true}, - v, - ).Scan(result.Interface()) - if err != nil { - t.Errorf("Simple protocol %d: %v", i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface()) - } - } -} - func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { conn := MustConnectDatabaseSQL(t, driverName) defer MustClose(t, conn) diff --git a/xid_test.go b/xid_test.go index d0f3f0ab..594d1214 100644 --- a/xid_test.go +++ b/xid_test.go @@ -20,9 +20,6 @@ func TestXIDTranscode(t *testing.T) { testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - // No direct conversion from int to xid, convert through text - testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) } From ea65a92de9b9a15995ea4969ff99b3149f639e55 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Apr 2019 14:06:01 -0500 Subject: [PATCH 0241/1158] Fix long standing text array text format null bug --- text_array.go | 2 +- typed_array_gen.sh | 10 +++++----- varchar_array.go | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/text_array.go b/text_array.go index ec487a23..88171d6c 100644 --- a/text_array.go +++ b/text_array.go @@ -209,7 +209,7 @@ func (src *TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if elemBuf == nil { - buf = append(buf, `"NULL"`...) + buf = append(buf, `NULL`...) } else { buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } diff --git a/typed_array_gen.sh b/typed_array_gen.sh index bd70faa4..911fa392 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -10,16 +10,16 @@ erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]fl erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr go_array_types=[]net.HardwareAddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' binary_format=true typed_array.go.erb > text_array.go -erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' binary_format=true typed_array.go.erb > varchar_array.go -erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string element_type_name=bpchar text_null='NULL' binary_format=true typed_array.go.erb > bpchar_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null=NULL binary_format=true typed_array.go.erb > varchar_array.go +erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string element_type_name=bpchar text_null=NULL binary_format=true typed_array.go.erb > bpchar_array.go erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]float64,[]int64,[]uint64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go -# While the binary format is theoretically possible it is only practical to use the text format. In addition, the text format for NULL enums is unquoted so TextArray or a possible GenericTextArray cannot be used. -erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string text_null='NULL' binary_format=false typed_array.go.erb > enum_array.go +# While the binary format is theoretically possible it is only practical to use the text format. +erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string text_null=NULL binary_format=false typed_array.go.erb > enum_array.go goimports -w *_array.go diff --git a/varchar_array.go b/varchar_array.go index 0a929920..7b9257b8 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -209,7 +209,7 @@ func (src *VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if elemBuf == nil { - buf = append(buf, `"NULL"`...) + buf = append(buf, `NULL`...) } else { buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } From a0f487bc098a63937a3fdc27c8fb7f0812a5432b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Apr 2019 14:17:04 -0500 Subject: [PATCH 0242/1158] More transcoding type tests Text every combination of text and binary arguments and text and binary results. --- testutil/testutil.go | 56 +++++++++++++++++++++++++++----------------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/testutil/testutil.go b/testutil/testutil.go index 462549a7..3711381c 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -107,7 +107,7 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] conn := MustConnectPgx(t) defer MustCloseContext(t, conn) - ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) + _, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) if err != nil { t.Fatal(err) } @@ -121,29 +121,43 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] } for i, v := range values { - for _, fc := range formats { - ps.FieldDescriptions[0].FormatCode = fc.formatCode - vEncoder := ForceEncoder(v, fc.formatCode) - if vEncoder == nil { - t.Logf("Skipping: %#v does not implement %v", v, fc.name) - continue - } - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } + for _, paramFormat := range formats { + for _, resultFormat := range formats { + vEncoder := ForceEncoder(v, paramFormat.formatCode) + if vEncoder == nil { + t.Logf("Skipping Param %s Result %s: %#v does not implement %v for encoding", paramFormat.name, resultFormat.name, v, paramFormat.name) + continue + } + switch resultFormat.formatCode { + case pgx.TextFormatCode: + if _, ok := v.(pgtype.TextEncoder); !ok { + t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name) + continue + } + case pgx.BinaryFormatCode: + if _, ok := v.(pgtype.BinaryEncoder); !ok { + t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name) + continue + } + } - result := reflect.New(reflect.TypeOf(derefV)) + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } - err := conn.QueryRow(context.Background(), "test", vEncoder).Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", fc.name, i, err) - } + result := reflect.New(reflect.TypeOf(derefV)) - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{resultFormat.formatCode}, vEncoder).Scan(result.Interface()) + if err != nil { + t.Errorf("Param %s Result %s %d: %v", paramFormat.name, resultFormat.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("Param %s Result %s %d: expected %v, got %v", paramFormat.name, resultFormat.name, i, derefV, result.Elem().Interface()) + } } } } From bd85fe870d0ee82e9ee6c57c0e01f1319a397053 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Apr 2019 16:45:52 -0500 Subject: [PATCH 0243/1158] Hard code standard PostgreSQL types Instead of needing to instrospect the database on connection preload the standard OID / type map. Types from extensions (like hstore) and custom types can be registered by the application developer. Otherwise, they will be treated as strings. --- decimal.go | 31 --------------- hstore_array_test.go | 14 +++++++ pgtype.go | 95 +++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 103 insertions(+), 37 deletions(-) delete mode 100644 decimal.go diff --git a/decimal.go b/decimal.go deleted file mode 100644 index 79653cf3..00000000 --- a/decimal.go +++ /dev/null @@ -1,31 +0,0 @@ -package pgtype - -type Decimal Numeric - -func (dst *Decimal) Set(src interface{}) error { - return (*Numeric)(dst).Set(src) -} - -func (dst *Decimal) Get() interface{} { - return (*Numeric)(dst).Get() -} - -func (src *Decimal) AssignTo(dst interface{}) error { - return (*Numeric)(src).AssignTo(dst) -} - -func (dst *Decimal) DecodeText(ci *ConnInfo, src []byte) error { - return (*Numeric)(dst).DecodeText(ci, src) -} - -func (dst *Decimal) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Numeric)(dst).DecodeBinary(ci, src) -} - -func (src *Decimal) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Numeric)(src).EncodeText(ci, buf) -} - -func (src *Decimal) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Numeric)(src).EncodeBinary(ci, buf) -} diff --git a/hstore_array_test.go b/hstore_array_test.go index c8104d28..03dc2ff1 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -14,6 +14,20 @@ func TestHstoreArrayTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) + var hstoreOID pgtype.OID + err := conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='hstore';").Scan(&hstoreOID) + if err != nil { + t.Fatalf("did not find hstore OID, %v", err) + } + conn.ConnInfo.RegisterDataType(pgtype.DataType{Value: &pgtype.Hstore{}, Name: "hstore", OID: hstoreOID}) + + var hstoreArrayOID pgtype.OID + err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='_hstore';").Scan(&hstoreArrayOID) + if err != nil { + t.Fatalf("did not find _hstore OID, %v", err) + } + conn.ConnInfo.RegisterDataType(pgtype.DataType{Value: &pgtype.HstoreArray{}, Name: "_hstore", OID: hstoreArrayOID}) + text := func(s string) pgtype.Text { return pgtype.Text{String: s, Status: pgtype.Present} } diff --git a/pgtype.go b/pgtype.go index 8f41d068..4faf23e1 100644 --- a/pgtype.go +++ b/pgtype.go @@ -11,7 +11,7 @@ import ( const ( BoolOID = 16 ByteaOID = 17 - CharOID = 18 + QCharOID = 18 NameOID = 19 Int8OID = 20 Int2OID = 21 @@ -22,11 +22,19 @@ const ( XIDOID = 28 CIDOID = 29 JSONOID = 114 + PointOID = 600 + LsegOID = 601 + PathOID = 602 + BoxOID = 603 + PolygonOID = 604 + LineOID = 628 CIDROID = 650 CIDRArrayOID = 651 Float4OID = 700 Float8OID = 701 + CircleOID = 718 UnknownOID = 705 + MacaddrOID = 829 InetOID = 869 BoolArrayOID = 1000 Int2ArrayOID = 1005 @@ -49,11 +57,21 @@ const ( DateArrayOID = 1182 TimestamptzOID = 1184 TimestamptzArrayOID = 1185 + IntervalOID = 1186 + NumericArrayOID = 1231 + BitOID = 1560 + VarbitOID = 1562 NumericOID = 1700 RecordOID = 2249 UUIDOID = 2950 UUIDArrayOID = 2951 JSONBOID = 3802 + DaterangeOID = 3812 + Int4rangeOID = 3904 + NumrangeOID = 3906 + TsrangeOID = 3908 + TstzrangeOID = 3910 + Int8rangeOID = 3926 ) type Status byte @@ -155,11 +173,77 @@ type ConnInfo struct { } func NewConnInfo() *ConnInfo { - return &ConnInfo{ - oidToDataType: make(map[OID]*DataType, 256), - nameToDataType: make(map[string]*DataType, 256), - reflectTypeToDataType: make(map[reflect.Type]*DataType, 256), + ci := &ConnInfo{ + oidToDataType: make(map[OID]*DataType, 128), + nameToDataType: make(map[string]*DataType, 128), + reflectTypeToDataType: make(map[reflect.Type]*DataType, 128), } + + ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID}) + ci.RegisterDataType(DataType{Value: &BoolArray{}, Name: "_bool", OID: BoolArrayOID}) + ci.RegisterDataType(DataType{Value: &BPCharArray{}, Name: "_bpchar", OID: BPCharArrayOID}) + ci.RegisterDataType(DataType{Value: &ByteaArray{}, Name: "_bytea", OID: ByteaArrayOID}) + ci.RegisterDataType(DataType{Value: &CIDRArray{}, Name: "_cidr", OID: CIDRArrayOID}) + ci.RegisterDataType(DataType{Value: &DateArray{}, Name: "_date", OID: DateArrayOID}) + ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID}) + ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID}) + ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID}) + ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID}) + ci.RegisterDataType(DataType{Value: &Int4Array{}, Name: "_int4", OID: Int4ArrayOID}) + ci.RegisterDataType(DataType{Value: &Int8Array{}, Name: "_int8", OID: Int8ArrayOID}) + ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) + ci.RegisterDataType(DataType{Value: &TextArray{}, Name: "_text", OID: TextArrayOID}) + ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID}) + ci.RegisterDataType(DataType{Value: &TimestamptzArray{}, Name: "_timestamptz", OID: TimestamptzArrayOID}) + ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID}) + ci.RegisterDataType(DataType{Value: &VarcharArray{}, Name: "_varchar", OID: VarcharArrayOID}) + ci.RegisterDataType(DataType{Value: &ACLItem{}, Name: "aclitem", OID: ACLItemOID}) + ci.RegisterDataType(DataType{Value: &Bit{}, Name: "bit", OID: BitOID}) + ci.RegisterDataType(DataType{Value: &Bool{}, Name: "bool", OID: BoolOID}) + ci.RegisterDataType(DataType{Value: &Box{}, Name: "box", OID: BoxOID}) + ci.RegisterDataType(DataType{Value: &BPChar{}, Name: "bpchar", OID: BPCharOID}) + ci.RegisterDataType(DataType{Value: &Bytea{}, Name: "bytea", OID: ByteaOID}) + ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) + ci.RegisterDataType(DataType{Value: &CID{}, Name: "cid", OID: CIDOID}) + ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID}) + ci.RegisterDataType(DataType{Value: &Circle{}, Name: "circle", OID: CircleOID}) + ci.RegisterDataType(DataType{Value: &Date{}, Name: "date", OID: DateOID}) + ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) + ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID}) + ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID}) + ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID}) + ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID}) + ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID}) + ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) + ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID}) + ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) + ci.RegisterDataType(DataType{Value: &Interval{}, Name: "interval", OID: IntervalOID}) + ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID}) + ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID}) + ci.RegisterDataType(DataType{Value: &Line{}, Name: "line", OID: LineOID}) + ci.RegisterDataType(DataType{Value: &Lseg{}, Name: "lseg", OID: LsegOID}) + ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID}) + ci.RegisterDataType(DataType{Value: &Name{}, Name: "name", OID: NameOID}) + ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID}) + ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) + ci.RegisterDataType(DataType{Value: &OIDValue{}, Name: "oid", OID: OIDOID}) + ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID}) + ci.RegisterDataType(DataType{Value: &Point{}, Name: "point", OID: PointOID}) + ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID}) + ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) + ci.RegisterDataType(DataType{Value: &Text{}, Name: "text", OID: TextOID}) + ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID}) + ci.RegisterDataType(DataType{Value: &Timestamp{}, Name: "timestamp", OID: TimestampOID}) + ci.RegisterDataType(DataType{Value: &Timestamptz{}, Name: "timestamptz", OID: TimestamptzOID}) + ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) + ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) + ci.RegisterDataType(DataType{Value: &Unknown{}, Name: "unknown", OID: UnknownOID}) + ci.RegisterDataType(DataType{Value: &UUID{}, Name: "uuid", OID: UUIDOID}) + ci.RegisterDataType(DataType{Value: &Varbit{}, Name: "varbit", OID: VarbitOID}) + ci.RegisterDataType(DataType{Value: &Varchar{}, Name: "varchar", OID: VarcharOID}) + ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID}) + + return ci } func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) { @@ -295,7 +379,6 @@ func init() { "circle": &Circle{}, "date": &Date{}, "daterange": &Daterange{}, - "decimal": &Decimal{}, "float4": &Float4{}, "float8": &Float8{}, "hstore": &Hstore{}, From 4e79a104f716343dcc38d011de626c3a65081025 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Apr 2019 17:09:51 -0500 Subject: [PATCH 0244/1158] Test domains when registered and unregistered Fix bug assigning to unknown type. --- aclitem.go | 1 + aclitem_array.go | 1 + bool.go | 1 + bool_array.go | 1 + bpchar_array.go | 1 + bytea.go | 1 + bytea_array.go | 1 + cidr_array.go | 1 + date.go | 1 + date_array.go | 1 + enum_array.go | 1 + ext/satori-uuid/uuid.go | 1 + ext/shopspring-numeric/decimal.go | 1 + float4_array.go | 1 + float8_array.go | 1 + hstore.go | 1 + hstore_array.go | 1 + inet.go | 1 + inet_array.go | 1 + int2_array.go | 1 + int4_array.go | 1 + int8_array.go | 1 + interval.go | 1 + macaddr.go | 1 + macaddr_array.go | 1 + numeric.go | 1 + numeric_array.go | 1 + record.go | 1 + text.go | 1 + text_array.go | 1 + timestamp.go | 1 + timestamp_array.go | 1 + timestamptz.go | 1 + timestamptz_array.go | 1 + typed_array.go.erb | 1 + uuid_array.go | 1 + varchar_array.go | 1 + 37 files changed, 37 insertions(+) diff --git a/aclitem.go b/aclitem.go index 4da962dd..a54955eb 100644 --- a/aclitem.go +++ b/aclitem.go @@ -65,6 +65,7 @@ func (src *ACLItem) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/aclitem_array.go b/aclitem_array.go index d8bf3303..2671022b 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -79,6 +79,7 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/bool.go b/bool.go index 0574588d..22774970 100644 --- a/bool.go +++ b/bool.go @@ -59,6 +59,7 @@ func (src *Bool) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/bool_array.go b/bool_array.go index 623937dc..1aefcd27 100644 --- a/bool_array.go +++ b/bool_array.go @@ -81,6 +81,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/bpchar_array.go b/bpchar_array.go index d1ee2419..dd4a8363 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -81,6 +81,7 @@ func (src *BPCharArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/bytea.go b/bytea.go index 4506dc31..064f199a 100644 --- a/bytea.go +++ b/bytea.go @@ -59,6 +59,7 @@ func (src *Bytea) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/bytea_array.go b/bytea_array.go index 68122961..fc07d103 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -81,6 +81,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/cidr_array.go b/cidr_array.go index 338d4904..62b0ca65 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -110,6 +110,7 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/date.go b/date.go index 85c698aa..3f8d188a 100644 --- a/date.go +++ b/date.go @@ -67,6 +67,7 @@ func (src *Date) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/date_array.go b/date_array.go index d04666f1..6d6c0899 100644 --- a/date_array.go +++ b/date_array.go @@ -82,6 +82,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/enum_array.go b/enum_array.go index 7168cb8a..5de2badf 100644 --- a/enum_array.go +++ b/enum_array.go @@ -79,6 +79,7 @@ func (src *EnumArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/ext/satori-uuid/uuid.go b/ext/satori-uuid/uuid.go index 78a90035..baebc5ed 100644 --- a/ext/satori-uuid/uuid.go +++ b/ext/satori-uuid/uuid.go @@ -78,6 +78,7 @@ func (src *UUID) AssignTo(dst interface{}) error { if nextDst, retry := pgtype.GetAssignToDstType(v); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case pgtype.Null: return pgtype.NullAssignTo(dst) diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index 507a93dc..7c1cd770 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -205,6 +205,7 @@ func (src *Numeric) AssignTo(dst interface{}) error { if nextDst, retry := pgtype.GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case pgtype.Null: return pgtype.NullAssignTo(dst) diff --git a/float4_array.go b/float4_array.go index 4e07ba43..b14161e8 100644 --- a/float4_array.go +++ b/float4_array.go @@ -81,6 +81,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/float8_array.go b/float8_array.go index e4c340b2..60e87236 100644 --- a/float8_array.go +++ b/float8_array.go @@ -81,6 +81,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/hstore.go b/hstore.go index 754c5a3f..8a84fe2a 100644 --- a/hstore.go +++ b/hstore.go @@ -68,6 +68,7 @@ func (src *Hstore) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/hstore_array.go b/hstore_array.go index 239c5d9c..19d07686 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -81,6 +81,7 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/inet.go b/inet.go index d93e6347..dfdd8868 100644 --- a/inet.go +++ b/inet.go @@ -86,6 +86,7 @@ func (src *Inet) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/inet_array.go b/inet_array.go index 7b4cf457..51ad7988 100644 --- a/inet_array.go +++ b/inet_array.go @@ -110,6 +110,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/int2_array.go b/int2_array.go index 5b4c2e1a..e3b9f64b 100644 --- a/int2_array.go +++ b/int2_array.go @@ -109,6 +109,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/int4_array.go b/int4_array.go index 77ad8654..ad75c4b5 100644 --- a/int4_array.go +++ b/int4_array.go @@ -128,6 +128,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/int8_array.go b/int8_array.go index 03b169d2..ae8d8e0f 100644 --- a/int8_array.go +++ b/int8_array.go @@ -109,6 +109,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/interval.go b/interval.go index 75969904..9172e14a 100644 --- a/interval.go +++ b/interval.go @@ -69,6 +69,7 @@ func (src *Interval) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/macaddr.go b/macaddr.go index 79004be4..6854400b 100644 --- a/macaddr.go +++ b/macaddr.go @@ -65,6 +65,7 @@ func (src *Macaddr) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/macaddr_array.go b/macaddr_array.go index c6bc2450..2d0439e9 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -82,6 +82,7 @@ func (src *MacaddrArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/numeric.go b/numeric.go index fb6e1a00..91aff123 100644 --- a/numeric.go +++ b/numeric.go @@ -250,6 +250,7 @@ func (src *Numeric) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/numeric_array.go b/numeric_array.go index 0d26f3b5..ec892cc8 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -165,6 +165,7 @@ func (src *NumericArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/record.go b/record.go index 64c6f13a..315deda5 100644 --- a/record.go +++ b/record.go @@ -62,6 +62,7 @@ func (src *Record) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/text.go b/text.go index 919743fe..648bbd58 100644 --- a/text.go +++ b/text.go @@ -69,6 +69,7 @@ func (src *Text) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/text_array.go b/text_array.go index 88171d6c..1556fec8 100644 --- a/text_array.go +++ b/text_array.go @@ -81,6 +81,7 @@ func (src *TextArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/timestamp.go b/timestamp.go index f8a4070d..93383e35 100644 --- a/timestamp.go +++ b/timestamp.go @@ -71,6 +71,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/timestamp_array.go b/timestamp_array.go index 493088a2..1fd1eefe 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -82,6 +82,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/timestamptz.go b/timestamptz.go index ca9b538d..c2c91c29 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -72,6 +72,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/timestamptz_array.go b/timestamptz_array.go index 612e9904..b87238ae 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -82,6 +82,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/typed_array.go.erb b/typed_array.go.erb index b33e7d99..3ee637aa 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -81,6 +81,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/uuid_array.go b/uuid_array.go index cddd62f1..fac838af 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -137,6 +137,7 @@ func (src *UUIDArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) diff --git a/varchar_array.go b/varchar_array.go index 7b9257b8..d2359d03 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -81,6 +81,7 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) From 78eda7d56799723cd9f6c3b2f599928033470f02 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Apr 2019 18:06:09 -0500 Subject: [PATCH 0245/1158] Remove unused scan float into numeric --- numeric.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/numeric.go b/numeric.go index 91aff123..887ad1f8 100644 --- a/numeric.go +++ b/numeric.go @@ -568,10 +568,6 @@ func (dst *Numeric) Scan(src interface{}) error { } switch src := src.(type) { - case float64: - // TODO - // *dst = Numeric{Float: src, Status: Present} - return nil case string: return dst.DecodeText(nil, []byte(src)) case []byte: From b2a540ca814e2103bdb45188625a035e2ef3f6b6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 16 Apr 2019 20:30:55 -0500 Subject: [PATCH 0246/1158] Add sufficient support for SCRAM --- authentication.go | 30 +++++++++++++++++++ sasl_initial_response.go | 64 ++++++++++++++++++++++++++++++++++++++++ sasl_response.go | 38 ++++++++++++++++++++++++ 3 files changed, 132 insertions(+) create mode 100644 sasl_initial_response.go create mode 100644 sasl_response.go diff --git a/authentication.go b/authentication.go index 14275a86..2078c87c 100644 --- a/authentication.go +++ b/authentication.go @@ -1,6 +1,7 @@ package pgproto3 import ( + "bytes" "encoding/binary" "github.com/jackc/pgio" @@ -11,6 +12,9 @@ const ( AuthTypeOk = 0 AuthTypeCleartextPassword = 3 AuthTypeMD5Password = 5 + AuthTypeSASL = 10 + AuthTypeSASLContinue = 11 + AuthTypeSASLFinal = 12 ) type Authentication struct { @@ -18,6 +22,12 @@ type Authentication struct { // MD5Password fields Salt [4]byte + + // SASL fields + SASLAuthMechanisms []string + + // SASLContinue and SASLFinal data + SASLData []byte } func (*Authentication) Backend() {} @@ -30,6 +40,17 @@ func (dst *Authentication) Decode(src []byte) error { case AuthTypeCleartextPassword: case AuthTypeMD5Password: copy(dst.Salt[:], src[4:8]) + case AuthTypeSASL: + authMechanisms := src[4:] + for len(authMechanisms) > 1 { + idx := bytes.IndexByte(authMechanisms, 0) + if idx > 0 { + dst.SASLAuthMechanisms = append(dst.SASLAuthMechanisms, string(authMechanisms[:idx])) + authMechanisms = authMechanisms[idx+1:] + } + } + case AuthTypeSASLContinue, AuthTypeSASLFinal: + dst.SASLData = src[4:] default: return errors.Errorf("unknown authentication type: %d", dst.Type) } @@ -46,6 +67,15 @@ func (src *Authentication) Encode(dst []byte) []byte { switch src.Type { case AuthTypeMD5Password: dst = append(dst, src.Salt[:]...) + case AuthTypeSASL: + for _, s := range src.SASLAuthMechanisms { + dst = append(dst, []byte(s)...) + dst = append(dst, 0) + } + dst = append(dst, 0) + case AuthTypeSASLContinue: + dst = pgio.AppendInt32(dst, int32(len(src.SASLData))) + dst = append(dst, src.SASLData...) } pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) diff --git a/sasl_initial_response.go b/sasl_initial_response.go new file mode 100644 index 00000000..63766131 --- /dev/null +++ b/sasl_initial_response.go @@ -0,0 +1,64 @@ +package pgproto3 + +import ( + "bytes" + "encoding/hex" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +type SASLInitialResponse struct { + AuthMechanism string + Data []byte +} + +func (*SASLInitialResponse) Frontend() {} + +func (dst *SASLInitialResponse) Decode(src []byte) error { + *dst = SASLInitialResponse{} + + rp := 0 + + idx := bytes.IndexByte(src, 0) + if idx < 0 { + return errors.New("invalid SASLInitialResponse") + } + + dst.AuthMechanism = string(src[rp:idx]) + rp = idx + 1 + + rp += 4 // The rest of the message is data so we can just skip the size + dst.Data = src[rp:] + + return nil +} + +func (src *SASLInitialResponse) Encode(dst []byte) []byte { + dst = append(dst, 'p') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, []byte(src.AuthMechanism)...) + dst = append(dst, 0) + + dst = pgio.AppendInt32(dst, int32(len(src.Data))) + dst = append(dst, src.Data...) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *SASLInitialResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + AuthMechanism string + Data string + }{ + Type: "SASLInitialResponse", + AuthMechanism: src.AuthMechanism, + Data: hex.EncodeToString(src.Data), + }) +} diff --git a/sasl_response.go b/sasl_response.go new file mode 100644 index 00000000..1e8d3bd3 --- /dev/null +++ b/sasl_response.go @@ -0,0 +1,38 @@ +package pgproto3 + +import ( + "encoding/hex" + "encoding/json" + + "github.com/jackc/pgio" +) + +type SASLResponse struct { + Data []byte +} + +func (*SASLResponse) Frontend() {} + +func (dst *SASLResponse) Decode(src []byte) error { + *dst = SASLResponse{Data: src} + return nil +} + +func (src *SASLResponse) Encode(dst []byte) []byte { + dst = append(dst, 'p') + dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) + + dst = append(dst, src.Data...) + + return dst +} + +func (src *SASLResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "SASLResponse", + Data: hex.EncodeToString(src.Data), + }) +} From 244e114435d4afb7934392284613255b639d6fb9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 16 Apr 2019 20:41:38 -0500 Subject: [PATCH 0247/1158] Add SCRAM authentication --- auth_scram.go | 255 +++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 4 +- go.sum | 11 ++- pgconn.go | 2 + pgconn_test.go | 3 +- 5 files changed, 268 insertions(+), 7 deletions(-) create mode 100644 auth_scram.go diff --git a/auth_scram.go b/auth_scram.go new file mode 100644 index 00000000..b78a236a --- /dev/null +++ b/auth_scram.go @@ -0,0 +1,255 @@ +// SCRAM-SHA-256 authentication +// +// Resources: +// https://tools.ietf.org/html/rfc5802 +// https://tools.ietf.org/html/rfc8265 +// https://www.postgresql.org/docs/current/sasl-authentication.html +// +// Inspiration drawn from other implementations: +// https://github.com/lib/pq/pull/608 +// https://github.com/lib/pq/pull/788 +// https://github.com/lib/pq/pull/833 +package pgconn + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "strconv" + + "github.com/jackc/pgproto3" + "golang.org/x/crypto/pbkdf2" + "golang.org/x/text/secure/precis" +) + +const clientNonceLen = 18 + +// Perform SCRAM authentication. +func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { + sc, err := newScramClient(serverAuthMechanisms, c.Config.Password) + if err != nil { + return err + } + + // Send client-first-message in a SASLInitialResponse + saslInitialResponse := &pgproto3.SASLInitialResponse{ + AuthMechanism: "SCRAM-SHA-256", + Data: sc.clientFirstMessage(), + } + _, err = c.conn.Write(saslInitialResponse.Encode(nil)) + if err != nil { + return err + } + + // Receive server-first-message payload in a AuthenticationSASLContinue. + authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue) + if err != nil { + return err + } + err = sc.recvServerFirstMessage(authMsg.SASLData) + if err != nil { + return err + } + + // Send client-final-message in a SASLResponse + saslResponse := &pgproto3.SASLResponse{ + Data: []byte(sc.clientFinalMessage()), + } + _, err = c.conn.Write(saslResponse.Encode(nil)) + if err != nil { + return err + } + + // Receive server-final-message payload in a AuthenticationSASLFinal. + authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal) + if err != nil { + return err + } + return sc.recvServerFinalMessage(authMsg.SASLData) +} + +func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) { + msg, err := c.ReceiveMessage() + if err != nil { + return nil, err + } + authMsg, ok := msg.(*pgproto3.Authentication) + if !ok { + return nil, errors.New("unexpected message type") + } + if authMsg.Type != typ { + return nil, errors.New("unexpected auth type") + } + + return authMsg, nil +} + +type scramClient struct { + serverAuthMechanisms []string + password []byte + clientNonce []byte + + clientFirstMessageBare []byte + + serverFirstMessage []byte + clientAndServerNonce []byte + salt []byte + iterations int + + saltedPassword []byte + authMessage []byte +} + +func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) { + sc := &scramClient{ + serverAuthMechanisms: serverAuthMechanisms, + } + + // Ensure server supports SCRAM-SHA-256 + hasScramSHA256 := false + for _, mech := range sc.serverAuthMechanisms { + if mech == "SCRAM-SHA-256" { + hasScramSHA256 = true + break + } + } + if !hasScramSHA256 { + return nil, errors.New("server does not support SCRAM-SHA-256") + } + + // precis.OpaqueString is equivalent to SASLprep for password. + var err error + sc.password, err = precis.OpaqueString.Bytes([]byte(password)) + if err != nil { + // PostgreSQL allows passwords invalid according to SCRAM / SASLprep. + sc.password = []byte(password) + } + + buf := make([]byte, clientNonceLen) + _, err = rand.Read(buf) + if err != nil { + return nil, err + } + sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf))) + base64.RawStdEncoding.Encode(sc.clientNonce, buf) + + return sc, nil +} + +func (sc *scramClient) clientFirstMessage() []byte { + sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce)) + return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare)) +} + +func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { + sc.serverFirstMessage = serverFirstMessage + buf := serverFirstMessage + if !bytes.HasPrefix(buf, []byte("r=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include r=") + } + buf = buf[2:] + + idx := bytes.IndexByte(buf, ',') + if idx == -1 { + return errors.New("invalid SCRAM server-first-message received from server: did not include s=") + } + sc.clientAndServerNonce = buf[:idx] + buf = buf[idx+1:] + + if !bytes.HasPrefix(buf, []byte("s=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include s=") + } + buf = buf[2:] + + idx = bytes.IndexByte(buf, ',') + if idx == -1 { + return errors.New("invalid SCRAM server-first-message received from server: did not include i=") + } + saltStr := buf[:idx] + buf = buf[idx+1:] + + if !bytes.HasPrefix(buf, []byte("i=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include i=") + } + buf = buf[2:] + iterationsStr := buf + + var err error + sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr)) + if err != nil { + return fmt.Errorf("invalid SCRAM salt received from server: %v", err) + } + + sc.iterations, err = strconv.Atoi(string(iterationsStr)) + if err != nil || sc.iterations <= 0 { + return fmt.Errorf("invalid SCRAM iteration count received from server: %s", iterationsStr) + } + + if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) { + return errors.New("invalid SCRAM nonce: did not start with client nonce") + } + + if len(sc.clientAndServerNonce) <= len(sc.clientNonce) { + return errors.New("invalid SCRAM nonce: did not include server nonce") + } + + return nil +} + +func (sc *scramClient) clientFinalMessage() string { + clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce)) + + sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New) + sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(",")) + + clientProof := computeClientProof(sc.saltedPassword, sc.authMessage) + + return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof) +} + +func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error { + if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) { + return errors.New("invalid SCRAM server-final-message received from server") + } + + serverSignature := serverFinalMessage[2:] + + if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) { + return errors.New("invalid SCRAM ServerSignature received from server") + } + + return nil +} + +func computeHMAC(key, msg []byte) []byte { + mac := hmac.New(sha256.New, key) + mac.Write(msg) + return mac.Sum(nil) +} + +func computeClientProof(saltedPassword, authMessage []byte) []byte { + clientKey := computeHMAC(saltedPassword, []byte("Client Key")) + storedKey := sha256.Sum256(clientKey) + clientSignature := computeHMAC(storedKey[:], authMessage) + + clientProof := make([]byte, len(clientSignature)) + for i := 0; i < len(clientSignature); i++ { + clientProof[i] = clientKey[i] ^ clientSignature[i] + } + + buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof))) + base64.StdEncoding.Encode(buf, clientProof) + return buf +} + +func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { + serverKey := computeHMAC(saltedPassword, []byte("Server Key")) + serverSignature := computeHMAC(serverKey[:], authMessage) + buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) + base64.StdEncoding.Encode(buf, serverSignature) + return buf +} diff --git a/go.mod b/go.mod index 3dc806a4..09b4471d 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,9 @@ go 1.12 require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3 v1.0.0 + github.com/jackc/pgproto3 v1.1.0 github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.3.0 + golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a + golang.org/x/text v0.3.0 ) diff --git a/go.sum b/go.sum index 5b6f835b..8872aac1 100644 --- a/go.sum +++ b/go.sum @@ -6,10 +6,8 @@ github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v0.0.0-20190330174656-bb06e6b3ff87 h1:xueDi0R+HxuFmuOA1xyFbbF+2LSXqWQJZSPWmmMFB0A= -github.com/jackc/pgproto3 v0.0.0-20190330174656-bb06e6b3ff87/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= -github.com/jackc/pgproto3 v1.0.0 h1:25tUmlES7eyD96oYaUHc1dLOFbgcJtFzCdnOOoqmA1I= -github.com/jackc/pgproto3 v1.0.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -17,3 +15,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/pgconn.go b/pgconn.go index c9891dbf..264d9e8c 100644 --- a/pgconn.go +++ b/pgconn.go @@ -260,6 +260,8 @@ func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { case pgproto3.AuthTypeMD5Password: digestedPassword := "md5" + hexMD5(hexMD5(c.Config.Password+c.Config.User)+string(msg.Salt[:])) err = c.txPasswordMessage(digestedPassword) + case pgproto3.AuthTypeSASL: + err = c.scramAuth(msg.SASLAuthMechanisms) default: err = errors.New("Received unknown authentication message") } diff --git a/pgconn_test.go b/pgconn_test.go index 66a4337b..fd57face 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -33,12 +33,11 @@ func TestConnect(t *testing.T) { {"TCP", "PGX_TEST_TCP_CONN_STRING"}, {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Parallel() - connString := os.Getenv(tt.env) if connString == "" { t.Skipf("Skipping due to missing environment variable %v", tt.env) From 0174907e04e75b23e393d98c31593b886e599e5f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 16 Apr 2019 20:58:10 -0500 Subject: [PATCH 0248/1158] Fix travis unix domain socket test --- travis/before_script.bash | 1 + 1 file changed, 1 insertion(+) diff --git a/travis/before_script.bash b/travis/before_script.bash index bcf748a1..923b7d06 100755 --- a/travis/before_script.bash +++ b/travis/before_script.bash @@ -11,6 +11,7 @@ then psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user travis" psql -U postgres -c "create user pgx_replication with replication password 'secret'" psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" fi From e948dc3246b579f05e9d508be07bf2816ba96b3d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 18 Apr 2019 21:51:58 -0500 Subject: [PATCH 0249/1158] Reuse buffer for writing --- pgconn.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/pgconn.go b/pgconn.go index 264d9e8c..82a010b8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -97,6 +97,8 @@ type PgConn struct { bufferingReceiveMux sync.Mutex bufferingReceiveMsg pgproto3.BackendMessage bufferingReceiveErr error + + wbuf []byte // Reusable write buffer } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -153,6 +155,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.Config = config + pgConn.wbuf = make([]byte, 0, 1024) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) @@ -190,7 +193,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig startupMsg.Parameters["database"] = config.Database } - if _, err := pgConn.conn.Write(startupMsg.Encode(nil)); err != nil { + if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { pgConn.conn.Close() return nil, err } @@ -271,7 +274,7 @@ func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { func (pgConn *PgConn) txPasswordMessage(password string) (err error) { msg := &pgproto3.PasswordMessage{Password: password} - _, err = pgConn.conn.Write(msg.Encode(nil)) + _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf)) return err } @@ -513,7 +516,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContextDeadline() - var buf []byte + buf := pgConn.wbuf buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) @@ -676,7 +679,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { } multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) - var buf []byte + buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) _, err := pgConn.conn.Write(buf) @@ -717,7 +720,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return result } - var buf []byte + buf := pgConn.wbuf buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) @@ -744,7 +747,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } - var buf []byte + buf := pgConn.wbuf buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) pgConn.execExtendedSuffix(buf, result) @@ -816,7 +819,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm defer cleanupContextDeadline() // Send copy to command - var buf []byte + buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) _, err := pgConn.conn.Write(buf) @@ -875,7 +878,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co defer cleanupContextDeadline() // Send copy to command - var buf []byte + buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) _, err := pgConn.conn.Write(buf) From bc139fadb5b49cf4159b33c1312cc66ef0582c7e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 18 Apr 2019 22:01:47 -0500 Subject: [PATCH 0250/1158] Reuse one ResultReader per connection --- pgconn.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pgconn.go b/pgconn.go index 82a010b8..4cf4d745 100644 --- a/pgconn.go +++ b/pgconn.go @@ -98,7 +98,9 @@ type PgConn struct { bufferingReceiveMsg pgproto3.BackendMessage bufferingReceiveErr error - wbuf []byte // Reusable write buffer + // Reusable / preallocated resources + wbuf []byte // write buffer + resultReader ResultReader } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -756,11 +758,12 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa } func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { - result := &ResultReader{ + pgConn.resultReader = ResultReader{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, } + result := &pgConn.resultReader if err := pgConn.lock(); err != nil { result.concludeCommand("", err) @@ -1035,20 +1038,22 @@ func (mrr *MultiResultReader) NextResult() bool { switch msg := msg.(type) { case *pgproto3.RowDescription: - mrr.rr = &ResultReader{ + mrr.pgConn.resultReader = ResultReader{ pgConn: mrr.pgConn, multiResultReader: mrr, ctx: mrr.ctx, cleanupContextDeadline: func() {}, fieldDescriptions: msg.Fields, } + mrr.rr = &mrr.pgConn.resultReader return true case *pgproto3.CommandComplete: - mrr.rr = &ResultReader{ + mrr.pgConn.resultReader = ResultReader{ commandTag: CommandTag(msg.CommandTag), commandConcluded: true, closed: true, } + mrr.rr = &mrr.pgConn.resultReader return true case *pgproto3.EmptyQueryResponse: return false From 2acb7b6d4e0b4478597bc19bacf52ea6033314e0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 18 Apr 2019 22:33:11 -0500 Subject: [PATCH 0251/1158] Reduce mallocs in RowDescription.Decode --- row_description.go | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/row_description.go b/row_description.go index 7f46ede3..dadfc80f 100644 --- a/row_description.go +++ b/row_description.go @@ -30,35 +30,44 @@ type RowDescription struct { func (*RowDescription) Backend() {} func (dst *RowDescription) Decode(src []byte) error { - buf := bytes.NewBuffer(src) - if buf.Len() < 2 { + if len(src) < 2 { return &invalidMessageFormatErr{messageType: "RowDescription"} } - fieldCount := int(binary.BigEndian.Uint16(buf.Next(2))) + fieldCount := int(binary.BigEndian.Uint16(src)) + rp := 2 dst.Fields = dst.Fields[0:0] for i := 0; i < fieldCount; i++ { var fd FieldDescription - bName, err := buf.ReadBytes(0) - if err != nil { - return err + + idx := bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "RowDescription"} } + bName := string(src[rp : rp+idx]) + rp += idx + 1 fd.Name = string(bName[:len(bName)-1]) // Since buf.Next() doesn't return an error if we hit the end of the buffer // check Len ahead of time - if buf.Len() < 18 { + if len(src[rp:]) < 18 { return &invalidMessageFormatErr{messageType: "RowDescription"} } - fd.TableOID = binary.BigEndian.Uint32(buf.Next(4)) - fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2)) - fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4)) - fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2))) - fd.TypeModifier = int32(binary.BigEndian.Uint32(buf.Next(4))) - fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2))) + fd.TableOID = binary.BigEndian.Uint32(src[rp:]) + rp += 4 + fd.TableAttributeNumber = binary.BigEndian.Uint16(src[rp:]) + rp += 2 + fd.DataTypeOID = binary.BigEndian.Uint32(src[rp:]) + rp += 4 + fd.DataTypeSize = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + fd.TypeModifier = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + fd.Format = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 dst.Fields = append(dst.Fields, fd) } From b6e5b74e2c82dc3305355453ec86dc002bf577b4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 18 Apr 2019 22:50:36 -0500 Subject: [PATCH 0252/1158] Reuse one MultiResultReader per connection Using a PgConn while locked now panics. i.e. You must Close any ResultReader or MultiResultReader. --- pgconn.go | 65 ++++++++++++++++---------------------------------- pgconn_test.go | 6 ++--- 2 files changed, 23 insertions(+), 48 deletions(-) diff --git a/pgconn.go b/pgconn.go index 4cf4d745..7e8909ea 100644 --- a/pgconn.go +++ b/pgconn.go @@ -99,8 +99,9 @@ type PgConn struct { bufferingReceiveErr error // Reusable / preallocated resources - wbuf []byte // write buffer - resultReader ResultReader + wbuf []byte // write buffer + resultReader ResultReader + multiResultReader MultiResultReader } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -411,24 +412,18 @@ func (pgConn *PgConn) IsAlive() bool { return !pgConn.closed } -// lock locks the connection. It returns an error if the connection is already locked or is closed. -func (pgConn *PgConn) lock() error { +// lock locks the connection. It panics if the connection is already locked or is closed. +func (pgConn *PgConn) lock() { if pgConn.locked { - return errors.New("connection busy") - } - - if pgConn.closed { - return errors.New("connection closed") + panic("connection busy") // This only should be possible in case of an application bug. } pgConn.locked = true - - return nil } func (pgConn *PgConn) unlock() { if !pgConn.locked { - panic("BUG: cannot unlock unlocked connection") + panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package. } pgConn.locked = false @@ -505,9 +500,7 @@ type PreparedStatementDescription struct { // Prepare creates a prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { - if err := pgConn.lock(); err != nil { - return nil, err - } + pgConn.lock() defer pgConn.unlock() select { @@ -626,9 +619,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { - if err := pgConn.lock(); err != nil { - return err - } + pgConn.lock() select { case <-ctx.Done(): @@ -659,17 +650,14 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { // // Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { - multiResult := &MultiResultReader{ + pgConn.lock() + + pgConn.multiResultReader = MultiResultReader{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, } - - if err := pgConn.lock(); err != nil { - multiResult.closed = true - multiResult.err = err - return multiResult - } + multiResult := &pgConn.multiResultReader select { case <-ctx.Done(): @@ -758,6 +746,8 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa } func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { + pgConn.lock() + pgConn.resultReader = ResultReader{ pgConn: pgConn, ctx: ctx, @@ -765,12 +755,6 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by } result := &pgConn.resultReader - if err := pgConn.lock(); err != nil { - result.concludeCommand("", err) - result.closed = true - return result - } - if len(paramValues) > math.MaxUint16 { result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true @@ -808,9 +792,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { - if err := pgConn.lock(); err != nil { - return "", err - } + pgConn.lock() select { case <-ctx.Done(): @@ -867,9 +849,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { - if err := pgConn.lock(); err != nil { - return "", err - } + pgConn.lock() defer pgConn.unlock() select { @@ -1251,17 +1231,14 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // transaction is already in progress or SQL contains transaction control statements. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { - multiResult := &MultiResultReader{ + pgConn.lock() + + pgConn.multiResultReader = MultiResultReader{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, } - - if err := pgConn.lock(); err != nil { - multiResult.closed = true - multiResult.err = ctx.Err() - return multiResult - } + multiResult := &pgConn.multiResultReader select { case <-ctx.Done(): diff --git a/pgconn_test.go b/pgconn_test.go index fd57face..3be61be8 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -690,11 +690,9 @@ func TestConnLocking(t *testing.T) { defer closeConn(t, pgConn) mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() - assert.Error(t, err) - assert.Equal(t, "connection busy", err.Error()) + require.Panics(t, func() { pgConn.Exec(context.Background(), "select 'Hello, world'") }) - results, err = mrr.ReadAll() + results, err := mrr.ReadAll() assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err) From 9d30dad837720c6b53dbf37fcb413afbd1d94045 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 18 Apr 2019 22:52:07 -0500 Subject: [PATCH 0253/1158] Do not buffer results in benchmarks --- benchmark_test.go | 123 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 115 insertions(+), 8 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 959e86be..000dfd1b 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -1,6 +1,7 @@ package pgconn_test import ( + "bytes" "context" "os" "testing" @@ -41,11 +42,42 @@ func BenchmarkExec(b *testing.B) { require.Nil(b, err) defer closeConn(b, conn) + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date").ReadAll() - require.Nil(b, err) + mrr := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + + for mrr.NextResult() { + rr := mrr.ResultReader() + + rowCount := 0 + for rr.NextRow() { + rowCount += 1 + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } + } + + err := mrr.Close() + if err != nil { + b.Fatal(err) + } } } @@ -54,14 +86,45 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { require.Nil(b, err) defer closeConn(b, conn) + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + b.ResetTimer() ctx, cancel := context.WithCancel(context.Background()) defer cancel() for i := 0; i < b.N; i++ { - _, err := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date").ReadAll() - require.Nil(b, err) + mrr := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + + for mrr.NextResult() { + rr := mrr.ResultReader() + + rowCount := 0 + for rr.NextRow() { + rowCount += 1 + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } + } + + err := mrr.Close() + if err != nil { + b.Fatal(err) + } } } @@ -73,11 +136,33 @@ func BenchmarkExecPrepared(b *testing.B) { _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) require.Nil(b, err) + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + b.ResetTimer() for i := 0; i < b.N; i++ { - result := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil).Read() - require.Nil(b, result.Err) + rr := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil) + + rowCount := 0 + for rr.NextRow() { + rowCount += 1 + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } } } @@ -92,10 +177,32 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { _, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) require.Nil(b, err) + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + b.ResetTimer() for i := 0; i < b.N; i++ { - result := conn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() - require.Nil(b, result.Err) + rr := conn.ExecPrepared(ctx, "ps1", nil, nil, nil) + + rowCount := 0 + for rr.NextRow() { + rowCount += 1 + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } } } From 8d43b38287d2efd59774e43b80ad3bbb7ca33b54 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 18 Apr 2019 23:12:00 -0500 Subject: [PATCH 0254/1158] RowDescription.Name is now []byte Avoid allocation --- row_description.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/row_description.go b/row_description.go index dadfc80f..1b1734dc 100644 --- a/row_description.go +++ b/row_description.go @@ -14,7 +14,7 @@ const ( ) type FieldDescription struct { - Name string + Name []byte TableOID uint32 TableAttributeNumber uint16 DataTypeOID uint32 @@ -46,9 +46,8 @@ func (dst *RowDescription) Decode(src []byte) error { if idx < 0 { return &invalidMessageFormatErr{messageType: "RowDescription"} } - bName := string(src[rp : rp+idx]) + fd.Name = src[rp : rp+idx] rp += idx + 1 - fd.Name = string(bName[:len(bName)-1]) // Since buf.Next() doesn't return an error if we hit the end of the buffer // check Len ahead of time From 76e904a5a4549a8f572b8980a86d6b9fc9c783af Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 18 Apr 2019 23:12:18 -0500 Subject: [PATCH 0255/1158] CommandComplete.CommandTag is now []byte Avoid allocation --- command_complete.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/command_complete.go b/command_complete.go index 0012f6f0..adcc56c8 100644 --- a/command_complete.go +++ b/command_complete.go @@ -8,7 +8,7 @@ import ( ) type CommandComplete struct { - CommandTag string + CommandTag []byte } func (*CommandComplete) Backend() {} @@ -19,7 +19,7 @@ func (dst *CommandComplete) Decode(src []byte) error { return &invalidMessageFormatErr{messageType: "CommandComplete"} } - dst.CommandTag = string(src[:idx]) + dst.CommandTag = src[:idx] return nil } @@ -43,6 +43,6 @@ func (src *CommandComplete) MarshalJSON() ([]byte, error) { CommandTag string }{ Type: "CommandComplete", - CommandTag: src.CommandTag, + CommandTag: string(src.CommandTag), }) } From 9b6a681f50bf8aa2372c54523d64346b1a35d46a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 18 Apr 2019 23:15:44 -0500 Subject: [PATCH 0256/1158] Update go.mod version --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 2c2b401e..37cc5114 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/jackc/pgproto3 +module github.com/jackc/pgproto3/v2 go 1.12 From 2383561e4d1bbf50fde6a214aa04f296764e265f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 18 Apr 2019 23:17:28 -0500 Subject: [PATCH 0257/1158] Use 0-alloc pgproto3/v2 --- auth_scram.go | 2 +- go.mod | 1 + go.sum | 2 ++ pgconn.go | 52 +++++++++++++++++++++++++++----------------------- pgconn_test.go | 10 +++++----- 5 files changed, 37 insertions(+), 30 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index b78a236a..50fbff40 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -21,7 +21,7 @@ import ( "fmt" "strconv" - "github.com/jackc/pgproto3" + "github.com/jackc/pgproto3/v2" "golang.org/x/crypto/pbkdf2" "golang.org/x/text/secure/precis" ) diff --git a/go.mod b/go.mod index 09b4471d..232df737 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3 v1.1.0 + github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a diff --git a/go.sum b/go.sum index 8872aac1..8e0e2c9f 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf h1:wI8d/uq9/RfZOe6bKOpC4Skd4VgkTIGZqxmHu6IQGb8= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/pgconn.go b/pgconn.go index 7e8909ea..7bc93435 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1,6 +1,7 @@ package pgconn import ( + "bytes" "context" "crypto/md5" "crypto/tls" @@ -17,7 +18,7 @@ import ( "time" "github.com/jackc/pgio" - "github.com/jackc/pgproto3" + "github.com/jackc/pgproto3/v2" ) var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) @@ -436,20 +437,23 @@ func (pgConn *PgConn) ParameterStatus(key string) string { } // CommandTag is the result of an Exec function -type CommandTag string +type CommandTag []byte // RowsAffected returns the number of rows affected. If the CommandTag was not // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. func (ct CommandTag) RowsAffected() int64 { - s := string(ct) - index := strings.LastIndex(s, " ") - if index == -1 { + idx := bytes.LastIndexByte([]byte(ct), ' ') + if idx == -1 { return 0 } - n, _ := strconv.ParseInt(s[index+1:], 10, 64) + n, _ := strconv.ParseInt(string([]byte(ct)[idx+1:]), 10, 64) return n } +func (ct CommandTag) String() string { + return string(ct) +} + // preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == // true. Otherwise returns err. func preferContextOverNetTimeoutError(ctx context.Context, err error) error { @@ -756,7 +760,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by result := &pgConn.resultReader if len(paramValues) > math.MaxUint16 { - result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true pgConn.unlock() return result @@ -764,7 +768,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by select { case <-ctx.Done(): - result.concludeCommand("", ctx.Err()) + result.concludeCommand(nil, ctx.Err()) result.closed = true pgConn.unlock() return result @@ -783,7 +787,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - result.concludeCommand("", err) + result.concludeCommand(nil, err) result.cleanupContextDeadline() result.closed = true pgConn.unlock() @@ -797,7 +801,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm select { case <-ctx.Done(): pgConn.unlock() - return "", ctx.Err() + return nil, ctx.Err() default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) @@ -812,7 +816,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm pgConn.hardClose() pgConn.unlock() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } // Read results @@ -822,7 +826,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -831,7 +835,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm _, err := w.Write(msg.Data) if err != nil { pgConn.hardClose() - return "", err + return nil, err } case *pgproto3.ReadyForQuery: pgConn.unlock() @@ -854,7 +858,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co select { case <-ctx.Done(): - return "", ctx.Err() + return nil, ctx.Err() default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) @@ -867,7 +871,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } // Read until copy in response or error. @@ -878,7 +882,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -908,7 +912,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } } @@ -917,7 +921,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -939,7 +943,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } // Read results @@ -947,7 +951,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1145,7 +1149,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { for !rr.commandConcluded { _, err := rr.receiveMessage() if err != nil { - return "", rr.err + return nil, rr.err } } @@ -1153,7 +1157,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { for { msg, err := rr.receiveMessage() if err != nil { - return "", rr.err + return nil, rr.err } switch msg.(type) { @@ -1176,7 +1180,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error } if err != nil { - rr.concludeCommand("", err) + rr.concludeCommand(nil, err) rr.cleanupContextDeadline() rr.closed = true if rr.multiResultReader == nil { @@ -1192,7 +1196,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.CommandComplete: rr.concludeCommand(CommandTag(msg.CommandTag), nil) case *pgproto3.ErrorResponse: - rr.concludeCommand("", errorResponseToPgError(msg)) + rr.concludeCommand(nil, errorResponseToPgError(msg)) } return msg, nil diff --git a/pgconn_test.go b/pgconn_test.go index 3be61be8..2b1e68a3 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -475,7 +475,7 @@ func TestConnExecParamsCanceled(t *testing.T) { } assert.Equal(t, 0, rowCount) commandTag, err := result.Close() - assert.Equal(t, pgconn.CommandTag(""), commandTag) + assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, context.DeadlineExceeded, err) assert.False(t, pgConn.IsAlive()) @@ -601,7 +601,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { } assert.Equal(t, 0, rowCount) commandTag, err := result.Close() - assert.Equal(t, pgconn.CommandTag(""), commandTag) + assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, context.DeadlineExceeded, err) assert.False(t, pgConn.IsAlive()) } @@ -958,7 +958,7 @@ func TestConnCopyToCanceled(t *testing.T) { defer cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") assert.Equal(t, context.DeadlineExceeded, err) - assert.Equal(t, pgconn.CommandTag(""), res) + assert.Equal(t, pgconn.CommandTag(nil), res) assert.False(t, pgConn.IsAlive()) } @@ -977,7 +977,7 @@ func TestConnCopyToPrecanceled(t *testing.T) { res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") require.Error(t, err) require.Equal(t, context.Canceled, err) - assert.Equal(t, pgconn.CommandTag(""), res) + assert.Equal(t, pgconn.CommandTag(nil), res) ensureConnValid(t, pgConn) } @@ -1084,7 +1084,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) { ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") require.Error(t, err) require.Equal(t, context.Canceled, err) - assert.Equal(t, pgconn.CommandTag(""), ct) + assert.Equal(t, pgconn.CommandTag(nil), ct) ensureConnValid(t, pgConn) } From 16412e56e22d0ae96c5c8bf95b512562c71cbc80 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 Apr 2019 14:24:51 -0500 Subject: [PATCH 0258/1158] 0 alloc context to deadline --- chan_to_set_deadline.go | 51 ++++++++++++++++ pgconn.go | 127 ++++++++++++++-------------------------- 2 files changed, 95 insertions(+), 83 deletions(-) create mode 100644 chan_to_set_deadline.go diff --git a/chan_to_set_deadline.go b/chan_to_set_deadline.go new file mode 100644 index 00000000..04bb8fde --- /dev/null +++ b/chan_to_set_deadline.go @@ -0,0 +1,51 @@ +package pgconn + +import ( + "time" +) + +var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) + +type setDeadliner interface { + SetDeadline(time.Time) error +} + +type chanToSetDeadline struct { + cleanupChan chan struct{} + conn setDeadliner + deadlineWasSet bool + cleanupComplete bool +} + +func (this *chanToSetDeadline) start(doneChan <-chan struct{}, conn setDeadliner) { + if this.cleanupChan == nil { + this.cleanupChan = make(chan struct{}) + } + this.conn = conn + this.deadlineWasSet = false + this.cleanupComplete = false + + if doneChan != nil { + go func() { + select { + case <-doneChan: + conn.SetDeadline(deadlineTime) + this.deadlineWasSet = true + <-this.cleanupChan + case <-this.cleanupChan: + } + }() + } else { + this.cleanupComplete = true + } +} + +func (this *chanToSetDeadline) cleanup() { + if !this.cleanupComplete { + this.cleanupChan <- struct{}{} + if this.deadlineWasSet { + this.conn.SetDeadline(time.Time{}) + } + this.cleanupComplete = true + } +} diff --git a/pgconn.go b/pgconn.go index 7bc93435..6ff0d39f 100644 --- a/pgconn.go +++ b/pgconn.go @@ -15,14 +15,11 @@ import ( "strconv" "strings" "sync" - "time" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" ) -var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) - // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. @@ -100,9 +97,10 @@ type PgConn struct { bufferingReceiveErr error // Reusable / preallocated resources - wbuf []byte // write buffer - resultReader ResultReader - multiResultReader MultiResultReader + wbuf []byte // write buffer + resultReader ResultReader + multiResultReader MultiResultReader + doneChanToDeadline chanToSetDeadline } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -382,8 +380,8 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContext() + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + defer pgConn.doneChanToDeadline.cleanup() _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { @@ -463,38 +461,6 @@ func preferContextOverNetTimeoutError(ctx context.Context, err error) error { return err } -// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from -// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to -// call multiple times. -func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) { - if ctx.Done() != nil { - deadlineWasSet := false - doneChan := make(chan struct{}) - go func() { - select { - case <-ctx.Done(): - conn.SetDeadline(deadlineTime) - deadlineWasSet = true - <-doneChan - case <-doneChan: - } - }() - - finished := false - return func() { - if !finished { - doneChan <- struct{}{} - if deadlineWasSet { - conn.SetDeadline(time.Time{}) - } - finished = true - } - } - } - - return func() {} -} - type PreparedStatementDescription struct { Name string SQL string @@ -512,8 +478,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ return nil, ctx.Err() default: } - cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContextDeadline() + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + defer pgConn.doneChanToDeadline.cleanup() buf := pgConn.wbuf buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) @@ -599,8 +565,9 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { } defer cancelConn.Close() - cleanupContext := contextDoneToConnDeadline(ctx, cancelConn) - defer cleanupContext() + var doneChanToDeadline chanToSetDeadline + doneChanToDeadline.start(ctx.Done(), cancelConn) + defer doneChanToDeadline.cleanup() buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) @@ -624,16 +591,16 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { pgConn.lock() + defer pgConn.unlock() select { case <-ctx.Done(): - pgConn.unlock() return ctx.Err() default: } - cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContextDeadline() - defer pgConn.unlock() + + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + defer pgConn.doneChanToDeadline.cleanup() for { msg, err := pgConn.ReceiveMessage() @@ -657,9 +624,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { pgConn.lock() pgConn.multiResultReader = MultiResultReader{ - pgConn: pgConn, - ctx: ctx, - cleanupContextDeadline: func() {}, + pgConn: pgConn, + ctx: ctx, } multiResult := &pgConn.multiResultReader @@ -671,7 +637,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { return multiResult default: } - multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) @@ -679,7 +645,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - multiResult.cleanupContextDeadline() + pgConn.doneChanToDeadline.cleanup() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) pgConn.unlock() @@ -753,9 +719,8 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by pgConn.lock() pgConn.resultReader = ResultReader{ - pgConn: pgConn, - ctx: ctx, - cleanupContextDeadline: func() {}, + pgConn: pgConn, + ctx: ctx, } result := &pgConn.resultReader @@ -774,7 +739,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result default: } - result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) return result } @@ -788,7 +753,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { if err != nil { pgConn.hardClose() result.concludeCommand(nil, err) - result.cleanupContextDeadline() + pgConn.doneChanToDeadline.cleanup() result.closed = true pgConn.unlock() } @@ -804,8 +769,8 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, ctx.Err() default: } - cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContextDeadline() + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + defer pgConn.doneChanToDeadline.cleanup() // Send copy to command buf := pgConn.wbuf @@ -861,8 +826,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co return nil, ctx.Err() default: } - cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContextDeadline() + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + defer pgConn.doneChanToDeadline.cleanup() // Send copy to command buf := pgConn.wbuf @@ -967,9 +932,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { - pgConn *PgConn - ctx context.Context - cleanupContextDeadline func() + pgConn *PgConn + ctx context.Context rr *ResultReader @@ -993,7 +957,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) msg, err := mrr.pgConn.ReceiveMessage() if err != nil { - mrr.cleanupContextDeadline() + mrr.pgConn.doneChanToDeadline.cleanup() mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.closed = true mrr.pgConn.hardClose() @@ -1002,7 +966,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - mrr.cleanupContextDeadline() + mrr.pgConn.doneChanToDeadline.cleanup() mrr.closed = true mrr.pgConn.unlock() case *pgproto3.ErrorResponse: @@ -1023,11 +987,10 @@ func (mrr *MultiResultReader) NextResult() bool { switch msg := msg.(type) { case *pgproto3.RowDescription: mrr.pgConn.resultReader = ResultReader{ - pgConn: mrr.pgConn, - multiResultReader: mrr, - ctx: mrr.ctx, - cleanupContextDeadline: func() {}, - fieldDescriptions: msg.Fields, + pgConn: mrr.pgConn, + multiResultReader: mrr, + ctx: mrr.ctx, + fieldDescriptions: msg.Fields, } mrr.rr = &mrr.pgConn.resultReader return true @@ -1066,10 +1029,9 @@ func (mrr *MultiResultReader) Close() error { // ResultReader is a reader for the result of a single query. type ResultReader struct { - pgConn *PgConn - multiResultReader *MultiResultReader - ctx context.Context - cleanupContextDeadline func() + pgConn *PgConn + multiResultReader *MultiResultReader + ctx context.Context fieldDescriptions []pgproto3.FieldDescription rowValues [][]byte @@ -1162,7 +1124,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { switch msg.(type) { case *pgproto3.ReadyForQuery: - rr.cleanupContextDeadline() + rr.pgConn.doneChanToDeadline.cleanup() rr.pgConn.unlock() return rr.commandTag, rr.err } @@ -1181,7 +1143,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error if err != nil { rr.concludeCommand(nil, err) - rr.cleanupContextDeadline() + rr.pgConn.doneChanToDeadline.cleanup() rr.closed = true if rr.multiResultReader == nil { rr.pgConn.hardClose() @@ -1238,9 +1200,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR pgConn.lock() pgConn.multiResultReader = MultiResultReader{ - pgConn: pgConn, - ctx: ctx, - cleanupContextDeadline: func() {}, + pgConn: pgConn, + ctx: ctx, } multiResult := &pgConn.multiResultReader @@ -1252,13 +1213,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR return multiResult default: } - multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) _, err := pgConn.conn.Write(batch.buf) if err != nil { pgConn.hardClose() - multiResult.cleanupContextDeadline() + pgConn.doneChanToDeadline.cleanup() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) pgConn.unlock() From 7bb6c2f3e9826f233e799c894439f87ac93e007f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 Apr 2019 15:52:12 -0500 Subject: [PATCH 0259/1158] Unify locked and closed into status No longer panic on locking busy conn --- pgconn.go | 81 ++++++++++++++++++++++++++++++++++++-------------- pgconn_test.go | 6 ++-- 2 files changed, 63 insertions(+), 24 deletions(-) diff --git a/pgconn.go b/pgconn.go index 6ff0d39f..7a9a42e4 100644 --- a/pgconn.go +++ b/pgconn.go @@ -20,6 +20,13 @@ import ( "github.com/jackc/pgproto3/v2" ) +const ( + connStatusUninitialized = iota + connStatusClosed + connStatusIdle + connStatusBusy +) + // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. @@ -88,8 +95,7 @@ type PgConn struct { Config *Config - locked bool - closed bool + status byte // One of connStatus* constants bufferingReceive bool bufferingReceiveMux sync.Mutex @@ -217,6 +223,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, err } case *pgproto3.ReadyForQuery: + pgConn.status = connStatusIdle if config.AfterConnectFunc != nil { err := config.AfterConnectFunc(ctx, pgConn) if err != nil { @@ -373,10 +380,10 @@ func (pgConn *PgConn) SecretKey() uint32 { // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The // underlying net.Conn.Close() will always be called regardless of any other errors. func (pgConn *PgConn) Close(ctx context.Context) error { - if pgConn.closed { + if pgConn.status == connStatusClosed { return nil } - pgConn.closed = true + pgConn.status = connStatusClosed defer pgConn.conn.Close() @@ -398,34 +405,41 @@ func (pgConn *PgConn) Close(ctx context.Context) error { // hardClose closes the underlying connection without sending the exit message. func (pgConn *PgConn) hardClose() error { - if pgConn.closed { + if pgConn.status == connStatusClosed { return nil } - pgConn.closed = true + pgConn.status = connStatusClosed return pgConn.conn.Close() } // TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect deatch of // underlying connection. func (pgConn *PgConn) IsAlive() bool { - return !pgConn.closed + return pgConn.status >= connStatusIdle } // lock locks the connection. It panics if the connection is already locked or is closed. -func (pgConn *PgConn) lock() { - if pgConn.locked { - panic("connection busy") // This only should be possible in case of an application bug. +func (pgConn *PgConn) lock() error { + switch pgConn.status { + case connStatusBusy: + return errors.New("connection busy") // This only should be possible in case of an application bug. + case connStatusClosed: + return errors.New("conn closed") + case connStatusUninitialized: + return errors.New("conn uninitialized") } - - pgConn.locked = true + pgConn.status = connStatusBusy + return nil } func (pgConn *PgConn) unlock() { - if !pgConn.locked { + switch pgConn.status { + case connStatusBusy: + pgConn.status = connStatusIdle + case connStatusClosed: + default: panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package. } - - pgConn.locked = false } // ParameterStatus returns the value of a parameter reported by the server (e.g. @@ -470,7 +484,9 @@ type PreparedStatementDescription struct { // Prepare creates a prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return nil, err + } defer pgConn.unlock() select { @@ -590,7 +606,9 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return err + } defer pgConn.unlock() select { @@ -621,7 +639,12 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { // // Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } pgConn.multiResultReader = MultiResultReader{ pgConn: pgConn, @@ -716,7 +739,12 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa } func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return &ResultReader{ + closed: true, + err: err, + } + } pgConn.resultReader = ResultReader{ pgConn: pgConn, @@ -761,7 +789,9 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return nil, err + } select { case <-ctx.Done(): @@ -818,7 +848,9 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return nil, err + } defer pgConn.unlock() select { @@ -1197,7 +1229,12 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // transaction is already in progress or SQL contains transaction control statements. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } pgConn.multiResultReader = MultiResultReader{ pgConn: pgConn, diff --git a/pgconn_test.go b/pgconn_test.go index 2b1e68a3..2ad02830 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -690,9 +690,11 @@ func TestConnLocking(t *testing.T) { defer closeConn(t, pgConn) mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") - require.Panics(t, func() { pgConn.Exec(context.Background(), "select 'Hello, world'") }) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "connection busy", err.Error()) - results, err := mrr.ReadAll() + results, err = mrr.ReadAll() assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err) From 3710e52a9a125c406f4a6f682ca9a67e695c38f6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 Apr 2019 16:16:55 -0500 Subject: [PATCH 0260/1158] Add named error for conn busy --- pgconn.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 7a9a42e4..7d437434 100644 --- a/pgconn.go +++ b/pgconn.go @@ -84,6 +84,10 @@ type NotificationHandler func(*PgConn, *Notification) // PostgreSQL server refuses to use TLS var ErrTLSRefused = errors.New("server refused TLS connection") +// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another +// action is attempted. +var ErrConnBusy = errors.New("conn is busy") + // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn // the underlying TCP or unix domain socket connection @@ -422,7 +426,7 @@ func (pgConn *PgConn) IsAlive() bool { func (pgConn *PgConn) lock() error { switch pgConn.status { case connStatusBusy: - return errors.New("connection busy") // This only should be possible in case of an application bug. + return ErrConnBusy // This only should be possible in case of an application bug. case connStatusClosed: return errors.New("conn closed") case connStatusUninitialized: From 9f774761bacc37fb32d6a8718e3aa9ccd9035de2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 10:59:50 -0500 Subject: [PATCH 0261/1158] Fix TestConnLocking --- pgconn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn_test.go b/pgconn_test.go index 2ad02830..d31e8cc9 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -692,7 +692,7 @@ func TestConnLocking(t *testing.T) { mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.Error(t, err) - assert.Equal(t, "connection busy", err.Error()) + assert.Equal(t, pgconn.ErrConnBusy, err) results, err = mrr.ReadAll() assert.NoError(t, err) From 39e6ff5766bde4b27a085b021b11b5b3be18a276 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 11:11:09 -0500 Subject: [PATCH 0262/1158] Prevent deadlock with huge batches --- pgconn.go | 22 +++++++++++--------- pgconn_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 9 deletions(-) diff --git a/pgconn.go b/pgconn.go index 7d437434..4f3cdd66 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1257,15 +1257,19 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) - _, err := pgConn.conn.Write(batch.buf) - if err != nil { - pgConn.hardClose() - pgConn.doneChanToDeadline.cleanup() - multiResult.closed = true - multiResult.err = preferContextOverNetTimeoutError(ctx, err) - pgConn.unlock() - return multiResult - } + + // A large batch can deadlock without concurrent reading and writing. If the Write fails the underlying net.Conn is + // closed. This is all that can be done without introducing a race condition or adding a concurrent safe communication + // channel to relay the error back. The practical effect of this is that the underlying Write error is not reported. + // The error the code reading the batch results receives will be a closed connection error. + // + // See https://github.com/jackc/pgx/issues/374. + go func() { + _, err := pgConn.conn.Write(batch.buf) + if err != nil { + pgConn.conn.Close() + } + }() return multiResult } diff --git a/pgconn_test.go b/pgconn_test.go index d31e8cc9..25cc3ee3 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -682,6 +682,60 @@ func TestConnExecBatchPrecanceled(t *testing.T) { ensureConnValid(t, pgConn) } +// Without concurrent reading and writing large batches can deadlock. +// +// See https://github.com/jackc/pgx/issues/374. +func TestConnExecBatchHuge(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + batch := &pgconn.Batch{} + + queryCount := 100000 + args := make([]string, queryCount) + + for i := range args { + args[i] = strconv.Itoa(i) + batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) + } + + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, queryCount) + + for i := range args { + require.Len(t, results[i].Rows, 1) + require.Equal(t, args[i], string(results[i].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) + } +} + +func TestConnExecBatchImplicitTransaction(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll() + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) + batch.ExecParams("select 1/0", nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.Error(t, err) + + result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read() + require.Equal(t, "0", string(result.Rows[0][0])) +} + func TestConnLocking(t *testing.T) { t.Parallel() From 6161728ff9ce457b0dd20c782698d6faf5e7833d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 11:47:16 -0500 Subject: [PATCH 0263/1158] Prepare takes context Also remove PrepareEx. It's primary usage was for context. Supplying parameter OIDs is unnecessary when you can type cast in the query SQL. If it does become necessary or desirable to add options back it can be added in a backwards compatible way by adding a varargs as last argument. --- hstore_array_test.go | 2 +- record_test.go | 2 +- testutil/testutil.go | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/hstore_array_test.go b/hstore_array_test.go index 03dc2ff1..849b5835 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -70,7 +70,7 @@ func TestHstoreArrayTranscode(t *testing.T) { Status: pgtype.Present, } - ps, err := conn.Prepare("test", "select $1::hstore[]") + ps, err := conn.Prepare(context.Background(), "test", "select $1::hstore[]") if err != nil { t.Fatal(err) } diff --git a/record_test.go b/record_test.go index 44b0e9d8..a4fc1e5d 100644 --- a/record_test.go +++ b/record_test.go @@ -85,7 +85,7 @@ func TestRecordTranscode(t *testing.T) { for i, tt := range tests { psName := fmt.Sprintf("test%d", i) - ps, err := conn.Prepare(psName, tt.sql) + ps, err := conn.Prepare(context.Background(), psName, tt.sql) if err != nil { t.Fatal(err) } diff --git a/testutil/testutil.go b/testutil/testutil.go index 3711381c..0d653394 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -107,7 +107,7 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] conn := MustConnectPgx(t) defer MustCloseContext(t, conn) - _, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) + _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName)) if err != nil { t.Fatal(err) } @@ -225,7 +225,7 @@ func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFun for i, tt := range tests { for _, fc := range formats { psName := fmt.Sprintf("test%d", i) - ps, err := conn.Prepare(psName, tt.SQL) + ps, err := conn.Prepare(context.Background(), psName, tt.SQL) if err != nil { t.Fatal(err) } From cd629965e6c1920f124691c4004507467fe2069c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 12:57:52 -0500 Subject: [PATCH 0264/1158] Use golang.org/x/xerrors --- auth_scram.go | 6 +++--- config.go | 14 +++++++------- go.mod | 3 ++- go.sum | 2 ++ pgconn.go | 9 ++++----- 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index 50fbff40..5baa680b 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -17,13 +17,13 @@ import ( "crypto/rand" "crypto/sha256" "encoding/base64" - "errors" "fmt" "strconv" "github.com/jackc/pgproto3/v2" "golang.org/x/crypto/pbkdf2" "golang.org/x/text/secure/precis" + errors "golang.org/x/xerrors" ) const clientNonceLen = 18 @@ -181,12 +181,12 @@ func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { var err error sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr)) if err != nil { - return fmt.Errorf("invalid SCRAM salt received from server: %v", err) + return errors.Errorf("invalid SCRAM salt received from server: %w", err) } sc.iterations, err = strconv.Atoi(string(iterationsStr)) if err != nil || sc.iterations <= 0 { - return fmt.Errorf("invalid SCRAM iteration count received from server: %s", iterationsStr) + return errors.Errorf("invalid SCRAM iteration count received from server: %w", err) } if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) { diff --git a/config.go b/config.go index d392924c..c751cc0d 100644 --- a/config.go +++ b/config.go @@ -18,7 +18,7 @@ import ( "time" "github.com/jackc/pgpassfile" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error @@ -195,7 +195,7 @@ func ParseConfig(connString string) (*Config, error) { port, err := parsePort(portStr) if err != nil { - return nil, fmt.Errorf("invalid port: %v", settings["port"]) + return nil, errors.Errorf("invalid port: %w", err) } var tlsConfigs []*tls.Config @@ -240,7 +240,7 @@ func ParseConfig(connString string) (*Config, error) { if settings["target_session_attrs"] == "read-write" { config.AfterConnectFunc = AfterConnectTargetSessionAttrsReadWrite } else if settings["target_session_attrs"] != "any" { - return nil, fmt.Errorf("unknown target_session_attrs value %v", settings["target_session_attrs"]) + return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"]) } return config, nil @@ -409,11 +409,11 @@ func configTLS(settings map[string]string) ([]*tls.Config, error) { caPath := sslrootcert caCert, err := ioutil.ReadFile(caPath) if err != nil { - return nil, errors.Wrapf(err, "unable to read CA file %q", caPath) + return nil, errors.Errorf("unable to read CA file: %w", err) } if !caCertPool.AppendCertsFromPEM(caCert) { - return nil, errors.Wrap(err, "unable to add CA to cert pool") + return nil, errors.Errorf("unable to add CA to cert pool: %w", err) } tlsConfig.RootCAs = caCertPool @@ -421,13 +421,13 @@ func configTLS(settings map[string]string) ([]*tls.Config, error) { } if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { - return nil, fmt.Errorf(`both "sslcert" and "sslkey" are required`) + return nil, errors.New(`both "sslcert" and "sslkey" are required`) } if sslcert != "" && sslkey != "" { cert, err := tls.LoadX509KeyPair(sslcert, sslkey) if err != nil { - return nil, errors.Wrap(err, "unable to read cert") + return nil, errors.Errorf("unable to read cert: %w", err) } tlsConfig.Certificates = []tls.Certificate{cert} diff --git a/go.mod b/go.mod index 232df737..dda76fe1 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,11 @@ go 1.12 require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3 v1.1.0 + github.com/jackc/pgproto3 v1.1.0 // indirect github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 + golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 ) diff --git a/go.sum b/go.sum index 8e0e2c9f..5a100ff0 100644 --- a/go.sum +++ b/go.sum @@ -22,3 +22,5 @@ golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaE golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pgconn.go b/pgconn.go index 4f3cdd66..14377beb 100644 --- a/pgconn.go +++ b/pgconn.go @@ -7,8 +7,6 @@ import ( "crypto/tls" "encoding/binary" "encoding/hex" - "errors" - "fmt" "io" "math" "net" @@ -18,6 +16,7 @@ import ( "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" + errors "golang.org/x/xerrors" ) const ( @@ -232,7 +231,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig err := config.AfterConnectFunc(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, fmt.Errorf("AfterConnectFunc: %v", err) + return nil, errors.Errorf("AfterConnectFunc: %v", err) } } return pgConn, nil @@ -601,7 +600,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { _, err = cancelConn.Read(buf) if err != io.EOF { - return fmt.Errorf("Server failed to close connection after cancel query request: %v", preferContextOverNetTimeoutError(ctx, err)) + return errors.Errorf("Server failed to close connection after cancel query request: %w", preferContextOverNetTimeoutError(ctx, err)) } return nil @@ -757,7 +756,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by result := &pgConn.resultReader if len(paramValues) > math.MaxUint16 { - result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.concludeCommand(nil, errors.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true pgConn.unlock() return result From c116219b62db0c6cd4f646fd66a94b529020fc6d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 13:01:11 -0500 Subject: [PATCH 0265/1158] Update tests to use v2 --- backend_test.go | 2 +- frontend_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend_test.go b/backend_test.go index 6cba81b6..cd8788b8 100644 --- a/backend_test.go +++ b/backend_test.go @@ -3,7 +3,7 @@ package pgproto3_test import ( "testing" - "github.com/jackc/pgproto3" + "github.com/jackc/pgproto3/v2" ) func TestBackendReceiveInterrupted(t *testing.T) { diff --git a/frontend_test.go b/frontend_test.go index d3e57f81..2d5c8de7 100644 --- a/frontend_test.go +++ b/frontend_test.go @@ -5,7 +5,7 @@ import ( "github.com/pkg/errors" - "github.com/jackc/pgproto3" + "github.com/jackc/pgproto3/v2" ) type interruptReader struct { From 7a520059d9115a271068a920b7583a911bd3a509 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 13:01:59 -0500 Subject: [PATCH 0266/1158] Update to remove pgprotov3 ref --- go.mod | 3 +-- go.sum | 6 ++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index dda76fe1..acbee593 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,7 @@ go 1.12 require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3 v1.1.0 // indirect - github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf + github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a diff --git a/go.sum b/go.sum index 5a100ff0..9160f187 100644 --- a/go.sum +++ b/go.sum @@ -6,10 +6,8 @@ github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= -github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf h1:wI8d/uq9/RfZOe6bKOpC4Skd4VgkTIGZqxmHu6IQGb8= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= From f3b5f6b2753fb81b66507c5f42a55af75241d01c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 15:34:49 -0500 Subject: [PATCH 0267/1158] Allow skipping TestConnExecBatchHuge in short mode --- pgconn_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pgconn_test.go b/pgconn_test.go index 25cc3ee3..3fc15e7a 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -686,6 +686,10 @@ func TestConnExecBatchPrecanceled(t *testing.T) { // // See https://github.com/jackc/pgx/issues/374. func TestConnExecBatchHuge(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) From 0f8e1d30e2dc1a4f359761d5418126bb0e0685d5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 15:53:30 -0500 Subject: [PATCH 0268/1158] Link context errors and underlying conn errors Using golang.org/x/xerrors type errors both errors can be exposed. --- errors.go | 85 ++++++++++++++++++++++++++++++++++++++++++++++++++ pgconn.go | 76 ++++++++++---------------------------------- pgconn_test.go | 8 ++--- 3 files changed, 105 insertions(+), 64 deletions(-) create mode 100644 errors.go diff --git a/errors.go b/errors.go new file mode 100644 index 00000000..e42dae16 --- /dev/null +++ b/errors.go @@ -0,0 +1,85 @@ +package pgconn + +import ( + "context" + "net" + + errors "golang.org/x/xerrors" +) + +// ErrTLSRefused occurs when the connection attempt requires TLS and the +// PostgreSQL server refuses to use TLS +var ErrTLSRefused = errors.New("server refused TLS connection") + +// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another +// action is attempted. +var ErrConnBusy = errors.New("conn is busy") + +// PgError represents an error reported by the PostgreSQL server. See +// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for +// detailed field description. +type PgError struct { + Severity string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string +} + +func (pe *PgError) Error() string { + return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" +} + +// linkedError connects two errors as if err wrapped next. +type linkedError struct { + err error + next error +} + +func (le *linkedError) Error() string { + return le.err.Error() +} + +func (le *linkedError) Is(target error) bool { + return errors.Is(le.err, target) +} + +func (le *linkedError) As(target interface{}) bool { + return errors.As(le.err, target) +} + +func (le *linkedError) Unwrap() error { + return le.next +} + +// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == +// true. Otherwise returns err. +func preferContextOverNetTimeoutError(ctx context.Context, err error) error { + if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { + return ctx.Err() + } + return err +} + +// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned. +func linkErrors(outer, inner error) error { + if outer == nil { + return inner + } + if inner == nil { + return outer + } + return &linkedError{err: outer, next: inner} +} diff --git a/pgconn.go b/pgconn.go index 14377beb..2911211c 100644 --- a/pgconn.go +++ b/pgconn.go @@ -26,33 +26,6 @@ const ( connStatusBusy ) -// PgError represents an error reported by the PostgreSQL server. See -// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for -// detailed field description. -type PgError struct { - Severity string - Code string - Message string - Detail string - Hint string - Position int32 - InternalPosition int32 - InternalQuery string - Where string - SchemaName string - TableName string - ColumnName string - DataTypeName string - ConstraintName string - File string - Line int32 - Routine string -} - -func (pe *PgError) Error() string { - return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" -} - // Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from // LISTEN/NOTIFY notification. type Notice PgError @@ -79,14 +52,6 @@ type NoticeHandler func(*PgConn, *Notice) // notice event. type NotificationHandler func(*PgConn, *Notification) -// ErrTLSRefused occurs when the connection attempt requires TLS and the -// PostgreSQL server refuses to use TLS -var ErrTLSRefused = errors.New("server refused TLS connection") - -// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another -// action is attempted. -var ErrConnBusy = errors.New("conn is busy") - // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn // the underlying TCP or unix domain socket connection @@ -395,12 +360,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error { _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return linkErrors(ctx.Err(), err) } _, err = pgConn.conn.Read(make([]byte, 1)) if err != io.EOF { - return preferContextOverNetTimeoutError(ctx, err) + return linkErrors(ctx.Err(), err) } return pgConn.conn.Close() @@ -469,15 +434,6 @@ func (ct CommandTag) String() string { return string(ct) } -// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == -// true. Otherwise returns err. -func preferContextOverNetTimeoutError(ctx context.Context, err error) error { - if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { - return ctx.Err() - } - return err -} - type PreparedStatementDescription struct { Name string SQL string @@ -508,7 +464,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } psd := &PreparedStatementDescription{Name: name, SQL: sql} @@ -520,7 +476,7 @@ readloop: msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { @@ -595,12 +551,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) _, err = cancelConn.Write(buf) if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return linkErrors(ctx.Err(), err) } _, err = cancelConn.Read(buf) if err != io.EOF { - return errors.Errorf("Server failed to close connection after cancel query request: %w", preferContextOverNetTimeoutError(ctx, err)) + return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err)) } return nil @@ -626,7 +582,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { for { msg, err := pgConn.ReceiveMessage() if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return linkErrors(ctx.Err(), err) } switch msg.(type) { @@ -673,7 +629,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { pgConn.hardClose() pgConn.doneChanToDeadline.cleanup() multiResult.closed = true - multiResult.err = preferContextOverNetTimeoutError(ctx, err) + multiResult.err = linkErrors(ctx.Err(), err) pgConn.unlock() return multiResult } @@ -814,7 +770,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm pgConn.hardClose() pgConn.unlock() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } // Read results @@ -824,7 +780,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { @@ -871,7 +827,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } // Read until copy in response or error. @@ -882,7 +838,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { @@ -912,7 +868,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } } @@ -921,7 +877,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { @@ -943,7 +899,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } // Read results @@ -951,7 +907,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { diff --git a/pgconn_test.go b/pgconn_test.go index 3fc15e7a..30e6a425 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -18,7 +18,7 @@ import ( "time" "github.com/jackc/pgconn" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -907,7 +907,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) err = pgConn.WaitForNotification(ctx) cancel() - require.Equal(t, context.DeadlineExceeded, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) ensureConnValid(t, pgConn) } @@ -1017,7 +1017,7 @@ func TestConnCopyToCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") - assert.Equal(t, context.DeadlineExceeded, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.Equal(t, pgconn.CommandTag(nil), res) assert.False(t, pgConn.IsAlive()) @@ -1108,7 +1108,7 @@ func TestConnCopyFromCanceled(t *testing.T) { ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") cancel() assert.Equal(t, int64(0), ct.RowsAffected()) - require.Equal(t, context.DeadlineExceeded, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.False(t, pgConn.IsAlive()) } From 7e0022ef6ba389ca1b8140e50e42624af1df312e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 16:48:24 -0500 Subject: [PATCH 0269/1158] Tag errors if no bytes sent to server --- errors.go | 4 ++++ pgconn.go | 58 +++++++++++++++++++++++++++++++------------------- pgconn_test.go | 30 ++++++++++++++++---------- 3 files changed, 59 insertions(+), 33 deletions(-) diff --git a/errors.go b/errors.go index e42dae16..4f8af407 100644 --- a/errors.go +++ b/errors.go @@ -15,6 +15,10 @@ var ErrTLSRefused = errors.New("server refused TLS connection") // action is attempted. var ErrConnBusy = errors.New("conn is busy") +// ErrNoBytesSent is used to annotate an error that occurred without sending any bytes to the server. This can be used +// to implement safe retry logic. ErrNoBytesSent will never occur alone. It will always be wrapped by another error. +var ErrNoBytesSent = errors.New("no bytes sent to server") + // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. diff --git a/pgconn.go b/pgconn.go index 2911211c..a4402a7d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -444,13 +444,13 @@ type PreparedStatementDescription struct { // Prepare creates a prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { if err := pgConn.lock(); err != nil { - return nil, err + return nil, linkErrors(err, ErrNoBytesSent) } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, ctx.Err() + return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) @@ -461,9 +461,12 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - _, err := pgConn.conn.Write(buf) + n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } return nil, linkErrors(ctx.Err(), err) } @@ -601,7 +604,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, - err: err, + err: linkErrors(err, ErrNoBytesSent), } } @@ -614,7 +617,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = ctx.Err() + multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) pgConn.unlock() return multiResult default: @@ -624,11 +627,14 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) - _, err := pgConn.conn.Write(buf) + n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() pgConn.doneChanToDeadline.cleanup() multiResult.closed = true + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } multiResult.err = linkErrors(ctx.Err(), err) pgConn.unlock() return multiResult @@ -666,7 +672,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - pgConn.execExtendedSuffix(buf, result) + pgConn.execExtendedSuffix(ctx, buf, result) return result } @@ -692,7 +698,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa buf := pgConn.wbuf buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - pgConn.execExtendedSuffix(buf, result) + pgConn.execExtendedSuffix(ctx, buf, result) return result } @@ -701,7 +707,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by if err := pgConn.lock(); err != nil { return &ResultReader{ closed: true, - err: err, + err: linkErrors(err, ErrNoBytesSent), } } @@ -720,7 +726,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by select { case <-ctx.Done(): - result.concludeCommand(nil, ctx.Err()) + result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent)) result.closed = true pgConn.unlock() return result @@ -731,15 +737,18 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } -func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { +func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result *ResultReader) { buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - _, err := pgConn.conn.Write(buf) + n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - result.concludeCommand(nil, err) + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } + result.concludeCommand(nil, linkErrors(ctx.Err(), err)) pgConn.doneChanToDeadline.cleanup() result.closed = true pgConn.unlock() @@ -749,13 +758,13 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, err + return nil, linkErrors(err, ErrNoBytesSent) } select { case <-ctx.Done(): pgConn.unlock() - return nil, ctx.Err() + return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) @@ -765,11 +774,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) - _, err := pgConn.conn.Write(buf) + n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() pgConn.unlock() - + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } return nil, linkErrors(ctx.Err(), err) } @@ -808,13 +819,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, err + return nil, linkErrors(err, ErrNoBytesSent) } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, ctx.Err() + return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) @@ -824,9 +835,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) - _, err := pgConn.conn.Write(buf) + n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } return nil, linkErrors(ctx.Err(), err) } @@ -1191,7 +1205,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, - err: err, + err: linkErrors(err, ErrNoBytesSent), } } @@ -1204,7 +1218,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = ctx.Err() + multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) pgConn.unlock() return multiResult default: diff --git a/pgconn_test.go b/pgconn_test.go index 30e6a425..b7cb4036 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -264,9 +264,10 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil) - require.Nil(t, psd) - require.Error(t, err) - require.Equal(t, context.Canceled, err) + assert.Nil(t, psd) + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) ensureConnValid(t, pgConn) } @@ -386,8 +387,9 @@ func TestConnExecContextPrecanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() - require.Error(t, err) - require.Equal(t, context.Canceled, err) + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) ensureConnValid(t, pgConn) } @@ -492,7 +494,8 @@ func TestConnExecParamsPrecanceled(t *testing.T) { cancel() result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() require.Error(t, result.Err) - require.Equal(t, context.Canceled, result.Err) + assert.True(t, errors.Is(result.Err, context.Canceled)) + assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) ensureConnValid(t, pgConn) } @@ -620,7 +623,8 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { cancel() result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() require.Error(t, result.Err) - require.Equal(t, context.Canceled, result.Err) + assert.True(t, errors.Is(result.Err, context.Canceled)) + assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) ensureConnValid(t, pgConn) } @@ -677,7 +681,8 @@ func TestConnExecBatchPrecanceled(t *testing.T) { cancel() _, err = pgConn.ExecBatch(ctx, batch).ReadAll() require.Error(t, err) - require.Equal(t, context.Canceled, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) ensureConnValid(t, pgConn) } @@ -750,7 +755,8 @@ func TestConnLocking(t *testing.T) { mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.Error(t, err) - assert.Equal(t, pgconn.ErrConnBusy, err) + assert.True(t, errors.Is(err, pgconn.ErrConnBusy)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) results, err = mrr.ReadAll() assert.NoError(t, err) @@ -1036,7 +1042,8 @@ func TestConnCopyToPrecanceled(t *testing.T) { cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") require.Error(t, err) - require.Equal(t, context.Canceled, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.Equal(t, pgconn.CommandTag(nil), res) ensureConnValid(t, pgConn) @@ -1143,7 +1150,8 @@ func TestConnCopyFromPrecanceled(t *testing.T) { cancel() ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") require.Error(t, err) - require.Equal(t, context.Canceled, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.Equal(t, pgconn.CommandTag(nil), ct) ensureConnValid(t, pgConn) From 8502a12ac7723377772c6f70dd3b72491fd9d31e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 17:41:08 -0500 Subject: [PATCH 0270/1158] Fix go modules Wow. This is fun. Sure is easy to get modules wrong when upgrading a v2+ project. --- aclitem_array_test.go | 4 ++-- aclitem_test.go | 4 ++-- array_test.go | 2 +- bit_test.go | 4 ++-- bool_array_test.go | 4 ++-- bool_test.go | 4 ++-- box_test.go | 4 ++-- bpchar_array_test.go | 4 ++-- bpchar_test.go | 4 ++-- bytea_array_test.go | 4 ++-- bytea_test.go | 4 ++-- cid_test.go | 4 ++-- cidr_array_test.go | 4 ++-- circle_test.go | 4 ++-- date_array_test.go | 4 ++-- date_test.go | 4 ++-- daterange_test.go | 4 ++-- enum_array_test.go | 4 ++-- ext/satori-uuid/uuid.go | 2 +- ext/satori-uuid/uuid_test.go | 6 +++--- ext/shopspring-numeric/decimal.go | 2 +- ext/shopspring-numeric/decimal_test.go | 6 +++--- float4_array_test.go | 4 ++-- float4_test.go | 4 ++-- float8_array_test.go | 4 ++-- float8_test.go | 4 ++-- hstore_array_test.go | 6 +++--- hstore_test.go | 4 ++-- inet_array_test.go | 4 ++-- inet_test.go | 4 ++-- int2_array_test.go | 4 ++-- int2_test.go | 4 ++-- int4_array_test.go | 4 ++-- int4_test.go | 4 ++-- int4range_test.go | 4 ++-- int8_array_test.go | 4 ++-- int8_test.go | 4 ++-- int8range_test.go | 4 ++-- interval_test.go | 4 ++-- json_test.go | 4 ++-- jsonb_test.go | 4 ++-- line_test.go | 4 ++-- lseg_test.go | 4 ++-- macaddr_array_test.go | 4 ++-- macaddr_test.go | 4 ++-- name_test.go | 4 ++-- numeric_array_test.go | 4 ++-- numeric_test.go | 4 ++-- numrange_test.go | 4 ++-- oid_value_test.go | 4 ++-- path_test.go | 4 ++-- pgtype_test.go | 2 +- point_test.go | 4 ++-- polygon_test.go | 4 ++-- qchar_test.go | 4 ++-- record_test.go | 6 +++--- testutil/testutil.go | 6 +++--- text_array_test.go | 4 ++-- text_test.go | 4 ++-- tid_test.go | 4 ++-- timestamp_array_test.go | 4 ++-- timestamp_test.go | 4 ++-- timestamptz_array_test.go | 4 ++-- timestamptz_test.go | 4 ++-- tsrange_test.go | 4 ++-- tstzrange_test.go | 4 ++-- uuid_array_test.go | 4 ++-- uuid_test.go | 4 ++-- varbit_test.go | 4 ++-- varchar_array_test.go | 4 ++-- xid_test.go | 4 ++-- 71 files changed, 143 insertions(+), 143 deletions(-) diff --git a/aclitem_array_test.go b/aclitem_array_test.go index 4e60afca..5f16ab28 100644 --- a/aclitem_array_test.go +++ b/aclitem_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestACLItemArrayTranscode(t *testing.T) { diff --git a/aclitem_test.go b/aclitem_test.go index 65399a30..92dfc7a5 100644 --- a/aclitem_test.go +++ b/aclitem_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestACLItemTranscode(t *testing.T) { diff --git a/array_test.go b/array_test.go index d1cdb4c5..d17d753c 100644 --- a/array_test.go +++ b/array_test.go @@ -4,7 +4,7 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/v4/pgtype" ) func TestParseUntypedTextArray(t *testing.T) { diff --git a/bit_test.go b/bit_test.go index 19492bc9..05729323 100644 --- a/bit_test.go +++ b/bit_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestBitTranscode(t *testing.T) { diff --git a/bool_array_test.go b/bool_array_test.go index b529555e..6d2d7c06 100644 --- a/bool_array_test.go +++ b/bool_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestBoolArrayTranscode(t *testing.T) { diff --git a/bool_test.go b/bool_test.go index 04d9337d..5228e280 100644 --- a/bool_test.go +++ b/bool_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestBoolTranscode(t *testing.T) { diff --git a/box_test.go b/box_test.go index 197401f3..aad10262 100644 --- a/box_test.go +++ b/box_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestBoxTranscode(t *testing.T) { diff --git a/bpchar_array_test.go b/bpchar_array_test.go index e4f2e7eb..820dfa5b 100644 --- a/bpchar_array_test.go +++ b/bpchar_array_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestBPCharArrayTranscode(t *testing.T) { diff --git a/bpchar_test.go b/bpchar_test.go index c076ca1b..e8981e52 100644 --- a/bpchar_test.go +++ b/bpchar_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestChar3Transcode(t *testing.T) { diff --git a/bytea_array_test.go b/bytea_array_test.go index 8450b71b..00dc0a1f 100644 --- a/bytea_array_test.go +++ b/bytea_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestByteaArrayTranscode(t *testing.T) { diff --git a/bytea_test.go b/bytea_test.go index fd5a0dec..75b55de4 100644 --- a/bytea_test.go +++ b/bytea_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestByteaTranscode(t *testing.T) { diff --git a/cid_test.go b/cid_test.go index 924e4cf3..588e6c66 100644 --- a/cid_test.go +++ b/cid_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestCIDTranscode(t *testing.T) { diff --git a/cidr_array_test.go b/cidr_array_test.go index 206a590f..71125bdb 100644 --- a/cidr_array_test.go +++ b/cidr_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestCIDRArrayTranscode(t *testing.T) { diff --git a/circle_test.go b/circle_test.go index 634c5832..82598620 100644 --- a/circle_test.go +++ b/circle_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestCircleTranscode(t *testing.T) { diff --git a/date_array_test.go b/date_array_test.go index 2ba19d1a..24a8282c 100644 --- a/date_array_test.go +++ b/date_array_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestDateArrayTranscode(t *testing.T) { diff --git a/date_test.go b/date_test.go index d98e1652..ac7aadfe 100644 --- a/date_test.go +++ b/date_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestDateTranscode(t *testing.T) { diff --git a/daterange_test.go b/daterange_test.go index d2af5986..4d3119ee 100644 --- a/daterange_test.go +++ b/daterange_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestDaterangeTranscode(t *testing.T) { diff --git a/enum_array_test.go b/enum_array_test.go index 052a813c..dbe09751 100644 --- a/enum_array_test.go +++ b/enum_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestEnumArrayTranscode(t *testing.T) { diff --git a/ext/satori-uuid/uuid.go b/ext/satori-uuid/uuid.go index baebc5ed..8713b4d6 100644 --- a/ext/satori-uuid/uuid.go +++ b/ext/satori-uuid/uuid.go @@ -5,7 +5,7 @@ import ( "github.com/pkg/errors" - "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/v4/pgtype" uuid "github.com/satori/go.uuid" ) diff --git a/ext/satori-uuid/uuid_test.go b/ext/satori-uuid/uuid_test.go index 02ebb770..7a770b84 100644 --- a/ext/satori-uuid/uuid_test.go +++ b/ext/satori-uuid/uuid_test.go @@ -4,9 +4,9 @@ import ( "bytes" "testing" - "github.com/jackc/pgx/pgtype" - satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + satori "github.com/jackc/pgx/v4/pgtype/ext/satori-uuid" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestUUIDTranscode(t *testing.T) { diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index 7c1cd770..0b63999b 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -6,7 +6,7 @@ import ( "github.com/pkg/errors" - "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/v4/pgtype" "github.com/shopspring/decimal" ) diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go index b237478d..2af39e1d 100644 --- a/ext/shopspring-numeric/decimal_test.go +++ b/ext/shopspring-numeric/decimal_test.go @@ -7,9 +7,9 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - shopspring "github.com/jackc/pgx/pgtype/ext/shopspring-numeric" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + shopspring "github.com/jackc/pgx/v4/pgtype/ext/shopspring-numeric" + "github.com/jackc/pgx/v4/pgtype/testutil" "github.com/shopspring/decimal" ) diff --git a/float4_array_test.go b/float4_array_test.go index 4d6511b4..24d544b6 100644 --- a/float4_array_test.go +++ b/float4_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestFloat4ArrayTranscode(t *testing.T) { diff --git a/float4_test.go b/float4_test.go index 2ed8d05d..4779b357 100644 --- a/float4_test.go +++ b/float4_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestFloat4Transcode(t *testing.T) { diff --git a/float8_array_test.go b/float8_array_test.go index ff8e3b26..b3e7a197 100644 --- a/float8_array_test.go +++ b/float8_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestFloat8ArrayTranscode(t *testing.T) { diff --git a/float8_test.go b/float8_test.go index 46fc8d5d..15092916 100644 --- a/float8_test.go +++ b/float8_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestFloat8Transcode(t *testing.T) { diff --git a/hstore_array_test.go b/hstore_array_test.go index 849b5835..bc45cbdf 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -5,9 +5,9 @@ import ( "reflect" "testing" - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestHstoreArrayTranscode(t *testing.T) { diff --git a/hstore_test.go b/hstore_test.go index d76c9942..71fd2355 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestHstoreTranscode(t *testing.T) { diff --git a/inet_array_test.go b/inet_array_test.go index ca528ed3..4e93d0f5 100644 --- a/inet_array_test.go +++ b/inet_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInetArrayTranscode(t *testing.T) { diff --git a/inet_test.go b/inet_test.go index 32d66999..ee93873b 100644 --- a/inet_test.go +++ b/inet_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInetTranscode(t *testing.T) { diff --git a/int2_array_test.go b/int2_array_test.go index 0fe763c1..fb4f0d60 100644 --- a/int2_array_test.go +++ b/int2_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInt2ArrayTranscode(t *testing.T) { diff --git a/int2_test.go b/int2_test.go index d20bf0ed..ff4732f7 100644 --- a/int2_test.go +++ b/int2_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInt2Transcode(t *testing.T) { diff --git a/int4_array_test.go b/int4_array_test.go index f0418600..06772cf6 100644 --- a/int4_array_test.go +++ b/int4_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInt4ArrayTranscode(t *testing.T) { diff --git a/int4_test.go b/int4_test.go index 02f5409f..6b23c5a9 100644 --- a/int4_test.go +++ b/int4_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInt4Transcode(t *testing.T) { diff --git a/int4range_test.go b/int4range_test.go index 961678bb..95d448f0 100644 --- a/int4range_test.go +++ b/int4range_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInt4rangeTranscode(t *testing.T) { diff --git a/int8_array_test.go b/int8_array_test.go index 2ca65173..c2d914ab 100644 --- a/int8_array_test.go +++ b/int8_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInt8ArrayTranscode(t *testing.T) { diff --git a/int8_test.go b/int8_test.go index 0b3bb3eb..a5f80f42 100644 --- a/int8_test.go +++ b/int8_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInt8Transcode(t *testing.T) { diff --git a/int8range_test.go b/int8range_test.go index f33ae4d8..01af48bb 100644 --- a/int8range_test.go +++ b/int8range_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInt8rangeTranscode(t *testing.T) { diff --git a/interval_test.go b/interval_test.go index 76ea3240..7cafb0ae 100644 --- a/interval_test.go +++ b/interval_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestIntervalTranscode(t *testing.T) { diff --git a/json_test.go b/json_test.go index 38494841..bb0f1b20 100644 --- a/json_test.go +++ b/json_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestJSONTranscode(t *testing.T) { diff --git a/jsonb_test.go b/jsonb_test.go index afc51019..73656c76 100644 --- a/jsonb_test.go +++ b/jsonb_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestJSONBTranscode(t *testing.T) { diff --git a/line_test.go b/line_test.go index 077afe6b..5f0a58a3 100644 --- a/line_test.go +++ b/line_test.go @@ -4,8 +4,8 @@ import ( "context" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestLineTranscode(t *testing.T) { diff --git a/lseg_test.go b/lseg_test.go index 0a25090a..100bdf0f 100644 --- a/lseg_test.go +++ b/lseg_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestLsegTranscode(t *testing.T) { diff --git a/macaddr_array_test.go b/macaddr_array_test.go index d4bb2f01..cf07ebf6 100644 --- a/macaddr_array_test.go +++ b/macaddr_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestMacaddrArrayTranscode(t *testing.T) { diff --git a/macaddr_test.go b/macaddr_test.go index 5d329249..a08671c0 100644 --- a/macaddr_test.go +++ b/macaddr_test.go @@ -6,8 +6,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestMacaddrTranscode(t *testing.T) { diff --git a/name_test.go b/name_test.go index ec0820c4..75d7b95a 100644 --- a/name_test.go +++ b/name_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestNameTranscode(t *testing.T) { diff --git a/numeric_array_test.go b/numeric_array_test.go index 22ee1bc4..b17a6461 100644 --- a/numeric_array_test.go +++ b/numeric_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestNumericArrayTranscode(t *testing.T) { diff --git a/numeric_test.go b/numeric_test.go index 9d7d83d6..b723cc56 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -6,8 +6,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) // For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0) diff --git a/numrange_test.go b/numrange_test.go index ccc794d5..610447fe 100644 --- a/numrange_test.go +++ b/numrange_test.go @@ -4,8 +4,8 @@ import ( "math/big" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestNumrangeTranscode(t *testing.T) { diff --git a/oid_value_test.go b/oid_value_test.go index f5ff16cf..462a5a28 100644 --- a/oid_value_test.go +++ b/oid_value_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestOIDValueTranscode(t *testing.T) { diff --git a/path_test.go b/path_test.go index bc2d7435..16e781f5 100644 --- a/path_test.go +++ b/path_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestPathTranscode(t *testing.T) { diff --git a/pgtype_test.go b/pgtype_test.go index f7e743b2..400c0591 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -4,7 +4,7 @@ import ( "net" "testing" - _ "github.com/jackc/pgx/stdlib" + _ "github.com/jackc/pgx/v4/stdlib" _ "github.com/lib/pq" ) diff --git a/point_test.go b/point_test.go index af70b38b..017bfc03 100644 --- a/point_test.go +++ b/point_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestPointTranscode(t *testing.T) { diff --git a/polygon_test.go b/polygon_test.go index 5ff3bbb3..3bafebfc 100644 --- a/polygon_test.go +++ b/polygon_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestPolygonTranscode(t *testing.T) { diff --git a/qchar_test.go b/qchar_test.go index 057a557f..3b50bb3e 100644 --- a/qchar_test.go +++ b/qchar_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestQCharTranscode(t *testing.T) { diff --git a/record_test.go b/record_test.go index a4fc1e5d..5de8af31 100644 --- a/record_test.go +++ b/record_test.go @@ -6,9 +6,9 @@ import ( "reflect" "testing" - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestRecordTranscode(t *testing.T) { diff --git a/testutil/testutil.go b/testutil/testutil.go index 0d653394..121eb754 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -8,9 +8,9 @@ import ( "reflect" "testing" - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" - _ "github.com/jackc/pgx/stdlib" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" + _ "github.com/jackc/pgx/v4/stdlib" _ "github.com/lib/pq" ) diff --git a/text_array_test.go b/text_array_test.go index 105d9353..b03312d9 100644 --- a/text_array_test.go +++ b/text_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestTextArrayTranscode(t *testing.T) { diff --git a/text_test.go b/text_test.go index bd971807..53f4bd7e 100644 --- a/text_test.go +++ b/text_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestTextTranscode(t *testing.T) { diff --git a/tid_test.go b/tid_test.go index 9185cb31..cd753ab4 100644 --- a/tid_test.go +++ b/tid_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestTIDTranscode(t *testing.T) { diff --git a/timestamp_array_test.go b/timestamp_array_test.go index 5821f43a..002d1ca4 100644 --- a/timestamp_array_test.go +++ b/timestamp_array_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestTimestampArrayTranscode(t *testing.T) { diff --git a/timestamp_test.go b/timestamp_test.go index 267f1a7e..732f3cc2 100644 --- a/timestamp_test.go +++ b/timestamp_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestTimestampTranscode(t *testing.T) { diff --git a/timestamptz_array_test.go b/timestamptz_array_test.go index 8d7ea4c9..ac9975f0 100644 --- a/timestamptz_array_test.go +++ b/timestamptz_array_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestTimestamptzArrayTranscode(t *testing.T) { diff --git a/timestamptz_test.go b/timestamptz_test.go index c326802d..f522117b 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestTimestamptzTranscode(t *testing.T) { diff --git a/tsrange_test.go b/tsrange_test.go index 78eb1cd3..6215e318 100644 --- a/tsrange_test.go +++ b/tsrange_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestTsrangeTranscode(t *testing.T) { diff --git a/tstzrange_test.go b/tstzrange_test.go index a27ddd3a..ddaf798b 100644 --- a/tstzrange_test.go +++ b/tstzrange_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestTstzrangeTranscode(t *testing.T) { diff --git a/uuid_array_test.go b/uuid_array_test.go index ee9d3dfa..6ec6acfb 100644 --- a/uuid_array_test.go +++ b/uuid_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestUUIDArrayTranscode(t *testing.T) { diff --git a/uuid_test.go b/uuid_test.go index 162d999f..9d95c10c 100644 --- a/uuid_test.go +++ b/uuid_test.go @@ -4,8 +4,8 @@ import ( "bytes" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestUUIDTranscode(t *testing.T) { diff --git a/varbit_test.go b/varbit_test.go index 6c813aae..8ea282eb 100644 --- a/varbit_test.go +++ b/varbit_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestVarbitTranscode(t *testing.T) { diff --git a/varchar_array_test.go b/varchar_array_test.go index 9fb0960f..b836664f 100644 --- a/varchar_array_test.go +++ b/varchar_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestVarcharArrayTranscode(t *testing.T) { diff --git a/xid_test.go b/xid_test.go index 594d1214..34801e1f 100644 --- a/xid_test.go +++ b/xid_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestXIDTranscode(t *testing.T) { From f25878662dcdd36e4d5ba1dbff179178d5e4f494 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 17:43:44 -0500 Subject: [PATCH 0271/1158] Use golang.org/x/xerrors --- aclitem.go | 2 +- aclitem_array.go | 2 +- array.go | 2 +- bool.go | 2 +- bool_array.go | 2 +- box.go | 2 +- bpchar_array.go | 2 +- bytea.go | 2 +- bytea_array.go | 2 +- cidr_array.go | 2 +- circle.go | 2 +- convert.go | 2 +- database_sql.go | 2 +- date.go | 2 +- date_array.go | 2 +- daterange.go | 2 +- enum_array.go | 2 +- ext/satori-uuid/uuid.go | 2 +- ext/shopspring-numeric/decimal.go | 2 +- float4.go | 2 +- float4_array.go | 2 +- float8.go | 2 +- float8_array.go | 2 +- hstore.go | 2 +- hstore_array.go | 2 +- inet.go | 2 +- inet_array.go | 2 +- int2.go | 2 +- int2_array.go | 2 +- int4.go | 2 +- int4_array.go | 2 +- int4range.go | 2 +- int8.go | 2 +- int8_array.go | 2 +- int8range.go | 2 +- interval.go | 2 +- json.go | 2 +- jsonb.go | 2 +- line.go | 2 +- lseg.go | 2 +- macaddr.go | 2 +- macaddr_array.go | 2 +- numeric.go | 2 +- numeric_array.go | 2 +- numrange.go | 2 +- oid.go | 2 +- path.go | 2 +- pgtype.go | 2 +- pguint32.go | 2 +- point.go | 2 +- polygon.go | 2 +- qchar.go | 2 +- range.go | 2 +- record.go | 2 +- text.go | 2 +- text_array.go | 2 +- tid.go | 2 +- timestamp.go | 2 +- timestamp_array.go | 2 +- timestamptz.go | 2 +- timestamptz_array.go | 2 +- tsrange.go | 2 +- tstzrange.go | 2 +- uuid.go | 2 +- uuid_array.go | 2 +- varbit.go | 2 +- varchar_array.go | 2 +- 67 files changed, 67 insertions(+), 67 deletions(-) diff --git a/aclitem.go b/aclitem.go index a54955eb..c801eb83 100644 --- a/aclitem.go +++ b/aclitem.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) // ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem diff --git a/aclitem_array.go b/aclitem_array.go index 2671022b..c8421153 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type ACLItemArray struct { diff --git a/array.go b/array.go index 9ce0f003..69456782 100644 --- a/array.go +++ b/array.go @@ -9,7 +9,7 @@ import ( "unicode" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) // Information on the internals of PostgreSQL arrays can be found in diff --git a/bool.go b/bool.go index 22774970..f622061b 100644 --- a/bool.go +++ b/bool.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "strconv" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Bool struct { diff --git a/bool_array.go b/bool_array.go index 1aefcd27..3dde8dc0 100644 --- a/bool_array.go +++ b/bool_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type BoolArray struct { diff --git a/box.go b/box.go index 4c825c56..ce5300e5 100644 --- a/box.go +++ b/box.go @@ -9,7 +9,7 @@ import ( "strings" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Box struct { diff --git a/bpchar_array.go b/bpchar_array.go index dd4a8363..547b4e80 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type BPCharArray struct { diff --git a/bytea.go b/bytea.go index 064f199a..e6c28dc7 100644 --- a/bytea.go +++ b/bytea.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/hex" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Bytea struct { diff --git a/bytea_array.go b/bytea_array.go index fc07d103..369d6e08 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type ByteaArray struct { diff --git a/cidr_array.go b/cidr_array.go index 62b0ca65..94c07679 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -6,7 +6,7 @@ import ( "net" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type CIDRArray struct { diff --git a/circle.go b/circle.go index a3bb56f1..66dec132 100644 --- a/circle.go +++ b/circle.go @@ -9,7 +9,7 @@ import ( "strings" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Circle struct { diff --git a/convert.go b/convert.go index 5dfb738e..98999d45 100644 --- a/convert.go +++ b/convert.go @@ -5,7 +5,7 @@ import ( "reflect" "time" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) const maxUint = ^uint(0) diff --git a/database_sql.go b/database_sql.go index 969536dd..f54a750d 100644 --- a/database_sql.go +++ b/database_sql.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { diff --git a/date.go b/date.go index 3f8d188a..08ba8c08 100644 --- a/date.go +++ b/date.go @@ -6,7 +6,7 @@ import ( "time" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Date struct { diff --git a/date_array.go b/date_array.go index 6d6c0899..05070360 100644 --- a/date_array.go +++ b/date_array.go @@ -6,7 +6,7 @@ import ( "time" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type DateArray struct { diff --git a/daterange.go b/daterange.go index d10d34c0..40997bd9 100644 --- a/daterange.go +++ b/daterange.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Daterange struct { diff --git a/enum_array.go b/enum_array.go index 5de2badf..504d513c 100644 --- a/enum_array.go +++ b/enum_array.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type EnumArray struct { diff --git a/ext/satori-uuid/uuid.go b/ext/satori-uuid/uuid.go index 8713b4d6..2aebfc47 100644 --- a/ext/satori-uuid/uuid.go +++ b/ext/satori-uuid/uuid.go @@ -3,7 +3,7 @@ package uuid import ( "database/sql/driver" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" "github.com/jackc/pgx/v4/pgtype" uuid "github.com/satori/go.uuid" diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index 0b63999b..54612db9 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "strconv" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" "github.com/jackc/pgx/v4/pgtype" "github.com/shopspring/decimal" diff --git a/float4.go b/float4.go index c4feb0a7..0947f36a 100644 --- a/float4.go +++ b/float4.go @@ -7,7 +7,7 @@ import ( "strconv" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Float4 struct { diff --git a/float4_array.go b/float4_array.go index b14161e8..ef134407 100644 --- a/float4_array.go +++ b/float4_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Float4Array struct { diff --git a/float8.go b/float8.go index 63944d45..87cf6adb 100644 --- a/float8.go +++ b/float8.go @@ -7,7 +7,7 @@ import ( "strconv" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Float8 struct { diff --git a/float8_array.go b/float8_array.go index 60e87236..ba63449c 100644 --- a/float8_array.go +++ b/float8_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Float8Array struct { diff --git a/hstore.go b/hstore.go index 8a84fe2a..522813ff 100644 --- a/hstore.go +++ b/hstore.go @@ -8,7 +8,7 @@ import ( "unicode" "unicode/utf8" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" "github.com/jackc/pgio" ) diff --git a/hstore_array.go b/hstore_array.go index 19d07686..1bdac816 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type HstoreArray struct { diff --git a/inet.go b/inet.go index dfdd8868..0fb1c418 100644 --- a/inet.go +++ b/inet.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "net" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) // Network address family is dependent on server socket.h value for AF_INET. diff --git a/inet_array.go b/inet_array.go index 51ad7988..b31d3588 100644 --- a/inet_array.go +++ b/inet_array.go @@ -6,7 +6,7 @@ import ( "net" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type InetArray struct { diff --git a/int2.go b/int2.go index 72110684..bbf2952f 100644 --- a/int2.go +++ b/int2.go @@ -7,7 +7,7 @@ import ( "strconv" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Int2 struct { diff --git a/int2_array.go b/int2_array.go index e3b9f64b..afb39513 100644 --- a/int2_array.go +++ b/int2_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Int2Array struct { diff --git a/int4.go b/int4.go index 9ad878c4..cc34ce0a 100644 --- a/int4.go +++ b/int4.go @@ -8,7 +8,7 @@ import ( "strconv" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Int4 struct { diff --git a/int4_array.go b/int4_array.go index ad75c4b5..bd0babb9 100644 --- a/int4_array.go +++ b/int4_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Int4Array struct { diff --git a/int4range.go b/int4range.go index 67bbfcd2..03970ae6 100644 --- a/int4range.go +++ b/int4range.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Int4range struct { diff --git a/int8.go b/int8.go index 39b8a0a8..153f1f7d 100644 --- a/int8.go +++ b/int8.go @@ -8,7 +8,7 @@ import ( "strconv" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Int8 struct { diff --git a/int8_array.go b/int8_array.go index ae8d8e0f..392fd47e 100644 --- a/int8_array.go +++ b/int8_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Int8Array struct { diff --git a/int8range.go b/int8range.go index 25839a7b..0e0f1cdb 100644 --- a/int8range.go +++ b/int8range.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Int8range struct { diff --git a/interval.go b/interval.go index 9172e14a..a7edca83 100644 --- a/interval.go +++ b/interval.go @@ -9,7 +9,7 @@ import ( "time" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) const ( diff --git a/json.go b/json.go index 377a1546..49ff7a6c 100644 --- a/json.go +++ b/json.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/json" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type JSON struct { diff --git a/jsonb.go b/jsonb.go index c315c588..065e4e21 100644 --- a/jsonb.go +++ b/jsonb.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type JSONB JSON diff --git a/line.go b/line.go index 6ac4ac2a..617ee456 100644 --- a/line.go +++ b/line.go @@ -9,7 +9,7 @@ import ( "strings" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Line struct { diff --git a/lseg.go b/lseg.go index c0e77799..b8d6e322 100644 --- a/lseg.go +++ b/lseg.go @@ -9,7 +9,7 @@ import ( "strings" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Lseg struct { diff --git a/macaddr.go b/macaddr.go index 6854400b..25ffc48e 100644 --- a/macaddr.go +++ b/macaddr.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "net" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Macaddr struct { diff --git a/macaddr_array.go b/macaddr_array.go index 2d0439e9..0b791104 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -6,7 +6,7 @@ import ( "net" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type MacaddrArray struct { diff --git a/numeric.go b/numeric.go index 887ad1f8..bbd7667a 100644 --- a/numeric.go +++ b/numeric.go @@ -9,7 +9,7 @@ import ( "strings" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) // PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 diff --git a/numeric_array.go b/numeric_array.go index ec892cc8..1e8c5cda 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type NumericArray struct { diff --git a/numrange.go b/numrange.go index ff9d5372..f3e25109 100644 --- a/numrange.go +++ b/numrange.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Numrange struct { diff --git a/oid.go b/oid.go index 2afc60f8..593a5261 100644 --- a/oid.go +++ b/oid.go @@ -6,7 +6,7 @@ import ( "strconv" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) // OID (Object Identifier Type) is, according to diff --git a/path.go b/path.go index c1b72322..a4c6af77 100644 --- a/path.go +++ b/path.go @@ -9,7 +9,7 @@ import ( "strings" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Path struct { diff --git a/pgtype.go b/pgtype.go index 4faf23e1..cea4e1cd 100644 --- a/pgtype.go +++ b/pgtype.go @@ -4,7 +4,7 @@ import ( "database/sql" "reflect" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) // PostgreSQL oids for common types diff --git a/pguint32.go b/pguint32.go index 37178b5c..21da9664 100644 --- a/pguint32.go +++ b/pguint32.go @@ -7,7 +7,7 @@ import ( "strconv" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) // pguint32 is the core type that is used to implement PostgreSQL types such as diff --git a/point.go b/point.go index fefe5d1f..89f2359b 100644 --- a/point.go +++ b/point.go @@ -9,7 +9,7 @@ import ( "strings" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Vec2 struct { diff --git a/polygon.go b/polygon.go index 904e86e1..e739c71b 100644 --- a/polygon.go +++ b/polygon.go @@ -9,7 +9,7 @@ import ( "strings" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Polygon struct { diff --git a/qchar.go b/qchar.go index 064dab1e..5e77dc38 100644 --- a/qchar.go +++ b/qchar.go @@ -4,7 +4,7 @@ import ( "math" "strconv" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) // QChar is for PostgreSQL's special 8-bit-only "char" type more akin to the C diff --git a/range.go b/range.go index 54fc6ca0..35b80ced 100644 --- a/range.go +++ b/range.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/binary" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type BoundType byte diff --git a/record.go b/record.go index 315deda5..60733016 100644 --- a/record.go +++ b/record.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "reflect" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) // Record is the generic PostgreSQL record type such as is created with the diff --git a/text.go b/text.go index 648bbd58..4d4e6bb4 100644 --- a/text.go +++ b/text.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/json" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Text struct { diff --git a/text_array.go b/text_array.go index 1556fec8..b590972e 100644 --- a/text_array.go +++ b/text_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type TextArray struct { diff --git a/tid.go b/tid.go index e859865b..ff788b84 100644 --- a/tid.go +++ b/tid.go @@ -8,7 +8,7 @@ import ( "strings" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) // TID is PostgreSQL's Tuple Identifier type. diff --git a/timestamp.go b/timestamp.go index 93383e35..40dfdac8 100644 --- a/timestamp.go +++ b/timestamp.go @@ -6,7 +6,7 @@ import ( "time" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) const pgTimestampFormat = "2006-01-02 15:04:05.999999999" diff --git a/timestamp_array.go b/timestamp_array.go index 1fd1eefe..95f76639 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -6,7 +6,7 @@ import ( "time" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type TimestampArray struct { diff --git a/timestamptz.go b/timestamptz.go index c2c91c29..752c1818 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -6,7 +6,7 @@ import ( "time" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" diff --git a/timestamptz_array.go b/timestamptz_array.go index b87238ae..7fe60d50 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -6,7 +6,7 @@ import ( "time" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type TimestamptzArray struct { diff --git a/tsrange.go b/tsrange.go index d771a761..54cc863f 100644 --- a/tsrange.go +++ b/tsrange.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Tsrange struct { diff --git a/tstzrange.go b/tstzrange.go index 9a8c782e..1cf2859d 100644 --- a/tstzrange.go +++ b/tstzrange.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Tstzrange struct { diff --git a/uuid.go b/uuid.go index 5e1eead5..d3e68f5c 100644 --- a/uuid.go +++ b/uuid.go @@ -5,7 +5,7 @@ import ( "encoding/hex" "fmt" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type UUID struct { diff --git a/uuid_array.go b/uuid_array.go index fac838af..1d28ee59 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type UUIDArray struct { diff --git a/varbit.go b/varbit.go index 2c25b1fb..fe4db33d 100644 --- a/varbit.go +++ b/varbit.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type Varbit struct { diff --git a/varchar_array.go b/varchar_array.go index d2359d03..6aa92337 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "github.com/jackc/pgio" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type VarcharArray struct { From 4ed0de4755e042908ec6b12c68c0f900b7078726 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 19:14:08 -0500 Subject: [PATCH 0272/1158] Splitting pgtype into own repo --- aclitem_array_test.go | 4 +-- aclitem_test.go | 4 +-- array_test.go | 2 +- bit_test.go | 4 +-- bool_array_test.go | 4 +-- bool_test.go | 4 +-- box_test.go | 4 +-- bpchar_array_test.go | 4 +-- bpchar_test.go | 4 +-- bytea_array_test.go | 4 +-- bytea_test.go | 4 +-- cid_test.go | 4 +-- cidr_array_test.go | 4 +-- circle_test.go | 4 +-- date_array_test.go | 4 +-- date_test.go | 4 +-- daterange_test.go | 4 +-- enum_array_test.go | 4 +-- ext/satori-uuid/uuid.go | 2 +- ext/satori-uuid/uuid_test.go | 6 ++-- ext/shopspring-numeric/decimal.go | 2 +- ext/shopspring-numeric/decimal_test.go | 6 ++-- float4_array_test.go | 4 +-- float4_test.go | 4 +-- float8_array_test.go | 4 +-- float8_test.go | 4 +-- go.mod | 11 ++++++ go.sum | 49 ++++++++++++++++++++++++++ hstore_array_test.go | 4 +-- hstore_test.go | 4 +-- inet_array_test.go | 4 +-- inet_test.go | 4 +-- int2_array_test.go | 4 +-- int2_test.go | 4 +-- int4_array_test.go | 4 +-- int4_test.go | 4 +-- int4range_test.go | 4 +-- int8_array_test.go | 4 +-- int8_test.go | 4 +-- int8range_test.go | 4 +-- interval_test.go | 4 +-- json_test.go | 4 +-- jsonb_test.go | 4 +-- line_test.go | 4 +-- lseg_test.go | 4 +-- macaddr_array_test.go | 4 +-- macaddr_test.go | 4 +-- name_test.go | 4 +-- numeric_array_test.go | 4 +-- numeric_test.go | 4 +-- numrange_test.go | 4 +-- oid_value_test.go | 4 +-- path_test.go | 4 +-- point_test.go | 4 +-- polygon_test.go | 4 +-- qchar_test.go | 4 +-- record_test.go | 4 +-- testutil/testutil.go | 2 +- text_array_test.go | 4 +-- text_test.go | 4 +-- tid_test.go | 4 +-- timestamp_array_test.go | 4 +-- timestamp_test.go | 4 +-- timestamptz_array_test.go | 4 +-- timestamptz_test.go | 4 +-- tsrange_test.go | 4 +-- tstzrange_test.go | 4 +-- uuid_array_test.go | 4 +-- uuid_test.go | 4 +-- varbit_test.go | 4 +-- varchar_array_test.go | 4 +-- xid_test.go | 4 +-- 72 files changed, 198 insertions(+), 138 deletions(-) create mode 100644 go.mod create mode 100644 go.sum diff --git a/aclitem_array_test.go b/aclitem_array_test.go index 5f16ab28..dafd13b0 100644 --- a/aclitem_array_test.go +++ b/aclitem_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestACLItemArrayTranscode(t *testing.T) { diff --git a/aclitem_test.go b/aclitem_test.go index 92dfc7a5..480c457c 100644 --- a/aclitem_test.go +++ b/aclitem_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestACLItemTranscode(t *testing.T) { diff --git a/array_test.go b/array_test.go index d17d753c..486171b8 100644 --- a/array_test.go +++ b/array_test.go @@ -4,7 +4,7 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgtype" ) func TestParseUntypedTextArray(t *testing.T) { diff --git a/bit_test.go b/bit_test.go index 05729323..2e9c9b6e 100644 --- a/bit_test.go +++ b/bit_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestBitTranscode(t *testing.T) { diff --git a/bool_array_test.go b/bool_array_test.go index 6d2d7c06..bef94622 100644 --- a/bool_array_test.go +++ b/bool_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestBoolArrayTranscode(t *testing.T) { diff --git a/bool_test.go b/bool_test.go index 5228e280..64b4064d 100644 --- a/bool_test.go +++ b/bool_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestBoolTranscode(t *testing.T) { diff --git a/box_test.go b/box_test.go index aad10262..643c74ec 100644 --- a/box_test.go +++ b/box_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestBoxTranscode(t *testing.T) { diff --git a/bpchar_array_test.go b/bpchar_array_test.go index 820dfa5b..af6bf09a 100644 --- a/bpchar_array_test.go +++ b/bpchar_array_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestBPCharArrayTranscode(t *testing.T) { diff --git a/bpchar_test.go b/bpchar_test.go index e8981e52..7b8c1da3 100644 --- a/bpchar_test.go +++ b/bpchar_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestChar3Transcode(t *testing.T) { diff --git a/bytea_array_test.go b/bytea_array_test.go index 00dc0a1f..a4eb2d91 100644 --- a/bytea_array_test.go +++ b/bytea_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestByteaArrayTranscode(t *testing.T) { diff --git a/bytea_test.go b/bytea_test.go index 75b55de4..c8c49ff7 100644 --- a/bytea_test.go +++ b/bytea_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestByteaTranscode(t *testing.T) { diff --git a/cid_test.go b/cid_test.go index 588e6c66..50e50cd8 100644 --- a/cid_test.go +++ b/cid_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestCIDTranscode(t *testing.T) { diff --git a/cidr_array_test.go b/cidr_array_test.go index 71125bdb..421aec4e 100644 --- a/cidr_array_test.go +++ b/cidr_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestCIDRArrayTranscode(t *testing.T) { diff --git a/circle_test.go b/circle_test.go index 82598620..ba4f408b 100644 --- a/circle_test.go +++ b/circle_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestCircleTranscode(t *testing.T) { diff --git a/date_array_test.go b/date_array_test.go index 24a8282c..9f4a96a9 100644 --- a/date_array_test.go +++ b/date_array_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestDateArrayTranscode(t *testing.T) { diff --git a/date_test.go b/date_test.go index ac7aadfe..bcdbbf20 100644 --- a/date_test.go +++ b/date_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestDateTranscode(t *testing.T) { diff --git a/daterange_test.go b/daterange_test.go index 4d3119ee..4118cffa 100644 --- a/daterange_test.go +++ b/daterange_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestDaterangeTranscode(t *testing.T) { diff --git a/enum_array_test.go b/enum_array_test.go index dbe09751..406c6b47 100644 --- a/enum_array_test.go +++ b/enum_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestEnumArrayTranscode(t *testing.T) { diff --git a/ext/satori-uuid/uuid.go b/ext/satori-uuid/uuid.go index 2aebfc47..01adea23 100644 --- a/ext/satori-uuid/uuid.go +++ b/ext/satori-uuid/uuid.go @@ -5,7 +5,7 @@ import ( errors "golang.org/x/xerrors" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgtype" uuid "github.com/satori/go.uuid" ) diff --git a/ext/satori-uuid/uuid_test.go b/ext/satori-uuid/uuid_test.go index 7a770b84..247470a3 100644 --- a/ext/satori-uuid/uuid_test.go +++ b/ext/satori-uuid/uuid_test.go @@ -4,9 +4,9 @@ import ( "bytes" "testing" - "github.com/jackc/pgx/v4/pgtype" - satori "github.com/jackc/pgx/v4/pgtype/ext/satori-uuid" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + satori "github.com/jackc/pgtype/ext/satori-uuid" + "github.com/jackc/pgtype/testutil" ) func TestUUIDTranscode(t *testing.T) { diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index 54612db9..d8f176a8 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -6,7 +6,7 @@ import ( errors "golang.org/x/xerrors" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgtype" "github.com/shopspring/decimal" ) diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go index 2af39e1d..0b256b37 100644 --- a/ext/shopspring-numeric/decimal_test.go +++ b/ext/shopspring-numeric/decimal_test.go @@ -7,9 +7,9 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - shopspring "github.com/jackc/pgx/v4/pgtype/ext/shopspring-numeric" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + shopspring "github.com/jackc/pgtype/ext/shopspring-numeric" + "github.com/jackc/pgtype/testutil" "github.com/shopspring/decimal" ) diff --git a/float4_array_test.go b/float4_array_test.go index 24d544b6..658b3381 100644 --- a/float4_array_test.go +++ b/float4_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestFloat4ArrayTranscode(t *testing.T) { diff --git a/float4_test.go b/float4_test.go index 4779b357..d2524cda 100644 --- a/float4_test.go +++ b/float4_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestFloat4Transcode(t *testing.T) { diff --git a/float8_array_test.go b/float8_array_test.go index b3e7a197..2e29a19f 100644 --- a/float8_array_test.go +++ b/float8_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestFloat8ArrayTranscode(t *testing.T) { diff --git a/float8_test.go b/float8_test.go index 15092916..6bc7c652 100644 --- a/float8_test.go +++ b/float8_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestFloat8Transcode(t *testing.T) { diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..8412ceea --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/jackc/pgtype + +go 1.12 + +require ( + github.com/jackc/pgio v1.0.0 + github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96 + github.com/lib/pq v1.1.0 + github.com/satori/go.uuid v1.2.0 + golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..ff91dc33 --- /dev/null +++ b/go.sum @@ -0,0 +1,49 @@ +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3 h1:ZFYpB74Kq8xE9gmfxCmXD6QxZ27ja+j3HwGFc+YurhQ= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96 h1:ylEAOd688Duev/fxTmGdupsbyZfxNMdngIG14DoBKTM= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0 h1:/5u4a+KGJptBRqGzPvYQL9p0d/tPR4S31+Tnzj9lEO4= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/hstore_array_test.go b/hstore_array_test.go index bc45cbdf..47835605 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -5,9 +5,9 @@ import ( "reflect" "testing" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestHstoreArrayTranscode(t *testing.T) { diff --git a/hstore_test.go b/hstore_test.go index 71fd2355..ccd476dc 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestHstoreTranscode(t *testing.T) { diff --git a/inet_array_test.go b/inet_array_test.go index 4e93d0f5..6737aac0 100644 --- a/inet_array_test.go +++ b/inet_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestInetArrayTranscode(t *testing.T) { diff --git a/inet_test.go b/inet_test.go index ee93873b..8257a63d 100644 --- a/inet_test.go +++ b/inet_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestInetTranscode(t *testing.T) { diff --git a/int2_array_test.go b/int2_array_test.go index fb4f0d60..810d5a7e 100644 --- a/int2_array_test.go +++ b/int2_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestInt2ArrayTranscode(t *testing.T) { diff --git a/int2_test.go b/int2_test.go index ff4732f7..cf8acd30 100644 --- a/int2_test.go +++ b/int2_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestInt2Transcode(t *testing.T) { diff --git a/int4_array_test.go b/int4_array_test.go index 06772cf6..a0b8058f 100644 --- a/int4_array_test.go +++ b/int4_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestInt4ArrayTranscode(t *testing.T) { diff --git a/int4_test.go b/int4_test.go index 6b23c5a9..52bf9f0c 100644 --- a/int4_test.go +++ b/int4_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestInt4Transcode(t *testing.T) { diff --git a/int4range_test.go b/int4range_test.go index 95d448f0..43626189 100644 --- a/int4range_test.go +++ b/int4range_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestInt4rangeTranscode(t *testing.T) { diff --git a/int8_array_test.go b/int8_array_test.go index c2d914ab..f4ed76e0 100644 --- a/int8_array_test.go +++ b/int8_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestInt8ArrayTranscode(t *testing.T) { diff --git a/int8_test.go b/int8_test.go index a5f80f42..63dd6f3e 100644 --- a/int8_test.go +++ b/int8_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestInt8Transcode(t *testing.T) { diff --git a/int8range_test.go b/int8range_test.go index 01af48bb..99d4e8a3 100644 --- a/int8range_test.go +++ b/int8range_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestInt8rangeTranscode(t *testing.T) { diff --git a/interval_test.go b/interval_test.go index 7cafb0ae..6a4787e0 100644 --- a/interval_test.go +++ b/interval_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestIntervalTranscode(t *testing.T) { diff --git a/json_test.go b/json_test.go index bb0f1b20..918b33d5 100644 --- a/json_test.go +++ b/json_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestJSONTranscode(t *testing.T) { diff --git a/jsonb_test.go b/jsonb_test.go index 73656c76..e7ce7203 100644 --- a/jsonb_test.go +++ b/jsonb_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestJSONBTranscode(t *testing.T) { diff --git a/line_test.go b/line_test.go index 5f0a58a3..6a560dec 100644 --- a/line_test.go +++ b/line_test.go @@ -4,8 +4,8 @@ import ( "context" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestLineTranscode(t *testing.T) { diff --git a/lseg_test.go b/lseg_test.go index 100bdf0f..b75297cc 100644 --- a/lseg_test.go +++ b/lseg_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestLsegTranscode(t *testing.T) { diff --git a/macaddr_array_test.go b/macaddr_array_test.go index cf07ebf6..d2b0a73b 100644 --- a/macaddr_array_test.go +++ b/macaddr_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestMacaddrArrayTranscode(t *testing.T) { diff --git a/macaddr_test.go b/macaddr_test.go index a08671c0..364a8914 100644 --- a/macaddr_test.go +++ b/macaddr_test.go @@ -6,8 +6,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestMacaddrTranscode(t *testing.T) { diff --git a/name_test.go b/name_test.go index 75d7b95a..75329b01 100644 --- a/name_test.go +++ b/name_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestNameTranscode(t *testing.T) { diff --git a/numeric_array_test.go b/numeric_array_test.go index b17a6461..9d608dea 100644 --- a/numeric_array_test.go +++ b/numeric_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestNumericArrayTranscode(t *testing.T) { diff --git a/numeric_test.go b/numeric_test.go index b723cc56..046c2f94 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -6,8 +6,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) // For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0) diff --git a/numrange_test.go b/numrange_test.go index 610447fe..0bbb26f0 100644 --- a/numrange_test.go +++ b/numrange_test.go @@ -4,8 +4,8 @@ import ( "math/big" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestNumrangeTranscode(t *testing.T) { diff --git a/oid_value_test.go b/oid_value_test.go index 462a5a28..69742dd7 100644 --- a/oid_value_test.go +++ b/oid_value_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestOIDValueTranscode(t *testing.T) { diff --git a/path_test.go b/path_test.go index 16e781f5..969a89ec 100644 --- a/path_test.go +++ b/path_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestPathTranscode(t *testing.T) { diff --git a/point_test.go b/point_test.go index 017bfc03..0d191b5e 100644 --- a/point_test.go +++ b/point_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestPointTranscode(t *testing.T) { diff --git a/polygon_test.go b/polygon_test.go index 3bafebfc..f8b02ca2 100644 --- a/polygon_test.go +++ b/polygon_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestPolygonTranscode(t *testing.T) { diff --git a/qchar_test.go b/qchar_test.go index 3b50bb3e..4b60339c 100644 --- a/qchar_test.go +++ b/qchar_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestQCharTranscode(t *testing.T) { diff --git a/record_test.go b/record_test.go index 5de8af31..fbf36f5c 100644 --- a/record_test.go +++ b/record_test.go @@ -6,9 +6,9 @@ import ( "reflect" "testing" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestRecordTranscode(t *testing.T) { diff --git a/testutil/testutil.go b/testutil/testutil.go index 121eb754..66deff39 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -8,8 +8,8 @@ import ( "reflect" "testing" + "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgtype" _ "github.com/jackc/pgx/v4/stdlib" _ "github.com/lib/pq" ) diff --git a/text_array_test.go b/text_array_test.go index b03312d9..a29ce617 100644 --- a/text_array_test.go +++ b/text_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestTextArrayTranscode(t *testing.T) { diff --git a/text_test.go b/text_test.go index 53f4bd7e..f7286995 100644 --- a/text_test.go +++ b/text_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestTextTranscode(t *testing.T) { diff --git a/tid_test.go b/tid_test.go index cd753ab4..773bd96f 100644 --- a/tid_test.go +++ b/tid_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestTIDTranscode(t *testing.T) { diff --git a/timestamp_array_test.go b/timestamp_array_test.go index 002d1ca4..d7632fa3 100644 --- a/timestamp_array_test.go +++ b/timestamp_array_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestTimestampArrayTranscode(t *testing.T) { diff --git a/timestamp_test.go b/timestamp_test.go index 732f3cc2..eec0a52e 100644 --- a/timestamp_test.go +++ b/timestamp_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestTimestampTranscode(t *testing.T) { diff --git a/timestamptz_array_test.go b/timestamptz_array_test.go index ac9975f0..8a4cfd1d 100644 --- a/timestamptz_array_test.go +++ b/timestamptz_array_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestTimestamptzArrayTranscode(t *testing.T) { diff --git a/timestamptz_test.go b/timestamptz_test.go index f522117b..f6aec068 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestTimestamptzTranscode(t *testing.T) { diff --git a/tsrange_test.go b/tsrange_test.go index 6215e318..1be0c7d2 100644 --- a/tsrange_test.go +++ b/tsrange_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestTsrangeTranscode(t *testing.T) { diff --git a/tstzrange_test.go b/tstzrange_test.go index ddaf798b..b3d3ff6c 100644 --- a/tstzrange_test.go +++ b/tstzrange_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestTstzrangeTranscode(t *testing.T) { diff --git a/uuid_array_test.go b/uuid_array_test.go index 6ec6acfb..d5446920 100644 --- a/uuid_array_test.go +++ b/uuid_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestUUIDArrayTranscode(t *testing.T) { diff --git a/uuid_test.go b/uuid_test.go index 9d95c10c..49190168 100644 --- a/uuid_test.go +++ b/uuid_test.go @@ -4,8 +4,8 @@ import ( "bytes" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestUUIDTranscode(t *testing.T) { diff --git a/varbit_test.go b/varbit_test.go index 8ea282eb..3c5aea1e 100644 --- a/varbit_test.go +++ b/varbit_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestVarbitTranscode(t *testing.T) { diff --git a/varchar_array_test.go b/varchar_array_test.go index b836664f..9ad80862 100644 --- a/varchar_array_test.go +++ b/varchar_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestVarcharArrayTranscode(t *testing.T) { diff --git a/xid_test.go b/xid_test.go index 34801e1f..563ce96e 100644 --- a/xid_test.go +++ b/xid_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" ) func TestXIDTranscode(t *testing.T) { From 99fd636b8efaa82b65161f8003c209067e630169 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 19:20:51 -0500 Subject: [PATCH 0273/1158] Finish mod changes for split --- go.mod | 7 ++++++- go.sum | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 8412ceea..00679e12 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,13 @@ go 1.12 require ( github.com/jackc/pgio v1.0.0 - github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96 + github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912 + github.com/kr/pretty v0.1.0 // indirect github.com/lib/pq v1.1.0 github.com/satori/go.uuid v1.2.0 + github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 + go.uber.org/atomic v1.3.2 // indirect + go.uber.org/multierr v1.1.0 // indirect golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) diff --git a/go.sum b/go.sum index ff91dc33..ecd3007e 100644 --- a/go.sum +++ b/go.sum @@ -13,8 +13,11 @@ github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96 h1:ylEAOd688Duev/fxTmGdupsbyZfxNMdngIG14DoBKTM= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912 h1:YuOWGsSK5L4Fz81Olx5TNlZftmDuNrfv4ip0Yos77Tw= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -29,6 +32,7 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 h1:pntxY8Ary0t43dCZ5dqY4YTJCObLY1kIXl0uzMv+7DE= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= From 23a91ebc909de0d768ce3b5965a603dec0725286 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 24 Apr 2019 16:08:12 -0500 Subject: [PATCH 0274/1158] auth_scram.go file comment should not be part of docs --- auth_scram.go | 1 + 1 file changed, 1 insertion(+) diff --git a/auth_scram.go b/auth_scram.go index 5baa680b..d102d305 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -9,6 +9,7 @@ // https://github.com/lib/pq/pull/608 // https://github.com/lib/pq/pull/788 // https://github.com/lib/pq/pull/833 + package pgconn import ( From 1e3961bd0ea4d624dc181734894db99b9e5946f4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 24 Apr 2019 16:49:52 -0500 Subject: [PATCH 0275/1158] Fix flickering test --- pgconn_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pgconn_test.go b/pgconn_test.go index b7cb4036..dcbbfc89 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1289,7 +1289,12 @@ func TestConnCancelRequest(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) - multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(5)") + multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") + + // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a + // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a + // few milliseconds. + time.Sleep(50 * time.Millisecond) err = pgConn.CancelRequest(context.Background()) require.NoError(t, err) From 4acc0f54c6ba565535ed46a21891b9c9f377be11 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 May 2019 14:07:55 -0500 Subject: [PATCH 0276/1158] Import fixes from pgx/pgproto3 Import and adapt commit: fbb8cce --- backend.go | 3 --- copy_fail.go | 19 +++++++++---------- copy_out_response.go | 2 ++ 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/backend.go b/backend.go index 7f11bc7f..1121c550 100644 --- a/backend.go +++ b/backend.go @@ -14,7 +14,6 @@ type Backend struct { // Frontend message flyweights bind Bind _close Close - copyFail CopyFail describe Describe execute Execute flush Flush @@ -81,8 +80,6 @@ func (b *Backend) Receive() (FrontendMessage, error) { msg = &b.describe case 'E': msg = &b.execute - case 'f': - msg = &b.copyFail case 'H': msg = &b.flush case 'P': diff --git a/copy_fail.go b/copy_fail.go index e086207a..eadffa9c 100644 --- a/copy_fail.go +++ b/copy_fail.go @@ -8,11 +8,10 @@ import ( ) type CopyFail struct { - Error string + Message string } -func (*CopyFail) Frontend() {} -func (*CopyFail) Backend() {} +func (*CopyFail) Backend() {} func (dst *CopyFail) Decode(src []byte) error { idx := bytes.IndexByte(src, 0) @@ -20,17 +19,17 @@ func (dst *CopyFail) Decode(src []byte) error { return &invalidMessageFormatErr{messageType: "CopyFail"} } - dst.Error = string(src[:idx]) + dst.Message = string(src[:idx]) return nil } func (src *CopyFail) Encode(dst []byte) []byte { - dst = append(dst, 'C') + dst = append(dst, 'f') sp := len(dst) dst = pgio.AppendInt32(dst, -1) - dst = append(dst, src.Error...) + dst = append(dst, src.Message...) dst = append(dst, 0) pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) @@ -40,10 +39,10 @@ func (src *CopyFail) Encode(dst []byte) []byte { func (src *CopyFail) MarshalJSON() ([]byte, error) { return json.Marshal(struct { - Type string - Error string + Type string + Message string }{ - Type: "CopyFail", - Error: src.Error, + Type: "CopyFail", + Message: src.Message, }) } diff --git a/copy_out_response.go b/copy_out_response.go index 01a64228..eb6fb50e 100644 --- a/copy_out_response.go +++ b/copy_out_response.go @@ -44,6 +44,8 @@ func (src *CopyOutResponse) Encode(dst []byte) []byte { sp := len(dst) dst = pgio.AppendInt32(dst, -1) + dst = append(dst, src.OverallFormat) + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) From 1baf0ef57ec8643d0417d5b2b909ba17c214d125 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 7 May 2019 18:05:06 -0500 Subject: [PATCH 0277/1158] Refactor context handling into ctxwatch package --- benchmark_test.go | 16 +++ chan_to_set_deadline.go | 51 -------- go.mod | 1 - go.sum | 1 + helper_test.go | 4 +- internal/ctxwatch/context_watcher.go | 64 ++++++++++ internal/ctxwatch/context_watcher_test.go | 139 ++++++++++++++++++++++ pgconn.go | 65 ++++++---- 8 files changed, 261 insertions(+), 80 deletions(-) delete mode 100644 chan_to_set_deadline.go create mode 100644 internal/ctxwatch/context_watcher.go create mode 100644 internal/ctxwatch/context_watcher_test.go diff --git a/benchmark_test.go b/benchmark_test.go index 000dfd1b..073281aa 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -206,3 +206,19 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { } } } + +// func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) { +// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) +// require.Nil(b, err) +// defer closeConn(b, conn) + +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() + +// b.ResetTimer() + +// for i := 0; i < b.N; i++ { +// conn.ChanToSetDeadline().Watch(ctx) +// conn.ChanToSetDeadline().Ignore() +// } +// } diff --git a/chan_to_set_deadline.go b/chan_to_set_deadline.go deleted file mode 100644 index 04bb8fde..00000000 --- a/chan_to_set_deadline.go +++ /dev/null @@ -1,51 +0,0 @@ -package pgconn - -import ( - "time" -) - -var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) - -type setDeadliner interface { - SetDeadline(time.Time) error -} - -type chanToSetDeadline struct { - cleanupChan chan struct{} - conn setDeadliner - deadlineWasSet bool - cleanupComplete bool -} - -func (this *chanToSetDeadline) start(doneChan <-chan struct{}, conn setDeadliner) { - if this.cleanupChan == nil { - this.cleanupChan = make(chan struct{}) - } - this.conn = conn - this.deadlineWasSet = false - this.cleanupComplete = false - - if doneChan != nil { - go func() { - select { - case <-doneChan: - conn.SetDeadline(deadlineTime) - this.deadlineWasSet = true - <-this.cleanupChan - case <-this.cleanupChan: - } - }() - } else { - this.cleanupComplete = true - } -} - -func (this *chanToSetDeadline) cleanup() { - if !this.cleanupComplete { - this.cleanupChan <- struct{}{} - if this.deadlineWasSet { - this.conn.SetDeadline(time.Time{}) - } - this.cleanupComplete = true - } -} diff --git a/go.mod b/go.mod index acbee593..4ad3564a 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db - github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 diff --git a/go.sum b/go.sum index 9160f187..9e2398cb 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,7 @@ github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/helper_test.go b/helper_test.go index 5d44f3b8..1a3ca75e 100644 --- a/helper_test.go +++ b/helper_test.go @@ -12,9 +12,9 @@ import ( ) func closeConn(t testing.TB, conn *pgconn.PgConn) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - require.Nil(t, conn.Close(ctx)) + require.NoError(t, conn.Close(ctx)) } // Do a simple query to ensure the connection is still usable diff --git a/internal/ctxwatch/context_watcher.go b/internal/ctxwatch/context_watcher.go new file mode 100644 index 00000000..391f0b79 --- /dev/null +++ b/internal/ctxwatch/context_watcher.go @@ -0,0 +1,64 @@ +package ctxwatch + +import ( + "context" +) + +// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a +// time. +type ContextWatcher struct { + onCancel func() + onUnwatchAfterCancel func() + unwatchChan chan struct{} + watchInProgress bool + onCancelWasCalled bool +} + +// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. +// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and +// onCancel called. +func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher { + cw := &ContextWatcher{ + onCancel: onCancel, + onUnwatchAfterCancel: onUnwatchAfterCancel, + unwatchChan: make(chan struct{}), + } + + return cw +} + +// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called. +func (cw *ContextWatcher) Watch(ctx context.Context) { + if cw.watchInProgress { + panic("Watch already in progress") + } + + cw.onCancelWasCalled = false + + if ctx.Done() != nil { + cw.watchInProgress = true + go func() { + select { + case <-ctx.Done(): + cw.onCancel() + cw.onCancelWasCalled = true + <-cw.unwatchChan + case <-cw.unwatchChan: + } + }() + } else { + cw.watchInProgress = false + } +} + +// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was +// called then onUnwatchAfterCancel will also be called. +func (cw *ContextWatcher) Unwatch() { + if cw.watchInProgress { + cw.unwatchChan <- struct{}{} + if cw.onCancelWasCalled { + cw.onUnwatchAfterCancel() + } + cw.watchInProgress = false + } +} diff --git a/internal/ctxwatch/context_watcher_test.go b/internal/ctxwatch/context_watcher_test.go new file mode 100644 index 00000000..0b491bf8 --- /dev/null +++ b/internal/ctxwatch/context_watcher_test.go @@ -0,0 +1,139 @@ +package ctxwatch_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/jackc/pgconn/internal/ctxwatch" + "github.com/stretchr/testify/require" +) + +func TestContextWatcherContextCancelled(t *testing.T) { + canceledChan := make(chan struct{}) + cleanupCalled := false + cw := ctxwatch.NewContextWatcher(func() { + canceledChan <- struct{}{} + }, func() { + cleanupCalled = true + }) + + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cancel() + + select { + case <-canceledChan: + case <-time.NewTimer(time.Second).C: + t.Fatal("Timed out waiting for cancel func to be called") + } + + cw.Unwatch() + + require.True(t, cleanupCalled, "Cleanup func was not called") +} + +func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() { + t.Error("cancel func should not have been called") + }, func() { + t.Error("cleanup func should not have been called") + }) + + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cw.Unwatch() + cancel() +} + +func TestContextWatcherMultipleWatchPanics(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cw.Watch(ctx) + + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times") +} + +func TestContextWatcherStress(t *testing.T) { + var cancelFuncCalls int64 + var cleanupFuncCalls int64 + + cw := ctxwatch.NewContextWatcher(func() { + atomic.AddInt64(&cancelFuncCalls, 1) + }, func() { + atomic.AddInt64(&cleanupFuncCalls, 1) + }) + + cycleCount := 100000 + + for i := 0; i < cycleCount; i++ { + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + if i%2 == 0 { + cancel() + } + + // Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix. + if i%3 == 0 { + time.Sleep(time.Nanosecond) + } + + cw.Unwatch() + if i%2 == 1 { + cancel() + } + } + + actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls) + actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls) + + if actualCancelFuncCalls == 0 { + t.Fatal("actualCancelFuncCalls == 0") + } + + maxCancelFuncCalls := int64(cycleCount) / 2 + if actualCancelFuncCalls > maxCancelFuncCalls { + t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls) + } + + if actualCancelFuncCalls != actualCleanupFuncCalls { + t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls) + } +} + +func BenchmarkContextWatcherUncancellable(b *testing.B) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + for i := 0; i < b.N; i++ { + cw.Watch(context.Background()) + cw.Unwatch() + } +} + +func BenchmarkContextWatcherCancelled(b *testing.B) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + for i := 0; i < b.N; i++ { + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cancel() + cw.Unwatch() + } +} + +func BenchmarkContextWatcherCancellable(b *testing.B) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for i := 0; i < b.N; i++ { + cw.Watch(ctx) + cw.Unwatch() + } +} diff --git a/pgconn.go b/pgconn.go index a4402a7d..aad5fafd 100644 --- a/pgconn.go +++ b/pgconn.go @@ -13,7 +13,9 @@ import ( "strconv" "strings" "sync" + "time" + "github.com/jackc/pgconn/internal/ctxwatch" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" @@ -21,6 +23,7 @@ import ( const ( connStatusUninitialized = iota + connStatusConnecting connStatusClosed connStatusIdle connStatusBusy @@ -71,10 +74,10 @@ type PgConn struct { bufferingReceiveErr error // Reusable / preallocated resources - wbuf []byte // write buffer - resultReader ResultReader - multiResultReader MultiResultReader - doneChanToDeadline chanToSetDeadline + wbuf []byte // write buffer + resultReader ResultReader + multiResultReader MultiResultReader + contextWatcher *ctxwatch.ContextWatcher } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -149,6 +152,12 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } } + pgConn.status = connStatusConnecting + pgConn.contextWatcher = ctxwatch.NewContextWatcher( + func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { pgConn.conn.SetDeadline(time.Time{}) }, + ) + pgConn.Frontend, err = pgproto3.NewFrontend(pgproto3.NewChunkReader(pgConn.conn), pgConn.conn) if err != nil { return nil, err @@ -355,8 +364,8 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) - defer pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { @@ -377,6 +386,7 @@ func (pgConn *PgConn) hardClose() error { return nil } pgConn.status = connStatusClosed + return pgConn.conn.Close() } @@ -453,8 +463,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) - defer pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() buf := pgConn.wbuf buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) @@ -543,9 +553,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { } defer cancelConn.Close() - var doneChanToDeadline chanToSetDeadline - doneChanToDeadline.start(ctx.Done(), cancelConn) - defer doneChanToDeadline.cleanup() + contextWatcher := ctxwatch.NewContextWatcher( + func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { cancelConn.SetDeadline(time.Time{}) }, + ) + contextWatcher.Watch(ctx) + defer contextWatcher.Unwatch() buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) @@ -579,8 +592,8 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) - defer pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() for { msg, err := pgConn.ReceiveMessage() @@ -622,7 +635,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { return multiResult default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + pgConn.contextWatcher.Watch(ctx) buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) @@ -630,7 +643,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Unwatch() multiResult.closed = true if n == 0 { err = linkErrors(err, ErrNoBytesSent) @@ -732,7 +745,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + pgConn.contextWatcher.Watch(ctx) return result } @@ -749,7 +762,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result err = linkErrors(err, ErrNoBytesSent) } result.concludeCommand(nil, linkErrors(ctx.Err(), err)) - pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() } @@ -767,8 +780,8 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) - defer pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf @@ -828,8 +841,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) - defer pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf @@ -962,7 +975,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) msg, err := mrr.pgConn.ReceiveMessage() if err != nil { - mrr.pgConn.doneChanToDeadline.cleanup() + mrr.pgConn.contextWatcher.Unwatch() mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.closed = true mrr.pgConn.hardClose() @@ -971,7 +984,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - mrr.pgConn.doneChanToDeadline.cleanup() + mrr.pgConn.contextWatcher.Unwatch() mrr.closed = true mrr.pgConn.unlock() case *pgproto3.ErrorResponse: @@ -1129,7 +1142,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { switch msg.(type) { case *pgproto3.ReadyForQuery: - rr.pgConn.doneChanToDeadline.cleanup() + rr.pgConn.contextWatcher.Unwatch() rr.pgConn.unlock() return rr.commandTag, rr.err } @@ -1148,7 +1161,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error if err != nil { rr.concludeCommand(nil, err) - rr.pgConn.doneChanToDeadline.cleanup() + rr.pgConn.contextWatcher.Unwatch() rr.closed = true if rr.multiResultReader == nil { rr.pgConn.hardClose() @@ -1223,7 +1236,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR return multiResult default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + pgConn.contextWatcher.Watch(ctx) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) From d30cf1c19f3a13beb275eb8a517d7f54d5e185bf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 9 May 2019 15:15:40 -0500 Subject: [PATCH 0278/1158] Adjust buffer size for CopyFrom --- pgconn.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pgconn.go b/pgconn.go index aad5fafd..bbabb0dd 100644 --- a/pgconn.go +++ b/pgconn.go @@ -879,8 +879,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } // Send copy data - buf = make([]byte, 0, 20000) - // buf = make([]byte, 0, 65536) + buf = make([]byte, 0, 65536) buf = append(buf, 'd') sp := len(buf) var readErr error From a340d5f15f5d75eb0cc1f42fdd9996f666dc3224 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 17 May 2019 13:27:11 -0500 Subject: [PATCH 0279/1158] CopyFail should be frontend message --- backend.go | 3 +++ copy_fail.go | 2 +- frontend.go | 3 --- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/backend.go b/backend.go index 1121c550..7f11bc7f 100644 --- a/backend.go +++ b/backend.go @@ -14,6 +14,7 @@ type Backend struct { // Frontend message flyweights bind Bind _close Close + copyFail CopyFail describe Describe execute Execute flush Flush @@ -80,6 +81,8 @@ func (b *Backend) Receive() (FrontendMessage, error) { msg = &b.describe case 'E': msg = &b.execute + case 'f': + msg = &b.copyFail case 'H': msg = &b.flush case 'P': diff --git a/copy_fail.go b/copy_fail.go index eadffa9c..2f228a82 100644 --- a/copy_fail.go +++ b/copy_fail.go @@ -11,7 +11,7 @@ type CopyFail struct { Message string } -func (*CopyFail) Backend() {} +func (*CopyFail) Frontend() {} func (dst *CopyFail) Decode(src []byte) error { idx := bytes.IndexByte(src, 0) diff --git a/frontend.go b/frontend.go index ce94f49f..6fa03bce 100644 --- a/frontend.go +++ b/frontend.go @@ -22,7 +22,6 @@ type Frontend struct { copyInResponse CopyInResponse copyOutResponse CopyOutResponse copyDone CopyDone - copyFail CopyFail dataRow DataRow emptyQueryResponse EmptyQueryResponse errorResponse ErrorResponse @@ -82,8 +81,6 @@ func (b *Frontend) Receive() (BackendMessage, error) { msg = &b.dataRow case 'E': msg = &b.errorResponse - case 'f': - msg = &b.copyFail case 'G': msg = &b.copyInResponse case 'H': From 3294a8cf1f2701b7bcc229597e7e4081b5d49532 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 20 May 2019 16:26:58 -0500 Subject: [PATCH 0280/1158] Allow empty hstore keys See pgx commit: 56f4f0b9d319a910016ce044a53f52fcf986ddc6 --- hstore.go | 10 +++------- hstore_test.go | 1 + 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/hstore.go b/hstore.go index 522813ff..56af38ee 100644 --- a/hstore.go +++ b/hstore.go @@ -297,13 +297,9 @@ func parseHstore(s string) (k []string, v []Text, err error) { case hsKey: switch r { case '"': //End of the key - if buf.Len() == 0 { - err = errors.New("Empty Key is invalid") - } else { - keys = append(keys, buf.String()) - buf = bytes.Buffer{} - state = hsSep - } + keys = append(keys, buf.String()) + buf = bytes.Buffer{} + state = hsSep case '\\': //Potential escaped character n, end := p.Consume() switch { diff --git a/hstore_test.go b/hstore_test.go index ccd476dc..ba6c9373 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -19,6 +19,7 @@ func TestHstoreTranscode(t *testing.T) { &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"": text("bar")}, Status: pgtype.Present}, &pgtype.Hstore{Status: pgtype.Null}, } From de87e8be96e1ee042303fe7116d02692155e7504 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 27 May 2019 12:50:27 -0500 Subject: [PATCH 0281/1158] Fix: Use fallback config TLS config --- pgconn.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgconn.go b/pgconn.go index bbabb0dd..c51742ae 100644 --- a/pgconn.go +++ b/pgconn.go @@ -145,8 +145,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.parameterStatuses = make(map[string]string) - if config.TLSConfig != nil { - if err := pgConn.startTLS(config.TLSConfig); err != nil { + if fallbackConfig.TLSConfig != nil { + if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil { pgConn.conn.Close() return nil, err } From 71ec1f78211346069f77cf843fca96d5e62ba90c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 28 May 2019 06:54:20 -0500 Subject: [PATCH 0282/1158] Update xerrors package --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 4ad3564a..9401dce8 100644 --- a/go.mod +++ b/go.mod @@ -9,5 +9,5 @@ require ( github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 - golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 + golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 ) diff --git a/go.sum b/go.sum index 9e2398cb..1b6862a0 100644 --- a/go.sum +++ b/go.sum @@ -23,3 +23,5 @@ golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From a97dd2f9f6d06658a4d189720e2d3c8e0bf51f69 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Jun 2019 09:59:04 -0500 Subject: [PATCH 0283/1158] Update test envvar and docs --- .travis.yml | 2 +- README.md | 28 +++++++++++++ benchmark_test.go | 10 ++--- pgconn_stress_test.go | 2 +- pgconn_test.go | 94 +++++++++++++++++++++---------------------- 5 files changed, 82 insertions(+), 54 deletions(-) diff --git a/.travis.yml b/.travis.yml index 50e81eb5..e5ed43a8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,7 +11,7 @@ before_install: env: global: - GO111MODULE=on - - PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test + - PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_TLS_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require diff --git a/README.md b/README.md index 8a881009..05cfedf1 100644 --- a/README.md +++ b/README.md @@ -6,3 +6,31 @@ Package pgconn is a low-level PostgreSQL database driver. It is intended to serve as the foundation for the next generation of https://github.com/jackc/pgx. + +## Testing + +pgconn tests need a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING` +environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*` +environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify +environment variable handling. + +### Example Test Environment + +Connect to your PostgreSQL server and run: + +``` +create database pgx_test; +``` + +Now you can run the tests: + +``` +PGX_TEST_CONN_STRING="host=/var/run/postgresql database=pgx_test" go test ./... +``` + +### Connection and Authentication Tests + +There are multiple connection types and means of authentication that pgconn supports. These tests are optional. They +will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being +skipped. Typical developers will not need to enable these tests. See travis.yml for example setup if you need change +authentication code. diff --git a/benchmark_test.go b/benchmark_test.go index 073281aa..51e11e24 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -38,7 +38,7 @@ func BenchmarkConnect(b *testing.B) { } func BenchmarkExec(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.Nil(b, err) defer closeConn(b, conn) @@ -82,7 +82,7 @@ func BenchmarkExec(b *testing.B) { } func BenchmarkExecPossibleToCancel(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.Nil(b, err) defer closeConn(b, conn) @@ -129,7 +129,7 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { } func BenchmarkExecPrepared(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.Nil(b, err) defer closeConn(b, conn) @@ -167,7 +167,7 @@ func BenchmarkExecPrepared(b *testing.B) { } func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.Nil(b, err) defer closeConn(b, conn) @@ -208,7 +208,7 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { } // func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) { -// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) +// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) // require.Nil(b, err) // defer closeConn(b, conn) diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go index 7288c9b4..356b529a 100644 --- a/pgconn_stress_test.go +++ b/pgconn_stress_test.go @@ -14,7 +14,7 @@ import ( ) func TestConnStress(t *testing.T) { - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) diff --git a/pgconn_test.go b/pgconn_test.go index dcbbfc89..310b387b 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -112,7 +112,7 @@ func TestConnectWithConnectionRefused(t *testing.T) { func TestConnectCustomDialer(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) dialed := false @@ -130,7 +130,7 @@ func TestConnectCustomDialer(t *testing.T) { func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) config.RuntimeParams = map[string]string{ @@ -156,7 +156,7 @@ func TestConnectWithRuntimeParams(t *testing.T) { func TestConnectWithFallback(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) // Prepend current primary config to fallbacks @@ -189,7 +189,7 @@ func TestConnectWithFallback(t *testing.T) { func TestConnectWithAfterConnectFunc(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) dialCount := 0 @@ -228,7 +228,7 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) config.AfterConnectFunc = pgconn.AfterConnectTargetSessionAttrsReadWrite @@ -243,7 +243,7 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -257,7 +257,7 @@ func TestConnPrepareSyntaxError(t *testing.T) { func TestConnPrepareContextPrecanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -275,7 +275,7 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { func TestConnExec(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -294,7 +294,7 @@ func TestConnExec(t *testing.T) { func TestConnExecEmpty(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -315,7 +315,7 @@ func TestConnExecEmpty(t *testing.T) { func TestConnExecMultipleQueries(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -340,7 +340,7 @@ func TestConnExecMultipleQueries(t *testing.T) { func TestConnExecMultipleQueriesError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -362,7 +362,7 @@ func TestConnExecMultipleQueriesError(t *testing.T) { func TestConnExecContextCanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -380,7 +380,7 @@ func TestConnExecContextCanceled(t *testing.T) { func TestConnExecContextPrecanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -397,7 +397,7 @@ func TestConnExecContextPrecanceled(t *testing.T) { func TestConnExecParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -418,7 +418,7 @@ func TestConnExecParams(t *testing.T) { func TestConnExecParamsMaxNumberOfParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -441,7 +441,7 @@ func TestConnExecParamsMaxNumberOfParams(t *testing.T) { func TestConnExecParamsTooManyParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -464,7 +464,7 @@ func TestConnExecParamsTooManyParams(t *testing.T) { func TestConnExecParamsCanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -486,7 +486,7 @@ func TestConnExecParamsCanceled(t *testing.T) { func TestConnExecParamsPrecanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -503,7 +503,7 @@ func TestConnExecParamsPrecanceled(t *testing.T) { func TestConnExecPrepared(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -530,7 +530,7 @@ func TestConnExecPrepared(t *testing.T) { func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -559,7 +559,7 @@ func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { func TestConnExecPreparedTooManyParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -588,7 +588,7 @@ func TestConnExecPreparedTooManyParams(t *testing.T) { func TestConnExecPreparedCanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -612,7 +612,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { func TestConnExecPreparedPrecanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -632,7 +632,7 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { func TestConnExecBatch(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -664,7 +664,7 @@ func TestConnExecBatch(t *testing.T) { func TestConnExecBatchPrecanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -697,7 +697,7 @@ func TestConnExecBatchHuge(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -725,7 +725,7 @@ func TestConnExecBatchHuge(t *testing.T) { func TestConnExecBatchImplicitTransaction(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -748,7 +748,7 @@ func TestConnExecBatchImplicitTransaction(t *testing.T) { func TestConnLocking(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -795,7 +795,7 @@ func TestCommandTag(t *testing.T) { func TestConnOnNotice(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) var msg string @@ -821,7 +821,7 @@ end$$;`) func TestConnOnNotification(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) var msg string @@ -853,7 +853,7 @@ func TestConnOnNotification(t *testing.T) { func TestConnWaitForNotification(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) var msg string @@ -885,7 +885,7 @@ func TestConnWaitForNotification(t *testing.T) { func TestConnWaitForNotificationPrecanceled(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) pgConn, err := pgconn.ConnectConfig(context.Background(), config) @@ -903,7 +903,7 @@ func TestConnWaitForNotificationPrecanceled(t *testing.T) { func TestConnWaitForNotificationTimeout(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) pgConn, err := pgconn.ConnectConfig(context.Background(), config) @@ -921,7 +921,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { func TestConnCopyToSmall(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -959,7 +959,7 @@ func TestConnCopyToSmall(t *testing.T) { func TestConnCopyToLarge(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -997,7 +997,7 @@ func TestConnCopyToLarge(t *testing.T) { func TestConnCopyToQueryError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1014,7 +1014,7 @@ func TestConnCopyToQueryError(t *testing.T) { func TestConnCopyToCanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1032,7 +1032,7 @@ func TestConnCopyToCanceled(t *testing.T) { func TestConnCopyToPrecanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1052,7 +1052,7 @@ func TestConnCopyToPrecanceled(t *testing.T) { func TestConnCopyFrom(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1088,7 +1088,7 @@ func TestConnCopyFrom(t *testing.T) { func TestConnCopyFromCanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1123,7 +1123,7 @@ func TestConnCopyFromCanceled(t *testing.T) { func TestConnCopyFromPrecanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1160,7 +1160,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) { func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1217,7 +1217,7 @@ func TestConnCopyFromGzipReader(t *testing.T) { func TestConnCopyFromQuerySyntaxError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1240,7 +1240,7 @@ func TestConnCopyFromQuerySyntaxError(t *testing.T) { func TestConnCopyFromQueryNoTableError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1257,7 +1257,7 @@ func TestConnCopyFromQueryNoTableError(t *testing.T) { func TestConnEscapeString(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1285,7 +1285,7 @@ func TestConnEscapeString(t *testing.T) { func TestConnCancelRequest(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1310,7 +1310,7 @@ func TestConnCancelRequest(t *testing.T) { } func Example() { - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { log.Fatalln(err) } From 18e7e777be8bcae5c1bc24d4fb86f5d497012575 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jun 2019 10:26:26 -0500 Subject: [PATCH 0284/1158] Import PortalSuspended from pgx v3 0ab6f80f9929384a8cf6cfc299b43233534eb705 --- frontend.go | 3 +++ portal_suspended.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 portal_suspended.go diff --git a/frontend.go b/frontend.go index 6fa03bce..a0a5b493 100644 --- a/frontend.go +++ b/frontend.go @@ -34,6 +34,7 @@ type Frontend struct { parseComplete ParseComplete readyForQuery ReadyForQuery rowDescription RowDescription + portalSuspended PortalSuspended bodyLen int msgType byte @@ -95,6 +96,8 @@ func (b *Frontend) Receive() (BackendMessage, error) { msg = &b.noticeResponse case 'R': msg = &b.authentication + case 's': + msg = &b.portalSuspended case 'S': msg = &b.parameterStatus case 't': diff --git a/portal_suspended.go b/portal_suspended.go new file mode 100644 index 00000000..dc81b027 --- /dev/null +++ b/portal_suspended.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type PortalSuspended struct{} + +func (*PortalSuspended) Backend() {} + +func (dst *PortalSuspended) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "PortalSuspended", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *PortalSuspended) Encode(dst []byte) []byte { + return append(dst, 's', 0, 0, 0, 4) +} + +func (src *PortalSuspended) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "PortalSuspended", + }) +} From 4e0ed911f557f4ba29347b1a9ccb9ba1a3f2f693 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jun 2019 11:45:47 -0500 Subject: [PATCH 0285/1158] Import Fix for -0 numeric From pgx: d678216f468d1fe4dc28649feacd4b30a176769e --- numeric.go | 2 +- numeric_array_test.go | 15 +++++++++++++++ numeric_test.go | 3 +++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/numeric.go b/numeric.go index bbd7667a..45854e70 100644 --- a/numeric.go +++ b/numeric.go @@ -322,7 +322,7 @@ func parseNumericString(str string) (n *big.Int, exp int32, err error) { if len(parts) > 1 { exp = int32(-len(parts[1])) } else { - for len(digits) > 1 && digits[len(digits)-1] == '0' { + for len(digits) > 1 && digits[len(digits)-1] == '0' && digits[len(digits)-2] != '-' { digits = digits[:len(digits)-1] exp++ } diff --git a/numeric_array_test.go b/numeric_array_test.go index 9d608dea..eafd31be 100644 --- a/numeric_array_test.go +++ b/numeric_array_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "math" "math/big" "reflect" "testing" @@ -65,6 +66,13 @@ func TestNumericArraySet(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, + { + source: []float32{float32(math.Copysign(0, -1))}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(0), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, { source: []float64{1}, result: pgtype.NumericArray{ @@ -72,6 +80,13 @@ func TestNumericArraySet(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, + { + source: []float64{math.Copysign(0, -1)}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(0), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, { source: (([]float32)(nil)), result: pgtype.NumericArray{Status: pgtype.Null}, diff --git a/numeric_test.go b/numeric_test.go index 046c2f94..b925be83 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "math" "math/big" "math/rand" "reflect" @@ -188,7 +189,9 @@ func TestNumericSet(t *testing.T) { result *pgtype.Numeric }{ {source: float32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: float32(math.Copysign(0, -1)), result: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Present}}, {source: float64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: float64(math.Copysign(0, -1)), result: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Present}}, {source: int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, {source: int16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, {source: int32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, From ecdcf4a36773147f7fb46573cbb1ee7a6130ba0f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jun 2019 18:06:29 -0500 Subject: [PATCH 0286/1158] Rename Option to Config --- chunkreader.go | 25 +++++++++++++------------ chunkreader_test.go | 6 +++--- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/chunkreader.go b/chunkreader.go index fedd41e5..43363e42 100644 --- a/chunkreader.go +++ b/chunkreader.go @@ -12,31 +12,32 @@ type ChunkReader struct { buf []byte rp, wp int // buf read position and write position - options Options + config Config } -type Options struct { +// Config contains configuration parameters for ChunkReader. +type Config struct { MinBufLen int // Minimum buffer length } func NewChunkReader(r io.Reader) *ChunkReader { - cr, err := NewChunkReaderEx(r, Options{}) + cr, err := NewChunkReaderEx(r, Config{}) if err != nil { - panic("default options can't be bad") + panic("default config can't be bad") } return cr } -func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) { - if options.MinBufLen == 0 { - options.MinBufLen = 4096 +func NewChunkReaderEx(r io.Reader, config Config) (*ChunkReader, error) { + if config.MinBufLen == 0 { + config.MinBufLen = 4096 } return &ChunkReader{ - r: r, - buf: make([]byte, options.MinBufLen), - options: options, + r: r, + buf: make([]byte, config.MinBufLen), + config: config, }, nil } @@ -78,8 +79,8 @@ func (r *ChunkReader) appendAtLeast(fillLen int) error { } func (r *ChunkReader) newBuf(size int) []byte { - if size < r.options.MinBufLen { - size = r.options.MinBufLen + if size < r.config.MinBufLen { + size = r.config.MinBufLen } return make([]byte, size) } diff --git a/chunkreader_test.go b/chunkreader_test.go index 3be07e3c..66515e87 100644 --- a/chunkreader_test.go +++ b/chunkreader_test.go @@ -7,7 +7,7 @@ import ( func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) + r, err := NewChunkReaderEx(server, Config{MinBufLen: 4}) if err != nil { t.Fatal(err) } @@ -44,7 +44,7 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) + r, err := NewChunkReaderEx(server, Config{MinBufLen: 4}) if err != nil { t.Fatal(err) } @@ -66,7 +66,7 @@ func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { func TestChunkReaderDoesNotReuseBuf(t *testing.T) { server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) + r, err := NewChunkReaderEx(server, Config{MinBufLen: 4}) if err != nil { t.Fatal(err) } From 4e6b8011b67f7ee75e0cd8710295835b6a0c0abd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jun 2019 18:10:49 -0500 Subject: [PATCH 0287/1158] Shorten constructor function names --- chunkreader.go | 8 +++++--- chunkreader_test.go | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/chunkreader.go b/chunkreader.go index 43363e42..b47747f2 100644 --- a/chunkreader.go +++ b/chunkreader.go @@ -20,8 +20,9 @@ type Config struct { MinBufLen int // Minimum buffer length } -func NewChunkReader(r io.Reader) *ChunkReader { - cr, err := NewChunkReaderEx(r, Config{}) +// New creates and returns a new ChunkReader for r with default configuration. +func New(r io.Reader) *ChunkReader { + cr, err := NewConfig(r, Config{}) if err != nil { panic("default config can't be bad") } @@ -29,7 +30,8 @@ func NewChunkReader(r io.Reader) *ChunkReader { return cr } -func NewChunkReaderEx(r io.Reader, config Config) (*ChunkReader, error) { +// NewConfig creates and a new ChunkReader for r configured by config. +func NewConfig(r io.Reader, config Config) (*ChunkReader, error) { if config.MinBufLen == 0 { config.MinBufLen = 4096 } diff --git a/chunkreader_test.go b/chunkreader_test.go index 66515e87..67a20af2 100644 --- a/chunkreader_test.go +++ b/chunkreader_test.go @@ -7,7 +7,7 @@ import ( func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Config{MinBufLen: 4}) + r, err := NewConfig(server, Config{MinBufLen: 4}) if err != nil { t.Fatal(err) } @@ -44,7 +44,7 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Config{MinBufLen: 4}) + r, err := NewConfig(server, Config{MinBufLen: 4}) if err != nil { t.Fatal(err) } @@ -66,7 +66,7 @@ func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { func TestChunkReaderDoesNotReuseBuf(t *testing.T) { server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Config{MinBufLen: 4}) + r, err := NewConfig(server, Config{MinBufLen: 4}) if err != nil { t.Fatal(err) } From 21088f2cb5965119897433279957e2ad2b301ddd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jun 2019 18:29:13 -0500 Subject: [PATCH 0288/1158] Improve documentation --- README.md | 2 +- chunkreader.go | 14 ++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index bcc9ac6b..01209bfa 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,6 @@ # chunkreader -Package chunkreader provides an opinionated, efficient buffered reader. +Package chunkreader provides an io.Reader wrapper that minimizes IO reads and memory allocations. Extracted from original implementation in https://github.com/jackc/pgx. diff --git a/chunkreader.go b/chunkreader.go index b47747f2..36304fd5 100644 --- a/chunkreader.go +++ b/chunkreader.go @@ -1,11 +1,17 @@ -// Package chunkreader provides an opinionated, efficient buffered reader. +// Package chunkreader provides an io.Reader wrapper that minimizes IO reads and memory allocations. package chunkreader import ( "io" ) -// ChunkReader is a io.Reader wrapper that minimizes reads and memory allocations. +// ChunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and +// will read as much as will fit in the current buffer in a single call regardless of how large a read is actually +// requested. The memory returned via Next is owned by the caller. This avoids the need for an additional copy. +// +// The downside of this approach is that a large buffer can be pinned in memory even if only a small slice is +// referenced. For example, an entire 4096 byte block could be pinned in memory by even a 1 byte slice. In these rare +// cases it would be advantageous to copy the bytes to another slice. type ChunkReader struct { r io.Reader @@ -43,8 +49,8 @@ func NewConfig(r io.Reader, config Config) (*ChunkReader, error) { }, nil } -// Next returns buf filled with the next n bytes. If an error occurs, buf will -// be nil. +// Next returns buf filled with the next n bytes. The caller gains ownership of buf. It is not necessary to make a copy +// of buf. If an error occurs, buf will be nil. func (r *ChunkReader) Next(n int) (buf []byte, err error) { // n bytes already in buf if (r.wp - r.rp) >= n { From 2c463c0e7d0d0876517f087ce2cce66a46182141 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jun 2019 18:32:30 -0500 Subject: [PATCH 0289/1158] Release v2 --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index b1ed8c92..a1384b40 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/jackc/chunkreader +module github.com/jackc/chunkreader/v2 go 1.12 From bf3a27ae3f6edf9a05411024c4c789110c5b4997 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jun 2019 18:34:35 -0500 Subject: [PATCH 0290/1158] Update to github.com/jackc/chunkreader/v2 --- chunkreader.go | 4 ++-- go.mod | 1 + go.sum | 2 ++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/chunkreader.go b/chunkreader.go index 0acfd4bf..2eb278ea 100644 --- a/chunkreader.go +++ b/chunkreader.go @@ -3,7 +3,7 @@ package pgproto3 import ( "io" - "github.com/jackc/chunkreader" + "github.com/jackc/chunkreader/v2" ) // ChunkReader is an interface to decouple github.com/jackc/chunkreader from this package. @@ -14,5 +14,5 @@ type ChunkReader interface { } func NewChunkReader(r io.Reader) ChunkReader { - return chunkreader.NewChunkReader(r) + return chunkreader.New(r) } diff --git a/go.mod b/go.mod index 37cc5114..800f6043 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.12 require ( github.com/jackc/chunkreader v1.0.0 + github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/pkg/errors v0.8.1 ) diff --git a/go.sum b/go.sum index 887dd869..5dd456ad 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= From 432c2951c711384a18995a40538a7450dcd0c1e5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jun 2019 19:38:34 -0500 Subject: [PATCH 0291/1158] Add a lot of documentation --- authentication.go | 9 +++++++++ backend.go | 6 ++++++ backend_key_data.go | 5 +++++ bind.go | 5 +++++ bind_complete.go | 5 +++++ chunkreader.go | 1 + close.go | 5 +++++ close_complete.go | 5 +++++ command_complete.go | 5 +++++ copy_both_response.go | 5 +++++ copy_data.go | 9 ++++++++- copy_done.go | 5 +++++ copy_fail.go | 5 +++++ copy_in_response.go | 5 +++++ copy_out_response.go | 4 ++++ data_row.go | 5 +++++ describe.go | 5 +++++ doc.go | 2 ++ empty_query_response.go | 5 +++++ error_response.go | 4 ++++ execute.go | 5 +++++ flush.go | 5 +++++ frontend.go | 4 ++++ function_call_response.go | 5 +++++ no_data.go | 5 +++++ notice_response.go | 4 ++++ notification_response.go | 5 +++++ parameter_description.go | 5 +++++ parameter_status.go | 5 +++++ parse.go | 5 +++++ parse_complete.go | 5 +++++ password_message.go | 5 +++++ portal_suspended.go | 5 +++++ query.go | 5 +++++ ready_for_query.go | 5 +++++ row_description.go | 5 +++++ sasl_initial_response.go | 5 +++++ sasl_response.go | 5 +++++ startup_message.go | 5 +++++ sync.go | 5 +++++ terminate.go | 5 +++++ 41 files changed, 202 insertions(+), 1 deletion(-) diff --git a/authentication.go b/authentication.go index 2078c87c..bc654c4f 100644 --- a/authentication.go +++ b/authentication.go @@ -8,6 +8,7 @@ import ( "github.com/pkg/errors" ) +// Authentication message type constants. const ( AuthTypeOk = 0 AuthTypeCleartextPassword = 3 @@ -17,6 +18,10 @@ const ( AuthTypeSASLFinal = 12 ) +// Authentication is a message sent from the backend during the authentication process. +// +// There are multiple authentication messages that each begin with 'R'. This structure represents all such +// authentication messages. type Authentication struct { Type uint32 @@ -30,8 +35,11 @@ type Authentication struct { SASLData []byte } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*Authentication) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Authentication) Decode(src []byte) error { *dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])} @@ -58,6 +66,7 @@ func (dst *Authentication) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *Authentication) Encode(dst []byte) []byte { dst = append(dst, 'R') sp := len(dst) diff --git a/backend.go b/backend.go index 7f11bc7f..2e2f5eea 100644 --- a/backend.go +++ b/backend.go @@ -7,6 +7,7 @@ import ( "github.com/pkg/errors" ) +// Backend acts as a server for the PostgreSQL wire protocol version 3. type Backend struct { cr ChunkReader w io.Writer @@ -30,15 +31,19 @@ type Backend struct { partialMsg bool } +// NewBackend creates a new Backend. func NewBackend(cr ChunkReader, w io.Writer) (*Backend, error) { return &Backend{cr: cr, w: w}, nil } +// Send sends a message to the frontend. func (b *Backend) Send(msg BackendMessage) error { _, err := b.w.Write(msg.Encode(nil)) return err } +// ReceiveStartupMessage receives the initial startup message. This method is used of the normal Receive method +// because StartupMessage and SSLRequest are "special" and do not include the message type as the first byte. func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { buf, err := b.cr.Next(4) if err != nil { @@ -59,6 +64,7 @@ func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { return &b.startupMessage, nil } +// Receive receives a message from the frontend. func (b *Backend) Receive() (FrontendMessage, error) { if !b.partialMsg { header, err := b.cr.Next(5) diff --git a/backend_key_data.go b/backend_key_data.go index 0396379b..b775d689 100644 --- a/backend_key_data.go +++ b/backend_key_data.go @@ -12,8 +12,11 @@ type BackendKeyData struct { SecretKey uint32 } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*BackendKeyData) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *BackendKeyData) Decode(src []byte) error { if len(src) != 8 { return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} @@ -25,6 +28,7 @@ func (dst *BackendKeyData) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *BackendKeyData) Encode(dst []byte) []byte { dst = append(dst, 'K') dst = pgio.AppendUint32(dst, 12) @@ -33,6 +37,7 @@ func (src *BackendKeyData) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *BackendKeyData) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/bind.go b/bind.go index 459e5ff2..67d20b5d 100644 --- a/bind.go +++ b/bind.go @@ -17,8 +17,11 @@ type Bind struct { ResultFormatCodes []int16 } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Bind) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Bind) Decode(src []byte) error { *dst = Bind{} @@ -103,6 +106,7 @@ func (dst *Bind) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *Bind) Encode(dst []byte) []byte { dst = append(dst, 'B') sp := len(dst) @@ -139,6 +143,7 @@ func (src *Bind) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *Bind) MarshalJSON() ([]byte, error) { formattedParameters := make([]map[string]string, len(src.Parameters)) for i, p := range src.Parameters { diff --git a/bind_complete.go b/bind_complete.go index 60360519..fc9d317a 100644 --- a/bind_complete.go +++ b/bind_complete.go @@ -6,8 +6,11 @@ import ( type BindComplete struct{} +// Backend identifies this message as sendable by the PostgreSQL backend. func (*BindComplete) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *BindComplete) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)} @@ -16,10 +19,12 @@ func (dst *BindComplete) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *BindComplete) Encode(dst []byte) []byte { return append(dst, '2', 0, 0, 0, 4) } +// MarshalJSON implements encoding/json.Marshaler. func (src *BindComplete) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/chunkreader.go b/chunkreader.go index 2eb278ea..92206f35 100644 --- a/chunkreader.go +++ b/chunkreader.go @@ -13,6 +13,7 @@ type ChunkReader interface { Next(n int) (buf []byte, err error) } +// NewChunkReader creates and returns a new default ChunkReader. func NewChunkReader(r io.Reader) ChunkReader { return chunkreader.New(r) } diff --git a/close.go b/close.go index 4e497549..349a319d 100644 --- a/close.go +++ b/close.go @@ -12,8 +12,11 @@ type Close struct { Name string } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Close) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Close) Decode(src []byte) error { if len(src) < 2 { return &invalidMessageFormatErr{messageType: "Close"} @@ -32,6 +35,7 @@ func (dst *Close) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *Close) Encode(dst []byte) []byte { dst = append(dst, 'C') sp := len(dst) @@ -46,6 +50,7 @@ func (src *Close) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *Close) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/close_complete.go b/close_complete.go index db793c94..b4982207 100644 --- a/close_complete.go +++ b/close_complete.go @@ -6,8 +6,11 @@ import ( type CloseComplete struct{} +// Backend identifies this message as sendable by the PostgreSQL backend. func (*CloseComplete) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *CloseComplete) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)} @@ -16,10 +19,12 @@ func (dst *CloseComplete) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *CloseComplete) Encode(dst []byte) []byte { return append(dst, '3', 0, 0, 0, 4) } +// MarshalJSON implements encoding/json.Marshaler. func (src *CloseComplete) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/command_complete.go b/command_complete.go index adcc56c8..87fcddf6 100644 --- a/command_complete.go +++ b/command_complete.go @@ -11,8 +11,11 @@ type CommandComplete struct { CommandTag []byte } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*CommandComplete) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *CommandComplete) Decode(src []byte) error { idx := bytes.IndexByte(src, 0) if idx != len(src)-1 { @@ -24,6 +27,7 @@ func (dst *CommandComplete) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *CommandComplete) Encode(dst []byte) []byte { dst = append(dst, 'C') sp := len(dst) @@ -37,6 +41,7 @@ func (src *CommandComplete) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *CommandComplete) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/copy_both_response.go b/copy_both_response.go index aa59d52a..b037a197 100644 --- a/copy_both_response.go +++ b/copy_both_response.go @@ -13,8 +13,11 @@ type CopyBothResponse struct { ColumnFormatCodes []uint16 } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*CopyBothResponse) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *CopyBothResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) @@ -39,6 +42,7 @@ func (dst *CopyBothResponse) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *CopyBothResponse) Encode(dst []byte) []byte { dst = append(dst, 'W') sp := len(dst) @@ -54,6 +58,7 @@ func (src *CopyBothResponse) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *CopyBothResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/copy_data.go b/copy_data.go index 490d3d80..317710ac 100644 --- a/copy_data.go +++ b/copy_data.go @@ -11,14 +11,20 @@ type CopyData struct { Data []byte } -func (*CopyData) Backend() {} +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*CopyData) Backend() {} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*CopyData) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *CopyData) Decode(src []byte) error { dst.Data = src return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *CopyData) Encode(dst []byte) []byte { dst = append(dst, 'd') dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) @@ -26,6 +32,7 @@ func (src *CopyData) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *CopyData) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/copy_done.go b/copy_done.go index 92481908..7612350a 100644 --- a/copy_done.go +++ b/copy_done.go @@ -7,8 +7,11 @@ import ( type CopyDone struct { } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*CopyDone) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *CopyDone) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "CopyDone", expectedLen: 0, actualLen: len(src)} @@ -17,10 +20,12 @@ func (dst *CopyDone) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *CopyDone) Encode(dst []byte) []byte { return append(dst, 'c', 0, 0, 0, 4) } +// MarshalJSON implements encoding/json.Marshaler. func (src *CopyDone) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/copy_fail.go b/copy_fail.go index 2f228a82..b12d7ba0 100644 --- a/copy_fail.go +++ b/copy_fail.go @@ -11,8 +11,11 @@ type CopyFail struct { Message string } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*CopyFail) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *CopyFail) Decode(src []byte) error { idx := bytes.IndexByte(src, 0) if idx != len(src)-1 { @@ -24,6 +27,7 @@ func (dst *CopyFail) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *CopyFail) Encode(dst []byte) []byte { dst = append(dst, 'f') sp := len(dst) @@ -37,6 +41,7 @@ func (src *CopyFail) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *CopyFail) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/copy_in_response.go b/copy_in_response.go index 3ddeeb40..d28baa33 100644 --- a/copy_in_response.go +++ b/copy_in_response.go @@ -13,8 +13,11 @@ type CopyInResponse struct { ColumnFormatCodes []uint16 } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*CopyInResponse) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *CopyInResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) @@ -39,6 +42,7 @@ func (dst *CopyInResponse) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *CopyInResponse) Encode(dst []byte) []byte { dst = append(dst, 'G') sp := len(dst) @@ -54,6 +58,7 @@ func (src *CopyInResponse) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *CopyInResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/copy_out_response.go b/copy_out_response.go index eb6fb50e..1d3c2364 100644 --- a/copy_out_response.go +++ b/copy_out_response.go @@ -15,6 +15,8 @@ type CopyOutResponse struct { func (*CopyOutResponse) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *CopyOutResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) @@ -39,6 +41,7 @@ func (dst *CopyOutResponse) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *CopyOutResponse) Encode(dst []byte) []byte { dst = append(dst, 'H') sp := len(dst) @@ -56,6 +59,7 @@ func (src *CopyOutResponse) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *CopyOutResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/data_row.go b/data_row.go index 0da18b06..9d6a9f1f 100644 --- a/data_row.go +++ b/data_row.go @@ -12,8 +12,11 @@ type DataRow struct { Values [][]byte } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*DataRow) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *DataRow) Decode(src []byte) error { if len(src) < 2 { return &invalidMessageFormatErr{messageType: "DataRow"} @@ -59,6 +62,7 @@ func (dst *DataRow) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *DataRow) Encode(dst []byte) []byte { dst = append(dst, 'D') sp := len(dst) @@ -80,6 +84,7 @@ func (src *DataRow) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *DataRow) MarshalJSON() ([]byte, error) { formattedValues := make([]map[string]string, len(src.Values)) for i, v := range src.Values { diff --git a/describe.go b/describe.go index 86016ebc..d3fb5b09 100644 --- a/describe.go +++ b/describe.go @@ -12,8 +12,11 @@ type Describe struct { Name string } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Describe) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Describe) Decode(src []byte) error { if len(src) < 2 { return &invalidMessageFormatErr{messageType: "Describe"} @@ -32,6 +35,7 @@ func (dst *Describe) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *Describe) Encode(dst []byte) []byte { dst = append(dst, 'D') sp := len(dst) @@ -46,6 +50,7 @@ func (src *Describe) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *Describe) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/doc.go b/doc.go index 75340210..8226dc98 100644 --- a/doc.go +++ b/doc.go @@ -1,2 +1,4 @@ // Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. +// +// See https://www.postgresql.org/docs/current/protocol-message-formats.html for meanings of the different messages. package pgproto3 diff --git a/empty_query_response.go b/empty_query_response.go index d283b06d..1bec52e2 100644 --- a/empty_query_response.go +++ b/empty_query_response.go @@ -6,8 +6,11 @@ import ( type EmptyQueryResponse struct{} +// Backend identifies this message as sendable by the PostgreSQL backend. func (*EmptyQueryResponse) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *EmptyQueryResponse) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)} @@ -16,10 +19,12 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *EmptyQueryResponse) Encode(dst []byte) []byte { return append(dst, 'I', 0, 0, 0, 4) } +// MarshalJSON implements encoding/json.Marshaler. func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/error_response.go b/error_response.go index 987fe38a..d444798b 100644 --- a/error_response.go +++ b/error_response.go @@ -28,8 +28,11 @@ type ErrorResponse struct { UnknownFields map[byte]string } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*ErrorResponse) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *ErrorResponse) Decode(src []byte) error { *dst = ErrorResponse{} @@ -103,6 +106,7 @@ func (dst *ErrorResponse) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *ErrorResponse) Encode(dst []byte) []byte { return append(dst, src.marshalBinary('E')...) } diff --git a/execute.go b/execute.go index 71713f49..32269857 100644 --- a/execute.go +++ b/execute.go @@ -13,8 +13,11 @@ type Execute struct { MaxRows uint32 } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Execute) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Execute) Decode(src []byte) error { buf := bytes.NewBuffer(src) @@ -32,6 +35,7 @@ func (dst *Execute) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *Execute) Encode(dst []byte) []byte { dst = append(dst, 'E') sp := len(dst) @@ -47,6 +51,7 @@ func (src *Execute) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *Execute) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/flush.go b/flush.go index 7fd5e987..e7bc7e43 100644 --- a/flush.go +++ b/flush.go @@ -6,8 +6,11 @@ import ( type Flush struct{} +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Flush) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Flush) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)} @@ -16,10 +19,12 @@ func (dst *Flush) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *Flush) Encode(dst []byte) []byte { return append(dst, 'H', 0, 0, 0, 4) } +// MarshalJSON implements encoding/json.Marshaler. func (src *Flush) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/frontend.go b/frontend.go index a0a5b493..bcf796e1 100644 --- a/frontend.go +++ b/frontend.go @@ -7,6 +7,7 @@ import ( "github.com/pkg/errors" ) +// Frontend acts as a client for the PostgreSQL wire protocol version 3. type Frontend struct { cr ChunkReader w io.Writer @@ -41,15 +42,18 @@ type Frontend struct { partialMsg bool } +// NewFrontend creates a new Frontend. func NewFrontend(cr ChunkReader, w io.Writer) (*Frontend, error) { return &Frontend{cr: cr, w: w}, nil } +// Send sends a message to the backend. func (b *Frontend) Send(msg FrontendMessage) error { _, err := b.w.Write(msg.Encode(nil)) return err } +// Receive receives a message from the backend. func (b *Frontend) Receive() (BackendMessage, error) { if !b.partialMsg { header, err := b.cr.Next(5) diff --git a/function_call_response.go b/function_call_response.go index f14f8452..72bb907c 100644 --- a/function_call_response.go +++ b/function_call_response.go @@ -12,8 +12,11 @@ type FunctionCallResponse struct { Result []byte } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*FunctionCallResponse) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *FunctionCallResponse) Decode(src []byte) error { if len(src) < 4 { return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} @@ -35,6 +38,7 @@ func (dst *FunctionCallResponse) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *FunctionCallResponse) Encode(dst []byte) []byte { dst = append(dst, 'V') sp := len(dst) @@ -52,6 +56,7 @@ func (src *FunctionCallResponse) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) { var formattedValue map[string]string var hasNonPrintable bool diff --git a/no_data.go b/no_data.go index 1fb47c2a..172d0dc1 100644 --- a/no_data.go +++ b/no_data.go @@ -6,8 +6,11 @@ import ( type NoData struct{} +// Backend identifies this message as sendable by the PostgreSQL backend. func (*NoData) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *NoData) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)} @@ -16,10 +19,12 @@ func (dst *NoData) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *NoData) Encode(dst []byte) []byte { return append(dst, 'n', 0, 0, 0, 4) } +// MarshalJSON implements encoding/json.Marshaler. func (src *NoData) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/notice_response.go b/notice_response.go index e4595aa5..4ac28a79 100644 --- a/notice_response.go +++ b/notice_response.go @@ -2,12 +2,16 @@ package pgproto3 type NoticeResponse ErrorResponse +// Backend identifies this message as sendable by the PostgreSQL backend. func (*NoticeResponse) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *NoticeResponse) Decode(src []byte) error { return (*ErrorResponse)(dst).Decode(src) } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *NoticeResponse) Encode(dst []byte) []byte { return append(dst, (*ErrorResponse)(src).marshalBinary('N')...) } diff --git a/notification_response.go b/notification_response.go index 2b32b10c..33170f66 100644 --- a/notification_response.go +++ b/notification_response.go @@ -14,8 +14,11 @@ type NotificationResponse struct { Payload string } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*NotificationResponse) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *NotificationResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) @@ -37,6 +40,7 @@ func (dst *NotificationResponse) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *NotificationResponse) Encode(dst []byte) []byte { dst = append(dst, 'A') sp := len(dst) @@ -52,6 +56,7 @@ func (src *NotificationResponse) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *NotificationResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/parameter_description.go b/parameter_description.go index 9d964129..a43e802e 100644 --- a/parameter_description.go +++ b/parameter_description.go @@ -12,8 +12,11 @@ type ParameterDescription struct { ParameterOIDs []uint32 } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*ParameterDescription) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *ParameterDescription) Decode(src []byte) error { buf := bytes.NewBuffer(src) @@ -35,6 +38,7 @@ func (dst *ParameterDescription) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *ParameterDescription) Encode(dst []byte) []byte { dst = append(dst, 't') sp := len(dst) @@ -50,6 +54,7 @@ func (src *ParameterDescription) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *ParameterDescription) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/parameter_status.go b/parameter_status.go index d370a4c1..4385fe99 100644 --- a/parameter_status.go +++ b/parameter_status.go @@ -12,8 +12,11 @@ type ParameterStatus struct { Value string } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*ParameterStatus) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *ParameterStatus) Decode(src []byte) error { buf := bytes.NewBuffer(src) @@ -33,6 +36,7 @@ func (dst *ParameterStatus) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *ParameterStatus) Encode(dst []byte) []byte { dst = append(dst, 'S') sp := len(dst) @@ -48,6 +52,7 @@ func (src *ParameterStatus) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (ps *ParameterStatus) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/parse.go b/parse.go index 6f17175b..d0bbf865 100644 --- a/parse.go +++ b/parse.go @@ -14,8 +14,11 @@ type Parse struct { ParameterOIDs []uint32 } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Parse) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Parse) Decode(src []byte) error { *dst = Parse{} @@ -48,6 +51,7 @@ func (dst *Parse) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *Parse) Encode(dst []byte) []byte { dst = append(dst, 'P') sp := len(dst) @@ -68,6 +72,7 @@ func (src *Parse) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *Parse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/parse_complete.go b/parse_complete.go index 462a89ba..c2d3a34d 100644 --- a/parse_complete.go +++ b/parse_complete.go @@ -6,8 +6,11 @@ import ( type ParseComplete struct{} +// Backend identifies this message as sendable by the PostgreSQL backend. func (*ParseComplete) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *ParseComplete) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)} @@ -16,10 +19,12 @@ func (dst *ParseComplete) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *ParseComplete) Encode(dst []byte) []byte { return append(dst, '1', 0, 0, 0, 4) } +// MarshalJSON implements encoding/json.Marshaler. func (src *ParseComplete) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/password_message.go b/password_message.go index 30377cbe..b01316e9 100644 --- a/password_message.go +++ b/password_message.go @@ -11,8 +11,11 @@ type PasswordMessage struct { Password string } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*PasswordMessage) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *PasswordMessage) Decode(src []byte) error { buf := bytes.NewBuffer(src) @@ -25,6 +28,7 @@ func (dst *PasswordMessage) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *PasswordMessage) Encode(dst []byte) []byte { dst = append(dst, 'p') dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1)) @@ -35,6 +39,7 @@ func (src *PasswordMessage) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *PasswordMessage) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/portal_suspended.go b/portal_suspended.go index dc81b027..5603d95e 100644 --- a/portal_suspended.go +++ b/portal_suspended.go @@ -6,8 +6,11 @@ import ( type PortalSuspended struct{} +// Backend identifies this message as sendable by the PostgreSQL backend. func (*PortalSuspended) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *PortalSuspended) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "PortalSuspended", expectedLen: 0, actualLen: len(src)} @@ -16,10 +19,12 @@ func (dst *PortalSuspended) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *PortalSuspended) Encode(dst []byte) []byte { return append(dst, 's', 0, 0, 0, 4) } +// MarshalJSON implements encoding/json.Marshaler. func (src *PortalSuspended) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/query.go b/query.go index 16228cb4..17377dfb 100644 --- a/query.go +++ b/query.go @@ -11,8 +11,11 @@ type Query struct { String string } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Query) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Query) Decode(src []byte) error { i := bytes.IndexByte(src, 0) if i != len(src)-1 { @@ -24,6 +27,7 @@ func (dst *Query) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *Query) Encode(dst []byte) []byte { dst = append(dst, 'Q') dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1)) @@ -34,6 +38,7 @@ func (src *Query) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *Query) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/ready_for_query.go b/ready_for_query.go index 63b902bd..65f7d8c1 100644 --- a/ready_for_query.go +++ b/ready_for_query.go @@ -8,8 +8,11 @@ type ReadyForQuery struct { TxStatus byte } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*ReadyForQuery) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *ReadyForQuery) Decode(src []byte) error { if len(src) != 1 { return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)} @@ -20,10 +23,12 @@ func (dst *ReadyForQuery) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *ReadyForQuery) Encode(dst []byte) []byte { return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus) } +// MarshalJSON implements encoding/json.Marshaler. func (src *ReadyForQuery) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/row_description.go b/row_description.go index 1b1734dc..87479188 100644 --- a/row_description.go +++ b/row_description.go @@ -27,8 +27,11 @@ type RowDescription struct { Fields []FieldDescription } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*RowDescription) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *RowDescription) Decode(src []byte) error { if len(src) < 2 { @@ -74,6 +77,7 @@ func (dst *RowDescription) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *RowDescription) Encode(dst []byte) []byte { dst = append(dst, 'T') sp := len(dst) @@ -97,6 +101,7 @@ func (src *RowDescription) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *RowDescription) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/sasl_initial_response.go b/sasl_initial_response.go index 63766131..b9459e16 100644 --- a/sasl_initial_response.go +++ b/sasl_initial_response.go @@ -14,8 +14,11 @@ type SASLInitialResponse struct { Data []byte } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*SASLInitialResponse) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *SASLInitialResponse) Decode(src []byte) error { *dst = SASLInitialResponse{} @@ -35,6 +38,7 @@ func (dst *SASLInitialResponse) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *SASLInitialResponse) Encode(dst []byte) []byte { dst = append(dst, 'p') sp := len(dst) @@ -51,6 +55,7 @@ func (src *SASLInitialResponse) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *SASLInitialResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/sasl_response.go b/sasl_response.go index 1e8d3bd3..ef893437 100644 --- a/sasl_response.go +++ b/sasl_response.go @@ -11,13 +11,17 @@ type SASLResponse struct { Data []byte } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*SASLResponse) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *SASLResponse) Decode(src []byte) error { *dst = SASLResponse{Data: src} return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *SASLResponse) Encode(dst []byte) []byte { dst = append(dst, 'p') dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) @@ -27,6 +31,7 @@ func (src *SASLResponse) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *SASLResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/startup_message.go b/startup_message.go index 93a3d992..d9f04d17 100644 --- a/startup_message.go +++ b/startup_message.go @@ -19,8 +19,11 @@ type StartupMessage struct { Parameters map[string]string } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*StartupMessage) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *StartupMessage) Decode(src []byte) error { if len(src) < 4 { return errors.Errorf("startup message too short") @@ -66,6 +69,7 @@ func (dst *StartupMessage) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *StartupMessage) Encode(dst []byte) []byte { sp := len(dst) dst = pgio.AppendInt32(dst, -1) @@ -84,6 +88,7 @@ func (src *StartupMessage) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. func (src *StartupMessage) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/sync.go b/sync.go index 85f4749a..a058e8c9 100644 --- a/sync.go +++ b/sync.go @@ -6,8 +6,11 @@ import ( type Sync struct{} +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Sync) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Sync) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)} @@ -16,10 +19,12 @@ func (dst *Sync) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *Sync) Encode(dst []byte) []byte { return append(dst, 'S', 0, 0, 0, 4) } +// MarshalJSON implements encoding/json.Marshaler. func (src *Sync) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/terminate.go b/terminate.go index 0a3310da..6c9d5b1a 100644 --- a/terminate.go +++ b/terminate.go @@ -6,8 +6,11 @@ import ( type Terminate struct{} +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Terminate) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Terminate) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)} @@ -16,10 +19,12 @@ func (dst *Terminate) Decode(src []byte) error { return nil } +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. func (src *Terminate) Encode(dst []byte) []byte { return append(dst, 'X', 0, 0, 0, 4) } +// MarshalJSON implements encoding/json.Marshaler. func (src *Terminate) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string From 529805557f0334621e115ca0dabf6dcf9b5a38bb Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sat, 22 Jun 2019 10:41:01 +0300 Subject: [PATCH 0292/1158] Fix linters notifications Signed-off-by: Artemiy Ryabinkov --- auth_scram.go | 2 +- benchmark_test.go | 15 ++++++++------- internal/ctxwatch/context_watcher_test.go | 3 +++ pgconn.go | 18 +++++++++--------- pgconn_test.go | 11 ++++++----- 5 files changed, 27 insertions(+), 22 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index d102d305..bdaf3e92 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -249,7 +249,7 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte { func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { serverKey := computeHMAC(saltedPassword, []byte("Server Key")) - serverSignature := computeHMAC(serverKey[:], authMessage) + serverSignature := computeHMAC(serverKey, authMessage) buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) base64.StdEncoding.Encode(buf, serverSignature) return buf diff --git a/benchmark_test.go b/benchmark_test.go index 51e11e24..8067c985 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -20,6 +20,7 @@ func BenchmarkConnect(b *testing.B) { } for _, bm := range benchmarks { + bm := bm b.Run(bm.name, func(b *testing.B) { connString := os.Getenv(bm.env) if connString == "" { @@ -54,12 +55,12 @@ func BenchmarkExec(b *testing.B) { rowCount := 0 for rr.NextRow() { - rowCount += 1 + rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } @@ -101,12 +102,12 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { rowCount := 0 for rr.NextRow() { - rowCount += 1 + rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } @@ -145,12 +146,12 @@ func BenchmarkExecPrepared(b *testing.B) { rowCount := 0 for rr.NextRow() { - rowCount += 1 + rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } @@ -191,7 +192,7 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } diff --git a/internal/ctxwatch/context_watcher_test.go b/internal/ctxwatch/context_watcher_test.go index 0b491bf8..a1b3c863 100644 --- a/internal/ctxwatch/context_watcher_test.go +++ b/internal/ctxwatch/context_watcher_test.go @@ -87,6 +87,9 @@ func TestContextWatcherStress(t *testing.T) { if i%2 == 1 { cancel() } + + // To avoid context leak + cancel() } actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls) diff --git a/pgconn.go b/pgconn.go index c51742ae..9e4f6253 100644 --- a/pgconn.go +++ b/pgconn.go @@ -241,16 +241,16 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { return nil } -func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { +func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { switch msg.Type { case pgproto3.AuthTypeOk: case pgproto3.AuthTypeCleartextPassword: - err = c.txPasswordMessage(c.Config.Password) + err = pgConn.txPasswordMessage(pgConn.Config.Password) case pgproto3.AuthTypeMD5Password: - digestedPassword := "md5" + hexMD5(hexMD5(c.Config.Password+c.Config.User)+string(msg.Salt[:])) - err = c.txPasswordMessage(digestedPassword) + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.Config.Password+pgConn.Config.User)+string(msg.Salt[:])) + err = pgConn.txPasswordMessage(digestedPassword) case pgproto3.AuthTypeSASL: - err = c.scramAuth(msg.SASLAuthMechanisms) + err = pgConn.scramAuth(msg.SASLAuthMechanisms) default: err = errors.New("Received unknown authentication message") } @@ -514,11 +514,11 @@ readloop: func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ - Severity: string(msg.Severity), + Severity: msg.Severity, Code: string(msg.Code), Message: string(msg.Message), Detail: string(msg.Detail), - Hint: string(msg.Hint), + Hint: msg.Hint, Position: msg.Position, InternalPosition: msg.InternalPosition, InternalQuery: string(msg.InternalQuery), @@ -527,7 +527,7 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { TableName: string(msg.TableName), ColumnName: string(msg.ColumnName), DataTypeName: string(msg.DataTypeName), - ConstraintName: string(msg.ConstraintName), + ConstraintName: msg.ConstraintName, File: string(msg.File), Line: msg.Line, Routine: string(msg.Routine), @@ -919,7 +919,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co copyDone := &pgproto3.CopyDone{} buf = copyDone.Encode(buf) } else { - copyFail := &pgproto3.CopyFail{Error: readErr.Error()} + copyFail := &pgproto3.CopyFail{Message: readErr.Error()} buf = copyFail.Encode(buf) } _, err = pgConn.conn.Write(buf) diff --git a/pgconn_test.go b/pgconn_test.go index 310b387b..4389fe99 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -37,6 +37,7 @@ func TestConnect(t *testing.T) { } for _, tt := range tests { + tt := tt t.Run(tt.name, func(t *testing.T) { connString := os.Getenv(tt.env) if connString == "" { @@ -194,13 +195,13 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { dialCount := 0 config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialCount += 1 + dialCount++ return net.Dial(network, address) } acceptConnCount := 0 config.AfterConnectFunc = func(ctx context.Context, conn *pgconn.PgConn) error { - acceptConnCount += 1 + acceptConnCount++ if acceptConnCount < 2 { return errors.New("reject first conn") } @@ -302,7 +303,7 @@ func TestConnExecEmpty(t *testing.T) { resultCount := 0 for multiResult.NextResult() { - resultCount += 1 + resultCount++ multiResult.ResultReader().Close() } assert.Equal(t, 0, resultCount) @@ -753,12 +754,12 @@ func TestConnLocking(t *testing.T) { defer closeConn(t, pgConn) mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.Error(t, err) assert.True(t, errors.Is(err, pgconn.ErrConnBusy)) assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) - results, err = mrr.ReadAll() + results, err := mrr.ReadAll() assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err) From 54ce9c6bb807f53394115ce7849f9e083aea095a Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sat, 22 Jun 2019 14:35:17 +0300 Subject: [PATCH 0293/1158] Update pgproto3 dependency Signed-off-by: Artemiy Ryabinkov --- .gitignore | 1 + go.mod | 2 +- go.sum | 4 ++++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 7a6353d6..6eb9d442 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ .envrc +vendor/ \ No newline at end of file diff --git a/go.mod b/go.mod index 9401dce8..b1c84049 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.12 require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db + github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 diff --git a/go.sum b/go.sum index 1b6862a0..50dfc2fd 100644 --- a/go.sum +++ b/go.sum @@ -2,12 +2,16 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= From 07904bd774d7009cd55030606fe5d30e1329c6c0 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sat, 22 Jun 2019 20:09:55 +0300 Subject: [PATCH 0294/1158] Remove unnecassary ctx cancel Signed-off-by: Artemiy Ryabinkov --- internal/ctxwatch/context_watcher_test.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/internal/ctxwatch/context_watcher_test.go b/internal/ctxwatch/context_watcher_test.go index a1b3c863..0b491bf8 100644 --- a/internal/ctxwatch/context_watcher_test.go +++ b/internal/ctxwatch/context_watcher_test.go @@ -87,9 +87,6 @@ func TestContextWatcherStress(t *testing.T) { if i%2 == 1 { cancel() } - - // To avoid context leak - cancel() } actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls) From d2440c7fe62ef4c392bbae34eef01e4a6865ed03 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jun 2019 16:54:10 -0500 Subject: [PATCH 0295/1158] Improve documentation --- README.md | 24 ++++++++++++++++++++++-- config.go | 7 +++++++ pgconn.go | 2 +- 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 05cfedf1..9e35a0f5 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,29 @@ # pgconn -Package pgconn is a low-level PostgreSQL database driver. +Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level is the C library libpq. +It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx. +Applications should handle normal queries with a higher level library and only use pgconn directly when required for +low-level access to PostgreSQL functionality. -It is intended to serve as the foundation for the next generation of https://github.com/jackc/pgx. +## Example Usage + +```go +pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) +if err != nil { + log.Fatalln("pgconn failed to connect:", err) +} +defer pgConn.Close() + +result := pgConn.ExecParams(context.Background(), "select email from users where id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +for result.NextRow() { + fmt.Println("User 123 has email:", string(result.Values()[0])) +} +_, err := result.Close() +if err != nil { + log.Fatalln("failed reading result:", err) +}) +``` ## Testing diff --git a/config.go b/config.go index c751cc0d..98755b1f 100644 --- a/config.go +++ b/config.go @@ -121,6 +121,13 @@ func NetworkAddress(host string, port uint16) (network, address string) { // security guarantees than it would with libpq. Do not rely on this behavior as it // may be possible to match libpq in the future. If you need full security use // "verify-full". +// +// Other known differences with libpq: +// +// If a host name resolves into multiple addresses, libpq will try all addresses. pgconn will only try the first. +// +// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn +// does not. func ParseConfig(connString string) (*Config, error) { settings := defaultSettings() addEnvSettings(settings) diff --git a/pgconn.go b/pgconn.go index 9e4f6253..3deb8563 100644 --- a/pgconn.go +++ b/pgconn.go @@ -390,7 +390,7 @@ func (pgConn *PgConn) hardClose() error { return pgConn.conn.Close() } -// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect deatch of +// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect death of // underlying connection. func (pgConn *PgConn) IsAlive() bool { return pgConn.status >= connStatusIdle From 59941377c8ff1467e4805f20e7aef29201e72e2c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Jul 2019 09:52:22 -0500 Subject: [PATCH 0296/1158] Rename Config.AfterConnectFunc to AfterConnect No need to include the type in the name. --- config.go | 6 +++--- config_test.go | 18 +++++++++--------- pgconn.go | 6 +++--- pgconn_test.go | 6 +++--- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/config.go b/config.go index 98755b1f..533791c2 100644 --- a/config.go +++ b/config.go @@ -36,10 +36,10 @@ type Config struct { Fallbacks []*FallbackConfig - // AfterConnectFunc is called after successful connection. It can be used to set up the connection or to validate that + // AfterConnect is called after successful connection. It can be used to set up the connection or to validate that // server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This // allows implementing high availability behavior such as libpq does with target_session_attrs. - AfterConnectFunc AfterConnectFunc + AfterConnect AfterConnectFunc // OnNotice is a callback function called when a notice response is received. OnNotice NoticeHandler @@ -245,7 +245,7 @@ func ParseConfig(connString string) (*Config, error) { } if settings["target_session_attrs"] == "read-write" { - config.AfterConnectFunc = AfterConnectTargetSessionAttrsReadWrite + config.AfterConnect = AfterConnectTargetSessionAttrsReadWrite } else if settings["target_session_attrs"] != "any" { return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"]) } diff --git a/config_test.go b/config_test.go index ce6f3957..b222d8cc 100644 --- a/config_test.go +++ b/config_test.go @@ -378,14 +378,14 @@ func TestParseConfig(t *testing.T) { name: "target_session_attrs", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - AfterConnectFunc: pgconn.AfterConnectTargetSessionAttrsReadWrite, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + AfterConnect: pgconn.AfterConnectTargetSessionAttrsReadWrite, }, }, } @@ -416,7 +416,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) // Can't test function equality, so just test that they are set or not. - assert.Equalf(t, expected.AfterConnectFunc == nil, actual.AfterConnectFunc == nil, "%s - AfterConnectFunc", testName) + assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { if expected.TLSConfig != nil { diff --git a/pgconn.go b/pgconn.go index 3deb8563..2db35587 100644 --- a/pgconn.go +++ b/pgconn.go @@ -201,11 +201,11 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle - if config.AfterConnectFunc != nil { - err := config.AfterConnectFunc(ctx, pgConn) + if config.AfterConnect != nil { + err := config.AfterConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, errors.Errorf("AfterConnectFunc: %v", err) + return nil, errors.Errorf("AfterConnect: %v", err) } } return pgConn, nil diff --git a/pgconn_test.go b/pgconn_test.go index 4389fe99..028d5e94 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -187,7 +187,7 @@ func TestConnectWithFallback(t *testing.T) { closeConn(t, conn) } -func TestConnectWithAfterConnectFunc(t *testing.T) { +func TestConnectWithAfterConnect(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) @@ -200,7 +200,7 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { } acceptConnCount := 0 - config.AfterConnectFunc = func(ctx context.Context, conn *pgconn.PgConn) error { + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { acceptConnCount++ if acceptConnCount < 2 { return errors.New("reject first conn") @@ -232,7 +232,7 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - config.AfterConnectFunc = pgconn.AfterConnectTargetSessionAttrsReadWrite + config.AfterConnect = pgconn.AfterConnectTargetSessionAttrsReadWrite config.RuntimeParams["default_transaction_read_only"] = "on" conn, err := pgconn.ConnectConfig(context.Background(), config) From 3dec1848118789c4430914ca04d2f6fd0542c3d9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Jul 2019 10:22:09 -0500 Subject: [PATCH 0297/1158] Split ValidateConnect from AfterConnect This avoids the foot-gun of ParseConfig setting AfterConnect because of target_session_attrs and the user inadvertently overriding it with an AfterConnect designed to setup the connection. Now target_session_attrs will be handled with ValidateConnect. --- config.go | 18 ++++++++++++------ config_test.go | 17 +++++++++-------- pgconn.go | 22 +++++++++++++++++----- pgconn_test.go | 29 +++++++++++++++++++++++++---- 4 files changed, 63 insertions(+), 23 deletions(-) diff --git a/config.go b/config.go index 533791c2..9b74945e 100644 --- a/config.go +++ b/config.go @@ -22,6 +22,7 @@ import ( ) type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error +type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. type Config struct { @@ -36,9 +37,14 @@ type Config struct { Fallbacks []*FallbackConfig - // AfterConnect is called after successful connection. It can be used to set up the connection or to validate that - // server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This - // allows implementing high availability behavior such as libpq does with target_session_attrs. + // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. + // It can be used validate that server is acceptable. If this returns an error the connection is closed and the next + // fallback config is tried. This allows implementing high availability behavior such as libpq does with + // target_session_attrs. + ValidateConnect ValidateConnectFunc + + // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables + // or prepare statements). If this returns an error the connection attempt fails. AfterConnect AfterConnectFunc // OnNotice is a callback function called when a notice response is received. @@ -245,7 +251,7 @@ func ParseConfig(connString string) (*Config, error) { } if settings["target_session_attrs"] == "read-write" { - config.AfterConnect = AfterConnectTargetSessionAttrsReadWrite + config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite } else if settings["target_session_attrs"] != "any" { return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"]) } @@ -481,9 +487,9 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { return d.DialContext, nil } -// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible +// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible // target_session_attrs=read-write. -func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { +func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() if result.Err != nil { return result.Err diff --git a/config_test.go b/config_test.go index b222d8cc..23d86529 100644 --- a/config_test.go +++ b/config_test.go @@ -378,14 +378,14 @@ func TestParseConfig(t *testing.T) { name: "target_session_attrs", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - AfterConnect: pgconn.AfterConnectTargetSessionAttrsReadWrite, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadWrite, }, }, } @@ -416,6 +416,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { diff --git a/pgconn.go b/pgconn.go index 2db35587..6e1fb7e3 100644 --- a/pgconn.go +++ b/pgconn.go @@ -122,13 +122,25 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err for _, fc := range fallbackConfigs { pgConn, err = connect(ctx, config, fc) if err == nil { - return pgConn, nil + break } else if err, ok := err.(*PgError); ok { return nil, err } } - return nil, err + if err != nil { + return nil, err + } + + if config.AfterConnect != nil { + err := config.AfterConnect(ctx, pgConn) + if err != nil { + pgConn.conn.Close() + return nil, errors.Errorf("AfterConnect: %v", err) + } + } + + return pgConn, nil } func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { @@ -201,11 +213,11 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle - if config.AfterConnect != nil { - err := config.AfterConnect(ctx, pgConn) + if config.ValidateConnect != nil { + err := config.ValidateConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, errors.Errorf("AfterConnect: %v", err) + return nil, errors.Errorf("ValidateConnect: %v", err) } } return pgConn, nil diff --git a/pgconn_test.go b/pgconn_test.go index 028d5e94..feb78641 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -187,7 +187,7 @@ func TestConnectWithFallback(t *testing.T) { closeConn(t, conn) } -func TestConnectWithAfterConnect(t *testing.T) { +func TestConnectWithValidateConnect(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) @@ -200,7 +200,7 @@ func TestConnectWithAfterConnect(t *testing.T) { } acceptConnCount := 0 - config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { acceptConnCount++ if acceptConnCount < 2 { return errors.New("reject first conn") @@ -226,13 +226,13 @@ func TestConnectWithAfterConnect(t *testing.T) { assert.True(t, acceptConnCount > 1) } -func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { +func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - config.AfterConnect = pgconn.AfterConnectTargetSessionAttrsReadWrite + config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite config.RuntimeParams["default_transaction_read_only"] = "on" conn, err := pgconn.ConnectConfig(context.Background(), config) @@ -241,6 +241,27 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { } } +func TestConnectWithAfterConnect(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() + return err + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + + results, err := conn.Exec(context.Background(), "show search_path;").ReadAll() + require.NoError(t, err) + defer closeConn(t, conn) + + assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) +} + func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() From fa7e06489bda50794a89e7a6e60446c4cc1c2ba5 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Fri, 26 Jul 2019 11:14:07 +0300 Subject: [PATCH 0298/1158] Add MinReadBufferSize option to Config Signed-off-by: Artemiy Ryabinkov --- config.go | 3 +++ pgconn.go | 8 +++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 9b74945e..bbd458e3 100644 --- a/config.go +++ b/config.go @@ -37,6 +37,9 @@ type Config struct { Fallbacks []*FallbackConfig + // MinReadBufferSize used to configure size of connection read buffer. + MinReadBufferSize int + // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. // It can be used validate that server is acceptable. If this returns an error the connection is closed and the next // fallback config is tried. This allows implementing high availability behavior such as libpq does with diff --git a/pgconn.go b/pgconn.go index 6e1fb7e3..5077ccae 100644 --- a/pgconn.go +++ b/pgconn.go @@ -15,6 +15,7 @@ import ( "sync" "time" + "github.com/jackc/chunkreader/v2" "github.com/jackc/pgconn/internal/ctxwatch" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" @@ -170,7 +171,12 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig func() { pgConn.conn.SetDeadline(time.Time{}) }, ) - pgConn.Frontend, err = pgproto3.NewFrontend(pgproto3.NewChunkReader(pgConn.conn), pgConn.conn) + cr, err := chunkreader.NewConfig(pgConn.conn, chunkreader.Config{MinBufLen: config.MinReadBufferSize}) + if err != nil { + return nil, err + } + + pgConn.Frontend, err = pgproto3.NewFrontend(cr, pgConn.conn) if err != nil { return nil, err } From f0b479097a4868d74e83c938131f5a24d25c49e8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 6 Aug 2019 17:07:11 -0500 Subject: [PATCH 0299/1158] Fix missing deferred constraint violations in certain conditions See https://github.com/jackc/pgx/issues/570. --- pgconn.go | 5 ++- pgconn_test.go | 85 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 6e1fb7e3..3157f17e 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1151,7 +1151,10 @@ func (rr *ResultReader) Close() (CommandTag, error) { return nil, rr.err } - switch msg.(type) { + switch msg := msg.(type) { + // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete. + case *pgproto3.ErrorResponse: + rr.err = errorResponseToPgError(msg) case *pgproto3.ReadyForQuery: rr.pgConn.contextWatcher.Unwatch() rr.pgConn.unlock() diff --git a/pgconn_test.go b/pgconn_test.go index feb78641..1b90b9d2 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -381,6 +381,34 @@ func TestConnExecMultipleQueriesError(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll() + require.NotNil(t, err) + + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + func TestConnExecContextCanceled(t *testing.T) { t.Parallel() @@ -437,6 +465,33 @@ func TestConnExecParams(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecParamsDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() + require.NotNil(t, result.Err) + var pgErr *pgconn.PgError + require.True(t, errors.As(result.Err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + func TestConnExecParamsMaxNumberOfParams(t *testing.T) { t.Parallel() @@ -683,6 +738,36 @@ func TestConnExecBatch(t *testing.T) { assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } +func TestConnExecBatchDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NotNil(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + func TestConnExecBatchPrecanceled(t *testing.T) { t.Parallel() From 0a99b543c007eab4dd3eb284e0206eb7d8144346 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 8 Aug 2019 11:46:25 +0300 Subject: [PATCH 0300/1158] Add BuildFrontendFunc in Config Signed-off-by: Artemiy Ryabinkov --- config.go | 30 +++++++++++++++++++----------- go.sum | 4 ---- pgconn.go | 32 +++++++++++++++++--------------- 3 files changed, 36 insertions(+), 30 deletions(-) diff --git a/config.go b/config.go index bbd458e3..be8bdab4 100644 --- a/config.go +++ b/config.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io" "io/ioutil" "math" "net" @@ -18,6 +19,7 @@ import ( "time" "github.com/jackc/pgpassfile" + "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" ) @@ -26,20 +28,18 @@ type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. type Config struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 - Database string - User string - Password string - TLSConfig *tls.Config // nil disables TLS - DialFunc DialFunc // e.g. net.Dialer.DialContext - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + Database string + User string + Password string + TLSConfig *tls.Config // nil disables TLS + DialFunc DialFunc // e.g. net.Dialer.DialContext + BuildFrontendFunc BuildFrontendFunc + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) Fallbacks []*FallbackConfig - // MinReadBufferSize used to configure size of connection read buffer. - MinReadBufferSize int - // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. // It can be used validate that server is acceptable. If this returns an error the connection is closed and the next // fallback config is tried. This allows implementing high availability behavior such as libpq does with @@ -476,6 +476,14 @@ func makeDefaultDialer() *net.Dialer { return &net.Dialer{KeepAlive: 5 * time.Minute} } +func makeDefaultBuildFrontendFunc() BuildFrontendFunc { + return func(r io.Reader) Frontend { + frontend, _ := pgproto3.NewFrontend(pgproto3.NewChunkReader(r), nil) + + return frontend + } +} + func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { timeout, err := strconv.ParseInt(s, 10, 64) if err != nil { diff --git a/go.sum b/go.sum index 50dfc2fd..0e853203 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,6 @@ github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= @@ -25,7 +23,5 @@ golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqY golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pgconn.go b/pgconn.go index 5077ccae..e7833c1f 100644 --- a/pgconn.go +++ b/pgconn.go @@ -15,7 +15,6 @@ import ( "sync" "time" - "github.com/jackc/chunkreader/v2" "github.com/jackc/pgconn/internal/ctxwatch" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" @@ -41,9 +40,12 @@ type Notification struct { Payload string } -// DialFunc is a function that can be used to connect to a PostgreSQL server +// DialFunc is a function that can be used to connect to a PostgreSQL server. type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) +// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. +type BuildFrontendFunc func(r io.Reader) Frontend + // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY @@ -56,6 +58,11 @@ type NoticeHandler func(*PgConn, *Notice) // notice event. type NotificationHandler func(*PgConn, *Notification) +// Frontend used to receive messages from backend. +type Frontend interface { + Receive() (pgproto3.BackendMessage, error) +} + // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn // the underlying TCP or unix domain socket connection @@ -63,7 +70,7 @@ type PgConn struct { secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server TxStatus byte - Frontend *pgproto3.Frontend + frontend Frontend Config *Config @@ -106,6 +113,9 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err if config.DialFunc == nil { config.DialFunc = makeDefaultDialer().DialContext } + if config.BuildFrontendFunc == nil { + config.BuildFrontendFunc = makeDefaultBuildFrontendFunc() + } if config.RuntimeParams == nil { config.RuntimeParams = make(map[string]string) } @@ -171,15 +181,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig func() { pgConn.conn.SetDeadline(time.Time{}) }, ) - cr, err := chunkreader.NewConfig(pgConn.conn, chunkreader.Config{MinBufLen: config.MinReadBufferSize}) - if err != nil { - return nil, err - } - - pgConn.Frontend, err = pgproto3.NewFrontend(cr, pgConn.conn) - if err != nil { - return nil, err - } + pgConn.frontend = config.BuildFrontendFunc(pgConn.conn) startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, @@ -298,7 +300,7 @@ func (pgConn *PgConn) signalMessage() chan struct{} { ch := make(chan struct{}) go func() { - pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.Frontend.Receive() + pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() pgConn.bufferingReceiveMux.Unlock() close(ch) }() @@ -318,10 +320,10 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { // If a timeout error happened in the background try the read again. if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - msg, err = pgConn.Frontend.Receive() + msg, err = pgConn.frontend.Receive() } } else { - msg, err = pgConn.Frontend.Receive() + msg, err = pgConn.frontend.Receive() } if err != nil { From dbb7aa8fd51b866cf601df8daf11306a9bb7c707 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 8 Aug 2019 12:52:04 +0300 Subject: [PATCH 0301/1158] Add GOPROXY to travis builds to mitigate problems with github and etc Signed-off-by: Artemiy Ryabinkov --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index e5ed43a8..1687adad 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,6 +11,7 @@ before_install: env: global: - GO111MODULE=on + - GOPROXY=https://proxy.golang.org - PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test From c9660e30c8b4f7903eaa7814789656ea79b6d173 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 8 Aug 2019 13:12:27 +0300 Subject: [PATCH 0302/1158] Use go mod download to install deps on travis-ci. Add cache for travis-ci. Signed-off-by: Artemiy Ryabinkov --- .travis.yml | 12 ++++++++++-- travis/install.bash | 14 -------------- 2 files changed, 10 insertions(+), 16 deletions(-) delete mode 100755 travis/install.bash diff --git a/.travis.yml b/.travis.yml index 1687adad..2c547abf 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,6 +4,9 @@ go: - 1.x - tip +git: + depth: 1 + # Derived from https://github.com/lib/pq/blob/master/.travis.yml before_install: - ./travis/before_install.bash @@ -12,6 +15,7 @@ env: global: - GO111MODULE=on - GOPROXY=https://proxy.golang.org + - GOFLAGS=-mod=readonly - PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test @@ -26,11 +30,15 @@ env: - PGVERSION=9.4 - PGVERSION=9.3 +cache: + directories: + - $HOME/.cache/go-build + - $HOME/gopath/pkg/mod + before_script: - ./travis/before_script.bash -install: - - ./travis/install.bash +install: go mod download script: - ./travis/script.bash diff --git a/travis/install.bash b/travis/install.bash deleted file mode 100755 index 63ba875d..00000000 --- a/travis/install.bash +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env bash -set -eux - -go get -u github.com/cockroachdb/apd -go get -u github.com/shopspring/decimal -go get -u gopkg.in/inconshreveable/log15.v2 -go get -u github.com/jackc/fake -go get -u github.com/lib/pq -go get -u github.com/hashicorp/go-version -go get -u github.com/satori/go.uuid -go get -u github.com/sirupsen/logrus -go get -u github.com/pkg/errors -go get -u go.uber.org/zap -go get -u github.com/rs/zerolog From f76af93c210584fb9c059bff875060e720c6004d Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 8 Aug 2019 13:41:51 +0300 Subject: [PATCH 0303/1158] Increase buffer size to 8KB Signed-off-by: Artemiy Ryabinkov --- chunkreader.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chunkreader.go b/chunkreader.go index 36304fd5..c45b75aa 100644 --- a/chunkreader.go +++ b/chunkreader.go @@ -39,7 +39,7 @@ func New(r io.Reader) *ChunkReader { // NewConfig creates and a new ChunkReader for r configured by config. func NewConfig(r io.Reader, config Config) (*ChunkReader, error) { if config.MinBufLen == 0 { - config.MinBufLen = 4096 + config.MinBufLen = 8192 } return &ChunkReader{ From e204afcc8c18b630476abe8e28032fe2b5762825 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 8 Aug 2019 13:43:26 +0300 Subject: [PATCH 0304/1158] Add explanation for default buffer size Signed-off-by: Artemiy Ryabinkov --- chunkreader.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/chunkreader.go b/chunkreader.go index c45b75aa..afea1c52 100644 --- a/chunkreader.go +++ b/chunkreader.go @@ -39,6 +39,10 @@ func New(r io.Reader) *ChunkReader { // NewConfig creates and a new ChunkReader for r configured by config. func NewConfig(r io.Reader, config Config) (*ChunkReader, error) { if config.MinBufLen == 0 { + // By historical reasons Postgres currently has 8KB send buffer inside, + // so here we want to have at least the same size buffer. + // @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134 + // @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru config.MinBufLen = 8192 } From bcc139a365fd09c93c5bfe50fef33e852796c57e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 17 Aug 2019 13:30:41 -0500 Subject: [PATCH 0305/1158] Port fc020c24ac9590f6547f8ad1d291fc75b4873a84 from pgx v3 commit fc020c24ac9590f6547f8ad1d291fc75b4873a84 Author: Nicholas Wilson Date: Wed Jul 24 12:32:18 2019 +0100 Add support for pgtype.UUID to write into any [16]byte type --- convert.go | 29 +++++++++++++++++++++++++++++ uuid.go | 2 +- uuid_test.go | 21 +++++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/convert.go b/convert.go index 98999d45..2fd840fc 100644 --- a/convert.go +++ b/convert.go @@ -163,6 +163,27 @@ func underlyingTimeType(val interface{}) (interface{}, bool) { return time.Time{}, false } +// underlyingUUIDType gets the underlying type that can be converted to [16]byte +func underlyingUUIDType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return time.Time{}, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + uuidType := reflect.TypeOf([16]byte{}) + if refVal.Type().ConvertibleTo(uuidType) { + return refVal.Convert(uuidType).Interface(), true + } + + return nil, false +} + // underlyingSliceType gets the underlying slice type func underlyingSliceType(val interface{}) (interface{}, bool) { refVal := reflect.ValueOf(val) @@ -401,6 +422,14 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { } } + if dstVal.Kind() == reflect.Array { + if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { + baseArrayType := reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType)) + nextDst := dstPtr.Convert(baseArrayType) + return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + } + } + return nil, false } diff --git a/uuid.go b/uuid.go index d3e68f5c..5dd10d89 100644 --- a/uuid.go +++ b/uuid.go @@ -39,7 +39,7 @@ func (dst *UUID) Set(src interface{}) error { } *dst = UUID{Bytes: uuid, Status: Present} default: - if originalSrc, ok := underlyingPtrType(src); ok { + if originalSrc, ok := underlyingUUIDType(src); ok { return dst.Set(originalSrc) } return errors.Errorf("cannot convert %v to UUID", value) diff --git a/uuid_test.go b/uuid_test.go index 49190168..f0480f9a 100644 --- a/uuid_test.go +++ b/uuid_test.go @@ -15,6 +15,8 @@ func TestUUIDTranscode(t *testing.T) { }) } +type SomeUUIDType [16]byte + func TestUUIDSet(t *testing.T) { successfulTests := []struct { source interface{} @@ -32,6 +34,10 @@ func TestUUIDSet(t *testing.T) { source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, + { + source: SomeUUIDType{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, { source: ([]byte)(nil), result: pgtype.UUID{Status: pgtype.Null}, @@ -86,6 +92,21 @@ func TestUUIDAssignTo(t *testing.T) { } } + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst SomeUUIDType + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + { src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst string From 9010c554edc5b0d65eb6cb48d735a06edc65d4c4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 17 Aug 2019 13:33:34 -0500 Subject: [PATCH 0306/1158] Port 251e6b7730c7b31b600e6fe06162e541f3032604 from pgx v3 commit 251e6b7730c7b31b600e6fe06162e541f3032604 Author: Nicholas Wilson Date: Wed Jul 24 12:32:43 2019 +0100 Tidying: make underlyingTimeType consistent with other underlyingFooType The first return value is ignored when returning false - so there's no point returning an empty time.Time when it can be nil. --- convert.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/convert.go b/convert.go index 2fd840fc..cc5c10ab 100644 --- a/convert.go +++ b/convert.go @@ -149,7 +149,7 @@ func underlyingTimeType(val interface{}) (interface{}, bool) { switch refVal.Kind() { case reflect.Ptr: if refVal.IsNil() { - return time.Time{}, false + return nil, false } convVal := refVal.Elem().Interface() return convVal, true @@ -160,7 +160,7 @@ func underlyingTimeType(val interface{}) (interface{}, bool) { return refVal.Convert(timeType).Interface(), true } - return time.Time{}, false + return nil, false } // underlyingUUIDType gets the underlying type that can be converted to [16]byte From d364370a31359546fb19828f737073b19a56f812 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 20 Aug 2019 14:11:16 -0500 Subject: [PATCH 0307/1158] Add SendBytes and ReceiveMessage --- auth_scram.go | 2 +- pgconn.go | 77 +++++++++++++++++++++++++++++++++++++++++++------- pgconn_test.go | 40 ++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 11 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index bdaf3e92..4409a080 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -74,7 +74,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { } func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) { - msg, err := c.ReceiveMessage() + msg, err := c.receiveMessage() if err != nil { return nil, err } diff --git a/pgconn.go b/pgconn.go index 63e19ed1..abbc2d10 100644 --- a/pgconn.go +++ b/pgconn.go @@ -204,7 +204,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.conn.Close() return nil, err @@ -308,7 +308,64 @@ func (pgConn *PgConn) signalMessage() chan struct{} { return ch } -func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { +// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as +// error to call SendBytes while reading the result of a query. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { + if err := pgConn.lock(); err != nil { + return linkErrors(err, ErrNoBytesSent) + } + defer pgConn.unlock() + + select { + case <-ctx.Done(): + return linkErrors(ctx.Err(), ErrNoBytesSent) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.hardClose() + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } + return linkErrors(ctx.Err(), err) + } + + return nil +} + +// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the +// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages +// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger +// the OnNotification callback. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) { + if err := pgConn.lock(); err != nil { + return nil, linkErrors(err, ErrNoBytesSent) + } + defer pgConn.unlock() + + select { + case <-ctx.Done(): + return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + + msg, err := pgConn.receiveMessage() + return msg, err +} + +// receiveMessage receives a message without setting up context cancellation +func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { var msg pgproto3.BackendMessage var err error if pgConn.bufferingReceive { @@ -506,7 +563,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ readloop: for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -616,7 +673,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { defer pgConn.contextWatcher.Unwatch() for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { return linkErrors(ctx.Err(), err) } @@ -821,7 +878,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm var commandTag CommandTag var pgErr error for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -882,7 +939,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co var pgErr error pendingCopyInResponse := true for pendingCopyInResponse { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -920,7 +977,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co select { case <-signalMessageChan: - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -950,7 +1007,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // Read results for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -991,7 +1048,7 @@ func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { } func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { - msg, err := mrr.pgConn.ReceiveMessage() + msg, err := mrr.pgConn.receiveMessage() if err != nil { mrr.pgConn.contextWatcher.Unwatch() @@ -1176,7 +1233,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { if rr.multiResultReader == nil { - msg, err = rr.pgConn.ReceiveMessage() + msg, err = rr.pgConn.receiveMessage() } else { msg, err = rr.multiResultReader.receiveMessage() } diff --git a/pgconn_test.go b/pgconn_test.go index 1b90b9d2..f385bc19 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -18,6 +18,7 @@ import ( "time" "github.com/jackc/pgconn" + "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" "github.com/stretchr/testify/assert" @@ -1416,6 +1417,45 @@ func TestConnCancelRequest(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnSendBytesAndReceiveMessage(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + queryMsg := pgproto3.Query{String: "select 42"} + buf := queryMsg.Encode(nil) + + err = pgConn.SendBytes(ctx, buf) + require.NoError(t, err) + + msg, err := pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok := msg.(*pgproto3.RowDescription) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.DataRow) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.CommandComplete) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.ReadyForQuery) + require.True(t, ok) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { From 11255efe7af4e7c2ab77e863f245f42f4ca6b4c5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 20 Aug 2019 15:49:57 -0500 Subject: [PATCH 0308/1158] Make ErrorResponseToPgError public --- pgconn.go | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/pgconn.go b/pgconn.go index abbc2d10..e51d40e8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -233,7 +233,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig // handled by ReceiveMessage case *pgproto3.ErrorResponse: pgConn.conn.Close() - return nil, errorResponseToPgError(msg) + return nil, ErrorResponseToPgError(msg) default: pgConn.conn.Close() return nil, errors.New("unexpected message") @@ -400,7 +400,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { case *pgproto3.ErrorResponse: if msg.Severity == "FATAL" { pgConn.hardClose() - return nil, errorResponseToPgError(msg) + return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: if pgConn.Config.OnNotice != nil { @@ -577,7 +577,7 @@ readloop: psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) copy(psd.Fields, msg.Fields) case *pgproto3.ErrorResponse: - parseErr = errorResponseToPgError(msg) + parseErr = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: break readloop } @@ -589,7 +589,8 @@ readloop: return psd, nil } -func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { +// ErrorResponseToPgError converts a wire protocol error message to a *PgError. +func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ Severity: msg.Severity, Code: string(msg.Code), @@ -612,7 +613,7 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { } func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { - pgerr := errorResponseToPgError((*pgproto3.ErrorResponse)(msg)) + pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg)) return (*Notice)(pgerr) } @@ -898,7 +899,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) } } } @@ -949,7 +950,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case *pgproto3.CopyInResponse: pendingCopyInResponse = false case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: return commandTag, pgErr } @@ -985,7 +986,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co switch msg := msg.(type) { case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) } default: } @@ -1019,7 +1020,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) } } } @@ -1064,7 +1065,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) mrr.closed = true mrr.pgConn.unlock() case *pgproto3.ErrorResponse: - mrr.err = errorResponseToPgError(msg) + mrr.err = ErrorResponseToPgError(msg) } return msg, nil @@ -1219,7 +1220,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { switch msg := msg.(type) { // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete. case *pgproto3.ErrorResponse: - rr.err = errorResponseToPgError(msg) + rr.err = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: rr.pgConn.contextWatcher.Unwatch() rr.pgConn.unlock() @@ -1255,7 +1256,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.CommandComplete: rr.concludeCommand(CommandTag(msg.CommandTag), nil) case *pgproto3.ErrorResponse: - rr.concludeCommand(nil, errorResponseToPgError(msg)) + rr.concludeCommand(nil, ErrorResponseToPgError(msg)) } return msg, nil From 4cf1c4481746f931c736085893a4a97ba0056644 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 22 Aug 2019 18:20:36 -0500 Subject: [PATCH 0309/1158] Fix unknown OID scanning into string and []byte --- go.mod | 1 + go.sum | 3 +++ pgtype.go | 22 +++++++++++++++++++++- pgtype_test.go | 37 +++++++++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 00679e12..075b4ee9 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/lib/pq v1.1.0 github.com/satori/go.uuid v1.2.0 github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 + github.com/stretchr/testify v1.3.0 go.uber.org/atomic v1.3.2 // indirect go.uber.org/multierr v1.1.0 // indirect golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 diff --git a/go.sum b/go.sum index ecd3007e..919c31c3 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= @@ -28,6 +29,7 @@ github.com/lib/pq v1.1.0 h1:/5u4a+KGJptBRqGzPvYQL9p0d/tPR4S31+Tnzj9lEO4= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= @@ -38,6 +40,7 @@ github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMB github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= diff --git a/pgtype.go b/pgtype.go index cea4e1cd..94e8bcbc 100644 --- a/pgtype.go +++ b/pgtype.go @@ -342,7 +342,27 @@ func (ci *ConnInfo) Scan(oid OID, formatCode int16, buf []byte, dest interface{} return value.AssignTo(dest) } } - return errors.Errorf("unknown oid: %v", oid) + + return scanUnknownType(oid, formatCode, buf, dest) +} + +func scanUnknownType(oid OID, formatCode int16, buf []byte, dest interface{}) error { + switch dest := dest.(type) { + case *string: + if formatCode == BinaryFormatCode { + return errors.Errorf("unknown oid %d in binary format cannot be scanned into %t", oid, dest) + } + *dest = string(buf) + return nil + case *[]byte: + *dest = buf + return nil + default: + if nextDst, retry := GetAssignToDstType(dest); retry { + return scanUnknownType(oid, formatCode, buf, nextDst) + } + return errors.Errorf("unknown oid %d cannot be scanned into %t", oid, dest) + } } var nameValues map[string]Value diff --git a/pgtype_test.go b/pgtype_test.go index 400c0591..53580d18 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -4,8 +4,11 @@ import ( "net" "testing" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" _ "github.com/jackc/pgx/v4/stdlib" _ "github.com/lib/pq" + "github.com/stretchr/testify/assert" ) // Test for renamed types @@ -37,3 +40,37 @@ func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { return addr } + +func TestConnInfoScanUnknownOID(t *testing.T) { + unknownOID := pgtype.OID(999999) + srcBuf := []byte("foo") + ci := pgtype.NewConnInfo() + + var s string + err := ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &s) + assert.NoError(t, err) + assert.Equal(t, "foo", s) + + var rs _string + err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rs) + assert.NoError(t, err) + assert.Equal(t, "foo", string(rs)) + + var b []byte + err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &b) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), b) + + err = ci.Scan(unknownOID, pgx.BinaryFormatCode, srcBuf, &b) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), b) + + var rb _byteSlice + err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rb) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), []byte(rb)) + + err = ci.Scan(unknownOID, pgx.BinaryFormatCode, srcBuf, &b) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), []byte(rb)) +} From 1558987979c58286747e7c90ab181adc1560f027 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 22 Aug 2019 20:11:27 -0500 Subject: [PATCH 0310/1158] ReceiveMessage returns context error instead of io error on cancel --- pgconn.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pgconn.go b/pgconn.go index e51d40e8..5d84871b 100644 --- a/pgconn.go +++ b/pgconn.go @@ -361,6 +361,9 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa defer pgConn.contextWatcher.Unwatch() msg, err := pgConn.receiveMessage() + if err != nil { + err = linkErrors(ctx.Err(), err) + } return msg, err } From 760dd75542eb13b37333e0e134b3463efade7cb4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 09:28:44 -0500 Subject: [PATCH 0311/1158] Require Config to be created by ParseConfig --- config.go | 15 ++++++++++----- pgconn.go | 19 ++++++------------- pgconn_test.go | 8 ++++++++ 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/config.go b/config.go index be8bdab4..a861ff5f 100644 --- a/config.go +++ b/config.go @@ -26,7 +26,8 @@ import ( type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error -// Config is the settings used to establish a connection to a PostgreSQL server. +// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and +// then it can be modified. A manually initialized Config will cause ConnectConfig to panic. type Config struct { Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) Port uint16 @@ -55,6 +56,8 @@ type Config struct { // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. OnNotification NotificationHandler + + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a @@ -157,10 +160,12 @@ func ParseConfig(connString string) (*Config, error) { } config := &Config{ - Database: settings["database"], - User: settings["user"], - Password: settings["password"], - RuntimeParams: make(map[string]string), + createdByParseConfig: true, + Database: settings["database"], + User: settings["user"], + Password: settings["password"], + RuntimeParams: make(map[string]string), + BuildFrontendFunc: makeDefaultBuildFrontendFunc(), } if connectTimeout, present := settings["connect_timeout"]; present { diff --git a/pgconn.go b/pgconn.go index 5d84871b..b0e4cfd2 100644 --- a/pgconn.go +++ b/pgconn.go @@ -99,25 +99,18 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) { return ConnectConfig(ctx, config) } -// Connect establishes a connection to a PostgreSQL server using config. ctx can be used to cancel a connect attempt. +// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with +// ParseConfig. ctx can be used to cancel a connect attempt. // // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // authentication error will terminate the chain of attempts (like libpq: // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, // if all attempts fail the last error is returned. func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) { - // For convenience set a few defaults if not already set. This makes it simpler to directly construct a config. - if config.Port == 0 { - config.Port = 5432 - } - if config.DialFunc == nil { - config.DialFunc = makeDefaultDialer().DialContext - } - if config.BuildFrontendFunc == nil { - config.BuildFrontendFunc = makeDefaultBuildFrontendFunc() - } - if config.RuntimeParams == nil { - config.RuntimeParams = make(map[string]string) + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from + // zero values. + if !config.createdByParseConfig { + panic("config must be created by ParseConfig") } // Simplify usage by treating primary config and fallbacks the same. diff --git a/pgconn_test.go b/pgconn_test.go index f385bc19..1cd74024 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -263,6 +263,14 @@ func TestConnectWithAfterConnect(t *testing.T) { assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) } +func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { + t.Parallel() + + config := &pgconn.Config{} + + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) }) +} + func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() From ab885b375b90c76db7e4a980c1974c31595d13ce Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 13:49:12 -0500 Subject: [PATCH 0312/1158] OID type should only be used for scanning and encoding values It was a mistake to use it in other contexts. This made interop difficult between pacakges that depended on pgtype such as pgx and packages that did not like pgconn and pgproto3. In particular this was awkward for prepared statements. Because pgx depends on pgtype and the tests for pgtype depend on pgx this change will require a couple back and forth commits to get the go.mod dependecies correct. --- go.mod | 7 ++++--- go.sum | 31 +++++++++++++++++++++++++++++++ hstore_array_test.go | 10 +++++----- pgtype.go | 16 ++++++++-------- pgtype_test.go | 2 +- record.go | 2 +- record_test.go | 5 ++--- testutil/testutil.go | 6 +++--- 8 files changed, 55 insertions(+), 24 deletions(-) diff --git a/go.mod b/go.mod index 075b4ee9..b3221838 100644 --- a/go.mod +++ b/go.mod @@ -9,9 +9,10 @@ require ( github.com/lib/pq v1.1.0 github.com/satori/go.uuid v1.2.0 github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 - github.com/stretchr/testify v1.3.0 - go.uber.org/atomic v1.3.2 // indirect + github.com/stretchr/testify v1.4.0 go.uber.org/multierr v1.1.0 // indirect - golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 + golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) + +replace github.com/jackc/pgx/v4 => ../pgx diff --git a/go.sum b/go.sum index 919c31c3..162c454f 100644 --- a/go.sum +++ b/go.sum @@ -2,10 +2,15 @@ github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMe github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3 h1:ZFYpB74Kq8xE9gmfxCmXD6QxZ27ja+j3HwGFc+YurhQ= github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb h1:d6GP9szHvXVopAOAnZ7WhRnF3Xdxrylmm/9jnfmW4Ag= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -14,6 +19,8 @@ github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96 h1:ylEAOd688Duev/fxTmGdupsbyZfxNMdngIG14DoBKTM= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= @@ -27,6 +34,9 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0 h1:/5u4a+KGJptBRqGzPvYQL9p0d/tPR4S31+Tnzj9lEO4= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -42,15 +52,36 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/hstore_array_test.go b/hstore_array_test.go index 47835605..ea8f03b0 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -14,14 +14,14 @@ func TestHstoreArrayTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) - var hstoreOID pgtype.OID + var hstoreOID uint32 err := conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='hstore';").Scan(&hstoreOID) if err != nil { t.Fatalf("did not find hstore OID, %v", err) } conn.ConnInfo.RegisterDataType(pgtype.DataType{Value: &pgtype.Hstore{}, Name: "hstore", OID: hstoreOID}) - var hstoreArrayOID pgtype.OID + var hstoreArrayOID uint32 err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='_hstore';").Scan(&hstoreArrayOID) if err != nil { t.Fatalf("did not find _hstore OID, %v", err) @@ -70,7 +70,7 @@ func TestHstoreArrayTranscode(t *testing.T) { Status: pgtype.Present, } - ps, err := conn.Prepare(context.Background(), "test", "select $1::hstore[]") + _, err = conn.Prepare(context.Background(), "test", "select $1::hstore[]") if err != nil { t.Fatal(err) } @@ -84,7 +84,7 @@ func TestHstoreArrayTranscode(t *testing.T) { } for _, fc := range formats { - ps.FieldDescriptions[0].FormatCode = fc.formatCode + queryResultFormats := pgx.QueryResultFormats{fc.formatCode} vEncoder := testutil.ForceEncoder(src, fc.formatCode) if vEncoder == nil { t.Logf("%#v does not implement %v", src, fc.name) @@ -92,7 +92,7 @@ func TestHstoreArrayTranscode(t *testing.T) { } var result pgtype.HstoreArray - err := conn.QueryRow(context.Background(), "test", vEncoder).Scan(&result) + err := conn.QueryRow(context.Background(), "test", queryResultFormats, vEncoder).Scan(&result) if err != nil { t.Errorf("%v: %v", fc.name, err) continue diff --git a/pgtype.go b/pgtype.go index 94e8bcbc..6e187ae4 100644 --- a/pgtype.go +++ b/pgtype.go @@ -163,18 +163,18 @@ var errBadStatus = errors.New("invalid status") type DataType struct { Value Value Name string - OID OID + OID uint32 } type ConnInfo struct { - oidToDataType map[OID]*DataType + oidToDataType map[uint32]*DataType nameToDataType map[string]*DataType reflectTypeToDataType map[reflect.Type]*DataType } func NewConnInfo() *ConnInfo { ci := &ConnInfo{ - oidToDataType: make(map[OID]*DataType, 128), + oidToDataType: make(map[uint32]*DataType, 128), nameToDataType: make(map[string]*DataType, 128), reflectTypeToDataType: make(map[reflect.Type]*DataType, 128), } @@ -246,7 +246,7 @@ func NewConnInfo() *ConnInfo { return ci } -func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) { +func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) { for name, oid := range nameOIDs { var value Value if t, ok := nameValues[name]; ok { @@ -264,7 +264,7 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t } -func (ci *ConnInfo) DataTypeForOID(oid OID) (*DataType, bool) { +func (ci *ConnInfo) DataTypeForOID(oid uint32) (*DataType, bool) { dt, ok := ci.oidToDataType[oid] return dt, ok } @@ -282,7 +282,7 @@ func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) { // DeepCopy makes a deep copy of the ConnInfo. func (ci *ConnInfo) DeepCopy() *ConnInfo { ci2 := &ConnInfo{ - oidToDataType: make(map[OID]*DataType, len(ci.oidToDataType)), + oidToDataType: make(map[uint32]*DataType, len(ci.oidToDataType)), nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)), reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)), } @@ -298,7 +298,7 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { return ci2 } -func (ci *ConnInfo) Scan(oid OID, formatCode int16, buf []byte, dest interface{}) error { +func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interface{}) error { if dest, ok := dest.(BinaryDecoder); ok && formatCode == BinaryFormatCode { return dest.DecodeBinary(ci, buf) } @@ -346,7 +346,7 @@ func (ci *ConnInfo) Scan(oid OID, formatCode int16, buf []byte, dest interface{} return scanUnknownType(oid, formatCode, buf, dest) } -func scanUnknownType(oid OID, formatCode int16, buf []byte, dest interface{}) error { +func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) error { switch dest := dest.(type) { case *string: if formatCode == BinaryFormatCode { diff --git a/pgtype_test.go b/pgtype_test.go index 53580d18..8771b77f 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -42,7 +42,7 @@ func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { } func TestConnInfoScanUnknownOID(t *testing.T) { - unknownOID := pgtype.OID(999999) + unknownOID := uint32(999999) srcBuf := []byte("foo") ci := pgtype.NewConnInfo() diff --git a/record.go b/record.go index 60733016..28f4a182 100644 --- a/record.go +++ b/record.go @@ -91,7 +91,7 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { if len(src[rp:]) < 8 { return errors.Errorf("Record incomplete %v", src) } - fieldOID := OID(binary.BigEndian.Uint32(src[rp:])) + fieldOID := binary.BigEndian.Uint32(src[rp:]) rp += 4 fieldLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) diff --git a/record_test.go b/record_test.go index fbf36f5c..71a2f702 100644 --- a/record_test.go +++ b/record_test.go @@ -85,14 +85,13 @@ func TestRecordTranscode(t *testing.T) { for i, tt := range tests { psName := fmt.Sprintf("test%d", i) - ps, err := conn.Prepare(context.Background(), psName, tt.sql) + _, err := conn.Prepare(context.Background(), psName, tt.sql) if err != nil { t.Fatal(err) } - ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode var result pgtype.Record - if err := conn.QueryRow(context.Background(), psName).Scan(&result); err != nil { + if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil { t.Errorf("%d: %v", i, err) continue } diff --git a/testutil/testutil.go b/testutil/testutil.go index 66deff39..068b7c59 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -225,12 +225,12 @@ func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFun for i, tt := range tests { for _, fc := range formats { psName := fmt.Sprintf("test%d", i) - ps, err := conn.Prepare(context.Background(), psName, tt.SQL) + _, err := conn.Prepare(context.Background(), psName, tt.SQL) if err != nil { t.Fatal(err) } - ps.FieldDescriptions[0].FormatCode = fc.formatCode + queryResultFormats := pgx.QueryResultFormats{fc.formatCode} if ForceEncoder(tt.Value, fc.formatCode) == nil { t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name) continue @@ -243,7 +243,7 @@ func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFun } result := reflect.New(reflect.TypeOf(derefV)) - err = conn.QueryRow(context.Background(), psName).Scan(result.Interface()) + err = conn.QueryRow(context.Background(), psName, queryResultFormats).Scan(result.Interface()) if err != nil { t.Errorf("%v %d: %v", fc.name, i, err) } From 7d83f9ba53a650e360dd59f9bdcab4bdd8d2014d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 13:59:25 -0500 Subject: [PATCH 0313/1158] Update pgx for tests Finish previous go mod dependency bounce. --- go.mod | 9 ++------- go.sum | 29 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index b3221838..dbe9f53c 100644 --- a/go.mod +++ b/go.mod @@ -4,15 +4,10 @@ go 1.12 require ( github.com/jackc/pgio v1.0.0 - github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912 - github.com/kr/pretty v0.1.0 // indirect - github.com/lib/pq v1.1.0 + github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186 + github.com/lib/pq v1.2.0 github.com/satori/go.uuid v1.2.0 github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 github.com/stretchr/testify v1.4.0 - go.uber.org/multierr v1.1.0 // indirect golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 - gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) - -replace github.com/jackc/pgx/v4 => ../pgx diff --git a/go.sum b/go.sum index 162c454f..f9a56ffd 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,9 @@ +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -22,18 +27,28 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96 h1:ylEAOd688Duev/fxTmGdupsbyZfxNMdngIG14DoBKTM= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912 h1:YuOWGsSK5L4Fz81Olx5TNlZftmDuNrfv4ip0Yos77Tw= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186 h1:ZQM8qLT/E/CGD6XX0E6q9FAwxJYmWpJufzmLMaFuzgQ= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0 h1:/5u4a+KGJptBRqGzPvYQL9p0d/tPR4S31+Tnzj9lEO4= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= @@ -41,46 +56,60 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 h1:pntxY8Ary0t43dCZ5dqY4YTJCObLY1kIXl0uzMv+7DE= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= From e540a0576006af74ed45bea905dbb4d8a5e320bc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 14:16:38 -0500 Subject: [PATCH 0314/1158] Fix typo in docs --- doc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc.go b/doc.go index d36eb0fd..cde58cd8 100644 --- a/doc.go +++ b/doc.go @@ -15,7 +15,7 @@ reads all rows into memory. Executing Multiple Queries in a Single Round Trip -Exec and ExecBatch can execute multiple queries in a single round trip. The return readers that iterate over each query +Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query result. The ReadAll method reads all query results into memory. Context Support From e6bd7390678ab23b1fded5035d8364e6fa704f28 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 16:02:27 -0500 Subject: [PATCH 0315/1158] Add pscache package --- pscache/lrucache.go | 111 ++++++++++++++++++++++++++++++++++++++ pscache/lrucache_test.go | 113 +++++++++++++++++++++++++++++++++++++++ pscache/pscache.go | 52 ++++++++++++++++++ 3 files changed, 276 insertions(+) create mode 100644 pscache/lrucache.go create mode 100644 pscache/lrucache_test.go create mode 100644 pscache/pscache.go diff --git a/pscache/lrucache.go b/pscache/lrucache.go new file mode 100644 index 00000000..d5d6062f --- /dev/null +++ b/pscache/lrucache.go @@ -0,0 +1,111 @@ +package pscache + +import ( + "container/list" + "context" + "fmt" + "sync/atomic" + + "github.com/jackc/pgconn" +) + +var lruCacheCount uint64 + +// LRUCache implements cache with a Least Recently Used (LRU) cache. +type LRUCache struct { + conn *pgconn.PgConn + mode int + cap int + prepareCount int + m map[string]*list.Element + l *list.List + psNamePrefix string +} + +// NewLRUCache creates a new LRUCache. mode is either PrepareMode or DescribeMode. cap is the maximum size of the cache. +func NewLRUCache(conn *pgconn.PgConn, mode int, cap int) *LRUCache { + mustBeValidMode(mode) + mustBeValidCap(cap) + + n := atomic.AddUint64(&lruCacheCount, 1) + + return &LRUCache{ + conn: conn, + mode: mode, + cap: cap, + m: make(map[string]*list.Element), + l: list.New(), + psNamePrefix: fmt.Sprintf("lrupsc_%d", n), + } +} + +// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. +func (c *LRUCache) Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { + if el, ok := c.m[sql]; ok { + c.l.MoveToFront(el) + return el.Value.(*pgconn.PreparedStatementDescription), nil + } + + if c.l.Len() == c.cap { + err := c.removeOldest(ctx) + if err != nil { + return nil, err + } + } + + psd, err := c.prepare(ctx, sql) + if err != nil { + return nil, err + } + + el := c.l.PushFront(psd) + c.m[sql] = el + + return psd, nil +} + +// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. +func (c *LRUCache) Clear(ctx context.Context) error { + for c.l.Len() > 0 { + err := c.removeOldest(ctx) + if err != nil { + return err + } + } + + return nil +} + +// Len returns the number of cached prepared statement descriptions. +func (c *LRUCache) Len() int { + return c.l.Len() +} + +// Cap returns the maximum number of cached prepared statement descriptions. +func (c *LRUCache) Cap() int { + return c.cap +} + +// Mode returns the mode of the cache (PrepareMode or DescribeMode) +func (c *LRUCache) Mode() int { + return c.mode +} + +func (c *LRUCache) prepare(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { + var name string + if c.mode == PrepareMode { + name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) + c.prepareCount += 1 + } + + return c.conn.Prepare(ctx, name, sql, nil) +} + +func (c *LRUCache) removeOldest(ctx context.Context) error { + oldest := c.l.Back() + c.l.Remove(oldest) + if c.mode == PrepareMode { + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.PreparedStatementDescription).Name)).Close() + } + return nil +} diff --git a/pscache/lrucache_test.go b/pscache/lrucache_test.go new file mode 100644 index 00000000..bf2fcbe0 --- /dev/null +++ b/pscache/lrucache_test.go @@ -0,0 +1,113 @@ +package pscache_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/jackc/pgconn" + "github.com/jackc/pgconn/pscache" + + "github.com/stretchr/testify/require" +) + +func TestLRUCachePrepareMode(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := pscache.NewLRUCache(conn, pscache.PrepareMode, 2) + require.EqualValues(t, 0, cache.Len()) + require.EqualValues(t, 2, cache.Cap()) + require.EqualValues(t, pscache.PrepareMode, cache.Mode()) + + psd, err := cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 2") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 3") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn)) + + err = cache.Clear(ctx) + require.NoError(t, err) + require.EqualValues(t, 0, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) +} + +func TestLRUCacheDescribeMode(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := pscache.NewLRUCache(conn, pscache.DescribeMode, 2) + require.EqualValues(t, 0, cache.Len()) + require.EqualValues(t, 2, cache.Cap()) + require.EqualValues(t, pscache.DescribeMode, cache.Mode()) + + psd, err := cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 2") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 3") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + err = cache.Clear(ctx) + require.NoError(t, err) + require.EqualValues(t, 0, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) +} + +func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string { + result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + var statements []string + for _, r := range result.Rows { + statements = append(statements, string(r[0])) + } + return statements +} diff --git a/pscache/pscache.go b/pscache/pscache.go new file mode 100644 index 00000000..bfd51e81 --- /dev/null +++ b/pscache/pscache.go @@ -0,0 +1,52 @@ +// Package pscache is a cache that can be used to implement lazy, automatic prepared statements. +package pscache + +import ( + "context" + + "github.com/jackc/pgconn" +) + +const ( + PrepareMode = iota // Cache should prepare named statements. + DescribeMode // Cache should prepare the anonymous prepared statement to only fetch the description of the statement. +) + +// Cache prepares and caches prepared statement descriptions. +type Cache interface { + // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. + Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) + + // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. + Clear(ctx context.Context) error + + // Len returns the number of cached prepared statement descriptions. + Len() int + + // Cap returns the maximum number of cached prepared statement descriptions. + Cap() int + + // Mode returns the mode of the cache (PrepareMode or DescribeMode) + Mode() int +} + +// New returns the preferred cache implementation for mode and cap. mode is either PrepareMode or DescribeMode. cap is +// the maximum size of the cache. +func New(conn *pgconn.PgConn, mode int, cap int) Cache { + mustBeValidMode(mode) + mustBeValidCap(cap) + + return NewLRUCache(conn, mode, cap) +} + +func mustBeValidMode(mode int) { + if mode != PrepareMode && mode != DescribeMode { + panic("mode must be PrepareMode or DescribeMode") + } +} + +func mustBeValidCap(cap int) { + if cap < 1 { + panic("cache must have cap of >= 1") + } +} From 797a44bf048f27e5db5c79dcbf7e406969ca6904 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 16:18:01 -0500 Subject: [PATCH 0316/1158] Rename BuildFrontendFunc to BuildFrontend For consistency with other functions supplied in Config. --- config.go | 20 ++++++++++---------- pgconn.go | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/config.go b/config.go index a861ff5f..b5c119f5 100644 --- a/config.go +++ b/config.go @@ -29,15 +29,15 @@ type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and // then it can be modified. A manually initialized Config will cause ConnectConfig to panic. type Config struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 - Database string - User string - Password string - TLSConfig *tls.Config // nil disables TLS - DialFunc DialFunc // e.g. net.Dialer.DialContext - BuildFrontendFunc BuildFrontendFunc - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + Database string + User string + Password string + TLSConfig *tls.Config // nil disables TLS + DialFunc DialFunc // e.g. net.Dialer.DialContext + BuildFrontend BuildFrontendFunc + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) Fallbacks []*FallbackConfig @@ -165,7 +165,7 @@ func ParseConfig(connString string) (*Config, error) { User: settings["user"], Password: settings["password"], RuntimeParams: make(map[string]string), - BuildFrontendFunc: makeDefaultBuildFrontendFunc(), + BuildFrontend: makeDefaultBuildFrontendFunc(), } if connectTimeout, present := settings["connect_timeout"]; present { diff --git a/pgconn.go b/pgconn.go index b0e4cfd2..fe2f304e 100644 --- a/pgconn.go +++ b/pgconn.go @@ -174,7 +174,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig func() { pgConn.conn.SetDeadline(time.Time{}) }, ) - pgConn.frontend = config.BuildFrontendFunc(pgConn.conn) + pgConn.frontend = config.BuildFrontend(pgConn.conn) startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, From 2209d2e36aea43ee17610489a2644af2212a4bc3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 16:27:54 -0500 Subject: [PATCH 0317/1158] Rename mode constants --- pscache/lrucache.go | 8 ++++---- pscache/lrucache_test.go | 12 ++++++------ pscache/pscache.go | 12 ++++++------ 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/pscache/lrucache.go b/pscache/lrucache.go index d5d6062f..cdcec63c 100644 --- a/pscache/lrucache.go +++ b/pscache/lrucache.go @@ -22,7 +22,7 @@ type LRUCache struct { psNamePrefix string } -// NewLRUCache creates a new LRUCache. mode is either PrepareMode or DescribeMode. cap is the maximum size of the cache. +// NewLRUCache creates a new LRUCache. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. func NewLRUCache(conn *pgconn.PgConn, mode int, cap int) *LRUCache { mustBeValidMode(mode) mustBeValidCap(cap) @@ -86,14 +86,14 @@ func (c *LRUCache) Cap() int { return c.cap } -// Mode returns the mode of the cache (PrepareMode or DescribeMode) +// Mode returns the mode of the cache (ModePrepare or ModeDescribe) func (c *LRUCache) Mode() int { return c.mode } func (c *LRUCache) prepare(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { var name string - if c.mode == PrepareMode { + if c.mode == ModePrepare { name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) c.prepareCount += 1 } @@ -104,7 +104,7 @@ func (c *LRUCache) prepare(ctx context.Context, sql string) (*pgconn.PreparedSta func (c *LRUCache) removeOldest(ctx context.Context) error { oldest := c.l.Back() c.l.Remove(oldest) - if c.mode == PrepareMode { + if c.mode == ModePrepare { return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.PreparedStatementDescription).Name)).Close() } return nil diff --git a/pscache/lrucache_test.go b/pscache/lrucache_test.go index bf2fcbe0..a5d413e3 100644 --- a/pscache/lrucache_test.go +++ b/pscache/lrucache_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestLRUCachePrepareMode(t *testing.T) { +func TestLRUCacheModePrepare(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) @@ -22,10 +22,10 @@ func TestLRUCachePrepareMode(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := pscache.NewLRUCache(conn, pscache.PrepareMode, 2) + cache := pscache.NewLRUCache(conn, pscache.ModePrepare, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) - require.EqualValues(t, pscache.PrepareMode, cache.Mode()) + require.EqualValues(t, pscache.ModePrepare, cache.Mode()) psd, err := cache.Get(ctx, "select 1") require.NoError(t, err) @@ -57,7 +57,7 @@ func TestLRUCachePrepareMode(t *testing.T) { require.Empty(t, fetchServerStatements(t, ctx, conn)) } -func TestLRUCacheDescribeMode(t *testing.T) { +func TestLRUCacheModeDescribe(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) @@ -67,10 +67,10 @@ func TestLRUCacheDescribeMode(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := pscache.NewLRUCache(conn, pscache.DescribeMode, 2) + cache := pscache.NewLRUCache(conn, pscache.ModeDescribe, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) - require.EqualValues(t, pscache.DescribeMode, cache.Mode()) + require.EqualValues(t, pscache.ModeDescribe, cache.Mode()) psd, err := cache.Get(ctx, "select 1") require.NoError(t, err) diff --git a/pscache/pscache.go b/pscache/pscache.go index bfd51e81..4f8cf723 100644 --- a/pscache/pscache.go +++ b/pscache/pscache.go @@ -8,8 +8,8 @@ import ( ) const ( - PrepareMode = iota // Cache should prepare named statements. - DescribeMode // Cache should prepare the anonymous prepared statement to only fetch the description of the statement. + ModePrepare = iota // Cache should prepare named statements. + ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement. ) // Cache prepares and caches prepared statement descriptions. @@ -26,11 +26,11 @@ type Cache interface { // Cap returns the maximum number of cached prepared statement descriptions. Cap() int - // Mode returns the mode of the cache (PrepareMode or DescribeMode) + // Mode returns the mode of the cache (ModePrepare or ModeDescribe) Mode() int } -// New returns the preferred cache implementation for mode and cap. mode is either PrepareMode or DescribeMode. cap is +// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is // the maximum size of the cache. func New(conn *pgconn.PgConn, mode int, cap int) Cache { mustBeValidMode(mode) @@ -40,8 +40,8 @@ func New(conn *pgconn.PgConn, mode int, cap int) Cache { } func mustBeValidMode(mode int) { - if mode != PrepareMode && mode != DescribeMode { - panic("mode must be PrepareMode or DescribeMode") + if mode != ModePrepare && mode != ModeDescribe { + panic("mode must be ModePrepare or ModeDescribe") } } From beba629bb5d526f8d7de6ec8754090d39b476757 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 17:18:29 -0500 Subject: [PATCH 0318/1158] Fix result reader returned by locked conn --- pgconn.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pgconn.go b/pgconn.go index fe2f304e..797080bd 100644 --- a/pgconn.go +++ b/pgconn.go @@ -791,19 +791,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa } func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { - if err := pgConn.lock(); err != nil { - return &ResultReader{ - closed: true, - err: linkErrors(err, ErrNoBytesSent), - } - } - pgConn.resultReader = ResultReader{ pgConn: pgConn, ctx: ctx, } result := &pgConn.resultReader + if err := pgConn.lock(); err != nil { + result.concludeCommand(nil, linkErrors(err, ErrNoBytesSent)) + result.closed = true + return result + } + if len(paramValues) > math.MaxUint16 { result.concludeCommand(nil, errors.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true From bcd6b9244ab8fc80e85b75b604bf214f82345e59 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 19:46:14 -0500 Subject: [PATCH 0319/1158] Rename pscache to stmtcache --- {pscache => stmtcache}/lrucache.go | 2 +- {pscache => stmtcache}/lrucache_test.go | 12 ++++++------ pscache/pscache.go => stmtcache/stmtcache.go | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) rename {pscache => stmtcache}/lrucache.go (99%) rename {pscache => stmtcache}/lrucache_test.go (90%) rename pscache/pscache.go => stmtcache/stmtcache.go (92%) diff --git a/pscache/lrucache.go b/stmtcache/lrucache.go similarity index 99% rename from pscache/lrucache.go rename to stmtcache/lrucache.go index cdcec63c..9c4d046d 100644 --- a/pscache/lrucache.go +++ b/stmtcache/lrucache.go @@ -1,4 +1,4 @@ -package pscache +package stmtcache import ( "container/list" diff --git a/pscache/lrucache_test.go b/stmtcache/lrucache_test.go similarity index 90% rename from pscache/lrucache_test.go rename to stmtcache/lrucache_test.go index a5d413e3..ed8ebdc3 100644 --- a/pscache/lrucache_test.go +++ b/stmtcache/lrucache_test.go @@ -1,4 +1,4 @@ -package pscache_test +package stmtcache_test import ( "context" @@ -7,7 +7,7 @@ import ( "time" "github.com/jackc/pgconn" - "github.com/jackc/pgconn/pscache" + "github.com/jackc/pgconn/stmtcache" "github.com/stretchr/testify/require" ) @@ -22,10 +22,10 @@ func TestLRUCacheModePrepare(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := pscache.NewLRUCache(conn, pscache.ModePrepare, 2) + cache := stmtcache.NewLRUCache(conn, stmtcache.ModePrepare, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) - require.EqualValues(t, pscache.ModePrepare, cache.Mode()) + require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) psd, err := cache.Get(ctx, "select 1") require.NoError(t, err) @@ -67,10 +67,10 @@ func TestLRUCacheModeDescribe(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := pscache.NewLRUCache(conn, pscache.ModeDescribe, 2) + cache := stmtcache.NewLRUCache(conn, stmtcache.ModeDescribe, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) - require.EqualValues(t, pscache.ModeDescribe, cache.Mode()) + require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode()) psd, err := cache.Get(ctx, "select 1") require.NoError(t, err) diff --git a/pscache/pscache.go b/stmtcache/stmtcache.go similarity index 92% rename from pscache/pscache.go rename to stmtcache/stmtcache.go index 4f8cf723..d70f277b 100644 --- a/pscache/pscache.go +++ b/stmtcache/stmtcache.go @@ -1,5 +1,5 @@ -// Package pscache is a cache that can be used to implement lazy, automatic prepared statements. -package pscache +// Package stmtcache is a cache that can be used to implement lazy prepared statements. +package stmtcache import ( "context" From 78abbdf1d7eef6b2aa78831c31141057876537f6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 19:48:43 -0500 Subject: [PATCH 0320/1158] Rename LRUCache to LRU --- stmtcache/{lrucache.go => lru.go} | 28 ++++++++++----------- stmtcache/{lrucache_test.go => lru_test.go} | 8 +++--- stmtcache/stmtcache.go | 2 +- 3 files changed, 19 insertions(+), 19 deletions(-) rename stmtcache/{lrucache.go => lru.go} (70%) rename stmtcache/{lrucache_test.go => lru_test.go} (93%) diff --git a/stmtcache/lrucache.go b/stmtcache/lru.go similarity index 70% rename from stmtcache/lrucache.go rename to stmtcache/lru.go index 9c4d046d..432a70b4 100644 --- a/stmtcache/lrucache.go +++ b/stmtcache/lru.go @@ -9,10 +9,10 @@ import ( "github.com/jackc/pgconn" ) -var lruCacheCount uint64 +var lruCount uint64 -// LRUCache implements cache with a Least Recently Used (LRU) cache. -type LRUCache struct { +// LRU implements Cache with a Least Recently Used (LRU) cache. +type LRU struct { conn *pgconn.PgConn mode int cap int @@ -22,14 +22,14 @@ type LRUCache struct { psNamePrefix string } -// NewLRUCache creates a new LRUCache. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. -func NewLRUCache(conn *pgconn.PgConn, mode int, cap int) *LRUCache { +// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. +func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { mustBeValidMode(mode) mustBeValidCap(cap) - n := atomic.AddUint64(&lruCacheCount, 1) + n := atomic.AddUint64(&lruCount, 1) - return &LRUCache{ + return &LRU{ conn: conn, mode: mode, cap: cap, @@ -40,7 +40,7 @@ func NewLRUCache(conn *pgconn.PgConn, mode int, cap int) *LRUCache { } // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. -func (c *LRUCache) Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { +func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { if el, ok := c.m[sql]; ok { c.l.MoveToFront(el) return el.Value.(*pgconn.PreparedStatementDescription), nil @@ -65,7 +65,7 @@ func (c *LRUCache) Get(ctx context.Context, sql string) (*pgconn.PreparedStateme } // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. -func (c *LRUCache) Clear(ctx context.Context) error { +func (c *LRU) Clear(ctx context.Context) error { for c.l.Len() > 0 { err := c.removeOldest(ctx) if err != nil { @@ -77,21 +77,21 @@ func (c *LRUCache) Clear(ctx context.Context) error { } // Len returns the number of cached prepared statement descriptions. -func (c *LRUCache) Len() int { +func (c *LRU) Len() int { return c.l.Len() } // Cap returns the maximum number of cached prepared statement descriptions. -func (c *LRUCache) Cap() int { +func (c *LRU) Cap() int { return c.cap } // Mode returns the mode of the cache (ModePrepare or ModeDescribe) -func (c *LRUCache) Mode() int { +func (c *LRU) Mode() int { return c.mode } -func (c *LRUCache) prepare(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { +func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { var name string if c.mode == ModePrepare { name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) @@ -101,7 +101,7 @@ func (c *LRUCache) prepare(ctx context.Context, sql string) (*pgconn.PreparedSta return c.conn.Prepare(ctx, name, sql, nil) } -func (c *LRUCache) removeOldest(ctx context.Context) error { +func (c *LRU) removeOldest(ctx context.Context) error { oldest := c.l.Back() c.l.Remove(oldest) if c.mode == ModePrepare { diff --git a/stmtcache/lrucache_test.go b/stmtcache/lru_test.go similarity index 93% rename from stmtcache/lrucache_test.go rename to stmtcache/lru_test.go index ed8ebdc3..b518364e 100644 --- a/stmtcache/lrucache_test.go +++ b/stmtcache/lru_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestLRUCacheModePrepare(t *testing.T) { +func TestLRUModePrepare(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) @@ -22,7 +22,7 @@ func TestLRUCacheModePrepare(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := stmtcache.NewLRUCache(conn, stmtcache.ModePrepare, 2) + cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) @@ -57,7 +57,7 @@ func TestLRUCacheModePrepare(t *testing.T) { require.Empty(t, fetchServerStatements(t, ctx, conn)) } -func TestLRUCacheModeDescribe(t *testing.T) { +func TestLRUModeDescribe(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) @@ -67,7 +67,7 @@ func TestLRUCacheModeDescribe(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := stmtcache.NewLRUCache(conn, stmtcache.ModeDescribe, 2) + cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode()) diff --git a/stmtcache/stmtcache.go b/stmtcache/stmtcache.go index d70f277b..9bedf549 100644 --- a/stmtcache/stmtcache.go +++ b/stmtcache/stmtcache.go @@ -36,7 +36,7 @@ func New(conn *pgconn.PgConn, mode int, cap int) Cache { mustBeValidMode(mode) mustBeValidCap(cap) - return NewLRUCache(conn, mode, cap) + return NewLRU(conn, mode, cap) } func mustBeValidMode(mode int) { From da9fc85c4404a53f910e2f8210be5add1bc50454 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 20:39:01 -0500 Subject: [PATCH 0321/1158] Rename PreparedStatementDescription to StatementDescription PreparedStatementDescription was too long. It also no longer entirely represents its purpose now that it is also intended for use with described statements. --- pgconn.go | 9 +++++---- stmtcache/lru.go | 8 ++++---- stmtcache/stmtcache.go | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pgconn.go b/pgconn.go index 797080bd..8f3291f1 100644 --- a/pgconn.go +++ b/pgconn.go @@ -517,15 +517,16 @@ func (ct CommandTag) String() string { return string(ct) } -type PreparedStatementDescription struct { +type StatementDescription struct { Name string SQL string ParamOIDs []uint32 Fields []pgproto3.FieldDescription } -// Prepare creates a prepared statement. -func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { +// Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This +// allows Prepare to also to describe statements without creating a server-side prepared statement. +func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { if err := pgConn.lock(); err != nil { return nil, linkErrors(err, ErrNoBytesSent) } @@ -553,7 +554,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ return nil, linkErrors(ctx.Err(), err) } - psd := &PreparedStatementDescription{Name: name, SQL: sql} + psd := &StatementDescription{Name: name, SQL: sql} var parseErr error diff --git a/stmtcache/lru.go b/stmtcache/lru.go index 432a70b4..fff4d0b7 100644 --- a/stmtcache/lru.go +++ b/stmtcache/lru.go @@ -40,10 +40,10 @@ func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { } // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. -func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { +func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { if el, ok := c.m[sql]; ok { c.l.MoveToFront(el) - return el.Value.(*pgconn.PreparedStatementDescription), nil + return el.Value.(*pgconn.StatementDescription), nil } if c.l.Len() == c.cap { @@ -91,7 +91,7 @@ func (c *LRU) Mode() int { return c.mode } -func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { +func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { var name string if c.mode == ModePrepare { name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) @@ -105,7 +105,7 @@ func (c *LRU) removeOldest(ctx context.Context) error { oldest := c.l.Back() c.l.Remove(oldest) if c.mode == ModePrepare { - return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.PreparedStatementDescription).Name)).Close() + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.StatementDescription).Name)).Close() } return nil } diff --git a/stmtcache/stmtcache.go b/stmtcache/stmtcache.go index 9bedf549..96215799 100644 --- a/stmtcache/stmtcache.go +++ b/stmtcache/stmtcache.go @@ -15,7 +15,7 @@ const ( // Cache prepares and caches prepared statement descriptions. type Cache interface { // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. - Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) + Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. Clear(ctx context.Context) error From 6feea0c1c57d8ec5ff0cd806354437ed03b415f6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 23:43:26 -0500 Subject: [PATCH 0322/1158] Replace IsAlive with IsClosed IsAlive is ambiguous because the connection may be dead and we do not know it. It implies the possibility of a ping. IsClosed is clearer -- it does not promise the connection is alive only that it hasn't been closed. fixes #2 --- pgconn.go | 7 +++---- pgconn_test.go | 10 +++++----- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pgconn.go b/pgconn.go index 8f3291f1..153829ca 100644 --- a/pgconn.go +++ b/pgconn.go @@ -463,10 +463,9 @@ func (pgConn *PgConn) hardClose() error { return pgConn.conn.Close() } -// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect death of -// underlying connection. -func (pgConn *PgConn) IsAlive() bool { - return pgConn.status >= connStatusIdle +// IsClosed reports if the connection has been closed. +func (pgConn *PgConn) IsClosed() bool { + return pgConn.status < connStatusIdle } // lock locks the connection. It panics if the connection is already locked or is closed. diff --git a/pgconn_test.go b/pgconn_test.go index 1cd74024..64628262 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -433,7 +433,7 @@ func TestConnExecContextCanceled(t *testing.T) { } err = multiResult.Close() assert.Equal(t, context.DeadlineExceeded, err) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnExecContextPrecanceled(t *testing.T) { @@ -566,7 +566,7 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, context.DeadlineExceeded, err) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnExecParamsPrecanceled(t *testing.T) { @@ -692,7 +692,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, context.DeadlineExceeded, err) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnExecPreparedPrecanceled(t *testing.T) { @@ -1142,7 +1142,7 @@ func TestConnCopyToCanceled(t *testing.T) { assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.Equal(t, pgconn.CommandTag(nil), res) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnCopyToPrecanceled(t *testing.T) { @@ -1233,7 +1233,7 @@ func TestConnCopyFromCanceled(t *testing.T) { assert.Equal(t, int64(0), ct.RowsAffected()) assert.True(t, errors.Is(err, context.DeadlineExceeded)) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnCopyFromPrecanceled(t *testing.T) { From 595d09d6f1bfba423db8d00f61efebf0aaa6a85a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 23:57:24 -0500 Subject: [PATCH 0323/1158] Build fully operational Frontend --- config.go | 4 ++-- pgconn.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/config.go b/config.go index b5c119f5..e078a061 100644 --- a/config.go +++ b/config.go @@ -482,8 +482,8 @@ func makeDefaultDialer() *net.Dialer { } func makeDefaultBuildFrontendFunc() BuildFrontendFunc { - return func(r io.Reader) Frontend { - frontend, _ := pgproto3.NewFrontend(pgproto3.NewChunkReader(r), nil) + return func(r io.Reader, w io.Writer) Frontend { + frontend, _ := pgproto3.NewFrontend(pgproto3.NewChunkReader(r), w) return frontend } diff --git a/pgconn.go b/pgconn.go index 153829ca..7d301af2 100644 --- a/pgconn.go +++ b/pgconn.go @@ -44,7 +44,7 @@ type Notification struct { type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. -type BuildFrontendFunc func(r io.Reader) Frontend +type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin @@ -174,7 +174,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig func() { pgConn.conn.SetDeadline(time.Time{}) }, ) - pgConn.frontend = config.BuildFrontend(pgConn.conn) + pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, From e6cf51b304f1d6961663ede4ba89be363fc54237 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 25 Aug 2019 00:22:32 -0500 Subject: [PATCH 0324/1158] Expose min_read_buffer_size config param --- config.go | 24 +++++++++++++++++++++--- config_test.go | 12 ++++++++++++ go.mod | 1 + 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/config.go b/config.go index e078a061..cb153c77 100644 --- a/config.go +++ b/config.go @@ -18,6 +18,7 @@ import ( "strings" "time" + "github.com/jackc/chunkreader/v2" "github.com/jackc/pgpassfile" "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" @@ -140,6 +141,11 @@ func NetworkAddress(host string, port uint16) (network, address string) { // // When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn // does not. +// +// In addition, ParseConfig accepts the following options: +// +// min_read_buffer_size +// The minimum size of the internal read buffer. Default 8192. func ParseConfig(connString string) (*Config, error) { settings := defaultSettings() addEnvSettings(settings) @@ -159,13 +165,18 @@ func ParseConfig(connString string) (*Config, error) { } } + minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) + if err != nil { + return nil, errors.Errorf("cannot parse min_read_buffer_size: %w", err) + } + config := &Config{ createdByParseConfig: true, Database: settings["database"], User: settings["user"], Password: settings["password"], RuntimeParams: make(map[string]string), - BuildFrontend: makeDefaultBuildFrontendFunc(), + BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)), } if connectTimeout, present := settings["connect_timeout"]; present { @@ -192,6 +203,7 @@ func ParseConfig(connString string) (*Config, error) { "sslcert": struct{}{}, "sslrootcert": struct{}{}, "target_session_attrs": struct{}{}, + "min_read_buffer_size": struct{}{}, } for k, v := range settings { @@ -284,6 +296,8 @@ func defaultSettings() map[string]string { settings["target_session_attrs"] = "any" + settings["min_read_buffer_size"] = "8192" + return settings } @@ -481,9 +495,13 @@ func makeDefaultDialer() *net.Dialer { return &net.Dialer{KeepAlive: 5 * time.Minute} } -func makeDefaultBuildFrontendFunc() BuildFrontendFunc { +func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { return func(r io.Reader, w io.Writer) Frontend { - frontend, _ := pgproto3.NewFrontend(pgproto3.NewChunkReader(r), w) + cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen}) + if err != nil { + panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err)) + } + frontend, _ := pgproto3.NewFrontend(cr, w) return frontend } diff --git a/config_test.go b/config_test.go index 23d86529..af42094d 100644 --- a/config_test.go +++ b/config_test.go @@ -561,3 +561,15 @@ func TestParseConfigReadsPgPassfile(t *testing.T) { assertConfigsEqual(t, expected, actual, "passfile") } + +func TestParseConfigExtractsMinReadBufferSize(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig("min_read_buffer_size=0") + require.NoError(t, err) + _, present := config.RuntimeParams["min_read_buffer_size"] + require.False(t, present) + + // The buffer size is internal so there isn't much that can be done to test it other than see that the runtime param + // was removed. +} diff --git a/go.mod b/go.mod index b1c84049..cbeef02a 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/jackc/pgconn go 1.12 require ( + github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 From b1e25e4ea49c995a914679a7e85d47b4101f2ed9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 25 Aug 2019 00:32:11 -0500 Subject: [PATCH 0325/1158] Add format code helpers to ConnInfo --- pgtype.go | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/pgtype.go b/pgtype.go index 6e187ae4..3391a04f 100644 --- a/pgtype.go +++ b/pgtype.go @@ -170,6 +170,8 @@ type ConnInfo struct { oidToDataType map[uint32]*DataType nameToDataType map[string]*DataType reflectTypeToDataType map[reflect.Type]*DataType + oidToParamFormatCode map[uint32]int16 + oidToResultFormatCode map[uint32]int16 } func NewConnInfo() *ConnInfo { @@ -177,6 +179,8 @@ func NewConnInfo() *ConnInfo { oidToDataType: make(map[uint32]*DataType, 128), nameToDataType: make(map[string]*DataType, 128), reflectTypeToDataType: make(map[reflect.Type]*DataType, 128), + oidToParamFormatCode: make(map[uint32]int16, 128), + oidToResultFormatCode: make(map[uint32]int16, 128), } ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID}) @@ -262,6 +266,22 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { ci.oidToDataType[t.OID] = &t ci.nameToDataType[t.Name] = &t ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t + + { + var formatCode int16 + if _, ok := t.Value.(BinaryEncoder); ok { + formatCode = BinaryFormatCode + } + ci.oidToParamFormatCode[t.OID] = formatCode + } + + { + var formatCode int16 + if _, ok := t.Value.(BinaryDecoder); ok { + formatCode = BinaryFormatCode + } + ci.oidToResultFormatCode[t.OID] = formatCode + } } func (ci *ConnInfo) DataTypeForOID(oid uint32) (*DataType, bool) { @@ -279,6 +299,22 @@ func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) { return dt, ok } +func (ci *ConnInfo) ParamFormatCodeForOID(oid uint32) int16 { + fc, ok := ci.oidToParamFormatCode[oid] + if ok { + return fc + } + return TextFormatCode +} + +func (ci *ConnInfo) ResultFormatCodeForOID(oid uint32) int16 { + fc, ok := ci.oidToResultFormatCode[oid] + if ok { + return fc + } + return TextFormatCode +} + // DeepCopy makes a deep copy of the ConnInfo. func (ci *ConnInfo) DeepCopy() *ConnInfo { ci2 := &ConnInfo{ From 138254da5b02b80a548f7858f01636f9a426b918 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 27 Aug 2019 18:01:59 -0500 Subject: [PATCH 0326/1158] Refactor errors - Use strongly typed errors internally - SafeToRetry(error) streamlines retry logic over ErrNoBytesSent - Timeout(error) removes the need to choose between returning a context and an i/o error --- config.go | 14 ++--- errors.go | 156 ++++++++++++++++++++++++++++++++++++------------- pgconn.go | 125 +++++++++++++++++---------------------- pgconn_test.go | 41 ++++++------- 4 files changed, 195 insertions(+), 141 deletions(-) diff --git a/config.go b/config.go index cb153c77..d24d0202 100644 --- a/config.go +++ b/config.go @@ -155,19 +155,19 @@ func ParseConfig(connString string) (*Config, error) { if strings.HasPrefix(connString, "postgres://") { err := addURLSettings(settings, connString) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} } } else { err := addDSNSettings(settings, connString) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} } } } minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) if err != nil { - return nil, errors.Errorf("cannot parse min_read_buffer_size: %w", err) + return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err} } config := &Config{ @@ -182,7 +182,7 @@ func ParseConfig(connString string) (*Config, error) { if connectTimeout, present := settings["connect_timeout"]; present { dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} } config.DialFunc = dialFunc } else { @@ -228,7 +228,7 @@ func ParseConfig(connString string) (*Config, error) { port, err := parsePort(portStr) if err != nil { - return nil, errors.Errorf("invalid port: %w", err) + return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err} } var tlsConfigs []*tls.Config @@ -240,7 +240,7 @@ func ParseConfig(connString string) (*Config, error) { var err error tlsConfigs, err = configTLS(settings) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} } } @@ -273,7 +273,7 @@ func ParseConfig(connString string) (*Config, error) { if settings["target_session_attrs"] == "read-write" { config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite } else if settings["target_session_attrs"] != "any" { - return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"]) + return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])} } return config, nil diff --git a/errors.go b/errors.go index 4f8af407..a088dcdd 100644 --- a/errors.go +++ b/errors.go @@ -2,22 +2,31 @@ package pgconn import ( "context" + "fmt" "net" + "strings" errors "golang.org/x/xerrors" ) -// ErrTLSRefused occurs when the connection attempt requires TLS and the -// PostgreSQL server refuses to use TLS -var ErrTLSRefused = errors.New("server refused TLS connection") +// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. +func SafeToRetry(err error) bool { + if e, ok := err.(interface{ SafeToRetry() bool }); ok { + return e.SafeToRetry() + } + return false +} -// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another -// action is attempted. -var ErrConnBusy = errors.New("conn is busy") +// Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a +// context.Canceled, context.Canceled or an implementer of net.Error where Timeout() is true. +func Timeout(err error) bool { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true + } -// ErrNoBytesSent is used to annotate an error that occurred without sending any bytes to the server. This can be used -// to implement safe retry logic. ErrNoBytesSent will never occur alone. It will always be wrapped by another error. -var ErrNoBytesSent = errors.New("no bytes sent to server") + var netErr net.Error + return errors.As(err, &netErr) && netErr.Timeout() +} // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for @@ -46,44 +55,107 @@ func (pe *PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } -// linkedError connects two errors as if err wrapped next. -type linkedError struct { - err error - next error +type connectError struct { + config *Config + msg string + err error } -func (le *linkedError) Error() string { - return le.err.Error() -} - -func (le *linkedError) Is(target error) bool { - return errors.Is(le.err, target) -} - -func (le *linkedError) As(target interface{}) bool { - return errors.As(le.err, target) -} - -func (le *linkedError) Unwrap() error { - return le.next -} - -// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == -// true. Otherwise returns err. -func preferContextOverNetTimeoutError(ctx context.Context, err error) error { - if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { - return ctx.Err() +func (e *connectError) Error() string { + sb := &strings.Builder{} + fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg) + if e.err != nil { + fmt.Fprintf(sb, " (%s)", e.err.Error()) } - return err + return sb.String() } -// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned. -func linkErrors(outer, inner error) error { - if outer == nil { - return inner +func (e *connectError) Unwrap() error { + return e.err +} + +type connLockError struct { + status string +} + +func (e *connLockError) SafeToRetry() bool { + return true // a lock failure by definition happens before the connection is used. +} + +func (e *connLockError) Error() string { + return e.status +} + +type parseConfigError struct { + connString string + msg string + err error +} + +func (e *parseConfigError) Error() string { + if e.err == nil { + return fmt.Sprintf("cannot parse `%s`: %s", e.connString, e.msg) } - if inner == nil { - return outer + return fmt.Sprintf("cannot parse `%s`: %s (%s)", e.connString, e.msg, e.err.Error()) +} + +func (e *parseConfigError) Unwrap() error { + return e.err +} + +type pgconnError struct { + msg string + err error + safeToRetry bool +} + +func (e *pgconnError) Error() string { + if e.msg == "" { + return e.err.Error() } - return &linkedError{err: outer, next: inner} + if e.err == nil { + return e.msg + } + return fmt.Sprintf("%s: %s", e.msg, e.err.Error()) +} + +func (e *pgconnError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *pgconnError) Unwrap() error { + return e.err +} + +type contextAlreadyDoneError struct { + err error +} + +func (e *contextAlreadyDoneError) Error() string { + return fmt.Sprintf("context already done: %s", e.err.Error()) +} + +func (e *contextAlreadyDoneError) SafeToRetry() bool { + return true +} + +func (e *contextAlreadyDoneError) Unwrap() error { + return e.err +} + +type writeError struct { + err error + safeToRetry bool +} + +func (e *writeError) Error() string { + return fmt.Sprintf("write failed: %s", e.err.Error()) +} + +func (e *writeError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *writeError) Unwrap() error { + return e.err } diff --git a/pgconn.go b/pgconn.go index 7d301af2..347acf80 100644 --- a/pgconn.go +++ b/pgconn.go @@ -128,19 +128,19 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err if err == nil { break } else if err, ok := err.(*PgError); ok { - return nil, err + return nil, &connectError{config: config, msg: "server error", err: err} } } if err != nil { - return nil, err + return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError } if config.AfterConnect != nil { err := config.AfterConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, errors.Errorf("AfterConnect: %v", err) + return nil, &connectError{config: config, msg: "AfterConnect error", err: err} } } @@ -156,7 +156,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) pgConn.conn, err = config.DialFunc(ctx, network, address) if err != nil { - return nil, err + return nil, &connectError{config: config, msg: "dial error", err: err} } pgConn.parameterStatuses = make(map[string]string) @@ -164,7 +164,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if fallbackConfig.TLSConfig != nil { if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil { pgConn.conn.Close() - return nil, err + return nil, &connectError{config: config, msg: "tls error", err: err} } } @@ -193,14 +193,17 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { pgConn.conn.Close() - return nil, err + return nil, &connectError{config: config, msg: "failed to write startup message", err: err} } for { msg, err := pgConn.receiveMessage() if err != nil { pgConn.conn.Close() - return nil, err + if err, ok := err.(*PgError); ok { + return nil, err + } + return nil, &connectError{config: config, msg: "failed to receive message", err: err} } switch msg := msg.(type) { @@ -210,7 +213,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig case *pgproto3.Authentication: if err = pgConn.rxAuthenticationX(msg); err != nil { pgConn.conn.Close() - return nil, err + return nil, &connectError{config: config, msg: "failed handle authentication message", err: err} } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle @@ -218,7 +221,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig err := config.ValidateConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, errors.Errorf("ValidateConnect: %v", err) + return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} } } return pgConn, nil @@ -229,7 +232,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, ErrorResponseToPgError(msg) default: pgConn.conn.Close() - return nil, errors.New("unexpected message") + return nil, &connectError{config: config, msg: "received unexpected message", err: err} } } } @@ -246,7 +249,7 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { } if response[0] != 'S' { - return ErrTLSRefused + return errors.New("server refused TLS connection") } pgConn.conn = tls.Client(pgConn.conn, tlsConfig) @@ -308,13 +311,13 @@ func (pgConn *PgConn) signalMessage() chan struct{} { // See https://www.postgresql.org/docs/current/protocol.html. func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { if err := pgConn.lock(); err != nil { - return linkErrors(err, ErrNoBytesSent) + return err } defer pgConn.unlock() select { case <-ctx.Done(): - return linkErrors(ctx.Err(), ErrNoBytesSent) + return &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -323,10 +326,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return linkErrors(ctx.Err(), err) + return &writeError{err: err, safeToRetry: n == 0} } return nil @@ -341,13 +341,13 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { // See https://www.postgresql.org/docs/current/protocol.html. func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -355,7 +355,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa msg, err := pgConn.receiveMessage() if err != nil { - err = linkErrors(ctx.Err(), err) + err = &pgconnError{msg: "receive message failed", err: err, safeToRetry: true} } return msg, err } @@ -442,12 +442,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error { _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { - return linkErrors(ctx.Err(), err) + return err } _, err = pgConn.conn.Read(make([]byte, 1)) if err != io.EOF { - return linkErrors(ctx.Err(), err) + return err } return pgConn.conn.Close() @@ -468,15 +468,15 @@ func (pgConn *PgConn) IsClosed() bool { return pgConn.status < connStatusIdle } -// lock locks the connection. It panics if the connection is already locked or is closed. +// lock locks the connection. func (pgConn *PgConn) lock() error { switch pgConn.status { case connStatusBusy: - return ErrConnBusy // This only should be possible in case of an application bug. + return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug. case connStatusClosed: - return errors.New("conn closed") + return &connLockError{status: "conn closed"} case connStatusUninitialized: - return errors.New("conn uninitialized") + return &connLockError{status: "conn uninitialized"} } pgConn.status = connStatusBusy return nil @@ -527,13 +527,13 @@ type StatementDescription struct { // allows Prepare to also to describe statements without creating a server-side prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -547,10 +547,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return nil, linkErrors(ctx.Err(), err) + return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0} } psd := &StatementDescription{Name: name, SQL: sql} @@ -562,7 +559,7 @@ readloop: msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -641,12 +638,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) _, err = cancelConn.Write(buf) if err != nil { - return linkErrors(ctx.Err(), err) + return err } _, err = cancelConn.Read(buf) if err != io.EOF { - return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err)) + return err } return nil @@ -672,7 +669,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { for { msg, err := pgConn.receiveMessage() if err != nil { - return linkErrors(ctx.Err(), err) + return err } switch msg.(type) { @@ -691,7 +688,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, - err: linkErrors(err, ErrNoBytesSent), + err: err, } } @@ -704,7 +701,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} pgConn.unlock() return multiResult default: @@ -719,10 +716,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { pgConn.hardClose() pgConn.contextWatcher.Unwatch() multiResult.closed = true - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - multiResult.err = linkErrors(ctx.Err(), err) + multiResult.err = &writeError{err: err, safeToRetry: n == 0} pgConn.unlock() return multiResult } @@ -798,7 +792,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by result := &pgConn.resultReader if err := pgConn.lock(); err != nil { - result.concludeCommand(nil, linkErrors(err, ErrNoBytesSent)) + result.concludeCommand(nil, err) result.closed = true return result } @@ -812,7 +806,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by select { case <-ctx.Done(): - result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent)) + result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) result.closed = true pgConn.unlock() return result @@ -831,10 +825,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - result.concludeCommand(nil, linkErrors(ctx.Err(), err)) + result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0}) pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() @@ -844,13 +835,13 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } select { case <-ctx.Done(): pgConn.unlock() - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -864,10 +855,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm if err != nil { pgConn.hardClose() pgConn.unlock() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return nil, linkErrors(ctx.Err(), err) + return nil, &writeError{err: err, safeToRetry: n == 0} } // Read results @@ -877,7 +865,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -905,13 +893,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -924,10 +912,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return nil, linkErrors(ctx.Err(), err) + return nil, &writeError{err: err, safeToRetry: n == 0} } // Read until copy in response or error. @@ -938,7 +923,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -967,7 +952,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } } @@ -976,7 +961,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -998,7 +983,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } // Read results @@ -1006,7 +991,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -1048,7 +1033,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) if err != nil { mrr.pgConn.contextWatcher.Unwatch() - mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) + mrr.err = err mrr.closed = true mrr.pgConn.hardClose() return nil, mrr.err @@ -1263,7 +1248,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { } rr.commandTag = commandTag - rr.err = preferContextOverNetTimeoutError(rr.ctx, err) + rr.err = err rr.fieldDescriptions = nil rr.rowValues = nil rr.commandConcluded = true @@ -1293,7 +1278,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, - err: linkErrors(err, ErrNoBytesSent), + err: err, } } @@ -1306,7 +1291,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} pgConn.unlock() return multiResult default: diff --git a/pgconn_test.go b/pgconn_test.go index 64628262..3fbdf8df 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -86,14 +86,11 @@ func TestConnectInvalidUser(t *testing.T) { config.User = "pgxinvalidusertest" - conn, err := pgconn.ConnectConfig(context.Background(), config) - if err == nil { - conn.Close(context.Background()) - t.Fatal("expected err but got none") - } - pgErr, ok := err.(*pgconn.PgError) + _, err = pgconn.ConnectConfig(context.Background(), config) + require.Error(t, err) + pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) if !ok { - t.Fatalf("Expected to receive a PgError, instead received: %v", err) + t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) } if pgErr.Code != "28000" && pgErr.Code != "28P01" { t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) @@ -298,7 +295,7 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { assert.Nil(t, psd) assert.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } @@ -432,7 +429,7 @@ func TestConnExecContextCanceled(t *testing.T) { for multiResult.NextResult() { } err = multiResult.Close() - assert.Equal(t, context.DeadlineExceeded, err) + assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) } @@ -448,7 +445,7 @@ func TestConnExecContextPrecanceled(t *testing.T) { _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() assert.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } @@ -564,7 +561,7 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.Equal(t, 0, rowCount) commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) - assert.Equal(t, context.DeadlineExceeded, err) + assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) } @@ -581,7 +578,7 @@ func TestConnExecParamsPrecanceled(t *testing.T) { result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() require.Error(t, result.Err) assert.True(t, errors.Is(result.Err, context.Canceled)) - assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(result.Err)) ensureConnValid(t, pgConn) } @@ -691,7 +688,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { assert.Equal(t, 0, rowCount) commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) - assert.Equal(t, context.DeadlineExceeded, err) + assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) } @@ -710,7 +707,7 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() require.Error(t, result.Err) assert.True(t, errors.Is(result.Err, context.Canceled)) - assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(result.Err)) ensureConnValid(t, pgConn) } @@ -798,7 +795,7 @@ func TestConnExecBatchPrecanceled(t *testing.T) { _, err = pgConn.ExecBatch(ctx, batch).ReadAll() require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } @@ -871,8 +868,8 @@ func TestConnLocking(t *testing.T) { mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.Error(t, err) - assert.True(t, errors.Is(err, pgconn.ErrConnBusy)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.Equal(t, "conn busy", err.Error()) + assert.True(t, pgconn.SafeToRetry(err)) results, err := mrr.ReadAll() assert.NoError(t, err) @@ -1029,7 +1026,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) err = pgConn.WaitForNotification(ctx) cancel() - assert.True(t, errors.Is(err, context.DeadlineExceeded)) + assert.True(t, pgconn.Timeout(err)) ensureConnValid(t, pgConn) } @@ -1139,7 +1136,7 @@ func TestConnCopyToCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") - assert.True(t, errors.Is(err, context.DeadlineExceeded)) + assert.Error(t, err) assert.Equal(t, pgconn.CommandTag(nil), res) assert.True(t, pgConn.IsClosed()) @@ -1159,7 +1156,7 @@ func TestConnCopyToPrecanceled(t *testing.T) { res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) assert.Equal(t, pgconn.CommandTag(nil), res) ensureConnValid(t, pgConn) @@ -1231,7 +1228,7 @@ func TestConnCopyFromCanceled(t *testing.T) { ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") cancel() assert.Equal(t, int64(0), ct.RowsAffected()) - assert.True(t, errors.Is(err, context.DeadlineExceeded)) + assert.Error(t, err) assert.True(t, pgConn.IsClosed()) } @@ -1267,7 +1264,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) { ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) assert.Equal(t, pgconn.CommandTag(nil), ct) ensureConnValid(t, pgConn) From 66aaed7c9eb0751b2936dbdbf278963dda8804fd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 27 Aug 2019 18:11:50 -0500 Subject: [PATCH 0327/1158] Remove public fields from PgConn - Access TxStatus via method - Make Config private fixes #7 --- auth_scram.go | 2 +- pgconn.go | 27 ++++++++++++++++----------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index 4409a080..6d6d0651 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -31,7 +31,7 @@ const clientNonceLen = 18 // Perform SCRAM authentication. func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { - sc, err := newScramClient(serverAuthMechanisms, c.Config.Password) + sc, err := newScramClient(serverAuthMechanisms, c.config.Password) if err != nil { return err } diff --git a/pgconn.go b/pgconn.go index 347acf80..1e3f9515 100644 --- a/pgconn.go +++ b/pgconn.go @@ -69,10 +69,10 @@ type PgConn struct { pid uint32 // backend pid secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server - TxStatus byte + txStatus byte frontend Frontend - Config *Config + config *Config status byte // One of connStatus* constants @@ -149,7 +149,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) - pgConn.Config = config + pgConn.config = config pgConn.wbuf = make([]byte, 0, 1024) var err error @@ -261,9 +261,9 @@ func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error switch msg.Type { case pgproto3.AuthTypeOk: case pgproto3.AuthTypeCleartextPassword: - err = pgConn.txPasswordMessage(pgConn.Config.Password) + err = pgConn.txPasswordMessage(pgConn.config.Password) case pgproto3.AuthTypeMD5Password: - digestedPassword := "md5" + hexMD5(hexMD5(pgConn.Config.Password+pgConn.Config.User)+string(msg.Salt[:])) + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) err = pgConn.txPasswordMessage(digestedPassword) case pgproto3.AuthTypeSASL: err = pgConn.scramAuth(msg.SASLAuthMechanisms) @@ -390,7 +390,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - pgConn.TxStatus = msg.TxStatus + pgConn.txStatus = msg.TxStatus case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: @@ -399,12 +399,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: - if pgConn.Config.OnNotice != nil { - pgConn.Config.OnNotice(pgConn, noticeResponseToNotice(msg)) + if pgConn.config.OnNotice != nil { + pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg)) } case *pgproto3.NotificationResponse: - if pgConn.Config.OnNotification != nil { - pgConn.Config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) + if pgConn.config.OnNotification != nil { + pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) } } @@ -421,6 +421,11 @@ func (pgConn *PgConn) PID() uint32 { return pgConn.pid } +// TxStatus returns the current TxStatus as reported by the server. +func (pgConn *PgConn) TxStatus() byte { + return pgConn.txStatus +} + // SecretKey returns the backend secret key used to send a cancel query message to the server. func (pgConn *PgConn) SecretKey() uint32 { return pgConn.secretKey @@ -618,7 +623,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() - cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err } From a8802b16cc593842f5c69b0f7cfb0de11d5cd3a8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 27 Aug 2019 20:46:16 -0500 Subject: [PATCH 0328/1158] Value, EncodeBinary, EncodeText, and MarshalJSON on T instead of *T Methods defined on T are also available on *T. This change makes Value consistent with database/sql Value implementations. It also makes Value, EncodeBinary, and EncodeText more convenient to use because you can pass T or *T as an argument to a query. The MarshalJSON change is even more significant because without it json.Marshal would generate the "%v" format instead of the implemented MarshalJSON. Thought this technically changes the interface, because *T will be automatically dereferenced as needed it shouldn't be a breaking change. See: https://github.com/jackc/pgx/issues/538 for initial discussion. --- aclitem.go | 4 ++-- aclitem_array.go | 4 ++-- array.go | 2 +- bit.go | 8 ++++---- bool.go | 6 +++--- bool_array.go | 6 +++--- box.go | 6 +++--- bpchar.go | 16 ++++++++-------- bpchar_array.go | 6 +++--- bytea.go | 6 +++--- bytea_array.go | 6 +++--- cid.go | 12 ++++++------ cidr.go | 8 ++++---- cidr_array.go | 6 +++--- circle.go | 6 +++--- date.go | 6 +++--- date_array.go | 6 +++--- enum_array.go | 4 ++-- ext/satori-uuid/uuid.go | 6 +++--- ext/shopspring-numeric/decimal.go | 6 +++--- float4.go | 6 +++--- float4_array.go | 6 +++--- float8.go | 6 +++--- float8_array.go | 6 +++--- generic_binary.go | 8 ++++---- generic_text.go | 8 ++++---- hstore.go | 6 +++--- hstore_array.go | 6 +++--- inet.go | 6 +++--- inet_array.go | 6 +++--- int2.go | 8 ++++---- int2_array.go | 6 +++--- int4.go | 8 ++++---- int4_array.go | 6 +++--- int8.go | 8 ++++---- int8_array.go | 6 +++--- interval.go | 6 +++--- json.go | 6 +++--- jsonb.go | 10 +++++----- line.go | 6 +++--- lseg.go | 6 +++--- macaddr.go | 6 +++--- macaddr_array.go | 6 +++--- name.go | 12 ++++++------ numeric.go | 6 +++--- numeric_array.go | 6 +++--- oid_value.go | 12 ++++++------ path.go | 6 +++--- pguint32.go | 6 +++--- point.go | 6 +++--- polygon.go | 6 +++--- qchar.go | 2 +- text.go | 8 ++++---- text_array.go | 6 +++--- tid.go | 6 +++--- timestamp.go | 6 +++--- timestamp_array.go | 6 +++--- timestamptz.go | 6 +++--- timestamptz_array.go | 6 +++--- unknown.go | 4 ++-- uuid.go | 6 +++--- uuid_array.go | 6 +++--- varbit.go | 6 +++--- varchar.go | 16 ++++++++-------- varchar_array.go | 6 +++--- xid.go | 12 ++++++------ 66 files changed, 222 insertions(+), 222 deletions(-) diff --git a/aclitem.go b/aclitem.go index c801eb83..123e86b6 100644 --- a/aclitem.go +++ b/aclitem.go @@ -84,7 +84,7 @@ func (dst *ACLItem) DecodeText(ci *ConnInfo, src []byte) error { return nil } -func (src *ACLItem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src ACLItem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -115,7 +115,7 @@ func (dst *ACLItem) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *ACLItem) Value() (driver.Value, error) { +func (src ACLItem) Value() (driver.Value, error) { switch src.Status { case Present: return src.String, nil diff --git a/aclitem_array.go b/aclitem_array.go index c8421153..e8142091 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -124,7 +124,7 @@ func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { return nil } -func (src *ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -200,7 +200,7 @@ func (dst *ACLItemArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *ACLItemArray) Value() (driver.Value, error) { +func (src ACLItemArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/array.go b/array.go index 69456782..bd3a993b 100644 --- a/array.go +++ b/array.go @@ -60,7 +60,7 @@ func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { return rp, nil } -func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte { +func (src ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte { buf = pgio.AppendInt32(buf, int32(len(src.Dimensions))) var containsNull int32 diff --git a/bit.go b/bit.go index f892cee5..4f40a532 100644 --- a/bit.go +++ b/bit.go @@ -22,8 +22,8 @@ func (dst *Bit) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Varbit)(dst).DecodeBinary(ci, src) } -func (src *Bit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Varbit)(src).EncodeBinary(ci, buf) +func (src Bit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Varbit)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. @@ -32,6 +32,6 @@ func (dst *Bit) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Bit) Value() (driver.Value, error) { - return (*Varbit)(src).Value() +func (src Bit) Value() (driver.Value, error) { + return (Varbit)(src).Value() } diff --git a/bool.go b/bool.go index f622061b..ad55dce4 100644 --- a/bool.go +++ b/bool.go @@ -96,7 +96,7 @@ func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -113,7 +113,7 @@ func (src *Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Bool) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Bool) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -153,7 +153,7 @@ func (dst *Bool) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Bool) Value() (driver.Value, error) { +func (src Bool) Value() (driver.Value, error) { switch src.Status { case Present: return src.Bool, nil diff --git a/bool_array.go b/bool_array.go index 3dde8dc0..ba453254 100644 --- a/bool_array.go +++ b/bool_array.go @@ -168,7 +168,7 @@ func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +225,7 @@ func (src *BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +288,7 @@ func (dst *BoolArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *BoolArray) Value() (driver.Value, error) { +func (src BoolArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/box.go b/box.go index ce5300e5..9baabf6b 100644 --- a/box.go +++ b/box.go @@ -108,7 +108,7 @@ func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -125,7 +125,7 @@ func (src *Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Box) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Box) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -161,6 +161,6 @@ func (dst *Box) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Box) Value() (driver.Value, error) { +func (src Box) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/bpchar.go b/bpchar.go index 21263184..1a85fa0d 100644 --- a/bpchar.go +++ b/bpchar.go @@ -41,12 +41,12 @@ func (dst *BPChar) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } -func (src *BPChar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Text)(src).EncodeText(ci, buf) +func (src BPChar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeText(ci, buf) } -func (src *BPChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Text)(src).EncodeBinary(ci, buf) +func (src BPChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. @@ -55,12 +55,12 @@ func (dst *BPChar) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *BPChar) Value() (driver.Value, error) { - return (*Text)(src).Value() +func (src BPChar) Value() (driver.Value, error) { + return (Text)(src).Value() } -func (src *BPChar) MarshalJSON() ([]byte, error) { - return (*Text)(src).MarshalJSON() +func (src BPChar) MarshalJSON() ([]byte, error) { + return (Text)(src).MarshalJSON() } func (dst *BPChar) UnmarshalJSON(b []byte) error { diff --git a/bpchar_array.go b/bpchar_array.go index 547b4e80..da601d0d 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -168,7 +168,7 @@ func (dst *BPCharArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +225,7 @@ func (src *BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *BPCharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src BPCharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +288,7 @@ func (dst *BPCharArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *BPCharArray) Value() (driver.Value, error) { +func (src BPCharArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/bytea.go b/bytea.go index e6c28dc7..c6e79cdf 100644 --- a/bytea.go +++ b/bytea.go @@ -100,7 +100,7 @@ func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Bytea) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Bytea) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -113,7 +113,7 @@ func (src *Bytea) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Bytea) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Bytea) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -145,7 +145,7 @@ func (dst *Bytea) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Bytea) Value() (driver.Value, error) { +func (src Bytea) Value() (driver.Value, error) { switch src.Status { case Present: return src.Bytes, nil diff --git a/bytea_array.go b/bytea_array.go index 369d6e08..1c2f6548 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -168,7 +168,7 @@ func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +225,7 @@ func (src *ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +288,7 @@ func (dst *ByteaArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *ByteaArray) Value() (driver.Value, error) { +func (src ByteaArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/cid.go b/cid.go index 0ed54f44..d27982bd 100644 --- a/cid.go +++ b/cid.go @@ -42,12 +42,12 @@ func (dst *CID) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src *CID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*pguint32)(src).EncodeText(ci, buf) +func (src CID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeText(ci, buf) } -func (src *CID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*pguint32)(src).EncodeBinary(ci, buf) +func (src CID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. @@ -56,6 +56,6 @@ func (dst *CID) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *CID) Value() (driver.Value, error) { - return (*pguint32)(src).Value() +func (src CID) Value() (driver.Value, error) { + return (pguint32)(src).Value() } diff --git a/cidr.go b/cidr.go index 519b9cae..9e13a97e 100644 --- a/cidr.go +++ b/cidr.go @@ -22,10 +22,10 @@ func (dst *CIDR) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Inet)(dst).DecodeBinary(ci, src) } -func (src *CIDR) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Inet)(src).EncodeText(ci, buf) +func (src CIDR) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Inet)(src).EncodeText(ci, buf) } -func (src *CIDR) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Inet)(src).EncodeBinary(ci, buf) +func (src CIDR) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Inet)(src).EncodeBinary(ci, buf) } diff --git a/cidr_array.go b/cidr_array.go index 94c07679..234c6aff 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -197,7 +197,7 @@ func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -254,7 +254,7 @@ func (src *CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -317,7 +317,7 @@ func (dst *CIDRArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *CIDRArray) Value() (driver.Value, error) { +func (src CIDRArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/circle.go b/circle.go index 66dec132..9644345c 100644 --- a/circle.go +++ b/circle.go @@ -95,7 +95,7 @@ func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Circle) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Circle) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -112,7 +112,7 @@ func (src *Circle) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Circle) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Circle) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -146,6 +146,6 @@ func (dst *Circle) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Circle) Value() (driver.Value, error) { +func (src Circle) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/date.go b/date.go index 08ba8c08..8e35b22a 100644 --- a/date.go +++ b/date.go @@ -125,7 +125,7 @@ func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Date) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Date) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -147,7 +147,7 @@ func (src *Date) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, s...), nil } -func (src *Date) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Date) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -195,7 +195,7 @@ func (dst *Date) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Date) Value() (driver.Value, error) { +func (src Date) Value() (driver.Value, error) { switch src.Status { case Present: if src.InfinityModifier != None { diff --git a/date_array.go b/date_array.go index 05070360..69fc3e5e 100644 --- a/date_array.go +++ b/date_array.go @@ -169,7 +169,7 @@ func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -226,7 +226,7 @@ func (src *DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -289,7 +289,7 @@ func (dst *DateArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *DateArray) Value() (driver.Value, error) { +func (src DateArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/enum_array.go b/enum_array.go index 504d513c..f4609169 100644 --- a/enum_array.go +++ b/enum_array.go @@ -124,7 +124,7 @@ func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { return nil } -func (src *EnumArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src EnumArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -200,7 +200,7 @@ func (dst *EnumArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *EnumArray) Value() (driver.Value, error) { +func (src EnumArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/ext/satori-uuid/uuid.go b/ext/satori-uuid/uuid.go index 01adea23..9b958b58 100644 --- a/ext/satori-uuid/uuid.go +++ b/ext/satori-uuid/uuid.go @@ -117,7 +117,7 @@ func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return nil } -func (src *UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { +func (src UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case pgtype.Null: return nil, nil @@ -128,7 +128,7 @@ func (src *UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { return append(buf, src.UUID.String()...), nil } -func (src *UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { +func (src UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case pgtype.Null: return nil, nil @@ -157,6 +157,6 @@ func (dst *UUID) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *UUID) Value() (driver.Value, error) { +func (src UUID) Value() (driver.Value, error) { return pgtype.EncodeValueText(src) } diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index d8f176a8..c035b15b 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -257,7 +257,7 @@ func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return nil } -func (src *Numeric) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { +func (src Numeric) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case pgtype.Null: return nil, nil @@ -268,7 +268,7 @@ func (src *Numeric) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) return append(buf, src.Decimal.String()...), nil } -func (src *Numeric) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { +func (src Numeric) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case pgtype.Null: return nil, nil @@ -306,7 +306,7 @@ func (dst *Numeric) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Numeric) Value() (driver.Value, error) { +func (src Numeric) Value() (driver.Value, error) { switch src.Status { case pgtype.Present: return src.Decimal.Value() diff --git a/float4.go b/float4.go index 0947f36a..3f701dc5 100644 --- a/float4.go +++ b/float4.go @@ -138,7 +138,7 @@ func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Float4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Float4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -150,7 +150,7 @@ func (src *Float4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Float4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Float4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -185,7 +185,7 @@ func (dst *Float4) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Float4) Value() (driver.Value, error) { +func (src Float4) Value() (driver.Value, error) { switch src.Status { case Present: return float64(src.Float), nil diff --git a/float4_array.go b/float4_array.go index ef134407..80aff879 100644 --- a/float4_array.go +++ b/float4_array.go @@ -168,7 +168,7 @@ func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +225,7 @@ func (src *Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +288,7 @@ func (dst *Float4Array) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Float4Array) Value() (driver.Value, error) { +func (src Float4Array) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/float8.go b/float8.go index 87cf6adb..9c6847c3 100644 --- a/float8.go +++ b/float8.go @@ -128,7 +128,7 @@ func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Float8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Float8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -140,7 +140,7 @@ func (src *Float8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Float8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Float8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -175,7 +175,7 @@ func (dst *Float8) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Float8) Value() (driver.Value, error) { +func (src Float8) Value() (driver.Value, error) { switch src.Status { case Present: return src.Float, nil diff --git a/float8_array.go b/float8_array.go index ba63449c..3999cf7d 100644 --- a/float8_array.go +++ b/float8_array.go @@ -168,7 +168,7 @@ func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +225,7 @@ func (src *Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +288,7 @@ func (dst *Float8Array) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Float8Array) Value() (driver.Value, error) { +func (src Float8Array) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/generic_binary.go b/generic_binary.go index 2596ecae..5689523e 100644 --- a/generic_binary.go +++ b/generic_binary.go @@ -24,8 +24,8 @@ func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Bytea)(dst).DecodeBinary(ci, src) } -func (src *GenericBinary) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Bytea)(src).EncodeBinary(ci, buf) +func (src GenericBinary) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Bytea)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. @@ -34,6 +34,6 @@ func (dst *GenericBinary) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *GenericBinary) Value() (driver.Value, error) { - return (*Bytea)(src).Value() +func (src GenericBinary) Value() (driver.Value, error) { + return (Bytea)(src).Value() } diff --git a/generic_text.go b/generic_text.go index 0e3db9de..d8890f48 100644 --- a/generic_text.go +++ b/generic_text.go @@ -24,8 +24,8 @@ func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeText(ci, src) } -func (src *GenericText) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Text)(src).EncodeText(ci, buf) +func (src GenericText) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeText(ci, buf) } // Scan implements the database/sql Scanner interface. @@ -34,6 +34,6 @@ func (dst *GenericText) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *GenericText) Value() (driver.Value, error) { - return (*Text)(src).Value() +func (src GenericText) Value() (driver.Value, error) { + return (Text)(src).Value() } diff --git a/hstore.go b/hstore.go index 56af38ee..45b165af 100644 --- a/hstore.go +++ b/hstore.go @@ -151,7 +151,7 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -186,7 +186,7 @@ func (src *Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Hstore) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Hstore) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -426,6 +426,6 @@ func (dst *Hstore) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Hstore) Value() (driver.Value, error) { +func (src Hstore) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/hstore_array.go b/hstore_array.go index 1bdac816..8269fb40 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -168,7 +168,7 @@ func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +225,7 @@ func (src *HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +288,7 @@ func (dst *HstoreArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *HstoreArray) Value() (driver.Value, error) { +func (src HstoreArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/inet.go b/inet.go index 0fb1c418..3c2eda9b 100644 --- a/inet.go +++ b/inet.go @@ -148,7 +148,7 @@ func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Inet) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Inet) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -160,7 +160,7 @@ func (src *Inet) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } // EncodeBinary encodes src into w. -func (src *Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -211,6 +211,6 @@ func (dst *Inet) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Inet) Value() (driver.Value, error) { +func (src Inet) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/inet_array.go b/inet_array.go index b31d3588..a6fd419e 100644 --- a/inet_array.go +++ b/inet_array.go @@ -197,7 +197,7 @@ func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -254,7 +254,7 @@ func (src *InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -317,7 +317,7 @@ func (dst *InetArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *InetArray) Value() (driver.Value, error) { +func (src InetArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/int2.go b/int2.go index bbf2952f..f3e01308 100644 --- a/int2.go +++ b/int2.go @@ -133,7 +133,7 @@ func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int2) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int2) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -144,7 +144,7 @@ func (src *Int2) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil } -func (src *Int2) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int2) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -184,7 +184,7 @@ func (dst *Int2) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Int2) Value() (driver.Value, error) { +func (src Int2) Value() (driver.Value, error) { switch src.Status { case Present: return int64(src.Int), nil @@ -195,7 +195,7 @@ func (src *Int2) Value() (driver.Value, error) { } } -func (src *Int2) MarshalJSON() ([]byte, error) { +func (src Int2) MarshalJSON() ([]byte, error) { switch src.Status { case Present: return []byte(strconv.FormatInt(int64(src.Int), 10)), nil diff --git a/int2_array.go b/int2_array.go index afb39513..beea543f 100644 --- a/int2_array.go +++ b/int2_array.go @@ -196,7 +196,7 @@ func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -253,7 +253,7 @@ func (src *Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -316,7 +316,7 @@ func (dst *Int2Array) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Int2Array) Value() (driver.Value, error) { +func (src Int2Array) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/int4.go b/int4.go index cc34ce0a..da39b7f0 100644 --- a/int4.go +++ b/int4.go @@ -125,7 +125,7 @@ func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -136,7 +136,7 @@ func (src *Int4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil } -func (src *Int4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -176,7 +176,7 @@ func (dst *Int4) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Int4) Value() (driver.Value, error) { +func (src Int4) Value() (driver.Value, error) { switch src.Status { case Present: return int64(src.Int), nil @@ -187,7 +187,7 @@ func (src *Int4) Value() (driver.Value, error) { } } -func (src *Int4) MarshalJSON() ([]byte, error) { +func (src Int4) MarshalJSON() ([]byte, error) { switch src.Status { case Present: return []byte(strconv.FormatInt(int64(src.Int), 10)), nil diff --git a/int4_array.go b/int4_array.go index bd0babb9..83ee4c26 100644 --- a/int4_array.go +++ b/int4_array.go @@ -215,7 +215,7 @@ func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -272,7 +272,7 @@ func (src *Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -335,7 +335,7 @@ func (dst *Int4Array) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Int4Array) Value() (driver.Value, error) { +func (src Int4Array) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/int8.go b/int8.go index 153f1f7d..7f410b15 100644 --- a/int8.go +++ b/int8.go @@ -117,7 +117,7 @@ func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -128,7 +128,7 @@ func (src *Int8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, strconv.FormatInt(src.Int, 10)...), nil } -func (src *Int8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -162,7 +162,7 @@ func (dst *Int8) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Int8) Value() (driver.Value, error) { +func (src Int8) Value() (driver.Value, error) { switch src.Status { case Present: return int64(src.Int), nil @@ -173,7 +173,7 @@ func (src *Int8) Value() (driver.Value, error) { } } -func (src *Int8) MarshalJSON() ([]byte, error) { +func (src Int8) MarshalJSON() ([]byte, error) { switch src.Status { case Present: return []byte(strconv.FormatInt(src.Int, 10)), nil diff --git a/int8_array.go b/int8_array.go index 392fd47e..f118bc83 100644 --- a/int8_array.go +++ b/int8_array.go @@ -196,7 +196,7 @@ func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -253,7 +253,7 @@ func (src *Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -316,7 +316,7 @@ func (dst *Int8Array) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Int8Array) Value() (driver.Value, error) { +func (src Int8Array) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/interval.go b/interval.go index a7edca83..bb19f956 100644 --- a/interval.go +++ b/interval.go @@ -179,7 +179,7 @@ func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Interval) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Interval) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -213,7 +213,7 @@ func (src *Interval) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } // EncodeBinary encodes src into w. -func (src *Interval) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Interval) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -246,6 +246,6 @@ func (dst *Interval) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Interval) Value() (driver.Value, error) { +func (src Interval) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/json.go b/json.go index 49ff7a6c..592dfa31 100644 --- a/json.go +++ b/json.go @@ -120,7 +120,7 @@ func (dst *JSON) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } -func (src *JSON) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src JSON) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -131,7 +131,7 @@ func (src *JSON) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, src.Bytes...), nil } -func (src *JSON) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src JSON) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return src.EncodeText(ci, buf) } @@ -155,7 +155,7 @@ func (dst *JSON) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *JSON) Value() (driver.Value, error) { +func (src JSON) Value() (driver.Value, error) { switch src.Status { case Present: return src.Bytes, nil diff --git a/jsonb.go b/jsonb.go index 065e4e21..c70be144 100644 --- a/jsonb.go +++ b/jsonb.go @@ -43,11 +43,11 @@ func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { } -func (src *JSONB) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*JSON)(src).EncodeText(ci, buf) +func (src JSONB) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (JSON)(src).EncodeText(ci, buf) } -func (src *JSONB) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src JSONB) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -65,6 +65,6 @@ func (dst *JSONB) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *JSONB) Value() (driver.Value, error) { - return (*JSON)(src).Value() +func (src JSONB) Value() (driver.Value, error) { + return (JSON)(src).Value() } diff --git a/line.go b/line.go index 617ee456..61477ad9 100644 --- a/line.go +++ b/line.go @@ -93,7 +93,7 @@ func (dst *Line) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Line) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Line) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -110,7 +110,7 @@ func (src *Line) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Line) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Line) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -144,6 +144,6 @@ func (dst *Line) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Line) Value() (driver.Value, error) { +func (src Line) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/lseg.go b/lseg.go index b8d6e322..822b7bf4 100644 --- a/lseg.go +++ b/lseg.go @@ -108,7 +108,7 @@ func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Lseg) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Lseg) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -126,7 +126,7 @@ func (src *Lseg) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Lseg) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Lseg) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -161,6 +161,6 @@ func (dst *Lseg) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Lseg) Value() (driver.Value, error) { +func (src Lseg) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/macaddr.go b/macaddr.go index 25ffc48e..29c60440 100644 --- a/macaddr.go +++ b/macaddr.go @@ -107,7 +107,7 @@ func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Macaddr) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Macaddr) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -119,7 +119,7 @@ func (src *Macaddr) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } // EncodeBinary encodes src into w. -func (src *Macaddr) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Macaddr) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -150,6 +150,6 @@ func (dst *Macaddr) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Macaddr) Value() (driver.Value, error) { +func (src Macaddr) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/macaddr_array.go b/macaddr_array.go index 0b791104..7c62da2b 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -169,7 +169,7 @@ func (dst *MacaddrArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *MacaddrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src MacaddrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -226,7 +226,7 @@ func (src *MacaddrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *MacaddrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src MacaddrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -289,7 +289,7 @@ func (dst *MacaddrArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *MacaddrArray) Value() (driver.Value, error) { +func (src MacaddrArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/name.go b/name.go index af064a82..753a074a 100644 --- a/name.go +++ b/name.go @@ -39,12 +39,12 @@ func (dst *Name) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } -func (src *Name) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Text)(src).EncodeText(ci, buf) +func (src Name) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeText(ci, buf) } -func (src *Name) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Text)(src).EncodeBinary(ci, buf) +func (src Name) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. @@ -53,6 +53,6 @@ func (dst *Name) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Name) Value() (driver.Value, error) { - return (*Text)(src).Value() +func (src Name) Value() (driver.Value, error) { + return (Text)(src).Value() } diff --git a/numeric.go b/numeric.go index 45854e70..554fb582 100644 --- a/numeric.go +++ b/numeric.go @@ -455,7 +455,7 @@ func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { return accum, rp, digits } -func (src *Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -469,7 +469,7 @@ func (src *Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -580,7 +580,7 @@ func (dst *Numeric) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Numeric) Value() (driver.Value, error) { +func (src Numeric) Value() (driver.Value, error) { switch src.Status { case Present: buf, err := src.EncodeText(nil, nil) diff --git a/numeric_array.go b/numeric_array.go index 1e8c5cda..8757b14d 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -252,7 +252,7 @@ func (dst *NumericArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -309,7 +309,7 @@ func (src *NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -372,7 +372,7 @@ func (dst *NumericArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *NumericArray) Value() (driver.Value, error) { +func (src NumericArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/oid_value.go b/oid_value.go index 7eae4bf1..619681a5 100644 --- a/oid_value.go +++ b/oid_value.go @@ -36,12 +36,12 @@ func (dst *OIDValue) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src *OIDValue) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*pguint32)(src).EncodeText(ci, buf) +func (src OIDValue) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeText(ci, buf) } -func (src *OIDValue) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*pguint32)(src).EncodeBinary(ci, buf) +func (src OIDValue) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. @@ -50,6 +50,6 @@ func (dst *OIDValue) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *OIDValue) Value() (driver.Value, error) { - return (*pguint32)(src).Value() +func (src OIDValue) Value() (driver.Value, error) { + return (pguint32)(src).Value() } diff --git a/path.go b/path.go index a4c6af77..484c9174 100644 --- a/path.go +++ b/path.go @@ -116,7 +116,7 @@ func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Path) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Path) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -147,7 +147,7 @@ func (src *Path) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, endByte), nil } -func (src *Path) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Path) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -191,6 +191,6 @@ func (dst *Path) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Path) Value() (driver.Value, error) { +func (src Path) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/pguint32.go b/pguint32.go index 21da9664..546d6f8f 100644 --- a/pguint32.go +++ b/pguint32.go @@ -102,7 +102,7 @@ func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *pguint32) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src pguint32) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -113,7 +113,7 @@ func (src *pguint32) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, strconv.FormatUint(uint64(src.Uint), 10)...), nil } -func (src *pguint32) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src pguint32) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -150,7 +150,7 @@ func (dst *pguint32) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *pguint32) Value() (driver.Value, error) { +func (src pguint32) Value() (driver.Value, error) { switch src.Status { case Present: return int64(src.Uint), nil diff --git a/point.go b/point.go index 89f2359b..bb7daa24 100644 --- a/point.go +++ b/point.go @@ -90,7 +90,7 @@ func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Point) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Point) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -104,7 +104,7 @@ func (src *Point) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { )...), nil } -func (src *Point) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Point) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -137,6 +137,6 @@ func (dst *Point) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Point) Value() (driver.Value, error) { +func (src Point) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/polygon.go b/polygon.go index e739c71b..7805604b 100644 --- a/polygon.go +++ b/polygon.go @@ -111,7 +111,7 @@ func (dst *Polygon) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Polygon) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Polygon) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -134,7 +134,7 @@ func (src *Polygon) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, ')'), nil } -func (src *Polygon) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Polygon) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -172,6 +172,6 @@ func (dst *Polygon) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Polygon) Value() (driver.Value, error) { +func (src Polygon) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/qchar.go b/qchar.go index 5e77dc38..8a316d9b 100644 --- a/qchar.go +++ b/qchar.go @@ -134,7 +134,7 @@ func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *QChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src QChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil diff --git a/text.go b/text.go index 4d4e6bb4..d13a9ba4 100644 --- a/text.go +++ b/text.go @@ -92,7 +92,7 @@ func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } -func (src *Text) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Text) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -103,7 +103,7 @@ func (src *Text) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, src.String...), nil } -func (src *Text) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Text) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return src.EncodeText(ci, buf) } @@ -127,7 +127,7 @@ func (dst *Text) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Text) Value() (driver.Value, error) { +func (src Text) Value() (driver.Value, error) { switch src.Status { case Present: return src.String, nil @@ -138,7 +138,7 @@ func (src *Text) Value() (driver.Value, error) { } } -func (src *Text) MarshalJSON() ([]byte, error) { +func (src Text) MarshalJSON() ([]byte, error) { switch src.Status { case Present: return json.Marshal(src.String) diff --git a/text_array.go b/text_array.go index b590972e..fca36ec8 100644 --- a/text_array.go +++ b/text_array.go @@ -168,7 +168,7 @@ func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +225,7 @@ func (src *TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +288,7 @@ func (dst *TextArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *TextArray) Value() (driver.Value, error) { +func (src TextArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/tid.go b/tid.go index ff788b84..08f5c047 100644 --- a/tid.go +++ b/tid.go @@ -94,7 +94,7 @@ func (dst *TID) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *TID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src TID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -106,7 +106,7 @@ func (src *TID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *TID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src TID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -139,6 +139,6 @@ func (dst *TID) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *TID) Value() (driver.Value, error) { +func (src TID) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/timestamp.go b/timestamp.go index 40dfdac8..01c38a0a 100644 --- a/timestamp.go +++ b/timestamp.go @@ -136,7 +136,7 @@ func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { // EncodeText writes the text encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src *Timestamp) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Timestamp) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -163,7 +163,7 @@ func (src *Timestamp) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { // EncodeBinary writes the binary encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src *Timestamp) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Timestamp) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -211,7 +211,7 @@ func (dst *Timestamp) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Timestamp) Value() (driver.Value, error) { +func (src Timestamp) Value() (driver.Value, error) { switch src.Status { case Present: if src.InfinityModifier != None { diff --git a/timestamp_array.go b/timestamp_array.go index 95f76639..204b22eb 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -169,7 +169,7 @@ func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -226,7 +226,7 @@ func (src *TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) return buf, nil } -func (src *TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -289,7 +289,7 @@ func (dst *TimestampArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *TimestampArray) Value() (driver.Value, error) { +func (src TimestampArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/timestamptz.go b/timestamptz.go index 752c1818..9af39b16 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -140,7 +140,7 @@ func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Timestamptz) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Timestamptz) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -162,7 +162,7 @@ func (src *Timestamptz) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, s...), nil } -func (src *Timestamptz) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Timestamptz) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -207,7 +207,7 @@ func (dst *Timestamptz) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Timestamptz) Value() (driver.Value, error) { +func (src Timestamptz) Value() (driver.Value, error) { switch src.Status { case Present: if src.InfinityModifier != None { diff --git a/timestamptz_array.go b/timestamptz_array.go index 7fe60d50..9bef64c6 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -169,7 +169,7 @@ func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -226,7 +226,7 @@ func (src *TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error return buf, nil } -func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -289,7 +289,7 @@ func (dst *TimestamptzArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *TimestamptzArray) Value() (driver.Value, error) { +func (src TimestamptzArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/unknown.go b/unknown.go index 567831d7..2dca0f87 100644 --- a/unknown.go +++ b/unknown.go @@ -39,6 +39,6 @@ func (dst *Unknown) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Unknown) Value() (driver.Value, error) { - return (*Text)(src).Value() +func (src Unknown) Value() (driver.Value, error) { + return (Text)(src).Value() } diff --git a/uuid.go b/uuid.go index 5dd10d89..ba999a06 100644 --- a/uuid.go +++ b/uuid.go @@ -139,7 +139,7 @@ func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *UUID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src UUID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -150,7 +150,7 @@ func (src *UUID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, encodeUUID(src.Bytes)...), nil } -func (src *UUID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src UUID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -181,6 +181,6 @@ func (dst *UUID) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *UUID) Value() (driver.Value, error) { +func (src UUID) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/uuid_array.go b/uuid_array.go index 1d28ee59..c3f18882 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -224,7 +224,7 @@ func (dst *UUIDArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -281,7 +281,7 @@ func (src *UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *UUIDArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src UUIDArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -344,7 +344,7 @@ func (dst *UUIDArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *UUIDArray) Value() (driver.Value, error) { +func (src UUIDArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/varbit.go b/varbit.go index fe4db33d..019fff8a 100644 --- a/varbit.go +++ b/varbit.go @@ -75,7 +75,7 @@ func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Varbit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Varbit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -96,7 +96,7 @@ func (src *Varbit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Varbit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Varbit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -128,6 +128,6 @@ func (dst *Varbit) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Varbit) Value() (driver.Value, error) { +func (src Varbit) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/varchar.go b/varchar.go index 6be1a035..58de1097 100644 --- a/varchar.go +++ b/varchar.go @@ -31,12 +31,12 @@ func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } -func (src *Varchar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Text)(src).EncodeText(ci, buf) +func (src Varchar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeText(ci, buf) } -func (src *Varchar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Text)(src).EncodeBinary(ci, buf) +func (src Varchar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. @@ -45,12 +45,12 @@ func (dst *Varchar) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Varchar) Value() (driver.Value, error) { - return (*Text)(src).Value() +func (src Varchar) Value() (driver.Value, error) { + return (Text)(src).Value() } -func (src *Varchar) MarshalJSON() ([]byte, error) { - return (*Text)(src).MarshalJSON() +func (src Varchar) MarshalJSON() ([]byte, error) { + return (Text)(src).MarshalJSON() } func (dst *Varchar) UnmarshalJSON(b []byte) error { diff --git a/varchar_array.go b/varchar_array.go index 6aa92337..1e60c344 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -168,7 +168,7 @@ func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +225,7 @@ func (src *VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +288,7 @@ func (dst *VarcharArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *VarcharArray) Value() (driver.Value, error) { +func (src VarcharArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/xid.go b/xid.go index f66f5367..80ebf0e0 100644 --- a/xid.go +++ b/xid.go @@ -45,12 +45,12 @@ func (dst *XID) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src *XID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*pguint32)(src).EncodeText(ci, buf) +func (src XID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeText(ci, buf) } -func (src *XID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*pguint32)(src).EncodeBinary(ci, buf) +func (src XID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. @@ -59,6 +59,6 @@ func (dst *XID) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *XID) Value() (driver.Value, error) { - return (*pguint32)(src).Value() +func (src XID) Value() (driver.Value, error) { + return (pguint32)(src).Value() } From 76538434cf13478c1e52d991f5d707430bdb6828 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 27 Aug 2019 21:13:45 -0500 Subject: [PATCH 0329/1158] MarshalJSON should be defined on T not *T Otherwise "%v" format would be used by json.Marshal(T). --- backend_key_data.go | 2 +- bind.go | 2 +- bind_complete.go | 2 +- close.go | 2 +- close_complete.go | 2 +- command_complete.go | 2 +- copy_both_response.go | 2 +- copy_data.go | 2 +- copy_done.go | 2 +- copy_fail.go | 2 +- copy_in_response.go | 2 +- copy_out_response.go | 2 +- data_row.go | 2 +- describe.go | 2 +- empty_query_response.go | 2 +- execute.go | 2 +- flush.go | 2 +- function_call_response.go | 2 +- no_data.go | 2 +- notification_response.go | 2 +- parameter_description.go | 2 +- parameter_status.go | 2 +- parse.go | 2 +- parse_complete.go | 2 +- password_message.go | 2 +- portal_suspended.go | 2 +- query.go | 2 +- ready_for_query.go | 2 +- row_description.go | 2 +- sasl_initial_response.go | 2 +- sasl_response.go | 2 +- startup_message.go | 2 +- sync.go | 2 +- terminate.go | 2 +- 34 files changed, 34 insertions(+), 34 deletions(-) diff --git a/backend_key_data.go b/backend_key_data.go index b775d689..ca20dd25 100644 --- a/backend_key_data.go +++ b/backend_key_data.go @@ -38,7 +38,7 @@ func (src *BackendKeyData) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *BackendKeyData) MarshalJSON() ([]byte, error) { +func (src BackendKeyData) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ProcessID uint32 diff --git a/bind.go b/bind.go index 67d20b5d..65b4c1d6 100644 --- a/bind.go +++ b/bind.go @@ -144,7 +144,7 @@ func (src *Bind) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *Bind) MarshalJSON() ([]byte, error) { +func (src Bind) MarshalJSON() ([]byte, error) { formattedParameters := make([]map[string]string, len(src.Parameters)) for i, p := range src.Parameters { if p == nil { diff --git a/bind_complete.go b/bind_complete.go index fc9d317a..3be256c8 100644 --- a/bind_complete.go +++ b/bind_complete.go @@ -25,7 +25,7 @@ func (src *BindComplete) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *BindComplete) MarshalJSON() ([]byte, error) { +func (src BindComplete) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/close.go b/close.go index 349a319d..38296909 100644 --- a/close.go +++ b/close.go @@ -51,7 +51,7 @@ func (src *Close) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *Close) MarshalJSON() ([]byte, error) { +func (src Close) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ObjectType string diff --git a/close_complete.go b/close_complete.go index b4982207..1d7b8f08 100644 --- a/close_complete.go +++ b/close_complete.go @@ -25,7 +25,7 @@ func (src *CloseComplete) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *CloseComplete) MarshalJSON() ([]byte, error) { +func (src CloseComplete) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/command_complete.go b/command_complete.go index 87fcddf6..b5106fda 100644 --- a/command_complete.go +++ b/command_complete.go @@ -42,7 +42,7 @@ func (src *CommandComplete) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *CommandComplete) MarshalJSON() ([]byte, error) { +func (src CommandComplete) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string CommandTag string diff --git a/copy_both_response.go b/copy_both_response.go index b037a197..2d58f820 100644 --- a/copy_both_response.go +++ b/copy_both_response.go @@ -59,7 +59,7 @@ func (src *CopyBothResponse) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *CopyBothResponse) MarshalJSON() ([]byte, error) { +func (src CopyBothResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ColumnFormatCodes []uint16 diff --git a/copy_data.go b/copy_data.go index 317710ac..7d6002fe 100644 --- a/copy_data.go +++ b/copy_data.go @@ -33,7 +33,7 @@ func (src *CopyData) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *CopyData) MarshalJSON() ([]byte, error) { +func (src CopyData) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string Data string diff --git a/copy_done.go b/copy_done.go index 7612350a..d8b6e5d7 100644 --- a/copy_done.go +++ b/copy_done.go @@ -26,7 +26,7 @@ func (src *CopyDone) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *CopyDone) MarshalJSON() ([]byte, error) { +func (src CopyDone) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/copy_fail.go b/copy_fail.go index b12d7ba0..78ff0b30 100644 --- a/copy_fail.go +++ b/copy_fail.go @@ -42,7 +42,7 @@ func (src *CopyFail) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *CopyFail) MarshalJSON() ([]byte, error) { +func (src CopyFail) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string Message string diff --git a/copy_in_response.go b/copy_in_response.go index d28baa33..4439a032 100644 --- a/copy_in_response.go +++ b/copy_in_response.go @@ -59,7 +59,7 @@ func (src *CopyInResponse) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *CopyInResponse) MarshalJSON() ([]byte, error) { +func (src CopyInResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ColumnFormatCodes []uint16 diff --git a/copy_out_response.go b/copy_out_response.go index 1d3c2364..8538dfc7 100644 --- a/copy_out_response.go +++ b/copy_out_response.go @@ -60,7 +60,7 @@ func (src *CopyOutResponse) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *CopyOutResponse) MarshalJSON() ([]byte, error) { +func (src CopyOutResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ColumnFormatCodes []uint16 diff --git a/data_row.go b/data_row.go index 9d6a9f1f..d908e7b2 100644 --- a/data_row.go +++ b/data_row.go @@ -85,7 +85,7 @@ func (src *DataRow) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *DataRow) MarshalJSON() ([]byte, error) { +func (src DataRow) MarshalJSON() ([]byte, error) { formattedValues := make([]map[string]string, len(src.Values)) for i, v := range src.Values { if v == nil { diff --git a/describe.go b/describe.go index d3fb5b09..308f582e 100644 --- a/describe.go +++ b/describe.go @@ -51,7 +51,7 @@ func (src *Describe) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *Describe) MarshalJSON() ([]byte, error) { +func (src Describe) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ObjectType string diff --git a/empty_query_response.go b/empty_query_response.go index 1bec52e2..2b85e744 100644 --- a/empty_query_response.go +++ b/empty_query_response.go @@ -25,7 +25,7 @@ func (src *EmptyQueryResponse) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) { +func (src EmptyQueryResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/execute.go b/execute.go index 32269857..8bae6133 100644 --- a/execute.go +++ b/execute.go @@ -52,7 +52,7 @@ func (src *Execute) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *Execute) MarshalJSON() ([]byte, error) { +func (src Execute) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string Portal string diff --git a/flush.go b/flush.go index e7bc7e43..2725f689 100644 --- a/flush.go +++ b/flush.go @@ -25,7 +25,7 @@ func (src *Flush) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *Flush) MarshalJSON() ([]byte, error) { +func (src Flush) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/function_call_response.go b/function_call_response.go index 72bb907c..5cc2d4d2 100644 --- a/function_call_response.go +++ b/function_call_response.go @@ -57,7 +57,7 @@ func (src *FunctionCallResponse) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) { +func (src FunctionCallResponse) MarshalJSON() ([]byte, error) { var formattedValue map[string]string var hasNonPrintable bool for _, b := range src.Result { diff --git a/no_data.go b/no_data.go index 172d0dc1..d8f85d38 100644 --- a/no_data.go +++ b/no_data.go @@ -25,7 +25,7 @@ func (src *NoData) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *NoData) MarshalJSON() ([]byte, error) { +func (src NoData) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/notification_response.go b/notification_response.go index 33170f66..cd83c5ba 100644 --- a/notification_response.go +++ b/notification_response.go @@ -57,7 +57,7 @@ func (src *NotificationResponse) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *NotificationResponse) MarshalJSON() ([]byte, error) { +func (src NotificationResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string PID uint32 diff --git a/parameter_description.go b/parameter_description.go index a43e802e..e28965c8 100644 --- a/parameter_description.go +++ b/parameter_description.go @@ -55,7 +55,7 @@ func (src *ParameterDescription) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *ParameterDescription) MarshalJSON() ([]byte, error) { +func (src ParameterDescription) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ParameterOIDs []uint32 diff --git a/parameter_status.go b/parameter_status.go index 4385fe99..c4021d92 100644 --- a/parameter_status.go +++ b/parameter_status.go @@ -53,7 +53,7 @@ func (src *ParameterStatus) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (ps *ParameterStatus) MarshalJSON() ([]byte, error) { +func (ps ParameterStatus) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string Name string diff --git a/parse.go b/parse.go index d0bbf865..723885d4 100644 --- a/parse.go +++ b/parse.go @@ -73,7 +73,7 @@ func (src *Parse) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *Parse) MarshalJSON() ([]byte, error) { +func (src Parse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string Name string diff --git a/parse_complete.go b/parse_complete.go index c2d3a34d..92c9498b 100644 --- a/parse_complete.go +++ b/parse_complete.go @@ -25,7 +25,7 @@ func (src *ParseComplete) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *ParseComplete) MarshalJSON() ([]byte, error) { +func (src ParseComplete) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/password_message.go b/password_message.go index b01316e9..4b68b31a 100644 --- a/password_message.go +++ b/password_message.go @@ -40,7 +40,7 @@ func (src *PasswordMessage) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *PasswordMessage) MarshalJSON() ([]byte, error) { +func (src PasswordMessage) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string Password string diff --git a/portal_suspended.go b/portal_suspended.go index 5603d95e..1a9e7bfb 100644 --- a/portal_suspended.go +++ b/portal_suspended.go @@ -25,7 +25,7 @@ func (src *PortalSuspended) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *PortalSuspended) MarshalJSON() ([]byte, error) { +func (src PortalSuspended) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/query.go b/query.go index 17377dfb..41c93b4a 100644 --- a/query.go +++ b/query.go @@ -39,7 +39,7 @@ func (src *Query) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *Query) MarshalJSON() ([]byte, error) { +func (src Query) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string String string diff --git a/ready_for_query.go b/ready_for_query.go index 65f7d8c1..879afe39 100644 --- a/ready_for_query.go +++ b/ready_for_query.go @@ -29,7 +29,7 @@ func (src *ReadyForQuery) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *ReadyForQuery) MarshalJSON() ([]byte, error) { +func (src ReadyForQuery) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string TxStatus string diff --git a/row_description.go b/row_description.go index 87479188..2745fa43 100644 --- a/row_description.go +++ b/row_description.go @@ -102,7 +102,7 @@ func (src *RowDescription) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *RowDescription) MarshalJSON() ([]byte, error) { +func (src RowDescription) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string Fields []FieldDescription diff --git a/sasl_initial_response.go b/sasl_initial_response.go index b9459e16..0bf8a9e5 100644 --- a/sasl_initial_response.go +++ b/sasl_initial_response.go @@ -56,7 +56,7 @@ func (src *SASLInitialResponse) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *SASLInitialResponse) MarshalJSON() ([]byte, error) { +func (src SASLInitialResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string AuthMechanism string diff --git a/sasl_response.go b/sasl_response.go index ef893437..21be6d75 100644 --- a/sasl_response.go +++ b/sasl_response.go @@ -32,7 +32,7 @@ func (src *SASLResponse) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *SASLResponse) MarshalJSON() ([]byte, error) { +func (src SASLResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string Data string diff --git a/startup_message.go b/startup_message.go index d9f04d17..0c5c961d 100644 --- a/startup_message.go +++ b/startup_message.go @@ -89,7 +89,7 @@ func (src *StartupMessage) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *StartupMessage) MarshalJSON() ([]byte, error) { +func (src StartupMessage) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ProtocolVersion uint32 diff --git a/sync.go b/sync.go index a058e8c9..5db8e07a 100644 --- a/sync.go +++ b/sync.go @@ -25,7 +25,7 @@ func (src *Sync) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *Sync) MarshalJSON() ([]byte, error) { +func (src Sync) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/terminate.go b/terminate.go index 6c9d5b1a..135191ea 100644 --- a/terminate.go +++ b/terminate.go @@ -25,7 +25,7 @@ func (src *Terminate) Encode(dst []byte) []byte { } // MarshalJSON implements encoding/json.Marshaler. -func (src *Terminate) MarshalJSON() ([]byte, error) { +func (src Terminate) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ From 1ba5dcbe01a089d1f2bd3822e7bc7a0c913c44cc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 11:48:01 -0500 Subject: [PATCH 0330/1158] Support SSLRequest and CancelRequest --- backend.go | 38 +++++++++++++++++++++++------- cancel_request.go | 58 ++++++++++++++++++++++++++++++++++++++++++++++ ssl_request.go | 49 +++++++++++++++++++++++++++++++++++++++ startup_message.go | 9 +------ 4 files changed, 138 insertions(+), 16 deletions(-) create mode 100644 cancel_request.go create mode 100644 ssl_request.go diff --git a/backend.go b/backend.go index 2e2f5eea..be8f3bdb 100644 --- a/backend.go +++ b/backend.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "fmt" "io" "github.com/pkg/errors" @@ -14,6 +15,7 @@ type Backend struct { // Frontend message flyweights bind Bind + cancelRequest CancelRequest _close Close copyFail CopyFail describe Describe @@ -22,6 +24,7 @@ type Backend struct { parse Parse passwordMessage PasswordMessage query Query + sslRequest SSLRequest startupMessage StartupMessage sync Sync terminate Terminate @@ -42,9 +45,10 @@ func (b *Backend) Send(msg BackendMessage) error { return err } -// ReceiveStartupMessage receives the initial startup message. This method is used of the normal Receive method -// because StartupMessage and SSLRequest are "special" and do not include the message type as the first byte. -func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { +// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method +// because the initial connection message is "special" and does not include the message type as the first byte. This +// will return either a StartupMessage, SSLRequest, or CancelRequest. +func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { buf, err := b.cr.Next(4) if err != nil { return nil, err @@ -56,12 +60,30 @@ func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { return nil, err } - err = b.startupMessage.Decode(buf) - if err != nil { - return nil, err - } + code := binary.BigEndian.Uint32(buf) - return &b.startupMessage, nil + switch code { + case ProtocolVersionNumber: + err = b.startupMessage.Decode(buf) + if err != nil { + return nil, err + } + return &b.startupMessage, nil + case sslRequestNumber: + err = b.sslRequest.Decode(buf) + if err != nil { + return nil, err + } + return &b.sslRequest, nil + case cancelRequestCode: + err = b.cancelRequest.Decode(buf) + if err != nil { + return nil, err + } + return &b.cancelRequest, nil + default: + return nil, fmt.Errorf("unknown startup message code: %d", code) + } } // Receive receives a message from the frontend. diff --git a/cancel_request.go b/cancel_request.go new file mode 100644 index 00000000..ec1d8606 --- /dev/null +++ b/cancel_request.go @@ -0,0 +1,58 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgio" + "github.com/pkg/errors" +) + +const cancelRequestCode = 80877102 + +type CancelRequest struct { + ProcessID uint32 + SecretKey uint32 +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*CancelRequest) Frontend() {} + +func (dst *CancelRequest) Decode(src []byte) error { + if len(src) != 12 { + return errors.Errorf("bad cancel request size") + } + + requestCode := binary.BigEndian.Uint32(src) + + if requestCode != cancelRequestCode { + return errors.Errorf("bad cancel request code") + } + + dst.ProcessID = binary.BigEndian.Uint32(src[4:]) + dst.SecretKey = binary.BigEndian.Uint32(src[8:]) + + return nil +} + +// Encode encodes src into dst. dst will include the 4 byte message length. +func (src *CancelRequest) Encode(dst []byte) []byte { + dst = pgio.AppendInt32(dst, 16) + dst = pgio.AppendInt32(dst, cancelRequestCode) + dst = pgio.AppendUint32(dst, src.ProcessID) + dst = pgio.AppendUint32(dst, src.SecretKey) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CancelRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProcessID uint32 + SecretKey uint32 + }{ + Type: "CancelRequest", + ProcessID: src.ProcessID, + SecretKey: src.SecretKey, + }) +} diff --git a/ssl_request.go b/ssl_request.go new file mode 100644 index 00000000..2f4b378a --- /dev/null +++ b/ssl_request.go @@ -0,0 +1,49 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgio" + "github.com/pkg/errors" +) + +const sslRequestNumber = 80877103 + +type SSLRequest struct { +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*SSLRequest) Frontend() {} + +func (dst *SSLRequest) Decode(src []byte) error { + if len(src) < 4 { + return errors.Errorf("ssl request too short") + } + + requestCode := binary.BigEndian.Uint32(src) + + if requestCode != sslRequestNumber { + return errors.Errorf("bad ssl request code") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 4 byte message length. +func (src *SSLRequest) Encode(dst []byte) []byte { + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendInt32(dst, sslRequestNumber) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src SSLRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProtocolVersion uint32 + Parameters map[string]string + }{ + Type: "SSLRequest", + }) +} diff --git a/startup_message.go b/startup_message.go index 0c5c961d..5be42500 100644 --- a/startup_message.go +++ b/startup_message.go @@ -9,10 +9,7 @@ import ( "github.com/pkg/errors" ) -const ( - ProtocolVersionNumber = 196608 // 3.0 - sslRequestNumber = 80877103 -) +const ProtocolVersionNumber = 196608 // 3.0 type StartupMessage struct { ProtocolVersion uint32 @@ -32,10 +29,6 @@ func (dst *StartupMessage) Decode(src []byte) error { dst.ProtocolVersion = binary.BigEndian.Uint32(src) rp := 4 - if dst.ProtocolVersion == sslRequestNumber { - return errors.Errorf("can't handle ssl connection request") - } - if dst.ProtocolVersion != ProtocolVersionNumber { return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) } From 2bc8f2e6afa5b2c92854cfff6e025fcadcd07fb1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 11:53:26 -0500 Subject: [PATCH 0331/1158] Remove pkg/errors package --- authentication.go | 4 ++-- backend.go | 4 +--- cancel_request.go | 6 +++--- frontend.go | 5 ++--- frontend_test.go | 3 +-- go.mod | 2 -- go.sum | 4 ---- ssl_request.go | 6 +++--- startup_message.go | 9 +++++---- 9 files changed, 17 insertions(+), 26 deletions(-) diff --git a/authentication.go b/authentication.go index bc654c4f..5ff05d96 100644 --- a/authentication.go +++ b/authentication.go @@ -3,9 +3,9 @@ package pgproto3 import ( "bytes" "encoding/binary" + "fmt" "github.com/jackc/pgio" - "github.com/pkg/errors" ) // Authentication message type constants. @@ -60,7 +60,7 @@ func (dst *Authentication) Decode(src []byte) error { case AuthTypeSASLContinue, AuthTypeSASLFinal: dst.SASLData = src[4:] default: - return errors.Errorf("unknown authentication type: %d", dst.Type) + return fmt.Errorf("unknown authentication type: %d", dst.Type) } return nil diff --git a/backend.go b/backend.go index be8f3bdb..176a25ba 100644 --- a/backend.go +++ b/backend.go @@ -4,8 +4,6 @@ import ( "encoding/binary" "fmt" "io" - - "github.com/pkg/errors" ) // Backend acts as a server for the PostgreSQL wire protocol version 3. @@ -124,7 +122,7 @@ func (b *Backend) Receive() (FrontendMessage, error) { case 'X': msg = &b.terminate default: - return nil, errors.Errorf("unknown message type: %c", b.msgType) + return nil, fmt.Errorf("unknown message type: %c", b.msgType) } msgBody, err := b.cr.Next(b.bodyLen) diff --git a/cancel_request.go b/cancel_request.go index ec1d8606..942e404b 100644 --- a/cancel_request.go +++ b/cancel_request.go @@ -3,9 +3,9 @@ package pgproto3 import ( "encoding/binary" "encoding/json" + "errors" "github.com/jackc/pgio" - "github.com/pkg/errors" ) const cancelRequestCode = 80877102 @@ -20,13 +20,13 @@ func (*CancelRequest) Frontend() {} func (dst *CancelRequest) Decode(src []byte) error { if len(src) != 12 { - return errors.Errorf("bad cancel request size") + return errors.New("bad cancel request size") } requestCode := binary.BigEndian.Uint32(src) if requestCode != cancelRequestCode { - return errors.Errorf("bad cancel request code") + return errors.New("bad cancel request code") } dst.ProcessID = binary.BigEndian.Uint32(src[4:]) diff --git a/frontend.go b/frontend.go index bcf796e1..f6ebaf43 100644 --- a/frontend.go +++ b/frontend.go @@ -2,9 +2,8 @@ package pgproto3 import ( "encoding/binary" + "fmt" "io" - - "github.com/pkg/errors" ) // Frontend acts as a client for the PostgreSQL wire protocol version 3. @@ -115,7 +114,7 @@ func (b *Frontend) Receive() (BackendMessage, error) { case 'Z': msg = &b.readyForQuery default: - return nil, errors.Errorf("unknown message type: %c", b.msgType) + return nil, fmt.Errorf("unknown message type: %c", b.msgType) } msgBody, err := b.cr.Next(b.bodyLen) diff --git a/frontend_test.go b/frontend_test.go index 2d5c8de7..1d6a07ae 100644 --- a/frontend_test.go +++ b/frontend_test.go @@ -1,10 +1,9 @@ package pgproto3_test import ( + "errors" "testing" - "github.com/pkg/errors" - "github.com/jackc/pgproto3/v2" ) diff --git a/go.mod b/go.mod index 800f6043..4821676a 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,6 @@ module github.com/jackc/pgproto3/v2 go 1.12 require ( - github.com/jackc/chunkreader v1.0.0 github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 - github.com/pkg/errors v0.8.1 ) diff --git a/go.sum b/go.sum index 5dd456ad..36160794 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,4 @@ -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= -github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/ssl_request.go b/ssl_request.go index 2f4b378a..96ce489e 100644 --- a/ssl_request.go +++ b/ssl_request.go @@ -3,9 +3,9 @@ package pgproto3 import ( "encoding/binary" "encoding/json" + "errors" "github.com/jackc/pgio" - "github.com/pkg/errors" ) const sslRequestNumber = 80877103 @@ -18,13 +18,13 @@ func (*SSLRequest) Frontend() {} func (dst *SSLRequest) Decode(src []byte) error { if len(src) < 4 { - return errors.Errorf("ssl request too short") + return errors.New("ssl request too short") } requestCode := binary.BigEndian.Uint32(src) if requestCode != sslRequestNumber { - return errors.Errorf("bad ssl request code") + return errors.New("bad ssl request code") } return nil diff --git a/startup_message.go b/startup_message.go index 5be42500..5f1cd24f 100644 --- a/startup_message.go +++ b/startup_message.go @@ -4,9 +4,10 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "fmt" "github.com/jackc/pgio" - "github.com/pkg/errors" ) const ProtocolVersionNumber = 196608 // 3.0 @@ -23,14 +24,14 @@ func (*StartupMessage) Frontend() {} // type identifier and 4 byte message length. func (dst *StartupMessage) Decode(src []byte) error { if len(src) < 4 { - return errors.Errorf("startup message too short") + return errors.New("startup message too short") } dst.ProtocolVersion = binary.BigEndian.Uint32(src) rp := 4 if dst.ProtocolVersion != ProtocolVersionNumber { - return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) + return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) } dst.Parameters = make(map[string]string) @@ -53,7 +54,7 @@ func (dst *StartupMessage) Decode(src []byte) error { if len(src[rp:]) == 1 { if src[rp] != 0 { - return errors.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) + return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) } break } From 6bba3c4810ce93171830696896238f19911b7ca3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 11:55:02 -0500 Subject: [PATCH 0332/1158] Update pgproto3 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index cbeef02a..b54607b6 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 + github.com/jackc/pgproto3/v2 v2.0.0-rc2 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 diff --git a/go.sum b/go.sum index 0e853203..d7a6d087 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,8 @@ github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc2 h1:u+jUsxBxiLY2C6mhr8cZhSy71n/y8Id2STOzJ7bl2Mg= +github.com/jackc/pgproto3/v2 v2.0.0-rc2/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= From 439ea11d4737a38ab70bfadd8218d15ff7967909 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 14:49:55 -0500 Subject: [PATCH 0333/1158] NewFrontend and NewBackend cannot fail --- backend.go | 4 ++-- backend_test.go | 5 +---- frontend.go | 4 ++-- frontend_test.go | 5 +---- 4 files changed, 6 insertions(+), 12 deletions(-) diff --git a/backend.go b/backend.go index 176a25ba..5741647f 100644 --- a/backend.go +++ b/backend.go @@ -33,8 +33,8 @@ type Backend struct { } // NewBackend creates a new Backend. -func NewBackend(cr ChunkReader, w io.Writer) (*Backend, error) { - return &Backend{cr: cr, w: w}, nil +func NewBackend(cr ChunkReader, w io.Writer) *Backend { + return &Backend{cr: cr, w: w} } // Send sends a message to the frontend. diff --git a/backend_test.go b/backend_test.go index cd8788b8..43a3f76c 100644 --- a/backend_test.go +++ b/backend_test.go @@ -12,10 +12,7 @@ func TestBackendReceiveInterrupted(t *testing.T) { server := &interruptReader{} server.push([]byte{'Q', 0, 0, 0, 6}) - backend, err := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) - if err != nil { - t.Fatal(err) - } + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) msg, err := backend.Receive() if err == nil { diff --git a/frontend.go b/frontend.go index f6ebaf43..a67b6670 100644 --- a/frontend.go +++ b/frontend.go @@ -42,8 +42,8 @@ type Frontend struct { } // NewFrontend creates a new Frontend. -func NewFrontend(cr ChunkReader, w io.Writer) (*Frontend, error) { - return &Frontend{cr: cr, w: w}, nil +func NewFrontend(cr ChunkReader, w io.Writer) *Frontend { + return &Frontend{cr: cr, w: w} } // Send sends a message to the backend. diff --git a/frontend_test.go b/frontend_test.go index 1d6a07ae..9b63aa00 100644 --- a/frontend_test.go +++ b/frontend_test.go @@ -36,10 +36,7 @@ func TestFrontendReceiveInterrupted(t *testing.T) { server := &interruptReader{} server.push([]byte{'Z', 0, 0, 0, 5}) - frontend, err := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil) - if err != nil { - t.Fatal(err) - } + frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil) msg, err := frontend.Receive() if err == nil { From 0d1ceed7a6902fbf9ae6c0c43f00174400547d5d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 15:43:07 -0500 Subject: [PATCH 0334/1158] Refactor authentication message handling --- authentication.go | 93 ------------ authentication_cleartext_password.go | 39 +++++ authentication_md5_password.go | 43 ++++++ authentication_ok.go | 39 +++++ authentication_sasl.go | 60 ++++++++ authentication_sasl_continue.go | 49 ++++++ authentication_sasl_final.go | 49 ++++++ frontend.go | 214 ++++++++++++++++----------- 8 files changed, 408 insertions(+), 178 deletions(-) delete mode 100644 authentication.go create mode 100644 authentication_cleartext_password.go create mode 100644 authentication_md5_password.go create mode 100644 authentication_ok.go create mode 100644 authentication_sasl.go create mode 100644 authentication_sasl_continue.go create mode 100644 authentication_sasl_final.go diff --git a/authentication.go b/authentication.go deleted file mode 100644 index 5ff05d96..00000000 --- a/authentication.go +++ /dev/null @@ -1,93 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/binary" - "fmt" - - "github.com/jackc/pgio" -) - -// Authentication message type constants. -const ( - AuthTypeOk = 0 - AuthTypeCleartextPassword = 3 - AuthTypeMD5Password = 5 - AuthTypeSASL = 10 - AuthTypeSASLContinue = 11 - AuthTypeSASLFinal = 12 -) - -// Authentication is a message sent from the backend during the authentication process. -// -// There are multiple authentication messages that each begin with 'R'. This structure represents all such -// authentication messages. -type Authentication struct { - Type uint32 - - // MD5Password fields - Salt [4]byte - - // SASL fields - SASLAuthMechanisms []string - - // SASLContinue and SASLFinal data - SASLData []byte -} - -// Backend identifies this message as sendable by the PostgreSQL backend. -func (*Authentication) Backend() {} - -// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message -// type identifier and 4 byte message length. -func (dst *Authentication) Decode(src []byte) error { - *dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])} - - switch dst.Type { - case AuthTypeOk: - case AuthTypeCleartextPassword: - case AuthTypeMD5Password: - copy(dst.Salt[:], src[4:8]) - case AuthTypeSASL: - authMechanisms := src[4:] - for len(authMechanisms) > 1 { - idx := bytes.IndexByte(authMechanisms, 0) - if idx > 0 { - dst.SASLAuthMechanisms = append(dst.SASLAuthMechanisms, string(authMechanisms[:idx])) - authMechanisms = authMechanisms[idx+1:] - } - } - case AuthTypeSASLContinue, AuthTypeSASLFinal: - dst.SASLData = src[4:] - default: - return fmt.Errorf("unknown authentication type: %d", dst.Type) - } - - return nil -} - -// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Authentication) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - dst = pgio.AppendUint32(dst, src.Type) - - switch src.Type { - case AuthTypeMD5Password: - dst = append(dst, src.Salt[:]...) - case AuthTypeSASL: - for _, s := range src.SASLAuthMechanisms { - dst = append(dst, []byte(s)...) - dst = append(dst, 0) - } - dst = append(dst, 0) - case AuthTypeSASLContinue: - dst = pgio.AppendInt32(dst, int32(len(src.SASLData))) - dst = append(dst, src.SASLData...) - } - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} diff --git a/authentication_cleartext_password.go b/authentication_cleartext_password.go new file mode 100644 index 00000000..dd82c7a7 --- /dev/null +++ b/authentication_cleartext_password.go @@ -0,0 +1,39 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required. +type AuthenticationCleartextPassword struct { +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationCleartextPassword) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationCleartextPassword) Decode(src []byte) error { + if len(src) != 4 { + return errors.New("bad authentication message size") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeCleartextPassword { + return errors.New("bad auth type") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword) + return dst +} diff --git a/authentication_md5_password.go b/authentication_md5_password.go new file mode 100644 index 00000000..4680db5a --- /dev/null +++ b/authentication_md5_password.go @@ -0,0 +1,43 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required. +type AuthenticationMD5Password struct { + Salt [4]byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationMD5Password) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationMD5Password) Decode(src []byte) error { + if len(src) != 8 { + return errors.New("bad authentication message size") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeMD5Password { + return errors.New("bad auth type") + } + + copy(dst.Salt[:], src[4:8]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, 12) + dst = pgio.AppendUint32(dst, AuthTypeOk) + dst = append(dst, src.Salt[:]...) + return dst +} diff --git a/authentication_ok.go b/authentication_ok.go new file mode 100644 index 00000000..7b13c6e0 --- /dev/null +++ b/authentication_ok.go @@ -0,0 +1,39 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationOk is a message sent from the backend indicating that authentication was successful. +type AuthenticationOk struct { +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationOk) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationOk) Decode(src []byte) error { + if len(src) != 4 { + return errors.New("bad authentication message size") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeOk { + return errors.New("bad auth type") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationOk) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendUint32(dst, AuthTypeOk) + return dst +} diff --git a/authentication_sasl.go b/authentication_sasl.go new file mode 100644 index 00000000..c57ae32d --- /dev/null +++ b/authentication_sasl.go @@ -0,0 +1,60 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required. +type AuthenticationSASL struct { + AuthMechanisms []string +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationSASL) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationSASL) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeSASL { + return errors.New("bad auth type") + } + + authMechanisms := src[4:] + for len(authMechanisms) > 1 { + idx := bytes.IndexByte(authMechanisms, 0) + if idx > 0 { + dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx])) + authMechanisms = authMechanisms[idx+1:] + } + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationSASL) Encode(dst []byte) []byte { + dst = append(dst, 'R') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint32(dst, AuthTypeSASL) + + for _, s := range src.AuthMechanisms { + dst = append(dst, []byte(s)...) + dst = append(dst, 0) + } + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} diff --git a/authentication_sasl_continue.go b/authentication_sasl_continue.go new file mode 100644 index 00000000..a393ae10 --- /dev/null +++ b/authentication_sasl_continue.go @@ -0,0 +1,49 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge. +type AuthenticationSASLContinue struct { + Data []byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationSASLContinue) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationSASLContinue) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeSASLContinue { + return errors.New("bad auth type") + } + + dst.Data = src[4:] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { + dst = append(dst, 'R') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint32(dst, AuthTypeSASLContinue) + + dst = pgio.AppendInt32(dst, int32(len(src.Data))) + dst = append(dst, src.Data...) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} diff --git a/authentication_sasl_final.go b/authentication_sasl_final.go new file mode 100644 index 00000000..b8f89d59 --- /dev/null +++ b/authentication_sasl_final.go @@ -0,0 +1,49 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed. +type AuthenticationSASLFinal struct { + Data []byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationSASLFinal) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationSASLFinal) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeSASLFinal { + return errors.New("bad auth type") + } + + dst.Data = src[4:] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { + dst = append(dst, 'R') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint32(dst, AuthTypeSASLFinal) + + dst = pgio.AppendInt32(dst, int32(len(src.Data))) + dst = append(dst, src.Data...) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} diff --git a/frontend.go b/frontend.go index a67b6670..0826685b 100644 --- a/frontend.go +++ b/frontend.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "errors" "fmt" "io" ) @@ -12,29 +13,34 @@ type Frontend struct { w io.Writer // Backend message flyweights - authentication Authentication - backendKeyData BackendKeyData - bindComplete BindComplete - closeComplete CloseComplete - commandComplete CommandComplete - copyBothResponse CopyBothResponse - copyData CopyData - copyInResponse CopyInResponse - copyOutResponse CopyOutResponse - copyDone CopyDone - dataRow DataRow - emptyQueryResponse EmptyQueryResponse - errorResponse ErrorResponse - functionCallResponse FunctionCallResponse - noData NoData - noticeResponse NoticeResponse - notificationResponse NotificationResponse - parameterDescription ParameterDescription - parameterStatus ParameterStatus - parseComplete ParseComplete - readyForQuery ReadyForQuery - rowDescription RowDescription - portalSuspended PortalSuspended + authenticationOk AuthenticationOk + authenticationCleartextPassword AuthenticationCleartextPassword + authenticationMD5Password AuthenticationMD5Password + authenticationSASL AuthenticationSASL + authenticationSASLContinue AuthenticationSASLContinue + authenticationSASLFinal AuthenticationSASLFinal + backendKeyData BackendKeyData + bindComplete BindComplete + closeComplete CloseComplete + commandComplete CommandComplete + copyBothResponse CopyBothResponse + copyData CopyData + copyInResponse CopyInResponse + copyOutResponse CopyOutResponse + copyDone CopyDone + dataRow DataRow + emptyQueryResponse EmptyQueryResponse + errorResponse ErrorResponse + functionCallResponse FunctionCallResponse + noData NoData + noticeResponse NoticeResponse + notificationResponse NotificationResponse + parameterDescription ParameterDescription + parameterStatus ParameterStatus + parseComplete ParseComplete + readyForQuery ReadyForQuery + rowDescription RowDescription + portalSuspended PortalSuspended bodyLen int msgType byte @@ -47,83 +53,121 @@ func NewFrontend(cr ChunkReader, w io.Writer) *Frontend { } // Send sends a message to the backend. -func (b *Frontend) Send(msg FrontendMessage) error { - _, err := b.w.Write(msg.Encode(nil)) +func (f *Frontend) Send(msg FrontendMessage) error { + _, err := f.w.Write(msg.Encode(nil)) return err } // Receive receives a message from the backend. -func (b *Frontend) Receive() (BackendMessage, error) { - if !b.partialMsg { - header, err := b.cr.Next(5) +func (f *Frontend) Receive() (BackendMessage, error) { + if !f.partialMsg { + header, err := f.cr.Next(5) if err != nil { return nil, err } - b.msgType = header[0] - b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 - b.partialMsg = true + f.msgType = header[0] + f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + f.partialMsg = true } - var msg BackendMessage - switch b.msgType { - case '1': - msg = &b.parseComplete - case '2': - msg = &b.bindComplete - case '3': - msg = &b.closeComplete - case 'A': - msg = &b.notificationResponse - case 'c': - msg = &b.copyDone - case 'C': - msg = &b.commandComplete - case 'd': - msg = &b.copyData - case 'D': - msg = &b.dataRow - case 'E': - msg = &b.errorResponse - case 'G': - msg = &b.copyInResponse - case 'H': - msg = &b.copyOutResponse - case 'I': - msg = &b.emptyQueryResponse - case 'K': - msg = &b.backendKeyData - case 'n': - msg = &b.noData - case 'N': - msg = &b.noticeResponse - case 'R': - msg = &b.authentication - case 's': - msg = &b.portalSuspended - case 'S': - msg = &b.parameterStatus - case 't': - msg = &b.parameterDescription - case 'T': - msg = &b.rowDescription - case 'V': - msg = &b.functionCallResponse - case 'W': - msg = &b.copyBothResponse - case 'Z': - msg = &b.readyForQuery - default: - return nil, fmt.Errorf("unknown message type: %c", b.msgType) - } - - msgBody, err := b.cr.Next(b.bodyLen) + msgBody, err := f.cr.Next(f.bodyLen) if err != nil { return nil, err } - b.partialMsg = false + f.partialMsg = false + + var msg BackendMessage + switch f.msgType { + case '1': + msg = &f.parseComplete + case '2': + msg = &f.bindComplete + case '3': + msg = &f.closeComplete + case 'A': + msg = &f.notificationResponse + case 'c': + msg = &f.copyDone + case 'C': + msg = &f.commandComplete + case 'd': + msg = &f.copyData + case 'D': + msg = &f.dataRow + case 'E': + msg = &f.errorResponse + case 'G': + msg = &f.copyInResponse + case 'H': + msg = &f.copyOutResponse + case 'I': + msg = &f.emptyQueryResponse + case 'K': + msg = &f.backendKeyData + case 'n': + msg = &f.noData + case 'N': + msg = &f.noticeResponse + case 'R': + var err error + msg, err = f.findAuthenticationMessageType(msgBody) + if err != nil { + return nil, err + } + case 's': + msg = &f.portalSuspended + case 'S': + msg = &f.parameterStatus + case 't': + msg = &f.parameterDescription + case 'T': + msg = &f.rowDescription + case 'V': + msg = &f.functionCallResponse + case 'W': + msg = &f.copyBothResponse + case 'Z': + msg = &f.readyForQuery + default: + return nil, fmt.Errorf("unknown message type: %c", f.msgType) + } err = msg.Decode(msgBody) return msg, err } + +// Authentication message type constants. +const ( + AuthTypeOk = 0 + AuthTypeCleartextPassword = 3 + AuthTypeMD5Password = 5 + AuthTypeSASL = 10 + AuthTypeSASLContinue = 11 + AuthTypeSASLFinal = 12 +) + +func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) { + if len(src) < 4 { + return nil, errors.New("authentication message too short") + } + authType := binary.BigEndian.Uint32(src[:4]) + + switch authType { + case AuthTypeOk: + return &f.authenticationOk, nil + case AuthTypeCleartextPassword: + return &f.authenticationCleartextPassword, nil + case AuthTypeMD5Password: + return &f.authenticationMD5Password, nil + case AuthTypeSASL: + return &f.authenticationSASL, nil + case AuthTypeSASLContinue: + return &f.authenticationSASLContinue, nil + case AuthTypeSASLFinal: + return &f.authenticationSASLFinal, nil + default: + return nil, fmt.Errorf("unknown authentication type: %d", authType) + } +} From 2fabfa3c18b7bcb4f204c365f2f0d2e09d4564eb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 15:44:54 -0500 Subject: [PATCH 0335/1158] Update to newest pgproto3 --- auth_scram.go | 34 ++++++++++++++++++++++------------ config.go | 2 +- go.mod | 2 +- go.sum | 8 ++------ pgconn.go | 40 ++++++++++++++++++++-------------------- 5 files changed, 46 insertions(+), 40 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index 6d6d0651..665fc2c2 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -47,11 +47,11 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { } // Receive server-first-message payload in a AuthenticationSASLContinue. - authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue) + saslContinue, err := c.rxSASLContinue() if err != nil { return err } - err = sc.recvServerFirstMessage(authMsg.SASLData) + err = sc.recvServerFirstMessage(saslContinue.Data) if err != nil { return err } @@ -66,27 +66,37 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { } // Receive server-final-message payload in a AuthenticationSASLFinal. - authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal) + saslFinal, err := c.rxSASLFinal() if err != nil { return err } - return sc.recvServerFinalMessage(authMsg.SASLData) + return sc.recvServerFinalMessage(saslFinal.Data) } -func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) { +func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) { msg, err := c.receiveMessage() if err != nil { return nil, err } - authMsg, ok := msg.(*pgproto3.Authentication) - if !ok { - return nil, errors.New("unexpected message type") - } - if authMsg.Type != typ { - return nil, errors.New("unexpected auth type") + saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue) + if ok { + return saslContinue, nil } - return authMsg, nil + return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message") +} + +func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal) + if ok { + return saslFinal, nil + } + + return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message") } type scramClient struct { diff --git a/config.go b/config.go index d24d0202..d1267621 100644 --- a/config.go +++ b/config.go @@ -501,7 +501,7 @@ func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { if err != nil { panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err)) } - frontend, _ := pgproto3.NewFrontend(cr, w) + frontend := pgproto3.NewFrontend(cr, w) return frontend } diff --git a/go.mod b/go.mod index b54607b6..6e270cd6 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-rc2 + github.com/jackc/pgproto3/v2 v2.0.0-rc3 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 diff --git a/go.sum b/go.sum index d7a6d087..ed8eb401 100644 --- a/go.sum +++ b/go.sum @@ -1,17 +1,13 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= -github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3/v2 v2.0.0-rc2 h1:u+jUsxBxiLY2C6mhr8cZhSy71n/y8Id2STOzJ7bl2Mg= -github.com/jackc/pgproto3/v2 v2.0.0-rc2/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/jackc/pgproto3/v2 v2.0.0-rc3 h1:EHkgVE6iDyI7HZDfMPaZ2Xjdf7C29DikR6o39WVO61c= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/pgconn.go b/pgconn.go index 1e3f9515..d51eb76a 100644 --- a/pgconn.go +++ b/pgconn.go @@ -210,11 +210,28 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig case *pgproto3.BackendKeyData: pgConn.pid = msg.ProcessID pgConn.secretKey = msg.SecretKey - case *pgproto3.Authentication: - if err = pgConn.rxAuthenticationX(msg); err != nil { + + case *pgproto3.AuthenticationOk: + case *pgproto3.AuthenticationCleartextPassword: + err = pgConn.txPasswordMessage(pgConn.config.Password) + if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed handle authentication message", err: err} + return nil, &connectError{config: config, msg: "failed to write password message", err: err} } + case *pgproto3.AuthenticationMD5Password: + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) + err = pgConn.txPasswordMessage(digestedPassword) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed to write password message", err: err} + } + case *pgproto3.AuthenticationSASL: + err = pgConn.scramAuth(msg.AuthMechanisms) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed SASL auth", err: err} + } + case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle if config.ValidateConnect != nil { @@ -257,23 +274,6 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { return nil } -func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { - switch msg.Type { - case pgproto3.AuthTypeOk: - case pgproto3.AuthTypeCleartextPassword: - err = pgConn.txPasswordMessage(pgConn.config.Password) - case pgproto3.AuthTypeMD5Password: - digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) - err = pgConn.txPasswordMessage(digestedPassword) - case pgproto3.AuthTypeSASL: - err = pgConn.scramAuth(msg.SASLAuthMechanisms) - default: - err = errors.New("Received unknown authentication message") - } - - return -} - func (pgConn *PgConn) txPasswordMessage(password string) (err error) { msg := &pgproto3.PasswordMessage{Password: password} _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf)) From 4c03ce451f299969fd0fff6212e1f1e696b0f523 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 16:00:41 -0500 Subject: [PATCH 0336/1158] Add MarshalJSON for FieldDescription --- row_description.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/row_description.go b/row_description.go index 2745fa43..d9b8c7c9 100644 --- a/row_description.go +++ b/row_description.go @@ -23,6 +23,27 @@ type FieldDescription struct { Format int16 } +// MarshalJSON implements encoding/json.Marshaler. +func (fd FieldDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + Format int16 + }{ + Name: string(fd.Name), + TableOID: fd.TableOID, + TableAttributeNumber: fd.TableAttributeNumber, + DataTypeOID: fd.DataTypeOID, + DataTypeSize: fd.DataTypeSize, + TypeModifier: fd.TypeModifier, + Format: fd.Format, + }) +} + type RowDescription struct { Fields []FieldDescription } From 2f6b8f3f5665228c0800e66b05e797ef119f3ef2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 17:01:54 -0500 Subject: [PATCH 0337/1158] Fix context timeout on connect --- go.mod | 11 ++++--- go.sum | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++ pgconn.go | 3 ++ pgconn_test.go | 62 +++++++++++++++++++++++++++++++++++ 4 files changed, 159 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 6e270cd6..11692c10 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,11 @@ go 1.12 require ( github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 + github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-rc3 - github.com/stretchr/testify v1.3.0 - golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a - golang.org/x/text v0.3.0 - golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 + github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 + github.com/stretchr/testify v1.4.0 + golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 + golang.org/x/text v0.3.2 + golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 ) diff --git a/go.sum b/go.sum index ed8eb401..c1b3d405 100644 --- a/go.sum +++ b/go.sum @@ -1,23 +1,111 @@ +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgproto3/v2 v2.0.0-rc3 h1:EHkgVE6iDyI7HZDfMPaZ2Xjdf7C29DikR6o39WVO61c= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI1NIJyNFVVeh1gUISyt57iw/fmI/IXJfH3ATE= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/pgconn.go b/pgconn.go index d51eb76a..5c01d1dc 100644 --- a/pgconn.go +++ b/pgconn.go @@ -174,6 +174,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig func() { pgConn.conn.SetDeadline(time.Time{}) }, ) + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) startupMsg := pgproto3.StartupMessage{ diff --git a/pgconn_test.go b/pgconn_test.go index 3fbdf8df..4a67a2e0 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -18,6 +18,7 @@ import ( "time" "github.com/jackc/pgconn" + "github.com/jackc/pgmock" "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" @@ -73,6 +74,67 @@ func TestConnectTLS(t *testing.T) { closeConn(t, conn) } +type pgmockWaitStep time.Duration + +func (s pgmockWaitStep) Step(*pgproto3.Backend) error { + time.Sleep(time.Duration(s)) + return nil +} + +func TestConnectWithContextThatTimesOut(t *testing.T) { + t.Parallel() + + script := &pgmock.Script{ + Steps: []pgmock.Step{ + pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + pgmock.SendMessage(&pgproto3.AuthenticationOk{}), + pgmockWaitStep(time.Millisecond * 500), + pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + }, + } + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + tooLate := time.Now().Add(time.Millisecond * 500) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancel() + _, err = pgconn.Connect(ctx, connStr) + require.True(t, pgconn.Timeout(err), err) + require.True(t, time.Now().Before(tooLate)) +} + func TestConnectInvalidUser(t *testing.T) { t.Parallel() From 80f2cbce255e0b0dc138baae0c50a7ce755579b4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Sep 2019 11:37:43 -0500 Subject: [PATCH 0338/1158] Add pgfortune example --- README.md | 4 ++ example/pgfortune/README.md | 53 ++++++++++++++++++ example/pgfortune/main.go | 48 +++++++++++++++++ example/pgfortune/server.go | 104 ++++++++++++++++++++++++++++++++++++ 4 files changed, 209 insertions(+) create mode 100644 example/pgfortune/README.md create mode 100644 example/pgfortune/main.go create mode 100644 example/pgfortune/server.go diff --git a/README.md b/README.md index 3491780e..565b3efd 100644 --- a/README.md +++ b/README.md @@ -5,4 +5,8 @@ Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. +pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more. + +See example/pgfortune for a playful example of a fake PostgreSQL server. + Extracted from original implementation in https://github.com/jackc/pgx. diff --git a/example/pgfortune/README.md b/example/pgfortune/README.md new file mode 100644 index 00000000..c181c38a --- /dev/null +++ b/example/pgfortune/README.md @@ -0,0 +1,53 @@ +# pgfortune + +pgfortune is a mock PostgreSQL server that responds to every query with a fortune. + +## Installation + +Install `fortune` and `cowsay`. They should be available in any Unix package manager (apt, yum, brew, etc.) + +``` +go get -u github.com/jackc/pgproto3/example/pgfortune +``` + +## Usage + +``` +$ pgfortune +``` + +By default pgfortune listens on 127.0.0.1:15432 and responds to queries with `fortune | cowsay -f elephant`. These are +configurable with the `listen` and `response-command` arguments respectively. + +While `pgfortune` is running connect to it with `psql`. + +``` +$ psql -h 127.0.0.1 -p 15432 +Timing is on. +Null display is "∅". +Line style is unicode. +psql (11.5, server 0.0.0) +Type "help" for help. + +jack@127.0.0.1:15432 jack=# select foo; + fortune +───────────────────────────────────────────── + _________________________________________ ↵ + / Ships are safe in harbor, but they were \↵ + \ never meant to stay there. /↵ + ----------------------------------------- ↵ + \ /\ ___ /\ ↵ + \ // \/ \/ \\ ↵ + (( O O )) ↵ + \\ / \ // ↵ + \/ | | \/ ↵ + | | | | ↵ + | | | | ↵ + | o | ↵ + | | | | ↵ + |m| |m| ↵ + +(1 row) + +Time: 28.161 ms +``` diff --git a/example/pgfortune/main.go b/example/pgfortune/main.go new file mode 100644 index 00000000..45970eb3 --- /dev/null +++ b/example/pgfortune/main.go @@ -0,0 +1,48 @@ +package main + +import ( + "flag" + "fmt" + "log" + "net" + "os" + "os/exec" +) + +var options struct { + listenAddress string + responseCommand string +} + +func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "usage: %s [options]\n", os.Args[0]) + flag.PrintDefaults() + } + + flag.StringVar(&options.listenAddress, "listen", "127.0.0.1:15432", "Listen address") + flag.StringVar(&options.responseCommand, "response-command", "fortune | cowsay -f elephant", "Command to execute to generate query response") + flag.Parse() + + ln, err := net.Listen("tcp", options.listenAddress) + if err != nil { + log.Fatal(err) + } + + for { + conn, err := ln.Accept() + if err != nil { + log.Fatal(err) + } + + b := NewPgFortuneBackend(conn, func() ([]byte, error) { + return exec.Command("sh", "-c", options.responseCommand).CombinedOutput() + }) + go func() { + err := b.Run() + if err != nil { + log.Println(err) + } + }() + } +} diff --git a/example/pgfortune/server.go b/example/pgfortune/server.go new file mode 100644 index 00000000..777192a6 --- /dev/null +++ b/example/pgfortune/server.go @@ -0,0 +1,104 @@ +package main + +import ( + "fmt" + "net" + + "github.com/jackc/pgproto3/v2" +) + +type PgFortuneBackend struct { + backend *pgproto3.Backend + conn net.Conn + responder func() ([]byte, error) +} + +func NewPgFortuneBackend(conn net.Conn, responder func() ([]byte, error)) *PgFortuneBackend { + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn) + + connHandler := &PgFortuneBackend{ + backend: backend, + conn: conn, + responder: responder, + } + + return connHandler +} + +func (p *PgFortuneBackend) Run() error { + defer p.Close() + + err := p.handleStartup() + if err != nil { + return err + } + + for { + msg, err := p.backend.Receive() + if err != nil { + return fmt.Errorf("error receiving message: %w", err) + } + + switch msg.(type) { + case *pgproto3.Query: + response, err := p.responder() + if err != nil { + return fmt.Errorf("error generating query response: %w", err) + } + + buf := (&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + { + Name: []byte("fortune"), + TableOID: 0, + TableAttributeNumber: 0, + DataTypeOID: 25, + DataTypeSize: -1, + TypeModifier: -1, + Format: 0, + }, + }}).Encode(nil) + buf = (&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf) + buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf) + buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf) + _, err = p.conn.Write(buf) + if err != nil { + return fmt.Errorf("error writing query response: %w", err) + } + case *pgproto3.Terminate: + return nil + default: + return fmt.Errorf("received message other than Query from client: %#v", msg) + } + } +} + +func (p *PgFortuneBackend) handleStartup() error { + startupMessage, err := p.backend.ReceiveStartupMessage() + if err != nil { + return fmt.Errorf("error receiving startup message: %w", err) + } + + switch startupMessage.(type) { + case *pgproto3.StartupMessage: + buf := (&pgproto3.AuthenticationOk{}).Encode(nil) + buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf) + _, err = p.conn.Write(buf) + if err != nil { + return fmt.Errorf("error sending ready for query: %w", err) + } + case *pgproto3.SSLRequest: + _, err = p.conn.Write([]byte("N")) + if err != nil { + return fmt.Errorf("error sending deny SSL request: %w", err) + } + return p.handleStartup() + default: + return fmt.Errorf("unknown startup message: %#v", startupMessage) + } + + return nil +} + +func (p *PgFortuneBackend) Close() error { + return p.conn.Close() +} From eca1e51822f3ebd8f48f62029d3e32f931d32c32 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Sep 2019 11:41:31 -0500 Subject: [PATCH 0339/1158] Add more pgfortune output --- example/pgfortune/main.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/example/pgfortune/main.go b/example/pgfortune/main.go index 45970eb3..0c25510b 100644 --- a/example/pgfortune/main.go +++ b/example/pgfortune/main.go @@ -28,12 +28,14 @@ func main() { if err != nil { log.Fatal(err) } + log.Println("Listening on", ln.Addr()) for { conn, err := ln.Accept() if err != nil { log.Fatal(err) } + log.Println("Accepted connection from", conn.RemoteAddr()) b := NewPgFortuneBackend(conn, func() ([]byte, error) { return exec.Command("sh", "-c", options.responseCommand).CombinedOutput() @@ -43,6 +45,7 @@ func main() { if err != nil { log.Println(err) } + log.Println("Closed connection from", conn.RemoteAddr()) }() } } From a90ef7ed5b85de58f81c707bbd3d75720b47e6c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20Geisend=C3=B6rfer?= Date: Sun, 8 Sep 2019 17:29:06 +0200 Subject: [PATCH 0340/1158] fix: AuthenticationMD5Password AuthType --- authentication_md5_password.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/authentication_md5_password.go b/authentication_md5_password.go index 4680db5a..d505d264 100644 --- a/authentication_md5_password.go +++ b/authentication_md5_password.go @@ -37,7 +37,7 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error { func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { dst = append(dst, 'R') dst = pgio.AppendInt32(dst, 12) - dst = pgio.AppendUint32(dst, AuthTypeOk) + dst = pgio.AppendUint32(dst, AuthTypeMD5Password) dst = append(dst, src.Salt[:]...) return dst } From a8362ef96d23eb9e53a9eb57bb12889f8cbaa1c2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 10 Sep 2019 17:14:04 -0500 Subject: [PATCH 0341/1158] Parse postgresql:// protocol --- config.go | 2 +- config_test.go | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index d1267621..2ec6ae3f 100644 --- a/config.go +++ b/config.go @@ -152,7 +152,7 @@ func ParseConfig(connString string) (*Config, error) { if connString != "" { // connString may be a database URL or a DSN - if strings.HasPrefix(connString, "postgres://") { + if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { err := addURLSettings(settings, connString) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} diff --git a/config_test.go b/config_test.go index af42094d..090302a2 100644 --- a/config_test.go +++ b/config_test.go @@ -214,6 +214,18 @@ func TestParseConfig(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + name: "database url postgresql protocol", + connString: "postgresql://jack@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { name: "DSN everything", connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema", From f8be2b60ce34bf79b747009b9cc7fb718b918734 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 10 Sep 2019 17:25:25 -0500 Subject: [PATCH 0342/1158] go.sum changes --- go.sum | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/go.sum b/go.sum index c1b3d405..d0a917fc 100644 --- a/go.sum +++ b/go.sum @@ -2,11 +2,11 @@ github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMe github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= @@ -19,10 +19,10 @@ github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye47 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= -github.com/jackc/pgproto3/v2 v2.0.0-rc3 h1:EHkgVE6iDyI7HZDfMPaZ2Xjdf7C29DikR6o39WVO61c= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI1NIJyNFVVeh1gUISyt57iw/fmI/IXJfH3ATE= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= @@ -48,7 +48,6 @@ github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -63,7 +62,6 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= @@ -74,7 +72,6 @@ go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/ go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -86,12 +83,10 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -99,7 +94,6 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From b2ca5d8f521597a28e8dc0703b9b2a8c72d9866a Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Fri, 13 Sep 2019 17:26:09 +0300 Subject: [PATCH 0343/1158] validate all addresses resolved from hostname Signed-off-by: Artemiy Ryabinkov --- config.go | 9 ++++++++- pgconn.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 2ec6ae3f..57e65e13 100644 --- a/config.go +++ b/config.go @@ -37,6 +37,7 @@ type Config struct { Password string TLSConfig *tls.Config // nil disables TLS DialFunc DialFunc // e.g. net.Dialer.DialContext + LookupFunc LookupFunc // e.g. net.Resolver.LookupHost BuildFrontend BuildFrontendFunc RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) @@ -77,7 +78,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) } else { network = "tcp" - address = fmt.Sprintf("%s:%d", host, port) + address = net.JoinHostPort(host, strconv.Itoa(int(port))) } return network, address } @@ -190,6 +191,8 @@ func ParseConfig(connString string) (*Config, error) { config.DialFunc = defaultDialer.DialContext } + config.LookupFunc = makeDefaultResolver().LookupHost + notRuntimeParams := map[string]struct{}{ "host": struct{}{}, "port": struct{}{}, @@ -495,6 +498,10 @@ func makeDefaultDialer() *net.Dialer { return &net.Dialer{KeepAlive: 5 * time.Minute} } +func makeDefaultResolver() *net.Resolver { + return net.DefaultResolver +} + func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { return func(r io.Reader, w io.Writer) Frontend { cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen}) diff --git a/pgconn.go b/pgconn.go index 5c01d1dc..db2ebe73 100644 --- a/pgconn.go +++ b/pgconn.go @@ -43,6 +43,9 @@ type Notification struct { // DialFunc is a function that can be used to connect to a PostgreSQL server. type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) +// LookupFunc is a function that can be used to lookup IPs addrs from host. +type LookupFunc func(ctx context.Context, host string) (addrs []string, err error) + // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend @@ -123,6 +126,15 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err } fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) + fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) + if err != nil { + return nil, &connectError{config: config, msg: "hostname resolving error", err: err} + } + + if len(fallbackConfigs) == 0 { + return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} + } + for _, fc := range fallbackConfigs { pgConn, err = connect(ctx, config, fc) if err == nil { @@ -147,6 +159,27 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err return pgConn, nil } +func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) { + var configs []*FallbackConfig + + for _, fb := range fallbacks { + ips, err := lookupFn(ctx, fb.Host) + if err != nil { + return nil, err + } + + for _, ip := range ips { + configs = append(configs, &FallbackConfig{ + Host: ip, + Port: fb.Port, + TLSConfig: fb.TLSConfig, + }) + } + } + + return configs, nil +} + func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config From e538885fa71f92c5974c28edb49db682a0194a33 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Fri, 13 Sep 2019 17:52:01 +0300 Subject: [PATCH 0344/1158] skip resolve for unix sockets Signed-off-by: Artemiy Ryabinkov --- pgconn.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pgconn.go b/pgconn.go index db2ebe73..25f4f4d5 100644 --- a/pgconn.go +++ b/pgconn.go @@ -163,6 +163,17 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba var configs []*FallbackConfig for _, fb := range fallbacks { + // skip resolve for unix sockets + if strings.HasPrefix(fb.Host, "/") { + configs = append(configs, &FallbackConfig{ + Host: fb.Host, + Port: fb.Port, + TLSConfig: fb.TLSConfig, + }) + + continue + } + ips, err := lookupFn(ctx, fb.Host) if err != nil { return nil, err From 17d3d592e980720a8baa9a98e91a3de9fec06af7 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sat, 14 Sep 2019 19:11:26 +0300 Subject: [PATCH 0345/1158] add test for custom lookup func Signed-off-by: Artemiy Ryabinkov --- pgconn_test.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/pgconn_test.go b/pgconn_test.go index 4a67a2e0..36499b68 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -188,6 +188,24 @@ func TestConnectCustomDialer(t *testing.T) { closeConn(t, conn) } +func TestConnectCustomLookup(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + looked := false + config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { + looked = true + return net.LookupHost(host) + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) +} + func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel() From 99f22ac8e4c9c142d9541ab648274e7663357fab Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 14 Sep 2019 18:37:33 -0500 Subject: [PATCH 0346/1158] Port DSN parser from pgx v3 Original implementation: 2d9d8dc52ac211c6191c08e050c03588aa633038 by Joshua Barone . Also changed DSN tests to use "dbname" as key rather than "database" as that is what the PostgreSQL documentation specifies. "database" still actually works but it should not be encouraged as it is non-standard. --- .travis.yml | 8 ++--- README.md | 2 +- config.go | 61 ++++++++++++++++++++++++++++++++++--- config_test.go | 82 +++++++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 139 insertions(+), 14 deletions(-) diff --git a/.travis.yml b/.travis.yml index 2c547abf..abff8515 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,15 +17,15 @@ env: - GOPROXY=https://proxy.golang.org - GOFLAGS=-mod=readonly - PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" + - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_TLS_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - PGX_TEST_MD5_PASSWORD_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_PLAIN_PASSWORD_CONN_STRING=postgres://pgx_pw:secret@127.0.0.1/pgx_test matrix: - - CRATEVERSION=2.1 PGX_TEST_CRATEDB_CONN_STRING="host=127.0.0.1 port=6543 user=pgx database=pgx_test" - - PGVERSION=10 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret database=pgx_test" - - PGVERSION=9.6 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret database=pgx_test" + - CRATEVERSION=2.1 PGX_TEST_CRATEDB_CONN_STRING="host=127.0.0.1 port=6543 user=pgx dbname=pgx_test" + - PGVERSION=10 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret dbname=pgx_test" + - PGVERSION=9.6 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret dbname=pgx_test" - PGVERSION=9.5 - PGVERSION=9.4 - PGVERSION=9.3 diff --git a/README.md b/README.md index 9e35a0f5..aa980b6d 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ create database pgx_test; Now you can run the tests: ``` -PGX_TEST_CONN_STRING="host=/var/run/postgresql database=pgx_test" go test ./... +PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./... ``` ### Connection and Authentication Tests diff --git a/config.go b/config.go index 2ec6ae3f..6eb0065a 100644 --- a/config.go +++ b/config.go @@ -13,7 +13,6 @@ import ( "os" "os/user" "path/filepath" - "regexp" "strconv" "strings" "time" @@ -389,13 +388,65 @@ func addURLSettings(settings map[string]string, connString string) error { return nil } -var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) +var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} func addDSNSettings(settings map[string]string, s string) error { - m := dsnRegexp.FindAllStringSubmatch(s, -1) + nameMap := map[string]string{ + "dbname": "database", + } - for _, b := range m { - settings[b[1]] = b[2] + for len(s) > 0 { + var key, val string + eqIdx := strings.IndexRune(s, '=') + if eqIdx < 0 { + return errors.New("invalid dsn") + } + + key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") + s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f") + if s[0] != '\'' { + end := 0 + for ; end < len(s); end++ { + if asciiSpace[s[end]] == 1 { + break + } + if s[end] == '\\' { + end++ + } + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } else { // quoted string + s = s[1:] + end := 0 + for ; end < len(s); end++ { + if s[end] == '\'' { + break + } + if s[end] == '\\' { + end++ + } + } + if end == len(s) { + return errors.New("unterminated quoted string in connection info string") + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } + + if k, ok := nameMap[key]; ok { + key = k + } + + settings[key] = val } return nil diff --git a/config_test.go b/config_test.go index 090302a2..9eb5df2f 100644 --- a/config_test.go +++ b/config_test.go @@ -228,7 +228,7 @@ func TestParseConfig(t *testing.T) { }, { name: "DSN everything", - connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema", + connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema", config: &pgconn.Config{ User: "jack", Password: "secret", @@ -242,6 +242,80 @@ func TestParseConfig(t *testing.T) { }, }, }, + { + name: "DSN with escaped single quote", + connString: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack's", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with escaped backslash", + connString: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "sooper\\secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with single quoted values", + connString: "user='jack' host='localhost' dbname='mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with single quoted value with escaped single quote", + connString: "user='jack\\'s' host='localhost' dbname='mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack's", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with empty single quoted value", + connString: "user='jack' password='' host='localhost' dbname='mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with space between key and value", + connString: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { name: "URL multiple hosts", connString: "postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable", @@ -294,7 +368,7 @@ func TestParseConfig(t *testing.T) { }, { name: "DSN multiple hosts one port", - connString: "user=jack password=secret host=foo,bar,baz port=5432 database=mydb sslmode=disable", + connString: "user=jack password=secret host=foo,bar,baz port=5432 dbname=mydb sslmode=disable", config: &pgconn.Config{ User: "jack", Password: "secret", @@ -319,7 +393,7 @@ func TestParseConfig(t *testing.T) { }, { name: "DSN multiple hosts multiple ports", - connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 database=mydb sslmode=disable", + connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 dbname=mydb sslmode=disable", config: &pgconn.Config{ User: "jack", Password: "secret", @@ -344,7 +418,7 @@ func TestParseConfig(t *testing.T) { }, { name: "multiple hosts and fallback tsl", - connString: "user=jack password=secret host=foo,bar,baz database=mydb sslmode=prefer", + connString: "user=jack password=secret host=foo,bar,baz dbname=mydb sslmode=prefer", config: &pgconn.Config{ User: "jack", Password: "secret", From cf8fe4a477596f78341eafa0dca2f378719cbb4a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 14 Sep 2019 19:57:32 -0500 Subject: [PATCH 0347/1158] uuid extension switched to gofrs from satori Do not encourage library use that has serious outstanding bug: https://github.com/satori/go.uuid/issues/73 --- ext/{satori-uuid => gofrs-uuid}/uuid.go | 2 +- ext/{satori-uuid => gofrs-uuid}/uuid_test.go | 22 ++++++++++---------- go.mod | 2 +- go.sum | 2 ++ 4 files changed, 15 insertions(+), 13 deletions(-) rename ext/{satori-uuid => gofrs-uuid}/uuid.go (99%) rename ext/{satori-uuid => gofrs-uuid}/uuid_test.go (63%) diff --git a/ext/satori-uuid/uuid.go b/ext/gofrs-uuid/uuid.go similarity index 99% rename from ext/satori-uuid/uuid.go rename to ext/gofrs-uuid/uuid.go index 9b958b58..9b95a225 100644 --- a/ext/satori-uuid/uuid.go +++ b/ext/gofrs-uuid/uuid.go @@ -5,8 +5,8 @@ import ( errors "golang.org/x/xerrors" + "github.com/gofrs/uuid" "github.com/jackc/pgtype" - uuid "github.com/satori/go.uuid" ) var errUndefined = errors.New("cannot encode status undefined") diff --git a/ext/satori-uuid/uuid_test.go b/ext/gofrs-uuid/uuid_test.go similarity index 63% rename from ext/satori-uuid/uuid_test.go rename to ext/gofrs-uuid/uuid_test.go index 247470a3..124720b8 100644 --- a/ext/satori-uuid/uuid_test.go +++ b/ext/gofrs-uuid/uuid_test.go @@ -5,38 +5,38 @@ import ( "testing" "github.com/jackc/pgtype" - satori "github.com/jackc/pgtype/ext/satori-uuid" + gofrs "github.com/jackc/pgtype/ext/gofrs-uuid" "github.com/jackc/pgtype/testutil" ) func TestUUIDTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - &satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - &satori.UUID{Status: pgtype.Null}, + &gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + &gofrs.UUID{Status: pgtype.Null}, }) } func TestUUIDSet(t *testing.T) { successfulTests := []struct { source interface{} - result satori.UUID + result gofrs.UUID }{ { source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, { source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, { source: "00010203-0405-0607-0809-0a0b0c0d0e0f", - result: satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, } for i, tt := range successfulTests { - var r satori.UUID + var r gofrs.UUID err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -50,7 +50,7 @@ func TestUUIDSet(t *testing.T) { func TestUUIDAssignTo(t *testing.T) { { - src := satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst [16]byte expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} @@ -65,7 +65,7 @@ func TestUUIDAssignTo(t *testing.T) { } { - src := satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst []byte expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} @@ -80,7 +80,7 @@ func TestUUIDAssignTo(t *testing.T) { } { - src := satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst string expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" diff --git a/go.mod b/go.mod index dbe9f53c..9f47a705 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,10 @@ module github.com/jackc/pgtype go 1.12 require ( + github.com/gofrs/uuid v3.2.0+incompatible github.com/jackc/pgio v1.0.0 github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186 github.com/lib/pq v1.2.0 - github.com/satori/go.uuid v1.2.0 github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 github.com/stretchr/testify v1.4.0 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 diff --git a/go.sum b/go.sum index f9a56ffd..275e7fe1 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= From bbc7f67a6f5907a413ff3106ebf6c54d1f09101a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 14 Sep 2019 20:22:50 -0500 Subject: [PATCH 0348/1158] Update to pgproto3 v2.0.0 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 11692c10..4a188cce 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 + github.com/jackc/pgproto3/v2 v2.0.0 github.com/stretchr/testify v1.4.0 golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 golang.org/x/text v0.3.2 diff --git a/go.sum b/go.sum index d0a917fc..51c55d12 100644 --- a/go.sum +++ b/go.sum @@ -26,6 +26,8 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI1NIJyNFVVeh1gUISyt57iw/fmI/IXJfH3ATE= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0 h1:FApgMJ/GtaXfI0s8Lvd0kaLaRwMOhs4VH92pwkwQQvU= +github.com/jackc/pgproto3/v2 v2.0.0/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= From f517670ba59ed8443facc7ff16cb74aa30682158 Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Wed, 18 Sep 2019 13:51:01 -0700 Subject: [PATCH 0349/1158] Add tstzrange data type --- tstzrange_array.go | 301 +++++++++++++++++++++++++++++++++++++++++++++ typed_array_gen.sh | 1 + 2 files changed, 302 insertions(+) create mode 100644 tstzrange_array.go diff --git a/tstzrange_array.go b/tstzrange_array.go new file mode 100644 index 00000000..8180e4c2 --- /dev/null +++ b/tstzrange_array.go @@ -0,0 +1,301 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgio" + errors "golang.org/x/xerrors" +) + +type TstzrangeArray struct { + Elements []Tstzrange + Dimensions []ArrayDimension + Status Status +} + +func (dst *TstzrangeArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = TstzrangeArray{Status: Null} + return nil + } + + switch value := src.(type) { + + case []Tstzrange: + if value == nil { + *dst = TstzrangeArray{Status: Null} + } else if len(value) == 0 { + *dst = TstzrangeArray{Status: Present} + } else { + elements := make([]Tstzrange, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TstzrangeArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to TstzrangeArray", value) + } + + return nil +} + +func (dst *TstzrangeArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *TstzrangeArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]Tstzrange: + *v = make([]Tstzrange, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return errors.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %#v into %T", src, dst) +} + +func (dst *TstzrangeArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TstzrangeArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Tstzrange + + if len(uta.Elements) > 0 { + elements = make([]Tstzrange, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Tstzrange + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TstzrangeArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TstzrangeArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TstzrangeArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TstzrangeArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Tstzrange, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = TstzrangeArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *TstzrangeArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *TstzrangeArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("tstzrange"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "tstzrange") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TstzrangeArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TstzrangeArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 911fa392..76c174ef 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -4,6 +4,7 @@ erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64, erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange go_array_types=[]Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstz_range_array.go erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go From f5eead90fca09203d8af956fea01861884ed9a8a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 19 Sep 2019 21:04:14 -0500 Subject: [PATCH 0350/1158] Fix statement cache reuse bug --- stmtcache/lru.go | 4 +++- stmtcache/lru_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/stmtcache/lru.go b/stmtcache/lru.go index fff4d0b7..d82ced19 100644 --- a/stmtcache/lru.go +++ b/stmtcache/lru.go @@ -104,8 +104,10 @@ func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescrip func (c *LRU) removeOldest(ctx context.Context) error { oldest := c.l.Back() c.l.Remove(oldest) + psd := oldest.Value.(*pgconn.StatementDescription) + delete(c.m, psd.SQL) if c.mode == ModePrepare { - return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.StatementDescription).Name)).Close() + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close() } return nil } diff --git a/stmtcache/lru_test.go b/stmtcache/lru_test.go index b518364e..d2902dbb 100644 --- a/stmtcache/lru_test.go +++ b/stmtcache/lru_test.go @@ -2,6 +2,8 @@ package stmtcache_test import ( "context" + "fmt" + "math/rand" "os" "testing" "time" @@ -57,6 +59,30 @@ func TestLRUModePrepare(t *testing.T) { require.Empty(t, fetchServerStatements(t, ctx, conn)) } +func TestLRUModePrepareStress(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 8) + require.EqualValues(t, 0, cache.Len()) + require.EqualValues(t, 8, cache.Cap()) + require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) + + for i := 0; i < 1000; i++ { + psd, err := cache.Get(ctx, fmt.Sprintf("select %d", rand.Intn(50))) + require.NoError(t, err) + require.NotNil(t, psd) + result := conn.ExecPrepared(ctx, psd.Name, nil, nil, nil).Read() + require.NoError(t, result.Err) + } +} + func TestLRUModeDescribe(t *testing.T) { t.Parallel() From d6b0287fcda8ef85425ef39a43e0e10921877449 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 19 Sep 2019 21:41:20 -0500 Subject: [PATCH 0351/1158] Release v1.0.1 --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..5384b031 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,3 @@ +# 1.0.1 (September 19, 2019) + +* Fix statement cache not properly cleaning discarded statements From 52ae698572731734c6d8b988b86b0a4083f0b6c3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 19 Sep 2019 21:43:18 -0500 Subject: [PATCH 0352/1158] Fix daterange oid --- pgtype.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgtype.go b/pgtype.go index 3391a04f..05a4fb05 100644 --- a/pgtype.go +++ b/pgtype.go @@ -66,7 +66,7 @@ const ( UUIDOID = 2950 UUIDArrayOID = 2951 JSONBOID = 3802 - DaterangeOID = 3812 + DaterangeOID = 3912 Int4rangeOID = 3904 NumrangeOID = 3906 TsrangeOID = 3908 From 9dc453458c0eddf886c1b7b9d1c920e4e8e439eb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 19 Sep 2019 21:57:09 -0500 Subject: [PATCH 0353/1158] Release v1.0.1 --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..20605e2a --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,3 @@ +# 1.0.1 (September 19, 2019) + +* Fix daterange OID From eb20ab82192c3f4b02ed74ec0dec0e069e58f0cd Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Mon, 16 Sep 2019 15:22:33 -0400 Subject: [PATCH 0354/1158] Added a license -- fixes #3 --- LICENSE | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..dd9e7be9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2013 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. From 90d22fb483f81f749cf24e5a6700402615658603 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 26 Sep 2019 21:08:20 -0500 Subject: [PATCH 0355/1158] Add basic README.md --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 00000000..6848acc5 --- /dev/null +++ b/README.md @@ -0,0 +1,7 @@ +[![](https://godoc.org/github.com/jackc/pgtype?status.svg)](https://godoc.org/github.com/jackc/pgtype) + +# pgtype + +pgtype implements Go types for over 70 PostgreSQL types. pgtype is the type system underlying the +https://github.com/jackc/pgx PostgreSQL driver. These types support the binary format for enhanced performance with pgx. +They also support the database/sql `Scan` and `Value` interfaces and can be used with https://github.com/lib/pq. From fa5c331c789ed1902deaa2f58562669c3e4342d0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 26 Sep 2019 21:12:32 -0500 Subject: [PATCH 0356/1158] Add text format support to bit fixes #7 --- bit.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/bit.go b/bit.go index 4f40a532..925cfe7c 100644 --- a/bit.go +++ b/bit.go @@ -26,6 +26,14 @@ func (src Bit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return (Varbit)(src).EncodeBinary(ci, buf) } +func (dst *Bit) DecodeText(ci *ConnInfo, src []byte) error { + return (*Varbit)(dst).DecodeText(ci, src) +} + +func (src Bit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Varbit)(src).EncodeText(ci, buf) +} + // Scan implements the database/sql Scanner interface. func (dst *Bit) Scan(src interface{}) error { return (*Varbit)(dst).Scan(src) From 6c195c17b2af217104e20bb20da66c91d2b2f8f1 Mon Sep 17 00:00:00 2001 From: Francis Chuang <2263040+F21@users.noreply.github.com> Date: Thu, 3 Oct 2019 09:49:12 +1000 Subject: [PATCH 0357/1158] Fix minor errors and reword some sentences for readability --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index aa980b6d..5d14e914 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ # pgconn -Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level is the C library libpq. +Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq. It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx. Applications should handle normal queries with a higher level library and only use pgconn directly when required for low-level access to PostgreSQL functionality. @@ -17,7 +17,7 @@ if err != nil { } defer pgConn.Close() -result := pgConn.ExecParams(context.Background(), "select email from users where id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) for result.NextRow() { fmt.Println("User 123 has email:", string(result.Values()[0])) } @@ -29,7 +29,7 @@ if err != nil { ## Testing -pgconn tests need a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING` +The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING` environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*` environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify environment variable handling. @@ -44,13 +44,13 @@ create database pgx_test; Now you can run the tests: -``` +```bash PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./... ``` ### Connection and Authentication Tests -There are multiple connection types and means of authentication that pgconn supports. These tests are optional. They +Pgconn supports multiple connection types and means of authentication. These tests are optional. They will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being -skipped. Typical developers will not need to enable these tests. See travis.yml for example setup if you need change +skipped. Most developers will not need to enable these tests. See `travis.yml` for an example set up if you need change authentication code. From fcfd7d09a9079edbce62cf83d5d184e8b2dbc33e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Oct 2019 10:21:33 -0500 Subject: [PATCH 0358/1158] Add PgConn.IsBusy() method --- pgconn.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pgconn.go b/pgconn.go index 25f4f4d5..e3f3aaff 100644 --- a/pgconn.go +++ b/pgconn.go @@ -520,6 +520,11 @@ func (pgConn *PgConn) IsClosed() bool { return pgConn.status < connStatusIdle } +// IsBusy reports if the connection is busy. +func (pgConn *PgConn) IsBusy() bool { + return pgConn.status == connStatusBusy +} + // lock locks the connection. func (pgConn *PgConn) lock() error { switch pgConn.status { From 4df62cf3d029efb55dc1cc8d31144e9ed2d80d44 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Oct 2019 11:23:48 -0500 Subject: [PATCH 0359/1158] Release v1.1.0 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5384b031..92497f47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.1.0 (October 12, 2019) + +* Add PgConn.IsBusy() method + # 1.0.1 (September 19, 2019) * Fix statement cache not properly cleaning discarded statements From 81b6ad72f6dedf2162a06cdb3543de33b28ec2ff Mon Sep 17 00:00:00 2001 From: Skip Gibson Date: Wed, 16 Oct 2019 10:01:16 +0100 Subject: [PATCH 0360/1158] config: fix ValidateConnect comment --- config.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/config.go b/config.go index f41c38b9..628deed8 100644 --- a/config.go +++ b/config.go @@ -43,9 +43,8 @@ type Config struct { Fallbacks []*FallbackConfig // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. - // It can be used validate that server is acceptable. If this returns an error the connection is closed and the next - // fallback config is tried. This allows implementing high availability behavior such as libpq does with - // target_session_attrs. + // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next + // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. ValidateConnect ValidateConnectFunc // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables From f395b32fa66e1b729466b844db9453efc9b9e944 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 19 Oct 2019 11:43:24 -0500 Subject: [PATCH 0361/1158] Added failing test for pointer to custom type --- pgtype_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/pgtype_test.go b/pgtype_test.go index 8771b77f..9602f419 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "bytes" "net" "testing" @@ -9,6 +10,8 @@ import ( _ "github.com/jackc/pgx/v4/stdlib" _ "github.com/lib/pq" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + errors "golang.org/x/xerrors" ) // Test for renamed types @@ -41,7 +44,7 @@ func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { return addr } -func TestConnInfoScanUnknownOID(t *testing.T) { +func TestConnInfoScanUnknownOIDToStringsAndBytes(t *testing.T) { unknownOID := uint32(999999) srcBuf := []byte("foo") ci := pgtype.NewConnInfo() @@ -74,3 +77,53 @@ func TestConnInfoScanUnknownOID(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []byte("foo"), []byte(rb)) } + +type pgCustomType struct { + a string + b string +} + +func (ct *pgCustomType) DecodeText(ci *pgtype.ConnInfo, buf []byte) error { + // This is not a complete parser for the text format of composite types. This is just for test purposes. + if buf == nil { + return errors.New("cannot parse null") + } + + if len(buf) < 2 { + return errors.New("invalid text format") + } + + parts := bytes.Split(buf[1:len(buf)-1], []byte(",")) + if len(parts) != 2 { + return errors.New("wrong number of parts") + } + + ct.a = string(parts[0]) + ct.b = string(parts[1]) + + return nil +} + +func TestConnInfoScanUnknownOIDToCustomType(t *testing.T) { + unknownOID := uint32(999999) + ci := pgtype.NewConnInfo() + + var ct pgCustomType + err := ci.Scan(unknownOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct) + assert.NoError(t, err) + assert.Equal(t, "foo", ct.a) + assert.Equal(t, "bar", ct.b) + + // Scan value into pointer to custom type + var pCt *pgCustomType + err = ci.Scan(unknownOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt) + assert.NoError(t, err) + require.NotNil(t, pCt) + assert.Equal(t, "foo", pCt.a) + assert.Equal(t, "bar", pCt.b) + + // Scan null into pointer to custom type + err = ci.Scan(unknownOID, pgx.TextFormatCode, nil, &pCt) + assert.NoError(t, err) + assert.Nil(t, pCt) +} From af517d68fc1775f22fa81e9d9852185d955e8e35 Mon Sep 17 00:00:00 2001 From: jaltavilla Date: Mon, 21 Oct 2019 17:21:42 -0400 Subject: [PATCH 0362/1158] Scan into nullable custom types (pointers to pointers). --- pgtype.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/pgtype.go b/pgtype.go index 05a4fb05..058aa5c6 100644 --- a/pgtype.go +++ b/pgtype.go @@ -379,6 +379,24 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac } } + // We might be given a pointer to something that implements the decoder interface(s), + // even though the pointer itself doesn't. + refVal := reflect.ValueOf(dest) + if refVal.Kind() == reflect.Ptr && refVal.Type().Elem().Kind() == reflect.Ptr { + // If the database returned NULL, then we set dest as nil to indicate that. + if buf == nil { + nilPtr := reflect.Zero(refVal.Type().Elem()) + refVal.Elem().Set(nilPtr) + return nil + } + + // We need to allocate an element, and set the destination to it + // Then we can retry as that element. + elemPtr := reflect.New(refVal.Type().Elem().Elem()) + refVal.Elem().Set(elemPtr) + return ci.Scan(oid, formatCode, buf, elemPtr.Interface()) + } + return scanUnknownType(oid, formatCode, buf, dest) } From f711de35917e8c398c066bab98a80d1b6c97ab7f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 22 Oct 2019 20:45:14 -0500 Subject: [PATCH 0363/1158] Release 1.0.2 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 20605e2a..bd83dd63 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.0.2 (October 22, 2019) + +* Fix scan into null into pointer to pointer implementing Decode* interface. (Jeremy Altavilla) + # 1.0.1 (September 19, 2019) * Fix daterange OID From 0079108e29e16232f4f091461f8f59d1d826d01a Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Fri, 8 Nov 2019 14:59:19 -0500 Subject: [PATCH 0364/1158] Fixes #11 -- support initializing Array types from a slice of the value --- aclitem_array.go | 16 ++++++++++++++-- bool_array.go | 18 +++++++++++++++--- bpchar_array.go | 18 +++++++++++++++--- bytea_array.go | 18 +++++++++++++++--- cidr_array.go | 18 +++++++++++++++--- date_array.go | 18 +++++++++++++++--- enum_array.go | 16 ++++++++++++++-- float4_array.go | 18 +++++++++++++++--- float8_array.go | 18 +++++++++++++++--- hstore_array.go | 18 +++++++++++++++--- inet_array.go | 18 +++++++++++++++--- int2_array.go | 18 +++++++++++++++--- int4_array.go | 37 +++++++++++++++---------------------- int8_array.go | 18 +++++++++++++++--- macaddr_array.go | 18 +++++++++++++++--- numeric_array.go | 18 +++++++++++++++--- text_array.go | 18 +++++++++++++++--- timestamp_array.go | 18 +++++++++++++++--- timestamptz_array.go | 18 +++++++++++++++--- typed_array.go.erb | 14 ++++++++++++++ uuid_array.go | 18 +++++++++++++++--- varchar_array.go | 18 +++++++++++++++--- 22 files changed, 327 insertions(+), 80 deletions(-) diff --git a/aclitem_array.go b/aclitem_array.go index e8142091..e41edaea 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -40,6 +40,18 @@ func (dst *ACLItemArray) Set(src interface{}) error { } } + case []ACLItem: + if value == nil { + *dst = ACLItemArray{Status: Null} + } else if len(value) == 0 { + *dst = ACLItemArray{Status: Present} + } else { + *dst = ACLItemArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -124,7 +136,7 @@ func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { return nil } -func (src ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -200,7 +212,7 @@ func (dst *ACLItemArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src ACLItemArray) Value() (driver.Value, error) { +func (src *ACLItemArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/bool_array.go b/bool_array.go index ba453254..89fac9ec 100644 --- a/bool_array.go +++ b/bool_array.go @@ -42,6 +42,18 @@ func (dst *BoolArray) Set(src interface{}) error { } } + case []Bool: + if value == nil { + *dst = BoolArray{Status: Null} + } else if len(value) == 0 { + *dst = BoolArray{Status: Present} + } else { + *dst = BoolArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -168,7 +180,7 @@ func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +237,7 @@ func (src BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +300,7 @@ func (dst *BoolArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src BoolArray) Value() (driver.Value, error) { +func (src *BoolArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/bpchar_array.go b/bpchar_array.go index da601d0d..d974df16 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -42,6 +42,18 @@ func (dst *BPCharArray) Set(src interface{}) error { } } + case []BPChar: + if value == nil { + *dst = BPCharArray{Status: Null} + } else if len(value) == 0 { + *dst = BPCharArray{Status: Present} + } else { + *dst = BPCharArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -168,7 +180,7 @@ func (dst *BPCharArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +237,7 @@ func (src BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src BPCharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *BPCharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +300,7 @@ func (dst *BPCharArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src BPCharArray) Value() (driver.Value, error) { +func (src *BPCharArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/bytea_array.go b/bytea_array.go index 1c2f6548..a8a67368 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -42,6 +42,18 @@ func (dst *ByteaArray) Set(src interface{}) error { } } + case []Bytea: + if value == nil { + *dst = ByteaArray{Status: Null} + } else if len(value) == 0 { + *dst = ByteaArray{Status: Present} + } else { + *dst = ByteaArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -168,7 +180,7 @@ func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +237,7 @@ func (src ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +300,7 @@ func (dst *ByteaArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src ByteaArray) Value() (driver.Value, error) { +func (src *ByteaArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/cidr_array.go b/cidr_array.go index 234c6aff..bddf74ec 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -62,6 +62,18 @@ func (dst *CIDRArray) Set(src interface{}) error { } } + case []CIDR: + if value == nil { + *dst = CIDRArray{Status: Null} + } else if len(value) == 0 { + *dst = CIDRArray{Status: Present} + } else { + *dst = CIDRArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -197,7 +209,7 @@ func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -254,7 +266,7 @@ func (src CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -317,7 +329,7 @@ func (dst *CIDRArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src CIDRArray) Value() (driver.Value, error) { +func (src *CIDRArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/date_array.go b/date_array.go index 69fc3e5e..95f52ac0 100644 --- a/date_array.go +++ b/date_array.go @@ -43,6 +43,18 @@ func (dst *DateArray) Set(src interface{}) error { } } + case []Date: + if value == nil { + *dst = DateArray{Status: Null} + } else if len(value) == 0 { + *dst = DateArray{Status: Present} + } else { + *dst = DateArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -169,7 +181,7 @@ func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -226,7 +238,7 @@ func (src DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -289,7 +301,7 @@ func (dst *DateArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src DateArray) Value() (driver.Value, error) { +func (src *DateArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/enum_array.go b/enum_array.go index f4609169..f32be61c 100644 --- a/enum_array.go +++ b/enum_array.go @@ -40,6 +40,18 @@ func (dst *EnumArray) Set(src interface{}) error { } } + case []GenericText: + if value == nil { + *dst = EnumArray{Status: Null} + } else if len(value) == 0 { + *dst = EnumArray{Status: Present} + } else { + *dst = EnumArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -124,7 +136,7 @@ func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { return nil } -func (src EnumArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *EnumArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -200,7 +212,7 @@ func (dst *EnumArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src EnumArray) Value() (driver.Value, error) { +func (src *EnumArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/float4_array.go b/float4_array.go index 80aff879..a21e0a1f 100644 --- a/float4_array.go +++ b/float4_array.go @@ -42,6 +42,18 @@ func (dst *Float4Array) Set(src interface{}) error { } } + case []Float4: + if value == nil { + *dst = Float4Array{Status: Null} + } else if len(value) == 0 { + *dst = Float4Array{Status: Present} + } else { + *dst = Float4Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -168,7 +180,7 @@ func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +237,7 @@ func (src Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +300,7 @@ func (dst *Float4Array) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Float4Array) Value() (driver.Value, error) { +func (src *Float4Array) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/float8_array.go b/float8_array.go index 3999cf7d..6a44339a 100644 --- a/float8_array.go +++ b/float8_array.go @@ -42,6 +42,18 @@ func (dst *Float8Array) Set(src interface{}) error { } } + case []Float8: + if value == nil { + *dst = Float8Array{Status: Null} + } else if len(value) == 0 { + *dst = Float8Array{Status: Present} + } else { + *dst = Float8Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -168,7 +180,7 @@ func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +237,7 @@ func (src Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +300,7 @@ func (dst *Float8Array) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Float8Array) Value() (driver.Value, error) { +func (src *Float8Array) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/hstore_array.go b/hstore_array.go index 8269fb40..a0a2b3a9 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -42,6 +42,18 @@ func (dst *HstoreArray) Set(src interface{}) error { } } + case []Hstore: + if value == nil { + *dst = HstoreArray{Status: Null} + } else if len(value) == 0 { + *dst = HstoreArray{Status: Present} + } else { + *dst = HstoreArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -168,7 +180,7 @@ func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +237,7 @@ func (src HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +300,7 @@ func (dst *HstoreArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src HstoreArray) Value() (driver.Value, error) { +func (src *HstoreArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/inet_array.go b/inet_array.go index a6fd419e..d754fab3 100644 --- a/inet_array.go +++ b/inet_array.go @@ -62,6 +62,18 @@ func (dst *InetArray) Set(src interface{}) error { } } + case []Inet: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + *dst = InetArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -197,7 +209,7 @@ func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -254,7 +266,7 @@ func (src InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -317,7 +329,7 @@ func (dst *InetArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src InetArray) Value() (driver.Value, error) { +func (src *InetArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/int2_array.go b/int2_array.go index beea543f..59c05de3 100644 --- a/int2_array.go +++ b/int2_array.go @@ -61,6 +61,18 @@ func (dst *Int2Array) Set(src interface{}) error { } } + case []Int2: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + *dst = Int2Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -196,7 +208,7 @@ func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -253,7 +265,7 @@ func (src Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -316,7 +328,7 @@ func (dst *Int2Array) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Int2Array) Value() (driver.Value, error) { +func (src *Int2Array) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/int4_array.go b/int4_array.go index 83ee4c26..08040955 100644 --- a/int4_array.go +++ b/int4_array.go @@ -23,25 +23,6 @@ func (dst *Int4Array) Set(src interface{}) error { switch value := src.(type) { - case []int: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - case []int32: if value == nil { *dst = Int4Array{Status: Null} @@ -80,6 +61,18 @@ func (dst *Int4Array) Set(src interface{}) error { } } + case []Int4: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + *dst = Int4Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -215,7 +208,7 @@ func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -272,7 +265,7 @@ func (src Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -335,7 +328,7 @@ func (dst *Int4Array) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Int4Array) Value() (driver.Value, error) { +func (src *Int4Array) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/int8_array.go b/int8_array.go index f118bc83..8cb446eb 100644 --- a/int8_array.go +++ b/int8_array.go @@ -61,6 +61,18 @@ func (dst *Int8Array) Set(src interface{}) error { } } + case []Int8: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + *dst = Int8Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -196,7 +208,7 @@ func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -253,7 +265,7 @@ func (src Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -316,7 +328,7 @@ func (dst *Int8Array) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Int8Array) Value() (driver.Value, error) { +func (src *Int8Array) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/macaddr_array.go b/macaddr_array.go index 7c62da2b..88bc44fd 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -43,6 +43,18 @@ func (dst *MacaddrArray) Set(src interface{}) error { } } + case []Macaddr: + if value == nil { + *dst = MacaddrArray{Status: Null} + } else if len(value) == 0 { + *dst = MacaddrArray{Status: Present} + } else { + *dst = MacaddrArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -169,7 +181,7 @@ func (dst *MacaddrArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src MacaddrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *MacaddrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -226,7 +238,7 @@ func (src MacaddrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src MacaddrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *MacaddrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -289,7 +301,7 @@ func (dst *MacaddrArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src MacaddrArray) Value() (driver.Value, error) { +func (src *MacaddrArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/numeric_array.go b/numeric_array.go index 8757b14d..cbd2e93f 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -99,6 +99,18 @@ func (dst *NumericArray) Set(src interface{}) error { } } + case []Numeric: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + *dst = NumericArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -252,7 +264,7 @@ func (dst *NumericArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -309,7 +321,7 @@ func (src NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -372,7 +384,7 @@ func (dst *NumericArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src NumericArray) Value() (driver.Value, error) { +func (src *NumericArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/text_array.go b/text_array.go index fca36ec8..d6aa3cfb 100644 --- a/text_array.go +++ b/text_array.go @@ -42,6 +42,18 @@ func (dst *TextArray) Set(src interface{}) error { } } + case []Text: + if value == nil { + *dst = TextArray{Status: Null} + } else if len(value) == 0 { + *dst = TextArray{Status: Present} + } else { + *dst = TextArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -168,7 +180,7 @@ func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +237,7 @@ func (src TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +300,7 @@ func (dst *TextArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src TextArray) Value() (driver.Value, error) { +func (src *TextArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/timestamp_array.go b/timestamp_array.go index 204b22eb..18d54b38 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -43,6 +43,18 @@ func (dst *TimestampArray) Set(src interface{}) error { } } + case []Timestamp: + if value == nil { + *dst = TimestampArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestampArray{Status: Present} + } else { + *dst = TimestampArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -169,7 +181,7 @@ func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -226,7 +238,7 @@ func (src TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -289,7 +301,7 @@ func (dst *TimestampArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src TimestampArray) Value() (driver.Value, error) { +func (src *TimestampArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/timestamptz_array.go b/timestamptz_array.go index 9bef64c6..98593305 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -43,6 +43,18 @@ func (dst *TimestamptzArray) Set(src interface{}) error { } } + case []Timestamptz: + if value == nil { + *dst = TimestamptzArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestamptzArray{Status: Present} + } else { + *dst = TimestamptzArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -169,7 +181,7 @@ func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -226,7 +238,7 @@ func (src TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) return buf, nil } -func (src TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -289,7 +301,7 @@ func (dst *TimestamptzArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src TimestamptzArray) Value() (driver.Value, error) { +func (src *TimestamptzArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/typed_array.go.erb b/typed_array.go.erb index 3ee637aa..2279380b 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -23,6 +23,7 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { switch value := src.(type) { <% go_array_types.split(",").each do |t| %> + <% if t != pgtype_element_type %> case <%= t %>: if value == nil { *dst = <%= pgtype_array_type %>{Status: Null} @@ -42,6 +43,19 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { } } <% end %> + <% end %> + case []<%= pgtype_element_type %>: + if value == nil { + *dst = <%= pgtype_array_type %>{Status: Null} + } else if len(value) == 0 { + *dst = <%= pgtype_array_type %>{Status: Present} + } else { + *dst = <%= pgtype_array_type %>{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status : Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) diff --git a/uuid_array.go b/uuid_array.go index c3f18882..25bf21a8 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -80,6 +80,18 @@ func (dst *UUIDArray) Set(src interface{}) error { } } + case []UUID: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + *dst = UUIDArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -224,7 +236,7 @@ func (dst *UUIDArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -281,7 +293,7 @@ func (src UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src UUIDArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *UUIDArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -344,7 +356,7 @@ func (dst *UUIDArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src UUIDArray) Value() (driver.Value, error) { +func (src *UUIDArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/varchar_array.go b/varchar_array.go index 1e60c344..aa505404 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -42,6 +42,18 @@ func (dst *VarcharArray) Set(src interface{}) error { } } + case []Varchar: + if value == nil { + *dst = VarcharArray{Status: Null} + } else if len(value) == 0 { + *dst = VarcharArray{Status: Present} + } else { + *dst = VarcharArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -168,7 +180,7 @@ func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +237,7 @@ func (src VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +300,7 @@ func (dst *VarcharArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src VarcharArray) Value() (driver.Value, error) { +func (src *VarcharArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err From be36a7e14b3e4f6938baa727139d6fa95f6ad1fe Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 14 Nov 2019 20:40:41 -0600 Subject: [PATCH 0365/1158] Fix test and avoid change to array signatures typed_array.go.erb was not updated back in a8802b16cc593842f5c69b0f7cfb0de11d5cd3a8 when Value, EncodeBinary, EncodeText, and MarshalJSON were changed to be defined on T instead of *T. This has been corrected. --- aclitem_array.go | 4 ++-- bool_array.go | 6 +++--- bpchar_array.go | 6 +++--- bytea_array.go | 6 +++--- cidr_array.go | 6 +++--- date_array.go | 6 +++--- enum_array.go | 4 ++-- float4_array.go | 6 +++--- float8_array.go | 6 +++--- hstore_array.go | 6 +++--- inet_array.go | 6 +++--- int2_array.go | 6 +++--- int4_array.go | 34 +++++++++++++++++++++++++++++++--- int8_array.go | 6 +++--- macaddr_array.go | 6 +++--- numeric_array.go | 6 +++--- text_array.go | 6 +++--- timestamp_array.go | 6 +++--- timestamptz_array.go | 6 +++--- tstzrange_array.go | 17 +++++------------ typed_array.go.erb | 8 ++++---- typed_array_gen.sh | 4 ++-- uuid_array.go | 6 +++--- varchar_array.go | 6 +++--- 24 files changed, 100 insertions(+), 79 deletions(-) diff --git a/aclitem_array.go b/aclitem_array.go index e41edaea..7b2e4dbc 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -136,7 +136,7 @@ func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { return nil } -func (src *ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -212,7 +212,7 @@ func (dst *ACLItemArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *ACLItemArray) Value() (driver.Value, error) { +func (src ACLItemArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/bool_array.go b/bool_array.go index 89fac9ec..3dbb4ca0 100644 --- a/bool_array.go +++ b/bool_array.go @@ -180,7 +180,7 @@ func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -237,7 +237,7 @@ func (src *BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -300,7 +300,7 @@ func (dst *BoolArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *BoolArray) Value() (driver.Value, error) { +func (src BoolArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/bpchar_array.go b/bpchar_array.go index d974df16..b60ccc91 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -180,7 +180,7 @@ func (dst *BPCharArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -237,7 +237,7 @@ func (src *BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *BPCharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src BPCharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -300,7 +300,7 @@ func (dst *BPCharArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *BPCharArray) Value() (driver.Value, error) { +func (src BPCharArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/bytea_array.go b/bytea_array.go index a8a67368..fbebff24 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -180,7 +180,7 @@ func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -237,7 +237,7 @@ func (src *ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -300,7 +300,7 @@ func (dst *ByteaArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *ByteaArray) Value() (driver.Value, error) { +func (src ByteaArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/cidr_array.go b/cidr_array.go index bddf74ec..dbc71bb5 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -209,7 +209,7 @@ func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -266,7 +266,7 @@ func (src *CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -329,7 +329,7 @@ func (dst *CIDRArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *CIDRArray) Value() (driver.Value, error) { +func (src CIDRArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/date_array.go b/date_array.go index 95f52ac0..c97e83ee 100644 --- a/date_array.go +++ b/date_array.go @@ -181,7 +181,7 @@ func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -238,7 +238,7 @@ func (src *DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -301,7 +301,7 @@ func (dst *DateArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *DateArray) Value() (driver.Value, error) { +func (src DateArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/enum_array.go b/enum_array.go index f32be61c..3e07eae9 100644 --- a/enum_array.go +++ b/enum_array.go @@ -136,7 +136,7 @@ func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { return nil } -func (src *EnumArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src EnumArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -212,7 +212,7 @@ func (dst *EnumArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *EnumArray) Value() (driver.Value, error) { +func (src EnumArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/float4_array.go b/float4_array.go index a21e0a1f..07fac71a 100644 --- a/float4_array.go +++ b/float4_array.go @@ -180,7 +180,7 @@ func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -237,7 +237,7 @@ func (src *Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -300,7 +300,7 @@ func (dst *Float4Array) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Float4Array) Value() (driver.Value, error) { +func (src Float4Array) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/float8_array.go b/float8_array.go index 6a44339a..2f65c736 100644 --- a/float8_array.go +++ b/float8_array.go @@ -180,7 +180,7 @@ func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -237,7 +237,7 @@ func (src *Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -300,7 +300,7 @@ func (dst *Float8Array) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Float8Array) Value() (driver.Value, error) { +func (src Float8Array) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/hstore_array.go b/hstore_array.go index a0a2b3a9..06a11c02 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -180,7 +180,7 @@ func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -237,7 +237,7 @@ func (src *HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -300,7 +300,7 @@ func (dst *HstoreArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *HstoreArray) Value() (driver.Value, error) { +func (src HstoreArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/inet_array.go b/inet_array.go index d754fab3..88181739 100644 --- a/inet_array.go +++ b/inet_array.go @@ -209,7 +209,7 @@ func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -266,7 +266,7 @@ func (src *InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -329,7 +329,7 @@ func (dst *InetArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *InetArray) Value() (driver.Value, error) { +func (src InetArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/int2_array.go b/int2_array.go index 59c05de3..27892b15 100644 --- a/int2_array.go +++ b/int2_array.go @@ -208,7 +208,7 @@ func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -265,7 +265,7 @@ func (src *Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -328,7 +328,7 @@ func (dst *Int2Array) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Int2Array) Value() (driver.Value, error) { +func (src Int2Array) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/int4_array.go b/int4_array.go index 08040955..e3819562 100644 --- a/int4_array.go +++ b/int4_array.go @@ -61,6 +61,25 @@ func (dst *Int4Array) Set(src interface{}) error { } } + case []int: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Int4: if value == nil { *dst = Int4Array{Status: Null} @@ -117,6 +136,15 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } return nil + case *[]int: + *v = make([]int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) @@ -208,7 +236,7 @@ func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -265,7 +293,7 @@ func (src *Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -328,7 +356,7 @@ func (dst *Int4Array) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Int4Array) Value() (driver.Value, error) { +func (src Int4Array) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/int8_array.go b/int8_array.go index 8cb446eb..a31a474a 100644 --- a/int8_array.go +++ b/int8_array.go @@ -208,7 +208,7 @@ func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -265,7 +265,7 @@ func (src *Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -328,7 +328,7 @@ func (dst *Int8Array) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Int8Array) Value() (driver.Value, error) { +func (src Int8Array) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/macaddr_array.go b/macaddr_array.go index 88bc44fd..8382ea45 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -181,7 +181,7 @@ func (dst *MacaddrArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *MacaddrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src MacaddrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -238,7 +238,7 @@ func (src *MacaddrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *MacaddrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src MacaddrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -301,7 +301,7 @@ func (dst *MacaddrArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *MacaddrArray) Value() (driver.Value, error) { +func (src MacaddrArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/numeric_array.go b/numeric_array.go index cbd2e93f..432cd96f 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -264,7 +264,7 @@ func (dst *NumericArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -321,7 +321,7 @@ func (src *NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -384,7 +384,7 @@ func (dst *NumericArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *NumericArray) Value() (driver.Value, error) { +func (src NumericArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/text_array.go b/text_array.go index d6aa3cfb..653e41fc 100644 --- a/text_array.go +++ b/text_array.go @@ -180,7 +180,7 @@ func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -237,7 +237,7 @@ func (src *TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -300,7 +300,7 @@ func (dst *TextArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *TextArray) Value() (driver.Value, error) { +func (src TextArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/timestamp_array.go b/timestamp_array.go index 18d54b38..072e01ac 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -181,7 +181,7 @@ func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -238,7 +238,7 @@ func (src *TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) return buf, nil } -func (src *TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -301,7 +301,7 @@ func (dst *TimestampArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *TimestampArray) Value() (driver.Value, error) { +func (src TimestampArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/timestamptz_array.go b/timestamptz_array.go index 98593305..9d0677c8 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -181,7 +181,7 @@ func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -238,7 +238,7 @@ func (src *TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error return buf, nil } -func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -301,7 +301,7 @@ func (dst *TimestamptzArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *TimestamptzArray) Value() (driver.Value, error) { +func (src TimestamptzArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/tstzrange_array.go b/tstzrange_array.go index 8180e4c2..f7c0121d 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -29,19 +29,12 @@ func (dst *TstzrangeArray) Set(src interface{}) error { } else if len(value) == 0 { *dst = TstzrangeArray{Status: Present} } else { - elements := make([]Tstzrange, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } *dst = TstzrangeArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, Status: Present, } } - default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -168,7 +161,7 @@ func (dst *TstzrangeArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *TstzrangeArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src TstzrangeArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -225,7 +218,7 @@ func (src *TstzrangeArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) return buf, nil } -func (src *TstzrangeArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src TstzrangeArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -288,7 +281,7 @@ func (dst *TstzrangeArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *TstzrangeArray) Value() (driver.Value, error) { +func (src TstzrangeArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/typed_array.go.erb b/typed_array.go.erb index 2279380b..72c0c381 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -23,7 +23,7 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { switch value := src.(type) { <% go_array_types.split(",").each do |t| %> - <% if t != pgtype_element_type %> + <% if t != "[]#{pgtype_element_type}" %> case <%= t %>: if value == nil { *dst = <%= pgtype_array_type %>{Status: Null} @@ -184,7 +184,7 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) erro } <% end %> -func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src <%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -242,7 +242,7 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byt } <% if binary_format == "true" %> - func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + func (src <%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -306,7 +306,7 @@ func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *<%= pgtype_array_type %>) Value() (driver.Value, error) { +func (src <%= pgtype_array_type %>) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 76c174ef..6eca219d 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -1,10 +1,10 @@ erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32,[]int element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go -erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange go_array_types=[]Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstz_range_array.go +erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange go_array_types=[]Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstzrange_array.go erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go diff --git a/uuid_array.go b/uuid_array.go index 25bf21a8..7c324e53 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -236,7 +236,7 @@ func (dst *UUIDArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -293,7 +293,7 @@ func (src *UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *UUIDArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src UUIDArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -356,7 +356,7 @@ func (dst *UUIDArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *UUIDArray) Value() (driver.Value, error) { +func (src UUIDArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/varchar_array.go b/varchar_array.go index aa505404..ac9af519 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -180,7 +180,7 @@ func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -237,7 +237,7 @@ func (src *VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -300,7 +300,7 @@ func (dst *VarcharArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *VarcharArray) Value() (driver.Value, error) { +func (src VarcharArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err From 7e1301257e86a0fb94c0b16892a5cba76d0d12e0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Nov 2019 11:10:32 -0600 Subject: [PATCH 0366/1158] Release 1.0.3 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bd83dd63..7db5c1a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.0.3 (November 16, 2019) + +* Support initializing Array types from a slice of the value (Alex Gaynor) + # 1.0.2 (October 22, 2019) * Fix scan into null into pointer to pointer implementing Decode* interface. (Jeremy Altavilla) From eb81d2926b3b0519cd1fe945c2795cceeabe236c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 18 Nov 2019 07:24:05 -0600 Subject: [PATCH 0367/1158] Ignore errors sending Terminate message while closing connection This mimics the behavior of libpq PGfinish. refs #637 --- pgconn.go | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/pgconn.go b/pgconn.go index e3f3aaff..210d9979 100644 --- a/pgconn.go +++ b/pgconn.go @@ -492,15 +492,13 @@ func (pgConn *PgConn) Close(ctx context.Context) error { pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() - _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) - if err != nil { - return err - } - - _, err = pgConn.conn.Read(make([]byte, 1)) - if err != io.EOF { - return err - } + // Ignore any errors sending Terminate message and waiting for server to close connection. + // This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully + // ignores errors. + // + // See https://github.com/jackc/pgx/issues/637 + pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) + pgConn.conn.Read(make([]byte, 1)) return pgConn.conn.Close() } From 32350bd1dc3288aa22f271f87065741da0a1bdb8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 18 Nov 2019 07:28:47 -0600 Subject: [PATCH 0368/1158] TestConnectCustomLookup must test with TCP connection Test (correctly) fails if run on a Unix domain socket. --- pgconn_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pgconn_test.go b/pgconn_test.go index 36499b68..6f330efb 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -191,7 +191,12 @@ func TestConnectCustomDialer(t *testing.T) { func TestConnectCustomLookup(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + config, err := pgconn.ParseConfig(connString) require.NoError(t, err) looked := false From 01ae643a487a6f2fdeaa1458654a5e2aa885c6f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Philippe=20Qu=C3=A9m=C3=A9ner?= Date: Tue, 26 Nov 2019 17:11:54 +0100 Subject: [PATCH 0369/1158] feat: make conversion between numeric values and arrays less strict closes https://github.com/jackc/pgx/issues/642 --- int2_array.go | 140 +++++++++++++++++++++++++++++++++++++++++++++ int4_array.go | 56 ++++++++++++++++++ typed_array_gen.sh | 4 +- 3 files changed, 198 insertions(+), 2 deletions(-) mode change 100644 => 100755 typed_array_gen.sh diff --git a/int2_array.go b/int2_array.go index 27892b15..6e08325c 100644 --- a/int2_array.go +++ b/int2_array.go @@ -61,6 +61,101 @@ func (dst *Int2Array) Set(src interface{}) error { } } + case []int32: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint32: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int64: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint64: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Int2: if value == nil { *dst = Int2Array{Status: Null} @@ -117,6 +212,51 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } return nil + case *[]int32: + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint32: + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int: + *v = make([]int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/int4_array.go b/int4_array.go index e3819562..993cdae9 100644 --- a/int4_array.go +++ b/int4_array.go @@ -61,6 +61,44 @@ func (dst *Int4Array) Set(src interface{}) error { } } + case []int64: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint64: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []int: if value == nil { *dst = Int4Array{Status: Null} @@ -136,6 +174,24 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } return nil + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]int: *v = make([]int, len(src.Elements)) for i := range src.Elements { diff --git a/typed_array_gen.sh b/typed_array_gen.sh old mode 100644 new mode 100755 index 6eca219d..9fc01c2c --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -1,5 +1,5 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32,[]int element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16,[]int32,[]uint32,[]int64,[]uint64,[]int element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32,[]int64,[]uint64,[]int element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go From 9ff83bc41ca7e2bc8e7dfbf6500e51bd0dea27f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Philippe=20Qu=C3=A9m=C3=A9ner?= Date: Tue, 26 Nov 2019 17:31:13 +0100 Subject: [PATCH 0370/1158] feat: add tests for less stricter numeric conversion --- int2_array_test.go | 35 +++++++++++++++++++++++++++++++++++ int4_array_test.go | 14 ++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/int2_array_test.go b/int2_array_test.go index 810d5a7e..22f71745 100644 --- a/int2_array_test.go +++ b/int2_array_test.go @@ -57,6 +57,20 @@ func TestInt2ArraySet(t *testing.T) { source interface{} result pgtype.Int2Array }{ + { + source: []int64{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []int32{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, { source: []int16{1}, result: pgtype.Int2Array{ @@ -64,6 +78,27 @@ func TestInt2ArraySet(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, + { + source: []int{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint64{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint32{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, { source: []uint16{1}, result: pgtype.Int2Array{ diff --git a/int4_array_test.go b/int4_array_test.go index a0b8058f..820b6670 100644 --- a/int4_array_test.go +++ b/int4_array_test.go @@ -59,6 +59,13 @@ func TestInt4ArraySet(t *testing.T) { result pgtype.Int4Array expectedError bool }{ + { + source: []int64{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, { source: []int32{1}, result: pgtype.Int4Array{ @@ -77,6 +84,13 @@ func TestInt4ArraySet(t *testing.T) { source: []int{1, math.MaxInt32 + 1, 2}, expectedError: true, }, + { + source: []uint64{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, { source: []uint32{1}, result: pgtype.Int4Array{ From 038f263a44bace8358b68846bd7a15a4a9cdd66a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 27 Nov 2019 20:23:43 -0600 Subject: [PATCH 0371/1158] Add remaining int array conversions --- int2_array.go | 28 ++++++++ int4_array.go | 84 +++++++++++++++++++++++ int4_array_test.go | 14 ++++ int8_array.go | 168 +++++++++++++++++++++++++++++++++++++++++++++ int8_array_test.go | 42 ++++++++++++ typed_array_gen.sh | 6 +- 6 files changed, 339 insertions(+), 3 deletions(-) diff --git a/int2_array.go b/int2_array.go index 6e08325c..3f6bdb87 100644 --- a/int2_array.go +++ b/int2_array.go @@ -156,6 +156,25 @@ func (dst *Int2Array) Set(src interface{}) error { } } + case []uint: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Int2: if value == nil { *dst = Int2Array{Status: Null} @@ -257,6 +276,15 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } return nil + case *[]uint: + *v = make([]uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/int4_array.go b/int4_array.go index 993cdae9..f3e87b00 100644 --- a/int4_array.go +++ b/int4_array.go @@ -23,6 +23,44 @@ func (dst *Int4Array) Set(src interface{}) error { switch value := src.(type) { + case []int16: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint16: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []int32: if value == nil { *dst = Int4Array{Status: Null} @@ -118,6 +156,25 @@ func (dst *Int4Array) Set(src interface{}) error { } } + case []uint: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Int4: if value == nil { *dst = Int4Array{Status: Null} @@ -156,6 +213,24 @@ func (src *Int4Array) AssignTo(dst interface{}) error { case Present: switch v := dst.(type) { + case *[]int16: + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint16: + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]int32: *v = make([]int32, len(src.Elements)) for i := range src.Elements { @@ -201,6 +276,15 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } return nil + case *[]uint: + *v = make([]uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/int4_array_test.go b/int4_array_test.go index 820b6670..c839c1c9 100644 --- a/int4_array_test.go +++ b/int4_array_test.go @@ -73,6 +73,13 @@ func TestInt4ArraySet(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, + { + source: []int16{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, { source: []int{1}, result: pgtype.Int4Array{ @@ -98,6 +105,13 @@ func TestInt4ArraySet(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, + { + source: []uint16{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, { source: (([]int32)(nil)), result: pgtype.Int4Array{Status: pgtype.Null}, diff --git a/int8_array.go b/int8_array.go index a31a474a..a6798173 100644 --- a/int8_array.go +++ b/int8_array.go @@ -23,6 +23,82 @@ func (dst *Int8Array) Set(src interface{}) error { switch value := src.(type) { + case []int16: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint16: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int32: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint32: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []int64: if value == nil { *dst = Int8Array{Status: Null} @@ -61,6 +137,44 @@ func (dst *Int8Array) Set(src interface{}) error { } } + case []int: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Int8: if value == nil { *dst = Int8Array{Status: Null} @@ -99,6 +213,42 @@ func (src *Int8Array) AssignTo(dst interface{}) error { case Present: switch v := dst.(type) { + case *[]int16: + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint16: + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int32: + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint32: + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]int64: *v = make([]int64, len(src.Elements)) for i := range src.Elements { @@ -117,6 +267,24 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } return nil + case *[]int: + *v = make([]int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint: + *v = make([]uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/int8_array_test.go b/int8_array_test.go index f4ed76e0..e9e7acfb 100644 --- a/int8_array_test.go +++ b/int8_array_test.go @@ -64,6 +64,27 @@ func TestInt8ArraySet(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, + { + source: []int32{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []int16{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []int{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, { source: []uint64{1}, result: pgtype.Int8Array{ @@ -71,6 +92,27 @@ func TestInt8ArraySet(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, + { + source: []uint32{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint16{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, { source: (([]int64)(nil)), result: pgtype.Int8Array{Status: pgtype.Null}, diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 9fc01c2c..6fd49264 100755 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -1,6 +1,6 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16,[]int32,[]uint32,[]int64,[]uint64,[]int element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32,[]int64,[]uint64,[]int element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16,[]int32,[]uint32,[]int64,[]uint64,[]int,[]uint element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]uint16,[]int32,[]uint32,[]int64,[]uint64,[]int,[]uint element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]uint16,[]int32,[]uint32,[]int64,[]uint64,[]int,[]uint element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go From e6b823d64953284cfbc874de358700902129cdae Mon Sep 17 00:00:00 2001 From: Yuli Khodorkovskiy Date: Tue, 17 Dec 2019 20:03:55 -0500 Subject: [PATCH 0372/1158] Add missing GSSEncRequest --- backend.go | 9 ++++++++- gss_enc_request.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 gss_enc_request.go diff --git a/backend.go b/backend.go index 5741647f..cd7e8ce2 100644 --- a/backend.go +++ b/backend.go @@ -19,6 +19,7 @@ type Backend struct { describe Describe execute Execute flush Flush + gssEncRequest GSSEncRequest parse Parse passwordMessage PasswordMessage query Query @@ -45,7 +46,7 @@ func (b *Backend) Send(msg BackendMessage) error { // ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method // because the initial connection message is "special" and does not include the message type as the first byte. This -// will return either a StartupMessage, SSLRequest, or CancelRequest. +// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest. func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { buf, err := b.cr.Next(4) if err != nil { @@ -79,6 +80,12 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { return nil, err } return &b.cancelRequest, nil + case gssEncReqNumber: + err = b.gssEncRequest.Decode(buf) + if err != nil { + return nil, err + } + return &b.gssEncRequest, nil default: return nil, fmt.Errorf("unknown startup message code: %d", code) } diff --git a/gss_enc_request.go b/gss_enc_request.go new file mode 100644 index 00000000..cf405a3e --- /dev/null +++ b/gss_enc_request.go @@ -0,0 +1,49 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +const gssEncReqNumber = 80877104 + +type GSSEncRequest struct { +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*GSSEncRequest) Frontend() {} + +func (dst *GSSEncRequest) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("gss encoding request too short") + } + + requestCode := binary.BigEndian.Uint32(src) + + if requestCode != gssEncReqNumber { + return errors.New("bad gss encoding request code") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 4 byte message length. +func (src *GSSEncRequest) Encode(dst []byte) []byte { + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendInt32(dst, gssEncReqNumber) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src GSSEncRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProtocolVersion uint32 + Parameters map[string]string + }{ + Type: "GSSEncRequest", + }) +} From 1c20e7d36efaac58ececd18182685740c199182a Mon Sep 17 00:00:00 2001 From: Yuli Khodorkovskiy Date: Fri, 6 Dec 2019 10:02:41 -0500 Subject: [PATCH 0373/1158] Fix malformed SASL messages Per the PG documentation [0], an AuthenticationSASLContinue message has: AuthenticationSASLContinue (B) Byte1('R') Identifies the message as an authentication request. Int32 Length of message contents in bytes, including self. Int32(11) Specifies that this message contains a SASL challenge. Byten SASL data, specific to the SASL mechanism being used. The current implementation was mistakenly adding the lengh of msg bytes in between the Int32(11) and Byten. There was a similar issue for AuthenticationSASLFinal. [0] https://www.postgresql.org/docs/current/protocol-message-formats.html --- authentication_sasl_continue.go | 1 - authentication_sasl_final.go | 1 - 2 files changed, 2 deletions(-) diff --git a/authentication_sasl_continue.go b/authentication_sasl_continue.go index a393ae10..1b918a6e 100644 --- a/authentication_sasl_continue.go +++ b/authentication_sasl_continue.go @@ -40,7 +40,6 @@ func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { dst = pgio.AppendInt32(dst, -1) dst = pgio.AppendUint32(dst, AuthTypeSASLContinue) - dst = pgio.AppendInt32(dst, int32(len(src.Data))) dst = append(dst, src.Data...) pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) diff --git a/authentication_sasl_final.go b/authentication_sasl_final.go index b8f89d59..11d35660 100644 --- a/authentication_sasl_final.go +++ b/authentication_sasl_final.go @@ -40,7 +40,6 @@ func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { dst = pgio.AppendInt32(dst, -1) dst = pgio.AppendUint32(dst, AuthTypeSASLFinal) - dst = pgio.AppendInt32(dst, int32(len(src.Data))) dst = append(dst, src.Data...) pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) From c7502af68bb37f6d0191fc462d33f79ed1cfc45b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 19 Dec 2019 21:34:33 -0600 Subject: [PATCH 0374/1158] Add PostgreSQL time type support fixes #15 --- pgtype.go | 2 + time.go | 219 +++++++++++++++++++++++++++++++++++++++++++++++++++ time_test.go | 128 ++++++++++++++++++++++++++++++ 3 files changed, 349 insertions(+) create mode 100644 time.go create mode 100644 time_test.go diff --git a/pgtype.go b/pgtype.go index 058aa5c6..1109d0d8 100644 --- a/pgtype.go +++ b/pgtype.go @@ -52,6 +52,7 @@ const ( BPCharOID = 1042 VarcharOID = 1043 DateOID = 1082 + TimeOID = 1083 TimestampOID = 1114 TimestampArrayOID = 1115 DateArrayOID = 1182 @@ -237,6 +238,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) ci.RegisterDataType(DataType{Value: &Text{}, Name: "text", OID: TextOID}) ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID}) + ci.RegisterDataType(DataType{Value: &Time{}, Name: "time", OID: TimeOID}) ci.RegisterDataType(DataType{Value: &Timestamp{}, Name: "timestamp", OID: TimestampOID}) ci.RegisterDataType(DataType{Value: &Timestamptz{}, Name: "timestamptz", OID: TimestamptzOID}) ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) diff --git a/time.go b/time.go new file mode 100644 index 00000000..3bf91b10 --- /dev/null +++ b/time.go @@ -0,0 +1,219 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + "time" + + "github.com/jackc/pgio" + errors "golang.org/x/xerrors" +) + +// Time represents the PostgreSQL time type. The PostgreSQL time is a time of day without time zone. +// +// Time is represented as the number of microseconds since midnight in the same way that PostgreSQL does. Other time +// and date types in pgtype can use time.Time as the underlying representation. However, pgtype.Time type cannot due +// to needing to handle 24:00:00. time.Time converts that to 00:00:00 on the following day. +type Time struct { + Microseconds int64 // Number of microseconds since midnight + Status Status +} + +// Set converts src into a Time and stores in dst. +func (dst *Time) Set(src interface{}) error { + if src == nil { + *dst = Time{Status: Null} + return nil + } + + switch value := src.(type) { + case time.Time: + usec := int64(value.Hour())*microsecondsPerHour + + int64(value.Minute())*microsecondsPerMinute + + int64(value.Second())*microsecondsPerSecond + + int64(value.Nanosecond())/1000 + *dst = Time{Microseconds: usec, Status: Present} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Time", value) + } + + return nil +} + +func (dst *Time) Get() interface{} { + switch dst.Status { + case Present: + return dst.Microseconds + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Time) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + // 24:00:00 is max allowed time in PostgreSQL, but time.Time will normalize that to 00:00:00 the next day. + var maxRepresentableByTime int64 = 24*60*60*1000000 - 1 + if src.Microseconds > maxRepresentableByTime { + return errors.Errorf("%d microseconds cannot be represented as time.Time", src.Microseconds) + } + + usec := src.Microseconds + hours := usec / microsecondsPerHour + usec -= hours * microsecondsPerHour + minutes := usec / microsecondsPerMinute + usec -= minutes * microsecondsPerMinute + seconds := usec / microsecondsPerSecond + usec -= seconds * microsecondsPerSecond + ns := usec * 1000 + *v = time.Date(2000, 1, 1, int(hours), int(minutes), int(seconds), int(ns), time.UTC) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return errors.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %#v into %T", src, dst) +} + +// DecodeText decodes from src into dst. +func (dst *Time) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Time{Status: Null} + return nil + } + + s := string(src) + + if len(s) < 8 { + return errors.Errorf("cannot decode %v into Time", s) + } + + hours, err := strconv.ParseInt(s[0:2], 10, 64) + if err != nil { + return errors.Errorf("cannot decode %v into Time", s) + } + usec := hours * microsecondsPerHour + + minutes, err := strconv.ParseInt(s[3:5], 10, 64) + if err != nil { + return errors.Errorf("cannot decode %v into Time", s) + } + usec += minutes * microsecondsPerMinute + + seconds, err := strconv.ParseInt(s[6:8], 10, 64) + if err != nil { + return errors.Errorf("cannot decode %v into Time", s) + } + usec += seconds * microsecondsPerSecond + + if len(s) > 9 { + fraction := s[9:] + n, err := strconv.ParseInt(fraction, 10, 64) + if err != nil { + return errors.Errorf("cannot decode %v into Time", s) + } + + for i := len(fraction); i < 6; i++ { + n *= 10 + } + + usec += n + } + + *dst = Time{Microseconds: usec, Status: Present} + + return nil +} + +// DecodeBinary decodes from src into dst. +func (dst *Time) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Time{Status: Null} + return nil + } + + if len(src) != 8 { + return errors.Errorf("invalid length for time: %v", len(src)) + } + + usec := int64(binary.BigEndian.Uint64(src)) + *dst = Time{Microseconds: usec, Status: Present} + + return nil +} + +// EncodeText writes the text encoding of src into w. +func (src Time) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + usec := src.Microseconds + hours := usec / microsecondsPerHour + usec -= hours * microsecondsPerHour + minutes := usec / microsecondsPerMinute + usec -= minutes * microsecondsPerMinute + seconds := usec / microsecondsPerSecond + usec -= seconds * microsecondsPerSecond + + s := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, usec) + + return append(buf, s...), nil +} + +// EncodeBinary writes the binary encoding of src into w. If src.Time is not in +// the UTC time zone it returns an error. +func (src Time) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return pgio.AppendInt64(buf, src.Microseconds), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Time) Scan(src interface{}) error { + if src == nil { + *dst = Time{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + case time.Time: + return dst.Set(src) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Time) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/time_test.go b/time_test.go new file mode 100644 index 00000000..bf6365ef --- /dev/null +++ b/time_test.go @@ -0,0 +1,128 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestTimeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "time", []interface{}{ + &pgtype.Time{Microseconds: 0, Status: pgtype.Present}, + &pgtype.Time{Microseconds: 1, Status: pgtype.Present}, + &pgtype.Time{Microseconds: 86399999999, Status: pgtype.Present}, + &pgtype.Time{Status: pgtype.Null}, + }) +} + +// Test for transcoding 24:00:00 separately as github.com/lib/pq doesn't seem to support it. +func TestTimeTranscode24HH(t *testing.T) { + pgTypeName := "time" + values := []interface{}{ + &pgtype.Time{Microseconds: 86400000000, Status: pgtype.Present}, + } + + eqFunc := func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + } + + testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, "github.com/jackc/pgx/stdlib", pgTypeName, values, eqFunc) +} + +func TestTimeSet(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Time + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 0, Status: pgtype.Present}}, + {source: time.Date(1900, 1, 1, 1, 0, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 3600000000, Status: pgtype.Present}}, + {source: time.Date(1900, 1, 1, 0, 1, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 60000000, Status: pgtype.Present}}, + {source: time.Date(1900, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Time{Microseconds: 1000000, Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 1, time.UTC), result: pgtype.Time{Microseconds: 0, Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 1000, time.UTC), result: pgtype.Time{Microseconds: 1, Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 23, 59, 59, 999999999, time.UTC), result: pgtype.Time{Microseconds: 86399999999, Status: pgtype.Present}}, + {source: time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local), result: pgtype.Time{Microseconds: 2, Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 3000, time.UTC)), result: pgtype.Time{Microseconds: 3, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Time + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimeAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Time + dst interface{} + expected interface{} + }{ + {src: pgtype.Time{Microseconds: 0, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 3600000000, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 1, 0, 0, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 60000000, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 1, 0, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 1000000, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 1, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 0, 1000, time.UTC)}, + {src: pgtype.Time{Microseconds: 86399999999, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 23, 59, 59, 999999000, time.UTC)}, + {src: pgtype.Time{Microseconds: 0, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Time + dst interface{} + expected interface{} + }{ + {src: pgtype.Time{Microseconds: 0, Status: pgtype.Present}, dst: &ptim, expected: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Time + dst interface{} + }{ + {src: pgtype.Time{Microseconds: 86400000000, Status: pgtype.Present}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} From bd0ce203e9563e2b966b5f796bab2cf9f555bc2b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 Dec 2019 10:31:27 -0600 Subject: [PATCH 0375/1158] CopyFrom not table test was failing with syntax error --- pgconn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn_test.go b/pgconn_test.go index 6f330efb..6b57dd09 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1444,7 +1444,7 @@ func TestConnCopyFromQueryNoTableError(t *testing.T) { srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "copy foo to stdout") require.Error(t, err) assert.IsType(t, &pgconn.PgError{}, err) assert.Equal(t, int64(0), res.RowsAffected()) From dd53b7488d920c44204098385e460cf708626a42 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 Dec 2019 10:06:24 -0600 Subject: [PATCH 0376/1158] Restart signalMessage when receiving non-error message in CopyFrom fixes #21 --- pgconn.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgconn.go b/pgconn.go index 210d9979..4c75d367 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1022,6 +1022,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co switch msg := msg.(type) { case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) + default: + signalMessageChan = pgConn.signalMessage() } default: } From 18d1ed5ee5619f59c6b5e670e2ffa36f1ffe95fd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 Dec 2019 14:37:09 -0600 Subject: [PATCH 0377/1158] Remove PostgreSQL 9.3 from Travis build matrix PostgreSQL 9.3 is EOL so it doesn't make sense for pgconn to specifically support. There are no known incompatibilities but it will not longer be tested. --- .travis.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index abff8515..50f13881 100644 --- a/.travis.yml +++ b/.travis.yml @@ -28,7 +28,6 @@ env: - PGVERSION=9.6 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret dbname=pgx_test" - PGVERSION=9.5 - PGVERSION=9.4 - - PGVERSION=9.3 cache: directories: From 5fc867a833afcd0d51c7d05bbff19e16e2adb34d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 Dec 2019 14:40:30 -0600 Subject: [PATCH 0378/1158] Remove unused travis environment variable --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 50f13881..c1688000 100644 --- a/.travis.yml +++ b/.travis.yml @@ -24,8 +24,8 @@ env: - PGX_TEST_PLAIN_PASSWORD_CONN_STRING=postgres://pgx_pw:secret@127.0.0.1/pgx_test matrix: - CRATEVERSION=2.1 PGX_TEST_CRATEDB_CONN_STRING="host=127.0.0.1 port=6543 user=pgx dbname=pgx_test" - - PGVERSION=10 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret dbname=pgx_test" - - PGVERSION=9.6 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret dbname=pgx_test" + - PGVERSION=10 + - PGVERSION=9.6 - PGVERSION=9.5 - PGVERSION=9.4 From 3e503b7b1a3beb8466a0ef126b87c209c5dc91e6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 Dec 2019 14:41:09 -0600 Subject: [PATCH 0379/1158] Add PostgreSQL 11 and 12 to the Travis build matrix --- .travis.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.travis.yml b/.travis.yml index c1688000..87a0c058 100644 --- a/.travis.yml +++ b/.travis.yml @@ -24,6 +24,8 @@ env: - PGX_TEST_PLAIN_PASSWORD_CONN_STRING=postgres://pgx_pw:secret@127.0.0.1/pgx_test matrix: - CRATEVERSION=2.1 PGX_TEST_CRATEDB_CONN_STRING="host=127.0.0.1 port=6543 user=pgx dbname=pgx_test" + - PGVERSION=12 + - PGVERSION=11 - PGVERSION=10 - PGVERSION=9.6 - PGVERSION=9.5 From 89416dd80542cc62f45af214ca0722c32e6624ca Mon Sep 17 00:00:00 2001 From: bakape Date: Wed, 1 Jan 2020 13:09:50 +0200 Subject: [PATCH 0380/1158] Enable passing nil context --- .gitignore | 3 +- doc.go | 3 + pgconn.go | 187 +++++++++++++++++++++++++++++++---------------------- 3 files changed, 116 insertions(+), 77 deletions(-) diff --git a/.gitignore b/.gitignore index 6eb9d442..e980f555 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .envrc -vendor/ \ No newline at end of file +vendor/ +.vscode diff --git a/doc.go b/doc.go index cde58cd8..12ed6630 100644 --- a/doc.go +++ b/doc.go @@ -23,6 +23,9 @@ Context Support All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the method immediately returns. In most circumstances, this will close the underlying connection. +A nil context can be passed for convenience. This has the same effect as passing context.Background() with an additional +slight performance increase, if you don't need the operation to be cancellable. + The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the client to abort. */ diff --git a/pgconn.go b/pgconn.go index 4c75d367..3b90b802 100644 --- a/pgconn.go +++ b/pgconn.go @@ -116,6 +116,10 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err panic("config must be created by ParseConfig") } + if ctx == nil { + ctx = context.Background() + } + // Simplify usage by treating primary config and fallbacks the same. fallbackConfigs := []*FallbackConfig{ { @@ -362,13 +366,15 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { } defer pgConn.unlock() - select { - case <-ctx.Done(): - return &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + return &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() n, err := pgConn.conn.Write(buf) if err != nil { @@ -392,13 +398,15 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa } defer pgConn.unlock() - select { - case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() msg, err := pgConn.receiveMessage() if err != nil { @@ -489,8 +497,10 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() + if ctx != nil { + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } // Ignore any errors sending Terminate message and waiting for server to close connection. // This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully @@ -586,13 +596,15 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ } defer pgConn.unlock() - select { - case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() buf := pgConn.wbuf buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) @@ -673,18 +685,24 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() - cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + _ctx := ctx + if _ctx == nil { + _ctx = context.Background() + } + cancelConn, err := pgConn.config.DialFunc(_ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err } defer cancelConn.Close() - contextWatcher := ctxwatch.NewContextWatcher( - func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { cancelConn.SetDeadline(time.Time{}) }, - ) - contextWatcher.Watch(ctx) - defer contextWatcher.Unwatch() + if ctx != nil { + contextWatcher := ctxwatch.NewContextWatcher( + func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { cancelConn.SetDeadline(time.Time{}) }, + ) + contextWatcher.Watch(ctx) + defer contextWatcher.Unwatch() + } buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) @@ -712,14 +730,16 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { } defer pgConn.unlock() - select { - case <-ctx.Done(): - return ctx.Err() - default: - } + if ctx != nil { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } for { msg, err := pgConn.receiveMessage() @@ -752,16 +772,19 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { ctx: ctx, } multiResult := &pgConn.multiResultReader - - select { - case <-ctx.Done(): - multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} - pgConn.unlock() - return multiResult - default: + if ctx != nil { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } else { + pgConn.multiResultReader.ctx = context.Background() } - pgConn.contextWatcher.Watch(ctx) buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) @@ -808,7 +831,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - pgConn.execExtendedSuffix(ctx, buf, result) + pgConn.execExtendedSuffix(buf, result) return result } @@ -834,7 +857,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa buf := pgConn.wbuf buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - pgConn.execExtendedSuffix(ctx, buf, result) + pgConn.execExtendedSuffix(buf, result) return result } @@ -845,6 +868,9 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by ctx: ctx, } result := &pgConn.resultReader + if ctx == nil { + pgConn.resultReader.ctx = context.Background() + } if err := pgConn.lock(); err != nil { result.concludeCommand(nil, err) @@ -859,20 +885,22 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } - select { - case <-ctx.Done(): - result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) - result.closed = true - pgConn.unlock() - return result - default: + if ctx != nil { + select { + case <-ctx.Done(): + result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) + result.closed = true + pgConn.unlock() + return result + default: + } + pgConn.contextWatcher.Watch(ctx) } - pgConn.contextWatcher.Watch(ctx) return result } -func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result *ResultReader) { +func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) @@ -893,14 +921,16 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, err } - select { - case <-ctx.Done(): - pgConn.unlock() - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + pgConn.unlock() + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf @@ -952,13 +982,15 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } defer pgConn.unlock() - select { - case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf @@ -1344,16 +1376,19 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR ctx: ctx, } multiResult := &pgConn.multiResultReader - - select { - case <-ctx.Done(): - multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} - pgConn.unlock() - return multiResult - default: + if ctx != nil { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } else { + pgConn.multiResultReader.ctx = context.Background() } - pgConn.contextWatcher.Watch(ctx) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) From 719623452110bc4bce0e2358db9d3df658777eeb Mon Sep 17 00:00:00 2001 From: bakape Date: Wed, 1 Jan 2020 13:10:04 +0200 Subject: [PATCH 0381/1158] Benchmark nil context execution --- benchmark_test.go | 156 +++++++++++++++++++++++++++------------------- 1 file changed, 93 insertions(+), 63 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 8067c985..1914e07a 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -14,9 +14,14 @@ func BenchmarkConnect(b *testing.B) { benchmarks := []struct { name string env string + ctx context.Context }{ - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, - {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + // The first benchmark in the list sometimes executes faster, no matter how + // you reorder it. Nil context is still faster on average. + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING", context.Background()}, + {"TCP", "PGX_TEST_TCP_CONN_STRING", context.Background()}, + {"Unix socket nil context", "PGX_TEST_UNIX_SOCKET_CONN_STRING", nil}, + {"TCP nil context", "PGX_TEST_TCP_CONN_STRING", nil}, } for _, bm := range benchmarks { @@ -28,10 +33,10 @@ func BenchmarkConnect(b *testing.B) { } for i := 0; i < b.N; i++ { - conn, err := pgconn.Connect(context.Background(), connString) + conn, err := pgconn.Connect(bm.ctx, connString) require.Nil(b, err) - err = conn.Close(context.Background()) + err = conn.Close(bm.ctx) require.Nil(b, err) } }) @@ -39,46 +44,58 @@ func BenchmarkConnect(b *testing.B) { } func BenchmarkExec(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.Nil(b, err) - defer closeConn(b, conn) - expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + benchmarks := []struct { + name string + ctx context.Context + }{ + {"background context", context.Background()}, + {"nil context", nil}, + } - b.ResetTimer() + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.Nil(b, err) + defer closeConn(b, conn) - for i := 0; i < b.N; i++ { - mrr := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + b.ResetTimer() - for mrr.NextResult() { - rr := mrr.ResultReader() + for i := 0; i < b.N; i++ { + mrr := conn.Exec(bm.ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") - rowCount := 0 - for rr.NextRow() { - rowCount++ - if len(rr.Values()) != len(expectedValues) { - b.Fatalf("unexpected number of values: %d", len(rr.Values())) - } - for i := range rr.Values() { - if !bytes.Equal(rr.Values()[i], expectedValues[i]) { - b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + for mrr.NextResult() { + rr := mrr.ResultReader() + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) } } - } - _, err = rr.Close() - if err != nil { - b.Fatal(err) + err := mrr.Close() + if err != nil { + b.Fatal(err) + } } - if rowCount != 1 { - b.Fatalf("unexpected rowCount: %d", rowCount) - } - } - - err := mrr.Close() - if err != nil { - b.Fatal(err) - } + }) } } @@ -130,40 +147,53 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { } func BenchmarkExecPrepared(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.Nil(b, err) - defer closeConn(b, conn) - - _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) - require.Nil(b, err) - expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} - b.ResetTimer() + benchmarks := []struct { + name string + ctx context.Context + }{ + {"background context", context.Background()}, + {"nil context", nil}, + } - for i := 0; i < b.N; i++ { - rr := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil) + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.Nil(b, err) + defer closeConn(b, conn) - rowCount := 0 - for rr.NextRow() { - rowCount++ - if len(rr.Values()) != len(expectedValues) { - b.Fatalf("unexpected number of values: %d", len(rr.Values())) - } - for i := range rr.Values() { - if !bytes.Equal(rr.Values()[i], expectedValues[i]) { - b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + _, err = conn.Prepare(bm.ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + require.Nil(b, err) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr := conn.ExecPrepared(bm.ctx, "ps1", nil, nil, nil) + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) } } - } - _, err = rr.Close() - - if err != nil { - b.Fatal(err) - } - if rowCount != 1 { - b.Fatalf("unexpected rowCount: %d", rowCount) - } + }) } } From 4d345164f1027d985717335e841868f60ca69ac2 Mon Sep 17 00:00:00 2001 From: bakape Date: Wed, 1 Jan 2020 14:36:38 +0200 Subject: [PATCH 0382/1158] Branch tests for nil context --- README.md | 4 +- helper_test.go | 22 + pgconn_test.go | 1500 +++++++++++++++++++++++++----------------------- 3 files changed, 818 insertions(+), 708 deletions(-) diff --git a/README.md b/README.md index 5d14e914..ddbfeaf3 100644 --- a/README.md +++ b/README.md @@ -11,13 +11,13 @@ low-level access to PostgreSQL functionality. ## Example Usage ```go -pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) +pgConn, err := pgconn.Connect(nil, os.Getenv("DATABASE_URL")) if err != nil { log.Fatalln("pgconn failed to connect:", err) } defer pgConn.Close() -result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +result := pgConn.ExecParams(nil, "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) for result.NextRow() { fmt.Println("User 123 has email:", string(result.Values()[0])) } diff --git a/helper_test.go b/helper_test.go index 1a3ca75e..1cb05fd2 100644 --- a/helper_test.go +++ b/helper_test.go @@ -29,3 +29,25 @@ func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { assert.Equal(t, "2", string(result.Rows[1][0])) assert.Equal(t, "3", string(result.Rows[2][0])) } + +// Run subtest both with a context.Background() and nil context +func splitOnContext(t *testing.T, test func(t *testing.T, ctx context.Context)) { + t.Helper() + + cases := [...]struct { + name string + ctx context.Context + }{ + {"background context", context.Background()}, + {"nil context", nil}, + } + + for i := range cases { + c := cases[i] + t.Run(c.name, func(t *testing.T) { + t.Helper() + t.Parallel() + test(t, c.ctx) + }) + } +} diff --git a/pgconn_test.go b/pgconn_test.go index 6b57dd09..30d20229 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -27,31 +27,33 @@ import ( ) func TestConnect(t *testing.T) { - tests := []struct { - name string - env string - }{ - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, - {"TCP", "PGX_TEST_TCP_CONN_STRING"}, - {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, - {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, - {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, + } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - connString := os.Getenv(tt.env) - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", tt.env) - } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } - conn, err := pgconn.Connect(context.Background(), connString) - require.NoError(t, err) + conn, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) - closeConn(t, conn) - }) - } + closeConn(t, conn) + }) + } + }) } // TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure @@ -59,19 +61,21 @@ func TestConnect(t *testing.T) { func TestConnectTLS(t *testing.T) { t.Parallel() - connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } - conn, err := pgconn.Connect(context.Background(), connString) - require.NoError(t, err) + conn, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) - if _, ok := conn.Conn().(*tls.Conn); !ok { - t.Error("not a TLS connection") - } + if _, ok := conn.Conn().(*tls.Conn); !ok { + t.Error("not a TLS connection") + } - closeConn(t, conn) + closeConn(t, conn) + }) } type pgmockWaitStep time.Duration @@ -138,233 +142,259 @@ func TestConnectWithContextThatTimesOut(t *testing.T) { func TestConnectInvalidUser(t *testing.T) { t.Parallel() - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) - config.User = "pgxinvalidusertest" + config.User = "pgxinvalidusertest" - _, err = pgconn.ConnectConfig(context.Background(), config) - require.Error(t, err) - pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) - if !ok { - t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) - } - if pgErr.Code != "28000" && pgErr.Code != "28P01" { - t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) - } + _, err = pgconn.ConnectConfig(ctx, config) + require.Error(t, err) + pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) + if !ok { + t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) + } + if pgErr.Code != "28000" && pgErr.Code != "28P01" { + t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) + } + }) } func TestConnectWithConnectionRefused(t *testing.T) { t.Parallel() - // Presumably nothing is listening on 127.0.0.1:1 - conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") - if err == nil { - conn.Close(context.Background()) - t.Fatal("Expected error establishing connection to bad port") - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + // Presumably nothing is listening on 127.0.0.1:1 + conn, err := pgconn.Connect(ctx, "host=127.0.0.1 port=1") + if err == nil { + conn.Close(ctx) + t.Fatal("Expected error establishing connection to bad port") + } + }) } func TestConnectCustomDialer(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - dialed := false - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialed = true - return net.Dial(network, address) - } + dialed := false + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialed = true + return net.Dial(network, address) + } - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - require.True(t, dialed) - closeConn(t, conn) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + require.True(t, dialed) + closeConn(t, conn) + }) } func TestConnectCustomLookup(t *testing.T) { t.Parallel() - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) - looked := false - config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { - looked = true - return net.LookupHost(host) - } + looked := false + config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { + looked = true + return net.LookupHost(host) + } - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - require.True(t, looked) - closeConn(t, conn) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) + }) } func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.RuntimeParams = map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - } + config.RuntimeParams = map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + } - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, conn) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, conn) - result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "pgxtest", string(result.Rows[0][0])) + result := conn.ExecParams(ctx, "show application_name", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "myschema", string(result.Rows[0][0])) + result = conn.ExecParams(ctx, "show search_path", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "myschema", string(result.Rows[0][0])) + }) } func TestConnectWithFallback(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - // Prepend current primary config to fallbacks - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) + // Prepend current primary config to fallbacks + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) - // Make primary config bad - config.Host = "localhost" - config.Port = 1 // presumably nothing listening here + // Make primary config bad + config.Host = "localhost" + config.Port = 1 // presumably nothing listening here - // Prepend bad first fallback - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "localhost", - Port: 1, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) + // Prepend bad first fallback + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 1, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - closeConn(t, conn) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + closeConn(t, conn) + }) } func TestConnectWithValidateConnect(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - dialCount := 0 - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialCount++ - return net.Dial(network, address) - } - - acceptConnCount := 0 - config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { - acceptConnCount++ - if acceptConnCount < 2 { - return errors.New("reject first conn") + dialCount := 0 + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialCount++ + return net.Dial(network, address) } - return nil - } - // Append current primary config to fallbacks - config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, + acceptConnCount := 0 + config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + acceptConnCount++ + if acceptConnCount < 2 { + return errors.New("reject first conn") + } + return nil + } + + // Append current primary config to fallbacks + config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }) + + // Repeat fallbacks + config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + closeConn(t, conn) + + assert.True(t, dialCount > 1) + assert.True(t, acceptConnCount > 1) }) - - // Repeat fallbacks - config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) - - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - closeConn(t, conn) - - assert.True(t, dialCount > 1) - assert.True(t, acceptConnCount > 1) } func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite - config.RuntimeParams["default_transaction_read_only"] = "on" + config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite + config.RuntimeParams["default_transaction_read_only"] = "on" - conn, err := pgconn.ConnectConfig(context.Background(), config) - if !assert.NotNil(t, err) { - conn.Close(context.Background()) - } + conn, err := pgconn.ConnectConfig(ctx, config) + if !assert.NotNil(t, err) { + conn.Close(ctx) + } + }) } func TestConnectWithAfterConnect(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { - _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() - return err - } + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() + return err + } - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) - results, err := conn.Exec(context.Background(), "show search_path;").ReadAll() - require.NoError(t, err) - defer closeConn(t, conn) + results, err := conn.Exec(ctx, "show search_path;").ReadAll() + require.NoError(t, err) + defer closeConn(t, conn) - assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) + assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) + }) } func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { t.Parallel() - config := &pgconn.Config{} + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config := &pgconn.Config{} - require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) }) + require.PanicsWithValue( + t, + "config must be created by ParseConfig", + func() { pgconn.ConnectConfig(ctx, config) }, + ) + }) } func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) - require.Nil(t, psd) - require.NotNil(t, err) + psd, err := pgConn.Prepare(ctx, "ps1", "SYNTAX ERROR", nil) + require.Nil(t, psd) + require.NotNil(t, err) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnPrepareContextPrecanceled(t *testing.T) { @@ -388,116 +418,126 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { func TestConnExec(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() - assert.NoError(t, err) + results, err := pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecEmpty(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(context.Background(), ";") + multiResult := pgConn.Exec(ctx, ";") - resultCount := 0 - for multiResult.NextResult() { - resultCount++ - multiResult.ResultReader().Close() - } - assert.Equal(t, 0, resultCount) - err = multiResult.Close() - assert.NoError(t, err) + resultCount := 0 + for multiResult.NextResult() { + resultCount++ + multiResult.ResultReader().Close() + } + assert.Equal(t, 0, resultCount) + err = multiResult.Close() + assert.NoError(t, err) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecMultipleQueries(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() - assert.NoError(t, err) + results, err := pgConn.Exec(ctx, "select 'Hello, world'; select 1").ReadAll() + assert.NoError(t, err) - assert.Len(t, results, 2) + assert.Len(t, results, 2) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - assert.Nil(t, results[1].Err) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - assert.Len(t, results[1].Rows, 1) - assert.Equal(t, "1", string(results[1].Rows[0][0])) + assert.Nil(t, results[1].Err) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + assert.Len(t, results[1].Rows, 1) + assert.Equal(t, "1", string(results[1].Rows[0][0])) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecMultipleQueriesError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() - require.NotNil(t, err) - if pgErr, ok := err.(*pgconn.PgError); ok { - assert.Equal(t, "22012", pgErr.Code) - } else { - t.Errorf("unexpected error: %v", err) - } + results, err := pgConn.Exec(ctx, "select 1; select 1/0; select 1").ReadAll() + require.NotNil(t, err) + if pgErr, ok := err.(*pgconn.PgError); ok { + assert.Equal(t, "22012", pgErr.Code) + } else { + t.Errorf("unexpected error: %v", err) + } - assert.Len(t, results, 1) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "1", string(results[0].Rows[0][0])) + assert.Len(t, results, 1) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecDeferredError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(ctx, setupSQL).ReadAll() + assert.NoError(t, err) - _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll() - require.NotNil(t, err) + _, err = pgConn.Exec(ctx, `update t set n=n+1 where id='b' returning *`).ReadAll() + require.NotNil(t, err) - var pgErr *pgconn.PgError - require.True(t, errors.As(err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecContextCanceled(t *testing.T) { @@ -538,95 +578,103 @@ func TestConnExecContextPrecanceled(t *testing.T) { func TestConnExecParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) + result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecParamsDeferredError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(ctx, setupSQL).ReadAll() + assert.NoError(t, err) - result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() - require.NotNil(t, result.Err) - var pgErr *pgconn.PgError - require.True(t, errors.As(result.Err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + result := pgConn.ExecParams(ctx, `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() + require.NotNil(t, result.Err) + var pgErr *pgconn.PgError + require.True(t, errors.As(result.Err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecParamsMaxNumberOfParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, paramCount) + result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecParamsTooManyParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 + 1 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecParamsCanceled(t *testing.T) { @@ -671,86 +719,92 @@ func TestConnExecParamsPrecanceled(t *testing.T) { func TestConnExecPrepared(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, 1) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(ctx, "ps1", "select $1::text", nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) + result := pgConn.ExecPrepared(ctx, "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, paramCount) + result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecPreparedTooManyParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 + 1 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecPreparedCanceled(t *testing.T) { @@ -800,63 +854,67 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { func TestConnExecBatch(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) - require.NoError(t, err) + _, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil) + require.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) - batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) - results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, 3) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + results, err := pgConn.ExecBatch(ctx, batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, 3) - require.Len(t, results[0].Rows, 1) - require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + require.Len(t, results[0].Rows, 1) + require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - require.Len(t, results[1].Rows, 1) - require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + require.Len(t, results[1].Rows, 1) + require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - require.Len(t, results[2].Rows, 1) - require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) + require.Len(t, results[2].Rows, 1) + require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) + }) } func TestConnExecBatchDeferredError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(ctx, setupSQL).ReadAll() + assert.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) - _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.NotNil(t, err) - var pgErr *pgconn.PgError - require.True(t, errors.As(err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) + _, err = pgConn.ExecBatch(ctx, batch).ReadAll() + require.NotNil(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecBatchPrecanceled(t *testing.T) { @@ -895,76 +953,82 @@ func TestConnExecBatchHuge(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - queryCount := 100000 - args := make([]string, queryCount) + queryCount := 100000 + args := make([]string, queryCount) - for i := range args { - args[i] = strconv.Itoa(i) - batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) - } + for i := range args { + args[i] = strconv.Itoa(i) + batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) + } - results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, queryCount) + results, err := pgConn.ExecBatch(ctx, batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, queryCount) - for i := range args { - require.Len(t, results[i].Rows, 1) - require.Equal(t, args[i], string(results[i].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) - } + for i := range args { + require.Len(t, results[i].Rows, 1) + require.Equal(t, args[i], string(results[i].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) + } + }) } func TestConnExecBatchImplicitTransaction(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, "create temporary table t(id int)").ReadAll() + require.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) - batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) - batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) - batch.ExecParams("select 1/0", nil, nil, nil, nil) - _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.Error(t, err) + batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) + batch.ExecParams("select 1/0", nil, nil, nil, nil) + _, err = pgConn.ExecBatch(ctx, batch).ReadAll() + require.Error(t, err) - result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read() - require.Equal(t, "0", string(result.Rows[0][0])) + result := pgConn.ExecParams(ctx, "select count(*) from t", nil, nil, nil, nil).Read() + require.Equal(t, "0", string(result.Rows[0][0])) + }) } func TestConnLocking(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") - _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() - assert.Error(t, err) - assert.Equal(t, "conn busy", err.Error()) - assert.True(t, pgconn.SafeToRetry(err)) + mrr := pgConn.Exec(ctx, "select 'Hello, world'") + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "conn busy", err.Error()) + assert.True(t, pgconn.SafeToRetry(err)) - results, err := mrr.ReadAll() - assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + results, err := mrr.ReadAll() + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestCommandTag(t *testing.T) { @@ -993,91 +1057,97 @@ func TestCommandTag(t *testing.T) { func TestConnOnNotice(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { - msg = notice.Message - } + var msg string + config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { + msg = notice.Message + } - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(context.Background(), `do $$ -begin - raise notice 'hello, world'; -end$$;`) - err = multiResult.Close() - require.NoError(t, err) - assert.Equal(t, "hello, world", msg) + multiResult := pgConn.Exec(ctx, `do $$ + begin + raise notice 'hello, world'; + end$$;`) + err = multiResult.Close() + require.NoError(t, err) + assert.Equal(t, "hello, world", msg) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnOnNotification(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, "listen foo").ReadAll() + require.NoError(t, err) - notifier, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() - require.NoError(t, err) + notifier, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, "select 1").ReadAll() + require.NoError(t, err) - assert.Equal(t, "bar", msg) + assert.Equal(t, "bar", msg) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnWaitForNotification(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, "listen foo").ReadAll() + require.NoError(t, err) - notifier, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() - require.NoError(t, err) + notifier, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() + require.NoError(t, err) - err = pgConn.WaitForNotification(context.Background()) - require.NoError(t, err) + err = pgConn.WaitForNotification(ctx) + require.NoError(t, err) - assert.Equal(t, "bar", msg) + assert.Equal(t, "bar", msg) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnWaitForNotificationPrecanceled(t *testing.T) { @@ -1119,94 +1189,100 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { func TestConnCopyToSmall(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json - )`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json + )`).ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() + require.NoError(t, err) - inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + - "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") + inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") - require.NoError(t, err) + res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") + require.NoError(t, err) - assert.Equal(t, int64(2), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) + assert.Equal(t, int64(2), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnCopyToLarge(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json, - h bytea - )`).ReadAll() - require.NoError(t, err) - - inputBytes := make([]byte, 0) - - for i := 0; i < 1000; i++ { - _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) - } + defer closeConn(t, pgConn) - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json, + h bytea + )`).ReadAll() + require.NoError(t, err) - res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") - require.NoError(t, err) + inputBytes := make([]byte, 0) - assert.Equal(t, int64(1000), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) + for i := 0; i < 1000; i++ { + _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() + require.NoError(t, err) + inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) + } - ensureConnValid(t, pgConn) + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + + res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") + require.NoError(t, err) + + assert.Equal(t, int64(1000), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) + + ensureConnValid(t, pgConn) + }) } func TestConnCopyToQueryError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - outputWriter := bytes.NewBuffer(make([]byte, 0)) + outputWriter := bytes.NewBuffer(make([]byte, 0)) - res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyTo(ctx, outputWriter, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnCopyToCanceled(t *testing.T) { @@ -1250,37 +1326,39 @@ func TestConnCopyToPrecanceled(t *testing.T) { func TestConnCopyFrom(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) - - srcBuf := &bytes.Buffer{} - - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - } + defer closeConn(t, pgConn) - ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) - result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) + srcBuf := &bytes.Buffer{} - assert.Equal(t, inputRows, result.Rows) + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } - ensureConnValid(t, pgConn) + ct, err := pgConn.CopyFrom(ctx, srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) + }) } func TestConnCopyFromCanceled(t *testing.T) { @@ -1358,153 +1436,163 @@ func TestConnCopyFromPrecanceled(t *testing.T) { func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) - - f, err := ioutil.TempFile("", "*") - require.NoError(t, err) - - gw := gzip.NewWriter(f) - - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - } + defer closeConn(t, pgConn) - err = gw.Close() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) - _, err = f.Seek(0, 0) - require.NoError(t, err) + f, err := ioutil.TempFile("", "*") + require.NoError(t, err) - gr, err := gzip.NewReader(f) - require.NoError(t, err) + gw := gzip.NewWriter(f) - ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } - err = gr.Close() - require.NoError(t, err) + err = gw.Close() + require.NoError(t, err) - err = f.Close() - require.NoError(t, err) + _, err = f.Seek(0, 0) + require.NoError(t, err) - err = os.Remove(f.Name()) - require.NoError(t, err) + gr, err := gzip.NewReader(f) + require.NoError(t, err) - result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) + ct, err := pgConn.CopyFrom(ctx, gr, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - assert.Equal(t, inputRows, result.Rows) + err = gr.Close() + require.NoError(t, err) - ensureConnValid(t, pgConn) + err = f.Close() + require.NoError(t, err) + + err = os.Remove(f.Name()) + require.NoError(t, err) + + result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) + }) } func TestConnCopyFromQuerySyntaxError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) - srcBuf := &bytes.Buffer{} + srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyFrom(ctx, srcBuf, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnCopyFromQueryNoTableError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - srcBuf := &bytes.Buffer{} + srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(context.Background(), srcBuf, "copy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyFrom(ctx, srcBuf, "copy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnEscapeString(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - tests := []struct { - in string - out string - }{ - {in: "", out: ""}, - {in: "42", out: "42"}, - {in: "'", out: "''"}, - {in: "hi'there", out: "hi''there"}, - {in: "'hi there'", out: "''hi there''"}, - } - - for i, tt := range tests { - value, err := pgConn.EscapeString(tt.in) - if assert.NoErrorf(t, err, "%d.", i) { - assert.Equalf(t, tt.out, value, "%d.", i) + tests := []struct { + in string + out string + }{ + {in: "", out: ""}, + {in: "42", out: "42"}, + {in: "'", out: "''"}, + {in: "hi'there", out: "hi''there"}, + {in: "'hi there'", out: "''hi there''"}, } - } - ensureConnValid(t, pgConn) + for i, tt := range tests { + value, err := pgConn.EscapeString(tt.in) + if assert.NoErrorf(t, err, "%d.", i) { + assert.Equalf(t, tt.out, value, "%d.", i) + } + } + + ensureConnValid(t, pgConn) + }) } func TestConnCancelRequest(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") + multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(2)") - // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a - // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a - // few milliseconds. - time.Sleep(50 * time.Millisecond) + // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a + // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a + // few milliseconds. + time.Sleep(50 * time.Millisecond) - err = pgConn.CancelRequest(context.Background()) - require.NoError(t, err) + err = pgConn.CancelRequest(ctx) + require.NoError(t, err) - for multiResult.NextResult() { - } - err = multiResult.Close() + for multiResult.NextResult() { + } + err = multiResult.Close() - require.IsType(t, &pgconn.PgError{}, err) - require.Equal(t, "57014", err.(*pgconn.PgError).Code) + require.IsType(t, &pgconn.PgError{}, err) + require.Equal(t, "57014", err.(*pgconn.PgError).Code) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnSendBytesAndReceiveMessage(t *testing.T) { @@ -1547,13 +1635,13 @@ func TestConnSendBytesAndReceiveMessage(t *testing.T) { } func Example() { - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(nil, os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { log.Fatalln(err) } - defer pgConn.Close(context.Background()) + defer pgConn.Close(nil) - result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read() + result := pgConn.ExecParams(nil, "select generate_series(1,3)", nil, nil, nil, nil).Read() if result.Err != nil { log.Fatalln(result.Err) } From 93722181071cd124ad5bb67122d33b31d4ada632 Mon Sep 17 00:00:00 2001 From: bakape Date: Wed, 1 Jan 2020 19:34:56 +0200 Subject: [PATCH 0383/1158] Don't synchronize with context.Background() --- benchmark_test.go | 12 +++++++---- doc.go | 4 +--- pgconn.go | 52 +++++++++++++++++++++++++++++++++-------------- 3 files changed, 46 insertions(+), 22 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 1914e07a..4cce5a97 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -18,8 +18,10 @@ func BenchmarkConnect(b *testing.B) { }{ // The first benchmark in the list sometimes executes faster, no matter how // you reorder it. Nil context is still faster on average. - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING", context.Background()}, - {"TCP", "PGX_TEST_TCP_CONN_STRING", context.Background()}, + // + // Using and empty context other than context.Background() to compare. + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING", context.TODO()}, + {"TCP", "PGX_TEST_TCP_CONN_STRING", context.TODO()}, {"Unix socket nil context", "PGX_TEST_UNIX_SOCKET_CONN_STRING", nil}, {"TCP nil context", "PGX_TEST_TCP_CONN_STRING", nil}, } @@ -49,7 +51,8 @@ func BenchmarkExec(b *testing.B) { name string ctx context.Context }{ - {"background context", context.Background()}, + // Using and empty context other than context.Background() to compare. + {"empty context", context.TODO()}, {"nil context", nil}, } @@ -153,7 +156,8 @@ func BenchmarkExecPrepared(b *testing.B) { name string ctx context.Context }{ - {"background context", context.Background()}, + // Using and empty context other than context.Background() to compare. + {"empty context", context.TODO()}, {"nil context", nil}, } diff --git a/doc.go b/doc.go index 12ed6630..25382c68 100644 --- a/doc.go +++ b/doc.go @@ -22,9 +22,7 @@ Context Support All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the method immediately returns. In most circumstances, this will close the underlying connection. - -A nil context can be passed for convenience. This has the same effect as passing context.Background() with an additional -slight performance increase, if you don't need the operation to be cancellable. +A nil context can be passed for convenience. This has the same effect as passing context.Background(). The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the client to abort. diff --git a/pgconn.go b/pgconn.go index 3b90b802..b8ea9df7 100644 --- a/pgconn.go +++ b/pgconn.go @@ -366,7 +366,9 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { } defer pgConn.unlock() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): return &contextAlreadyDoneError{err: ctx.Err()} @@ -398,7 +400,9 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa } defer pgConn.unlock() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -497,7 +501,9 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } @@ -596,7 +602,9 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ } defer pgConn.unlock() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -695,7 +703,9 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { } defer cancelConn.Close() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: contextWatcher := ctxwatch.NewContextWatcher( func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, func() { cancelConn.SetDeadline(time.Time{}) }, @@ -730,7 +740,9 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { } defer pgConn.unlock() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): return ctx.Err() @@ -772,7 +784,11 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { ctx: ctx, } multiResult := &pgConn.multiResultReader - if ctx != nil { + switch ctx { + case nil: + pgConn.multiResultReader.ctx = context.Background() + case context.Background(): + default: select { case <-ctx.Done(): multiResult.closed = true @@ -782,8 +798,6 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { default: } pgConn.contextWatcher.Watch(ctx) - } else { - pgConn.multiResultReader.ctx = context.Background() } buf := pgConn.wbuf @@ -885,7 +899,9 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) @@ -921,7 +937,9 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, err } - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): pgConn.unlock() @@ -982,7 +1000,9 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } defer pgConn.unlock() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -1376,7 +1396,11 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR ctx: ctx, } multiResult := &pgConn.multiResultReader - if ctx != nil { + switch ctx { + case nil: + pgConn.multiResultReader.ctx = context.Background() + case context.Background(): + default: select { case <-ctx.Done(): multiResult.closed = true @@ -1386,8 +1410,6 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR default: } pgConn.contextWatcher.Watch(ctx) - } else { - pgConn.multiResultReader.ctx = context.Background() } batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) From 98b3c57584a2bde785c3f706afcd3d371d6faec3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 8 Jan 2020 10:02:32 -0600 Subject: [PATCH 0384/1158] Try to cancel any in-progress query when a conn is closed by ctx cancel See https://github.com/jackc/pgx/issues/659 --- pgconn.go | 58 +++++++++++++++++++++++++++++++------------------- pgconn_test.go | 43 +++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 22 deletions(-) diff --git a/pgconn.go b/pgconn.go index 4c75d367..70c33c4f 100644 --- a/pgconn.go +++ b/pgconn.go @@ -372,7 +372,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return &writeError{err: err, safeToRetry: n == 0} } @@ -429,7 +429,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { if err != nil { // Close on anything other than timeout error - everything else is fatal if err, ok := err.(net.Error); !(ok && err.Timeout()) { - pgConn.hardClose() + pgConn.ayncClose() } return nil, err @@ -442,7 +442,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: if msg.Severity == "FATAL" { - pgConn.hardClose() + pgConn.ayncClose() return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: @@ -503,14 +503,28 @@ func (pgConn *PgConn) Close(ctx context.Context) error { return pgConn.conn.Close() } -// hardClose closes the underlying connection without sending the exit message. -func (pgConn *PgConn) hardClose() error { +// ayncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying +// connection. +func (pgConn *PgConn) ayncClose() { if pgConn.status == connStatusClosed { - return nil + return } pgConn.status = connStatusClosed - return pgConn.conn.Close() + go func() { + defer pgConn.conn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + + pgConn.CancelRequest(ctx) + + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + + pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) + pgConn.conn.Read(make([]byte, 1)) + }() } // IsClosed reports if the connection has been closed. @@ -601,7 +615,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0} } @@ -613,7 +627,7 @@ readloop: for { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, err } @@ -768,7 +782,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() pgConn.contextWatcher.Unwatch() multiResult.closed = true multiResult.err = &writeError{err: err, safeToRetry: n == 0} @@ -879,7 +893,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0}) pgConn.contextWatcher.Unwatch() result.closed = true @@ -908,7 +922,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() pgConn.unlock() return nil, &writeError{err: err, safeToRetry: n == 0} } @@ -919,7 +933,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm for { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, err } @@ -928,7 +942,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm case *pgproto3.CopyData: _, err := w.Write(msg.Data) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, err } case *pgproto3.ReadyForQuery: @@ -966,7 +980,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, &writeError{err: err, safeToRetry: n == 0} } @@ -977,7 +991,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co for pendingCopyInResponse { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, err } @@ -1006,7 +1020,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, err } } @@ -1015,7 +1029,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case <-signalMessageChan: msg, err := pgConn.receiveMessage() if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, err } @@ -1039,7 +1053,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } _, err = pgConn.conn.Write(buf) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, err } @@ -1047,7 +1061,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co for { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, err } @@ -1092,7 +1106,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) mrr.pgConn.contextWatcher.Unwatch() mrr.err = err mrr.closed = true - mrr.pgConn.hardClose() + mrr.pgConn.ayncClose() return nil, mrr.err } @@ -1281,7 +1295,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error rr.pgConn.contextWatcher.Unwatch() rr.closed = true if rr.multiResultReader == nil { - rr.pgConn.hardClose() + rr.pgConn.ayncClose() } return nil, rr.err diff --git a/pgconn_test.go b/pgconn_test.go index 6b57dd09..7ae6fdc5 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1507,6 +1507,49 @@ func TestConnCancelRequest(t *testing.T) { ensureConnValid(t, pgConn) } +// https://github.com/jackc/pgx/issues/659 +func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pid := pgConn.PID() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(30)") + + for multiResult.NextResult() { + } + err = multiResult.Close() + assert.True(t, pgconn.Timeout(err)) + assert.True(t, pgConn.IsClosed()) + + otherConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, otherConn) + + ctx, cancel = context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + for { + result := otherConn.ExecParams(ctx, + `select 1 from pg_stat_activity where pid=$1`, + [][]byte{[]byte(strconv.FormatInt(int64(pid), 10))}, + nil, + nil, + nil, + ).Read() + require.NoError(t, result.Err) + + if len(result.Rows) == 0 { + break + } + } +} + func TestConnSendBytesAndReceiveMessage(t *testing.T) { t.Parallel() From 9decdbc2ec3357581cd2911141b3ead982f5026f Mon Sep 17 00:00:00 2001 From: bakape Date: Sat, 11 Jan 2020 16:53:50 +0200 Subject: [PATCH 0385/1158] Revert nil context support --- README.md | 4 +- benchmark_test.go | 25 +- doc.go | 1 - helper_test.go | 22 - pgconn.go | 62 +- pgconn_test.go | 1500 +++++++++++++++++++++------------------------ 6 files changed, 731 insertions(+), 883 deletions(-) diff --git a/README.md b/README.md index ddbfeaf3..5d14e914 100644 --- a/README.md +++ b/README.md @@ -11,13 +11,13 @@ low-level access to PostgreSQL functionality. ## Example Usage ```go -pgConn, err := pgconn.Connect(nil, os.Getenv("DATABASE_URL")) +pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { log.Fatalln("pgconn failed to connect:", err) } defer pgConn.Close() -result := pgConn.ExecParams(nil, "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) for result.NextRow() { fmt.Println("User 123 has email:", string(result.Values()[0])) } diff --git a/benchmark_test.go b/benchmark_test.go index 4cce5a97..3295a90f 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -14,16 +14,9 @@ func BenchmarkConnect(b *testing.B) { benchmarks := []struct { name string env string - ctx context.Context }{ - // The first benchmark in the list sometimes executes faster, no matter how - // you reorder it. Nil context is still faster on average. - // - // Using and empty context other than context.Background() to compare. - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING", context.TODO()}, - {"TCP", "PGX_TEST_TCP_CONN_STRING", context.TODO()}, - {"Unix socket nil context", "PGX_TEST_UNIX_SOCKET_CONN_STRING", nil}, - {"TCP nil context", "PGX_TEST_TCP_CONN_STRING", nil}, + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, } for _, bm := range benchmarks { @@ -35,10 +28,10 @@ func BenchmarkConnect(b *testing.B) { } for i := 0; i < b.N; i++ { - conn, err := pgconn.Connect(bm.ctx, connString) + conn, err := pgconn.Connect(context.Background(), connString) require.Nil(b, err) - err = conn.Close(bm.ctx) + err = conn.Close(context.Background()) require.Nil(b, err) } }) @@ -51,9 +44,10 @@ func BenchmarkExec(b *testing.B) { name string ctx context.Context }{ - // Using and empty context other than context.Background() to compare. + // Using an empty context other than context.Background() to compare + // performance + {"background context", context.Background()}, {"empty context", context.TODO()}, - {"nil context", nil}, } for _, bm := range benchmarks { @@ -156,9 +150,10 @@ func BenchmarkExecPrepared(b *testing.B) { name string ctx context.Context }{ - // Using and empty context other than context.Background() to compare. + // Using an empty context other than context.Background() to compare + // performance + {"background context", context.Background()}, {"empty context", context.TODO()}, - {"nil context", nil}, } for _, bm := range benchmarks { diff --git a/doc.go b/doc.go index 25382c68..cde58cd8 100644 --- a/doc.go +++ b/doc.go @@ -22,7 +22,6 @@ Context Support All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the method immediately returns. In most circumstances, this will close the underlying connection. -A nil context can be passed for convenience. This has the same effect as passing context.Background(). The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the client to abort. diff --git a/helper_test.go b/helper_test.go index 1cb05fd2..1a3ca75e 100644 --- a/helper_test.go +++ b/helper_test.go @@ -29,25 +29,3 @@ func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { assert.Equal(t, "2", string(result.Rows[1][0])) assert.Equal(t, "3", string(result.Rows[2][0])) } - -// Run subtest both with a context.Background() and nil context -func splitOnContext(t *testing.T, test func(t *testing.T, ctx context.Context)) { - t.Helper() - - cases := [...]struct { - name string - ctx context.Context - }{ - {"background context", context.Background()}, - {"nil context", nil}, - } - - for i := range cases { - c := cases[i] - t.Run(c.name, func(t *testing.T) { - t.Helper() - t.Parallel() - test(t, c.ctx) - }) - } -} diff --git a/pgconn.go b/pgconn.go index b8ea9df7..9763b319 100644 --- a/pgconn.go +++ b/pgconn.go @@ -116,10 +116,6 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err panic("config must be created by ParseConfig") } - if ctx == nil { - ctx = context.Background() - } - // Simplify usage by treating primary config and fallbacks the same. fallbackConfigs := []*FallbackConfig{ { @@ -366,9 +362,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return &contextAlreadyDoneError{err: ctx.Err()} @@ -400,9 +394,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -501,9 +493,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } @@ -602,9 +592,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -693,19 +681,13 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() - _ctx := ctx - if _ctx == nil { - _ctx = context.Background() - } - cancelConn, err := pgConn.config.DialFunc(_ctx, serverAddr.Network(), serverAddr.String()) + cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err } defer cancelConn.Close() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { contextWatcher := ctxwatch.NewContextWatcher( func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, func() { cancelConn.SetDeadline(time.Time{}) }, @@ -740,9 +722,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return ctx.Err() @@ -784,11 +764,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { ctx: ctx, } multiResult := &pgConn.multiResultReader - switch ctx { - case nil: - pgConn.multiResultReader.ctx = context.Background() - case context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): multiResult.closed = true @@ -882,9 +858,6 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by ctx: ctx, } result := &pgConn.resultReader - if ctx == nil { - pgConn.resultReader.ctx = context.Background() - } if err := pgConn.lock(); err != nil { result.concludeCommand(nil, err) @@ -899,9 +872,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) @@ -937,9 +908,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, err } - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): pgConn.unlock() @@ -1000,9 +969,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -1396,11 +1363,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR ctx: ctx, } multiResult := &pgConn.multiResultReader - switch ctx { - case nil: - pgConn.multiResultReader.ctx = context.Background() - case context.Background(): - default: + + if ctx != context.Background() { select { case <-ctx.Done(): multiResult.closed = true diff --git a/pgconn_test.go b/pgconn_test.go index 30d20229..6b57dd09 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -27,33 +27,31 @@ import ( ) func TestConnect(t *testing.T) { - splitOnContext(t, func(t *testing.T, ctx context.Context) { - tests := []struct { - name string - env string - }{ - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, - {"TCP", "PGX_TEST_TCP_CONN_STRING"}, - {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, - {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, - {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, - } + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, + } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - connString := os.Getenv(tt.env) - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", tt.env) - } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } - conn, err := pgconn.Connect(ctx, connString) - require.NoError(t, err) + conn, err := pgconn.Connect(context.Background(), connString) + require.NoError(t, err) - closeConn(t, conn) - }) - } - }) + closeConn(t, conn) + }) + } } // TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure @@ -61,21 +59,19 @@ func TestConnect(t *testing.T) { func TestConnectTLS(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") - } + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } - conn, err := pgconn.Connect(ctx, connString) - require.NoError(t, err) + conn, err := pgconn.Connect(context.Background(), connString) + require.NoError(t, err) - if _, ok := conn.Conn().(*tls.Conn); !ok { - t.Error("not a TLS connection") - } + if _, ok := conn.Conn().(*tls.Conn); !ok { + t.Error("not a TLS connection") + } - closeConn(t, conn) - }) + closeConn(t, conn) } type pgmockWaitStep time.Duration @@ -142,259 +138,233 @@ func TestConnectWithContextThatTimesOut(t *testing.T) { func TestConnectInvalidUser(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) - config.User = "pgxinvalidusertest" + config.User = "pgxinvalidusertest" - _, err = pgconn.ConnectConfig(ctx, config) - require.Error(t, err) - pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) - if !ok { - t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) - } - if pgErr.Code != "28000" && pgErr.Code != "28P01" { - t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) - } - }) + _, err = pgconn.ConnectConfig(context.Background(), config) + require.Error(t, err) + pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) + if !ok { + t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) + } + if pgErr.Code != "28000" && pgErr.Code != "28P01" { + t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) + } } func TestConnectWithConnectionRefused(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - // Presumably nothing is listening on 127.0.0.1:1 - conn, err := pgconn.Connect(ctx, "host=127.0.0.1 port=1") - if err == nil { - conn.Close(ctx) - t.Fatal("Expected error establishing connection to bad port") - } - }) + // Presumably nothing is listening on 127.0.0.1:1 + conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") + if err == nil { + conn.Close(context.Background()) + t.Fatal("Expected error establishing connection to bad port") + } } func TestConnectCustomDialer(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - dialed := false - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialed = true - return net.Dial(network, address) - } + dialed := false + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialed = true + return net.Dial(network, address) + } - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - require.True(t, dialed) - closeConn(t, conn) - }) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, dialed) + closeConn(t, conn) } func TestConnectCustomLookup(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) - looked := false - config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { - looked = true - return net.LookupHost(host) - } + looked := false + config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { + looked = true + return net.LookupHost(host) + } - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - require.True(t, looked) - closeConn(t, conn) - }) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) } func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.RuntimeParams = map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - } + config.RuntimeParams = map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + } - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, conn) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, conn) - result := conn.ExecParams(ctx, "show application_name", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "pgxtest", string(result.Rows[0][0])) + result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result = conn.ExecParams(ctx, "show search_path", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "myschema", string(result.Rows[0][0])) - }) + result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "myschema", string(result.Rows[0][0])) } func TestConnectWithFallback(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - // Prepend current primary config to fallbacks - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) + // Prepend current primary config to fallbacks + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) - // Make primary config bad - config.Host = "localhost" - config.Port = 1 // presumably nothing listening here + // Make primary config bad + config.Host = "localhost" + config.Port = 1 // presumably nothing listening here - // Prepend bad first fallback - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "localhost", - Port: 1, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) + // Prepend bad first fallback + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 1, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - closeConn(t, conn) - }) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + closeConn(t, conn) } func TestConnectWithValidateConnect(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - dialCount := 0 - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialCount++ - return net.Dial(network, address) + dialCount := 0 + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialCount++ + return net.Dial(network, address) + } + + acceptConnCount := 0 + config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + acceptConnCount++ + if acceptConnCount < 2 { + return errors.New("reject first conn") } + return nil + } - acceptConnCount := 0 - config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { - acceptConnCount++ - if acceptConnCount < 2 { - return errors.New("reject first conn") - } - return nil - } - - // Append current primary config to fallbacks - config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }) - - // Repeat fallbacks - config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) - - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - closeConn(t, conn) - - assert.True(t, dialCount > 1) - assert.True(t, acceptConnCount > 1) + // Append current primary config to fallbacks + config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, }) + + // Repeat fallbacks + config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + closeConn(t, conn) + + assert.True(t, dialCount > 1) + assert.True(t, acceptConnCount > 1) } func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite - config.RuntimeParams["default_transaction_read_only"] = "on" + config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite + config.RuntimeParams["default_transaction_read_only"] = "on" - conn, err := pgconn.ConnectConfig(ctx, config) - if !assert.NotNil(t, err) { - conn.Close(ctx) - } - }) + conn, err := pgconn.ConnectConfig(context.Background(), config) + if !assert.NotNil(t, err) { + conn.Close(context.Background()) + } } func TestConnectWithAfterConnect(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { - _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() - return err - } + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() + return err + } - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) - results, err := conn.Exec(ctx, "show search_path;").ReadAll() - require.NoError(t, err) - defer closeConn(t, conn) + results, err := conn.Exec(context.Background(), "show search_path;").ReadAll() + require.NoError(t, err) + defer closeConn(t, conn) - assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) - }) + assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) } func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config := &pgconn.Config{} + config := &pgconn.Config{} - require.PanicsWithValue( - t, - "config must be created by ParseConfig", - func() { pgconn.ConnectConfig(ctx, config) }, - ) - }) + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) }) } func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(ctx, "ps1", "SYNTAX ERROR", nil) - require.Nil(t, psd) - require.NotNil(t, err) + psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) + require.Nil(t, psd) + require.NotNil(t, err) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnPrepareContextPrecanceled(t *testing.T) { @@ -418,126 +388,116 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { func TestConnExec(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() - assert.NoError(t, err) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecEmpty(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(ctx, ";") + multiResult := pgConn.Exec(context.Background(), ";") - resultCount := 0 - for multiResult.NextResult() { - resultCount++ - multiResult.ResultReader().Close() - } - assert.Equal(t, 0, resultCount) - err = multiResult.Close() - assert.NoError(t, err) + resultCount := 0 + for multiResult.NextResult() { + resultCount++ + multiResult.ResultReader().Close() + } + assert.Equal(t, 0, resultCount) + err = multiResult.Close() + assert.NoError(t, err) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecMultipleQueries(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(ctx, "select 'Hello, world'; select 1").ReadAll() - assert.NoError(t, err) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() + assert.NoError(t, err) - assert.Len(t, results, 2) + assert.Len(t, results, 2) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - assert.Nil(t, results[1].Err) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - assert.Len(t, results[1].Rows, 1) - assert.Equal(t, "1", string(results[1].Rows[0][0])) + assert.Nil(t, results[1].Err) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + assert.Len(t, results[1].Rows, 1) + assert.Equal(t, "1", string(results[1].Rows[0][0])) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecMultipleQueriesError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(ctx, "select 1; select 1/0; select 1").ReadAll() - require.NotNil(t, err) - if pgErr, ok := err.(*pgconn.PgError); ok { - assert.Equal(t, "22012", pgErr.Code) - } else { - t.Errorf("unexpected error: %v", err) - } + results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() + require.NotNil(t, err) + if pgErr, ok := err.(*pgconn.PgError); ok { + assert.Equal(t, "22012", pgErr.Code) + } else { + t.Errorf("unexpected error: %v", err) + } - assert.Len(t, results, 1) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "1", string(results[0].Rows[0][0])) + assert.Len(t, results, 1) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecDeferredError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(ctx, setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) - _, err = pgConn.Exec(ctx, `update t set n=n+1 where id='b' returning *`).ReadAll() - require.NotNil(t, err) + _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll() + require.NotNil(t, err) - var pgErr *pgconn.PgError - require.True(t, errors.As(err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecContextCanceled(t *testing.T) { @@ -578,103 +538,95 @@ func TestConnExecContextPrecanceled(t *testing.T) { func TestConnExecParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) + result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecParamsDeferredError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(ctx, setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) - result := pgConn.ExecParams(ctx, `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() - require.NotNil(t, result.Err) - var pgErr *pgconn.PgError - require.True(t, errors.As(result.Err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() + require.NotNil(t, result.Err) + var pgErr *pgconn.PgError + require.True(t, errors.As(result.Err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecParamsMaxNumberOfParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, paramCount) + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecParamsTooManyParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 + 1 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecParamsCanceled(t *testing.T) { @@ -719,92 +671,86 @@ func TestConnExecParamsPrecanceled(t *testing.T) { func TestConnExecPrepared(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(ctx, "ps1", "select $1::text", nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, 1) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(ctx, "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) + result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, paramCount) + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecPreparedTooManyParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 + 1 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecPreparedCanceled(t *testing.T) { @@ -854,67 +800,63 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { func TestConnExecBatch(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil) - require.NoError(t, err) + _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) - batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) - results, err := pgConn.ExecBatch(ctx, batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, 3) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, 3) - require.Len(t, results[0].Rows, 1) - require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + require.Len(t, results[0].Rows, 1) + require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - require.Len(t, results[1].Rows, 1) - require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + require.Len(t, results[1].Rows, 1) + require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - require.Len(t, results[2].Rows, 1) - require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) - }) + require.Len(t, results[2].Rows, 1) + require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } func TestConnExecBatchDeferredError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(ctx, setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) - _, err = pgConn.ExecBatch(ctx, batch).ReadAll() - require.NotNil(t, err) - var pgErr *pgconn.PgError - require.True(t, errors.As(err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NotNil(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecBatchPrecanceled(t *testing.T) { @@ -953,82 +895,76 @@ func TestConnExecBatchHuge(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - queryCount := 100000 - args := make([]string, queryCount) + queryCount := 100000 + args := make([]string, queryCount) - for i := range args { - args[i] = strconv.Itoa(i) - batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) - } + for i := range args { + args[i] = strconv.Itoa(i) + batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) + } - results, err := pgConn.ExecBatch(ctx, batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, queryCount) + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, queryCount) - for i := range args { - require.Len(t, results[i].Rows, 1) - require.Equal(t, args[i], string(results[i].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) - } - }) + for i := range args { + require.Len(t, results[i].Rows, 1) + require.Equal(t, args[i], string(results[i].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) + } } func TestConnExecBatchImplicitTransaction(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, "create temporary table t(id int)").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll() + require.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) - batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) - batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) - batch.ExecParams("select 1/0", nil, nil, nil, nil) - _, err = pgConn.ExecBatch(ctx, batch).ReadAll() - require.Error(t, err) + batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) + batch.ExecParams("select 1/0", nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.Error(t, err) - result := pgConn.ExecParams(ctx, "select count(*) from t", nil, nil, nil, nil).Read() - require.Equal(t, "0", string(result.Rows[0][0])) - }) + result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read() + require.Equal(t, "0", string(result.Rows[0][0])) } func TestConnLocking(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - mrr := pgConn.Exec(ctx, "select 'Hello, world'") - _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() - assert.Error(t, err) - assert.Equal(t, "conn busy", err.Error()) - assert.True(t, pgconn.SafeToRetry(err)) + mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") + _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "conn busy", err.Error()) + assert.True(t, pgconn.SafeToRetry(err)) - results, err := mrr.ReadAll() - assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + results, err := mrr.ReadAll() + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestCommandTag(t *testing.T) { @@ -1057,97 +993,91 @@ func TestCommandTag(t *testing.T) { func TestConnOnNotice(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { - msg = notice.Message - } + var msg string + config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { + msg = notice.Message + } - pgConn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(ctx, `do $$ - begin - raise notice 'hello, world'; - end$$;`) - err = multiResult.Close() - require.NoError(t, err) - assert.Equal(t, "hello, world", msg) + multiResult := pgConn.Exec(context.Background(), `do $$ +begin + raise notice 'hello, world'; +end$$;`) + err = multiResult.Close() + require.NoError(t, err) + assert.Equal(t, "hello, world", msg) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnOnNotification(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } - pgConn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, "listen foo").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.NoError(t, err) - notifier, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() - require.NoError(t, err) + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(ctx, "select 1").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() + require.NoError(t, err) - assert.Equal(t, "bar", msg) + assert.Equal(t, "bar", msg) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnWaitForNotification(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } - pgConn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, "listen foo").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.NoError(t, err) - notifier, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() - require.NoError(t, err) + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.NoError(t, err) - err = pgConn.WaitForNotification(ctx) - require.NoError(t, err) + err = pgConn.WaitForNotification(context.Background()) + require.NoError(t, err) - assert.Equal(t, "bar", msg) + assert.Equal(t, "bar", msg) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnWaitForNotificationPrecanceled(t *testing.T) { @@ -1189,100 +1119,94 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { func TestConnCopyToSmall(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json - )`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json + )`).ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(ctx, `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() + require.NoError(t, err) - inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + - "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") + inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") - require.NoError(t, err) + res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") + require.NoError(t, err) - assert.Equal(t, int64(2), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) + assert.Equal(t, int64(2), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyToLarge(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json, + h bytea + )`).ReadAll() + require.NoError(t, err) + + inputBytes := make([]byte, 0) + + for i := 0; i < 1000; i++ { + _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() require.NoError(t, err) - defer closeConn(t, pgConn) + inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) + } - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json, - h bytea - )`).ReadAll() - require.NoError(t, err) + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - inputBytes := make([]byte, 0) + res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") + require.NoError(t, err) - for i := 0; i < 1000; i++ { - _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() - require.NoError(t, err) - inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) - } + assert.Equal(t, int64(1000), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - - res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") - require.NoError(t, err) - - assert.Equal(t, int64(1000), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) - - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyToQueryError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - outputWriter := bytes.NewBuffer(make([]byte, 0)) + outputWriter := bytes.NewBuffer(make([]byte, 0)) - res, err := pgConn.CopyTo(ctx, outputWriter, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyToCanceled(t *testing.T) { @@ -1326,39 +1250,37 @@ func TestConnCopyToPrecanceled(t *testing.T) { func TestConnCopyFrom(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + srcBuf := &bytes.Buffer{} + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) require.NoError(t, err) - defer closeConn(t, pgConn) + } - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) + ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - srcBuf := &bytes.Buffer{} + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) - require.NoError(t, err) - } + assert.Equal(t, inputRows, result.Rows) - ct, err := pgConn.CopyFrom(ctx, srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - - result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - - assert.Equal(t, inputRows, result.Rows) - - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyFromCanceled(t *testing.T) { @@ -1436,163 +1358,153 @@ func TestConnCopyFromPrecanceled(t *testing.T) { func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + f, err := ioutil.TempFile("", "*") + require.NoError(t, err) + + gw := gzip.NewWriter(f) + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) require.NoError(t, err) - defer closeConn(t, pgConn) + } - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) + err = gw.Close() + require.NoError(t, err) - f, err := ioutil.TempFile("", "*") - require.NoError(t, err) + _, err = f.Seek(0, 0) + require.NoError(t, err) - gw := gzip.NewWriter(f) + gr, err := gzip.NewReader(f) + require.NoError(t, err) - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) - require.NoError(t, err) - } + ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - err = gw.Close() - require.NoError(t, err) + err = gr.Close() + require.NoError(t, err) - _, err = f.Seek(0, 0) - require.NoError(t, err) + err = f.Close() + require.NoError(t, err) - gr, err := gzip.NewReader(f) - require.NoError(t, err) + err = os.Remove(f.Name()) + require.NoError(t, err) - ct, err := pgConn.CopyFrom(ctx, gr, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) - err = gr.Close() - require.NoError(t, err) + assert.Equal(t, inputRows, result.Rows) - err = f.Close() - require.NoError(t, err) - - err = os.Remove(f.Name()) - require.NoError(t, err) - - result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - - assert.Equal(t, inputRows, result.Rows) - - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyFromQuerySyntaxError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) - srcBuf := &bytes.Buffer{} + srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(ctx, srcBuf, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyFromQueryNoTableError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - srcBuf := &bytes.Buffer{} + srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(ctx, srcBuf, "copy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "copy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnEscapeString(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - tests := []struct { - in string - out string - }{ - {in: "", out: ""}, - {in: "42", out: "42"}, - {in: "'", out: "''"}, - {in: "hi'there", out: "hi''there"}, - {in: "'hi there'", out: "''hi there''"}, + tests := []struct { + in string + out string + }{ + {in: "", out: ""}, + {in: "42", out: "42"}, + {in: "'", out: "''"}, + {in: "hi'there", out: "hi''there"}, + {in: "'hi there'", out: "''hi there''"}, + } + + for i, tt := range tests { + value, err := pgConn.EscapeString(tt.in) + if assert.NoErrorf(t, err, "%d.", i) { + assert.Equalf(t, tt.out, value, "%d.", i) } + } - for i, tt := range tests { - value, err := pgConn.EscapeString(tt.in) - if assert.NoErrorf(t, err, "%d.", i) { - assert.Equalf(t, tt.out, value, "%d.", i) - } - } - - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCancelRequest(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(2)") + multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") - // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a - // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a - // few milliseconds. - time.Sleep(50 * time.Millisecond) + // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a + // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a + // few milliseconds. + time.Sleep(50 * time.Millisecond) - err = pgConn.CancelRequest(ctx) - require.NoError(t, err) + err = pgConn.CancelRequest(context.Background()) + require.NoError(t, err) - for multiResult.NextResult() { - } - err = multiResult.Close() + for multiResult.NextResult() { + } + err = multiResult.Close() - require.IsType(t, &pgconn.PgError{}, err) - require.Equal(t, "57014", err.(*pgconn.PgError).Code) + require.IsType(t, &pgconn.PgError{}, err) + require.Equal(t, "57014", err.(*pgconn.PgError).Code) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnSendBytesAndReceiveMessage(t *testing.T) { @@ -1635,13 +1547,13 @@ func TestConnSendBytesAndReceiveMessage(t *testing.T) { } func Example() { - pgConn, err := pgconn.Connect(nil, os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { log.Fatalln(err) } - defer pgConn.Close(nil) + defer pgConn.Close(context.Background()) - result := pgConn.ExecParams(nil, "select generate_series(1,3)", nil, nil, nil, nil).Read() + result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read() if result.Err != nil { log.Fatalln(result.Err) } From b6669ae6dda06f53fe221f80507123d967f7f099 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Jan 2020 18:23:41 -0600 Subject: [PATCH 0386/1158] Add PgError.SQLState method fixes #15 --- errors.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/errors.go b/errors.go index a088dcdd..7a21af98 100644 --- a/errors.go +++ b/errors.go @@ -55,6 +55,11 @@ func (pe *PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } +// SQLState returns the SQLState of the error. +func (pe *PgError) SQLState() string { + return pe.Code +} + type connectError struct { config *Config msg string From fd2093cef8e97839e11bb13bb4a2c1b805ae62f5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Jan 2020 18:42:31 -0600 Subject: [PATCH 0387/1158] Add statement type convenience methods to CommandTag and optimize Added convenient way to check whether a statement was a select, insert, update, or delete. These methods do not allocate. RowsAffected now does not allocate even when a large number of rows are affected. It also is multiple times faster, though the absolute change is inconsequential. --- benchmark_test.go | 68 +++++++++++++++++++++++++++++++++++++++++++++++ pgconn.go | 64 +++++++++++++++++++++++++++++++++++++++++--- pgconn_test.go | 25 ++++++++++++----- 3 files changed, 146 insertions(+), 11 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 3295a90f..ced785b6 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "os" + "strings" "testing" "github.com/jackc/pgconn" @@ -252,3 +253,70 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { // conn.ChanToSetDeadline().Ignore() // } // } + +func BenchmarkCommandTagRowsAffected(b *testing.B) { + benchmarks := []struct { + commandTag string + rowsAffected int64 + }{ + {"UPDATE 1", 1}, + {"UPDATE 123456789", 123456789}, + {"INSERT 0 1", 1}, + {"INSERT 0 123456789", 123456789}, + } + + for _, bm := range benchmarks { + ct := pgconn.CommandTag(bm.commandTag) + b.Run(bm.commandTag, func(b *testing.B) { + var n int64 + for i := 0; i < b.N; i++ { + n = ct.RowsAffected() + } + if n != bm.rowsAffected { + b.Errorf("expected %d got %d", bm.rowsAffected, n) + } + }) + } +} + +func BenchmarkCommandTagTypeFromString(b *testing.B) { + ct := pgconn.CommandTag("UPDATE 1") + + var update bool + for i := 0; i < b.N; i++ { + update = strings.HasPrefix(ct.String(), "UPDATE") + } + if !update { + b.Error("expected update") + } +} + +func BenchmarkCommandTagInsert(b *testing.B) { + benchmarks := []struct { + commandTag string + is bool + }{ + {"INSERT 1", true}, + {"INSERT 1234567890", true}, + {"UPDATE 1", false}, + {"UPDATE 1234567890", false}, + {"DELETE 1", false}, + {"DELETE 1234567890", false}, + {"SELECT 1", false}, + {"SELECT 1234567890", false}, + {"UNKNOWN 1234567890", false}, + } + + for _, bm := range benchmarks { + ct := pgconn.CommandTag(bm.commandTag) + b.Run(bm.commandTag, func(b *testing.B) { + var is bool + for i := 0; i < b.N; i++ { + is = ct.Insert() + } + if is != bm.is { + b.Errorf("expected %v got %v", bm.is, is) + } + }) + } +} diff --git a/pgconn.go b/pgconn.go index c46dc6a6..dce4bfb5 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1,7 +1,6 @@ package pgconn import ( - "bytes" "context" "crypto/md5" "crypto/tls" @@ -10,7 +9,6 @@ import ( "io" "math" "net" - "strconv" "strings" "sync" "time" @@ -579,11 +577,25 @@ type CommandTag []byte // RowsAffected returns the number of rows affected. If the CommandTag was not // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. func (ct CommandTag) RowsAffected() int64 { - idx := bytes.LastIndexByte([]byte(ct), ' ') + // Find last non-digit + idx := -1 + for i := len(ct) - 1; i >= 0; i-- { + if ct[i] >= '0' && ct[i] <= '9' { + idx = i + } else { + break + } + } + if idx == -1 { return 0 } - n, _ := strconv.ParseInt(string([]byte(ct)[idx+1:]), 10, 64) + + var n int64 + for _, b := range ct[idx:] { + n = n*10 + int64(b-'0') + } + return n } @@ -591,6 +603,50 @@ func (ct CommandTag) String() string { return string(ct) } +// Insert is true if the command tag starts with "INSERT". +func (ct CommandTag) Insert() bool { + return len(ct) >= 6 && + ct[0] == 'I' && + ct[1] == 'N' && + ct[2] == 'S' && + ct[3] == 'E' && + ct[4] == 'R' && + ct[5] == 'T' +} + +// Update is true if the command tag starts with "UPDATE". +func (ct CommandTag) Update() bool { + return len(ct) >= 6 && + ct[0] == 'U' && + ct[1] == 'P' && + ct[2] == 'D' && + ct[3] == 'A' && + ct[4] == 'T' && + ct[5] == 'E' +} + +// Delete is true if the command tag starts with "DELETE". +func (ct CommandTag) Delete() bool { + return len(ct) >= 6 && + ct[0] == 'D' && + ct[1] == 'E' && + ct[2] == 'L' && + ct[3] == 'E' && + ct[4] == 'T' && + ct[5] == 'E' +} + +// Select is true if the command tag starts with "SELECT". +func (ct CommandTag) Select() bool { + return len(ct) >= 6 && + ct[0] == 'S' && + ct[1] == 'E' && + ct[2] == 'L' && + ct[3] == 'E' && + ct[4] == 'C' && + ct[5] == 'T' +} + type StatementDescription struct { Name string SQL string diff --git a/pgconn_test.go b/pgconn_test.go index 7ae6fdc5..2c303d81 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -973,20 +973,31 @@ func TestCommandTag(t *testing.T) { var tests = []struct { commandTag pgconn.CommandTag rowsAffected int64 + isInsert bool + isUpdate bool + isDelete bool + isSelect bool }{ - {commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5}, - {commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0}, - {commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1}, - {commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0}, - {commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1}, + {commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5, isInsert: true}, + {commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0, isUpdate: true}, + {commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1, isUpdate: true}, + {commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0, isDelete: true}, + {commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1, isDelete: true}, + {commandTag: pgconn.CommandTag("DELETE 1234567890"), rowsAffected: 1234567890, isDelete: true}, + {commandTag: pgconn.CommandTag("SELECT 1"), rowsAffected: 1, isSelect: true}, + {commandTag: pgconn.CommandTag("SELECT 99999999999"), rowsAffected: 99999999999, isSelect: true}, {commandTag: pgconn.CommandTag("CREATE TABLE"), rowsAffected: 0}, {commandTag: pgconn.CommandTag("ALTER TABLE"), rowsAffected: 0}, {commandTag: pgconn.CommandTag("DROP TABLE"), rowsAffected: 0}, } for i, tt := range tests { - actual := tt.commandTag.RowsAffected() - assert.Equalf(t, tt.rowsAffected, actual, "%d. %v", i, tt.commandTag) + ct := tt.commandTag + assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag) } } From a48e9bf63c413ae0498d16e339f94c2884fa988e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Jan 2020 19:07:39 -0600 Subject: [PATCH 0388/1158] Update changelog --- CHANGELOG.md | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 92497f47..1debb10b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,25 @@ +# 1.2.0 (January 11, 2020) + +## Features + +* Add Insert(), Update(), Delete(), and Select() statement type query methods to CommandTag. +* Add PgError.SQLState method. This could be used for compatibility with other drivers and databases. + +## Performance + +* Improve performance when context.Background() is used. (bakape) +* CommandTag.RowsAffected is faster and does not allocate. + +## Fixes + +* Try to cancel any in-progress query when a conn is closed by ctx cancel. +* Handle NoticeResponse during CopyFrom. +* Ignore errors sending Terminate message while closing connection. This mimics the behavior of libpq PGfinish. + # 1.1.0 (October 12, 2019) -* Add PgConn.IsBusy() method +* Add PgConn.IsBusy() method. # 1.0.1 (September 19, 2019) -* Fix statement cache not properly cleaning discarded statements +* Fix statement cache not properly cleaning discarded statements. From 186f4b3539e5b358d3d237ad2e1ab267e6470a30 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Jan 2020 19:15:23 -0600 Subject: [PATCH 0389/1158] Update changelog --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7db5c1a2..8c76d496 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 1.1.0 (January 11, 2020) + +* Add PostgreSQL time type support +* Add more automatic conversions of integer arrays of different types (Jean-Philippe Quéméner) + # 1.0.3 (November 16, 2019) * Support initializing Array types from a slice of the value (Alex Gaynor) From 0df97353b8acf7c2751f7812f27d99a6974e596c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 12 Jan 2020 16:27:46 -0600 Subject: [PATCH 0390/1158] Fix racy usage of pgConn.contextWatcher in ayncClose --- pgconn.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pgconn.go b/pgconn.go index dce4bfb5..3a6c598a 100644 --- a/pgconn.go +++ b/pgconn.go @@ -518,13 +518,14 @@ func (pgConn *PgConn) ayncClose() { go func() { defer pgConn.conn.Close() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + deadline := time.Now().Add(time.Second * 15) + + ctx, cancel := context.WithDeadline(context.Background(), deadline) defer cancel() pgConn.CancelRequest(ctx) - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() + pgConn.conn.SetDeadline(deadline) pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) pgConn.conn.Read(make([]byte, 1)) From 2582879459d09494565cda6b2fe91d7660623122 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 12 Jan 2020 16:28:56 -0600 Subject: [PATCH 0391/1158] Fix typo - rename ayncClose to asyncClose --- pgconn.go | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/pgconn.go b/pgconn.go index 3a6c598a..89d7bb45 100644 --- a/pgconn.go +++ b/pgconn.go @@ -372,7 +372,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return &writeError{err: err, safeToRetry: n == 0} } @@ -431,7 +431,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { if err != nil { // Close on anything other than timeout error - everything else is fatal if err, ok := err.(net.Error); !(ok && err.Timeout()) { - pgConn.ayncClose() + pgConn.asyncClose() } return nil, err @@ -444,7 +444,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: if msg.Severity == "FATAL" { - pgConn.ayncClose() + pgConn.asyncClose() return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: @@ -507,9 +507,9 @@ func (pgConn *PgConn) Close(ctx context.Context) error { return pgConn.conn.Close() } -// ayncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying +// asyncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying // connection. -func (pgConn *PgConn) ayncClose() { +func (pgConn *PgConn) asyncClose() { if pgConn.status == connStatusClosed { return } @@ -680,7 +680,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0} } @@ -692,7 +692,7 @@ readloop: for { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, err } @@ -852,7 +852,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() pgConn.contextWatcher.Unwatch() multiResult.closed = true multiResult.err = &writeError{err: err, safeToRetry: n == 0} @@ -965,7 +965,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0}) pgConn.contextWatcher.Unwatch() result.closed = true @@ -996,7 +996,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() pgConn.unlock() return nil, &writeError{err: err, safeToRetry: n == 0} } @@ -1007,7 +1007,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm for { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, err } @@ -1016,7 +1016,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm case *pgproto3.CopyData: _, err := w.Write(msg.Data) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, err } case *pgproto3.ReadyForQuery: @@ -1056,7 +1056,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, &writeError{err: err, safeToRetry: n == 0} } @@ -1067,7 +1067,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co for pendingCopyInResponse { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, err } @@ -1096,7 +1096,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, err } } @@ -1105,7 +1105,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case <-signalMessageChan: msg, err := pgConn.receiveMessage() if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, err } @@ -1129,7 +1129,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } _, err = pgConn.conn.Write(buf) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, err } @@ -1137,7 +1137,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co for { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, err } @@ -1182,7 +1182,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) mrr.pgConn.contextWatcher.Unwatch() mrr.err = err mrr.closed = true - mrr.pgConn.ayncClose() + mrr.pgConn.asyncClose() return nil, mrr.err } @@ -1371,7 +1371,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error rr.pgConn.contextWatcher.Unwatch() rr.closed = true if rr.multiResultReader == nil { - rr.pgConn.ayncClose() + rr.pgConn.asyncClose() } return nil, rr.err From e7dd01e064b5caf31bd290db23fadd13e60f8cd8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 13 Jan 2020 08:48:32 -0600 Subject: [PATCH 0392/1158] Update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1debb10b..c79d4f0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.2.1 (January 13, 2020) + +* Fix data race in context cancellation introduced in v1.2.0. + # 1.2.0 (January 11, 2020) ## Features From 595780be0f9f581451a23a5151b77f782202ad72 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 17 Jan 2020 16:55:05 -0600 Subject: [PATCH 0393/1158] Map io.EOF errors to io.ErrUnexpectedEOF io.EOF is never expected during valid usage. In addition, database/sql uses io.EOF as a sentinal value that all rows from a query have been received. See https://github.com/jackc/pgx/issues/662. --- frontend.go | 11 +++++++++-- frontend_test.go | 26 ++++++++++++++++++++++++-- go.mod | 1 + go.sum | 10 ++++++++++ 4 files changed, 44 insertions(+), 4 deletions(-) diff --git a/frontend.go b/frontend.go index 0826685b..3298d7e6 100644 --- a/frontend.go +++ b/frontend.go @@ -58,12 +58,19 @@ func (f *Frontend) Send(msg FrontendMessage) error { return err } +func translateEOFtoErrUnexpectedEOF(err error) error { + if err == io.EOF { + return io.ErrUnexpectedEOF + } + return err +} + // Receive receives a message from the backend. func (f *Frontend) Receive() (BackendMessage, error) { if !f.partialMsg { header, err := f.cr.Next(5) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } f.msgType = header[0] @@ -73,7 +80,7 @@ func (f *Frontend) Receive() (BackendMessage, error) { msgBody, err := f.cr.Next(f.bodyLen) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } f.partialMsg = false diff --git a/frontend_test.go b/frontend_test.go index 9b63aa00..002da759 100644 --- a/frontend_test.go +++ b/frontend_test.go @@ -1,10 +1,11 @@ package pgproto3_test import ( - "errors" + "io" "testing" "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/assert" ) type interruptReader struct { @@ -13,7 +14,7 @@ type interruptReader struct { func (ir *interruptReader) Read(p []byte) (n int, err error) { if len(ir.chunks) == 0 { - return 0, errors.New("no data") + return 0, io.EOF } n = copy(p, ir.chunks[0]) @@ -56,3 +57,24 @@ func TestFrontendReceiveInterrupted(t *testing.T) { t.Fatalf("unexpected msg: %v", msg) } } + +func TestFrontendReceiveUnexpectedEOF(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Z', 0, 0, 0, 5}) + + frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil) + + msg, err := frontend.Receive() + if err == nil { + t.Fatal("expected err") + } + if msg != nil { + t.Fatalf("did not expect msg, but %v", msg) + } + + msg, err = frontend.Receive() + assert.Nil(t, msg) + assert.Equal(t, io.ErrUnexpectedEOF, err) +} diff --git a/go.mod b/go.mod index 4821676a..36041a94 100644 --- a/go.mod +++ b/go.mod @@ -5,4 +5,5 @@ go 1.12 require ( github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 + github.com/stretchr/testify v1.4.0 ) diff --git a/go.sum b/go.sum index 36160794..dd9cd044 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,14 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= From 8be01d690fed6a2bd6d1cad7819c4fe00cb3611e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 17 Jan 2020 17:38:07 -0600 Subject: [PATCH 0394/1158] Make Host comment more precise --- config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.go b/config.go index 628deed8..9876ac94 100644 --- a/config.go +++ b/config.go @@ -29,7 +29,7 @@ type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and // then it can be modified. A manually initialized Config will cause ConnectConfig to panic. type Config struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) Port uint16 Database string User string From 59525245114b2a264f25fbeeddda947a64e2c61e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 17 Jan 2020 17:38:44 -0600 Subject: [PATCH 0395/1158] Add Hijack and Construct fixes #9 --- pgconn.go | 69 +++++++++++++++++++++++++++++++++++++++++++++++++- pgconn_test.go | 26 +++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 89d7bb45..44a08cc8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -27,6 +27,8 @@ const ( connStatusBusy ) +const wbufLen = 1024 + // Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from // LISTEN/NOTIFY notification. type Notice PgError @@ -192,7 +194,7 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config - pgConn.wbuf = make([]byte, 0, 1024) + pgConn.wbuf = make([]byte, 0, wbufLen) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) @@ -1481,3 +1483,68 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) { return strings.Replace(s, "'", "''", -1), nil } + +// HijackedConn is the result of hijacking a connection. +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +type HijackedConn struct { + Conn net.Conn // the underlying TCP or unix domain socket connection + PID uint32 // backend pid + SecretKey uint32 // key to use to send a cancel query message to the server + ParameterStatuses map[string]string // parameters that have been reported by the server + TxStatus byte + Frontend Frontend + Config *Config +} + +// Hijack extracts the internal connection data. pgConn must be in an idle state. pgConn is unusable after hijacking. +// Hijacking is typically only useful when using pgconn to establish a connection, but taking complete control of the +// raw connection after that (e.g. a load balancer or proxy). +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +func (pgConn *PgConn) Hijack() (*HijackedConn, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + pgConn.status = connStatusClosed + + return &HijackedConn{ + Conn: pgConn.conn, + PID: pgConn.pid, + SecretKey: pgConn.secretKey, + ParameterStatuses: pgConn.parameterStatuses, + TxStatus: pgConn.txStatus, + Frontend: pgConn.frontend, + Config: pgConn.config, + }, nil +} + +// Construct created a PgConn from an already established connection to a PostgreSQL server. This is the inverse of +// PgConn.Hijack. The connection must be in an idle state. +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +func Construct(hc *HijackedConn) (*PgConn, error) { + pgConn := &PgConn{ + conn: hc.Conn, + pid: hc.PID, + secretKey: hc.SecretKey, + parameterStatuses: hc.ParameterStatuses, + txStatus: hc.TxStatus, + frontend: hc.Frontend, + config: hc.Config, + + status: connStatusIdle, + + wbuf: make([]byte, 0, wbufLen), + } + + pgConn.contextWatcher = ctxwatch.NewContextWatcher( + func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { pgConn.conn.SetDeadline(time.Time{}) }, + ) + + return pgConn, nil +} diff --git a/pgconn_test.go b/pgconn_test.go index 2c303d81..34982bb7 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1600,6 +1600,32 @@ func TestConnSendBytesAndReceiveMessage(t *testing.T) { ensureConnValid(t, pgConn) } +func TestHijackAndConstruct(t *testing.T) { + t.Parallel() + + origConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + hc, err := origConn.Hijack() + require.NoError(t, err) + + newConn, err := pgconn.Construct(hc) + require.NoError(t, err) + + defer closeConn(t, newConn) + + results, err := newConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.NoError(t, err) + + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + ensureConnValid(t, newConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { From a4375eb53f25d9dc139319d01d1921b2927179f9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 17 Jan 2020 17:42:20 -0600 Subject: [PATCH 0396/1158] Add test that Hijack'ed conn is no longer usable. --- pgconn_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pgconn_test.go b/pgconn_test.go index 34982bb7..c37a2fb2 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1609,6 +1609,9 @@ func TestHijackAndConstruct(t *testing.T) { hc, err := origConn.Hijack() require.NoError(t, err) + _, err = origConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + require.Error(t, err) + newConn, err := pgconn.Construct(hc) require.NoError(t, err) From f909a64ff567aec10157dc6a7efb9e5c9365aac6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 23 Jan 2020 20:55:52 -0600 Subject: [PATCH 0397/1158] Update pgproto3 to v2.0.1 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 4a188cce..59e7e98e 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0 + github.com/jackc/pgproto3/v2 v2.0.1 github.com/stretchr/testify v1.4.0 golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 golang.org/x/text v0.3.2 diff --git a/go.sum b/go.sum index 51c55d12..0c7fc9f1 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0 h1:FApgMJ/GtaXfI0s8Lvd0kaLaRwMOhs4VH92pwkwQQvU= github.com/jackc/pgproto3/v2 v2.0.0/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.1 h1:Rdjp4NFjwHnEslx2b66FfCI2S0LhO4itac3hXz6WX9M= +github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= From 6124b07bb1380523c4d6e01db8b55546e6c61136 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 23 Jan 2020 20:57:13 -0600 Subject: [PATCH 0398/1158] Update changelog --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c79d4f0b..26e9c8c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 1.3.0 (January 23, 2020) + +* Add Hijack and Construct. +* Update pgproto3 to v2.0.1. + # 1.2.1 (January 13, 2020) * Fix data race in context cancellation introduced in v1.2.0. From 0bbaad1348a924d630c9ed4c68d6d999f94021da Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 24 Jan 2020 11:23:28 -0600 Subject: [PATCH 0399/1158] Add zeronull package for easier NULL <-> zero conversion --- testutil/testutil.go | 150 +++++++++++++++++++++++++++++++++++ zeronull/doc.go | 22 +++++ zeronull/int2.go | 90 +++++++++++++++++++++ zeronull/int2_test.go | 23 ++++++ zeronull/int4.go | 90 +++++++++++++++++++++ zeronull/int4_test.go | 23 ++++++ zeronull/int8.go | 90 +++++++++++++++++++++ zeronull/int8_test.go | 23 ++++++ zeronull/text.go | 90 +++++++++++++++++++++ zeronull/text_test.go | 23 ++++++ zeronull/timestamp.go | 91 +++++++++++++++++++++ zeronull/timestamp_test.go | 29 +++++++ zeronull/timestamptz.go | 91 +++++++++++++++++++++ zeronull/timestamptz_test.go | 29 +++++++ zeronull/uuid.go | 90 +++++++++++++++++++++ zeronull/uuid_test.go | 23 ++++++ 16 files changed, 977 insertions(+) create mode 100644 zeronull/doc.go create mode 100644 zeronull/int2.go create mode 100644 zeronull/int2_test.go create mode 100644 zeronull/int4.go create mode 100644 zeronull/int4_test.go create mode 100644 zeronull/int8.go create mode 100644 zeronull/int8_test.go create mode 100644 zeronull/text.go create mode 100644 zeronull/text_test.go create mode 100644 zeronull/timestamp.go create mode 100644 zeronull/timestamp_test.go create mode 100644 zeronull/timestamptz.go create mode 100644 zeronull/timestamptz_test.go create mode 100644 zeronull/uuid.go create mode 100644 zeronull/uuid_test.go diff --git a/testutil/testutil.go b/testutil/testutil.go index 068b7c59..e7b64b58 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -284,3 +284,153 @@ func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, t } } } + +func TestGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) { + TestPgxGoZeroToNullConversion(t, pgTypeName, zero) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + TestDatabaseSQLGoZeroToNullConversion(t, driverName, pgTypeName, zero) + } +} + +func TestNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) { + TestPgxNullToGoZeroConversion(t, pgTypeName, zero) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + TestDatabaseSQLNullToGoZeroConversion(t, driverName, pgTypeName, zero) + } +} + +func TestPgxGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) { + conn := MustConnectPgx(t) + defer MustCloseContext(t, conn) + + _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s is null", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, paramFormat := range formats { + vEncoder := ForceEncoder(zero, paramFormat.formatCode) + if vEncoder == nil { + t.Logf("Skipping Param %s: %#v does not implement %v for encoding", paramFormat.name, zero, paramFormat.name) + continue + } + + var result bool + err := conn.QueryRow(context.Background(), "test", vEncoder).Scan(&result) + if err != nil { + t.Errorf("Param %s: %v", paramFormat.name, err) + } + + if !result { + t.Errorf("Param %s: did not convert zero to null", paramFormat.name) + } + } +} + +func TestPgxNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) { + conn := MustConnectPgx(t) + defer MustCloseContext(t, conn) + + _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select null::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, resultFormat := range formats { + + switch resultFormat.formatCode { + case pgx.TextFormatCode: + if _, ok := zero.(pgtype.TextEncoder); !ok { + t.Logf("Skipping Result %s: %#v does not implement %v for decoding", resultFormat.name, zero, resultFormat.name) + continue + } + case pgx.BinaryFormatCode: + if _, ok := zero.(pgtype.BinaryEncoder); !ok { + t.Logf("Skipping Result %s: %#v does not implement %v for decoding", resultFormat.name, zero, resultFormat.name) + continue + } + } + + // Derefence value if it is a pointer + derefZero := zero + refVal := reflect.ValueOf(zero) + if refVal.Kind() == reflect.Ptr { + derefZero = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefZero)) + + err := conn.QueryRow(context.Background(), "test").Scan(result.Interface()) + if err != nil { + t.Errorf("Result %s: %v", resultFormat.name, err) + } + + if !reflect.DeepEqual(result.Elem().Interface(), derefZero) { + t.Errorf("Result %s: did not convert null to zero", resultFormat.name) + } + } +} + +func TestDatabaseSQLGoZeroToNullConversion(t testing.TB, driverName, pgTypeName string, zero interface{}) { + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select $1::%s is null", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + var result bool + err = ps.QueryRow(zero).Scan(&result) + if err != nil { + t.Errorf("%v %v", driverName, err) + } + + if !result { + t.Errorf("%v: did not convert zero to null", driverName) + } +} + +func TestDatabaseSQLNullToGoZeroConversion(t testing.TB, driverName, pgTypeName string, zero interface{}) { + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select null::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + // Derefence value if it is a pointer + derefZero := zero + refVal := reflect.ValueOf(zero) + if refVal.Kind() == reflect.Ptr { + derefZero = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefZero)) + + err = ps.QueryRow().Scan(result.Interface()) + if err != nil { + t.Errorf("%v %v", driverName, err) + } + + if !reflect.DeepEqual(result.Elem().Interface(), derefZero) { + t.Errorf("%s: did not convert null to zero", driverName) + } +} diff --git a/zeronull/doc.go b/zeronull/doc.go new file mode 100644 index 00000000..8db3507c --- /dev/null +++ b/zeronull/doc.go @@ -0,0 +1,22 @@ +// Package zeronull contains types that automatically convert between database NULLs and Go zero values. +/* +Sometimes the distinction between a zero value and a NULL value is not useful at the application level. For example, +in PostgreSQL an empty string may be stored as NULL. There is usually no application level distinction between an +empty string and a NULL string. Package zeronull implements types that seemlessly convert between PostgreSQL NULL and +the zero value. + +It is recommended to convert types at usage time rather than instantiate these types directly. In the example below, +middlename would be stored as a NULL. + + firstname := "John" + middlename := "" + lastname := "Smith" + _, err := conn.Exec( + ctx, + "insert into people(firstname, middlename, lastname) values($1, $2, $3)", + zeronull.Text(firstname), + zeronull.Text(middlename), + zeronull.Text(lastname), + ) +*/ +package zeronull diff --git a/zeronull/int2.go b/zeronull/int2.go new file mode 100644 index 00000000..a528642f --- /dev/null +++ b/zeronull/int2.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type Int2 int16 + +func (dst *Int2) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int2 + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Int2(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (dst *Int2) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int2 + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Int2(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (src Int2) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int2{ + Int: int16(src), + Status: pgtype.Present, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Int2) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int2{ + Int: int16(src), + Status: pgtype.Present, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int2) Scan(src interface{}) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int2 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int2(nullable.Int) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int2) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/zeronull/int2_test.go b/zeronull/int2_test.go new file mode 100644 index 00000000..2dcb4e79 --- /dev/null +++ b/zeronull/int2_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestInt2Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ + (zeronull.Int2)(1), + (zeronull.Int2)(0), + }) +} + +func TestInt2ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "int2", (zeronull.Int2)(0)) +} + +func TestInt2ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "int2", (zeronull.Int2)(0)) +} diff --git a/zeronull/int4.go b/zeronull/int4.go new file mode 100644 index 00000000..c539e43a --- /dev/null +++ b/zeronull/int4.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type Int4 int32 + +func (dst *Int4) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int4 + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Int4(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (dst *Int4) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int4 + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Int4(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (src Int4) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int4{ + Int: int32(src), + Status: pgtype.Present, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Int4) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int4{ + Int: int32(src), + Status: pgtype.Present, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4) Scan(src interface{}) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int4 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int4(nullable.Int) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/zeronull/int4_test.go b/zeronull/int4_test.go new file mode 100644 index 00000000..309e4125 --- /dev/null +++ b/zeronull/int4_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestInt4Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ + (zeronull.Int4)(1), + (zeronull.Int4)(0), + }) +} + +func TestInt4ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "int4", (zeronull.Int4)(0)) +} + +func TestInt4ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "int4", (zeronull.Int4)(0)) +} diff --git a/zeronull/int8.go b/zeronull/int8.go new file mode 100644 index 00000000..19774645 --- /dev/null +++ b/zeronull/int8.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type Int8 int64 + +func (dst *Int8) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int8 + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Int8(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (dst *Int8) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int8 + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Int8(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (src Int8) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int8{ + Int: int64(src), + Status: pgtype.Present, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Int8) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int8{ + Int: int64(src), + Status: pgtype.Present, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8) Scan(src interface{}) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int8 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int8(nullable.Int) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/zeronull/int8_test.go b/zeronull/int8_test.go new file mode 100644 index 00000000..ae80bc0a --- /dev/null +++ b/zeronull/int8_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestInt8Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ + (zeronull.Int8)(1), + (zeronull.Int8)(0), + }) +} + +func TestInt8ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "int8", (zeronull.Int8)(0)) +} + +func TestInt8ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "int8", (zeronull.Int8)(0)) +} diff --git a/zeronull/text.go b/zeronull/text.go new file mode 100644 index 00000000..8e79fc6a --- /dev/null +++ b/zeronull/text.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type Text string + +func (dst *Text) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Text + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Text(nullable.String) + } else { + *dst = Text("") + } + + return nil +} + +func (dst *Text) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Text + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Text(nullable.String) + } else { + *dst = Text("") + } + + return nil +} + +func (src Text) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == Text("") { + return nil, nil + } + + nullable := pgtype.Text{ + String: string(src), + Status: pgtype.Present, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Text) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == Text("") { + return nil, nil + } + + nullable := pgtype.Text{ + String: string(src), + Status: pgtype.Present, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Text) Scan(src interface{}) error { + if src == nil { + *dst = Text("") + return nil + } + + var nullable pgtype.Text + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Text(nullable.String) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Text) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/zeronull/text_test.go b/zeronull/text_test.go new file mode 100644 index 00000000..f08a0d2a --- /dev/null +++ b/zeronull/text_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestTextTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "text", []interface{}{ + (zeronull.Text)("foo"), + (zeronull.Text)(""), + }) +} + +func TestTextConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "text", (zeronull.Text)("")) +} + +func TestTextConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "text", (zeronull.Text)("")) +} diff --git a/zeronull/timestamp.go b/zeronull/timestamp.go new file mode 100644 index 00000000..a94c67cc --- /dev/null +++ b/zeronull/timestamp.go @@ -0,0 +1,91 @@ +package zeronull + +import ( + "database/sql/driver" + "time" + + "github.com/jackc/pgtype" +) + +type Timestamp time.Time + +func (dst *Timestamp) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Timestamp + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Timestamp(nullable.Time) + } else { + *dst = Timestamp{} + } + + return nil +} + +func (dst *Timestamp) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Timestamp + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Timestamp(nullable.Time) + } else { + *dst = Timestamp{} + } + + return nil +} + +func (src Timestamp) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == Timestamp{}) { + return nil, nil + } + + nullable := pgtype.Timestamp{ + Time: time.Time(src), + Status: pgtype.Present, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Timestamp) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == Timestamp{}) { + return nil, nil + } + + nullable := pgtype.Timestamp{ + Time: time.Time(src), + Status: pgtype.Present, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamp) Scan(src interface{}) error { + if src == nil { + *dst = Timestamp{} + return nil + } + + var nullable pgtype.Timestamp + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Timestamp(nullable.Time) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamp) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/zeronull/timestamp_test.go b/zeronull/timestamp_test.go new file mode 100644 index 00000000..ec96ff07 --- /dev/null +++ b/zeronull/timestamp_test.go @@ -0,0 +1,29 @@ +package zeronull_test + +import ( + "testing" + "time" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestTimestampTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ + (zeronull.Timestamp)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + (zeronull.Timestamp)(time.Time{}), + }, func(a, b interface{}) bool { + at := a.(zeronull.Timestamp) + bt := b.(zeronull.Timestamp) + + return time.Time(at).Equal(time.Time(bt)) + }) +} + +func TestTimestampConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "timestamp", (zeronull.Timestamp)(time.Time{})) +} + +func TestTimestampConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "timestamp", (zeronull.Timestamp)(time.Time{})) +} diff --git a/zeronull/timestamptz.go b/zeronull/timestamptz.go new file mode 100644 index 00000000..c641ca10 --- /dev/null +++ b/zeronull/timestamptz.go @@ -0,0 +1,91 @@ +package zeronull + +import ( + "database/sql/driver" + "time" + + "github.com/jackc/pgtype" +) + +type Timestamptz time.Time + +func (dst *Timestamptz) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Timestamptz + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Timestamptz(nullable.Time) + } else { + *dst = Timestamptz{} + } + + return nil +} + +func (dst *Timestamptz) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Timestamptz + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Timestamptz(nullable.Time) + } else { + *dst = Timestamptz{} + } + + return nil +} + +func (src Timestamptz) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == Timestamptz{}) { + return nil, nil + } + + nullable := pgtype.Timestamptz{ + Time: time.Time(src), + Status: pgtype.Present, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Timestamptz) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == Timestamptz{}) { + return nil, nil + } + + nullable := pgtype.Timestamptz{ + Time: time.Time(src), + Status: pgtype.Present, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamptz) Scan(src interface{}) error { + if src == nil { + *dst = Timestamptz{} + return nil + } + + var nullable pgtype.Timestamptz + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Timestamptz(nullable.Time) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamptz) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/zeronull/timestamptz_test.go b/zeronull/timestamptz_test.go new file mode 100644 index 00000000..3a401c49 --- /dev/null +++ b/zeronull/timestamptz_test.go @@ -0,0 +1,29 @@ +package zeronull_test + +import ( + "testing" + "time" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestTimestamptzTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ + (zeronull.Timestamptz)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + (zeronull.Timestamptz)(time.Time{}), + }, func(a, b interface{}) bool { + at := a.(zeronull.Timestamptz) + bt := b.(zeronull.Timestamptz) + + return time.Time(at).Equal(time.Time(bt)) + }) +} + +func TestTimestamptzConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "timestamptz", (zeronull.Timestamptz)(time.Time{})) +} + +func TestTimestamptzConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "timestamptz", (zeronull.Timestamptz)(time.Time{})) +} diff --git a/zeronull/uuid.go b/zeronull/uuid.go new file mode 100644 index 00000000..18fc667e --- /dev/null +++ b/zeronull/uuid.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type UUID [16]byte + +func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.UUID + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = UUID(nullable.Bytes) + } else { + *dst = UUID{} + } + + return nil +} + +func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.UUID + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = UUID(nullable.Bytes) + } else { + *dst = UUID{} + } + + return nil +} + +func (src UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == UUID{}) { + return nil, nil + } + + nullable := pgtype.UUID{ + Bytes: [16]byte(src), + Status: pgtype.Present, + } + + return nullable.EncodeText(ci, buf) +} + +func (src UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == UUID{}) { + return nil, nil + } + + nullable := pgtype.UUID{ + Bytes: [16]byte(src), + Status: pgtype.Present, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *UUID) Scan(src interface{}) error { + if src == nil { + *dst = UUID{} + return nil + } + + var nullable pgtype.UUID + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = UUID(nullable.Bytes) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src UUID) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/zeronull/uuid_test.go b/zeronull/uuid_test.go new file mode 100644 index 00000000..162bdf1f --- /dev/null +++ b/zeronull/uuid_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestUUIDTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ + (*zeronull.UUID)(&[16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), + (*zeronull.UUID)(&[16]byte{}), + }) +} + +func TestUUIDConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "uuid", (*zeronull.UUID)(&[16]byte{})) +} + +func TestUUIDConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "uuid", (*zeronull.UUID)(&[16]byte{})) +} From b01b35f466d926e7b372659a2a5291f722c59168 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 24 Jan 2020 14:58:59 -0600 Subject: [PATCH 0400/1158] Fix typo in docs --- zeronull/doc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeronull/doc.go b/zeronull/doc.go index 8db3507c..78a52307 100644 --- a/zeronull/doc.go +++ b/zeronull/doc.go @@ -2,7 +2,7 @@ /* Sometimes the distinction between a zero value and a NULL value is not useful at the application level. For example, in PostgreSQL an empty string may be stored as NULL. There is usually no application level distinction between an -empty string and a NULL string. Package zeronull implements types that seemlessly convert between PostgreSQL NULL and +empty string and a NULL string. Package zeronull implements types that seamlessly convert between PostgreSQL NULL and the zero value. It is recommended to convert types at usage time rather than instantiate these types directly. In the example below, From cf87e347920d0b4a33ef489e951e0d7d211f9d52 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 24 Jan 2020 17:07:41 -0600 Subject: [PATCH 0401/1158] Add JSON to shopspring-numeric extension --- ext/shopspring-numeric/decimal.go | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index c035b15b..259fa54d 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -11,6 +11,7 @@ import ( ) var errUndefined = errors.New("cannot encode status undefined") +var errBadStatus = errors.New("invalid status") type Numeric struct { Decimal decimal.Decimal @@ -316,3 +317,32 @@ func (src Numeric) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src Numeric) MarshalJSON() ([]byte, error) { + switch src.Status { + case pgtype.Present: + return src.Decimal.MarshalJSON() + case pgtype.Null: + return []byte("null"), nil + case pgtype.Undefined: + return nil, errUndefined + } + + return nil, errBadStatus +} + +func (dst *Numeric) UnmarshalJSON(b []byte) error { + d := decimal.NullDecimal{} + err := d.UnmarshalJSON(b) + if err != nil { + return err + } + + status := pgtype.Null + if d.Valid { + status = pgtype.Present + } + *dst = Numeric{Decimal: d.Decimal, Status: status} + + return nil +} From 06942241c4591e2fe7cad95007232da89ba5ec18 Mon Sep 17 00:00:00 2001 From: Jeffrey Stiles Date: Fri, 24 Jan 2020 16:38:15 -0800 Subject: [PATCH 0402/1158] Support Null Status in UnmarshalJSON --- int4.go | 8 ++++++-- int4_test.go | 21 +++++++++++++++++++++ int8.go | 8 ++++++-- int8_test.go | 21 +++++++++++++++++++++ text.go | 8 ++++++-- text_test.go | 21 +++++++++++++++++++++ 6 files changed, 81 insertions(+), 6 deletions(-) diff --git a/int4.go b/int4.go index da39b7f0..2075b375 100644 --- a/int4.go +++ b/int4.go @@ -201,13 +201,17 @@ func (src Int4) MarshalJSON() ([]byte, error) { } func (dst *Int4) UnmarshalJSON(b []byte) error { - var n int32 + var n *int32 err := json.Unmarshal(b, &n) if err != nil { return err } - *dst = Int4{Int: n, Status: Present} + if n == nil { + *dst = Int4{Status: Null} + } else { + *dst = Int4{Int: *n, Status: Present} + } return nil } diff --git a/int4_test.go b/int4_test.go index 52bf9f0c..77fba8a5 100644 --- a/int4_test.go +++ b/int4_test.go @@ -141,3 +141,24 @@ func TestInt4AssignTo(t *testing.T) { } } } + +func TestInt4UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int4 + }{ + {source: "null", result: pgtype.Int4{Int: 0, Status: pgtype.Null}}, + {source: "1", result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.Int4 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/int8.go b/int8.go index 7f410b15..01a694e4 100644 --- a/int8.go +++ b/int8.go @@ -187,13 +187,17 @@ func (src Int8) MarshalJSON() ([]byte, error) { } func (dst *Int8) UnmarshalJSON(b []byte) error { - var n int64 + var n *int64 err := json.Unmarshal(b, &n) if err != nil { return err } - *dst = Int8{Int: n, Status: Present} + if n == nil { + *dst = Int8{Status: Null} + } else { + *dst = Int8{Int: *n, Status: Present} + } return nil } diff --git a/int8_test.go b/int8_test.go index 63dd6f3e..73600eda 100644 --- a/int8_test.go +++ b/int8_test.go @@ -142,3 +142,24 @@ func TestInt8AssignTo(t *testing.T) { } } } + +func TestInt8UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int8 + }{ + {source: "null", result: pgtype.Int8{Int: 0, Status: pgtype.Null}}, + {source: "1", result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.Int8 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/text.go b/text.go index d13a9ba4..cdd993db 100644 --- a/text.go +++ b/text.go @@ -152,13 +152,17 @@ func (src Text) MarshalJSON() ([]byte, error) { } func (dst *Text) UnmarshalJSON(b []byte) error { - var s string + var s *string err := json.Unmarshal(b, &s) if err != nil { return err } - *dst = Text{String: s, Status: Present} + if s == nil { + *dst = Text{Status: Null} + } else { + *dst = Text{String: *s, Status: Present} + } return nil } diff --git a/text_test.go b/text_test.go index f7286995..3bacba68 100644 --- a/text_test.go +++ b/text_test.go @@ -121,3 +121,24 @@ func TestTextAssignTo(t *testing.T) { } } } + +func TestTextUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Text + }{ + {source: "null", result: pgtype.Text{String: "", Status: pgtype.Null}}, + {source: "\"a\"", result: pgtype.Text{String: "a", Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.Text + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} From 139342081ef84e9ca6933f2faa19e20059ad61a3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Jan 2020 20:32:42 -0600 Subject: [PATCH 0403/1158] Fix CopyFrom deadlock when multiple NoticeResponse received during copy fixes #21 --- pgconn.go | 51 ++++++++++++++++++++++++++++++++++---------------- pgconn_test.go | 40 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 16 deletions(-) diff --git a/pgconn.go b/pgconn.go index 44a08cc8..271e6628 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1084,26 +1084,44 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } // Send copy data - buf = make([]byte, 0, 65536) - buf = append(buf, 'd') - sp := len(buf) - var readErr error + abortCopyChan := make(chan struct{}) + copyErrChan := make(chan error) signalMessageChan := pgConn.signalMessage() - for readErr == nil && pgErr == nil { - var n int - n, readErr = r.Read(buf[5:cap(buf)]) - if n > 0 { - buf = buf[0 : n+5] - pgio.SetInt32(buf[sp:], int32(n+4)) - _, err = pgConn.conn.Write(buf) - if err != nil { - pgConn.asyncClose() - return nil, err + go func() { + buf := make([]byte, 0, 65536) + buf = append(buf, 'd') + sp := len(buf) + + for { + n, readErr := r.Read(buf[5:cap(buf)]) + if n > 0 { + buf = buf[0 : n+5] + pgio.SetInt32(buf[sp:], int32(n+4)) + + _, writeErr := pgConn.conn.Write(buf) + if writeErr != nil { + copyErrChan <- writeErr + return + } + } + if readErr != nil { + copyErrChan <- readErr + return + } + + select { + case <-abortCopyChan: + return + default: } } + }() + var copyErr error + for copyErr == nil && pgErr == nil { select { + case copyErr = <-copyErrChan: case <-signalMessageChan: msg, err := pgConn.receiveMessage() if err != nil { @@ -1120,13 +1138,14 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co default: } } + close(abortCopyChan) buf = buf[:0] - if readErr == io.EOF || pgErr != nil { + if copyErr == io.EOF || pgErr != nil { copyDone := &pgproto3.CopyDone{} buf = copyDone.Encode(buf) } else { - copyFail := &pgproto3.CopyFail{Message: readErr.Error()} + copyFail := &pgproto3.CopyFail{Message: copyErr.Error()} buf = copyFail.Encode(buf) } _, err = pgConn.conn.Write(buf) diff --git a/pgconn_test.go b/pgconn_test.go index c37a2fb2..19ad3a0a 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1463,6 +1463,46 @@ func TestConnCopyFromQueryNoTableError(t *testing.T) { ensureConnValid(t, pgConn) } +// https://github.com/jackc/pgconn/issues/21 +func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) { + t.Parallel() + + ctx := context.Background() + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(ctx, `create temporary table sentences( + t text, + ts tsvector + )`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(ctx, `create function pg_temp.sentences_trigger() returns trigger as $$ + begin + new.ts := to_tsvector(new.t); + return new; + end + $$ language plpgsql;`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(ctx, `create trigger sentences_update before insert on sentences for each row execute procedure pg_temp.sentences_trigger();`).ReadAll() + require.NoError(t, err) + + longString := make([]byte, 10001) + for i := range longString { + longString[i] = 'x' + } + + buf := &bytes.Buffer{} + for i := 0; i < 1000; i++ { + buf.Write([]byte(fmt.Sprintf("%s\n", string(longString)))) + } + + _, err = pgConn.CopyFrom(ctx, buf, "COPY sentences(t) FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) +} + func TestConnEscapeString(t *testing.T) { t.Parallel() From 67f2418279fabea76c16c3b613b9893a3b86e7d8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Jan 2020 20:39:18 -0600 Subject: [PATCH 0404/1158] Make copyErrChan buffered so goroutine can always terminate It is possible the goroutine that is reading from copyErrChan will not read in case of error. --- pgconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 271e6628..e34b4cfe 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1085,7 +1085,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // Send copy data abortCopyChan := make(chan struct{}) - copyErrChan := make(chan error) + copyErrChan := make(chan error, 1) signalMessageChan := pgConn.signalMessage() go func() { From c9abb86f21f0b89b909e9d112829e21daf3c06d8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Jan 2020 20:40:21 -0600 Subject: [PATCH 0405/1158] Ensure write failure in CopyFrom closes connection --- pgconn.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pgconn.go b/pgconn.go index e34b4cfe..f56575ca 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1101,6 +1101,9 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, writeErr := pgConn.conn.Write(buf) if writeErr != nil { + // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. + pgConn.conn.Close() + copyErrChan <- writeErr return } From 5f363cb1f02554168b67e4d3e5dbeece248464e0 Mon Sep 17 00:00:00 2001 From: Jeffrey Stiles Date: Mon, 27 Jan 2020 16:19:43 -0800 Subject: [PATCH 0406/1158] Add JSON marshalling for Bool, Date, JSON/B, Timestamptz --- bool.go | 34 +++++++++++++++++++++++++++ bool_test.go | 43 ++++++++++++++++++++++++++++++++++ date.go | 56 ++++++++++++++++++++++++++++++++++++++++++++ date_test.go | 49 ++++++++++++++++++++++++++++++++++++++ int4_test.go | 20 ++++++++++++++++ int8_test.go | 20 ++++++++++++++++ json.go | 23 ++++++++++++++++++ json_test.go | 41 ++++++++++++++++++++++++++++++++ jsonb.go | 8 +++++++ text_test.go | 20 ++++++++++++++++ timestamptz.go | 57 +++++++++++++++++++++++++++++++++++++++++++++ timestamptz_test.go | 47 +++++++++++++++++++++++++++++++++++++ 12 files changed, 418 insertions(+) diff --git a/bool.go b/bool.go index ad55dce4..db02f663 100644 --- a/bool.go +++ b/bool.go @@ -2,6 +2,7 @@ package pgtype import ( "database/sql/driver" + "encoding/json" "strconv" errors "golang.org/x/xerrors" @@ -163,3 +164,36 @@ func (src Bool) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src Bool) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + if src.Bool { + return []byte("true"), nil + } else { + return []byte("false"), nil + } + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + return nil, errBadStatus +} + +func (dst *Bool) UnmarshalJSON(b []byte) error { + var v *bool + err := json.Unmarshal(b, &v) + if err != nil { + return err + } + + if v == nil { + *dst = Bool{Status: Null} + } else { + *dst = Bool{Bool: *v, Status: Present} + } + + return nil +} diff --git a/bool_test.go b/bool_test.go index 64b4064d..8e7a5220 100644 --- a/bool_test.go +++ b/bool_test.go @@ -95,3 +95,46 @@ func TestBoolAssignTo(t *testing.T) { } } } + +func TestBoolMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Bool + result string + }{ + {source: pgtype.Bool{Status: pgtype.Null}, result: "null"}, + {source: pgtype.Bool{Bool: true, Status: pgtype.Present}, result: "true"}, + {source: pgtype.Bool{Bool: false, Status: pgtype.Present}, result: "false"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestBoolUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Bool + }{ + {source: "null", result: pgtype.Bool{Status: pgtype.Null}}, + {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.Bool + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/date.go b/date.go index 8e35b22a..eaf95dde 100644 --- a/date.go +++ b/date.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "time" "github.com/jackc/pgio" @@ -208,3 +209,58 @@ func (src Date) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src Date) MarshalJSON() ([]byte, error) { + switch src.Status { + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + if src.Status != Present { + return nil, errBadStatus + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.Format("2006-01-02") + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) +} + +func (dst *Date) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *dst = Date{Status: Null} + return nil + } + + switch *s { + case "infinity": + *dst = Date{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *dst = Date{Status: Present, InfinityModifier: -Infinity} + default: + t, err := time.ParseInLocation("2006-01-02", *s, time.UTC) + if err != nil { + return err + } + + *dst = Date{Time: t, Status: Present} + } + + return nil +} diff --git a/date_test.go b/date_test.go index bcdbbf20..0b77898b 100644 --- a/date_test.go +++ b/date_test.go @@ -116,3 +116,52 @@ func TestDateAssignTo(t *testing.T) { } } } + +func TestDateMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Date + result string + }{ + {source: pgtype.Date{Status: pgtype.Null}, result: "null"}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, result: "\"2012-03-29\""}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29\""}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29\""}, + {source: pgtype.Date{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, result: "\"infinity\""}, + {source: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, result: "\"-infinity\""}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestDateUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Date + }{ + {source: "null", result: pgtype.Date{Status: pgtype.Null}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, + {source: "\"infinity\"", result: pgtype.Date{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, + {source: "\"-infinity\"", result: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.Date + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r.Time.Year() != tt.result.Time.Year() || r.Time.Month() != tt.result.Time.Month() || r.Time.Day() != tt.result.Time.Day() || r.Status != tt.result.Status || r.InfinityModifier != tt.result.InfinityModifier { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/int4_test.go b/int4_test.go index 77fba8a5..c679de74 100644 --- a/int4_test.go +++ b/int4_test.go @@ -142,6 +142,26 @@ func TestInt4AssignTo(t *testing.T) { } } +func TestInt4MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int4 + result string + }{ + {source: pgtype.Int4{Int: 0, Status: pgtype.Null}, result: "null"}, + {source: pgtype.Int4{Int: 1, Status: pgtype.Present}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + func TestInt4UnmarshalJSON(t *testing.T) { successfulTests := []struct { source string diff --git a/int8_test.go b/int8_test.go index 73600eda..fb6f581b 100644 --- a/int8_test.go +++ b/int8_test.go @@ -143,6 +143,26 @@ func TestInt8AssignTo(t *testing.T) { } } +func TestInt8MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int8 + result string + }{ + {source: pgtype.Int8{Int: 0, Status: pgtype.Null}, result: "null"}, + {source: pgtype.Int8{Int: 1, Status: pgtype.Present}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + func TestInt8UnmarshalJSON(t *testing.T) { successfulTests := []struct { source string diff --git a/json.go b/json.go index 592dfa31..58a5b093 100644 --- a/json.go +++ b/json.go @@ -165,3 +165,26 @@ func (src JSON) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src JSON) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return src.Bytes, nil + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + return nil, errBadStatus +} + +func (dst *JSON) UnmarshalJSON(b []byte) error { + if b == nil || string(b) == "null" { + *dst = JSON{Status: Null} + } else { + *dst = JSON{Bytes: b, Status: Present} + } + return nil + +} diff --git a/json_test.go b/json_test.go index 918b33d5..bbd3959e 100644 --- a/json_test.go +++ b/json_test.go @@ -134,3 +134,44 @@ func TestJSONAssignTo(t *testing.T) { } } } + +func TestJSONMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.JSON + result string + }{ + {source: pgtype.JSON{Status: pgtype.Null}, result: "null"}, + {source: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Status: pgtype.Present}, result: "{\"a\": 1}"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestJSONUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.JSON + }{ + {source: "null", result: pgtype.JSON{Status: pgtype.Null}}, + {source: "{\"a\": 1}", result: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.JSON + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r.Bytes) != string(tt.result.Bytes) || r.Status != tt.result.Status { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/jsonb.go b/jsonb.go index c70be144..43e23fb2 100644 --- a/jsonb.go +++ b/jsonb.go @@ -68,3 +68,11 @@ func (dst *JSONB) Scan(src interface{}) error { func (src JSONB) Value() (driver.Value, error) { return (JSON)(src).Value() } + +func (src JSONB) MarshalJSON() ([]byte, error) { + return (JSON)(src).MarshalJSON() +} + +func (dst *JSONB) UnmarshalJSON(b []byte) error { + return (*JSON)(dst).UnmarshalJSON(b) +} diff --git a/text_test.go b/text_test.go index 3bacba68..cca3a05d 100644 --- a/text_test.go +++ b/text_test.go @@ -122,6 +122,26 @@ func TestTextAssignTo(t *testing.T) { } } +func TestTextMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Text + result string + }{ + {source: pgtype.Text{String: "", Status: pgtype.Null}, result: "null"}, + {source: pgtype.Text{String: "a", Status: pgtype.Present}, result: "\"a\""}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + func TestTextUnmarshalJSON(t *testing.T) { successfulTests := []struct { source string diff --git a/timestamptz.go b/timestamptz.go index 9af39b16..7ed86eb8 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "time" "github.com/jackc/pgio" @@ -220,3 +221,59 @@ func (src Timestamptz) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src Timestamptz) MarshalJSON() ([]byte, error) { + switch src.Status { + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + if src.Status != Present { + return nil, errBadStatus + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.Format(time.RFC3339Nano) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) +} + +func (dst *Timestamptz) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *dst = Timestamptz{Status: Null} + return nil + } + + switch *s { + case "infinity": + *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} + default: + // PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz + tim, err := time.Parse(time.RFC3339Nano, *s) + if err != nil { + return err + } + + *dst = Timestamptz{Time: tim, Status: Present} + } + + return nil +} diff --git a/timestamptz_test.go b/timestamptz_test.go index f6aec068..a020b1ec 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -120,3 +120,50 @@ func TestTimestamptzAssignTo(t *testing.T) { } } } + +func TestTimestamptzMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Timestamptz + result string + }{ + {source: pgtype.Timestamptz{Status: pgtype.Null}, result: "null"}, + {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29T10:05:45-06:00\""}, + {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29T10:05:45.555-06:00\""}, + {source: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, result: "\"infinity\""}, + {source: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, result: "\"-infinity\""}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestTimestamptzUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Timestamptz + }{ + {source: "null", result: pgtype.Timestamptz{Status: pgtype.Null}}, + {source: "\"2012-03-29T10:05:45-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, + {source: "\"2012-03-29T10:05:45.555-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, + {source: "\"infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, + {source: "\"-infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.Timestamptz + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !r.Time.Equal(tt.result.Time) || r.Status != tt.result.Status || r.InfinityModifier != tt.result.InfinityModifier { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} From 406afa0eb7f8a23c96e0c6ec7bb56cbce3fc1ca4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 5 Feb 2020 11:06:09 -0600 Subject: [PATCH 0407/1158] Release v1.3.1 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 26e9c8c7..5a9ca414 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.3.1 (February 5, 2020) + +* Fix CopyFrom deadlock when multiple NoticeResponse received during copy + # 1.3.0 (January 23, 2020) * Add Hijack and Construct. From 282b7936a2cd6528fd3c4cdab4232a514ca54adb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 5 Feb 2020 11:10:17 -0600 Subject: [PATCH 0408/1158] Release 1.2.0 --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c76d496..f12c5027 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +# 1.2.0 (February 5, 2020) + +* Add zeronull package for easier NULL <-> zero conversion +* Add JSON marshalling for shopspring-numeric extension +* Add JSON marshalling for Bool, Date, JSON/B, Timestamptz (Jeffrey Stiles) +* Fix null status in UnmarshalJSON for some types (Jeffrey Stiles) + # 1.1.0 (January 11, 2020) * Add PostgreSQL time type support From 06c4e181b1abf6d6d531b3da38b40f8a1932d21b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 5 Feb 2020 11:49:40 -0600 Subject: [PATCH 0409/1158] go mod tidy --- go.sum | 2 -- 1 file changed, 2 deletions(-) diff --git a/go.sum b/go.sum index 0c7fc9f1..c23d4412 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,6 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI1NIJyNFVVeh1gUISyt57iw/fmI/IXJfH3ATE= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.0 h1:FApgMJ/GtaXfI0s8Lvd0kaLaRwMOhs4VH92pwkwQQvU= -github.com/jackc/pgproto3/v2 v2.0.0/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.1 h1:Rdjp4NFjwHnEslx2b66FfCI2S0LhO4itac3hXz6WX9M= github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= From ac364e7a4366fc363b67cbfc06edf41594d9d8cc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 7 Feb 2020 15:40:50 -0600 Subject: [PATCH 0410/1158] Use writeError for Write error --- pgconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index f56575ca..751d8fc0 100644 --- a/pgconn.go +++ b/pgconn.go @@ -683,7 +683,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ n, err := pgConn.conn.Write(buf) if err != nil { pgConn.asyncClose() - return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0} + return nil, &writeError{err: err, safeToRetry: n == 0} } psd := &StatementDescription{Name: name, SQL: sql} From 6db848c6fca46bd3c67b1a66b5f764fbb16807ba Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Feb 2020 17:56:59 -0600 Subject: [PATCH 0411/1158] Update chunkreader to v2.0.1 --- CHANGELOG.md | 4 ++++ go.mod | 2 +- go.sum | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a9ca414..eb099dc2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.3.2 (February 14, 2020) + +* Update chunkreader to v2.0.1 for optimized default buffer size. + # 1.3.1 (February 5, 2020) * Fix CopyFrom deadlock when multiple NoticeResponse received during copy diff --git a/go.mod b/go.mod index 59e7e98e..37590559 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/jackc/pgconn go 1.12 require ( - github.com/jackc/chunkreader/v2 v2.0.0 + github.com/jackc/chunkreader/v2 v2.0.1 github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 diff --git a/go.sum b/go.sum index c23d4412..28f094e7 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZb github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= From f3816bd1c068f931d2077b02816102451749168d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 19 Feb 2020 10:48:09 -0600 Subject: [PATCH 0412/1158] Get implemented on T instead of *T Methods defined on T are also available on *T. Thought this technically changes the interface, because *T will be automatically dereferenced as needed it shouldn't be a breaking change. See a8802b16cc593842f5c69b0f7cfb0de11d5cd3a8 for similar change. --- aclitem.go | 2 +- aclitem_array.go | 2 +- bit.go | 4 ++-- bool.go | 2 +- bool_array.go | 2 +- box.go | 2 +- bpchar.go | 4 ++-- bpchar_array.go | 2 +- bytea.go | 2 +- bytea_array.go | 2 +- cid.go | 4 ++-- cidr.go | 4 ++-- cidr_array.go | 2 +- circle.go | 2 +- date.go | 2 +- date_array.go | 2 +- daterange.go | 2 +- enum_array.go | 2 +- ext/gofrs-uuid/uuid.go | 2 +- ext/shopspring-numeric/decimal.go | 2 +- float4.go | 2 +- float4_array.go | 2 +- float8.go | 2 +- float8_array.go | 2 +- generic_binary.go | 4 ++-- generic_text.go | 4 ++-- hstore.go | 2 +- hstore_array.go | 2 +- inet.go | 2 +- inet_array.go | 2 +- int2.go | 2 +- int2_array.go | 2 +- int4.go | 2 +- int4_array.go | 2 +- int4range.go | 2 +- int8.go | 2 +- int8_array.go | 2 +- int8range.go | 2 +- interval.go | 2 +- json.go | 2 +- jsonb.go | 4 ++-- line.go | 2 +- lseg.go | 2 +- macaddr.go | 2 +- macaddr_array.go | 2 +- name.go | 4 ++-- numeric.go | 2 +- numeric_array.go | 2 +- numrange.go | 2 +- oid_value.go | 4 ++-- path.go | 2 +- pguint32.go | 2 +- point.go | 2 +- polygon.go | 2 +- qchar.go | 2 +- record.go | 2 +- text.go | 2 +- text_array.go | 2 +- tid.go | 2 +- time.go | 2 +- timestamp.go | 2 +- timestamp_array.go | 2 +- timestamptz.go | 2 +- timestamptz_array.go | 2 +- tsrange.go | 2 +- tstzrange.go | 2 +- tstzrange_array.go | 2 +- typed_array.go.erb | 2 +- typed_range.go.erb | 2 +- unknown.go | 4 ++-- uuid.go | 2 +- uuid_array.go | 2 +- varbit.go | 2 +- varchar.go | 4 ++-- varchar_array.go | 2 +- xid.go | 4 ++-- 76 files changed, 88 insertions(+), 88 deletions(-) diff --git a/aclitem.go b/aclitem.go index 123e86b6..36df71bc 100644 --- a/aclitem.go +++ b/aclitem.go @@ -43,7 +43,7 @@ func (dst *ACLItem) Set(src interface{}) error { return nil } -func (dst *ACLItem) Get() interface{} { +func (dst ACLItem) Get() interface{} { switch dst.Status { case Present: return dst.String diff --git a/aclitem_array.go b/aclitem_array.go index 7b2e4dbc..bfb069fd 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -62,7 +62,7 @@ func (dst *ACLItemArray) Set(src interface{}) error { return nil } -func (dst *ACLItemArray) Get() interface{} { +func (dst ACLItemArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/bit.go b/bit.go index 925cfe7c..c1709e6b 100644 --- a/bit.go +++ b/bit.go @@ -10,8 +10,8 @@ func (dst *Bit) Set(src interface{}) error { return (*Varbit)(dst).Set(src) } -func (dst *Bit) Get() interface{} { - return (*Varbit)(dst).Get() +func (dst Bit) Get() interface{} { + return (Varbit)(dst).Get() } func (src *Bit) AssignTo(dst interface{}) error { diff --git a/bool.go b/bool.go index db02f663..898197f7 100644 --- a/bool.go +++ b/bool.go @@ -38,7 +38,7 @@ func (dst *Bool) Set(src interface{}) error { return nil } -func (dst *Bool) Get() interface{} { +func (dst Bool) Get() interface{} { switch dst.Status { case Present: return dst.Bool diff --git a/bool_array.go b/bool_array.go index 3dbb4ca0..44500b79 100644 --- a/bool_array.go +++ b/bool_array.go @@ -64,7 +64,7 @@ func (dst *BoolArray) Set(src interface{}) error { return nil } -func (dst *BoolArray) Get() interface{} { +func (dst BoolArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/box.go b/box.go index 9baabf6b..75d50f98 100644 --- a/box.go +++ b/box.go @@ -21,7 +21,7 @@ func (dst *Box) Set(src interface{}) error { return errors.Errorf("cannot convert %v to Box", src) } -func (dst *Box) Get() interface{} { +func (dst Box) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/bpchar.go b/bpchar.go index 1a85fa0d..f82e3724 100644 --- a/bpchar.go +++ b/bpchar.go @@ -14,8 +14,8 @@ func (dst *BPChar) Set(src interface{}) error { } // Get returns underlying value -func (dst *BPChar) Get() interface{} { - return (*Text)(dst).Get() +func (dst BPChar) Get() interface{} { + return (Text)(dst).Get() } // AssignTo assigns from src to dst. diff --git a/bpchar_array.go b/bpchar_array.go index b60ccc91..50168f6e 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -64,7 +64,7 @@ func (dst *BPCharArray) Set(src interface{}) error { return nil } -func (dst *BPCharArray) Get() interface{} { +func (dst BPCharArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/bytea.go b/bytea.go index c6e79cdf..507498cb 100644 --- a/bytea.go +++ b/bytea.go @@ -35,7 +35,7 @@ func (dst *Bytea) Set(src interface{}) error { return nil } -func (dst *Bytea) Get() interface{} { +func (dst Bytea) Get() interface{} { switch dst.Status { case Present: return dst.Bytes diff --git a/bytea_array.go b/bytea_array.go index fbebff24..d0b4b367 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -64,7 +64,7 @@ func (dst *ByteaArray) Set(src interface{}) error { return nil } -func (dst *ByteaArray) Get() interface{} { +func (dst ByteaArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/cid.go b/cid.go index d27982bd..b944748c 100644 --- a/cid.go +++ b/cid.go @@ -24,8 +24,8 @@ func (dst *CID) Set(src interface{}) error { return (*pguint32)(dst).Set(src) } -func (dst *CID) Get() interface{} { - return (*pguint32)(dst).Get() +func (dst CID) Get() interface{} { + return (pguint32)(dst).Get() } // AssignTo assigns from src to dst. Note that as CID is not a general number diff --git a/cidr.go b/cidr.go index 9e13a97e..2241ca1c 100644 --- a/cidr.go +++ b/cidr.go @@ -6,8 +6,8 @@ func (dst *CIDR) Set(src interface{}) error { return (*Inet)(dst).Set(src) } -func (dst *CIDR) Get() interface{} { - return (*Inet)(dst).Get() +func (dst CIDR) Get() interface{} { + return (Inet)(dst).Get() } func (src *CIDR) AssignTo(dst interface{}) error { diff --git a/cidr_array.go b/cidr_array.go index dbc71bb5..b6334f74 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -84,7 +84,7 @@ func (dst *CIDRArray) Set(src interface{}) error { return nil } -func (dst *CIDRArray) Get() interface{} { +func (dst CIDRArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/circle.go b/circle.go index 9644345c..d3f8b38a 100644 --- a/circle.go +++ b/circle.go @@ -22,7 +22,7 @@ func (dst *Circle) Set(src interface{}) error { return errors.Errorf("cannot convert %v to Circle", src) } -func (dst *Circle) Get() interface{} { +func (dst Circle) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/date.go b/date.go index eaf95dde..9804672b 100644 --- a/date.go +++ b/date.go @@ -40,7 +40,7 @@ func (dst *Date) Set(src interface{}) error { return nil } -func (dst *Date) Get() interface{} { +func (dst Date) Get() interface{} { switch dst.Status { case Present: if dst.InfinityModifier != None { diff --git a/date_array.go b/date_array.go index c97e83ee..ce6b9550 100644 --- a/date_array.go +++ b/date_array.go @@ -65,7 +65,7 @@ func (dst *DateArray) Set(src interface{}) error { return nil } -func (dst *DateArray) Get() interface{} { +func (dst DateArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/daterange.go b/daterange.go index 40997bd9..78e7b813 100644 --- a/daterange.go +++ b/daterange.go @@ -19,7 +19,7 @@ func (dst *Daterange) Set(src interface{}) error { return errors.Errorf("cannot convert %v to Daterange", src) } -func (dst *Daterange) Get() interface{} { +func (dst Daterange) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/enum_array.go b/enum_array.go index 3e07eae9..8220d425 100644 --- a/enum_array.go +++ b/enum_array.go @@ -62,7 +62,7 @@ func (dst *EnumArray) Set(src interface{}) error { return nil } -func (dst *EnumArray) Get() interface{} { +func (dst EnumArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/ext/gofrs-uuid/uuid.go b/ext/gofrs-uuid/uuid.go index 9b95a225..c1179ae2 100644 --- a/ext/gofrs-uuid/uuid.go +++ b/ext/gofrs-uuid/uuid.go @@ -47,7 +47,7 @@ func (dst *UUID) Set(src interface{}) error { return nil } -func (dst *UUID) Get() interface{} { +func (dst UUID) Get() interface{} { switch dst.Status { case pgtype.Present: return dst.UUID diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index 259fa54d..9fc8b515 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -89,7 +89,7 @@ func (dst *Numeric) Set(src interface{}) error { return nil } -func (dst *Numeric) Get() interface{} { +func (dst Numeric) Get() interface{} { switch dst.Status { case pgtype.Present: return dst.Decimal diff --git a/float4.go b/float4.go index 3f701dc5..cef14274 100644 --- a/float4.go +++ b/float4.go @@ -92,7 +92,7 @@ func (dst *Float4) Set(src interface{}) error { return nil } -func (dst *Float4) Get() interface{} { +func (dst Float4) Get() interface{} { switch dst.Status { case Present: return dst.Float diff --git a/float4_array.go b/float4_array.go index 07fac71a..4dcdef43 100644 --- a/float4_array.go +++ b/float4_array.go @@ -64,7 +64,7 @@ func (dst *Float4Array) Set(src interface{}) error { return nil } -func (dst *Float4Array) Get() interface{} { +func (dst Float4Array) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/float8.go b/float8.go index 9c6847c3..13d6b326 100644 --- a/float8.go +++ b/float8.go @@ -82,7 +82,7 @@ func (dst *Float8) Set(src interface{}) error { return nil } -func (dst *Float8) Get() interface{} { +func (dst Float8) Get() interface{} { switch dst.Status { case Present: return dst.Float diff --git a/float8_array.go b/float8_array.go index 2f65c736..be3d1d20 100644 --- a/float8_array.go +++ b/float8_array.go @@ -64,7 +64,7 @@ func (dst *Float8Array) Set(src interface{}) error { return nil } -func (dst *Float8Array) Get() interface{} { +func (dst Float8Array) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/generic_binary.go b/generic_binary.go index 5689523e..76a1d351 100644 --- a/generic_binary.go +++ b/generic_binary.go @@ -12,8 +12,8 @@ func (dst *GenericBinary) Set(src interface{}) error { return (*Bytea)(dst).Set(src) } -func (dst *GenericBinary) Get() interface{} { - return (*Bytea)(dst).Get() +func (dst GenericBinary) Get() interface{} { + return (Bytea)(dst).Get() } func (src *GenericBinary) AssignTo(dst interface{}) error { diff --git a/generic_text.go b/generic_text.go index d8890f48..dbf5b47e 100644 --- a/generic_text.go +++ b/generic_text.go @@ -12,8 +12,8 @@ func (dst *GenericText) Set(src interface{}) error { return (*Text)(dst).Set(src) } -func (dst *GenericText) Get() interface{} { - return (*Text)(dst).Get() +func (dst GenericText) Get() interface{} { + return (Text)(dst).Get() } func (src *GenericText) AssignTo(dst interface{}) error { diff --git a/hstore.go b/hstore.go index 45b165af..fcfd8f9a 100644 --- a/hstore.go +++ b/hstore.go @@ -40,7 +40,7 @@ func (dst *Hstore) Set(src interface{}) error { return nil } -func (dst *Hstore) Get() interface{} { +func (dst Hstore) Get() interface{} { switch dst.Status { case Present: return dst.Map diff --git a/hstore_array.go b/hstore_array.go index 06a11c02..3ab264f9 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -64,7 +64,7 @@ func (dst *HstoreArray) Set(src interface{}) error { return nil } -func (dst *HstoreArray) Get() interface{} { +func (dst HstoreArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/inet.go b/inet.go index 3c2eda9b..b7bbd9c4 100644 --- a/inet.go +++ b/inet.go @@ -52,7 +52,7 @@ func (dst *Inet) Set(src interface{}) error { return nil } -func (dst *Inet) Get() interface{} { +func (dst Inet) Get() interface{} { switch dst.Status { case Present: return dst.IPNet diff --git a/inet_array.go b/inet_array.go index 88181739..58cd656b 100644 --- a/inet_array.go +++ b/inet_array.go @@ -84,7 +84,7 @@ func (dst *InetArray) Set(src interface{}) error { return nil } -func (dst *InetArray) Get() interface{} { +func (dst InetArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/int2.go b/int2.go index f3e01308..7ed76803 100644 --- a/int2.go +++ b/int2.go @@ -88,7 +88,7 @@ func (dst *Int2) Set(src interface{}) error { return nil } -func (dst *Int2) Get() interface{} { +func (dst Int2) Get() interface{} { switch dst.Status { case Present: return dst.Int diff --git a/int2_array.go b/int2_array.go index 3f6bdb87..1ef24c63 100644 --- a/int2_array.go +++ b/int2_array.go @@ -197,7 +197,7 @@ func (dst *Int2Array) Set(src interface{}) error { return nil } -func (dst *Int2Array) Get() interface{} { +func (dst Int2Array) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/int4.go b/int4.go index 2075b375..efe3916e 100644 --- a/int4.go +++ b/int4.go @@ -80,7 +80,7 @@ func (dst *Int4) Set(src interface{}) error { return nil } -func (dst *Int4) Get() interface{} { +func (dst Int4) Get() interface{} { switch dst.Status { case Present: return dst.Int diff --git a/int4_array.go b/int4_array.go index f3e87b00..61112f8d 100644 --- a/int4_array.go +++ b/int4_array.go @@ -197,7 +197,7 @@ func (dst *Int4Array) Set(src interface{}) error { return nil } -func (dst *Int4Array) Get() interface{} { +func (dst Int4Array) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/int4range.go b/int4range.go index 03970ae6..6638e9c1 100644 --- a/int4range.go +++ b/int4range.go @@ -19,7 +19,7 @@ func (dst *Int4range) Set(src interface{}) error { return errors.Errorf("cannot convert %v to Int4range", src) } -func (dst *Int4range) Get() interface{} { +func (dst Int4range) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/int8.go b/int8.go index 01a694e4..526cde94 100644 --- a/int8.go +++ b/int8.go @@ -71,7 +71,7 @@ func (dst *Int8) Set(src interface{}) error { return nil } -func (dst *Int8) Get() interface{} { +func (dst Int8) Get() interface{} { switch dst.Status { case Present: return dst.Int diff --git a/int8_array.go b/int8_array.go index a6798173..985b47b8 100644 --- a/int8_array.go +++ b/int8_array.go @@ -197,7 +197,7 @@ func (dst *Int8Array) Set(src interface{}) error { return nil } -func (dst *Int8Array) Get() interface{} { +func (dst Int8Array) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/int8range.go b/int8range.go index 0e0f1cdb..88027974 100644 --- a/int8range.go +++ b/int8range.go @@ -19,7 +19,7 @@ func (dst *Int8range) Set(src interface{}) error { return errors.Errorf("cannot convert %v to Int8range", src) } -func (dst *Int8range) Get() interface{} { +func (dst Int8range) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/interval.go b/interval.go index bb19f956..0afd1cbd 100644 --- a/interval.go +++ b/interval.go @@ -44,7 +44,7 @@ func (dst *Interval) Set(src interface{}) error { return nil } -func (dst *Interval) Get() interface{} { +func (dst Interval) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/json.go b/json.go index 58a5b093..1b99c5c2 100644 --- a/json.go +++ b/json.go @@ -53,7 +53,7 @@ func (dst *JSON) Set(src interface{}) error { return nil } -func (dst *JSON) Get() interface{} { +func (dst JSON) Get() interface{} { switch dst.Status { case Present: var i interface{} diff --git a/jsonb.go b/jsonb.go index 43e23fb2..984c0973 100644 --- a/jsonb.go +++ b/jsonb.go @@ -12,8 +12,8 @@ func (dst *JSONB) Set(src interface{}) error { return (*JSON)(dst).Set(src) } -func (dst *JSONB) Get() interface{} { - return (*JSON)(dst).Get() +func (dst JSONB) Get() interface{} { + return (JSON)(dst).Get() } func (src *JSONB) AssignTo(dst interface{}) error { diff --git a/line.go b/line.go index 61477ad9..737f5d86 100644 --- a/line.go +++ b/line.go @@ -21,7 +21,7 @@ func (dst *Line) Set(src interface{}) error { return errors.Errorf("cannot convert %v to Line", src) } -func (dst *Line) Get() interface{} { +func (dst Line) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/lseg.go b/lseg.go index 822b7bf4..a16dcea3 100644 --- a/lseg.go +++ b/lseg.go @@ -21,7 +21,7 @@ func (dst *Lseg) Set(src interface{}) error { return errors.Errorf("cannot convert %v to Lseg", src) } -func (dst *Lseg) Get() interface{} { +func (dst Lseg) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/macaddr.go b/macaddr.go index 29c60440..55dec4f2 100644 --- a/macaddr.go +++ b/macaddr.go @@ -39,7 +39,7 @@ func (dst *Macaddr) Set(src interface{}) error { return nil } -func (dst *Macaddr) Get() interface{} { +func (dst Macaddr) Get() interface{} { switch dst.Status { case Present: return dst.Addr diff --git a/macaddr_array.go b/macaddr_array.go index 8382ea45..b4d42d61 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -65,7 +65,7 @@ func (dst *MacaddrArray) Set(src interface{}) error { return nil } -func (dst *MacaddrArray) Get() interface{} { +func (dst MacaddrArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/name.go b/name.go index 753a074a..7ce8d25e 100644 --- a/name.go +++ b/name.go @@ -23,8 +23,8 @@ func (dst *Name) Set(src interface{}) error { return (*Text)(dst).Set(src) } -func (dst *Name) Get() interface{} { - return (*Text)(dst).Get() +func (dst Name) Get() interface{} { + return (Text)(dst).Get() } func (src *Name) AssignTo(dst interface{}) error { diff --git a/numeric.go b/numeric.go index 554fb582..100a7e9c 100644 --- a/numeric.go +++ b/numeric.go @@ -104,7 +104,7 @@ func (dst *Numeric) Set(src interface{}) error { return nil } -func (dst *Numeric) Get() interface{} { +func (dst Numeric) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/numeric_array.go b/numeric_array.go index 432cd96f..224306c1 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -121,7 +121,7 @@ func (dst *NumericArray) Set(src interface{}) error { return nil } -func (dst *NumericArray) Get() interface{} { +func (dst NumericArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/numrange.go b/numrange.go index f3e25109..64b7fbc3 100644 --- a/numrange.go +++ b/numrange.go @@ -19,7 +19,7 @@ func (dst *Numrange) Set(src interface{}) error { return errors.Errorf("cannot convert %v to Numrange", src) } -func (dst *Numrange) Get() interface{} { +func (dst Numrange) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/oid_value.go b/oid_value.go index 619681a5..5dc9136c 100644 --- a/oid_value.go +++ b/oid_value.go @@ -18,8 +18,8 @@ func (dst *OIDValue) Set(src interface{}) error { return (*pguint32)(dst).Set(src) } -func (dst *OIDValue) Get() interface{} { - return (*pguint32)(dst).Get() +func (dst OIDValue) Get() interface{} { + return (pguint32)(dst).Get() } // AssignTo assigns from src to dst. Note that as OIDValue is not a general number diff --git a/path.go b/path.go index 484c9174..c5031330 100644 --- a/path.go +++ b/path.go @@ -22,7 +22,7 @@ func (dst *Path) Set(src interface{}) error { return errors.Errorf("cannot convert %v to Path", src) } -func (dst *Path) Get() interface{} { +func (dst Path) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/pguint32.go b/pguint32.go index 546d6f8f..a245d2c9 100644 --- a/pguint32.go +++ b/pguint32.go @@ -39,7 +39,7 @@ func (dst *pguint32) Set(src interface{}) error { return nil } -func (dst *pguint32) Get() interface{} { +func (dst pguint32) Get() interface{} { switch dst.Status { case Present: return dst.Uint diff --git a/point.go b/point.go index bb7daa24..87993656 100644 --- a/point.go +++ b/point.go @@ -26,7 +26,7 @@ func (dst *Point) Set(src interface{}) error { return errors.Errorf("cannot convert %v to Point", src) } -func (dst *Point) Get() interface{} { +func (dst Point) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/polygon.go b/polygon.go index 7805604b..653b04c1 100644 --- a/polygon.go +++ b/polygon.go @@ -21,7 +21,7 @@ func (dst *Polygon) Set(src interface{}) error { return errors.Errorf("cannot convert %v to Polygon", src) } -func (dst *Polygon) Get() interface{} { +func (dst Polygon) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/qchar.go b/qchar.go index 8a316d9b..d6577248 100644 --- a/qchar.go +++ b/qchar.go @@ -105,7 +105,7 @@ func (dst *QChar) Set(src interface{}) error { return nil } -func (dst *QChar) Get() interface{} { +func (dst QChar) Get() interface{} { switch dst.Status { case Present: return dst.Int diff --git a/record.go b/record.go index 28f4a182..aecc978b 100644 --- a/record.go +++ b/record.go @@ -33,7 +33,7 @@ func (dst *Record) Set(src interface{}) error { return nil } -func (dst *Record) Get() interface{} { +func (dst Record) Get() interface{} { switch dst.Status { case Present: return dst.Fields diff --git a/text.go b/text.go index cdd993db..bd5f0689 100644 --- a/text.go +++ b/text.go @@ -43,7 +43,7 @@ func (dst *Text) Set(src interface{}) error { return nil } -func (dst *Text) Get() interface{} { +func (dst Text) Get() interface{} { switch dst.Status { case Present: return dst.String diff --git a/text_array.go b/text_array.go index 653e41fc..9b5fcec6 100644 --- a/text_array.go +++ b/text_array.go @@ -64,7 +64,7 @@ func (dst *TextArray) Set(src interface{}) error { return nil } -func (dst *TextArray) Get() interface{} { +func (dst TextArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/tid.go b/tid.go index 08f5c047..98b95e2a 100644 --- a/tid.go +++ b/tid.go @@ -32,7 +32,7 @@ func (dst *TID) Set(src interface{}) error { return errors.Errorf("cannot convert %v to TID", src) } -func (dst *TID) Get() interface{} { +func (dst TID) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/time.go b/time.go index 3bf91b10..60043fcd 100644 --- a/time.go +++ b/time.go @@ -45,7 +45,7 @@ func (dst *Time) Set(src interface{}) error { return nil } -func (dst *Time) Get() interface{} { +func (dst Time) Get() interface{} { switch dst.Status { case Present: return dst.Microseconds diff --git a/timestamp.go b/timestamp.go index 01c38a0a..feb88873 100644 --- a/timestamp.go +++ b/timestamp.go @@ -43,7 +43,7 @@ func (dst *Timestamp) Set(src interface{}) error { return nil } -func (dst *Timestamp) Get() interface{} { +func (dst Timestamp) Get() interface{} { switch dst.Status { case Present: if dst.InfinityModifier != None { diff --git a/timestamp_array.go b/timestamp_array.go index 072e01ac..063d339b 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -65,7 +65,7 @@ func (dst *TimestampArray) Set(src interface{}) error { return nil } -func (dst *TimestampArray) Get() interface{} { +func (dst TimestampArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/timestamptz.go b/timestamptz.go index 7ed86eb8..3d3e7143 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -45,7 +45,7 @@ func (dst *Timestamptz) Set(src interface{}) error { return nil } -func (dst *Timestamptz) Get() interface{} { +func (dst Timestamptz) Get() interface{} { switch dst.Status { case Present: if dst.InfinityModifier != None { diff --git a/timestamptz_array.go b/timestamptz_array.go index 9d0677c8..4924498d 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -65,7 +65,7 @@ func (dst *TimestamptzArray) Set(src interface{}) error { return nil } -func (dst *TimestamptzArray) Get() interface{} { +func (dst TimestamptzArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/tsrange.go b/tsrange.go index 54cc863f..68fa6d73 100644 --- a/tsrange.go +++ b/tsrange.go @@ -19,7 +19,7 @@ func (dst *Tsrange) Set(src interface{}) error { return errors.Errorf("cannot convert %v to Tsrange", src) } -func (dst *Tsrange) Get() interface{} { +func (dst Tsrange) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/tstzrange.go b/tstzrange.go index 1cf2859d..8441275f 100644 --- a/tstzrange.go +++ b/tstzrange.go @@ -19,7 +19,7 @@ func (dst *Tstzrange) Set(src interface{}) error { return errors.Errorf("cannot convert %v to Tstzrange", src) } -func (dst *Tstzrange) Get() interface{} { +func (dst Tstzrange) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/tstzrange_array.go b/tstzrange_array.go index f7c0121d..cf407253 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -45,7 +45,7 @@ func (dst *TstzrangeArray) Set(src interface{}) error { return nil } -func (dst *TstzrangeArray) Get() interface{} { +func (dst TstzrangeArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/typed_array.go.erb b/typed_array.go.erb index 72c0c381..494bd534 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -66,7 +66,7 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { return nil } -func (dst *<%= pgtype_array_type %>) Get() interface{} { +func (dst <%= pgtype_array_type %>) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/typed_range.go.erb b/typed_range.go.erb index 035a71af..9846e5dd 100644 --- a/typed_range.go.erb +++ b/typed_range.go.erb @@ -21,7 +21,7 @@ func (dst *<%= range_type %>) Set(src interface{}) error { return errors.Errorf("cannot convert %v to <%= range_type %>", src) } -func (dst *<%= range_type %>) Get() interface{} { +func (dst <%= range_type %>) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/unknown.go b/unknown.go index 2dca0f87..c591b708 100644 --- a/unknown.go +++ b/unknown.go @@ -15,8 +15,8 @@ func (dst *Unknown) Set(src interface{}) error { return (*Text)(dst).Set(src) } -func (dst *Unknown) Get() interface{} { - return (*Text)(dst).Get() +func (dst Unknown) Get() interface{} { + return (Text)(dst).Get() } // AssignTo assigns from src to dst. Note that as Unknown is not a general number diff --git a/uuid.go b/uuid.go index ba999a06..70a6b7fa 100644 --- a/uuid.go +++ b/uuid.go @@ -48,7 +48,7 @@ func (dst *UUID) Set(src interface{}) error { return nil } -func (dst *UUID) Get() interface{} { +func (dst UUID) Get() interface{} { switch dst.Status { case Present: return dst.Bytes diff --git a/uuid_array.go b/uuid_array.go index 7c324e53..27dcd259 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -102,7 +102,7 @@ func (dst *UUIDArray) Set(src interface{}) error { return nil } -func (dst *UUIDArray) Get() interface{} { +func (dst UUIDArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/varbit.go b/varbit.go index 019fff8a..7461bab3 100644 --- a/varbit.go +++ b/varbit.go @@ -18,7 +18,7 @@ func (dst *Varbit) Set(src interface{}) error { return errors.Errorf("cannot convert %v to Varbit", src) } -func (dst *Varbit) Get() interface{} { +func (dst Varbit) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/varchar.go b/varchar.go index 58de1097..e4fa6869 100644 --- a/varchar.go +++ b/varchar.go @@ -13,8 +13,8 @@ func (dst *Varchar) Set(src interface{}) error { return (*Text)(dst).Set(src) } -func (dst *Varchar) Get() interface{} { - return (*Text)(dst).Get() +func (dst Varchar) Get() interface{} { + return (Text)(dst).Get() } // AssignTo assigns from src to dst. Note that as Varchar is not a general number diff --git a/varchar_array.go b/varchar_array.go index ac9af519..7f476285 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -64,7 +64,7 @@ func (dst *VarcharArray) Set(src interface{}) error { return nil } -func (dst *VarcharArray) Get() interface{} { +func (dst VarcharArray) Get() interface{} { switch dst.Status { case Present: return dst diff --git a/xid.go b/xid.go index 80ebf0e0..f6d6b22d 100644 --- a/xid.go +++ b/xid.go @@ -27,8 +27,8 @@ func (dst *XID) Set(src interface{}) error { return (*pguint32)(dst).Set(src) } -func (dst *XID) Get() interface{} { - return (*pguint32)(dst).Get() +func (dst XID) Get() interface{} { + return (pguint32)(dst).Get() } // AssignTo assigns from src to dst. Note that as XID is not a general number From 666bd514e2b2f43a39d1ebc56825d8748a3fdc31 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 19 Feb 2020 10:50:58 -0600 Subject: [PATCH 0413/1158] Add standard nil test to gofrs-uuid.UUID.Set --- ext/gofrs-uuid/uuid.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ext/gofrs-uuid/uuid.go b/ext/gofrs-uuid/uuid.go index c1179ae2..a358fead 100644 --- a/ext/gofrs-uuid/uuid.go +++ b/ext/gofrs-uuid/uuid.go @@ -17,6 +17,11 @@ type UUID struct { } func (dst *UUID) Set(src interface{}) error { + if src == nil { + *dst = UUID{Status: pgtype.Null} + return nil + } + switch value := src.(type) { case uuid.UUID: *dst = UUID{UUID: value, Status: pgtype.Present} From 55a56add235573da4358e3fd52ad4a35b30c92eb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 19 Feb 2020 11:58:49 -0600 Subject: [PATCH 0414/1158] Set will call Get on src if possible --- aclitem.go | 12 ++++++++++++ aclitem_array.go | 7 +++++++ bool.go | 7 +++++++ bool_array.go | 7 +++++++ bpchar_array.go | 7 +++++++ bytea.go | 7 +++++++ bytea_array.go | 7 +++++++ cidr_array.go | 7 +++++++ date.go | 7 +++++++ date_array.go | 7 +++++++ enum_array.go | 7 +++++++ ext/gofrs-uuid/uuid.go | 7 +++++++ ext/gofrs-uuid/uuid_test.go | 4 ++++ ext/shopspring-numeric/decimal.go | 7 +++++++ float4.go | 7 +++++++ float4_array.go | 7 +++++++ float8.go | 7 +++++++ float8_array.go | 7 +++++++ go.mod | 1 + go.sum | 2 ++ hstore.go | 7 +++++++ hstore_array.go | 7 +++++++ inet.go | 7 +++++++ inet_array.go | 7 +++++++ int2.go | 7 +++++++ int2_array.go | 7 +++++++ int4.go | 7 +++++++ int4_array.go | 7 +++++++ int8.go | 7 +++++++ int8_array.go | 7 +++++++ interval.go | 7 +++++++ json.go | 7 +++++++ macaddr.go | 7 +++++++ macaddr_array.go | 7 +++++++ numeric.go | 7 +++++++ numeric_array.go | 7 +++++++ qchar.go | 7 +++++++ record.go | 7 +++++++ text.go | 7 +++++++ text_array.go | 7 +++++++ time.go | 7 +++++++ timestamp.go | 7 +++++++ timestamp_array.go | 7 +++++++ timestamptz.go | 7 +++++++ timestamptz_array.go | 7 +++++++ tstzrange_array.go | 7 +++++++ typed_array.go.erb | 7 +++++++ uuid.go | 7 +++++++ uuid_array.go | 7 +++++++ varchar_array.go | 7 +++++++ 50 files changed, 341 insertions(+) diff --git a/aclitem.go b/aclitem.go index 36df71bc..d2fe7529 100644 --- a/aclitem.go +++ b/aclitem.go @@ -24,6 +24,18 @@ type ACLItem struct { } func (dst *ACLItem) Set(src interface{}) error { + if src == nil { + *dst = ACLItem{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case string: *dst = ACLItem{String: value, Status: Present} diff --git a/aclitem_array.go b/aclitem_array.go index bfb069fd..1d3de130 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -19,6 +19,13 @@ func (dst *ACLItemArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []string: diff --git a/bool.go b/bool.go index 898197f7..8b03a1af 100644 --- a/bool.go +++ b/bool.go @@ -19,6 +19,13 @@ func (dst *Bool) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case bool: *dst = Bool{Bool: value, Status: Present} diff --git a/bool_array.go b/bool_array.go index 44500b79..c1af1e1f 100644 --- a/bool_array.go +++ b/bool_array.go @@ -21,6 +21,13 @@ func (dst *BoolArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []bool: diff --git a/bpchar_array.go b/bpchar_array.go index 50168f6e..b6eeabd7 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -21,6 +21,13 @@ func (dst *BPCharArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []string: diff --git a/bytea.go b/bytea.go index 507498cb..b9e4d15a 100644 --- a/bytea.go +++ b/bytea.go @@ -18,6 +18,13 @@ func (dst *Bytea) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []byte: if value != nil { diff --git a/bytea_array.go b/bytea_array.go index d0b4b367..6a45e4da 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -21,6 +21,13 @@ func (dst *ByteaArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case [][]byte: diff --git a/cidr_array.go b/cidr_array.go index b6334f74..4f3097a0 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -22,6 +22,13 @@ func (dst *CIDRArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []*net.IPNet: diff --git a/date.go b/date.go index 9804672b..10e41fe7 100644 --- a/date.go +++ b/date.go @@ -27,6 +27,13 @@ func (dst *Date) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case time.Time: *dst = Date{Time: value, Status: Present} diff --git a/date_array.go b/date_array.go index ce6b9550..644e78fe 100644 --- a/date_array.go +++ b/date_array.go @@ -22,6 +22,13 @@ func (dst *DateArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []time.Time: diff --git a/enum_array.go b/enum_array.go index 8220d425..a31916dc 100644 --- a/enum_array.go +++ b/enum_array.go @@ -19,6 +19,13 @@ func (dst *EnumArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []string: diff --git a/ext/gofrs-uuid/uuid.go b/ext/gofrs-uuid/uuid.go index a358fead..b0413cae 100644 --- a/ext/gofrs-uuid/uuid.go +++ b/ext/gofrs-uuid/uuid.go @@ -22,6 +22,13 @@ func (dst *UUID) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case uuid.UUID: *dst = UUID{UUID: value, Status: pgtype.Present} diff --git a/ext/gofrs-uuid/uuid_test.go b/ext/gofrs-uuid/uuid_test.go index 124720b8..56814524 100644 --- a/ext/gofrs-uuid/uuid_test.go +++ b/ext/gofrs-uuid/uuid_test.go @@ -21,6 +21,10 @@ func TestUUIDSet(t *testing.T) { source interface{} result gofrs.UUID }{ + { + source: &gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, { source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index 9fc8b515..70906806 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -24,6 +24,13 @@ func (dst *Numeric) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case decimal.Decimal: *dst = Numeric{Decimal: value, Status: pgtype.Present} diff --git a/float4.go b/float4.go index cef14274..e33dfc75 100644 --- a/float4.go +++ b/float4.go @@ -21,6 +21,13 @@ func (dst *Float4) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case float32: *dst = Float4{Float: value, Status: Present} diff --git a/float4_array.go b/float4_array.go index 4dcdef43..ccd718a1 100644 --- a/float4_array.go +++ b/float4_array.go @@ -21,6 +21,13 @@ func (dst *Float4Array) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []float32: diff --git a/float8.go b/float8.go index 13d6b326..41d0fe70 100644 --- a/float8.go +++ b/float8.go @@ -21,6 +21,13 @@ func (dst *Float8) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case float32: *dst = Float8{Float: float64(value), Status: Present} diff --git a/float8_array.go b/float8_array.go index be3d1d20..740e8558 100644 --- a/float8_array.go +++ b/float8_array.go @@ -21,6 +21,13 @@ func (dst *Float8Array) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []float64: diff --git a/go.mod b/go.mod index 9f47a705..920f4b21 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.12 require ( github.com/gofrs/uuid v3.2.0+incompatible github.com/jackc/pgio v1.0.0 + github.com/jackc/pgx v3.6.2+incompatible github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186 github.com/lib/pq v1.2.0 github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 diff --git a/go.sum b/go.sum index 275e7fe1..7a3fc71e 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4 github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgx v3.6.2+incompatible h1:2zP5OD7kiyR3xzRYMhOcXVvkDZsImVXfj+yIyTQf3/o= +github.com/jackc/pgx v3.6.2+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96 h1:ylEAOd688Duev/fxTmGdupsbyZfxNMdngIG14DoBKTM= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912 h1:YuOWGsSK5L4Fz81Olx5TNlZftmDuNrfv4ip0Yos77Tw= diff --git a/hstore.go b/hstore.go index fcfd8f9a..3fe50ae5 100644 --- a/hstore.go +++ b/hstore.go @@ -26,6 +26,13 @@ func (dst *Hstore) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case map[string]string: m := make(map[string]Text, len(value)) diff --git a/hstore_array.go b/hstore_array.go index 3ab264f9..54909e42 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -21,6 +21,13 @@ func (dst *HstoreArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []map[string]string: diff --git a/inet.go b/inet.go index b7bbd9c4..7ab78bdf 100644 --- a/inet.go +++ b/inet.go @@ -27,6 +27,13 @@ func (dst *Inet) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case net.IPNet: *dst = Inet{IPNet: &value, Status: Present} diff --git a/inet_array.go b/inet_array.go index 58cd656b..a663d51d 100644 --- a/inet_array.go +++ b/inet_array.go @@ -22,6 +22,13 @@ func (dst *InetArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []*net.IPNet: diff --git a/int2.go b/int2.go index 7ed76803..54bab272 100644 --- a/int2.go +++ b/int2.go @@ -21,6 +21,13 @@ func (dst *Int2) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case int8: *dst = Int2{Int: int16(value), Status: Present} diff --git a/int2_array.go b/int2_array.go index 1ef24c63..98552171 100644 --- a/int2_array.go +++ b/int2_array.go @@ -21,6 +21,13 @@ func (dst *Int2Array) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []int16: diff --git a/int4.go b/int4.go index efe3916e..66fe9155 100644 --- a/int4.go +++ b/int4.go @@ -22,6 +22,13 @@ func (dst *Int4) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case int8: *dst = Int4{Int: int32(value), Status: Present} diff --git a/int4_array.go b/int4_array.go index 61112f8d..a52ab437 100644 --- a/int4_array.go +++ b/int4_array.go @@ -21,6 +21,13 @@ func (dst *Int4Array) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []int16: diff --git a/int8.go b/int8.go index 526cde94..fd721142 100644 --- a/int8.go +++ b/int8.go @@ -22,6 +22,13 @@ func (dst *Int8) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case int8: *dst = Int8{Int: int64(value), Status: Present} diff --git a/int8_array.go b/int8_array.go index 985b47b8..f6d577f0 100644 --- a/int8_array.go +++ b/int8_array.go @@ -21,6 +21,13 @@ func (dst *Int8Array) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []int16: diff --git a/interval.go b/interval.go index 0afd1cbd..3a91c595 100644 --- a/interval.go +++ b/interval.go @@ -31,6 +31,13 @@ func (dst *Interval) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case time.Duration: *dst = Interval{Microseconds: int64(value) / 1000, Status: Present} diff --git a/json.go b/json.go index 1b99c5c2..c642c727 100644 --- a/json.go +++ b/json.go @@ -18,6 +18,13 @@ func (dst *JSON) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case string: *dst = JSON{Bytes: []byte(value), Status: Present} diff --git a/macaddr.go b/macaddr.go index 55dec4f2..af0901b0 100644 --- a/macaddr.go +++ b/macaddr.go @@ -18,6 +18,13 @@ func (dst *Macaddr) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case net.HardwareAddr: addr := make(net.HardwareAddr, len(value)) diff --git a/macaddr_array.go b/macaddr_array.go index b4d42d61..97b13537 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -22,6 +22,13 @@ func (dst *MacaddrArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []net.HardwareAddr: diff --git a/numeric.go b/numeric.go index 100a7e9c..e6c58391 100644 --- a/numeric.go +++ b/numeric.go @@ -55,6 +55,13 @@ func (dst *Numeric) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case float32: num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) diff --git a/numeric_array.go b/numeric_array.go index 224306c1..3cec9fea 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -21,6 +21,13 @@ func (dst *NumericArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []float32: diff --git a/qchar.go b/qchar.go index d6577248..93964058 100644 --- a/qchar.go +++ b/qchar.go @@ -29,6 +29,13 @@ func (dst *QChar) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case int8: *dst = QChar{Int: value, Status: Present} diff --git a/record.go b/record.go index aecc978b..5c9d7a02 100644 --- a/record.go +++ b/record.go @@ -23,6 +23,13 @@ func (dst *Record) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []Value: *dst = Record{Fields: value, Status: Present} diff --git a/text.go b/text.go index bd5f0689..1f5d2a37 100644 --- a/text.go +++ b/text.go @@ -18,6 +18,13 @@ func (dst *Text) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case string: *dst = Text{String: value, Status: Present} diff --git a/text_array.go b/text_array.go index 9b5fcec6..2130af84 100644 --- a/text_array.go +++ b/text_array.go @@ -21,6 +21,13 @@ func (dst *TextArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []string: diff --git a/time.go b/time.go index 60043fcd..16a2a393 100644 --- a/time.go +++ b/time.go @@ -28,6 +28,13 @@ func (dst *Time) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case time.Time: usec := int64(value.Hour())*microsecondsPerHour + diff --git a/timestamp.go b/timestamp.go index feb88873..35ac5143 100644 --- a/timestamp.go +++ b/timestamp.go @@ -30,6 +30,13 @@ func (dst *Timestamp) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case time.Time: *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present} diff --git a/timestamp_array.go b/timestamp_array.go index 063d339b..49ac98fd 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -22,6 +22,13 @@ func (dst *TimestampArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []time.Time: diff --git a/timestamptz.go b/timestamptz.go index 3d3e7143..d390d266 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -32,6 +32,13 @@ func (dst *Timestamptz) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case time.Time: *dst = Timestamptz{Time: value, Status: Present} diff --git a/timestamptz_array.go b/timestamptz_array.go index 4924498d..2e26692b 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -22,6 +22,13 @@ func (dst *TimestamptzArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []time.Time: diff --git a/tstzrange_array.go b/tstzrange_array.go index cf407253..2c365645 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -21,6 +21,13 @@ func (dst *TstzrangeArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []Tstzrange: diff --git a/typed_array.go.erb b/typed_array.go.erb index 494bd534..d8ae97dd 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -21,6 +21,13 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { <% go_array_types.split(",").each do |t| %> <% if t != "[]#{pgtype_element_type}" %> diff --git a/uuid.go b/uuid.go index 70a6b7fa..bdbe17e4 100644 --- a/uuid.go +++ b/uuid.go @@ -19,6 +19,13 @@ func (dst *UUID) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case [16]byte: *dst = UUID{Bytes: value, Status: Present} diff --git a/uuid_array.go b/uuid_array.go index 27dcd259..4cd65017 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -21,6 +21,13 @@ func (dst *UUIDArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case [][16]byte: diff --git a/varchar_array.go b/varchar_array.go index 7f476285..b13f29ce 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -21,6 +21,13 @@ func (dst *VarcharArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []string: From 8117205a7549722353a70cfbc6046c4081f9c0c8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 3 Mar 2020 15:25:57 -0600 Subject: [PATCH 0415/1158] Range types Set method supports its own type, string, and nil Previously Set would always return an error when called on a range type. Now it will accept an instance of itself, a pointer to an instance of itself, a string, or nil. Strings are parsed with the same logic as DecodeText. --- daterange.go | 19 ++++++++++++- daterange_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++++++ int4range.go | 19 ++++++++++++- int8range.go | 19 ++++++++++++- numrange.go | 19 ++++++++++++- tsrange.go | 19 ++++++++++++- tstzrange.go | 19 ++++++++++++- typed_range.go.erb | 19 ++++++++++++- 8 files changed, 192 insertions(+), 7 deletions(-) diff --git a/daterange.go b/daterange.go index 78e7b813..7b9af795 100644 --- a/daterange.go +++ b/daterange.go @@ -16,7 +16,24 @@ type Daterange struct { } func (dst *Daterange) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Daterange", src) + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Daterange{Status: Null} + return nil + } + + switch value := src.(type) { + case Daterange: + *dst = value + case *Daterange: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return errors.Errorf("cannot convert %v to Daterange", src) + } + + return nil } func (dst Daterange) Get() interface{} { diff --git a/daterange_test.go b/daterange_test.go index 4118cffa..54d51e2d 100644 --- a/daterange_test.go +++ b/daterange_test.go @@ -65,3 +65,69 @@ func TestDaterangeNormalize(t *testing.T) { a.Upper.InfinityModifier == b.Upper.InfinityModifier }) } + +func TestDaterangeSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Daterange + }{ + { + source: nil, + result: pgtype.Daterange{Status: pgtype.Null}, + }, + { + source: &pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + result: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + }, + { + source: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + result: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + }, + { + source: "[1990-12-31,2028-01-01)", + result: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Daterange + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/int4range.go b/int4range.go index 6638e9c1..442f2501 100644 --- a/int4range.go +++ b/int4range.go @@ -16,7 +16,24 @@ type Int4range struct { } func (dst *Int4range) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Int4range", src) + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int4range{Status: Null} + return nil + } + + switch value := src.(type) { + case Int4range: + *dst = value + case *Int4range: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return errors.Errorf("cannot convert %v to Int4range", src) + } + + return nil } func (dst Int4range) Get() interface{} { diff --git a/int8range.go b/int8range.go index 88027974..92fcb136 100644 --- a/int8range.go +++ b/int8range.go @@ -16,7 +16,24 @@ type Int8range struct { } func (dst *Int8range) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Int8range", src) + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + switch value := src.(type) { + case Int8range: + *dst = value + case *Int8range: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return errors.Errorf("cannot convert %v to Int8range", src) + } + + return nil } func (dst Int8range) Get() interface{} { diff --git a/numrange.go b/numrange.go index 64b7fbc3..40467686 100644 --- a/numrange.go +++ b/numrange.go @@ -16,7 +16,24 @@ type Numrange struct { } func (dst *Numrange) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Numrange", src) + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + switch value := src.(type) { + case Numrange: + *dst = value + case *Numrange: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return errors.Errorf("cannot convert %v to Numrange", src) + } + + return nil } func (dst Numrange) Get() interface{} { diff --git a/tsrange.go b/tsrange.go index 68fa6d73..6ca12aed 100644 --- a/tsrange.go +++ b/tsrange.go @@ -16,7 +16,24 @@ type Tsrange struct { } func (dst *Tsrange) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Tsrange", src) + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Tsrange{Status: Null} + return nil + } + + switch value := src.(type) { + case Tsrange: + *dst = value + case *Tsrange: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return errors.Errorf("cannot convert %v to Tsrange", src) + } + + return nil } func (dst Tsrange) Get() interface{} { diff --git a/tstzrange.go b/tstzrange.go index 8441275f..1b05c3ea 100644 --- a/tstzrange.go +++ b/tstzrange.go @@ -16,7 +16,24 @@ type Tstzrange struct { } func (dst *Tstzrange) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Tstzrange", src) + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Tstzrange{Status: Null} + return nil + } + + switch value := src.(type) { + case Tstzrange: + *dst = value + case *Tstzrange: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return errors.Errorf("cannot convert %v to Tstzrange", src) + } + + return nil } func (dst Tstzrange) Get() interface{} { diff --git a/typed_range.go.erb b/typed_range.go.erb index 9846e5dd..e21b6cda 100644 --- a/typed_range.go.erb +++ b/typed_range.go.erb @@ -18,7 +18,24 @@ type <%= range_type %> struct { } func (dst *<%= range_type %>) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to <%= range_type %>", src) + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + switch value := src.(type) { + case <%= range_type %>: + *dst = value + case *<%= range_type %>: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return errors.Errorf("cannot convert %v to <%= range_type %>", src) + } + + return nil } func (dst <%= range_type %>) Get() interface{} { From 911e727d78134c87d39b064aaee2bfda30f7afde Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Mar 2020 10:55:28 -0600 Subject: [PATCH 0416/1158] ExecParams and ExecPrepared handle empty query An empty query does not return CommandComplete. Instead it returns EmptyQueryResponse. --- pgconn.go | 2 ++ pgconn_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/pgconn.go b/pgconn.go index 751d8fc0..6155281d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1406,6 +1406,8 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error rr.fieldDescriptions = msg.Fields case *pgproto3.CommandComplete: rr.concludeCommand(CommandTag(msg.CommandTag), nil) + case *pgproto3.EmptyQueryResponse: + rr.concludeCommand(nil, nil) case *pgproto3.ErrorResponse: rr.concludeCommand(nil, ErrorResponseToPgError(msg)) } diff --git a/pgconn_test.go b/pgconn_test.go index 19ad3a0a..17b40343 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -668,6 +668,24 @@ func TestConnExecParamsPrecanceled(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecParamsEmptySQL(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read() + assert.Nil(t, result.CommandTag) + assert.Len(t, result.Rows, 0) + assert.NoError(t, result.Err) + + ensureConnValid(t, pgConn) +} + func TestConnExecPrepared(t *testing.T) { t.Parallel() @@ -797,6 +815,27 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecPreparedEmptySQL(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(ctx, "ps1", "", nil) + require.NoError(t, err) + + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() + assert.Nil(t, result.CommandTag) + assert.Len(t, result.Rows, 0) + assert.NoError(t, result.Err) + + ensureConnValid(t, pgConn) +} + func TestConnExecBatch(t *testing.T) { t.Parallel() From cfbd2519e3a9dd64906a0888c38ee05d78e19889 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Mar 2020 13:17:39 -0600 Subject: [PATCH 0417/1158] Add PGSERVICE and PGSERVICEFILE support --- config.go | 93 +++++++++++++++++++++++++++++++++++++++++------- config_test.go | 95 ++++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 3 +- go.sum | 4 +++ 4 files changed, 182 insertions(+), 13 deletions(-) diff --git a/config.go b/config.go index 9876ac94..19521a8f 100644 --- a/config.go +++ b/config.go @@ -20,6 +20,7 @@ import ( "github.com/jackc/chunkreader/v2" "github.com/jackc/pgpassfile" "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgservicefile" errors "golang.org/x/xerrors" ) @@ -108,6 +109,8 @@ func NetworkAddress(host string, port uint16) (network, address string) { // PGUSER // PGPASSWORD // PGPASSFILE +// PGSERVICE +// PGSERVICEFILE // PGSSLMODE // PGSSLCERT // PGSSLKEY @@ -145,25 +148,40 @@ func NetworkAddress(host string, port uint16) (network, address string) { // // min_read_buffer_size // The minimum size of the internal read buffer. Default 8192. +// servicefile +// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a +// part of the connection string. func ParseConfig(connString string) (*Config, error) { - settings := defaultSettings() - addEnvSettings(settings) + defaultSettings := defaultSettings() + envSettings := parseEnvSettings() + connStringSettings := make(map[string]string) if connString != "" { + var err error // connString may be a database URL or a DSN if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { - err := addURLSettings(settings, connString) + connStringSettings, err = parseURLSettings(connString) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} } } else { - err := addDSNSettings(settings, connString) + connStringSettings, err = parseDSNSettings(connString) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} } } } + settings := mergeSettings(defaultSettings, envSettings, connStringSettings) + if service, present := settings["service"]; present { + serviceSettings, err := parseServiceSettings(settings["servicefile"], service) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err} + } + + settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) + } + minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) if err != nil { return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err} @@ -205,6 +223,8 @@ func ParseConfig(connString string) (*Config, error) { "sslrootcert": struct{}{}, "target_session_attrs": struct{}{}, "min_read_buffer_size": struct{}{}, + "service": struct{}{}, + "servicefile": struct{}{}, } for k, v := range settings { @@ -293,6 +313,7 @@ func defaultSettings() map[string]string { if err == nil { settings["user"] = user.Username settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") + settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") } settings["target_session_attrs"] = "any" @@ -321,7 +342,21 @@ func defaultHost() string { return "localhost" } -func addEnvSettings(settings map[string]string) { +func mergeSettings(settingSets ...map[string]string) map[string]string { + settings := make(map[string]string) + + for _, s2 := range settingSets { + for k, v := range s2 { + settings[k] = v + } + } + + return settings +} + +func parseEnvSettings() map[string]string { + settings := make(map[string]string) + nameMap := map[string]string{ "PGHOST": "host", "PGPORT": "port", @@ -336,6 +371,8 @@ func addEnvSettings(settings map[string]string) { "PGSSLCERT": "sslcert", "PGSSLROOTCERT": "sslrootcert", "PGTARGETSESSIONATTRS": "target_session_attrs", + "PGSERVICE": "service", + "PGSERVICEFILE": "servicefile", } for envname, realname := range nameMap { @@ -344,12 +381,16 @@ func addEnvSettings(settings map[string]string) { settings[realname] = value } } + + return settings } -func addURLSettings(settings map[string]string, connString string) error { +func parseURLSettings(connString string) (map[string]string, error) { + settings := make(map[string]string) + url, err := url.Parse(connString) if err != nil { - return err + return nil, err } if url.User != nil { @@ -387,12 +428,14 @@ func addURLSettings(settings map[string]string, connString string) error { settings[k] = v[0] } - return nil + return settings, nil } var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} -func addDSNSettings(settings map[string]string, s string) error { +func parseDSNSettings(s string) (map[string]string, error) { + settings := make(map[string]string) + nameMap := map[string]string{ "dbname": "database", } @@ -401,7 +444,7 @@ func addDSNSettings(settings map[string]string, s string) error { var key, val string eqIdx := strings.IndexRune(s, '=') if eqIdx < 0 { - return errors.New("invalid dsn") + return nil, errors.New("invalid dsn") } key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") @@ -434,7 +477,7 @@ func addDSNSettings(settings map[string]string, s string) error { } } if end == len(s) { - return errors.New("unterminated quoted string in connection info string") + return nil, errors.New("unterminated quoted string in connection info string") } val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) if end == len(s) { @@ -451,7 +494,33 @@ func addDSNSettings(settings map[string]string, s string) error { settings[key] = val } - return nil + return settings, nil +} + +func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) { + servicefile, err := pgservicefile.ReadServicefile(servicefilePath) + if err != nil { + fmt.Errorf("failed to read service file: %v", servicefile) + } + + service, err := servicefile.GetService(serviceName) + if err != nil { + fmt.Errorf("unable to find service: %v", servicefile) + } + + nameMap := map[string]string{ + "dbname": "database", + } + + settings := make(map[string]string, len(service.Settings)) + for k, v := range service.Settings { + if k2, present := nameMap[k]; present { + k = k2 + } + settings[k] = v + } + + return settings, nil } type pgTLSArgs struct { diff --git a/config_test.go b/config_test.go index 9eb5df2f..0819740f 100644 --- a/config_test.go +++ b/config_test.go @@ -648,6 +648,101 @@ func TestParseConfigReadsPgPassfile(t *testing.T) { assertConfigsEqual(t, expected, actual, "passfile") } +func TestParseConfigReadsPgServiceFile(t *testing.T) { + t.Parallel() + + tf, err := ioutil.TempFile("", "") + require.NoError(t, err) + + defer tf.Close() + defer os.Remove(tf.Name()) + + _, err = tf.Write([]byte(` +[abc] +host=abc.example.com +port=9999 +dbname=abcdb +user=abcuser + +[def] +host = def.example.com +dbname = defdb +user = defuser +application_name = spaced string +`)) + require.NoError(t, err) + + tests := []struct { + name string + connString string + config *pgconn.Config + }{ + { + name: "abc", + connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "abc"), + config: &pgconn.Config{ + Host: "abc.example.com", + Database: "abcdb", + User: "abcuser", + Port: 9999, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "abc.example.com", + Port: 9999, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "def", + connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "def"), + config: &pgconn.Config{ + Host: "def.example.com", + Port: 5432, + Database: "defdb", + User: "defuser", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{"application_name": "spaced string"}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "def.example.com", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "conn string has precedence", + connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tf.Name(), "abc"), + config: &pgconn.Config{ + Host: "other.example.com", + Database: "abcdb", + User: "abcuser", + Port: 7777, + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + } + + for i, tt := range tests { + config, err := pgconn.ParseConfig(tt.connString) + if !assert.NoErrorf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + func TestParseConfigExtractsMinReadBufferSize(t *testing.T) { t.Parallel() diff --git a/go.mod b/go.mod index 37590559..b306e1e4 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,8 @@ require ( github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.1 - github.com/stretchr/testify v1.4.0 + github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 + github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 golang.org/x/text v0.3.2 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 diff --git a/go.sum b/go.sum index 28f094e7..13f276b2 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.1 h1:Rdjp4NFjwHnEslx2b66FfCI2S0LhO4itac3hXz6WX9M= github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= +github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= @@ -69,6 +71,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= From ccf634cf2e2816d97bdc40644bf47b8dd3e5cd97 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Mar 2020 13:21:51 -0600 Subject: [PATCH 0418/1158] Release 1.4.0 --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index eb099dc2..e5b11b7c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 1.4.0 (March 7, 2020) + +* Fix ExecParams and ExecPrepared handling of empty query. +* Support reading config from PostgreSQL service files. + # 1.3.2 (February 14, 2020) * Update chunkreader to v2.0.1 for optimized default buffer size. From 9e700ff067212a8c0d4a2020825a219f045b7571 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 9 Mar 2020 10:40:40 -0500 Subject: [PATCH 0419/1158] Date.Set parses string --- date.go | 2 ++ date_test.go | 1 + 2 files changed, 3 insertions(+) diff --git a/date.go b/date.go index 10e41fe7..37fb8302 100644 --- a/date.go +++ b/date.go @@ -37,6 +37,8 @@ func (dst *Date) Set(src interface{}) error { switch value := src.(type) { case time.Time: *dst = Date{Time: value, Status: Present} + case string: + return dst.DecodeText(nil, []byte(value)) default: if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) diff --git a/date_test.go b/date_test.go index 0b77898b..5c38e7a3 100644 --- a/date_test.go +++ b/date_test.go @@ -42,6 +42,7 @@ func TestDateSet(t *testing.T) { {source: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: "1999-12-31", result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, } for i, tt := range successfulTests { From 4ed48d05d2378ae2828b6f767544aad61aa99a9c Mon Sep 17 00:00:00 2001 From: Greg Curtis Date: Tue, 17 Mar 2020 23:30:56 -0700 Subject: [PATCH 0420/1158] Implement "verify-ca" SSL mode ParseConfig currently treats the libpq "verify-ca" SSL mode as "verify-full". This is okay from a security standpoint because "verify-full" performs certificate verification and hostname verification, whereas "verify-ca" only performs certificate verification. The downside to this approach is that checking the hostname is unnecessary when the server's certificate has been signed by a private CA. It can also cause the SSL handshake to fail when connecting to an instance by IP. For example, a Google Cloud SQL instance typically doesn't have a hostname and uses its own private CA to sign its server and client certs. This change uses the tls.Config.VerifyPeerCertificate function to perform certificate verification without checking the hostname when the "verify-ca" SSL mode is set. This brings pgconn's behavior closer to that of libpq. See https://github.com/golang/go/issues/21971#issuecomment-332693931 and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate for more details on how this is implemented. --- config.go | 41 +++++++++++++++++++++++++++++++++++------ config_test.go | 4 +++- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/config.go b/config.go index 19521a8f..70e6073a 100644 --- a/config.go +++ b/config.go @@ -132,11 +132,6 @@ func NetworkAddress(host string, port uint16) (network, address string) { // See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of // security each sslmode provides. // -// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger -// security guarantees than it would with libpq. Do not rely on this behavior as it -// may be possible to match libpq in the future. If you need full security use -// "verify-full". -// // Other known differences with libpq: // // If a host name resolves into multiple addresses, libpq will try all addresses. pgconn will only try the first. @@ -554,7 +549,41 @@ func configTLS(settings map[string]string) ([]*tls.Config, error) { tlsConfig.InsecureSkipVerify = true case "require": tlsConfig.InsecureSkipVerify = sslrootcert == "" - case "verify-ca", "verify-full": + case "verify-ca": + // Don't perform the default certificate verification because it + // will verify the hostname. Instead, verify the server's + // certificate chain ourselves in VerifyPeerCertificate and + // ignore the server name. This emulates libpq's verify-ca + // behavior. + // + // See https://github.com/golang/go/issues/21971#issuecomment-332693931 + // and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate + // for more info. + tlsConfig.InsecureSkipVerify = true + tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error { + certs := make([]*x509.Certificate, len(certificates)) + for i, asn1Data := range certificates { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + return errors.New("failed to parse certificate from server: " + err.Error()) + } + certs[i] = cert + } + + // Leave DNSName empty to skip hostname verification. + opts := x509.VerifyOptions{ + Roots: tlsConfig.RootCAs, + Intermediates: x509.NewCertPool(), + } + // Skip the first cert because it's the leaf. All others + // are intermediates. + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + _, err := certs[0].Verify(opts) + return err + } + case "verify-full": tlsConfig.ServerName = host default: return nil, errors.New("sslmode is invalid") diff --git a/config_test.go b/config_test.go index 0819740f..b6068cc8 100644 --- a/config_test.go +++ b/config_test.go @@ -132,7 +132,9 @@ func TestParseConfig(t *testing.T) { Host: "localhost", Port: 5432, Database: "mydb", - TLSConfig: &tls.Config{ServerName: "localhost"}, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, RuntimeParams: map[string]string{}, }, }, From 43bf7131808dc1772f232f35074000e8dca59e5e Mon Sep 17 00:00:00 2001 From: Robert Welin Date: Fri, 27 Mar 2020 13:20:04 +0000 Subject: [PATCH 0421/1158] Use correct format verb for unknown type error --- pgtype.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgtype.go b/pgtype.go index 058aa5c6..d29a0210 100644 --- a/pgtype.go +++ b/pgtype.go @@ -404,7 +404,7 @@ func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) switch dest := dest.(type) { case *string: if formatCode == BinaryFormatCode { - return errors.Errorf("unknown oid %d in binary format cannot be scanned into %t", oid, dest) + return errors.Errorf("unknown oid %d in binary format cannot be scanned into %T", oid, dest) } *dest = string(buf) return nil @@ -415,7 +415,7 @@ func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) if nextDst, retry := GetAssignToDstType(dest); retry { return scanUnknownType(oid, formatCode, buf, nextDst) } - return errors.Errorf("unknown oid %d cannot be scanned into %t", oid, dest) + return errors.Errorf("unknown oid %d cannot be scanned into %T", oid, dest) } } From 523cdad66f0568602919000c3ef92b0746cc2d03 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 27 Mar 2020 15:59:54 -0500 Subject: [PATCH 0422/1158] Truncate nanoseconds in EncodeText for Timestamptz and Timestamp PostgreSQL has microsecond precision. If more than this precision is supplied in the text format it is rounded. This was inconsistent with the binary format. See https://github.com/jackc/pgx/issues/699 for original issue. --- timestamp.go | 2 +- timestamp_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ timestamptz.go | 2 +- timestamptz_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 2 deletions(-) diff --git a/timestamp.go b/timestamp.go index 35ac5143..de059f7e 100644 --- a/timestamp.go +++ b/timestamp.go @@ -158,7 +158,7 @@ func (src Timestamp) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.InfinityModifier { case None: - s = src.Time.Format(pgTimestampFormat) + s = src.Time.Truncate(time.Microsecond).Format(pgTimestampFormat) case Infinity: s = "infinity" case NegativeInfinity: diff --git a/timestamp_test.go b/timestamp_test.go index eec0a52e..2fdc7171 100644 --- a/timestamp_test.go +++ b/timestamp_test.go @@ -32,6 +32,51 @@ func TestTimestampTranscode(t *testing.T) { }) } +func TestTimestampNanosecondsTruncated(t *testing.T) { + tests := []struct { + input time.Time + expected time.Time + }{ + {time.Date(2020, 1, 1, 0, 0, 0, 999999999, time.UTC), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.UTC)}, + {time.Date(2020, 1, 1, 0, 0, 0, 999999001, time.UTC), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.UTC)}, + } + for i, tt := range tests { + { + ts := pgtype.Timestamp{Time: tt.input, Status: pgtype.Present} + buf, err := ts.EncodeText(nil, nil) + if err != nil { + t.Errorf("%d. EncodeText failed - %v", i, err) + } + + ts.DecodeText(nil, buf) + if err != nil { + t.Errorf("%d. DecodeText failed - %v", i, err) + } + + if !(ts.Status == pgtype.Present && ts.Time.Equal(tt.expected)) { + t.Errorf("%d. EncodeText did not truncate nanoseconds", i) + } + } + + { + ts := pgtype.Timestamp{Time: tt.input, Status: pgtype.Present} + buf, err := ts.EncodeBinary(nil, nil) + if err != nil { + t.Errorf("%d. EncodeBinary failed - %v", i, err) + } + + ts.DecodeBinary(nil, buf) + if err != nil { + t.Errorf("%d. DecodeBinary failed - %v", i, err) + } + + if !(ts.Status == pgtype.Present && ts.Time.Equal(tt.expected)) { + t.Errorf("%d. EncodeBinary did not truncate nanoseconds", i) + } + } + } +} + func TestTimestampSet(t *testing.T) { type _time time.Time diff --git a/timestamptz.go b/timestamptz.go index d390d266..100f44a5 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -160,7 +160,7 @@ func (src Timestamptz) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.InfinityModifier { case None: - s = src.Time.UTC().Format(pgTimestamptzSecondFormat) + s = src.Time.UTC().Truncate(time.Microsecond).Format(pgTimestamptzSecondFormat) case Infinity: s = "infinity" case NegativeInfinity: diff --git a/timestamptz_test.go b/timestamptz_test.go index a020b1ec..a088fc08 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -32,6 +32,51 @@ func TestTimestamptzTranscode(t *testing.T) { }) } +func TestTimestamptzNanosecondsTruncated(t *testing.T) { + tests := []struct { + input time.Time + expected time.Time + }{ + {time.Date(2020, 1, 1, 0, 0, 0, 999999999, time.Local), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.Local)}, + {time.Date(2020, 1, 1, 0, 0, 0, 999999001, time.Local), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.Local)}, + } + for i, tt := range tests { + { + tstz := pgtype.Timestamptz{Time: tt.input, Status: pgtype.Present} + buf, err := tstz.EncodeText(nil, nil) + if err != nil { + t.Errorf("%d. EncodeText failed - %v", i, err) + } + + tstz.DecodeText(nil, buf) + if err != nil { + t.Errorf("%d. DecodeText failed - %v", i, err) + } + + if !(tstz.Status == pgtype.Present && tstz.Time.Equal(tt.expected)) { + t.Errorf("%d. EncodeText did not truncate nanoseconds", i) + } + } + + { + tstz := pgtype.Timestamptz{Time: tt.input, Status: pgtype.Present} + buf, err := tstz.EncodeBinary(nil, nil) + if err != nil { + t.Errorf("%d. EncodeBinary failed - %v", i, err) + } + + tstz.DecodeBinary(nil, buf) + if err != nil { + t.Errorf("%d. DecodeBinary failed - %v", i, err) + } + + if !(tstz.Status == pgtype.Present && tstz.Time.Equal(tt.expected)) { + t.Errorf("%d. EncodeBinary did not truncate nanoseconds", i) + } + } + } +} + func TestTimestamptzSet(t *testing.T) { type _time time.Time From 11d9f4e54fb9a1534259b4b9375bcaa392f30425 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 30 Mar 2020 11:09:29 -0500 Subject: [PATCH 0423/1158] Update golang.org/x/crypto for security fix --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index b306e1e4..4dc095ca 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/jackc/pgproto3/v2 v2.0.1 github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 github.com/stretchr/testify v1.5.1 - golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 + golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 golang.org/x/text v0.3.2 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 ) diff --git a/go.sum b/go.sum index 13f276b2..23fb8b32 100644 --- a/go.sum +++ b/go.sum @@ -83,6 +83,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 h1:3zb4D3T4G8jdExgVU/95+vQXfpEPiMdCaZgmGVxjNHM= +golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= From e4f3224f4c6d615b7199c9a606c4e3385efd1f21 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 30 Mar 2020 11:15:08 -0500 Subject: [PATCH 0424/1158] Update changelog for v1.5.0 --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e5b11b7c..c4c3b2d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 1.5.0 (March 30, 2020) + +* Update golang.org/x/crypto for security fix +* Implement "verify-ca" SSL mode (Greg Curtis) + # 1.4.0 (March 7, 2020) * Fix ExecParams and ExecPrepared handling of empty query. From b26cd223783bf410ec9cd789ae872cef56610d93 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 30 Mar 2020 11:18:27 -0500 Subject: [PATCH 0425/1158] Update changelog for v1.3.0 --- CHANGELOG.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f12c5027..560abff3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,12 @@ +# 1.3.0 (March 30, 2020) + +* Get implemented on T instead of *T +* Set will call Get on src if possible +* Range types Set method supports its own type, string, and nil +* Date.Set parses string +* Fix correct format verb for unknown type error (Robert Welin) +* Truncate nanoseconds in EncodeText for Timestamptz and Timestamp + # 1.2.0 (February 5, 2020) * Add zeronull package for easier NULL <-> zero conversion From ef5f8b54af5333208a671a3e2b4c82f1c1dd7bfa Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 30 Mar 2020 11:30:37 -0500 Subject: [PATCH 0426/1158] Update dependencies --- go.mod | 11 +++++------ go.sum | 32 ++++++++++++++++++++++++++++++++ hstore_array_test.go | 4 ++-- jsonb_test.go | 2 +- line_test.go | 2 +- 5 files changed, 41 insertions(+), 10 deletions(-) diff --git a/go.mod b/go.mod index 920f4b21..35991562 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,9 @@ go 1.12 require ( github.com/gofrs/uuid v3.2.0+incompatible github.com/jackc/pgio v1.0.0 - github.com/jackc/pgx v3.6.2+incompatible - github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186 - github.com/lib/pq v1.2.0 - github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 - github.com/stretchr/testify v1.4.0 - golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 + github.com/jackc/pgx/v4 v4.5.0 + github.com/lib/pq v1.3.0 + github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc + github.com/stretchr/testify v1.5.1 + golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 ) diff --git a/go.sum b/go.sum index 7a3fc71e..65468df6 100644 --- a/go.sum +++ b/go.sum @@ -8,18 +8,25 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gofrs/uuid v1.2.0 h1:coDhrjgyJaglxSjxuJdqQSSdUpG3w6p1OwN2od6frBU= github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3 h1:ZFYpB74Kq8xE9gmfxCmXD6QxZ27ja+j3HwGFc+YurhQ= github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb h1:d6GP9szHvXVopAOAnZ7WhRnF3Xdxrylmm/9jnfmW4Ag= github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.4.0 h1:E82UBzFyD752mvI+4RIl1WSxfO2ug64T+sLjvDBWTpA= +github.com/jackc/pgconn v1.4.0/go.mod h1:Y2O3ZDF0q4mMacyWV3AstPJpeHXWGEetiFttmq5lahk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= @@ -28,8 +35,16 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaK github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.1 h1:Rdjp4NFjwHnEslx2b66FfCI2S0LhO4itac3hXz6WX9M= +github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= +github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4WpC0= github.com/jackc/pgx v3.6.2+incompatible h1:2zP5OD7kiyR3xzRYMhOcXVvkDZsImVXfj+yIyTQf3/o= github.com/jackc/pgx v3.6.2+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96 h1:ylEAOd688Duev/fxTmGdupsbyZfxNMdngIG14DoBKTM= @@ -38,8 +53,11 @@ github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912 h1:YuOWGsSK5L4Fz81Olx github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186 h1:ZQM8qLT/E/CGD6XX0E6q9FAwxJYmWpJufzmLMaFuzgQ= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/pgx/v4 v4.5.0 h1:mN7Z3n0uqPe29+tA4yLWyZNceYKgRvUWNk8qW+D066E= +github.com/jackc/pgx/v4 v4.5.0/go.mod h1:EpAKPLdnTorwmPUUsqrPxy5fphV18j9q3wrfRXgo+kA= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= @@ -53,9 +71,14 @@ github.com/lib/pq v1.1.0 h1:/5u4a+KGJptBRqGzPvYQL9p0d/tPR4S31+Tnzj9lEO4= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU= +github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -67,6 +90,8 @@ github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 h1:pntxY8Ary0t43dCZ5dqY4YTJCObLY1kIXl0uzMv+7DE= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc h1:jUIKcSPO9MoMJBbEoyE/RJoE8vz7Mb8AjvifMMwSyvY= +github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -77,6 +102,8 @@ github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= @@ -88,6 +115,8 @@ golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcN golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7 h1:0hQKqeLdqlt5iIwVOBErRisrHJAN57yOiPRQItI20fU= +golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -100,6 +129,7 @@ golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= @@ -112,6 +142,8 @@ golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/hstore_array_test.go b/hstore_array_test.go index ea8f03b0..32b91840 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -19,14 +19,14 @@ func TestHstoreArrayTranscode(t *testing.T) { if err != nil { t.Fatalf("did not find hstore OID, %v", err) } - conn.ConnInfo.RegisterDataType(pgtype.DataType{Value: &pgtype.Hstore{}, Name: "hstore", OID: hstoreOID}) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.Hstore{}, Name: "hstore", OID: hstoreOID}) var hstoreArrayOID uint32 err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='_hstore';").Scan(&hstoreArrayOID) if err != nil { t.Fatalf("did not find _hstore OID, %v", err) } - conn.ConnInfo.RegisterDataType(pgtype.DataType{Value: &pgtype.HstoreArray{}, Name: "_hstore", OID: hstoreArrayOID}) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.HstoreArray{}, Name: "_hstore", OID: hstoreArrayOID}) text := func(s string) pgtype.Text { return pgtype.Text{String: s, Status: pgtype.Present} diff --git a/jsonb_test.go b/jsonb_test.go index e7ce7203..9ce80d42 100644 --- a/jsonb_test.go +++ b/jsonb_test.go @@ -12,7 +12,7 @@ import ( func TestJSONBTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) - if _, ok := conn.ConnInfo.DataTypeForName("jsonb"); !ok { + if _, ok := conn.ConnInfo().DataTypeForName("jsonb"); !ok { t.Skip("Skipping due to no jsonb type") } diff --git a/line_test.go b/line_test.go index 6a560dec..f697ac43 100644 --- a/line_test.go +++ b/line_test.go @@ -10,7 +10,7 @@ import ( func TestLineTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) - if _, ok := conn.ConnInfo.DataTypeForName("line"); !ok { + if _, ok := conn.ConnInfo().DataTypeForName("line"); !ok { t.Skip("Skipping due to no line type") } From 9016875caee4d838233b5af0ca2c34a2732faa16 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 2 Apr 2020 14:01:16 -0500 Subject: [PATCH 0427/1158] Add JSON support to ext/gofrs-uuid --- ext/gofrs-uuid/uuid.go | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/ext/gofrs-uuid/uuid.go b/ext/gofrs-uuid/uuid.go index b0413cae..fec912bc 100644 --- a/ext/gofrs-uuid/uuid.go +++ b/ext/gofrs-uuid/uuid.go @@ -10,6 +10,7 @@ import ( ) var errUndefined = errors.New("cannot encode status undefined") +var errBadStatus = errors.New("invalid status") type UUID struct { UUID uuid.UUID @@ -172,3 +173,32 @@ func (dst *UUID) Scan(src interface{}) error { func (src UUID) Value() (driver.Value, error) { return pgtype.EncodeValueText(src) } + +func (src UUID) MarshalJSON() ([]byte, error) { + switch src.Status { + case pgtype.Present: + return []byte(`"` + src.UUID.String() + `"`), nil + case pgtype.Null: + return []byte("null"), nil + case pgtype.Undefined: + return nil, errUndefined + } + + return nil, errBadStatus +} + +func (dst *UUID) UnmarshalJSON(b []byte) error { + u := uuid.NullUUID{} + err := u.UnmarshalJSON(b) + if err != nil { + return err + } + + status := pgtype.Null + if u.Valid { + status = pgtype.Present + } + *dst = UUID{UUID: u.UUID, Status: status} + + return nil +} From 1fcc71410c85e07399da17cb214459af0b0b6dc8 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Bronisz Date: Mon, 6 Apr 2020 19:45:07 +0200 Subject: [PATCH 0428/1158] Clean go.sum file to remove old version of pgx v3 --- go.sum | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/go.sum b/go.sum index 65468df6..5e75654d 100644 --- a/go.sum +++ b/go.sum @@ -8,7 +8,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/gofrs/uuid v1.2.0 h1:coDhrjgyJaglxSjxuJdqQSSdUpG3w6p1OwN2od6frBU= github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= @@ -26,6 +25,7 @@ github.com/jackc/pgconn v1.4.0 h1:E82UBzFyD752mvI+4RIl1WSxfO2ug64T+sLjvDBWTpA= github.com/jackc/pgconn v1.4.0/go.mod h1:Y2O3ZDF0q4mMacyWV3AstPJpeHXWGEetiFttmq5lahk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= @@ -45,8 +45,6 @@ github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01C github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4WpC0= -github.com/jackc/pgx v3.6.2+incompatible h1:2zP5OD7kiyR3xzRYMhOcXVvkDZsImVXfj+yIyTQf3/o= -github.com/jackc/pgx v3.6.2+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96 h1:ylEAOd688Duev/fxTmGdupsbyZfxNMdngIG14DoBKTM= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912 h1:YuOWGsSK5L4Fz81Olx5TNlZftmDuNrfv4ip0Yos77Tw= From 84aee0ab4443115da0c34114c300a50a410e5402 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petr=20Jedin=C3=BD?= Date: Wed, 8 Apr 2020 00:08:53 +0200 Subject: [PATCH 0429/1158] Fix behavior of sslmode=require with sslrootcert present According to PostgreSQL documentation the behavior should be the same as that of verify-ca sslmode https://www.postgresql.org/docs/12/libpq-ssl.html --- config.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 70e6073a..06184b02 100644 --- a/config.go +++ b/config.go @@ -548,7 +548,17 @@ func configTLS(settings map[string]string) ([]*tls.Config, error) { case "allow", "prefer": tlsConfig.InsecureSkipVerify = true case "require": - tlsConfig.InsecureSkipVerify = sslrootcert == "" + // According to PostgreSQL documentation, if a root CA file exists, + // the behavior of sslmode=require should be the same as that of verify-ca + // + // See https://www.postgresql.org/docs/12/libpq-ssl.html + if sslrootcert != "" { + goto nextCase + } + tlsConfig.InsecureSkipVerify = true + break + nextCase: + fallthrough case "verify-ca": // Don't perform the default certificate verification because it // will verify the hostname. Instead, verify the server's From 5d2be99c254e76f7dfb8b481db1791dd613b5d4c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 7 Apr 2020 19:38:21 -0500 Subject: [PATCH 0430/1158] Fix panic when closing conn during cancellable query fixes #29 --- internal/ctxwatch/context_watcher_test.go | 11 +++++++++++ pgconn.go | 7 +++++++ pgconn_test.go | 13 +++++++++++++ 3 files changed, 31 insertions(+) diff --git a/internal/ctxwatch/context_watcher_test.go b/internal/ctxwatch/context_watcher_test.go index 0b491bf8..6348b729 100644 --- a/internal/ctxwatch/context_watcher_test.go +++ b/internal/ctxwatch/context_watcher_test.go @@ -59,6 +59,17 @@ func TestContextWatcherMultipleWatchPanics(t *testing.T) { require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times") } +func TestContextWatcherUnwatchIsAlwaysSafe(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + cw.Unwatch() // unwatch when not / never watching + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cw.Watch(ctx) + cw.Unwatch() + cw.Unwatch() // double unwatch +} + func TestContextWatcherStress(t *testing.T) { var cancelFuncCalls int64 var cleanupFuncCalls int64 diff --git a/pgconn.go b/pgconn.go index 6155281d..d5a424ac 100644 --- a/pgconn.go +++ b/pgconn.go @@ -494,6 +494,13 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() if ctx != context.Background() { + // Close may be called while a cancellable query is in progress. This will most often be triggered by panic when + // a defer closes the connection (possibly indirectly via a transaction or a connection pool). Unwatch to end any + // previous watch. It is safe to Unwatch regardless of whether a watch is already is progress. + // + // See https://github.com/jackc/pgconn/issues/29 + pgConn.contextWatcher.Unwatch() + pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } diff --git a/pgconn_test.go b/pgconn_test.go index 17b40343..e29a36b2 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1708,6 +1708,19 @@ func TestHijackAndConstruct(t *testing.T) { ensureConnValid(t, newConn) } +func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + ctx, _ := context.WithCancel(context.Background()) + pgConn.Exec(ctx, "select n from generate_series(1,10) n") + + closeCtx, _ := context.WithCancel(context.Background()) + pgConn.Close(closeCtx) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { From 087df120bbdc0e79b94e174062509a1dc14e90a9 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Sat, 11 Apr 2020 10:38:23 +0100 Subject: [PATCH 0431/1158] Refactor lowlevel record field iteration --- record.go | 89 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 63 insertions(+), 26 deletions(-) diff --git a/record.go b/record.go index 5c9d7a02..76da5ad0 100644 --- a/record.go +++ b/record.go @@ -78,57 +78,94 @@ func (src *Record) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +type fieldIter struct { + rp int + fieldCount int + src []byte +} + +func newFieldIterator(src []byte) (fieldIter, error) { + rp := 0 + if len(src[rp:]) < 4 { + return fieldIter{}, errors.Errorf("Record incomplete %v", src) + } + + fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + return fieldIter{ + rp: rp, + fieldCount: fieldCount, + src: src, + }, nil +} + +func (fi *fieldIter) next() (fieldOID uint32, buf []byte, eof bool, err error) { + if fi.rp == len(fi.src) { + eof = true + return + } + + if len(fi.src[fi.rp:]) < 8 { + err = errors.Errorf("Record incomplete %v", fi.src) + return + } + fieldOID = binary.BigEndian.Uint32(fi.src[fi.rp:]) + fi.rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(fi.src[fi.rp:]))) + fi.rp += 4 + + if fieldLen >= 0 { + if len(fi.src[fi.rp:]) < fieldLen { + err = errors.Errorf("Record incomplete rp=%d src=%v", fi.rp, fi.src) + return + } + buf = fi.src[fi.rp : fi.rp+fieldLen] + fi.rp += fieldLen + } + + return +} + func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Record{Status: Null} return nil } - rp := 0 - - if len(src[rp:]) < 4 { - return errors.Errorf("Record incomplete %v", src) + fieldIter, err := newFieldIterator(src) + if err != nil { + return err } - fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - fields := make([]Value, fieldCount) + fields := make([]Value, fieldIter.fieldCount) + fieldOID, fieldBytes, eof, err := fieldIter.next() - for i := 0; i < fieldCount; i++ { - if len(src[rp:]) < 8 { - return errors.Errorf("Record incomplete %v", src) + for i := 0; !eof; i++ { + if err != nil { + return err } - fieldOID := binary.BigEndian.Uint32(src[rp:]) - rp += 4 - - fieldLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var binaryDecoder BinaryDecoder + if dt, ok := ci.DataTypeForOID(fieldOID); ok { binaryDecoder, _ = dt.Value.(BinaryDecoder) - } - if binaryDecoder == nil { + } else { return errors.Errorf("unknown oid while decoding record: %v", fieldOID) } - var fieldBytes []byte - if fieldLen >= 0 { - if len(src[rp:]) < fieldLen { - return errors.Errorf("Record incomplete %v", src) - } - fieldBytes = src[rp : rp+fieldLen] - rp += fieldLen + if binaryDecoder == nil { + return errors.Errorf("no binary decoder registered for: %v", fieldOID) } // Duplicate struct to scan into binaryDecoder = reflect.New(reflect.ValueOf(binaryDecoder).Elem().Type()).Interface().(BinaryDecoder) - if err := binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { return err } fields[i] = binaryDecoder.(Value) + fieldOID, fieldBytes, eof, err = fieldIter.next() } *dst = Record{Fields: fields, Status: Present} From 9a869c8359bd2845aaa54ee5db56c930a348a063 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Sat, 11 Apr 2020 11:08:53 +0100 Subject: [PATCH 0432/1158] Refactor record field binary decoder preparation --- record.go | 36 +++++++++++++++++++++++------------- record_test.go | 46 +++++++++++++++++++++++++++++++++++++--------- 2 files changed, 60 insertions(+), 22 deletions(-) diff --git a/record.go b/record.go index 76da5ad0..08603140 100644 --- a/record.go +++ b/record.go @@ -128,6 +128,25 @@ func (fi *fieldIter) next() (fieldOID uint32, buf []byte, eof bool, err error) { return } +func prepareNewBinaryDecoder(ci *ConnInfo, fieldOID uint32, v *Value) (BinaryDecoder, error) { + var binaryDecoder BinaryDecoder + + if dt, ok := ci.DataTypeForOID(fieldOID); ok { + binaryDecoder, _ = dt.Value.(BinaryDecoder) + } else { + return nil, errors.Errorf("unknown oid while decoding record: %v", fieldOID) + } + + if binaryDecoder == nil { + return nil, errors.Errorf("no binary decoder registered for: %v", fieldOID) + } + + // Duplicate struct to scan into + binaryDecoder = reflect.New(reflect.ValueOf(binaryDecoder).Elem().Type()).Interface().(BinaryDecoder) + *v = binaryDecoder.(Value) + return binaryDecoder, nil +} + func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Record{Status: Null} @@ -146,25 +165,16 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { if err != nil { return err } - var binaryDecoder BinaryDecoder - if dt, ok := ci.DataTypeForOID(fieldOID); ok { - binaryDecoder, _ = dt.Value.(BinaryDecoder) - } else { - return errors.Errorf("unknown oid while decoding record: %v", fieldOID) + binaryDecoder, err := prepareNewBinaryDecoder(ci, fieldOID, &fields[i]) + if err != nil { + return err } - if binaryDecoder == nil { - return errors.Errorf("no binary decoder registered for: %v", fieldOID) - } - - // Duplicate struct to scan into - binaryDecoder = reflect.New(reflect.ValueOf(binaryDecoder).Elem().Type()).Interface().(BinaryDecoder) - if err := binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { + if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { return err } - fields[i] = binaryDecoder.(Value) fieldOID, fieldBytes, eof, err = fieldIter.next() } diff --git a/record_test.go b/record_test.go index 71a2f702..c8d9097d 100644 --- a/record_test.go +++ b/record_test.go @@ -83,22 +83,50 @@ func TestRecordTranscode(t *testing.T) { }, } - for i, tt := range tests { + for i := 0; i < len(tests); i++ { + tt := tests[i] psName := fmt.Sprintf("test%d", i) _, err := conn.Prepare(context.Background(), psName, tt.sql) if err != nil { t.Fatal(err) } - var result pgtype.Record - if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil { - t.Errorf("%d: %v", i, err) - continue - } + t.Run(fmt.Sprintf("scan %d", i), func(t *testing.T) { + var result pgtype.Record + if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil { + t.Errorf("%v", err) + return + } + + if !reflect.DeepEqual(tt.expected, result) { + t.Errorf("expected %#v, got %#v", tt.expected, result) + } + }) + + t.Run(fmt.Sprintf("scan MatchFields %d", i), func(t *testing.T) { + tt.expected.MatchFields = true + + fieldsCopy := make([]pgtype.Value, len(tt.expected.Fields)) + reflect.Copy(reflect.ValueOf(fieldsCopy), reflect.ValueOf(tt.expected.Fields)) + + if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&tt.expected); err != nil { + t.Errorf("%d: %v", i, err) + return + } + + if !reflect.DeepEqual(tt.expected.Fields, fieldsCopy) { + t.Errorf("Matching scan succeeded, but modified predefined fields. %d: expected %#v, got %#v", i, tt.expected.Fields, fieldsCopy) + } + + // borrow fields from a neighbor test, this makes scan always fail + tt.expected.Fields = tests[(i+1)%len(tests)].expected.Fields + reflect.Copy(reflect.ValueOf(fieldsCopy), reflect.ValueOf(tt.expected.Fields)) + if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&tt.expected); err == nil { + t.Errorf("Matching scan didn't fail, despite fields not mathchin query result. %d: %v", i, err) + return + } + }) - if !reflect.DeepEqual(tt.expected, result) { - t.Errorf("%d: expected %#v, got %#v", i, tt.expected, result) - } } } From ff95f82f7057c17f1bc55e01c6a2a04da70b1f4f Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Sat, 11 Apr 2020 12:20:43 +0100 Subject: [PATCH 0433/1158] Add ScanRowValue helper function ScanRowValue is useful when reading ROW() values with known field types as well as composite types. It accepts pgtype.Value arguments, where ROW() fields are written to on successfull scan. --- convert.go | 38 +++++++++ record_test.go | 207 ++++++++++++++++++++++++++----------------------- 2 files changed, 150 insertions(+), 95 deletions(-) diff --git a/convert.go b/convert.go index cc5c10ab..a0c38c5b 100644 --- a/convert.go +++ b/convert.go @@ -433,6 +433,44 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { return nil, false } +// ScanRowValue assigns ROW()'s fields to destination Values. +// Argument types are checked and error is returned if SQL field value +// can't be assigned to corresponding destionation Value without loss +// of information. Number of fields have to match number of destination values. +// +// Values must implement BinaryDecoder interface otherwise error is returned. +// ScanRowValue takes ownership of src, caller MUST not use it after call +func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error { + fieldIter, err := newFieldIterator(src) + if err != nil { + return err + } + + if len(dst) != fieldIter.fieldCount { + return errors.Errorf("can't scan row value, number of fields don't match: row fields count=%d desired fields count=%d", fieldIter.fieldCount, len(dst)) + } + + _, fieldBytes, eof, err := fieldIter.next() + for i := 0; !eof; i++ { + if err != nil { + return err + } + + binaryDecoder, ok := dst[i].(BinaryDecoder) + if !ok { + return errors.Errorf("record field doesn't implement binary decoding: %s", reflect.TypeOf(dst[i]).Name()) + } + + if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { + return err + } + + _, fieldBytes, eof, err = fieldIter.next() + } + + return nil +} + func init() { kindTypes = map[reflect.Kind]reflect.Type{ reflect.Bool: reflect.TypeOf(false), diff --git a/record_test.go b/record_test.go index c8d9097d..af2105c7 100644 --- a/record_test.go +++ b/record_test.go @@ -11,87 +11,128 @@ import ( "github.com/jackc/pgx/v4" ) +var recordTests = []struct { + sql string + expected pgtype.Record +}{ + { + sql: `select row()`, + expected: pgtype.Record{ + Fields: []pgtype.Value{}, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row(100.0::float4, 1.09::float4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Float4{Float: 100, Status: pgtype.Present}, + &pgtype.Float4{Float: 1.09, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row(null)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Unknown{Status: pgtype.Null}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select null::record`, + expected: pgtype.Record{ + Status: pgtype.Null, + }, + }, +} + +// row values are binary compatible with records, so we test our helper +// routines here +func TestScanRowValue(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + for i := 0; i < len(recordTests); i++ { + tt := recordTests[i] + psName := fmt.Sprintf("test%d", i) + _, err := conn.Prepare(context.Background(), psName, tt.sql) + if err != nil { + t.Fatal(err) + } + t.Run(tt.sql, func(t *testing.T) { + desc := append([]pgtype.Value(nil), tt.expected.Fields...) + + var raw pgtype.GenericBinary + + if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&raw); err != nil { + t.Error(err) + return + } + + if raw.Status == pgtype.Null { + // ScanRowValue deals with complete rows only, NULL values (but NOT null fields) + // should be handled by the calling code + return + } + + if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err != nil { + t.Error(err) + } + + // borrow fields from a neighbor test, this makes scan always fail + desc = append([]pgtype.Value(nil), recordTests[(i+1)%len(recordTests)].expected.Fields...) + if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err == nil { + t.Error("Matching scan didn't fail, despite fields not mathching query result") + } + }) + } +} + func TestRecordTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) - tests := []struct { - sql string - expected pgtype.Record - }{ - { - sql: `select row()`, - expected: pgtype.Record{ - Fields: []pgtype.Value{}, - Status: pgtype.Present, - }, - }, - { - sql: `select row('foo'::text, 42::int4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, - }, - Status: pgtype.Present, - }, - }, - { - sql: `select row(100.0::float4, 1.09::float4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Float4{Float: 100, Status: pgtype.Present}, - &pgtype.Float4{Float: 1.09, Status: pgtype.Present}, - }, - Status: pgtype.Present, - }, - }, - { - sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: 4, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, - }, - Status: pgtype.Present, - }, - }, - { - sql: `select row(null)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Unknown{Status: pgtype.Null}, - }, - Status: pgtype.Present, - }, - }, - { - sql: `select null::record`, - expected: pgtype.Record{ - Status: pgtype.Null, - }, - }, - } - - for i := 0; i < len(tests); i++ { - tt := tests[i] + for i, tt := range recordTests { psName := fmt.Sprintf("test%d", i) _, err := conn.Prepare(context.Background(), psName, tt.sql) if err != nil { t.Fatal(err) } - t.Run(fmt.Sprintf("scan %d", i), func(t *testing.T) { + t.Run(tt.sql, func(t *testing.T) { var result pgtype.Record if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil { t.Errorf("%v", err) @@ -103,30 +144,6 @@ func TestRecordTranscode(t *testing.T) { } }) - t.Run(fmt.Sprintf("scan MatchFields %d", i), func(t *testing.T) { - tt.expected.MatchFields = true - - fieldsCopy := make([]pgtype.Value, len(tt.expected.Fields)) - reflect.Copy(reflect.ValueOf(fieldsCopy), reflect.ValueOf(tt.expected.Fields)) - - if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&tt.expected); err != nil { - t.Errorf("%d: %v", i, err) - return - } - - if !reflect.DeepEqual(tt.expected.Fields, fieldsCopy) { - t.Errorf("Matching scan succeeded, but modified predefined fields. %d: expected %#v, got %#v", i, tt.expected.Fields, fieldsCopy) - } - - // borrow fields from a neighbor test, this makes scan always fail - tt.expected.Fields = tests[(i+1)%len(tests)].expected.Fields - reflect.Copy(reflect.ValueOf(fieldsCopy), reflect.ValueOf(tt.expected.Fields)) - if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&tt.expected); err == nil { - t.Errorf("Matching scan didn't fail, despite fields not mathchin query result. %d: %v", i, err) - return - } - }) - } } From 71ed747f3a7786d1f1c1dc336d4c661d08da6e6a Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Sun, 12 Apr 2020 15:52:37 +0100 Subject: [PATCH 0434/1158] Add example of CompositeType handling with ScanRowValue helper --- composite_test.go | 77 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 composite_test.go diff --git a/composite_test.go b/composite_test.go new file mode 100644 index 00000000..d51cb579 --- /dev/null +++ b/composite_test.go @@ -0,0 +1,77 @@ +package pgtype_test + +import ( + "context" + "fmt" + "os" + + "github.com/jackc/pgtype" + pgx "github.com/jackc/pgx/v4" + errors "golang.org/x/xerrors" +) + +type MyType struct { + a int32 // NULL will cause decoding error + b *string // there can be NULL in this position in SQL +} + +func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs") + } + + a := pgtype.Int4{} + b := pgtype.Text{} + + if err := pgtype.ScanRowValue(ci, src, &a, &b); err != nil { + return err + } + + // type compatibility is checked by AssignTo + // only lossless assignments will succeed + if err := a.AssignTo(&dst.a); err != nil { + return err + } + + // AssignTo also deals with null value handling + if err := b.AssignTo(&dst.b); err != nil { + return err + } + + return nil +} + +func Example_compositeTypes() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + panic(err) + } + defer conn.Close(context.Background()) + _, err = conn.Exec(context.Background(), `drop type if exists mytype; + +create type mytype as ( + a int4, + b text +);`) + if err != nil { + panic(err) + } + defer conn.Exec(context.Background(), "drop type mytype") + + var result *MyType + if err = conn.QueryRow(context.Background(), "select (1,'foo')::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil { + panic(err) + } + + fmt.Printf("First row: a=%d b=%s\n", result.a, *result.b) + + // Because we scan into &*MyType, NULLs are handled generically by assigning nil to result + if err = conn.QueryRow(context.Background(), "select NULL::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil { + panic(err) + } + + fmt.Printf("Second row: %v\n", result) + // Output: + // First row: a=1 b=foo + // Second row: +} From 368295d3ee4d8f08f30d5f8cb1841461cd4f14a6 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Sun, 12 Apr 2020 18:40:52 +0100 Subject: [PATCH 0435/1158] Create ROW helper for adhoc decoding of records --- composite_test.go | 13 +++++++++++++ convert.go | 28 ++++++++++++++++++++++++++++ pgtype.go | 9 +++++++++ 3 files changed, 50 insertions(+) diff --git a/composite_test.go b/composite_test.go index d51cb579..ffa7d479 100644 --- a/composite_test.go +++ b/composite_test.go @@ -71,7 +71,20 @@ create type mytype as ( } fmt.Printf("Second row: %v\n", result) + + // Adhoc rows can be decoded inplace without boilerplate (works with composite types too) + var isNull bool + var a int + var b *string + + if err = conn.QueryRow(context.Background(), "select (2, 'bar')::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(pgtype.ROW(&isNull, &a, &b)); err != nil { + panic(err) + } + + fmt.Printf("Adhoc: isNull=%v a=%d b=%s", isNull, a, *b) + // Output: // First row: a=1 b=foo // Second row: + // Adhoc: isNull=false a=2 b=bar } diff --git a/convert.go b/convert.go index a0c38c5b..8157358b 100644 --- a/convert.go +++ b/convert.go @@ -471,6 +471,34 @@ func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error { return nil } +// ROW allows deconstructing row values (records and composite types) into +// fields directly without creating your own type and implementing decoder interfaces +func ROW(isNull *bool, fields ...interface{}) BinaryDecoderFunc { + return func(ci *ConnInfo, src []byte) error { + var record Record + if err := record.DecodeBinary(ci, src); err != nil { + return err + } + + if record.Status == Null { + *isNull = true + return nil + } + + if len(record.Fields) != len(fields) { + return errors.Errorf("can't scan row value, number of fields don't match: row fields count=%d desired fields count=%d", len(record.Fields), len(fields)) + } + + for i, f := range record.Fields { + if err := f.AssignTo(fields[i]); err != nil { + return err + } + } + + return nil + } +} + func init() { kindTypes = map[reflect.Kind]reflect.Type{ reflect.Bool: reflect.TypeOf(false), diff --git a/pgtype.go b/pgtype.go index 914e02d2..1749c8c2 100644 --- a/pgtype.go +++ b/pgtype.go @@ -158,6 +158,15 @@ type TextEncoder interface { EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) } +//The BinaryDecoderFunc type is an adapter to allow the use of ordinary functions as BinaryDecoder types. +// If f is a function with the appropriate signature, BinaryDecoderFunc(f) is a BinaryDecoder that calls f. +type BinaryDecoderFunc func(ci *ConnInfo, src []byte) error + +// DecodeBinary calls f(ci, src) +func (f BinaryDecoderFunc) DecodeBinary(ci *ConnInfo, src []byte) error { + return f(ci, src) +} + var errUndefined = errors.New("cannot encode status undefined") var errBadStatus = errors.New("invalid status") From 8ae83b19f7d6a2a27ac3b1a2664dc3c61a90cf46 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Sun, 12 Apr 2020 22:33:33 +0100 Subject: [PATCH 0436/1158] Add EncodeRow helpers Also extend example to show how EncodeRow can be used to create binary encoders for composite type --- composite_test.go | 47 +++++++++++++++++++++++++++++++++-------------- convert.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 14 deletions(-) diff --git a/composite_test.go b/composite_test.go index ffa7d479..d0c48f6e 100644 --- a/composite_test.go +++ b/composite_test.go @@ -41,11 +41,32 @@ func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return nil } -func Example_compositeTypes() { - conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) +func (src MyType) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) { + a := pgtype.Int4{src.a, pgtype.Present} + var b pgtype.Text + if src.b != nil { + b = pgtype.Text{*src.b, pgtype.Present} + } else { + b = pgtype.Text{Status: pgtype.Null} + } + + return pgtype.EncodeRow(ci, buf, &a, &b) +} + +func ptrS(s string) *string { + return &s +} + +func E(err error) { if err != nil { panic(err) } +} + +func Example_compositeTypes() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + E(err) + defer conn.Close(context.Background()) _, err = conn.Exec(context.Background(), `drop type if exists mytype; @@ -53,22 +74,21 @@ create type mytype as ( a int4, b text );`) - if err != nil { - panic(err) - } + E(err) defer conn.Exec(context.Background(), "drop type mytype") var result *MyType - if err = conn.QueryRow(context.Background(), "select (1,'foo')::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil { - panic(err) - } + + // Demonstrates both passing and reading back composite values + err = conn.QueryRow(context.Background(), "select $1::mytype", + pgx.QueryResultFormats{pgx.BinaryFormatCode}, MyType{1, ptrS("foo")}).Scan(&result) + E(err) fmt.Printf("First row: a=%d b=%s\n", result.a, *result.b) // Because we scan into &*MyType, NULLs are handled generically by assigning nil to result - if err = conn.QueryRow(context.Background(), "select NULL::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil { - panic(err) - } + err = conn.QueryRow(context.Background(), "select NULL::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result) + E(err) fmt.Printf("Second row: %v\n", result) @@ -77,9 +97,8 @@ create type mytype as ( var a int var b *string - if err = conn.QueryRow(context.Background(), "select (2, 'bar')::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(pgtype.ROW(&isNull, &a, &b)); err != nil { - panic(err) - } + err = conn.QueryRow(context.Background(), "select (2, 'bar')::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(pgtype.ROW(&isNull, &a, &b)) + E(err) fmt.Printf("Adhoc: isNull=%v a=%d b=%s", isNull, a, *b) diff --git a/convert.go b/convert.go index 8157358b..d22a714f 100644 --- a/convert.go +++ b/convert.go @@ -5,6 +5,7 @@ import ( "reflect" "time" + "github.com/jackc/pgio" errors "golang.org/x/xerrors" ) @@ -471,6 +472,38 @@ func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error { return nil } +// EncodeRow builds a binary representation of row values (row(), composite types) +func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) { + fieldBytes := make([]byte, 0, 128) + + newBuf = pgio.AppendUint32(buf, uint32(len(fields))) + for _, f := range fields { + dt, ok := ci.DataTypeForValue(f) + if !ok { + return nil, errors.Errorf("Unknown OID for %s", f) + } + newBuf = pgio.AppendUint32(newBuf, dt.OID) + + if f.Get() != nil { + binaryEncoder, ok := f.(BinaryEncoder) + if !ok { + return nil, errors.Errorf("record field doesn't implement binary encoding: %s", reflect.TypeOf(f).Name()) + } + fieldBytes, err = binaryEncoder.EncodeBinary(ci, fieldBytes[:0]) + if err != nil { + return nil, err + } + + newBuf = pgio.AppendUint32(newBuf, uint32(len(fieldBytes))) + newBuf = append(newBuf, fieldBytes...) + } else { + newBuf = pgio.AppendInt32(newBuf, int32(-1)) + } + + } + return +} + // ROW allows deconstructing row values (records and composite types) into // fields directly without creating your own type and implementing decoder interfaces func ROW(isNull *bool, fields ...interface{}) BinaryDecoderFunc { From 3ce29f9e055b46203c43c51744f536888942c018 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Mon, 13 Apr 2020 01:52:06 +0100 Subject: [PATCH 0437/1158] Add Composite type for inplace row() values handling Composite() function returns a private type, which should be registered with ConnInfo.RegisterDataType for the composite type's OID. All subsequent interaction with Composite types is to be done via Row(...) function. Function return value can be either passed as a query argument to build SQL composite value out of individual fields or passed to Scan to read SQL composite value back. When passed to Scan, Row() should have first argument of type *bool to flag NULL values returned from query. --- composite.go | 128 ++++++++++++++++++++++++++++++++++++++++++++++ composite_test.go | 16 ++++-- convert.go | 28 ---------- pgtype.go | 9 ++++ 4 files changed, 150 insertions(+), 31 deletions(-) create mode 100644 composite.go diff --git a/composite.go b/composite.go new file mode 100644 index 00000000..1caa24d6 --- /dev/null +++ b/composite.go @@ -0,0 +1,128 @@ +package pgtype + +import ( + errors "golang.org/x/xerrors" +) + +type composite struct { + fields []Value + status Status +} + +// helper struct to act both as a scanning target and query argument +type rowValue struct { + args []interface{} +} + +// Row helper function builds a value which can be both used to +// "assemble" composite quiery arguments and to scan results back. +// +// When passed as an argument to query, values from Row args will +// be assigned to corresponding fields in a composite type and a single +// composite type will be passed to the PostgreSQL. Composite type need +// to be registered in ConnInfo first. This is required so that pgx +// can know which SQL types to use when constructing SQL composite argument +// +// When passed to Scan individual fields from composite query result +// are assigned to corresponding Row arguments. First argument MUST +// be of type *bool to flag when NULL value received. So total number +// of Row arguments, when passed to Scan should be number of composite +// fields you expect to read + 1 +func Row(fields ...interface{}) rowValue { + return rowValue{fields} +} + +// Composite types is meant to be passed to ConnInfo.RegisterDataType only, +// so it is made private on purpose. Once registered, it allows Row +// function to correctly pass query arguments. +func Composite(fields ...Value) *composite { + return &composite{fields, Undefined} +} + +func (src composite) Get() interface{} { + switch src.status { + case Present: + return src + case Null: + return nil + default: + return src.status + } +} + +// Set is called internally when passing query arguments. +// Only valid src is a result of pgtype.Row() or nil +func (dst *composite) Set(src interface{}) error { + if src == nil { + *dst = composite{status: Null} + return nil + } + + switch value := src.(type) { + case rowValue: + if len(value.args) != len(dst.fields) { + return errors.Errorf("Number of fields don't match. Composite has %d fields", len(dst.fields)) + } + for i, v := range value.args { + if err := dst.fields[i].Set(v); err != nil { + return err + } + } + dst.status = Present + default: + return errors.Errorf("Use pgtype.Row() as query parameter") + } + + return nil +} + +// AssignTo is never called on composite value directly, it is here +// to satisfy Valuer interface +func (src composite) AssignTo(dst interface{}) error { + return errors.New("BUG: should never be called, because pgtype.composite doesn't support decoding") +} + +func (src composite) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { + return EncodeRow(ci, buf, src.fields...) +} + +// DecodeBinary here is just to make pgx use binary result format by default. +// Users should be using Row function or their own types to scan composites +func (src composite) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { + return errors.New("Pass pgtype.Row() to Scan to deconstruct Composite") +} + +// DecodeBinary is called when pgtype.Row() is passed to Scan() to +// deconstruct composite value +func (r rowValue) DecodeBinary(ci *ConnInfo, src []byte) error { + if len(r.args) == 0 { + return errors.New("pgtype.Row must have 'isNull *bool' as a first argument when used in Scan") + } + + isNull, ok := r.args[0].(*bool) + if !ok { + return errors.New("pgtype.Row must have 'isNull *bool' as a first argument when used in Scan") + } + args := r.args[1:] + + var record Record + if err := record.DecodeBinary(ci, src); err != nil { + return err + } + + if record.Status == Null { + *isNull = true + return nil + } + + if len(record.Fields) != len(args) { + return errors.Errorf("SQL composite can't be read, 'pgtype.Row' has wrong field cout. %d != %d", len(record.Fields), len(args)) + } + + for i, f := range record.Fields { + if err := f.AssignTo(args[i]); err != nil { + return err + } + } + return nil +} diff --git a/composite_test.go b/composite_test.go index d0c48f6e..b38cdd45 100644 --- a/composite_test.go +++ b/composite_test.go @@ -81,7 +81,8 @@ create type mytype as ( // Demonstrates both passing and reading back composite values err = conn.QueryRow(context.Background(), "select $1::mytype", - pgx.QueryResultFormats{pgx.BinaryFormatCode}, MyType{1, ptrS("foo")}).Scan(&result) + pgx.QueryResultFormats{pgx.BinaryFormatCode}, MyType{1, ptrS("foo")}). + Scan(&result) E(err) fmt.Printf("First row: a=%d b=%s\n", result.a, *result.b) @@ -92,12 +93,21 @@ create type mytype as ( fmt.Printf("Second row: %v\n", result) - // Adhoc rows can be decoded inplace without boilerplate (works with composite types too) + //WIP + q, err := conn.Prepare(context.Background(), "z", "select $1::mytype") + E(err) + conn.ConnInfo().RegisterDataType(pgtype.DataType{pgtype.Composite(&pgtype.Int4{}, &pgtype.Text{}), "mytype", q.ParamOIDs[0]}) + + // Adhoc rows can be decoded inplace without boilerplate + // Composite types can be encoded/decoded inplace + var isNull bool var a int var b *string - err = conn.QueryRow(context.Background(), "select (2, 'bar')::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(pgtype.ROW(&isNull, &a, &b)) + err = conn.QueryRow(context.Background(), "select row(($1::mytype).a, ($1).b)", + pgx.QueryResultFormats{pgx.BinaryFormatCode}, pgtype.Row(2, "bar")). + Scan(pgtype.Row(&isNull, &a, &b)) E(err) fmt.Printf("Adhoc: isNull=%v a=%d b=%s", isNull, a, *b) diff --git a/convert.go b/convert.go index d22a714f..134e123d 100644 --- a/convert.go +++ b/convert.go @@ -504,34 +504,6 @@ func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err er return } -// ROW allows deconstructing row values (records and composite types) into -// fields directly without creating your own type and implementing decoder interfaces -func ROW(isNull *bool, fields ...interface{}) BinaryDecoderFunc { - return func(ci *ConnInfo, src []byte) error { - var record Record - if err := record.DecodeBinary(ci, src); err != nil { - return err - } - - if record.Status == Null { - *isNull = true - return nil - } - - if len(record.Fields) != len(fields) { - return errors.Errorf("can't scan row value, number of fields don't match: row fields count=%d desired fields count=%d", len(record.Fields), len(fields)) - } - - for i, f := range record.Fields { - if err := f.AssignTo(fields[i]); err != nil { - return err - } - } - - return nil - } -} - func init() { kindTypes = map[reflect.Kind]reflect.Type{ reflect.Bool: reflect.TypeOf(false), diff --git a/pgtype.go b/pgtype.go index 1749c8c2..e86255f4 100644 --- a/pgtype.go +++ b/pgtype.go @@ -167,6 +167,15 @@ func (f BinaryDecoderFunc) DecodeBinary(ci *ConnInfo, src []byte) error { return f(ci, src) } +//The BinaryEncoderFunc type is an adapter to allow the use of ordinary functions as BinaryDecoder types. +// If f is a function with the appropriate signature, BinaryEncoderFunc(f) is a BinaryDecoder that calls f. +type BinaryEncoderFunc func(ci *ConnInfo, buf []byte) ([]byte, error) + +// EncodeBinary calls f(ci, buf) +func (f BinaryEncoderFunc) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { + return f(ci, buf) +} + var errUndefined = errors.New("cannot encode status undefined") var errBadStatus = errors.New("invalid status") From a6747b513f7e839171908923e91ad8c13ca8c51d Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Mon, 13 Apr 2020 17:44:02 +0100 Subject: [PATCH 0438/1158] Split composite examples --- composite_test.go | 99 ++++++++------------------------------ custom_composite_test.go | 101 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 80 deletions(-) create mode 100644 custom_composite_test.go diff --git a/composite_test.go b/composite_test.go index b38cdd45..3e63151c 100644 --- a/composite_test.go +++ b/composite_test.go @@ -7,63 +7,11 @@ import ( "github.com/jackc/pgtype" pgx "github.com/jackc/pgx/v4" - errors "golang.org/x/xerrors" ) -type MyType struct { - a int32 // NULL will cause decoding error - b *string // there can be NULL in this position in SQL -} - -func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs") - } - - a := pgtype.Int4{} - b := pgtype.Text{} - - if err := pgtype.ScanRowValue(ci, src, &a, &b); err != nil { - return err - } - - // type compatibility is checked by AssignTo - // only lossless assignments will succeed - if err := a.AssignTo(&dst.a); err != nil { - return err - } - - // AssignTo also deals with null value handling - if err := b.AssignTo(&dst.b); err != nil { - return err - } - - return nil -} - -func (src MyType) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) { - a := pgtype.Int4{src.a, pgtype.Present} - var b pgtype.Text - if src.b != nil { - b = pgtype.Text{*src.b, pgtype.Present} - } else { - b = pgtype.Text{Status: pgtype.Null} - } - - return pgtype.EncodeRow(ci, buf, &a, &b) -} - -func ptrS(s string) *string { - return &s -} - -func E(err error) { - if err != nil { - panic(err) - } -} - -func Example_compositeTypes() { +//ExampleComposite demonstrates use of Row() function to pass and receive +// back composite types without creating boilderplate custom types. +func Example_composite() { conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) E(err) @@ -77,43 +25,34 @@ create type mytype as ( E(err) defer conn.Exec(context.Background(), "drop type mytype") - var result *MyType - - // Demonstrates both passing and reading back composite values - err = conn.QueryRow(context.Background(), "select $1::mytype", - pgx.QueryResultFormats{pgx.BinaryFormatCode}, MyType{1, ptrS("foo")}). - Scan(&result) - E(err) - - fmt.Printf("First row: a=%d b=%s\n", result.a, *result.b) - - // Because we scan into &*MyType, NULLs are handled generically by assigning nil to result - err = conn.QueryRow(context.Background(), "select NULL::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result) - E(err) - - fmt.Printf("Second row: %v\n", result) - //WIP q, err := conn.Prepare(context.Background(), "z", "select $1::mytype") E(err) conn.ConnInfo().RegisterDataType(pgtype.DataType{pgtype.Composite(&pgtype.Int4{}, &pgtype.Text{}), "mytype", q.ParamOIDs[0]}) - // Adhoc rows can be decoded inplace without boilerplate - // Composite types can be encoded/decoded inplace - var isNull bool var a int var b *string - err = conn.QueryRow(context.Background(), "select row(($1::mytype).a, ($1).b)", - pgx.QueryResultFormats{pgx.BinaryFormatCode}, pgtype.Row(2, "bar")). + err = conn.QueryRow(context.Background(), "select $1::mytype", + pgtype.Row(2, "bar")). Scan(pgtype.Row(&isNull, &a, &b)) E(err) - fmt.Printf("Adhoc: isNull=%v a=%d b=%s", isNull, a, *b) + fmt.Printf("First: isNull=%v a=%d b=%s\n", isNull, a, *b) + + err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan(pgtype.Row(&isNull, &a, &b)) + E(err) + + fmt.Printf("Second: isNull=%v a=%d b=%v\n", isNull, a, b) + + err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(pgtype.Row(&isNull, &a, &b)) + E(err) + + fmt.Printf("Third: isNull=%v\n", isNull) // Output: - // First row: a=1 b=foo - // Second row: - // Adhoc: isNull=false a=2 b=bar + // First: isNull=false a=2 b=bar + // Second: isNull=false a=1 b= + // Third: isNull=true } diff --git a/custom_composite_test.go b/custom_composite_test.go new file mode 100644 index 00000000..61ea91c5 --- /dev/null +++ b/custom_composite_test.go @@ -0,0 +1,101 @@ +package pgtype_test + +import ( + "context" + "fmt" + "os" + + "github.com/jackc/pgtype" + pgx "github.com/jackc/pgx/v4" + errors "golang.org/x/xerrors" +) + +type MyType struct { + a int32 // NULL will cause decoding error + b *string // there can be NULL in this position in SQL +} + +func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs") + } + + a := pgtype.Int4{} + b := pgtype.Text{} + + if err := pgtype.ScanRowValue(ci, src, &a, &b); err != nil { + return err + } + + // type compatibility is checked by AssignTo + // only lossless assignments will succeed + if err := a.AssignTo(&dst.a); err != nil { + return err + } + + // AssignTo also deals with null value handling + if err := b.AssignTo(&dst.b); err != nil { + return err + } + + return nil +} + +func (src MyType) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) { + a := pgtype.Int4{src.a, pgtype.Present} + var b pgtype.Text + if src.b != nil { + b = pgtype.Text{*src.b, pgtype.Present} + } else { + b = pgtype.Text{Status: pgtype.Null} + } + + return pgtype.EncodeRow(ci, buf, &a, &b) +} + +func ptrS(s string) *string { + return &s +} + +func E(err error) { + if err != nil { + panic(err) + } +} + +// ExampleCustomCompositeTypes demonstrates how support for custom types mappable to SQL +// composites can be added. +func Example_customCompositeTypes() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + E(err) + + defer conn.Close(context.Background()) + _, err = conn.Exec(context.Background(), `drop type if exists mytype; + +create type mytype as ( + a int4, + b text +);`) + E(err) + defer conn.Exec(context.Background(), "drop type mytype") + + var result *MyType + + // Demonstrates both passing and reading back composite values + err = conn.QueryRow(context.Background(), "select $1::mytype", + pgx.QueryResultFormats{pgx.BinaryFormatCode}, MyType{1, ptrS("foo")}). + Scan(&result) + E(err) + + fmt.Printf("First row: a=%d b=%s\n", result.a, *result.b) + + // Because we scan into &*MyType, NULLs are handled generically by assigning nil to result + err = conn.QueryRow(context.Background(), "select NULL::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result) + E(err) + + fmt.Printf("Second row: %v\n", result) + + // Output: + // First row: a=1 b=foo + // Second row: +} From 2e13f2fe7691a7c99f55a85fdb2e8934da7a9582 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Thu, 16 Apr 2020 20:59:07 +0100 Subject: [PATCH 0439/1158] Move lowlevel binary routines into own package --- binary/record.go | 78 ++++++++++++++++++++++++++++++++++++++++++++++++ convert.go | 22 ++++++-------- record.go | 61 ++++--------------------------------- 3 files changed, 93 insertions(+), 68 deletions(-) create mode 100644 binary/record.go diff --git a/binary/record.go b/binary/record.go new file mode 100644 index 00000000..72b688a8 --- /dev/null +++ b/binary/record.go @@ -0,0 +1,78 @@ +package binary + +import ( + "encoding/binary" + + "github.com/jackc/pgio" + errors "golang.org/x/xerrors" +) + +type RecordFieldIter struct { + rp int + src []byte +} + +// NewRecordFieldIterator creates iterator over binary representation +// of record, aka ROW(), aka Composite +func NewRecordFieldIterator(src []byte) (RecordFieldIter, int, error) { + rp := 0 + if len(src[rp:]) < 4 { + return RecordFieldIter{}, 0, errors.Errorf("Record incomplete %v", src) + } + + fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + return RecordFieldIter{ + rp: rp, + src: src, + }, fieldCount, nil +} + +// Next returns next field decoded from record. eof is returned if no +// more fields left to decode. +func (fi *RecordFieldIter) Next() (fieldOID uint32, buf []byte, eof bool, err error) { + if fi.rp == len(fi.src) { + eof = true + return + } + + if len(fi.src[fi.rp:]) < 8 { + err = errors.Errorf("Record incomplete %v", fi.src) + return + } + fieldOID = binary.BigEndian.Uint32(fi.src[fi.rp:]) + fi.rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(fi.src[fi.rp:]))) + fi.rp += 4 + + if fieldLen >= 0 { + if len(fi.src[fi.rp:]) < fieldLen { + err = errors.Errorf("Record incomplete rp=%d src=%v", fi.rp, fi.src) + return + } + buf = fi.src[fi.rp : fi.rp+fieldLen] + fi.rp += fieldLen + } + + return +} + +// RecordStart adds record header to the buf +func RecordStart(buf []byte, fieldCount int) []byte { + return pgio.AppendUint32(buf, uint32(fieldCount)) +} + +// RecordAdd adds record field to the buf +func RecordAdd(buf []byte, oid uint32, fieldBytes []byte) []byte { + buf = pgio.AppendUint32(buf, oid) + buf = pgio.AppendUint32(buf, uint32(len(fieldBytes))) + buf = append(buf, fieldBytes...) + return buf +} + +// RecordAddNull adds null value as a field to the buf +func RecordAddNull(buf []byte, oid uint32) []byte { + return pgio.AppendInt32(buf, int32(-1)) +} diff --git a/convert.go b/convert.go index 134e123d..6d5ea0c9 100644 --- a/convert.go +++ b/convert.go @@ -5,7 +5,7 @@ import ( "reflect" "time" - "github.com/jackc/pgio" + "github.com/jackc/pgtype/binary" errors "golang.org/x/xerrors" ) @@ -442,16 +442,16 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { // Values must implement BinaryDecoder interface otherwise error is returned. // ScanRowValue takes ownership of src, caller MUST not use it after call func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error { - fieldIter, err := newFieldIterator(src) + fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src) if err != nil { return err } - if len(dst) != fieldIter.fieldCount { - return errors.Errorf("can't scan row value, number of fields don't match: row fields count=%d desired fields count=%d", fieldIter.fieldCount, len(dst)) + if len(dst) != fieldCount { + return errors.Errorf("can't scan row value, number of fields don't match: row fields count=%d desired fields count=%d", fieldCount, len(dst)) } - _, fieldBytes, eof, err := fieldIter.next() + _, fieldBytes, eof, err := fieldIter.Next() for i := 0; !eof; i++ { if err != nil { return err @@ -466,7 +466,7 @@ func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error { return err } - _, fieldBytes, eof, err = fieldIter.next() + _, fieldBytes, eof, err = fieldIter.Next() } return nil @@ -476,14 +476,12 @@ func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error { func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) { fieldBytes := make([]byte, 0, 128) - newBuf = pgio.AppendUint32(buf, uint32(len(fields))) + newBuf = binary.RecordStart(buf, len(fields)) for _, f := range fields { dt, ok := ci.DataTypeForValue(f) if !ok { return nil, errors.Errorf("Unknown OID for %s", f) } - newBuf = pgio.AppendUint32(newBuf, dt.OID) - if f.Get() != nil { binaryEncoder, ok := f.(BinaryEncoder) if !ok { @@ -493,11 +491,9 @@ func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err er if err != nil { return nil, err } - - newBuf = pgio.AppendUint32(newBuf, uint32(len(fieldBytes))) - newBuf = append(newBuf, fieldBytes...) + newBuf = binary.RecordAdd(newBuf, dt.OID, fieldBytes) } else { - newBuf = pgio.AppendInt32(newBuf, int32(-1)) + newBuf = binary.RecordAddNull(newBuf, dt.OID) } } diff --git a/record.go b/record.go index 08603140..4e39f92a 100644 --- a/record.go +++ b/record.go @@ -1,9 +1,10 @@ package pgtype import ( - "encoding/binary" "reflect" + "github.com/jackc/pgtype/binary" + errors "golang.org/x/xerrors" ) @@ -78,56 +79,6 @@ func (src *Record) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } -type fieldIter struct { - rp int - fieldCount int - src []byte -} - -func newFieldIterator(src []byte) (fieldIter, error) { - rp := 0 - if len(src[rp:]) < 4 { - return fieldIter{}, errors.Errorf("Record incomplete %v", src) - } - - fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - - return fieldIter{ - rp: rp, - fieldCount: fieldCount, - src: src, - }, nil -} - -func (fi *fieldIter) next() (fieldOID uint32, buf []byte, eof bool, err error) { - if fi.rp == len(fi.src) { - eof = true - return - } - - if len(fi.src[fi.rp:]) < 8 { - err = errors.Errorf("Record incomplete %v", fi.src) - return - } - fieldOID = binary.BigEndian.Uint32(fi.src[fi.rp:]) - fi.rp += 4 - - fieldLen := int(int32(binary.BigEndian.Uint32(fi.src[fi.rp:]))) - fi.rp += 4 - - if fieldLen >= 0 { - if len(fi.src[fi.rp:]) < fieldLen { - err = errors.Errorf("Record incomplete rp=%d src=%v", fi.rp, fi.src) - return - } - buf = fi.src[fi.rp : fi.rp+fieldLen] - fi.rp += fieldLen - } - - return -} - func prepareNewBinaryDecoder(ci *ConnInfo, fieldOID uint32, v *Value) (BinaryDecoder, error) { var binaryDecoder BinaryDecoder @@ -153,13 +104,13 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } - fieldIter, err := newFieldIterator(src) + fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src) if err != nil { return err } - fields := make([]Value, fieldIter.fieldCount) - fieldOID, fieldBytes, eof, err := fieldIter.next() + fields := make([]Value, fieldCount) + fieldOID, fieldBytes, eof, err := fieldIter.Next() for i := 0; !eof; i++ { if err != nil { @@ -175,7 +126,7 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { return err } - fieldOID, fieldBytes, eof, err = fieldIter.next() + fieldOID, fieldBytes, eof, err = fieldIter.Next() } *dst = Record{Fields: fields, Status: Present} From 54a03cb143744322b83bfc5ba36bc77cf93644a6 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Thu, 16 Apr 2020 22:24:40 +0100 Subject: [PATCH 0440/1158] Add benchmark for various composite encoder implementations ``` BenchmarkBinaryEncodingManual-12 824053234 28.9 ns/op 0 B/op 0 allocs/op BenchmarkBinaryEncodingHelper-12 76815436 314 ns/op 192 B/op 5 allocs/op BenchmarkBinaryEncodingRow-12 65302958 364 ns/op 192 B/op 5 allocs/op ``` --- composite_bench_test.go | 70 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 composite_bench_test.go diff --git a/composite_bench_test.go b/composite_bench_test.go new file mode 100644 index 00000000..a1eba72b --- /dev/null +++ b/composite_bench_test.go @@ -0,0 +1,70 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/binary" +) + +type MyCompositeRaw struct { + a int32 + b *string +} + +func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) { + a := pgtype.Int4{src.a, pgtype.Present} + + fieldBytes := make([]byte, 0, 64) + fieldBytes, _ = a.EncodeBinary(ci, fieldBytes[:0]) + + newBuf = binary.RecordStart(buf, 2) + newBuf = binary.RecordAdd(newBuf, pgtype.Int4OID, fieldBytes) + + if src.b != nil { + fieldBytes, _ = pgtype.Text{*src.b, pgtype.Present}.EncodeBinary(ci, fieldBytes[:0]) + newBuf = binary.RecordAdd(newBuf, pgtype.TextOID, fieldBytes) + } else { + newBuf = binary.RecordAddNull(newBuf, pgtype.TextOID) + } + return +} + +var x []byte + +func BenchmarkBinaryEncodingManual(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + + b.ResetTimer() + for n := 0; n < b.N; n++ { + v := MyCompositeRaw{4, ptrS("ABCDEFG")} + buf, _ = v.EncodeBinary(ci, buf[:0]) + } + x = buf +} + +func BenchmarkBinaryEncodingHelper(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + + b.ResetTimer() + for n := 0; n < b.N; n++ { + v := MyType{4, ptrS("ABCDEFG")} + buf, _ = v.EncodeBinary(ci, buf[:0]) + } + x = buf +} + +func BenchmarkBinaryEncodingRow(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + + b.ResetTimer() + for n := 0; n < b.N; n++ { + c := pgtype.Composite(&pgtype.Int4{}, &pgtype.Text{}) + c.Set(pgtype.Row(2, "bar")) + buf, _ = c.EncodeBinary(ci, buf[:0]) + } + x = buf +} From b88a3e07653f3db164be10edf86edd1497bd56e7 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Sat, 18 Apr 2020 14:08:28 +0100 Subject: [PATCH 0441/1158] Tighten ScanRowValue input types ScanRowValue needs not Value, but BinaryEncoder --- convert.go | 20 ++++++++------------ record_test.go | 10 ++++++++-- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/convert.go b/convert.go index 6d5ea0c9..45f117bc 100644 --- a/convert.go +++ b/convert.go @@ -434,14 +434,15 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { return nil, false } -// ScanRowValue assigns ROW()'s fields to destination Values. -// Argument types are checked and error is returned if SQL field value -// can't be assigned to corresponding destionation Value without loss -// of information. Number of fields have to match number of destination values. +// ScanRowValue decodes ROW()'s and composite type +// from src argument using provided decoders. Decoders should match +// order and count of fields of record being decoded. +// +// In practice you can pass pgtype.Value types as decoders, as +// most of them implement BinaryDecoder interface. // -// Values must implement BinaryDecoder interface otherwise error is returned. // ScanRowValue takes ownership of src, caller MUST not use it after call -func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error { +func ScanRowValue(ci *ConnInfo, src []byte, dst ...BinaryDecoder) error { fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src) if err != nil { return err @@ -457,12 +458,7 @@ func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error { return err } - binaryDecoder, ok := dst[i].(BinaryDecoder) - if !ok { - return errors.Errorf("record field doesn't implement binary decoding: %s", reflect.TypeOf(dst[i]).Name()) - } - - if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { + if err = dst[i].DecodeBinary(ci, fieldBytes); err != nil { return err } diff --git a/record_test.go b/record_test.go index af2105c7..9516612e 100644 --- a/record_test.go +++ b/record_test.go @@ -93,7 +93,10 @@ func TestScanRowValue(t *testing.T) { t.Fatal(err) } t.Run(tt.sql, func(t *testing.T) { - desc := append([]pgtype.Value(nil), tt.expected.Fields...) + desc := []pgtype.BinaryDecoder{} + for _, f := range tt.expected.Fields { + desc = append(desc, f.(pgtype.BinaryDecoder)) + } var raw pgtype.GenericBinary @@ -113,7 +116,10 @@ func TestScanRowValue(t *testing.T) { } // borrow fields from a neighbor test, this makes scan always fail - desc = append([]pgtype.Value(nil), recordTests[(i+1)%len(recordTests)].expected.Fields...) + desc = desc[:0] + for _, f := range recordTests[(i+1)%len(recordTests)].expected.Fields { + desc = append(desc, f.(pgtype.BinaryDecoder)) + } if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err == nil { t.Error("Matching scan didn't fail, despite fields not mathching query result") } From 53e0f25a4e17a0bd0ad643e92d1f62e172fe6921 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Sat, 18 Apr 2020 19:29:08 +0000 Subject: [PATCH 0442/1158] Make ScanRowValue error message clearer --- convert.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert.go b/convert.go index 45f117bc..8008d677 100644 --- a/convert.go +++ b/convert.go @@ -449,7 +449,7 @@ func ScanRowValue(ci *ConnInfo, src []byte, dst ...BinaryDecoder) error { } if len(dst) != fieldCount { - return errors.Errorf("can't scan row value, number of fields don't match: row fields count=%d desired fields count=%d", fieldCount, len(dst)) + return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=%d", fieldCount, len(dst)) } _, fieldBytes, eof, err := fieldIter.Next() From 72680d61f8072c85cb6e03ef51ac1be204736fc3 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Sun, 19 Apr 2020 11:30:21 +0000 Subject: [PATCH 0443/1158] Move value createion outside of encoding benchmark --- composite_bench_test.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/composite_bench_test.go b/composite_bench_test.go index a1eba72b..154b2e26 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -35,10 +35,10 @@ var x []byte func BenchmarkBinaryEncodingManual(b *testing.B) { buf := make([]byte, 0, 128) ci := pgtype.NewConnInfo() + v := MyCompositeRaw{4, ptrS("ABCDEFG")} b.ResetTimer() for n := 0; n < b.N; n++ { - v := MyCompositeRaw{4, ptrS("ABCDEFG")} buf, _ = v.EncodeBinary(ci, buf[:0]) } x = buf @@ -47,10 +47,10 @@ func BenchmarkBinaryEncodingManual(b *testing.B) { func BenchmarkBinaryEncodingHelper(b *testing.B) { buf := make([]byte, 0, 128) ci := pgtype.NewConnInfo() + v := MyType{4, ptrS("ABCDEFG")} b.ResetTimer() for n := 0; n < b.N; n++ { - v := MyType{4, ptrS("ABCDEFG")} buf, _ = v.EncodeBinary(ci, buf[:0]) } x = buf @@ -59,11 +59,13 @@ func BenchmarkBinaryEncodingHelper(b *testing.B) { func BenchmarkBinaryEncodingRow(b *testing.B) { buf := make([]byte, 0, 128) ci := pgtype.NewConnInfo() + f1 := 2 + f2 := ptrS("bar") b.ResetTimer() for n := 0; n < b.N; n++ { c := pgtype.Composite(&pgtype.Int4{}, &pgtype.Text{}) - c.Set(pgtype.Row(2, "bar")) + c.Set(pgtype.Row(f1, f2)) buf, _ = c.EncodeBinary(ci, buf[:0]) } x = buf From 04ff904ff59c7cdbb8bd3b7189c1b90bc02d3958 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Sun, 19 Apr 2020 15:46:10 +0000 Subject: [PATCH 0444/1158] Add binary decoding benchmarks ``` BenchmarkBinaryDecodingManual-4 10479085 106 ns/op 40 B/op 2 allocs/op BenchmarkBinaryDecodingHelpers-4 4485451 263 ns/op 64 B/op 4 allocs/op BenchmarkBinaryDecodingRow-4 1999726 587 ns/op 96 B/op 5 allocs/op ``` --- composite_bench_test.go | 89 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/composite_bench_test.go b/composite_bench_test.go index 154b2e26..30e48ae7 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -5,6 +5,7 @@ import ( "github.com/jackc/pgtype" "github.com/jackc/pgtype/binary" + errors "golang.org/x/xerrors" ) type MyCompositeRaw struct { @@ -30,6 +31,45 @@ func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf return } +func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + a := pgtype.Int4{} + b := pgtype.Text{} + + fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src) + if err != nil { + return err + } + + if 2 != fieldCount { + return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=2", fieldCount) + } + + _, fieldBytes, eof, err := fieldIter.Next() + if eof || err != nil { + return errors.New("Bad record") + } + if err = a.DecodeBinary(ci, fieldBytes); err != nil { + return err + } + + _, fieldBytes, eof, err = fieldIter.Next() + if eof || err != nil { + return errors.New("Bad record") + } + if err = b.DecodeBinary(ci, fieldBytes); err != nil { + return err + } + + dst.a = a.Int + if b.Status == pgtype.Present { + dst.b = &b.String + } else { + dst.b = nil + } + + return nil +} + var x []byte func BenchmarkBinaryEncodingManual(b *testing.B) { @@ -70,3 +110,52 @@ func BenchmarkBinaryEncodingRow(b *testing.B) { } x = buf } + +var dstRaw MyCompositeRaw + +func BenchmarkBinaryDecodingManual(b *testing.B) { + ci := pgtype.NewConnInfo() + buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) + dst := MyCompositeRaw{} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := dst.DecodeBinary(ci, buf) + E(err) + } + dstRaw = dst +} + +var dstMyType MyType + +func BenchmarkBinaryDecodingHelpers(b *testing.B) { + ci := pgtype.NewConnInfo() + buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) + dst := MyType{} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := dst.DecodeBinary(ci, buf) + E(err) + } + dstMyType = dst +} + +var gf1 int +var gf2 *string + +func BenchmarkBinaryDecodingRow(b *testing.B) { + ci := pgtype.NewConnInfo() + buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) + var isNull bool + var f1 int + var f2 *string + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := pgtype.Row(&isNull, &f1, &f2).DecodeBinary(ci, buf) + E(err) + } + gf1 = f1 + gf2 = f2 +} From e283f322e1082cff623f19ac046fe1b5ee2b81ed Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Mon, 20 Apr 2020 22:38:20 +0000 Subject: [PATCH 0445/1158] Composite().Row() helper for working with composites without registration --- composite.go | 18 ++++++++++++++++++ composite_bench_test.go | 12 ++++++++++++ 2 files changed, 30 insertions(+) diff --git a/composite.go b/composite.go index 1caa24d6..d9f47d92 100644 --- a/composite.go +++ b/composite.go @@ -92,6 +92,24 @@ func (src composite) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { return errors.New("Pass pgtype.Row() to Scan to deconstruct Composite") } +// Row method creates composite BinaryEncoder. It's main purpose +// is to build composite query argument inplace without registering +// pgtype.Composite in ConnInfo first +func (src composite) Row(values ...interface{}) BinaryEncoderFunc { + return func(ci *ConnInfo, buf []byte) ([]byte, error) { + if len(values) != len(src.fields) { + return nil, errors.Errorf("Number of fields don't match. Composite has %d fields", len(src.fields)) + } + for i, v := range values { + if err := src.fields[i].Set(v); err != nil { + return nil, err + } + } + src.status = Present + return src.EncodeBinary(ci, buf) + } +} + // DecodeBinary is called when pgtype.Row() is passed to Scan() to // deconstruct composite value func (r rowValue) DecodeBinary(ci *ConnInfo, src []byte) error { diff --git a/composite_bench_test.go b/composite_bench_test.go index 30e48ae7..67dcf1fd 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -110,6 +110,18 @@ func BenchmarkBinaryEncodingRow(b *testing.B) { } x = buf } +func BenchmarkBinaryEncodingRowInplace(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + f1 := 2 + f2 := ptrS("bar") + + b.ResetTimer() + for n := 0; n < b.N; n++ { + buf, _ = pgtype.Composite(&pgtype.Int4{}, &pgtype.Text{}).Row(f1, f2).EncodeBinary(ci, buf[:0]) + } + x = buf +} var dstRaw MyCompositeRaw From 5f0d5f42557769b5794e256eaf52566d10602b66 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Mon, 27 Apr 2020 00:40:29 +0100 Subject: [PATCH 0446/1158] Remove pgtype.Row(), introduce Composite.Scan() pgtype.Row() was optimized for a single line use without much ceremony at a cost of OID registration, which is cumbersome. In practice it so much incovnenience to create new Composite just before making a query. So now there is just a Composite type and 2 helper methods: - SetFields sets composite fields to values passed. This assignment fails if types passed are not assignable to Values pgtype is made of. - Scan acts exactly like query.Scan, but for a composite value. Passed values are set to values from SQL composite. --- composite.go | 195 +++++++++++++++++++++------------------- composite_bench_test.go | 22 ++--- composite_test.go | 17 ++-- 3 files changed, 114 insertions(+), 120 deletions(-) diff --git a/composite.go b/composite.go index d9f47d92..61034262 100644 --- a/composite.go +++ b/composite.go @@ -1,146 +1,153 @@ package pgtype import ( + "github.com/jackc/pgtype/binary" errors "golang.org/x/xerrors" ) -type composite struct { +type Composite struct { fields []Value - status Status + Status Status } -// helper struct to act both as a scanning target and query argument -type rowValue struct { - args []interface{} +// NewComposite creates a Composite object, which acts as a "schema" for +// SQL composite values. +// To pass Composite as SQL parameter first set it's fields, either by +// passing initialized Value{} instances to NewComposite or by calling +// SetFields method +// To read composite fields back pass result of Scan() method +// to query Scan function. +func NewComposite(fields ...Value) *Composite { + return &Composite{fields, Present} } -// Row helper function builds a value which can be both used to -// "assemble" composite quiery arguments and to scan results back. -// -// When passed as an argument to query, values from Row args will -// be assigned to corresponding fields in a composite type and a single -// composite type will be passed to the PostgreSQL. Composite type need -// to be registered in ConnInfo first. This is required so that pgx -// can know which SQL types to use when constructing SQL composite argument -// -// When passed to Scan individual fields from composite query result -// are assigned to corresponding Row arguments. First argument MUST -// be of type *bool to flag when NULL value received. So total number -// of Row arguments, when passed to Scan should be number of composite -// fields you expect to read + 1 -func Row(fields ...interface{}) rowValue { - return rowValue{fields} -} - -// Composite types is meant to be passed to ConnInfo.RegisterDataType only, -// so it is made private on purpose. Once registered, it allows Row -// function to correctly pass query arguments. -func Composite(fields ...Value) *composite { - return &composite{fields, Undefined} -} - -func (src composite) Get() interface{} { - switch src.status { +func (src Composite) Get() interface{} { + switch src.Status { case Present: return src case Null: return nil default: - return src.status + return src.Status } } // Set is called internally when passing query arguments. -// Only valid src is a result of pgtype.Row() or nil -func (dst *composite) Set(src interface{}) error { +func (dst *Composite) Set(src interface{}) error { if src == nil { - *dst = composite{status: Null} + *dst = Composite{Status: Null} return nil } switch value := src.(type) { - case rowValue: - if len(value.args) != len(dst.fields) { + case []Value: + if len(value) != len(dst.fields) { return errors.Errorf("Number of fields don't match. Composite has %d fields", len(dst.fields)) } - for i, v := range value.args { + for i, v := range value { if err := dst.fields[i].Set(v); err != nil { return err } } - dst.status = Present + dst.Status = Present default: - return errors.Errorf("Use pgtype.Row() as query parameter") + return errors.Errorf("Can not convert %v to Composite", src) } return nil } -// AssignTo is never called on composite value directly, it is here -// to satisfy Valuer interface -func (src composite) AssignTo(dst interface{}) error { - return errors.New("BUG: should never be called, because pgtype.composite doesn't support decoding") +// AssignTo should never be called on composite value directly +func (src Composite) AssignTo(dst interface{}) error { + return errors.New("Pass Composite.Scan() to deconstruct composite") } -func (src composite) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { +func (src Composite) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } return EncodeRow(ci, buf, src.fields...) } -// DecodeBinary here is just to make pgx use binary result format by default. -// Users should be using Row function or their own types to scan composites -func (src composite) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { - return errors.New("Pass pgtype.Row() to Scan to deconstruct Composite") -} - -// Row method creates composite BinaryEncoder. It's main purpose -// is to build composite query argument inplace without registering -// pgtype.Composite in ConnInfo first -func (src composite) Row(values ...interface{}) BinaryEncoderFunc { - return func(ci *ConnInfo, buf []byte) ([]byte, error) { - if len(values) != len(src.fields) { - return nil, errors.Errorf("Number of fields don't match. Composite has %d fields", len(src.fields)) - } - for i, v := range values { - if err := src.fields[i].Set(v); err != nil { - return nil, err - } - } - src.status = Present - return src.EncodeBinary(ci, buf) - } -} - -// DecodeBinary is called when pgtype.Row() is passed to Scan() to -// deconstruct composite value -func (r rowValue) DecodeBinary(ci *ConnInfo, src []byte) error { - if len(r.args) == 0 { - return errors.New("pgtype.Row must have 'isNull *bool' as a first argument when used in Scan") - } - - isNull, ok := r.args[0].(*bool) - if !ok { - return errors.New("pgtype.Row must have 'isNull *bool' as a first argument when used in Scan") - } - args := r.args[1:] - - var record Record - if err := record.DecodeBinary(ci, src); err != nil { - return err - } - - if record.Status == Null { - *isNull = true +// DecodeBinary implements BinaryDecoder interface. +// Opposite to Record, fields in a composite act as a "schema" +// and decoding fails if SQL value can't be assigned due to +// type mismatch +func (dst *Composite) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { + if buf == nil { + dst.Status = Null return nil } - if len(record.Fields) != len(args) { - return errors.Errorf("SQL composite can't be read, 'pgtype.Row' has wrong field cout. %d != %d", len(record.Fields), len(args)) + fieldIter, fieldCount, err := binary.NewRecordFieldIterator(buf) + if err != nil { + return err + } else if len(dst.fields) != fieldCount { + return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(dst.fields), fieldCount) } - for i, f := range record.Fields { - if err := f.AssignTo(args[i]); err != nil { + _, fieldBytes, eof, err := fieldIter.Next() + + for i := 0; !eof; i++ { + if err != nil { + return err + } + + binaryDecoder, ok := dst.fields[i].(BinaryDecoder) + if !ok { + return errors.New("Composite field doesn't support binary protocol") + } + + if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { + return err + } + + _, fieldBytes, eof, err = fieldIter.Next() + } + dst.Status = Present + + return nil +} + +// Scan is a helper function to perform "nested" scan of +// a composite value when scanning a query result row. +// isNull is set if scanned value is NULL +// Rest of arguments are set in the order of fields in the composite +// +// Use of Scan method doesn't modify original composite +func (src Composite) Scan(isNull *bool, dst ...interface{}) BinaryDecoderFunc { + return func(ci *ConnInfo, buf []byte) error { + if err := src.DecodeBinary(ci, buf); err != nil { + return err + } + + if src.Status == Null { + *isNull = true + return nil + } + + for i, f := range src.fields { + if err := f.AssignTo(dst[i]); err != nil { + return err + } + } + return nil + } +} + +// SetFields sets Composite's fields to corresponding values +func (dst *Composite) SetFields(values ...interface{}) error { + if len(values) != len(dst.fields) { + return errors.Errorf("Number of fields don't match. Composite has %d fields", len(dst.fields)) + } + for i, v := range values { + if err := dst.fields[i].Set(v); err != nil { return err } } + dst.Status = Present return nil } diff --git a/composite_bench_test.go b/composite_bench_test.go index 67dcf1fd..323c3179 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -101,27 +101,15 @@ func BenchmarkBinaryEncodingRow(b *testing.B) { ci := pgtype.NewConnInfo() f1 := 2 f2 := ptrS("bar") + c := pgtype.NewComposite(&pgtype.Int4{}, &pgtype.Text{}) b.ResetTimer() for n := 0; n < b.N; n++ { - c := pgtype.Composite(&pgtype.Int4{}, &pgtype.Text{}) - c.Set(pgtype.Row(f1, f2)) + c.SetFields(f1, f2) buf, _ = c.EncodeBinary(ci, buf[:0]) } x = buf } -func BenchmarkBinaryEncodingRowInplace(b *testing.B) { - buf := make([]byte, 0, 128) - ci := pgtype.NewConnInfo() - f1 := 2 - f2 := ptrS("bar") - - b.ResetTimer() - for n := 0; n < b.N; n++ { - buf, _ = pgtype.Composite(&pgtype.Int4{}, &pgtype.Text{}).Row(f1, f2).EncodeBinary(ci, buf[:0]) - } - x = buf -} var dstRaw MyCompositeRaw @@ -156,16 +144,18 @@ func BenchmarkBinaryDecodingHelpers(b *testing.B) { var gf1 int var gf2 *string -func BenchmarkBinaryDecodingRow(b *testing.B) { +func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { ci := pgtype.NewConnInfo() buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) var isNull bool var f1 int var f2 *string + c := pgtype.NewComposite(&pgtype.Int4{}, &pgtype.Text{}) + b.ResetTimer() for n := 0; n < b.N; n++ { - err := pgtype.Row(&isNull, &f1, &f2).DecodeBinary(ci, buf) + err := c.Scan(&isNull, &f1, &f2).DecodeBinary(ci, buf) E(err) } gf1 = f1 diff --git a/composite_test.go b/composite_test.go index 3e63151c..666de054 100644 --- a/composite_test.go +++ b/composite_test.go @@ -25,28 +25,25 @@ create type mytype as ( E(err) defer conn.Exec(context.Background(), "drop type mytype") - //WIP - q, err := conn.Prepare(context.Background(), "z", "select $1::mytype") - E(err) - conn.ConnInfo().RegisterDataType(pgtype.DataType{pgtype.Composite(&pgtype.Int4{}, &pgtype.Text{}), "mytype", q.ParamOIDs[0]}) - var isNull bool var a int var b *string - err = conn.QueryRow(context.Background(), "select $1::mytype", - pgtype.Row(2, "bar")). - Scan(pgtype.Row(&isNull, &a, &b)) + c := pgtype.NewComposite(&pgtype.Int4{}, &pgtype.Text{}) + c.SetFields(2, "bar") + + err = conn.QueryRow(context.Background(), "select $1::mytype", c). + Scan(c.Scan(&isNull, &a, &b)) E(err) fmt.Printf("First: isNull=%v a=%d b=%s\n", isNull, a, *b) - err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan(pgtype.Row(&isNull, &a, &b)) + err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan(c.Scan(&isNull, &a, &b)) E(err) fmt.Printf("Second: isNull=%v a=%d b=%v\n", isNull, a, b) - err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(pgtype.Row(&isNull, &a, &b)) + err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(c.Scan(&isNull, &a, &b)) E(err) fmt.Printf("Third: isNull=%v\n", isNull) From 8f3f335b0f54dcf33019f333c9bcf7d78e2fb0ba Mon Sep 17 00:00:00 2001 From: Tobias Salzmann <796084+Eun@users.noreply.github.com> Date: Thu, 30 Apr 2020 11:22:43 +0200 Subject: [PATCH 0447/1158] concludeCommand should not throw away fieldDescriptions --- pgconn.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pgconn.go b/pgconn.go index d5a424ac..e518744a 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1412,7 +1412,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.RowDescription: rr.fieldDescriptions = msg.Fields case *pgproto3.CommandComplete: - rr.concludeCommand(CommandTag(msg.CommandTag), nil) + rr.concludeCommand(CommandTa/g(msg.CommandTag), nil) case *pgproto3.EmptyQueryResponse: rr.concludeCommand(nil, nil) case *pgproto3.ErrorResponse: @@ -1429,7 +1429,6 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { rr.commandTag = commandTag rr.err = err - rr.fieldDescriptions = nil rr.rowValues = nil rr.commandConcluded = true } From 8d9293e1e7bebc0adf7bbca40fdf5579bfa8b5e9 Mon Sep 17 00:00:00 2001 From: Tobias Salzmann <796084+Eun@users.noreply.github.com> Date: Thu, 30 Apr 2020 11:27:01 +0200 Subject: [PATCH 0448/1158] Update pgconn.go --- pgconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index e518744a..4ff3c706 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1412,7 +1412,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.RowDescription: rr.fieldDescriptions = msg.Fields case *pgproto3.CommandComplete: - rr.concludeCommand(CommandTa/g(msg.CommandTag), nil) + rr.concludeCommand(CommandTag(msg.CommandTag), nil) case *pgproto3.EmptyQueryResponse: rr.concludeCommand(nil, nil) case *pgproto3.ErrorResponse: From 700df0d05a8316577c60d3120b6f4f41895fc522 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Fri, 1 May 2020 23:35:58 +0100 Subject: [PATCH 0449/1158] Request binary format in Composite tests --- .vscode/settings.json | 6 ++++++ composite_test.go | 8 +++++--- 2 files changed, 11 insertions(+), 3 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..a32b4d68 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "go.inferGopath": false, + "go.testEnvVars": { + "PGX_TEST_DATABASE": "user=postgres database=pgx_test host=127.0.0.1" + }, +} \ No newline at end of file diff --git a/composite_test.go b/composite_test.go index 666de054..ac0eb4d0 100644 --- a/composite_test.go +++ b/composite_test.go @@ -25,6 +25,8 @@ create type mytype as ( E(err) defer conn.Exec(context.Background(), "drop type mytype") + qrf := pgx.QueryResultFormats{pgx.BinaryFormatCode} + var isNull bool var a int var b *string @@ -32,18 +34,18 @@ create type mytype as ( c := pgtype.NewComposite(&pgtype.Int4{}, &pgtype.Text{}) c.SetFields(2, "bar") - err = conn.QueryRow(context.Background(), "select $1::mytype", c). + err = conn.QueryRow(context.Background(), "select $1::mytype", qrf, c). Scan(c.Scan(&isNull, &a, &b)) E(err) fmt.Printf("First: isNull=%v a=%d b=%s\n", isNull, a, *b) - err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan(c.Scan(&isNull, &a, &b)) + err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype", qrf).Scan(c.Scan(&isNull, &a, &b)) E(err) fmt.Printf("Second: isNull=%v a=%d b=%v\n", isNull, a, b) - err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(c.Scan(&isNull, &a, &b)) + err = conn.QueryRow(context.Background(), "select NULL::mytype", qrf).Scan(c.Scan(&isNull, &a, &b)) E(err) fmt.Printf("Third: isNull=%v\n", isNull) From 63c5d350a366a7d538ae5815b352f828134636d8 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Sat, 2 May 2020 10:54:19 +0100 Subject: [PATCH 0450/1158] Add JSON benchmarks --- composite_bench_test.go | 51 +++++++++++++++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/composite_bench_test.go b/composite_bench_test.go index 323c3179..429ce9b3 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -9,12 +9,12 @@ import ( ) type MyCompositeRaw struct { - a int32 - b *string + A int32 + B *string } func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) { - a := pgtype.Int4{src.a, pgtype.Present} + a := pgtype.Int4{src.A, pgtype.Present} fieldBytes := make([]byte, 0, 64) fieldBytes, _ = a.EncodeBinary(ci, fieldBytes[:0]) @@ -22,8 +22,8 @@ func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf newBuf = binary.RecordStart(buf, 2) newBuf = binary.RecordAdd(newBuf, pgtype.Int4OID, fieldBytes) - if src.b != nil { - fieldBytes, _ = pgtype.Text{*src.b, pgtype.Present}.EncodeBinary(ci, fieldBytes[:0]) + if src.B != nil { + fieldBytes, _ = pgtype.Text{*src.B, pgtype.Present}.EncodeBinary(ci, fieldBytes[:0]) newBuf = binary.RecordAdd(newBuf, pgtype.TextOID, fieldBytes) } else { newBuf = binary.RecordAddNull(newBuf, pgtype.TextOID) @@ -60,11 +60,11 @@ func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } - dst.a = a.Int + dst.A = a.Int if b.Status == pgtype.Present { - dst.b = &b.String + dst.B = &b.String } else { - dst.b = nil + dst.B = nil } return nil @@ -96,7 +96,7 @@ func BenchmarkBinaryEncodingHelper(b *testing.B) { x = buf } -func BenchmarkBinaryEncodingRow(b *testing.B) { +func BenchmarkBinaryEncodingComposite(b *testing.B) { buf := make([]byte, 0, 128) ci := pgtype.NewConnInfo() f1 := 2 @@ -111,6 +111,20 @@ func BenchmarkBinaryEncodingRow(b *testing.B) { x = buf } +func BenchmarkBinaryEncodingJSON(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + v := MyCompositeRaw{4, ptrS("ABCDEFG")} + j := pgtype.JSON{} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + j.Set(v) + buf, _ = j.EncodeBinary(ci, buf[:0]) + } + x = buf +} + var dstRaw MyCompositeRaw func BenchmarkBinaryDecodingManual(b *testing.B) { @@ -161,3 +175,22 @@ func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { gf1 = f1 gf2 = f2 } + +func BenchmarkBinaryDecodingJSON(b *testing.B) { + ci := pgtype.NewConnInfo() + j := pgtype.JSON{} + j.Set(MyCompositeRaw{4, ptrS("ABCDEFG")}) + buf, _ := j.EncodeBinary(ci, nil) + + j = pgtype.JSON{} + dst := MyCompositeRaw{} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := j.DecodeBinary(ci, buf) + E(err) + err = j.AssignTo(&dst) + E(err) + } + dstRaw = dst +} From 391e1ef2ced76042fae145dc82285d98dd85d2c1 Mon Sep 17 00:00:00 2001 From: georgysavva Date: Sat, 2 May 2020 16:35:22 +0300 Subject: [PATCH 0451/1158] Parse connect timeout setting into Config. Restrict context timeout via Config.ConnectTimeout on .Connect() call. --- config.go | 41 ++++++++++-------- config_test.go | 57 +++++++++++++------------ pgconn.go | 5 +++ pgconn_test.go | 113 ++++++++++++++++++++++++++++++------------------- 4 files changed, 129 insertions(+), 87 deletions(-) diff --git a/config.go b/config.go index 06184b02..4f23f7c2 100644 --- a/config.go +++ b/config.go @@ -30,16 +30,17 @@ type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and // then it can be modified. A manually initialized Config will cause ConnectConfig to panic. type Config struct { - Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) - Port uint16 - Database string - User string - Password string - TLSConfig *tls.Config // nil disables TLS - DialFunc DialFunc // e.g. net.Dialer.DialContext - LookupFunc LookupFunc // e.g. net.Resolver.LookupHost - BuildFrontend BuildFrontendFunc - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + Database string + User string + Password string + TLSConfig *tls.Config // nil disables TLS + ConnectTimeout time.Duration + DialFunc DialFunc // e.g. net.Dialer.DialContext + LookupFunc LookupFunc // e.g. net.Resolver.LookupHost + BuildFrontend BuildFrontendFunc + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) Fallbacks []*FallbackConfig @@ -191,12 +192,13 @@ func ParseConfig(connString string) (*Config, error) { BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)), } - if connectTimeout, present := settings["connect_timeout"]; present { - dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) + if connectTimeoutSetting, present := settings["connect_timeout"]; present { + connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting) if err != nil { return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} } - config.DialFunc = dialFunc + config.ConnectTimeout = connectTimeout + config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout) } else { defaultDialer := makeDefaultDialer() config.DialFunc = defaultDialer.DialContext @@ -672,18 +674,21 @@ func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { } } -func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { +func parseConnectTimeoutSetting(s string) (time.Duration, error) { timeout, err := strconv.ParseInt(s, 10, 64) if err != nil { - return nil, err + return 0, err } if timeout < 0 { - return nil, errors.New("negative timeout") + return 0, errors.New("negative timeout") } + return time.Duration(timeout) * time.Second, nil +} +func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc { d := makeDefaultDialer() - d.Timeout = time.Duration(timeout) * time.Second - return d.DialContext, nil + d.Timeout = timeout + return d.DialContext } // ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible diff --git a/config_test.go b/config_test.go index b6068cc8..35f6899e 100644 --- a/config_test.go +++ b/config_test.go @@ -7,6 +7,7 @@ import ( "os" "os/user" "testing" + "time" "github.com/jackc/pgconn" "github.com/stretchr/testify/assert" @@ -127,11 +128,11 @@ func TestParseConfig(t *testing.T) { name: "sslmode verify-ca", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, @@ -153,14 +154,15 @@ func TestParseConfig(t *testing.T) { }, { name: "database url everything", - connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + ConnectTimeout: 5 * time.Second, RuntimeParams: map[string]string{ "application_name": "pgxtest", "search_path": "myschema", @@ -230,14 +232,15 @@ func TestParseConfig(t *testing.T) { }, { name: "DSN everything", - connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema", + connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + ConnectTimeout: 5 * time.Second, RuntimeParams: map[string]string{ "application_name": "pgxtest", "search_path": "myschema", @@ -501,6 +504,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) assert.Equalf(t, expected.User, actual.User, "%s - User", testName) assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) + assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName) assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) // Can't test function equality, so just test that they are set or not. @@ -590,13 +594,14 @@ func TestParseConfigEnvLibpq(t *testing.T) { "PGAPPNAME": "pgxtest", }, config: &pgconn.Config{ - Host: "123.123.123.123", - Port: 7777, - Database: "foo", - User: "bar", - Password: "baz", - TLSConfig: nil, - RuntimeParams: map[string]string{"application_name": "pgxtest"}, + Host: "123.123.123.123", + Port: 7777, + Database: "foo", + User: "bar", + Password: "baz", + ConnectTimeout: 10 * time.Second, + TLSConfig: nil, + RuntimeParams: map[string]string{"application_name": "pgxtest"}, }, }, } diff --git a/pgconn.go b/pgconn.go index d5a424ac..932984c8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -116,6 +116,11 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err panic("config must be created by ParseConfig") } + if config.ConnectTimeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, config.ConnectTimeout) + defer cancel() + } // Simplify usage by treating primary config and fallbacks the same. fallbackConfigs := []*FallbackConfig{ { diff --git a/pgconn_test.go b/pgconn_test.go index e29a36b2..2f7974ea 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -6,6 +6,7 @@ import ( "context" "crypto/tls" "fmt" + "github.com/jackc/pgmock" "io" "io/ioutil" "log" @@ -18,7 +19,6 @@ import ( "time" "github.com/jackc/pgconn" - "github.com/jackc/pgmock" "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" @@ -81,58 +81,85 @@ func (s pgmockWaitStep) Step(*pgproto3.Backend) error { return nil } -func TestConnectWithContextThatTimesOut(t *testing.T) { +func TestConnectTimeout(t *testing.T) { t.Parallel() - - script := &pgmock.Script{ - Steps: []pgmock.Step{ - pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), - pgmock.SendMessage(&pgproto3.AuthenticationOk{}), - pgmockWaitStep(time.Millisecond * 500), - pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), - pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + tests := []struct { + name string + connect func(connStr string) error + }{ + { + name: "via context that times out", + connect: func(connStr string) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancel() + _, err := pgconn.Connect(ctx, connStr) + return err + }, + }, + { + name: "via config ConnectTimeout", + connect: func(connStr string) error { + conf, err := pgconn.ParseConfig(connStr) + require.NoError(t, err) + conf.ConnectTimeout = time.Microsecond * 50 + _, err = pgconn.ConnectConfig(context.Background(), conf) + return err + }, }, } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + script := &pgmock.Script{ + Steps: []pgmock.Step{ + pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + pgmock.SendMessage(&pgproto3.AuthenticationOk{}), + pgmockWaitStep(time.Millisecond * 500), + pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + }, + } - ln, err := net.Listen("tcp", "127.0.0.1:") - require.NoError(t, err) - defer ln.Close() + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() - serverErrChan := make(chan error, 1) - go func() { - defer close(serverErrChan) + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) - conn, err := ln.Accept() - if err != nil { - serverErrChan <- err - return - } - defer conn.Close() + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() - err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450)) - if err != nil { - serverErrChan <- err - return - } + err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450)) + if err != nil { + serverErrChan <- err + return + } - err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) - if err != nil { - serverErrChan <- err - return - } - }() + err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) + if err != nil { + serverErrChan <- err + return + } + }() - parts := strings.Split(ln.Addr().String(), ":") - host := parts[0] - port := parts[1] - connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) - tooLate := time.Now().Add(time.Millisecond * 500) + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + tooLate := time.Now().Add(time.Millisecond * 500) - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) - defer cancel() - _, err = pgconn.Connect(ctx, connStr) - require.True(t, pgconn.Timeout(err), err) - require.True(t, time.Now().Before(tooLate)) + err = tt.connect(connStr) + require.True(t, pgconn.Timeout(err), err) + require.True(t, time.Now().Before(tooLate)) + }) + } } func TestConnectInvalidUser(t *testing.T) { From 2d5a17beab6e8f40c60b56efe0a92e9528f2a424 Mon Sep 17 00:00:00 2001 From: georgysavva Date: Sat, 2 May 2020 16:39:51 +0300 Subject: [PATCH 0452/1158] Add comment. --- pgconn.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pgconn.go b/pgconn.go index 932984c8..69f42621 100644 --- a/pgconn.go +++ b/pgconn.go @@ -116,6 +116,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err panic("config must be created by ParseConfig") } + // ConnectTimeout restricts the whole connection process. if config.ConnectTimeout != 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, config.ConnectTimeout) From 01a7510ae90d37ffbee1438612f340e1f988bb17 Mon Sep 17 00:00:00 2001 From: georgysavva Date: Sat, 2 May 2020 16:43:02 +0300 Subject: [PATCH 0453/1158] Reformat imports --- pgconn_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pgconn_test.go b/pgconn_test.go index 2f7974ea..9a75dede 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -6,7 +6,6 @@ import ( "context" "crypto/tls" "fmt" - "github.com/jackc/pgmock" "io" "io/ioutil" "log" @@ -18,6 +17,8 @@ import ( "testing" "time" + "github.com/jackc/pgmock" + "github.com/jackc/pgconn" "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" From e6c6de9494c92cea58338aa6f3aa5bae7b7492dc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 May 2020 11:34:14 -0500 Subject: [PATCH 0454/1158] Improved ext/shopspring-numeric binary decoding performance Before: BenchmarkDecode/Zero-Binary-16 3944304 292 ns/op 128 B/op 7 allocs/op BenchmarkDecode/Small-Binary-16 2034132 585 ns/op 184 B/op 13 allocs/op BenchmarkDecode/Medium-Binary-16 1747191 690 ns/op 192 B/op 12 allocs/op BenchmarkDecode/Large-Binary-16 1334006 899 ns/op 304 B/op 14 allocs/op BenchmarkDecode/Huge-Binary-16 702382 1590 ns/op 584 B/op 18 allocs/op After: BenchmarkDecode/Zero-Binary-16 14592645 80.1 ns/op 64 B/op 2 allocs/op BenchmarkDecode/Small-Binary-16 5729318 212 ns/op 104 B/op 7 allocs/op BenchmarkDecode/Medium-Binary-16 4930009 241 ns/op 88 B/op 5 allocs/op BenchmarkDecode/Large-Binary-16 3369573 344 ns/op 144 B/op 7 allocs/op BenchmarkDecode/Huge-Binary-16 2587156 453 ns/op 216 B/op 9 allocs/op --- ext/shopspring-numeric/decimal.go | 12 +------ ext/shopspring-numeric/decimal_test.go | 44 ++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index 70906806..148589a4 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -250,17 +250,7 @@ func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } - buf, err := num.EncodeText(ci, nil) - if err != nil { - return err - } - - dec, err := decimal.NewFromString(string(buf)) - if err != nil { - return err - } - - *dst = Numeric{Decimal: dec, Status: pgtype.Present} + *dst = Numeric{Decimal: decimal.NewFromBigInt(num.Int, num.Exp), Status: pgtype.Present} return nil } diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go index 0b256b37..bf34e0dd 100644 --- a/ext/shopspring-numeric/decimal_test.go +++ b/ext/shopspring-numeric/decimal_test.go @@ -11,6 +11,7 @@ import ( shopspring "github.com/jackc/pgtype/ext/shopspring-numeric" "github.com/jackc/pgtype/testutil" "github.com/shopspring/decimal" + "github.com/stretchr/testify/require" ) func mustParseDecimal(t *testing.T, src string) decimal.Decimal { @@ -284,3 +285,46 @@ func TestNumericAssignTo(t *testing.T) { } } } + +func BenchmarkDecode(b *testing.B) { + benchmarks := []struct { + name string + numberStr string + }{ + {"Zero", "0"}, + {"Small", "12345"}, + {"Medium", "12345.12345"}, + {"Large", "123457890.1234567890"}, + {"Huge", "123457890123457890123457890.1234567890123457890123457890"}, + } + + for _, bm := range benchmarks { + src := &shopspring.Numeric{} + err := src.Set(bm.numberStr) + require.NoError(b, err) + textFormat, err := src.EncodeText(nil, nil) + require.NoError(b, err) + binaryFormat, err := src.EncodeBinary(nil, nil) + require.NoError(b, err) + + b.Run(fmt.Sprintf("%s-Text", bm.name), func(b *testing.B) { + dst := &shopspring.Numeric{} + for i := 0; i < b.N; i++ { + err := dst.DecodeText(nil, textFormat) + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run(fmt.Sprintf("%s-Binary", bm.name), func(b *testing.B) { + dst := &shopspring.Numeric{} + for i := 0; i < b.N; i++ { + err := dst.DecodeBinary(nil, binaryFormat) + if err != nil { + b.Fatal(err) + } + } + }) + } +} From a4dd4af7568f2601f86ae3d6c09614b6056c378d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 May 2020 17:31:34 -0500 Subject: [PATCH 0455/1158] Add benchmarks for scan into native type vs decoder --- pgtype_test.go | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/pgtype_test.go b/pgtype_test.go index 9602f419..dee5377d 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -127,3 +127,37 @@ func TestConnInfoScanUnknownOIDToCustomType(t *testing.T) { assert.NoError(t, err) assert.Nil(t, pCt) } + +func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v pgtype.Int4 + + for i := 0; i < b.N; i++ { + v = pgtype.Int4{} + err := ci.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != (pgtype.Int4{Int: 42, Status: pgtype.Present}) { + b.Fatal("scan failed due to bad value") + } + } +} + +func BenchmarkConnInfoScanInt4IntoGoInt32(b *testing.B) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v int32 + + for i := 0; i < b.N; i++ { + v = 0 + err := ci.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != 42 { + b.Fatal("scan failed due to bad value") + } + } +} From 6357d3b3f3522cb2e39d2bae8cf41a3ae31c1a34 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 May 2020 17:31:53 -0500 Subject: [PATCH 0456/1158] Avoid extra type assertion on native type Scan path Before: BenchmarkConnInfoScanInt4IntoBinaryDecoder-16 89744814 12.5 ns/op 0 B/op 0 allocs/op BenchmarkConnInfoScanInt4IntoGoInt32-16 27688370 41.1 ns/op 0 B/op 0 allocs/op After: BenchmarkConnInfoScanInt4IntoBinaryDecoder-16 88181061 12.4 ns/op 0 B/op 0 allocs/op BenchmarkConnInfoScanInt4IntoGoInt32-16 30402768 36.8 ns/op 0 B/op 0 allocs/op --- pgtype.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/pgtype.go b/pgtype.go index 914e02d2..f7dc1379 100644 --- a/pgtype.go +++ b/pgtype.go @@ -337,12 +337,17 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { } func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interface{}) error { - if dest, ok := dest.(BinaryDecoder); ok && formatCode == BinaryFormatCode { - return dest.DecodeBinary(ci, buf) - } - - if dest, ok := dest.(TextDecoder); ok && formatCode == TextFormatCode { - return dest.DecodeText(ci, buf) + switch formatCode { + case BinaryFormatCode: + if dest, ok := dest.(BinaryDecoder); ok { + return dest.DecodeBinary(ci, buf) + } + case TextFormatCode: + if dest, ok := dest.(TextDecoder); ok { + return dest.DecodeText(ci, buf) + } + default: + return errors.Errorf("unknown format code: %v", formatCode) } if dt, ok := ci.DataTypeForOID(oid); ok { @@ -366,8 +371,6 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac } else { return errors.Errorf("%T is not a pgtype.BinaryDecoder", value) } - default: - return errors.Errorf("unknown format code: %v", formatCode) } if scanner, ok := dest.(sql.Scanner); ok { From 18c64dceeee5aa96300d71c5260ba08bbdef9643 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 May 2020 20:18:51 -0500 Subject: [PATCH 0457/1158] ConnInfo Scan optimizes common native types This comes at a small expense to scanning into a type that implements TextDecoder or BinaryDecoder but I think it is a good trade. Before: BenchmarkConnInfoScanInt4IntoBinaryDecoder-16 88181061 12.4 ns/op 0 B/op 0 allocs/op BenchmarkConnInfoScanInt4IntoGoInt32-16 30402768 36.8 ns/op 0 B/op 0 allocs/op After: BenchmarkConnInfoScanInt4IntoBinaryDecoder-16 79859755 14.6 ns/op 0 B/op 0 allocs/op BenchmarkConnInfoScanInt4IntoGoInt32-16 38969991 30.0 ns/op 0 B/op 0 allocs/op --- pgtype.go | 61 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 17 deletions(-) diff --git a/pgtype.go b/pgtype.go index f7dc1379..f6c354ef 100644 --- a/pgtype.go +++ b/pgtype.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql" "reflect" + "time" errors "golang.org/x/xerrors" ) @@ -337,17 +338,39 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { } func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interface{}) error { - switch formatCode { - case BinaryFormatCode: - if dest, ok := dest.(BinaryDecoder); ok { - return dest.DecodeBinary(ci, buf) + isFastType := false + switch dest.(type) { + case *int16: + isFastType = true + case *int32: + isFastType = true + case *int64: + isFastType = true + case *float32: + isFastType = true + case *float64: + isFastType = true + case *string: + isFastType = true + case *time.Time: + isFastType = true + case *[]byte: + isFastType = true + } + + if !isFastType { + switch formatCode { + case BinaryFormatCode: + if dest, ok := dest.(BinaryDecoder); ok { + return dest.DecodeBinary(ci, buf) + } + case TextFormatCode: + if dest, ok := dest.(TextDecoder); ok { + return dest.DecodeText(ci, buf) + } + default: + return errors.Errorf("unknown format code: %v", formatCode) } - case TextFormatCode: - if dest, ok := dest.(TextDecoder); ok { - return dest.DecodeText(ci, buf) - } - default: - return errors.Errorf("unknown format code: %v", formatCode) } if dt, ok := ci.DataTypeForOID(oid); ok { @@ -371,17 +394,21 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac } else { return errors.Errorf("%T is not a pgtype.BinaryDecoder", value) } + default: + return errors.Errorf("unknown format code: %v", formatCode) } - if scanner, ok := dest.(sql.Scanner); ok { - sqlSrc, err := DatabaseSQLValue(ci, value) - if err != nil { - return err + if !isFastType { + if scanner, ok := dest.(sql.Scanner); ok { + sqlSrc, err := DatabaseSQLValue(ci, value) + if err != nil { + return err + } + return scanner.Scan(sqlSrc) } - return scanner.Scan(sqlSrc) - } else { - return value.AssignTo(dest) } + + return value.AssignTo(dest) } // We might be given a pointer to something that implements the decoder interface(s), From ab5e59782619eaf662b07c33d257c4c136ad5034 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 May 2020 20:30:58 -0500 Subject: [PATCH 0458/1158] Avoid type assertion in Scan Before: BenchmarkConnInfoScanInt4IntoBinaryDecoder-16 79859755 14.6 ns/op 0 B/op 0 allocs/op BenchmarkConnInfoScanInt4IntoGoInt32-16 38969991 30.0 ns/op 0 B/op 0 allocs/op After: BenchmarkConnInfoScanInt4IntoBinaryDecoder-16 458046958 13.3 ns/op 0 B/op 0 allocs/op BenchmarkConnInfoScanInt4IntoGoInt32-16 275791776 20.6 ns/op 0 B/op 0 allocs/op --- pgtype.go | 45 ++++++++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/pgtype.go b/pgtype.go index f6c354ef..bb0a99af 100644 --- a/pgtype.go +++ b/pgtype.go @@ -164,8 +164,12 @@ var errBadStatus = errors.New("invalid status") type DataType struct { Value Value - Name string - OID uint32 + + textDecoder TextDecoder + binaryDecoder BinaryDecoder + + Name string + OID uint32 } type ConnInfo struct { @@ -285,6 +289,14 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { } ci.oidToResultFormatCode[t.OID] = formatCode } + + if d, ok := t.Value.(TextDecoder); ok { + t.textDecoder = d + } + + if d, ok := t.Value.(BinaryDecoder); ok { + t.binaryDecoder = d + } } func (ci *ConnInfo) DataTypeForOID(oid uint32) (*DataType, bool) { @@ -374,25 +386,24 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac } if dt, ok := ci.DataTypeForOID(oid); ok { - value := dt.Value switch formatCode { - case TextFormatCode: - if textDecoder, ok := value.(TextDecoder); ok { - err := textDecoder.DecodeText(ci, buf) - if err != nil { - return err - } - } else { - return errors.Errorf("%T is not a pgtype.TextDecoder", value) - } case BinaryFormatCode: - if binaryDecoder, ok := value.(BinaryDecoder); ok { - err := binaryDecoder.DecodeBinary(ci, buf) + if dt.binaryDecoder != nil { + err := dt.binaryDecoder.DecodeBinary(ci, buf) if err != nil { return err } } else { - return errors.Errorf("%T is not a pgtype.BinaryDecoder", value) + return errors.Errorf("%T is not a pgtype.BinaryDecoder", dt.Value) + } + case TextFormatCode: + if dt.textDecoder != nil { + err := dt.textDecoder.DecodeText(ci, buf) + if err != nil { + return err + } + } else { + return errors.Errorf("%T is not a pgtype.TextDecoder", dt.Value) } default: return errors.Errorf("unknown format code: %v", formatCode) @@ -400,7 +411,7 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac if !isFastType { if scanner, ok := dest.(sql.Scanner); ok { - sqlSrc, err := DatabaseSQLValue(ci, value) + sqlSrc, err := DatabaseSQLValue(ci, dt.Value) if err != nil { return err } @@ -408,7 +419,7 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac } } - return value.AssignTo(dest) + return dt.Value.AssignTo(dest) } // We might be given a pointer to something that implements the decoder interface(s), From 3b7c47a2a7dac37cc998979d45fa13774c1e38e5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 5 May 2020 13:23:14 -0500 Subject: [PATCH 0459/1158] Add EnumType --- enum_type.go | 163 ++++++++++++++++++++++++++++++++++++++++++++++ enum_type_test.go | 148 +++++++++++++++++++++++++++++++++++++++++ pgtype.go | 41 +++++++++++- 3 files changed, 350 insertions(+), 2 deletions(-) create mode 100644 enum_type.go create mode 100644 enum_type_test.go diff --git a/enum_type.go b/enum_type.go new file mode 100644 index 00000000..44095cc7 --- /dev/null +++ b/enum_type.go @@ -0,0 +1,163 @@ +package pgtype + +import errors "golang.org/x/xerrors" + +// EnumType represents an enum type. In the normal pgtype model a Go type maps to a PostgreSQL type and an instance +// of a Go type maps to a PostgreSQL value of that type. EnumType is different in that an instance of EnumType +// represents a PostgreSQL type. The zero value is not usable -- NewEnumType must be used as a constructor. In general, +// an EnumType should not be used to represent a value. It should only be used as an encoder and decoder internal to +// ConnInfo. +type EnumType struct { + String string + Status Status + + pgTypeName string // PostgreSQL type name + members []string // enum members + membersMap map[string]string // map to quickly lookup member and reuse string instead of allocating +} + +// NewEnumType initializes a new EnumType. It retains a read-only reference to members. members must not be changed. +func NewEnumType(pgTypeName string, members []string) *EnumType { + et := &EnumType{pgTypeName: pgTypeName, members: members} + et.membersMap = make(map[string]string, len(members)) + for _, m := range members { + et.membersMap[m] = m + } + return et +} + +func (et *EnumType) CloneTypeValue() Value { + return &EnumType{ + String: et.String, + Status: et.Status, + + pgTypeName: et.pgTypeName, + members: et.members, + membersMap: et.membersMap, + } +} + +func (et *EnumType) PgTypeName() string { + return et.pgTypeName +} + +func (et *EnumType) Members() []string { + return et.members +} + +// Set assigns src to dst. Set purposely does not check that src is a member. This allows continued error free +// operation in the event the PostgreSQL enum type is modified during a connection. +func (dst *EnumType) Set(src interface{}) error { + if src == nil { + dst.Status = Null + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case string: + dst.String = value + dst.Status = Present + case *string: + if value == nil { + dst.Status = Null + } else { + dst.String = *value + dst.Status = Present + } + case []byte: + if value == nil { + dst.Status = Null + } else { + dst.String = string(value) + dst.Status = Present + } + default: + if originalSrc, ok := underlyingStringType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to enum %s", value, dst.pgTypeName) + } + + return nil +} + +func (dst EnumType) Get() interface{} { + switch dst.Status { + case Present: + return dst.String + case Null: + return nil + default: + return dst.Status + } +} + +func (src *EnumType) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *string: + *v = src.String + return nil + case *[]byte: + *v = make([]byte, len(src.String)) + copy(*v, src.String) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return errors.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %#v into %T", src, dst) +} + +func (dst *EnumType) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + dst.Status = Null + return nil + } + + // Lookup the string in membersMap to avoid an allocation. + if s, found := dst.membersMap[string(src)]; found { + dst.String = s + } else { + // If an enum type is modified after the initial connection it is possible to receive an unexpected value. + // Gracefully handle this situation. Purposely NOT modifying members and membersMap to allow for sharing members + // and membersMap between connections. + dst.String = string(src) + } + dst.Status = Present + + return nil +} + +func (dst *EnumType) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) +} + +func (src EnumType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.String...), nil +} + +func (src EnumType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return src.EncodeText(ci, buf) +} diff --git a/enum_type_test.go b/enum_type_test.go new file mode 100644 index 00000000..4dd88f2a --- /dev/null +++ b/enum_type_test.go @@ -0,0 +1,148 @@ +package pgtype_test + +import ( + "bytes" + "context" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupEnum(t *testing.T, conn *pgx.Conn) *pgtype.EnumType { + _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;") + require.NoError(t, err) + + _, err = conn.Exec(context.Background(), "create type pgtype_enum_color as enum ('blue', 'green', 'purple');") + require.NoError(t, err) + + var oid uint32 + err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "pgtype_enum_color").Scan(&oid) + require.NoError(t, err) + + et := pgtype.NewEnumType("pgtype_enum_color", []string{"blue", "green", "purple"}) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "pgtype_enum_color", OID: oid}) + + return et +} + +func cleanupEnum(t *testing.T, conn *pgx.Conn) { + _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;") + require.NoError(t, err) +} + +func TestEnumTypeTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + setupEnum(t, conn) + defer cleanupEnum(t, conn) + + var dst string + err := conn.QueryRow(context.Background(), "select $1::pgtype_enum_color", "blue").Scan(&dst) + require.NoError(t, err) + require.EqualValues(t, "blue", dst) +} + +func TestEnumTypeSet(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + enumType := setupEnum(t, conn) + defer cleanupEnum(t, conn) + + successfulTests := []struct { + source interface{} + result interface{} + }{ + {source: "blue", result: "blue"}, + {source: _string("green"), result: "green"}, + {source: (*string)(nil), result: nil}, + } + + for i, tt := range successfulTests { + err := enumType.Set(tt.source) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.result, enumType.Get(), "%d", i) + } +} + +func TestEnumTypeAssignTo(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + enumType := setupEnum(t, conn) + defer cleanupEnum(t, conn) + + { + var s string + + err := enumType.Set("blue") + require.NoError(t, err) + + err = enumType.AssignTo(&s) + require.NoError(t, err) + + assert.EqualValues(t, "blue", s) + } + + { + var ps *string + + err := enumType.Set("blue") + require.NoError(t, err) + + err = enumType.AssignTo(&ps) + require.NoError(t, err) + + assert.EqualValues(t, "blue", *ps) + } + + { + var ps *string + + err := enumType.Set(nil) + require.NoError(t, err) + + err = enumType.AssignTo(&ps) + require.NoError(t, err) + + assert.EqualValues(t, (*string)(nil), ps) + } + + var buf []byte + bytesTests := []struct { + src interface{} + dst *[]byte + expected []byte + }{ + {src: "blue", dst: &buf, expected: []byte("blue")}, + {src: nil, dst: &buf, expected: nil}, + } + + for i, tt := range bytesTests { + err := enumType.Set(tt.src) + require.NoError(t, err, "%d", i) + + err = enumType.AssignTo(tt.dst) + require.NoError(t, err, "%d", i) + + if bytes.Compare(*tt.dst, tt.expected) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst) + } + } + + { + var s string + + err := enumType.Set(nil) + require.NoError(t, err) + + err = enumType.AssignTo(&s) + require.Error(t, err) + } + +} diff --git a/pgtype.go b/pgtype.go index bb0a99af..997899d8 100644 --- a/pgtype.go +++ b/pgtype.go @@ -125,6 +125,22 @@ type Value interface { AssignTo(dst interface{}) error } +// TypeValue represents values where instances represent a type. In the normal pgtype model a Go type maps to a +// PostgreSQL type and an instance of a Go type maps to a PostgreSQL value of that type. Implementors of TypeValue +// are different in that an instance represents a PostgreSQL type. This can be useful for representing types such +// as enums, composites, and arrays. +// +// In general, instances of TypeValue should not be used to directly represent a value. It should only be used as an +// encoder and decoder internal to ConnInfo. +type TypeValue interface { + // CloneTypeValue duplicates a TypeValue including references to internal type information. e.g. the list of members + // in an EnumType. + CloneTypeValue() Value + + // PgTypeName returns the PostgreSQL name of this type. + PgTypeName() string +} + type BinaryDecoder interface { // DecodeBinary decodes src into BinaryDecoder. If src is nil then the // original SQL value is NULL. BinaryDecoder takes ownership of src. The @@ -270,9 +286,16 @@ func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) { } func (ci *ConnInfo) RegisterDataType(t DataType) { + tv, _ := t.Value.(TypeValue) + if tv != nil { + t.Value = tv.CloneTypeValue() + } + ci.oidToDataType[t.OID] = &t ci.nameToDataType[t.Name] = &t - ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t + if tv == nil { + ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t + } { var formatCode int16 @@ -310,6 +333,11 @@ func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) { } func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) { + if tv, ok := v.(TypeValue); ok { + dt, ok := ci.nameToDataType[tv.PgTypeName()] + return dt, ok + } + dt, ok := ci.reflectTypeToDataType[reflect.ValueOf(v).Type()] return dt, ok } @@ -336,11 +364,20 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { oidToDataType: make(map[uint32]*DataType, len(ci.oidToDataType)), nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)), reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)), + oidToParamFormatCode: make(map[uint32]int16, len(ci.oidToParamFormatCode)), + oidToResultFormatCode: make(map[uint32]int16, len(ci.oidToResultFormatCode)), } for _, dt := range ci.oidToDataType { + var value Value + if tv, ok := dt.Value.(TypeValue); ok { + value = tv.CloneTypeValue() + } else { + value = reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value) + } + ci2.RegisterDataType(DataType{ - Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value), + Value: value, Name: dt.Name, OID: dt.OID, }) From 4d2b5a18c4de39f44ed1829cf663748f7c30e5cf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 6 May 2020 09:51:41 -0500 Subject: [PATCH 0460/1158] Clarify Value.Get() documentation Specifying behavior for Status Null and Undefined is incorrect because a Value is not required to have a Status. In addition, standard behavior is to return nil, not pgtype.Null when the Status is pgtype.Null. --- pgtype.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pgtype.go b/pgtype.go index 997899d8..c002150c 100644 --- a/pgtype.go +++ b/pgtype.go @@ -115,8 +115,7 @@ type Value interface { // Set converts and assigns src to itself. Set(src interface{}) error - // Get returns the simplest representation of Value. If the Value is Null or - // Undefined that is the return value. If no simpler representation is + // Get returns the simplest representation of Value. If no simpler representation is // possible, then Get() returns Value. Get() interface{} From 2938981516bba5e0586953f9480b4b9a0a07429f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 6 May 2020 10:30:43 -0500 Subject: [PATCH 0461/1158] Make EnumType implementation private --- enum_type.go | 95 +++++++++++++++++++++++++---------------------- enum_type_test.go | 2 +- 2 files changed, 51 insertions(+), 46 deletions(-) diff --git a/enum_type.go b/enum_type.go index 44095cc7..eecd0237 100644 --- a/enum_type.go +++ b/enum_type.go @@ -2,14 +2,19 @@ package pgtype import errors "golang.org/x/xerrors" -// EnumType represents an enum type. In the normal pgtype model a Go type maps to a PostgreSQL type and an instance -// of a Go type maps to a PostgreSQL value of that type. EnumType is different in that an instance of EnumType -// represents a PostgreSQL type. The zero value is not usable -- NewEnumType must be used as a constructor. In general, -// an EnumType should not be used to represent a value. It should only be used as an encoder and decoder internal to -// ConnInfo. -type EnumType struct { - String string - Status Status +// EnumType represents a enum type. While it implements Value, this is only in service of its type conversion duties +// when registered as a data type in a ConnType. It should not be used directly as a Value. +type EnumType interface { + Value + TypeValue + + // Members returns possible members of this enumeration. The returned slice must not be modified. + Members() []string +} + +type enumType struct { + value string + status Status pgTypeName string // PostgreSQL type name members []string // enum members @@ -17,8 +22,8 @@ type EnumType struct { } // NewEnumType initializes a new EnumType. It retains a read-only reference to members. members must not be changed. -func NewEnumType(pgTypeName string, members []string) *EnumType { - et := &EnumType{pgTypeName: pgTypeName, members: members} +func NewEnumType(pgTypeName string, members []string) EnumType { + et := &enumType{pgTypeName: pgTypeName, members: members} et.membersMap = make(map[string]string, len(members)) for _, m := range members { et.membersMap[m] = m @@ -26,10 +31,10 @@ func NewEnumType(pgTypeName string, members []string) *EnumType { return et } -func (et *EnumType) CloneTypeValue() Value { - return &EnumType{ - String: et.String, - Status: et.Status, +func (et *enumType) CloneTypeValue() Value { + return &enumType{ + value: et.value, + status: et.status, pgTypeName: et.pgTypeName, members: et.members, @@ -37,19 +42,19 @@ func (et *EnumType) CloneTypeValue() Value { } } -func (et *EnumType) PgTypeName() string { +func (et *enumType) PgTypeName() string { return et.pgTypeName } -func (et *EnumType) Members() []string { +func (et *enumType) Members() []string { return et.members } // Set assigns src to dst. Set purposely does not check that src is a member. This allows continued error free // operation in the event the PostgreSQL enum type is modified during a connection. -func (dst *EnumType) Set(src interface{}) error { +func (dst *enumType) Set(src interface{}) error { if src == nil { - dst.Status = Null + dst.status = Null return nil } @@ -62,21 +67,21 @@ func (dst *EnumType) Set(src interface{}) error { switch value := src.(type) { case string: - dst.String = value - dst.Status = Present + dst.value = value + dst.status = Present case *string: if value == nil { - dst.Status = Null + dst.status = Null } else { - dst.String = *value - dst.Status = Present + dst.value = *value + dst.status = Present } case []byte: if value == nil { - dst.Status = Null + dst.status = Null } else { - dst.String = string(value) - dst.Status = Present + dst.value = string(value) + dst.status = Present } default: if originalSrc, ok := underlyingStringType(src); ok { @@ -88,27 +93,27 @@ func (dst *EnumType) Set(src interface{}) error { return nil } -func (dst EnumType) Get() interface{} { - switch dst.Status { +func (dst enumType) Get() interface{} { + switch dst.status { case Present: - return dst.String + return dst.value case Null: return nil default: - return dst.Status + return dst.status } } -func (src *EnumType) AssignTo(dst interface{}) error { - switch src.Status { +func (src *enumType) AssignTo(dst interface{}) error { + switch src.status { case Present: switch v := dst.(type) { case *string: - *v = src.String + *v = src.value return nil case *[]byte: - *v = make([]byte, len(src.String)) - copy(*v, src.String) + *v = make([]byte, len(src.value)) + copy(*v, src.value) return nil default: if nextDst, retry := GetAssignToDstType(dst); retry { @@ -123,41 +128,41 @@ func (src *EnumType) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } -func (dst *EnumType) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *enumType) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - dst.Status = Null + dst.status = Null return nil } // Lookup the string in membersMap to avoid an allocation. if s, found := dst.membersMap[string(src)]; found { - dst.String = s + dst.value = s } else { // If an enum type is modified after the initial connection it is possible to receive an unexpected value. // Gracefully handle this situation. Purposely NOT modifying members and membersMap to allow for sharing members // and membersMap between connections. - dst.String = string(src) + dst.value = string(src) } - dst.Status = Present + dst.status = Present return nil } -func (dst *EnumType) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *enumType) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } -func (src EnumType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { +func (src enumType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.status { case Null: return nil, nil case Undefined: return nil, errUndefined } - return append(buf, src.String...), nil + return append(buf, src.value...), nil } -func (src EnumType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src enumType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return src.EncodeText(ci, buf) } diff --git a/enum_type_test.go b/enum_type_test.go index 4dd88f2a..c1e2add0 100644 --- a/enum_type_test.go +++ b/enum_type_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -func setupEnum(t *testing.T, conn *pgx.Conn) *pgtype.EnumType { +func setupEnum(t *testing.T, conn *pgx.Conn) pgtype.EnumType { _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;") require.NoError(t, err) From 10838b39f64429e79d396b117ec9bfff94f6468e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 6 May 2020 14:45:55 -0500 Subject: [PATCH 0462/1158] Remove vscode settings --- .vscode/settings.json | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index a32b4d68..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "go.inferGopath": false, - "go.testEnvVars": { - "PGX_TEST_DATABASE": "user=postgres database=pgx_test host=127.0.0.1" - }, -} \ No newline at end of file From 37e976192b45fab340e6c8d7cf60dfbe799406f2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 May 2020 17:17:12 -0500 Subject: [PATCH 0463/1158] ScanRowValue accepts interface{} dst --- convert.go | 8 ++++---- custom_composite_test.go | 16 +--------------- pgtype.go | 2 -- record_test.go | 2 +- 4 files changed, 6 insertions(+), 22 deletions(-) diff --git a/convert.go b/convert.go index 8008d677..115f33a3 100644 --- a/convert.go +++ b/convert.go @@ -442,7 +442,7 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { // most of them implement BinaryDecoder interface. // // ScanRowValue takes ownership of src, caller MUST not use it after call -func ScanRowValue(ci *ConnInfo, src []byte, dst ...BinaryDecoder) error { +func ScanRowValue(ci *ConnInfo, src []byte, dst ...interface{}) error { fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src) if err != nil { return err @@ -452,17 +452,17 @@ func ScanRowValue(ci *ConnInfo, src []byte, dst ...BinaryDecoder) error { return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=%d", fieldCount, len(dst)) } - _, fieldBytes, eof, err := fieldIter.Next() + fieldOID, fieldBytes, eof, err := fieldIter.Next() for i := 0; !eof; i++ { if err != nil { return err } - if err = dst[i].DecodeBinary(ci, fieldBytes); err != nil { + if err = ci.Scan(fieldOID, BinaryFormatCode, fieldBytes, dst[i]); err != nil { return err } - _, fieldBytes, eof, err = fieldIter.Next() + fieldOID, fieldBytes, eof, err = fieldIter.Next() } return nil diff --git a/custom_composite_test.go b/custom_composite_test.go index 61ea91c5..f6f37ec7 100644 --- a/custom_composite_test.go +++ b/custom_composite_test.go @@ -20,21 +20,7 @@ func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs") } - a := pgtype.Int4{} - b := pgtype.Text{} - - if err := pgtype.ScanRowValue(ci, src, &a, &b); err != nil { - return err - } - - // type compatibility is checked by AssignTo - // only lossless assignments will succeed - if err := a.AssignTo(&dst.a); err != nil { - return err - } - - // AssignTo also deals with null value handling - if err := b.AssignTo(&dst.b); err != nil { + if err := pgtype.ScanRowValue(ci, src, &dst.a, &dst.b); err != nil { return err } diff --git a/pgtype.go b/pgtype.go index d0d4885c..eead52af 100644 --- a/pgtype.go +++ b/pgtype.go @@ -459,8 +459,6 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac } else { return errors.Errorf("%T is not a pgtype.TextDecoder", dt.Value) } - default: - return errors.Errorf("unknown format code: %v", formatCode) } if !isFastType { diff --git a/record_test.go b/record_test.go index 9516612e..3794fcd7 100644 --- a/record_test.go +++ b/record_test.go @@ -93,7 +93,7 @@ func TestScanRowValue(t *testing.T) { t.Fatal(err) } t.Run(tt.sql, func(t *testing.T) { - desc := []pgtype.BinaryDecoder{} + desc := []interface{}{} for _, f := range tt.expected.Fields { desc = append(desc, f.(pgtype.BinaryDecoder)) } From ff9bc5d68dd597c3f4259eadd759cfb2817ca43b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 7 May 2020 10:15:23 -0500 Subject: [PATCH 0464/1158] Merge binary package into pgtype package --- binary/record.go | 78 ----------------------------------------- composite.go | 76 +++++++++++++++++++++++++++++++++++++-- composite_bench_test.go | 11 +++--- convert.go | 9 +++-- record.go | 4 +-- 5 files changed, 84 insertions(+), 94 deletions(-) delete mode 100644 binary/record.go diff --git a/binary/record.go b/binary/record.go deleted file mode 100644 index 72b688a8..00000000 --- a/binary/record.go +++ /dev/null @@ -1,78 +0,0 @@ -package binary - -import ( - "encoding/binary" - - "github.com/jackc/pgio" - errors "golang.org/x/xerrors" -) - -type RecordFieldIter struct { - rp int - src []byte -} - -// NewRecordFieldIterator creates iterator over binary representation -// of record, aka ROW(), aka Composite -func NewRecordFieldIterator(src []byte) (RecordFieldIter, int, error) { - rp := 0 - if len(src[rp:]) < 4 { - return RecordFieldIter{}, 0, errors.Errorf("Record incomplete %v", src) - } - - fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - - return RecordFieldIter{ - rp: rp, - src: src, - }, fieldCount, nil -} - -// Next returns next field decoded from record. eof is returned if no -// more fields left to decode. -func (fi *RecordFieldIter) Next() (fieldOID uint32, buf []byte, eof bool, err error) { - if fi.rp == len(fi.src) { - eof = true - return - } - - if len(fi.src[fi.rp:]) < 8 { - err = errors.Errorf("Record incomplete %v", fi.src) - return - } - fieldOID = binary.BigEndian.Uint32(fi.src[fi.rp:]) - fi.rp += 4 - - fieldLen := int(int32(binary.BigEndian.Uint32(fi.src[fi.rp:]))) - fi.rp += 4 - - if fieldLen >= 0 { - if len(fi.src[fi.rp:]) < fieldLen { - err = errors.Errorf("Record incomplete rp=%d src=%v", fi.rp, fi.src) - return - } - buf = fi.src[fi.rp : fi.rp+fieldLen] - fi.rp += fieldLen - } - - return -} - -// RecordStart adds record header to the buf -func RecordStart(buf []byte, fieldCount int) []byte { - return pgio.AppendUint32(buf, uint32(fieldCount)) -} - -// RecordAdd adds record field to the buf -func RecordAdd(buf []byte, oid uint32, fieldBytes []byte) []byte { - buf = pgio.AppendUint32(buf, oid) - buf = pgio.AppendUint32(buf, uint32(len(fieldBytes))) - buf = append(buf, fieldBytes...) - return buf -} - -// RecordAddNull adds null value as a field to the buf -func RecordAddNull(buf []byte, oid uint32) []byte { - return pgio.AppendInt32(buf, int32(-1)) -} diff --git a/composite.go b/composite.go index 61034262..6ffe9acf 100644 --- a/composite.go +++ b/composite.go @@ -1,7 +1,9 @@ package pgtype import ( - "github.com/jackc/pgtype/binary" + "encoding/binary" + + "github.com/jackc/pgio" errors "golang.org/x/xerrors" ) @@ -82,7 +84,7 @@ func (dst *Composite) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { return nil } - fieldIter, fieldCount, err := binary.NewRecordFieldIterator(buf) + fieldIter, fieldCount, err := NewRecordFieldIterator(buf) if err != nil { return err } else if len(dst.fields) != fieldCount { @@ -151,3 +153,73 @@ func (dst *Composite) SetFields(values ...interface{}) error { dst.Status = Present return nil } + +type RecordFieldIter struct { + rp int + src []byte +} + +// NewRecordFieldIterator creates iterator over binary representation +// of record, aka ROW(), aka Composite +func NewRecordFieldIterator(src []byte) (RecordFieldIter, int, error) { + rp := 0 + if len(src[rp:]) < 4 { + return RecordFieldIter{}, 0, errors.Errorf("Record incomplete %v", src) + } + + fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + return RecordFieldIter{ + rp: rp, + src: src, + }, fieldCount, nil +} + +// Next returns next field decoded from record. eof is returned if no +// more fields left to decode. +func (fi *RecordFieldIter) Next() (fieldOID uint32, buf []byte, eof bool, err error) { + if fi.rp == len(fi.src) { + eof = true + return + } + + if len(fi.src[fi.rp:]) < 8 { + err = errors.Errorf("Record incomplete %v", fi.src) + return + } + fieldOID = binary.BigEndian.Uint32(fi.src[fi.rp:]) + fi.rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(fi.src[fi.rp:]))) + fi.rp += 4 + + if fieldLen >= 0 { + if len(fi.src[fi.rp:]) < fieldLen { + err = errors.Errorf("Record incomplete rp=%d src=%v", fi.rp, fi.src) + return + } + buf = fi.src[fi.rp : fi.rp+fieldLen] + fi.rp += fieldLen + } + + return +} + +// RecordStart adds record header to the buf +func RecordStart(buf []byte, fieldCount int) []byte { + return pgio.AppendUint32(buf, uint32(fieldCount)) +} + +// RecordAdd adds record field to the buf +func RecordAdd(buf []byte, oid uint32, fieldBytes []byte) []byte { + buf = pgio.AppendUint32(buf, oid) + buf = pgio.AppendUint32(buf, uint32(len(fieldBytes))) + buf = append(buf, fieldBytes...) + return buf +} + +// RecordAddNull adds null value as a field to the buf +func RecordAddNull(buf []byte, oid uint32) []byte { + return pgio.AppendInt32(buf, int32(-1)) +} diff --git a/composite_bench_test.go b/composite_bench_test.go index 429ce9b3..fd31e8ea 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/jackc/pgtype" - "github.com/jackc/pgtype/binary" errors "golang.org/x/xerrors" ) @@ -19,14 +18,14 @@ func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf fieldBytes := make([]byte, 0, 64) fieldBytes, _ = a.EncodeBinary(ci, fieldBytes[:0]) - newBuf = binary.RecordStart(buf, 2) - newBuf = binary.RecordAdd(newBuf, pgtype.Int4OID, fieldBytes) + newBuf = pgtype.RecordStart(buf, 2) + newBuf = pgtype.RecordAdd(newBuf, pgtype.Int4OID, fieldBytes) if src.B != nil { fieldBytes, _ = pgtype.Text{*src.B, pgtype.Present}.EncodeBinary(ci, fieldBytes[:0]) - newBuf = binary.RecordAdd(newBuf, pgtype.TextOID, fieldBytes) + newBuf = pgtype.RecordAdd(newBuf, pgtype.TextOID, fieldBytes) } else { - newBuf = binary.RecordAddNull(newBuf, pgtype.TextOID) + newBuf = pgtype.RecordAddNull(newBuf, pgtype.TextOID) } return } @@ -35,7 +34,7 @@ func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { a := pgtype.Int4{} b := pgtype.Text{} - fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src) + fieldIter, fieldCount, err := pgtype.NewRecordFieldIterator(src) if err != nil { return err } diff --git a/convert.go b/convert.go index 115f33a3..91a32a60 100644 --- a/convert.go +++ b/convert.go @@ -5,7 +5,6 @@ import ( "reflect" "time" - "github.com/jackc/pgtype/binary" errors "golang.org/x/xerrors" ) @@ -443,7 +442,7 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { // // ScanRowValue takes ownership of src, caller MUST not use it after call func ScanRowValue(ci *ConnInfo, src []byte, dst ...interface{}) error { - fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src) + fieldIter, fieldCount, err := NewRecordFieldIterator(src) if err != nil { return err } @@ -472,7 +471,7 @@ func ScanRowValue(ci *ConnInfo, src []byte, dst ...interface{}) error { func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) { fieldBytes := make([]byte, 0, 128) - newBuf = binary.RecordStart(buf, len(fields)) + newBuf = RecordStart(buf, len(fields)) for _, f := range fields { dt, ok := ci.DataTypeForValue(f) if !ok { @@ -487,9 +486,9 @@ func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err er if err != nil { return nil, err } - newBuf = binary.RecordAdd(newBuf, dt.OID, fieldBytes) + newBuf = RecordAdd(newBuf, dt.OID, fieldBytes) } else { - newBuf = binary.RecordAddNull(newBuf, dt.OID) + newBuf = RecordAddNull(newBuf, dt.OID) } } diff --git a/record.go b/record.go index 4e39f92a..b0c47185 100644 --- a/record.go +++ b/record.go @@ -3,8 +3,6 @@ package pgtype import ( "reflect" - "github.com/jackc/pgtype/binary" - errors "golang.org/x/xerrors" ) @@ -104,7 +102,7 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } - fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src) + fieldIter, fieldCount, err := NewRecordFieldIterator(src) if err != nil { return err } From 452511dfc51d2f5948062f96c905849fbe1f4053 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 7 May 2020 13:28:28 -0500 Subject: [PATCH 0465/1158] Rename RecordFieldIter to CompositeBinaryScanner and adjust interface Use interface similar to bufio.Scanner and pgx.Rows. --- composite.go | 111 +++++++++++++++++++++++++--------------- composite_bench_test.go | 28 +++++----- convert.go | 18 +++---- record.go | 19 +++---- 4 files changed, 99 insertions(+), 77 deletions(-) diff --git a/composite.go b/composite.go index 6ffe9acf..4e6b68ca 100644 --- a/composite.go +++ b/composite.go @@ -84,31 +84,29 @@ func (dst *Composite) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { return nil } - fieldIter, fieldCount, err := NewRecordFieldIterator(buf) + scanner, err := NewCompositeBinaryScanner(buf) if err != nil { return err - } else if len(dst.fields) != fieldCount { - return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(dst.fields), fieldCount) + } + if len(dst.fields) != scanner.FieldCount() { + return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(dst.fields), scanner.FieldCount()) } - _, fieldBytes, eof, err := fieldIter.Next() - - for i := 0; !eof; i++ { - if err != nil { - return err - } - + for i := 0; scanner.Scan(); i++ { binaryDecoder, ok := dst.fields[i].(BinaryDecoder) if !ok { return errors.New("Composite field doesn't support binary protocol") } - if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { + if err = binaryDecoder.DecodeBinary(ci, scanner.Bytes()); err != nil { return err } - - _, fieldBytes, eof, err = fieldIter.Next() } + + if scanner.Err() != nil { + return scanner.Err() + } + dst.Status = Present return nil @@ -154,56 +152,85 @@ func (dst *Composite) SetFields(values ...interface{}) error { return nil } -type RecordFieldIter struct { +type CompositeBinaryScanner struct { rp int src []byte + + fieldCount int32 + fieldBytes []byte + fieldOID uint32 + err error } -// NewRecordFieldIterator creates iterator over binary representation -// of record, aka ROW(), aka Composite -func NewRecordFieldIterator(src []byte) (RecordFieldIter, int, error) { +// NewCompositeBinaryScanner a scanner over a binary encoded composite balue. +func NewCompositeBinaryScanner(src []byte) (CompositeBinaryScanner, error) { rp := 0 if len(src[rp:]) < 4 { - return RecordFieldIter{}, 0, errors.Errorf("Record incomplete %v", src) + return CompositeBinaryScanner{}, errors.Errorf("Record incomplete %v", src) } - fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + fieldCount := int32(binary.BigEndian.Uint32(src[rp:])) rp += 4 - return RecordFieldIter{ - rp: rp, - src: src, - }, fieldCount, nil + return CompositeBinaryScanner{ + rp: rp, + src: src, + fieldCount: fieldCount, + }, nil } -// Next returns next field decoded from record. eof is returned if no -// more fields left to decode. -func (fi *RecordFieldIter) Next() (fieldOID uint32, buf []byte, eof bool, err error) { - if fi.rp == len(fi.src) { - eof = true - return +// Scan advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Scan returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeBinaryScanner) Scan() bool { + if cfs.err != nil { + return false } - if len(fi.src[fi.rp:]) < 8 { - err = errors.Errorf("Record incomplete %v", fi.src) - return + if cfs.rp == len(cfs.src) { + return false } - fieldOID = binary.BigEndian.Uint32(fi.src[fi.rp:]) - fi.rp += 4 - fieldLen := int(int32(binary.BigEndian.Uint32(fi.src[fi.rp:]))) - fi.rp += 4 + if len(cfs.src[cfs.rp:]) < 8 { + cfs.err = errors.Errorf("Record incomplete %v", cfs.src) + return false + } + cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:]) + cfs.rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:]))) + cfs.rp += 4 if fieldLen >= 0 { - if len(fi.src[fi.rp:]) < fieldLen { - err = errors.Errorf("Record incomplete rp=%d src=%v", fi.rp, fi.src) - return + if len(cfs.src[cfs.rp:]) < fieldLen { + cfs.err = errors.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src) + return false } - buf = fi.src[fi.rp : fi.rp+fieldLen] - fi.rp += fieldLen + cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen] + cfs.rp += fieldLen + } else { + cfs.fieldBytes = nil } - return + return true +} + +func (cfs *CompositeBinaryScanner) FieldCount() int { + return int(cfs.fieldCount) +} + +// Bytes returns the bytes of the field most recently read by Scan(). +func (cfs *CompositeBinaryScanner) Bytes() []byte { + return cfs.fieldBytes +} + +// OID returns the OID of the field most recently read by Scan(). +func (cfs *CompositeBinaryScanner) OID() uint32 { + return cfs.fieldOID +} + +// Err returns any error encountered by the scanner. +func (cfs *CompositeBinaryScanner) Err() error { + return cfs.err } // RecordStart adds record header to the buf diff --git a/composite_bench_test.go b/composite_bench_test.go index fd31e8ea..fa0f9f61 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -34,29 +34,29 @@ func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { a := pgtype.Int4{} b := pgtype.Text{} - fieldIter, fieldCount, err := pgtype.NewRecordFieldIterator(src) + scanner, err := pgtype.NewCompositeBinaryScanner(src) if err != nil { return err } - if 2 != fieldCount { - return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=2", fieldCount) + if 2 != scanner.FieldCount() { + return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=2", scanner.FieldCount()) } - _, fieldBytes, eof, err := fieldIter.Next() - if eof || err != nil { - return errors.New("Bad record") - } - if err = a.DecodeBinary(ci, fieldBytes); err != nil { - return err + if scanner.Scan() { + if err = a.DecodeBinary(ci, scanner.Bytes()); err != nil { + return err + } } - _, fieldBytes, eof, err = fieldIter.Next() - if eof || err != nil { - return errors.New("Bad record") + if scanner.Scan() { + if err = b.DecodeBinary(ci, scanner.Bytes()); err != nil { + return err + } } - if err = b.DecodeBinary(ci, fieldBytes); err != nil { - return err + + if scanner.Err() != nil { + return scanner.Err() } dst.A = a.Int diff --git a/convert.go b/convert.go index 91a32a60..4fe659b3 100644 --- a/convert.go +++ b/convert.go @@ -442,26 +442,24 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { // // ScanRowValue takes ownership of src, caller MUST not use it after call func ScanRowValue(ci *ConnInfo, src []byte, dst ...interface{}) error { - fieldIter, fieldCount, err := NewRecordFieldIterator(src) + scanner, err := NewCompositeBinaryScanner(src) if err != nil { return err } - if len(dst) != fieldCount { - return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=%d", fieldCount, len(dst)) + if len(dst) != scanner.FieldCount() { + return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=%d", scanner.FieldCount(), len(dst)) } - fieldOID, fieldBytes, eof, err := fieldIter.Next() - for i := 0; !eof; i++ { + for i := 0; scanner.Scan(); i++ { + err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), dst[i]) if err != nil { return err } + } - if err = ci.Scan(fieldOID, BinaryFormatCode, fieldBytes, dst[i]); err != nil { - return err - } - - fieldOID, fieldBytes, eof, err = fieldIter.Next() + if scanner.Err() != nil { + return scanner.Err() } return nil diff --git a/record.go b/record.go index b0c47185..0d51ad4c 100644 --- a/record.go +++ b/record.go @@ -102,29 +102,26 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } - fieldIter, fieldCount, err := NewRecordFieldIterator(src) + scanner, err := NewCompositeBinaryScanner(src) if err != nil { return err } - fields := make([]Value, fieldCount) - fieldOID, fieldBytes, eof, err := fieldIter.Next() + fields := make([]Value, scanner.FieldCount()) - for i := 0; !eof; i++ { + for i := 0; scanner.Scan(); i++ { + binaryDecoder, err := prepareNewBinaryDecoder(ci, scanner.OID(), &fields[i]) if err != nil { return err } - binaryDecoder, err := prepareNewBinaryDecoder(ci, fieldOID, &fields[i]) - if err != nil { + if err = binaryDecoder.DecodeBinary(ci, scanner.Bytes()); err != nil { return err } + } - if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { - return err - } - - fieldOID, fieldBytes, eof, err = fieldIter.Next() + if scanner.Err() != nil { + return scanner.Err() } *dst = Record{Fields: fields, Status: Present} From 4a50a63f121988af42b813db7c637bd397f2d8ae Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 7 May 2020 19:48:48 -0500 Subject: [PATCH 0466/1158] Refactor Scan optimization Instead of hardcoding specific types and skipping type assertions based on that, only check if a destination is a (sql.Scanner) after a failed AssignTo. This is slightly slower in the non-decoder case and *very* slightly faster in the decoder. However, this approach is cleaner and has the potential for further optimizations. --- pgtype.go | 64 +++++++++++++++++++------------------------------------ 1 file changed, 22 insertions(+), 42 deletions(-) diff --git a/pgtype.go b/pgtype.go index eead52af..71babbc8 100644 --- a/pgtype.go +++ b/pgtype.go @@ -3,7 +3,6 @@ package pgtype import ( "database/sql" "reflect" - "time" errors "golang.org/x/xerrors" ) @@ -404,39 +403,17 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { } func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interface{}) error { - isFastType := false - switch dest.(type) { - case *int16: - isFastType = true - case *int32: - isFastType = true - case *int64: - isFastType = true - case *float32: - isFastType = true - case *float64: - isFastType = true - case *string: - isFastType = true - case *time.Time: - isFastType = true - case *[]byte: - isFastType = true - } - - if !isFastType { - switch formatCode { - case BinaryFormatCode: - if dest, ok := dest.(BinaryDecoder); ok { - return dest.DecodeBinary(ci, buf) - } - case TextFormatCode: - if dest, ok := dest.(TextDecoder); ok { - return dest.DecodeText(ci, buf) - } - default: - return errors.Errorf("unknown format code: %v", formatCode) + switch formatCode { + case BinaryFormatCode: + if dest, ok := dest.(BinaryDecoder); ok { + return dest.DecodeBinary(ci, buf) } + case TextFormatCode: + if dest, ok := dest.(TextDecoder); ok { + return dest.DecodeText(ci, buf) + } + default: + return errors.Errorf("unknown format code: %v", formatCode) } if dt, ok := ci.DataTypeForOID(oid); ok { @@ -461,17 +438,20 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac } } - if !isFastType { - if scanner, ok := dest.(sql.Scanner); ok { - sqlSrc, err := DatabaseSQLValue(ci, dt.Value) - if err != nil { - return err - } - return scanner.Scan(sqlSrc) - } + assignToErr := dt.Value.AssignTo(dest) + if assignToErr == nil { + return nil } - return dt.Value.AssignTo(dest) + if scanner, ok := dest.(sql.Scanner); ok { + sqlSrc, err := DatabaseSQLValue(ci, dt.Value) + if err != nil { + return err + } + return scanner.Scan(sqlSrc) + } + + return assignToErr } // We might be given a pointer to something that implements the decoder interface(s), From 97bbe6ae20e262f7f9d210cabebcb77dc4d871f0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 8 May 2020 16:04:16 -0500 Subject: [PATCH 0467/1158] Add RegisterDefaultPgType This allows registering a mapping of a Go type to a PostgreSQL type name. If the OID of a value to be encoded or decoded is unknown, this additional mapping will be used to determine a suitable data type. --- pgtype.go | 128 ++++++++++++++++++++++++++++++++++++++++--------- pgtype_test.go | 19 ++++++-- 2 files changed, 119 insertions(+), 28 deletions(-) diff --git a/pgtype.go b/pgtype.go index 71babbc8..af8d8661 100644 --- a/pgtype.go +++ b/pgtype.go @@ -2,7 +2,9 @@ package pgtype import ( "database/sql" + "net" "reflect" + "time" errors "golang.org/x/xerrors" ) @@ -207,19 +209,25 @@ type DataType struct { type ConnInfo struct { oidToDataType map[uint32]*DataType nameToDataType map[string]*DataType - reflectTypeToDataType map[reflect.Type]*DataType + reflectTypeToName map[reflect.Type]string oidToParamFormatCode map[uint32]int16 oidToResultFormatCode map[uint32]int16 + + reflectTypeToDataType map[reflect.Type]*DataType +} + +func newConnInfo() *ConnInfo { + return &ConnInfo{ + oidToDataType: make(map[uint32]*DataType), + nameToDataType: make(map[string]*DataType), + reflectTypeToName: make(map[reflect.Type]string), + oidToParamFormatCode: make(map[uint32]int16), + oidToResultFormatCode: make(map[uint32]int16), + } } func NewConnInfo() *ConnInfo { - ci := &ConnInfo{ - oidToDataType: make(map[uint32]*DataType, 128), - nameToDataType: make(map[string]*DataType, 128), - reflectTypeToDataType: make(map[reflect.Type]*DataType, 128), - oidToParamFormatCode: make(map[uint32]int16, 128), - oidToResultFormatCode: make(map[uint32]int16, 128), - } + ci := newConnInfo() ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID}) ci.RegisterDataType(DataType{Value: &BoolArray{}, Name: "_bool", OID: BoolArrayOID}) @@ -286,6 +294,42 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Varchar{}, Name: "varchar", OID: VarcharOID}) ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID}) + registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) { + ci.RegisterDefaultPgType(value, name) + valueType := reflect.TypeOf(value) + + ci.RegisterDefaultPgType(reflect.New(valueType).Interface(), name) + + sliceType := reflect.SliceOf(valueType) + ci.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName) + + ci.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName) + } + + // Integer types that directly map to a PostgreSQL type + registerDefaultPgTypeVariants("int2", "_int2", int16(0)) + registerDefaultPgTypeVariants("int4", "_int4", int32(0)) + registerDefaultPgTypeVariants("int8", "_int8", int64(0)) + + // Integer types that do not have a direct match to a PostgreSQL type + registerDefaultPgTypeVariants("int8", "_int8", uint16(0)) + registerDefaultPgTypeVariants("int8", "_int8", uint32(0)) + registerDefaultPgTypeVariants("int8", "_int8", uint64(0)) + registerDefaultPgTypeVariants("int8", "_int8", int(0)) + registerDefaultPgTypeVariants("int8", "_int8", uint(0)) + + registerDefaultPgTypeVariants("float4", "_float4", float32(0)) + registerDefaultPgTypeVariants("float8", "_float8", float64(0)) + + registerDefaultPgTypeVariants("bool", "_bool", false) + registerDefaultPgTypeVariants("timestamptz", "_timestamptz", time.Time{}) + registerDefaultPgTypeVariants("text", "_text", "") + registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil)) + + registerDefaultPgTypeVariants("inet", "_inet", net.IP{}) + ci.RegisterDefaultPgType((*net.IPNet)(nil), "cidr") + ci.RegisterDefaultPgType([]*net.IPNet(nil), "_cidr") + return ci } @@ -302,16 +346,12 @@ func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) { } func (ci *ConnInfo) RegisterDataType(t DataType) { - tv, _ := t.Value.(TypeValue) - if tv != nil { + if tv, ok := t.Value.(TypeValue); ok { t.Value = tv.CloneTypeValue() } ci.oidToDataType[t.OID] = &t ci.nameToDataType[t.Name] = &t - if tv == nil { - ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t - } { var formatCode int16 @@ -336,6 +376,16 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { if d, ok := t.Value.(BinaryDecoder); ok { t.binaryDecoder = d } + + ci.reflectTypeToDataType = nil // Invalidated by type registration +} + +// RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be +// encoded or decoded is determined by the PostgreSQL OID. But if the OID of a value to be encoded or decoded is +// unknown, this additional mapping will be used by DataTypeForValue to determine a suitable data type. +func (ci *ConnInfo) RegisterDefaultPgType(value interface{}, name string) { + ci.reflectTypeToName[reflect.TypeOf(value)] = name + ci.reflectTypeToDataType = nil // Invalidated by registering a default type } func (ci *ConnInfo) DataTypeForOID(oid uint32) (*DataType, bool) { @@ -348,13 +398,35 @@ func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) { return dt, ok } -func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) { +func (ci *ConnInfo) buildReflectTypeToDataType() { + ci.reflectTypeToDataType = make(map[reflect.Type]*DataType) + + for _, dt := range ci.oidToDataType { + if _, is := dt.Value.(TypeValue); !is { + ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt + } + } + + for reflectType, name := range ci.reflectTypeToName { + if dt, ok := ci.nameToDataType[name]; ok { + ci.reflectTypeToDataType[reflectType] = dt + } + } +} + +// DataTypeForValue finds a data type suitable for v. Use RegisterDataType to register types that can encode and decode +// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. +func (ci *ConnInfo) DataTypeForValue(v interface{}) (*DataType, bool) { + if ci.reflectTypeToDataType == nil { + ci.buildReflectTypeToDataType() + } + if tv, ok := v.(TypeValue); ok { dt, ok := ci.nameToDataType[tv.PgTypeName()] return dt, ok } - dt, ok := ci.reflectTypeToDataType[reflect.ValueOf(v).Type()] + dt, ok := ci.reflectTypeToDataType[reflect.TypeOf(v)] return dt, ok } @@ -376,13 +448,7 @@ func (ci *ConnInfo) ResultFormatCodeForOID(oid uint32) int16 { // DeepCopy makes a deep copy of the ConnInfo. func (ci *ConnInfo) DeepCopy() *ConnInfo { - ci2 := &ConnInfo{ - oidToDataType: make(map[uint32]*DataType, len(ci.oidToDataType)), - nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)), - reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)), - oidToParamFormatCode: make(map[uint32]int16, len(ci.oidToParamFormatCode)), - oidToResultFormatCode: make(map[uint32]int16, len(ci.oidToResultFormatCode)), - } + ci2 := newConnInfo() for _, dt := range ci.oidToDataType { var value Value @@ -399,6 +465,10 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { }) } + for t, n := range ci.reflectTypeToName { + ci2.reflectTypeToName[t] = n + } + return ci2 } @@ -416,7 +486,19 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac return errors.Errorf("unknown format code: %v", formatCode) } - if dt, ok := ci.DataTypeForOID(oid); ok { + var dt *DataType + + if oid == 0 { + if dataType, ok := ci.DataTypeForValue(dest); ok { + dt = dataType + } + } else { + if dataType, ok := ci.DataTypeForOID(oid); ok { + dt = dataType + } + } + + if dt != nil { switch formatCode { case BinaryFormatCode: if dt.binaryDecoder != nil { diff --git a/pgtype_test.go b/pgtype_test.go index dee5377d..664c5394 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -104,30 +104,39 @@ func (ct *pgCustomType) DecodeText(ci *pgtype.ConnInfo, buf []byte) error { return nil } -func TestConnInfoScanUnknownOIDToCustomType(t *testing.T) { - unknownOID := uint32(999999) +func TestConnInfoScanUnregisteredOIDToCustomType(t *testing.T) { + unregisteredOID := uint32(999999) ci := pgtype.NewConnInfo() var ct pgCustomType - err := ci.Scan(unknownOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct) + err := ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct) assert.NoError(t, err) assert.Equal(t, "foo", ct.a) assert.Equal(t, "bar", ct.b) // Scan value into pointer to custom type var pCt *pgCustomType - err = ci.Scan(unknownOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt) + err = ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt) assert.NoError(t, err) require.NotNil(t, pCt) assert.Equal(t, "foo", pCt.a) assert.Equal(t, "bar", pCt.b) // Scan null into pointer to custom type - err = ci.Scan(unknownOID, pgx.TextFormatCode, nil, &pCt) + err = ci.Scan(unregisteredOID, pgx.TextFormatCode, nil, &pCt) assert.NoError(t, err) assert.Nil(t, pCt) } +func TestConnInfoScanUnknownOIDTextFormat(t *testing.T) { + ci := pgtype.NewConnInfo() + + var n int32 + err := ci.Scan(0, pgx.TextFormatCode, []byte("123"), &n) + assert.NoError(t, err) + assert.EqualValues(t, 123, n) +} + func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) { ci := pgtype.NewConnInfo() src := []byte{0, 0, 0, 42} From c4e6445cc73142e773864439198a9ccf72767cb1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 May 2020 10:19:39 -0500 Subject: [PATCH 0468/1158] Explicitly test supported Go and PostgreSQL versions --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 87a0c058..0371101f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,8 @@ language: go go: - - 1.x + - 1.14.x + - 1.13.x - tip git: @@ -29,7 +30,6 @@ env: - PGVERSION=10 - PGVERSION=9.6 - PGVERSION=9.5 - - PGVERSION=9.4 cache: directories: From 7e66ab1e146c6da3b53021f4215c2b5e5b735b3c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 May 2020 23:52:48 -0500 Subject: [PATCH 0469/1158] Add scan plan system This can improve performance now and will be useful for the future transcoder system. --- pgtype.go | 310 ++++++++++++++++++++++++++++++++++++++++--------- pgtype_test.go | 38 ++++++ 2 files changed, 290 insertions(+), 58 deletions(-) diff --git a/pgtype.go b/pgtype.go index af8d8661..32c6da5a 100644 --- a/pgtype.go +++ b/pgtype.go @@ -2,6 +2,8 @@ package pgtype import ( "database/sql" + "encoding/binary" + "math" "net" "reflect" "time" @@ -472,76 +474,93 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { return ci2 } -func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interface{}) error { +// ScanPlan is a precompiled plan to scan into a particular destination. This requires care to use as it always scans +// to the same destination. +// +// This is a very low-level optimization. It should only be used to implement a PostgreSQL driver or custom type. +type ScanPlan interface { + // Scan scans src into dst. All parameters except src MUST be the same as were passed to PlanScan when this was + // created. + Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error +} + +type scanPlanDstBinaryDecoder struct { + d BinaryDecoder +} + +func (plan scanPlanDstBinaryDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.d.DecodeBinary(ci, src) +} + +type scanPlanDstTextDecoder struct { + d TextDecoder +} + +func (plan scanPlanDstTextDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.d.DecodeText(ci, src) +} + +type scanPlanDataTypeSQLScanner DataType + +func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + dt := (*DataType)(plan) + var err error switch formatCode { case BinaryFormatCode: - if dest, ok := dest.(BinaryDecoder); ok { - return dest.DecodeBinary(ci, buf) - } + err = dt.binaryDecoder.DecodeBinary(ci, src) case TextFormatCode: - if dest, ok := dest.(TextDecoder); ok { - return dest.DecodeText(ci, buf) - } - default: - return errors.Errorf("unknown format code: %v", formatCode) + err = dt.textDecoder.DecodeText(ci, src) + } + if err != nil { + return err } - var dt *DataType + scanner := dst.(sql.Scanner) + sqlSrc, err := DatabaseSQLValue(ci, dt.Value) + if err != nil { + return err + } + return scanner.Scan(sqlSrc) +} - if oid == 0 { - if dataType, ok := ci.DataTypeForValue(dest); ok { - dt = dataType - } +type scanPlanDataTypeAssignTo DataType + +func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + dt := (*DataType)(plan) + var err error + switch formatCode { + case BinaryFormatCode: + err = dt.binaryDecoder.DecodeBinary(ci, src) + case TextFormatCode: + err = dt.textDecoder.DecodeText(ci, src) + } + if err != nil { + return err + } + + return dt.Value.AssignTo(dst) +} + +type scanPlanSQLScanner struct{} + +func (scanPlanSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := dst.(sql.Scanner) + if formatCode == BinaryFormatCode { + return scanner.Scan(src) } else { - if dataType, ok := ci.DataTypeForOID(oid); ok { - dt = dataType - } + return scanner.Scan(string(src)) } +} - if dt != nil { - switch formatCode { - case BinaryFormatCode: - if dt.binaryDecoder != nil { - err := dt.binaryDecoder.DecodeBinary(ci, buf) - if err != nil { - return err - } - } else { - return errors.Errorf("%T is not a pgtype.BinaryDecoder", dt.Value) - } - case TextFormatCode: - if dt.textDecoder != nil { - err := dt.textDecoder.DecodeText(ci, buf) - if err != nil { - return err - } - } else { - return errors.Errorf("%T is not a pgtype.TextDecoder", dt.Value) - } - } - - assignToErr := dt.Value.AssignTo(dest) - if assignToErr == nil { - return nil - } - - if scanner, ok := dest.(sql.Scanner); ok { - sqlSrc, err := DatabaseSQLValue(ci, dt.Value) - if err != nil { - return err - } - return scanner.Scan(sqlSrc) - } - - return assignToErr - } +type scanPlanReflection struct{} +func (scanPlanReflection) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { // We might be given a pointer to something that implements the decoder interface(s), // even though the pointer itself doesn't. - refVal := reflect.ValueOf(dest) + refVal := reflect.ValueOf(dst) if refVal.Kind() == reflect.Ptr && refVal.Type().Elem().Kind() == reflect.Ptr { // If the database returned NULL, then we set dest as nil to indicate that. - if buf == nil { + if src == nil { nilPtr := reflect.Zero(refVal.Type().Elem()) refVal.Elem().Set(nilPtr) return nil @@ -551,10 +570,185 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac // Then we can retry as that element. elemPtr := reflect.New(refVal.Type().Elem().Elem()) refVal.Elem().Set(elemPtr) - return ci.Scan(oid, formatCode, buf, elemPtr.Interface()) + + plan := ci.PlanScan(oid, formatCode, src, elemPtr.Interface()) + return plan.Scan(ci, oid, formatCode, src, elemPtr.Interface()) } - return scanUnknownType(oid, formatCode, buf, dest) + return scanUnknownType(oid, formatCode, src, dst) +} + +type scanPlanBinaryInt16 int16 + +func (plan *scanPlanBinaryInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return errors.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 2 { + return errors.Errorf("invalid length for int2: %v", len(src)) + } + + *plan = scanPlanBinaryInt16(binary.BigEndian.Uint16(src)) + return nil +} + +type scanPlanBinaryInt32 int32 + +func (plan *scanPlanBinaryInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return errors.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return errors.Errorf("invalid length for int4: %v", len(src)) + } + + *plan = scanPlanBinaryInt32(binary.BigEndian.Uint32(src)) + return nil +} + +type scanPlanBinaryInt64 int64 + +func (plan *scanPlanBinaryInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return errors.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return errors.Errorf("invalid length for int8: %v", len(src)) + } + + *plan = scanPlanBinaryInt64(binary.BigEndian.Uint64(src)) + return nil +} + +type scanPlanBinaryFloat32 float32 + +func (plan *scanPlanBinaryFloat32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return errors.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return errors.Errorf("invalid length for int4: %v", len(src)) + } + + n := int32(binary.BigEndian.Uint32(src)) + *plan = scanPlanBinaryFloat32(math.Float32frombits(uint32(n))) + return nil +} + +type scanPlanBinaryFloat64 float64 + +func (plan *scanPlanBinaryFloat64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return errors.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return errors.Errorf("invalid length for int8: %v", len(src)) + } + + n := int64(binary.BigEndian.Uint64(src)) + *plan = scanPlanBinaryFloat64(math.Float64frombits(uint64(n))) + return nil +} + +type scanPlanBinaryBytes []byte + +func (plan *scanPlanBinaryBytes) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + *plan = scanPlanBinaryBytes(src) + return nil +} + +type scanPlanString string + +func (plan *scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return errors.Errorf("cannot scan null into %T", dst) + } + + *plan = scanPlanString(src) + return nil +} + +// PlanScan prepares a plan to scan a value into dst. +func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, buf []byte, dst interface{}) ScanPlan { + switch formatCode { + case BinaryFormatCode: + switch d := dst.(type) { + case *string: + switch oid { + case TextOID, VarcharOID: + return (*scanPlanString)(d) + } + case *int16: + if oid == Int2OID { + return (*scanPlanBinaryInt16)(d) + } + case *int32: + if oid == Int4OID { + return (*scanPlanBinaryInt32)(d) + } + case *int64: + if oid == Int8OID { + return (*scanPlanBinaryInt64)(d) + } + case *float32: + if oid == Float4OID { + return (*scanPlanBinaryFloat32)(d) + } + case *float64: + if oid == Float8OID { + return (*scanPlanBinaryFloat64)(d) + } + case *[]byte: + switch oid { + case ByteaOID, TextOID, VarcharOID: + return (*scanPlanBinaryBytes)(d) + } + case BinaryDecoder: + return scanPlanDstBinaryDecoder{d: d} + } + case TextFormatCode: + switch d := dst.(type) { + case *string: + return (*scanPlanString)(d) + case TextDecoder: + return scanPlanDstTextDecoder{d: d} + } + } + + var dt *DataType + + if oid == 0 { + if dataType, ok := ci.DataTypeForValue(dst); ok { + dt = dataType + } + } else { + if dataType, ok := ci.DataTypeForOID(oid); ok { + dt = dataType + } + } + + if dt != nil { + if _, ok := dst.(sql.Scanner); ok { + return (*scanPlanDataTypeSQLScanner)(dt) + } + return (*scanPlanDataTypeAssignTo)(dt) + } + + if _, ok := dst.(sql.Scanner); ok { + return scanPlanSQLScanner{} + } + + return scanPlanReflection{} +} + +func (ci *ConnInfo) Scan(oid uint32, formatCode int16, src []byte, dst interface{}) error { + plan := ci.PlanScan(oid, formatCode, src, dst) + return plan.Scan(ci, oid, formatCode, src, dst) } func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) error { diff --git a/pgtype_test.go b/pgtype_test.go index 664c5394..45b1b64d 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -170,3 +170,41 @@ func BenchmarkConnInfoScanInt4IntoGoInt32(b *testing.B) { } } } + +func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v pgtype.Int4 + + plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + + for i := 0; i < b.N; i++ { + v = pgtype.Int4{} + err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != (pgtype.Int4{Int: 42, Status: pgtype.Present}) { + b.Fatal("scan failed due to bad value") + } + } +} + +func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v int32 + + plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + + for i := 0; i < b.N; i++ { + v = 0 + err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != 42 { + b.Fatal("scan failed due to bad value") + } + } +} From 52729c1b77a09611b65b862310fc6dd2b9f77a3e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 10 May 2020 09:28:24 -0500 Subject: [PATCH 0470/1158] Back off some aggressive PlanScan optimizations PlanScan used to require the exact same value be used every time. While this was great for performance, on further consideration I think it is too much of a potential foot-gun. This moves back in the other direction. A plan tolerates a change in destination. It even detects a change in destination type and falls back to a new plan. Perfectly matched hot scan paths (e.g. PG int4 to Go int32) are still much faster than they were before this set of optimizations. The first scan of a destination that uses a decoder is faster due to not allocating. It's a little bit slower on subsequent runs than before this set of optimizations. But it is preferable to optimize for the most common scan targets (e.g. *int32, *int64, *string) over generic decoder destinations. In addition this fees pgx.connRows.Scan from having to check that the destination is unchanged. --- pgtype.go | 174 ++++++++++++++++++++++++++++++++----------------- pgtype_test.go | 17 +++++ 2 files changed, 131 insertions(+), 60 deletions(-) diff --git a/pgtype.go b/pgtype.go index 32c6da5a..fb60f067 100644 --- a/pgtype.go +++ b/pgtype.go @@ -474,35 +474,44 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { return ci2 } -// ScanPlan is a precompiled plan to scan into a particular destination. This requires care to use as it always scans -// to the same destination. -// -// This is a very low-level optimization. It should only be used to implement a PostgreSQL driver or custom type. +// ScanPlan is a precompiled plan to scan into a type of destination. type ScanPlan interface { - // Scan scans src into dst. All parameters except src MUST be the same as were passed to PlanScan when this was - // created. + // Scan scans src into dst. If the dst type has changed in an incompatible way a ScanPlan should automatically + // replan and scan. Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error } -type scanPlanDstBinaryDecoder struct { - d BinaryDecoder +type scanPlanDstBinaryDecoder struct{} + +func (scanPlanDstBinaryDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if d, ok := (dst).(BinaryDecoder); ok { + return d.DecodeBinary(ci, src) + } + + newPlan := ci.PlanScan(oid, formatCode, src, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) } -func (plan scanPlanDstBinaryDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.d.DecodeBinary(ci, src) -} - -type scanPlanDstTextDecoder struct { - d TextDecoder -} +type scanPlanDstTextDecoder struct{} func (plan scanPlanDstTextDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.d.DecodeText(ci, src) + if d, ok := (dst).(TextDecoder); ok { + return d.DecodeText(ci, src) + } + + newPlan := ci.PlanScan(oid, formatCode, src, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) } type scanPlanDataTypeSQLScanner DataType func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner, ok := dst.(sql.Scanner) + if !ok { + newPlan := ci.PlanScan(oid, formatCode, src, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + dt := (*DataType)(plan) var err error switch formatCode { @@ -515,7 +524,6 @@ func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCod return err } - scanner := dst.(sql.Scanner) sqlSrc, err := DatabaseSQLValue(ci, dt.Value) if err != nil { return err @@ -538,7 +546,18 @@ func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode return err } - return dt.Value.AssignTo(dst) + assignToErr := dt.Value.AssignTo(dst) + if assignToErr == nil { + return nil + } + + // assignToErr might have failed because the type of destination has changed + newPlan := ci.PlanScan(oid, formatCode, src, dst) + if newPlan, sameType := newPlan.(*scanPlanDataTypeAssignTo); !sameType { + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + + return assignToErr } type scanPlanSQLScanner struct{} @@ -578,9 +597,9 @@ func (scanPlanReflection) Scan(ci *ConnInfo, oid uint32, formatCode int16, src [ return scanUnknownType(oid, formatCode, src, dst) } -type scanPlanBinaryInt16 int16 +type scanPlanBinaryInt16 struct{} -func (plan *scanPlanBinaryInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { return errors.Errorf("cannot scan null into %T", dst) } @@ -589,13 +608,18 @@ func (plan *scanPlanBinaryInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16 return errors.Errorf("invalid length for int2: %v", len(src)) } - *plan = scanPlanBinaryInt16(binary.BigEndian.Uint16(src)) - return nil + if p, ok := (dst).(*int16); ok { + *p = int16(binary.BigEndian.Uint16(src)) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, src, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) } -type scanPlanBinaryInt32 int32 +type scanPlanBinaryInt32 struct{} -func (plan *scanPlanBinaryInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { return errors.Errorf("cannot scan null into %T", dst) } @@ -604,13 +628,18 @@ func (plan *scanPlanBinaryInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16 return errors.Errorf("invalid length for int4: %v", len(src)) } - *plan = scanPlanBinaryInt32(binary.BigEndian.Uint32(src)) - return nil + if p, ok := (dst).(*int32); ok { + *p = int32(binary.BigEndian.Uint32(src)) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, src, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) } -type scanPlanBinaryInt64 int64 +type scanPlanBinaryInt64 struct{} -func (plan *scanPlanBinaryInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { return errors.Errorf("cannot scan null into %T", dst) } @@ -619,13 +648,18 @@ func (plan *scanPlanBinaryInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16 return errors.Errorf("invalid length for int8: %v", len(src)) } - *plan = scanPlanBinaryInt64(binary.BigEndian.Uint64(src)) - return nil + if p, ok := (dst).(*int64); ok { + *p = int64(binary.BigEndian.Uint64(src)) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, src, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) } -type scanPlanBinaryFloat32 float32 +type scanPlanBinaryFloat32 struct{} -func (plan *scanPlanBinaryFloat32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryFloat32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { return errors.Errorf("cannot scan null into %T", dst) } @@ -634,14 +668,19 @@ func (plan *scanPlanBinaryFloat32) Scan(ci *ConnInfo, oid uint32, formatCode int return errors.Errorf("invalid length for int4: %v", len(src)) } - n := int32(binary.BigEndian.Uint32(src)) - *plan = scanPlanBinaryFloat32(math.Float32frombits(uint32(n))) - return nil + if p, ok := (dst).(*float32); ok { + n := int32(binary.BigEndian.Uint32(src)) + *p = float32(math.Float32frombits(uint32(n))) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, src, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) } -type scanPlanBinaryFloat64 float64 +type scanPlanBinaryFloat64 struct{} -func (plan *scanPlanBinaryFloat64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryFloat64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { return errors.Errorf("cannot scan null into %T", dst) } @@ -650,73 +689,88 @@ func (plan *scanPlanBinaryFloat64) Scan(ci *ConnInfo, oid uint32, formatCode int return errors.Errorf("invalid length for int8: %v", len(src)) } - n := int64(binary.BigEndian.Uint64(src)) - *plan = scanPlanBinaryFloat64(math.Float64frombits(uint64(n))) - return nil + if p, ok := (dst).(*float64); ok { + n := int64(binary.BigEndian.Uint64(src)) + *p = float64(math.Float64frombits(uint64(n))) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, src, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) } -type scanPlanBinaryBytes []byte +type scanPlanBinaryBytes struct{} -func (plan *scanPlanBinaryBytes) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - *plan = scanPlanBinaryBytes(src) - return nil +func (scanPlanBinaryBytes) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if p, ok := (dst).(*[]byte); ok { + *p = src + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, src, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) } -type scanPlanString string +type scanPlanString struct{} -func (plan *scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { return errors.Errorf("cannot scan null into %T", dst) } - *plan = scanPlanString(src) - return nil + if p, ok := (dst).(*string); ok { + *p = string(src) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, src, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) } // PlanScan prepares a plan to scan a value into dst. func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, buf []byte, dst interface{}) ScanPlan { switch formatCode { case BinaryFormatCode: - switch d := dst.(type) { + switch dst.(type) { case *string: switch oid { case TextOID, VarcharOID: - return (*scanPlanString)(d) + return scanPlanString{} } case *int16: if oid == Int2OID { - return (*scanPlanBinaryInt16)(d) + return scanPlanBinaryInt16{} } case *int32: if oid == Int4OID { - return (*scanPlanBinaryInt32)(d) + return scanPlanBinaryInt32{} } case *int64: if oid == Int8OID { - return (*scanPlanBinaryInt64)(d) + return scanPlanBinaryInt64{} } case *float32: if oid == Float4OID { - return (*scanPlanBinaryFloat32)(d) + return scanPlanBinaryFloat32{} } case *float64: if oid == Float8OID { - return (*scanPlanBinaryFloat64)(d) + return scanPlanBinaryFloat64{} } case *[]byte: switch oid { case ByteaOID, TextOID, VarcharOID: - return (*scanPlanBinaryBytes)(d) + return scanPlanBinaryBytes{} } case BinaryDecoder: - return scanPlanDstBinaryDecoder{d: d} + return scanPlanDstBinaryDecoder{} } case TextFormatCode: - switch d := dst.(type) { + switch dst.(type) { case *string: - return (*scanPlanString)(d) + return scanPlanString{} case TextDecoder: - return scanPlanDstTextDecoder{d: d} + return scanPlanDstTextDecoder{} } } diff --git a/pgtype_test.go b/pgtype_test.go index 45b1b64d..e1c49666 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -154,6 +154,23 @@ func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) { } } +func TestScanPlanBinaryInt32ScanChangedType(t *testing.T) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v int32 + + plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + require.NoError(t, err) + require.EqualValues(t, 42, v) + + var d pgtype.Int4 + err = plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &d) + require.NoError(t, err) + require.EqualValues(t, 42, d.Int) + require.EqualValues(t, pgtype.Present, d.Status) +} + func BenchmarkConnInfoScanInt4IntoGoInt32(b *testing.B) { ci := pgtype.NewConnInfo() src := []byte{0, 0, 0, 42} From a71c179ce378f5b18868d744f7ca4b877ac6e5dc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 10 May 2020 12:28:47 -0500 Subject: [PATCH 0471/1158] Extract nullAssignmentError --- convert.go | 4 ++-- pgtype.go | 9 +++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/convert.go b/convert.go index 4fe659b3..47227fd5 100644 --- a/convert.go +++ b/convert.go @@ -369,7 +369,7 @@ func NullAssignTo(dst interface{}) error { // AssignTo dst must always be a pointer if dstPtr.Kind() != reflect.Ptr { - return errors.Errorf("cannot assign NULL to %T", dst) + return &nullAssignmentError{dst: dst} } dstVal := dstPtr.Elem() @@ -380,7 +380,7 @@ func NullAssignTo(dst interface{}) error { return nil } - return errors.Errorf("cannot assign NULL to %T", dst) + return &nullAssignmentError{dst: dst} } var kindTypes map[reflect.Kind]reflect.Type diff --git a/pgtype.go b/pgtype.go index fb60f067..d58be882 100644 --- a/pgtype.go +++ b/pgtype.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql" "encoding/binary" + "fmt" "math" "net" "reflect" @@ -198,6 +199,14 @@ func (f BinaryEncoderFunc) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte var errUndefined = errors.New("cannot encode status undefined") var errBadStatus = errors.New("invalid status") +type nullAssignmentError struct { + dst interface{} +} + +func (e *nullAssignmentError) Error() string { + return fmt.Sprintf("cannot assign NULL to %T", e.dst) +} + type DataType struct { Value Value From cc4d1eafe02c5eb4b8fc60adb5588833b8120a86 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 10 May 2020 12:45:12 -0500 Subject: [PATCH 0472/1158] Doc tweaks and renames --- enum_type.go | 14 +++++++------- pgtype.go | 12 +++++------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/enum_type.go b/enum_type.go index eecd0237..6f52817a 100644 --- a/enum_type.go +++ b/enum_type.go @@ -16,14 +16,14 @@ type enumType struct { value string status Status - pgTypeName string // PostgreSQL type name + typeName string // PostgreSQL type name members []string // enum members membersMap map[string]string // map to quickly lookup member and reuse string instead of allocating } // NewEnumType initializes a new EnumType. It retains a read-only reference to members. members must not be changed. -func NewEnumType(pgTypeName string, members []string) EnumType { - et := &enumType{pgTypeName: pgTypeName, members: members} +func NewEnumType(typeName string, members []string) EnumType { + et := &enumType{typeName: typeName, members: members} et.membersMap = make(map[string]string, len(members)) for _, m := range members { et.membersMap[m] = m @@ -36,14 +36,14 @@ func (et *enumType) CloneTypeValue() Value { value: et.value, status: et.status, - pgTypeName: et.pgTypeName, + typeName: et.typeName, members: et.members, membersMap: et.membersMap, } } -func (et *enumType) PgTypeName() string { - return et.pgTypeName +func (et *enumType) TypeName() string { + return et.typeName } func (et *enumType) Members() []string { @@ -87,7 +87,7 @@ func (dst *enumType) Set(src interface{}) error { if originalSrc, ok := underlyingStringType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to enum %s", value, dst.pgTypeName) + return errors.Errorf("cannot convert %v to enum %s", value, dst.typeName) } return nil diff --git a/pgtype.go b/pgtype.go index d58be882..7c893360 100644 --- a/pgtype.go +++ b/pgtype.go @@ -128,10 +128,8 @@ type Value interface { AssignTo(dst interface{}) error } -// TypeValue represents values where instances represent a type. In the normal pgtype model a Go type maps to a -// PostgreSQL type and an instance of a Go type maps to a PostgreSQL value of that type. Implementors of TypeValue -// are different in that an instance represents a PostgreSQL type. This can be useful for representing types such -// as enums, composites, and arrays. +// TypeValue represents values where instances can represent different PostgreSQL types. This can be useful for +// representing types such as enums, composites, and arrays. // // In general, instances of TypeValue should not be used to directly represent a value. It should only be used as an // encoder and decoder internal to ConnInfo. @@ -140,8 +138,8 @@ type TypeValue interface { // in an EnumType. CloneTypeValue() Value - // PgTypeName returns the PostgreSQL name of this type. - PgTypeName() string + // TypeName returns the PostgreSQL name of this type. + TypeName() string } type BinaryDecoder interface { @@ -433,7 +431,7 @@ func (ci *ConnInfo) DataTypeForValue(v interface{}) (*DataType, bool) { } if tv, ok := v.(TypeValue); ok { - dt, ok := ci.nameToDataType[tv.PgTypeName()] + dt, ok := ci.nameToDataType[tv.TypeName()] return dt, ok } From 8cd94a14c75abb3e98c361d7bb2517380d82ab58 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 10 May 2020 14:05:16 -0500 Subject: [PATCH 0473/1158] Allow types to specify preference format result and param formats This will be useful for array and composite types that may have to support elements that may not support binary encoding. It also is slightly more convenient for text-ish types to have a default format of text. --- bpchar.go | 8 ++++++++ enum_type.go | 8 ++++++++ json.go | 8 ++++++++ jsonb.go | 8 ++++++++ pgtype.go | 20 ++++++++++++++++++-- pgtype_test.go | 20 ++++++++++++++++++++ text.go | 8 ++++++++ varchar.go | 8 ++++++++ 8 files changed, 86 insertions(+), 2 deletions(-) diff --git a/bpchar.go b/bpchar.go index f82e3724..e4d058e9 100644 --- a/bpchar.go +++ b/bpchar.go @@ -33,6 +33,10 @@ func (src *BPChar) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } +func (BPChar) PreferredResultFormat() int16 { + return TextFormatCode +} + func (dst *BPChar) DecodeText(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeText(ci, src) } @@ -41,6 +45,10 @@ func (dst *BPChar) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } +func (BPChar) PreferredParamFormat() int16 { + return TextFormatCode +} + func (src BPChar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return (Text)(src).EncodeText(ci, buf) } diff --git a/enum_type.go b/enum_type.go index 6f52817a..1a6a4b46 100644 --- a/enum_type.go +++ b/enum_type.go @@ -128,6 +128,10 @@ func (src *enumType) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (enumType) PreferredResultFormat() int16 { + return TextFormatCode +} + func (dst *enumType) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { dst.status = Null @@ -152,6 +156,10 @@ func (dst *enumType) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } +func (enumType) PreferredParamFormat() int16 { + return TextFormatCode +} + func (src enumType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.status { case Null: diff --git a/json.go b/json.go index c642c727..922da50d 100644 --- a/json.go +++ b/json.go @@ -113,6 +113,10 @@ func (src *JSON) AssignTo(dst interface{}) error { return nil } +func (JSON) PreferredResultFormat() int16 { + return TextFormatCode +} + func (dst *JSON) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = JSON{Status: Null} @@ -127,6 +131,10 @@ func (dst *JSON) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } +func (JSON) PreferredParamFormat() int16 { + return TextFormatCode +} + func (src JSON) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: diff --git a/jsonb.go b/jsonb.go index 984c0973..c129ac9b 100644 --- a/jsonb.go +++ b/jsonb.go @@ -20,6 +20,10 @@ func (src *JSONB) AssignTo(dst interface{}) error { return (*JSON)(src).AssignTo(dst) } +func (JSONB) PreferredResultFormat() int16 { + return TextFormatCode +} + func (dst *JSONB) DecodeText(ci *ConnInfo, src []byte) error { return (*JSON)(dst).DecodeText(ci, src) } @@ -43,6 +47,10 @@ func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { } +func (JSONB) PreferredParamFormat() int16 { + return TextFormatCode +} + func (src JSONB) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return (JSON)(src).EncodeText(ci, buf) } diff --git a/pgtype.go b/pgtype.go index 7c893360..ed676929 100644 --- a/pgtype.go +++ b/pgtype.go @@ -142,6 +142,18 @@ type TypeValue interface { TypeName() string } +// ResultFormatPreferrer allows a type to specify its preferred result format instead of it being inferred from +// whether it is also a BinaryDecoder. +type ResultFormatPreferrer interface { + PreferredResultFormat() int16 +} + +// ParamFormatPreferrer allows a type to specify its preferred param format instead of it being inferred from +// whether it is also a BinaryEncoder. +type ParamFormatPreferrer interface { + PreferredParamFormat() int16 +} + type BinaryDecoder interface { // DecodeBinary decodes src into BinaryDecoder. If src is nil then the // original SQL value is NULL. BinaryDecoder takes ownership of src. The @@ -364,7 +376,9 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { { var formatCode int16 - if _, ok := t.Value.(BinaryEncoder); ok { + if pfp, ok := t.Value.(ParamFormatPreferrer); ok { + formatCode = pfp.PreferredParamFormat() + } else if _, ok := t.Value.(BinaryEncoder); ok { formatCode = BinaryFormatCode } ci.oidToParamFormatCode[t.OID] = formatCode @@ -372,7 +386,9 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { { var formatCode int16 - if _, ok := t.Value.(BinaryDecoder); ok { + if rfp, ok := t.Value.(ResultFormatPreferrer); ok { + formatCode = rfp.PreferredResultFormat() + } else if _, ok := t.Value.(BinaryDecoder); ok { formatCode = BinaryFormatCode } ci.oidToResultFormatCode[t.OID] = formatCode diff --git a/pgtype_test.go b/pgtype_test.go index e1c49666..a96720d5 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -44,6 +44,26 @@ func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { return addr } +func TestConnInfoResultFormatCodeForOID(t *testing.T) { + ci := pgtype.NewConnInfo() + + // pgtype.JSONB implements BinaryDecoder but also implements ResultFormatPreferrer to override it to text. + assert.Equal(t, int16(pgtype.TextFormatCode), ci.ResultFormatCodeForOID(pgtype.JSONBOID)) + + // pgtype.Int4 implements BinaryDecoder but does not implement ResultFormatPreferrer so it should be binary. + assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.ResultFormatCodeForOID(pgtype.Int4OID)) +} + +func TestConnInfoParamFormatCodeForOID(t *testing.T) { + ci := pgtype.NewConnInfo() + + // pgtype.JSONB implements BinaryEncoder but also implements ParamFormatPreferrer to override it to text. + assert.Equal(t, int16(pgtype.TextFormatCode), ci.ParamFormatCodeForOID(pgtype.JSONBOID)) + + // pgtype.Int4 implements BinaryEncoder but does not implement ParamFormatPreferrer so it should be binary. + assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.ParamFormatCodeForOID(pgtype.Int4OID)) +} + func TestConnInfoScanUnknownOIDToStringsAndBytes(t *testing.T) { unknownOID := uint32(999999) srcBuf := []byte("foo") diff --git a/text.go b/text.go index 1f5d2a37..4c9e4a21 100644 --- a/text.go +++ b/text.go @@ -85,6 +85,10 @@ func (src *Text) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (Text) PreferredResultFormat() int16 { + return TextFormatCode +} + func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Text{Status: Null} @@ -99,6 +103,10 @@ func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } +func (Text) PreferredParamFormat() int16 { + return TextFormatCode +} + func (src Text) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: diff --git a/varchar.go b/varchar.go index e4fa6869..fea31d18 100644 --- a/varchar.go +++ b/varchar.go @@ -23,6 +23,10 @@ func (src *Varchar) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } +func (Varchar) PreferredResultFormat() int16 { + return TextFormatCode +} + func (dst *Varchar) DecodeText(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeText(ci, src) } @@ -31,6 +35,10 @@ func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } +func (Varchar) PreferredParamFormat() int16 { + return TextFormatCode +} + func (src Varchar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return (Text)(src).EncodeText(ci, buf) } From 6cef4638ad804465973b8894af24111888cd1135 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 10 May 2020 14:11:24 -0500 Subject: [PATCH 0474/1158] Update pgx dependency for tests --- go.mod | 2 +- go.sum | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 35991562..35ba688e 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.12 require ( github.com/gofrs/uuid v3.2.0+incompatible github.com/jackc/pgio v1.0.0 - github.com/jackc/pgx/v4 v4.5.0 + github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9 github.com/lib/pq v1.3.0 github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc github.com/stretchr/testify v1.5.1 diff --git a/go.sum b/go.sum index 5e75654d..a4816869 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,4 @@ +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -10,6 +11,7 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= @@ -23,6 +25,8 @@ github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= github.com/jackc/pgconn v1.4.0 h1:E82UBzFyD752mvI+4RIl1WSxfO2ug64T+sLjvDBWTpA= github.com/jackc/pgconn v1.4.0/go.mod h1:Y2O3ZDF0q4mMacyWV3AstPJpeHXWGEetiFttmq5lahk= +github.com/jackc/pgconn v1.5.0 h1:oFSOilzIZkyg787M1fEmyMfOUUvwj0daqYMfaWwNL4o= +github.com/jackc/pgconn v1.5.0/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA= @@ -45,6 +49,8 @@ github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01C github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4WpC0= +github.com/jackc/pgtype v1.3.1-0.20200510045248-7e66ab1e146c/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= +github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96 h1:ylEAOd688Duev/fxTmGdupsbyZfxNMdngIG14DoBKTM= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912 h1:YuOWGsSK5L4Fz81Olx5TNlZftmDuNrfv4ip0Yos77Tw= @@ -53,9 +59,13 @@ github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186 h1:ZQM8qLT/E/C github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= github.com/jackc/pgx/v4 v4.5.0 h1:mN7Z3n0uqPe29+tA4yLWyZNceYKgRvUWNk8qW+D066E= github.com/jackc/pgx/v4 v4.5.0/go.mod h1:EpAKPLdnTorwmPUUsqrPxy5fphV18j9q3wrfRXgo+kA= +github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9 h1:rche9LTjh3HEvkE6eb8ITYxRsgEKgBkODHrhdvDVX74= +github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9/go.mod h1:t3/cdRQl6fOLDxqtlyhe9UWgfIi9R8+8v8GKV5TRA/o= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= @@ -81,6 +91,7 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= @@ -105,16 +116,24 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7 h1:0hQKqeLdqlt5iIwVOBErRisrHJAN57yOiPRQItI20fU= golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 h1:3zb4D3T4G8jdExgVU/95+vQXfpEPiMdCaZgmGVxjNHM= +golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -133,8 +152,12 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -145,6 +168,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= From 1b3d694469966654768ac551936311af35095a15 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 10 May 2020 19:34:49 -0500 Subject: [PATCH 0475/1158] Add ArrayType --- array_type.go | 350 +++++++++++++++++++++++++++++++++++++++++++++ array_type_test.go | 62 ++++++++ pgtype.go | 9 ++ 3 files changed, 421 insertions(+) create mode 100644 array_type.go create mode 100644 array_type_test.go diff --git a/array_type.go b/array_type.go new file mode 100644 index 00000000..f25051d1 --- /dev/null +++ b/array_type.go @@ -0,0 +1,350 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "reflect" + + "github.com/jackc/pgio" + errors "golang.org/x/xerrors" +) + +// ArrayType represents an array type. While it implements Value, this is only in service of its type conversion duties +// when registered as a data type in a ConnType. It should not be used directly as a Value. ArrayType is a convenience +// type for types that do not have an concrete array type. +type ArrayType struct { + elements []ValueTranscoder + dimensions []ArrayDimension + status Status + + typeName string + newElement func() ValueTranscoder +} + +func NewArrayType(typeName string, newElement func() ValueTranscoder) *ArrayType { + return &ArrayType{typeName: typeName, newElement: newElement} +} + +func (at *ArrayType) CloneTypeValue() Value { + return &ArrayType{ + elements: at.elements, + dimensions: at.dimensions, + status: at.status, + + typeName: at.typeName, + newElement: at.newElement, + } +} + +func (at *ArrayType) TypeName() string { + return at.typeName +} + +func (dst *ArrayType) setNil() { + dst.elements = nil + dst.dimensions = nil + dst.status = Null +} + +func (dst *ArrayType) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + dst.setNil() + return nil + } + + sliceVal := reflect.ValueOf(src) + if sliceVal.Kind() != reflect.Slice { + return errors.Errorf("cannot set non-slice") + } + + if sliceVal.IsNil() { + dst.setNil() + return nil + } + + dst.elements = make([]ValueTranscoder, sliceVal.Len()) + for i := range dst.elements { + v := dst.newElement() + err := v.Set(sliceVal.Index(i).Interface()) + if err != nil { + return err + } + + dst.elements[i] = v + } + dst.dimensions = []ArrayDimension{{Length: int32(len(dst.elements)), LowerBound: 1}} + dst.status = Present + + return nil +} + +func (dst ArrayType) Get() interface{} { + switch dst.status { + case Present: + return dst.elements + case Null: + return nil + default: + return dst.status + } +} + +func (src *ArrayType) AssignTo(dst interface{}) error { + ptrSlice := reflect.ValueOf(dst) + if ptrSlice.Kind() != reflect.Ptr { + return errors.Errorf("cannot assign to non-pointer") + } + + sliceVal := ptrSlice.Elem() + sliceType := sliceVal.Type() + + if sliceType.Kind() != reflect.Slice { + return errors.Errorf("cannot assign to pointer to non-slice") + } + + switch src.status { + case Present: + slice := reflect.MakeSlice(sliceType, len(src.elements), len(src.elements)) + elemType := sliceType.Elem() + + for i := range src.elements { + ptrElem := reflect.New(elemType) + err := src.elements[i].AssignTo(ptrElem.Interface()) + if err != nil { + return err + } + + slice.Index(i).Set(ptrElem.Elem()) + } + + sliceVal.Set(slice) + return nil + case Null: + sliceVal.Set(reflect.Zero(sliceType)) + return nil + } + + return errors.Errorf("cannot decode %#v into %T", src, dst) +} + +func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + dst.setNil() + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []ValueTranscoder + + if len(uta.Elements) > 0 { + elements = make([]ValueTranscoder, len(uta.Elements)) + + for i, s := range uta.Elements { + elem := dst.newElement() + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + dst.elements = elements + dst.dimensions = uta.Dimensions + dst.status = Present + + return nil +} + +func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + dst.setNil() + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = ArrayType{dimensions: arrayHeader.Dimensions, status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]ValueTranscoder, elementCount) + + for i := range elements { + elem := dst.newElement() + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elem.DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + + dst.elements = elements + dst.dimensions = arrayHeader.Dimensions + dst.status = Present + + return nil +} + +func (src ArrayType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.dimensions)) + dimElemCounts[len(src.dimensions)-1] = int(src.dimensions[len(src.dimensions)-1].Length) + for i := len(src.dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src ArrayType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.dimensions, + } + + { + value := src.newElement() + if dt, ok := ci.DataTypeForValue(value); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for element type %v", value) + } + } + + for i := range src.elements { + if src.elements[i].Get() == nil { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *ArrayType) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src ArrayType) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/array_type_test.go b/array_type_test.go new file mode 100644 index 00000000..d0812a67 --- /dev/null +++ b/array_type_test.go @@ -0,0 +1,62 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestArrayTypeValue(t *testing.T) { + arrayType := pgtype.NewArrayType("_text", func() pgtype.ValueTranscoder { return &pgtype.Text{} }) + + err := arrayType.Set(nil) + require.NoError(t, err) + + gotValue := arrayType.Get() + require.Nil(t, gotValue) + + slice := []string{"foo", "bar"} + err = arrayType.AssignTo(&slice) + require.NoError(t, err) + require.Nil(t, slice) + + err = arrayType.Set([]string{}) + require.NoError(t, err) + + gotValue = arrayType.Get() + require.Len(t, gotValue, 0) + + err = arrayType.AssignTo(&slice) + require.NoError(t, err) + require.EqualValues(t, []string{}, slice) + + err = arrayType.Set([]string{"baz", "quz"}) + require.NoError(t, err) + + gotValue = arrayType.Get() + require.Len(t, gotValue, 2) + + err = arrayType.AssignTo(&slice) + require.NoError(t, err) + require.EqualValues(t, []string{"baz", "quz"}, slice) +} + +func TestArrayTypeTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + conn.ConnInfo().RegisterDataType(pgtype.DataType{ + Value: pgtype.NewArrayType("_text", func() pgtype.ValueTranscoder { return &pgtype.Text{} }), + Name: "_text", + OID: pgtype.TextArrayOID, + }) + + var dstStrings []string + err := conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings) + require.NoError(t, err) + + require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) +} diff --git a/pgtype.go b/pgtype.go index ed676929..7fee66b3 100644 --- a/pgtype.go +++ b/pgtype.go @@ -142,6 +142,15 @@ type TypeValue interface { TypeName() string } +// ValueTranscoder is a value that implements the text and binary encoding and decoding interfaces. +type ValueTranscoder interface { + Value + TextEncoder + BinaryEncoder + TextDecoder + BinaryDecoder +} + // ResultFormatPreferrer allows a type to specify its preferred result format instead of it being inferred from // whether it is also a BinaryDecoder. type ResultFormatPreferrer interface { From 36dbbd983d2fb03012f7c42dda3b16bdf2a92afd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 11 May 2020 17:21:21 -0500 Subject: [PATCH 0476/1158] Add CompositeFields type This adds support for the text format and removes the need for the ScanRowValue function. --- composite.go | 90 ++++++++++++++++++++++++++++ composite_fields.go | 76 +++++++++++++++++++++++ composite_fields_test.go | 126 +++++++++++++++++++++++++++++++++++++++ convert.go | 32 ---------- custom_composite_test.go | 2 +- record_test.go | 48 --------------- 6 files changed, 293 insertions(+), 81 deletions(-) create mode 100644 composite_fields.go create mode 100644 composite_fields_test.go diff --git a/composite.go b/composite.go index 4e6b68ca..59549736 100644 --- a/composite.go +++ b/composite.go @@ -233,6 +233,96 @@ func (cfs *CompositeBinaryScanner) Err() error { return cfs.err } +type CompositeTextScanner struct { + rp int + src []byte + + fieldBytes []byte + err error +} + +// NewCompositeTextScanner a scanner over a text encoded composite balue. +func NewCompositeTextScanner(src []byte) (CompositeTextScanner, error) { + if len(src) < 2 { + return CompositeTextScanner{}, errors.Errorf("Record incomplete %v", src) + } + + if src[0] != '(' { + return CompositeTextScanner{}, errors.Errorf("composite text format must start with '('") + } + + if src[len(src)-1] != ')' { + return CompositeTextScanner{}, errors.Errorf("composite text format must end with ')'") + } + + return CompositeTextScanner{ + rp: 1, + src: src, + }, nil +} + +// Scan advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Scan returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeTextScanner) Scan() bool { + if cfs.err != nil { + return false + } + + if cfs.rp == len(cfs.src) { + return false + } + + switch cfs.src[cfs.rp] { + case ',', ')': // null + cfs.rp++ + cfs.fieldBytes = nil + return true + case '"': // quoted value + cfs.rp++ + cfs.fieldBytes = make([]byte, 0, 16) + for { + ch := cfs.src[cfs.rp] + + if ch == '"' { + cfs.rp++ + if cfs.src[cfs.rp] == '"' { + cfs.fieldBytes = append(cfs.fieldBytes, '"') + cfs.rp++ + } else { + break + } + } else { + cfs.fieldBytes = append(cfs.fieldBytes, ch) + cfs.rp++ + } + } + cfs.rp++ + return true + default: // unquoted value + start := cfs.rp + for { + ch := cfs.src[cfs.rp] + if ch == ',' || ch == ')' { + break + } + cfs.rp++ + } + cfs.fieldBytes = cfs.src[start:cfs.rp] + cfs.rp++ + return true + } +} + +// Bytes returns the bytes of the field most recently read by Scan(). +func (cfs *CompositeTextScanner) Bytes() []byte { + return cfs.fieldBytes +} + +// Err returns any error encountered by the scanner. +func (cfs *CompositeTextScanner) Err() error { + return cfs.err +} + // RecordStart adds record header to the buf func RecordStart(buf []byte, fieldCount int) []byte { return pgio.AppendUint32(buf, uint32(fieldCount)) diff --git a/composite_fields.go b/composite_fields.go new file mode 100644 index 00000000..64a17b55 --- /dev/null +++ b/composite_fields.go @@ -0,0 +1,76 @@ +package pgtype + +import ( + errors "golang.org/x/xerrors" +) + +// CompositeFields scans the fields of a composite type into the elements of the CompositeFields value. To scan a +// nullable value use a *CompositeFields. It will be set to nil in case of null. +type CompositeFields []interface{} + +func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error { + if len(cf) == 0 { + return errors.Errorf("cannot decode into empty CompositeFields") + } + + if src == nil { + return errors.Errorf("cannot decode unexpected null into CompositeFields") + } + + scanner, err := NewCompositeBinaryScanner(src) + if err != nil { + return err + } + if len(cf) != scanner.FieldCount() { + return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), scanner.FieldCount()) + } + + for i := 0; scanner.Scan(); i++ { + err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), cf[i]) + if err != nil { + return err + } + } + + if scanner.Err() != nil { + return scanner.Err() + } + + return nil +} + +func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error { + if len(cf) == 0 { + return errors.Errorf("cannot decode into empty CompositeFields") + } + + if src == nil { + return errors.Errorf("cannot decode unexpected null into CompositeFields") + } + + scanner, err := NewCompositeTextScanner(src) + if err != nil { + return err + } + + fieldCount := 0 + + for i := 0; scanner.Scan(); i++ { + err := ci.Scan(0, TextFormatCode, scanner.Bytes(), cf[i]) + if err != nil { + return err + } + + fieldCount += 1 + } + + if scanner.Err() != nil { + return scanner.Err() + } + + if len(cf) != fieldCount { + return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), fieldCount) + } + + return nil +} diff --git a/composite_fields_test.go b/composite_fields_test.go new file mode 100644 index 00000000..d53e48ec --- /dev/null +++ b/composite_fields_test.go @@ -0,0 +1,126 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" +) + +func TestCompositeFieldsDecode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + formats := []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} + + // Assorted values + { + var a int32 + var b string + var c float64 + + for _, format := range formats { + err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, &b, &c}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.EqualValuesf(t, 1, a, "Format: %v", format) + assert.EqualValuesf(t, "hi", b, "Format: %v", format) + assert.EqualValuesf(t, 2.1, c, "Format: %v", format) + } + } + + // nulls, string "null", and empty string fields + { + var a pgtype.Text + var b string + var c pgtype.Text + var d string + var e pgtype.Text + + for _, format := range formats { + err := conn.QueryRow(context.Background(), "select row(null,'null',null,'',null)", pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.Nilf(t, a.Get(), "Format: %v", format) + assert.EqualValuesf(t, "null", b, "Format: %v", format) + assert.Nilf(t, c.Get(), "Format: %v", format) + assert.EqualValuesf(t, "", d, "Format: %v", format) + assert.Nilf(t, e.Get(), "Format: %v", format) + } + } + + // null record + { + var a pgtype.Text + var b string + cf := pgtype.CompositeFields{&a, &b} + + for _, format := range formats { + // Cannot scan nil into + err := conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan( + cf, + ) + if assert.Errorf(t, err, "Format: %v", format) { + continue + } + assert.NotNilf(t, cf, "Format: %v", format) + + // But can scan nil into *pgtype.CompositeFields + err = conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan( + &cf, + ) + if assert.Errorf(t, err, "Format: %v", format) { + continue + } + assert.Nilf(t, cf, "Format: %v", format) + } + } + + // quotes and special characters + { + var a, b, c, d string + + for _, format := range formats { + err := conn.QueryRow(context.Background(), `select row('"', 'foo bar', 'foo''bar', 'baz)bar')`, pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, &b, &c, &d}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.Equalf(t, `"`, a, "Format: %v", format) + assert.Equalf(t, `foo bar`, b, "Format: %v", format) + assert.Equalf(t, `foo'bar`, c, "Format: %v", format) + assert.Equalf(t, `baz)bar`, d, "Format: %v", format) + } + } + + // arrays + { + var a []string + var b []int64 + + for _, format := range formats { + err := conn.QueryRow(context.Background(), `select row(array['foo', 'bar', 'baz'], array[1,2,3])`, pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, &b}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.EqualValuesf(t, []string{"foo", "bar", "baz"}, a, "Format: %v", format) + assert.EqualValuesf(t, []int64{1, 2, 3}, b, "Format: %v", format) + } + } +} diff --git a/convert.go b/convert.go index 47227fd5..6e70e82e 100644 --- a/convert.go +++ b/convert.go @@ -433,38 +433,6 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { return nil, false } -// ScanRowValue decodes ROW()'s and composite type -// from src argument using provided decoders. Decoders should match -// order and count of fields of record being decoded. -// -// In practice you can pass pgtype.Value types as decoders, as -// most of them implement BinaryDecoder interface. -// -// ScanRowValue takes ownership of src, caller MUST not use it after call -func ScanRowValue(ci *ConnInfo, src []byte, dst ...interface{}) error { - scanner, err := NewCompositeBinaryScanner(src) - if err != nil { - return err - } - - if len(dst) != scanner.FieldCount() { - return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=%d", scanner.FieldCount(), len(dst)) - } - - for i := 0; scanner.Scan(); i++ { - err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), dst[i]) - if err != nil { - return err - } - } - - if scanner.Err() != nil { - return scanner.Err() - } - - return nil -} - // EncodeRow builds a binary representation of row values (row(), composite types) func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) { fieldBytes := make([]byte, 0, 128) diff --git a/custom_composite_test.go b/custom_composite_test.go index f6f37ec7..a93a8ad0 100644 --- a/custom_composite_test.go +++ b/custom_composite_test.go @@ -20,7 +20,7 @@ func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs") } - if err := pgtype.ScanRowValue(ci, src, &dst.a, &dst.b); err != nil { + if err := (pgtype.CompositeFields{&dst.a, &dst.b}).DecodeBinary(ci, src); err != nil { return err } diff --git a/record_test.go b/record_test.go index 3794fcd7..240812a6 100644 --- a/record_test.go +++ b/record_test.go @@ -79,54 +79,6 @@ var recordTests = []struct { }, } -// row values are binary compatible with records, so we test our helper -// routines here -func TestScanRowValue(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - for i := 0; i < len(recordTests); i++ { - tt := recordTests[i] - psName := fmt.Sprintf("test%d", i) - _, err := conn.Prepare(context.Background(), psName, tt.sql) - if err != nil { - t.Fatal(err) - } - t.Run(tt.sql, func(t *testing.T) { - desc := []interface{}{} - for _, f := range tt.expected.Fields { - desc = append(desc, f.(pgtype.BinaryDecoder)) - } - - var raw pgtype.GenericBinary - - if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&raw); err != nil { - t.Error(err) - return - } - - if raw.Status == pgtype.Null { - // ScanRowValue deals with complete rows only, NULL values (but NOT null fields) - // should be handled by the calling code - return - } - - if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err != nil { - t.Error(err) - } - - // borrow fields from a neighbor test, this makes scan always fail - desc = desc[:0] - for _, f := range recordTests[(i+1)%len(recordTests)].expected.Fields { - desc = append(desc, f.(pgtype.BinaryDecoder)) - } - if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err == nil { - t.Error("Matching scan didn't fail, despite fields not mathching query result") - } - }) - } -} - func TestRecordTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) From 036101deb508cf0bd2b3310567ce14ab90a5e3a4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 11 May 2020 17:41:20 -0500 Subject: [PATCH 0477/1158] Allow scanning to nil as no-op --- pgtype.go | 4 ++++ pgtype_test.go | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/pgtype.go b/pgtype.go index 7fee66b3..193980ef 100644 --- a/pgtype.go +++ b/pgtype.go @@ -833,6 +833,10 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, buf []byte, dst inter } func (ci *ConnInfo) Scan(oid uint32, formatCode int16, src []byte, dst interface{}) error { + if dst == nil { + return nil + } + plan := ci.PlanScan(oid, formatCode, src, dst) return plan.Scan(ci, oid, formatCode, src, dst) } diff --git a/pgtype_test.go b/pgtype_test.go index a96720d5..6bdbe7c8 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -64,6 +64,13 @@ func TestConnInfoParamFormatCodeForOID(t *testing.T) { assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.ParamFormatCodeForOID(pgtype.Int4OID)) } +func TestConnInfoScanNilIsNoOp(t *testing.T) { + ci := pgtype.NewConnInfo() + + err := ci.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), nil) + assert.NoError(t, err) +} + func TestConnInfoScanUnknownOIDToStringsAndBytes(t *testing.T) { unknownOID := uint32(999999) srcBuf := []byte("foo") From 08d071c0944e1d60af437ce4e11da55d7150385e Mon Sep 17 00:00:00 2001 From: Lukas Vogel Date: Fri, 8 May 2020 13:38:34 +0200 Subject: [PATCH 0478/1158] Handle IPv6 in connection URLs Previously IPv6 addresses were wrongly split and lead to a parse error. This commit fixes the behavior. --- config.go | 20 +++++++++++++++----- config_test.go | 46 +++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/config.go b/config.go index 06184b02..a9b19d67 100644 --- a/config.go +++ b/config.go @@ -399,13 +399,19 @@ func parseURLSettings(connString string) (map[string]string, error) { var hosts []string var ports []string for _, host := range strings.Split(url.Host, ",") { - parts := strings.SplitN(host, ":", 2) - if parts[0] != "" { - hosts = append(hosts, parts[0]) + if host == "" { + continue } - if len(parts) == 2 { - ports = append(ports, parts[1]) + if isIPOnly(host) { + hosts = append(hosts, strings.Trim(host, "[]")) + continue } + h, p, err := net.SplitHostPort(host) + if err != nil { + return nil, errors.Errorf("failed to split host:port in '%s', err: %w", host, err) + } + hosts = append(hosts, h) + ports = append(ports, p) } if len(hosts) > 0 { settings["host"] = strings.Join(hosts, ",") @@ -426,6 +432,10 @@ func parseURLSettings(connString string) (map[string]string, error) { return settings, nil } +func isIPOnly(host string) bool { + return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":") +} + var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} func parseDSNSettings(s string) (map[string]string, error) { diff --git a/config_test.go b/config_test.go index b6068cc8..d932a605 100644 --- a/config_test.go +++ b/config_test.go @@ -127,11 +127,11 @@ func TestParseConfig(t *testing.T) { name: "sslmode verify-ca", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, @@ -228,6 +228,42 @@ func TestParseConfig(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + name: "database url IPv4 with port", + connString: "postgresql://jack@127.0.0.1:5433/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "127.0.0.1", + Port: 5433, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url IPv6 with port", + connString: "postgresql://jack@[2001:db8::1]:5433/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "2001:db8::1", + Port: 5433, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url IPv6 no port", + connString: "postgresql://jack@[2001:db8::1]/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "2001:db8::1", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { name: "DSN everything", connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema", From 4a6bd41a36f5451c35f8155e71e1793326cd70b7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 07:58:10 -0500 Subject: [PATCH 0479/1158] Rename Composite to CompositeType. This harmonizes the naming with EnumType and ArrayType. --- composite_bench_test.go | 4 ++-- composite.go => composite_type.go | 26 ++++++++++----------- composite_test.go => composite_type_test.go | 2 +- 3 files changed, 16 insertions(+), 16 deletions(-) rename composite.go => composite_type.go (90%) rename composite_test.go => composite_type_test.go (95%) diff --git a/composite_bench_test.go b/composite_bench_test.go index fa0f9f61..1a5a7492 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -100,7 +100,7 @@ func BenchmarkBinaryEncodingComposite(b *testing.B) { ci := pgtype.NewConnInfo() f1 := 2 f2 := ptrS("bar") - c := pgtype.NewComposite(&pgtype.Int4{}, &pgtype.Text{}) + c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{}) b.ResetTimer() for n := 0; n < b.N; n++ { @@ -164,7 +164,7 @@ func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { var f1 int var f2 *string - c := pgtype.NewComposite(&pgtype.Int4{}, &pgtype.Text{}) + c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{}) b.ResetTimer() for n := 0; n < b.N; n++ { diff --git a/composite.go b/composite_type.go similarity index 90% rename from composite.go rename to composite_type.go index 59549736..97d0c0d7 100644 --- a/composite.go +++ b/composite_type.go @@ -7,23 +7,23 @@ import ( errors "golang.org/x/xerrors" ) -type Composite struct { +type CompositeType struct { fields []Value Status Status } -// NewComposite creates a Composite object, which acts as a "schema" for +// NewCompositeType creates a Composite object, which acts as a "schema" for // SQL composite values. // To pass Composite as SQL parameter first set it's fields, either by -// passing initialized Value{} instances to NewComposite or by calling +// passing initialized Value{} instances to NewCompositeType or by calling // SetFields method // To read composite fields back pass result of Scan() method // to query Scan function. -func NewComposite(fields ...Value) *Composite { - return &Composite{fields, Present} +func NewCompositeType(fields ...Value) *CompositeType { + return &CompositeType{fields, Present} } -func (src Composite) Get() interface{} { +func (src CompositeType) Get() interface{} { switch src.Status { case Present: return src @@ -35,9 +35,9 @@ func (src Composite) Get() interface{} { } // Set is called internally when passing query arguments. -func (dst *Composite) Set(src interface{}) error { +func (dst *CompositeType) Set(src interface{}) error { if src == nil { - *dst = Composite{Status: Null} + *dst = CompositeType{Status: Null} return nil } @@ -60,11 +60,11 @@ func (dst *Composite) Set(src interface{}) error { } // AssignTo should never be called on composite value directly -func (src Composite) AssignTo(dst interface{}) error { +func (src CompositeType) AssignTo(dst interface{}) error { return errors.New("Pass Composite.Scan() to deconstruct composite") } -func (src Composite) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { +func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { switch src.Status { case Null: return nil, nil @@ -78,7 +78,7 @@ func (src Composite) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err // Opposite to Record, fields in a composite act as a "schema" // and decoding fails if SQL value can't be assigned due to // type mismatch -func (dst *Composite) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { +func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { if buf == nil { dst.Status = Null return nil @@ -118,7 +118,7 @@ func (dst *Composite) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { // Rest of arguments are set in the order of fields in the composite // // Use of Scan method doesn't modify original composite -func (src Composite) Scan(isNull *bool, dst ...interface{}) BinaryDecoderFunc { +func (src CompositeType) Scan(isNull *bool, dst ...interface{}) BinaryDecoderFunc { return func(ci *ConnInfo, buf []byte) error { if err := src.DecodeBinary(ci, buf); err != nil { return err @@ -139,7 +139,7 @@ func (src Composite) Scan(isNull *bool, dst ...interface{}) BinaryDecoderFunc { } // SetFields sets Composite's fields to corresponding values -func (dst *Composite) SetFields(values ...interface{}) error { +func (dst *CompositeType) SetFields(values ...interface{}) error { if len(values) != len(dst.fields) { return errors.Errorf("Number of fields don't match. Composite has %d fields", len(dst.fields)) } diff --git a/composite_test.go b/composite_type_test.go similarity index 95% rename from composite_test.go rename to composite_type_test.go index ac0eb4d0..4f614fc5 100644 --- a/composite_test.go +++ b/composite_type_test.go @@ -31,7 +31,7 @@ create type mytype as ( var a int var b *string - c := pgtype.NewComposite(&pgtype.Int4{}, &pgtype.Text{}) + c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{}) c.SetFields(2, "bar") err = conn.QueryRow(context.Background(), "select $1::mytype", qrf, c). From c41160bcbbf7eb80030631af9af9fd5d86e9c5af Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 08:01:10 -0500 Subject: [PATCH 0480/1158] Make CompositeType status private --- composite_type.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/composite_type.go b/composite_type.go index 97d0c0d7..7cc620d5 100644 --- a/composite_type.go +++ b/composite_type.go @@ -9,7 +9,7 @@ import ( type CompositeType struct { fields []Value - Status Status + status Status } // NewCompositeType creates a Composite object, which acts as a "schema" for @@ -24,20 +24,20 @@ func NewCompositeType(fields ...Value) *CompositeType { } func (src CompositeType) Get() interface{} { - switch src.Status { + switch src.status { case Present: return src case Null: return nil default: - return src.Status + return src.status } } // Set is called internally when passing query arguments. func (dst *CompositeType) Set(src interface{}) error { if src == nil { - *dst = CompositeType{Status: Null} + *dst = CompositeType{status: Null} return nil } @@ -51,7 +51,7 @@ func (dst *CompositeType) Set(src interface{}) error { return err } } - dst.Status = Present + dst.status = Present default: return errors.Errorf("Can not convert %v to Composite", src) } @@ -65,7 +65,7 @@ func (src CompositeType) AssignTo(dst interface{}) error { } func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { - switch src.Status { + switch src.status { case Null: return nil, nil case Undefined: @@ -80,7 +80,7 @@ func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, // type mismatch func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { if buf == nil { - dst.Status = Null + dst.status = Null return nil } @@ -107,7 +107,7 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { return scanner.Err() } - dst.Status = Present + dst.status = Present return nil } @@ -124,7 +124,7 @@ func (src CompositeType) Scan(isNull *bool, dst ...interface{}) BinaryDecoderFun return err } - if src.Status == Null { + if src.status == Null { *isNull = true return nil } @@ -148,7 +148,7 @@ func (dst *CompositeType) SetFields(values ...interface{}) error { return err } } - dst.Status = Present + dst.status = Present return nil } From 247043b597bb92b99d689edf157c262e848338b7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 08:35:45 -0500 Subject: [PATCH 0481/1158] Merge SetFields functionality into Set --- composite_bench_test.go | 2 +- composite_type.go | 35 ++++++++++++++----------------- composite_type_test.go | 46 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 61 insertions(+), 22 deletions(-) diff --git a/composite_bench_test.go b/composite_bench_test.go index 1a5a7492..d4eb0ac7 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -104,7 +104,7 @@ func BenchmarkBinaryEncodingComposite(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - c.SetFields(f1, f2) + c.Set([]interface{}{f1, f2}) buf, _ = c.EncodeBinary(ci, buf[:0]) } x = buf diff --git a/composite_type.go b/composite_type.go index 7cc620d5..76b32b86 100644 --- a/composite_type.go +++ b/composite_type.go @@ -20,13 +20,17 @@ type CompositeType struct { // To read composite fields back pass result of Scan() method // to query Scan function. func NewCompositeType(fields ...Value) *CompositeType { - return &CompositeType{fields, Present} + return &CompositeType{fields, Undefined} } func (src CompositeType) Get() interface{} { switch src.status { case Present: - return src + results := make([]interface{}, len(src.fields)) + for i := range results { + results[i] = src.fields[i].Get() + } + return results case Null: return nil default: @@ -34,17 +38,16 @@ func (src CompositeType) Get() interface{} { } } -// Set is called internally when passing query arguments. func (dst *CompositeType) Set(src interface{}) error { if src == nil { - *dst = CompositeType{status: Null} + dst.status = Null return nil } switch value := src.(type) { - case []Value: + case []interface{}: if len(value) != len(dst.fields) { - return errors.Errorf("Number of fields don't match. Composite has %d fields", len(dst.fields)) + return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.fields)) } for i, v := range value { if err := dst.fields[i].Set(v); err != nil { @@ -52,6 +55,12 @@ func (dst *CompositeType) Set(src interface{}) error { } } dst.status = Present + case *[]interface{}: + if value == nil { + dst.status = Null + return nil + } + return dst.Set(*value) default: return errors.Errorf("Can not convert %v to Composite", src) } @@ -138,20 +147,6 @@ func (src CompositeType) Scan(isNull *bool, dst ...interface{}) BinaryDecoderFun } } -// SetFields sets Composite's fields to corresponding values -func (dst *CompositeType) SetFields(values ...interface{}) error { - if len(values) != len(dst.fields) { - return errors.Errorf("Number of fields don't match. Composite has %d fields", len(dst.fields)) - } - for i, v := range values { - if err := dst.fields[i].Set(v); err != nil { - return err - } - } - dst.status = Present - return nil -} - type CompositeBinaryScanner struct { rp int src []byte diff --git a/composite_type_test.go b/composite_type_test.go index 4f614fc5..3e38b6dc 100644 --- a/composite_type_test.go +++ b/composite_type_test.go @@ -4,11 +4,55 @@ import ( "context" "fmt" "os" + "testing" "github.com/jackc/pgtype" pgx "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" ) +func TestCompositeTypeSetAndGet(t *testing.T) { + ct := pgtype.NewCompositeType(&pgtype.Text{}, &pgtype.Int4{}) + assert.Equal(t, pgtype.Undefined, ct.Get()) + + nilTests := []struct { + src interface{} + }{ + {nil}, // nil interface + {(*[]interface{})(nil)}, // typed nil + } + + for i, tt := range nilTests { + err := ct.Set(tt.src) + assert.NoErrorf(t, err, "%d", i) + assert.Equal(t, nil, ct.Get()) + } + + compatibleValuesTests := []struct { + src []interface{} + expected []interface{} + }{ + { + src: []interface{}{"foo", int32(42)}, + expected: []interface{}{"foo", int32(42)}, + }, + { + src: []interface{}{nil, nil}, + expected: []interface{}{nil, nil}, + }, + { + src: []interface{}{&pgtype.Text{String: "hi", Status: pgtype.Present}, &pgtype.Int4{Int: 7, Status: pgtype.Present}}, + expected: []interface{}{"hi", int32(7)}, + }, + } + + for i, tt := range compatibleValuesTests { + err := ct.Set(tt.src) + assert.NoErrorf(t, err, "%d", i) + assert.EqualValues(t, tt.expected, ct.Get()) + } +} + //ExampleComposite demonstrates use of Row() function to pass and receive // back composite types without creating boilderplate custom types. func Example_composite() { @@ -32,7 +76,7 @@ create type mytype as ( var b *string c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{}) - c.SetFields(2, "bar") + c.Set([]interface{}{2, "bar"}) err = conn.QueryRow(context.Background(), "select $1::mytype", qrf, c). Scan(c.Scan(&isNull, &a, &b)) From bff2829b0f28e321ac8e441b82b4eb6867837dd1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 10:19:41 -0500 Subject: [PATCH 0482/1158] Move ComposteType.Scan functionality into AssignTo Also remove adapter functions that are no longer used. --- composite_bench_test.go | 11 ++- composite_type.go | 66 ++++++++++-------- composite_type_test.go | 145 +++++++++++++++++++++++++++++++++------- pgtype.go | 18 ----- 4 files changed, 167 insertions(+), 73 deletions(-) diff --git a/composite_bench_test.go b/composite_bench_test.go index d4eb0ac7..9eaf7632 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -160,7 +160,6 @@ var gf2 *string func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { ci := pgtype.NewConnInfo() buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) - var isNull bool var f1 int var f2 *string @@ -168,8 +167,14 @@ func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - err := c.Scan(&isNull, &f1, &f2).DecodeBinary(ci, buf) - E(err) + err := c.DecodeBinary(ci, buf) + if err != nil { + b.Fatal(err) + } + err = c.AssignTo([]interface{}{&f1, &f2}) + if err != nil { + b.Fatal(err) + } } gf1 = f1 gf2 = f2 diff --git a/composite_type.go b/composite_type.go index 76b32b86..53386f37 100644 --- a/composite_type.go +++ b/composite_type.go @@ -70,7 +70,45 @@ func (dst *CompositeType) Set(src interface{}) error { // AssignTo should never be called on composite value directly func (src CompositeType) AssignTo(dst interface{}) error { - return errors.New("Pass Composite.Scan() to deconstruct composite") + switch src.status { + case Present: + switch v := dst.(type) { + case []interface{}: + if len(v) != len(src.fields) { + return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.fields)) + } + for i := range src.fields { + if v[i] == nil { + continue + } + + assignToErr := src.fields[i].AssignTo(v[i]) + if assignToErr != nil { + // Try to use get / set instead -- this avoids every type having to be able to AssignTo type of self. + setSucceeded := false + if setter, ok := v[i].(Value); ok { + err := setter.Set(src.fields[i].Get()) + setSucceeded = err == nil + } + if !setSucceeded { + return errors.Errorf("unable to assign to dst[%d]: %v", i, assignToErr) + } + } + + } + return nil + case *[]interface{}: + return src.AssignTo(*v) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return errors.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { @@ -121,32 +159,6 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { return nil } -// Scan is a helper function to perform "nested" scan of -// a composite value when scanning a query result row. -// isNull is set if scanned value is NULL -// Rest of arguments are set in the order of fields in the composite -// -// Use of Scan method doesn't modify original composite -func (src CompositeType) Scan(isNull *bool, dst ...interface{}) BinaryDecoderFunc { - return func(ci *ConnInfo, buf []byte) error { - if err := src.DecodeBinary(ci, buf); err != nil { - return err - } - - if src.status == Null { - *isNull = true - return nil - } - - for i, f := range src.fields { - if err := f.AssignTo(dst[i]); err != nil { - return err - } - } - return nil - } -} - type CompositeBinaryScanner struct { rp int src []byte diff --git a/composite_type_test.go b/composite_type_test.go index 3e38b6dc..56b9318b 100644 --- a/composite_type_test.go +++ b/composite_type_test.go @@ -53,49 +53,144 @@ func TestCompositeTypeSetAndGet(t *testing.T) { } } -//ExampleComposite demonstrates use of Row() function to pass and receive -// back composite types without creating boilderplate custom types. +func TestCompositeTypeAssignTo(t *testing.T) { + ct := pgtype.NewCompositeType(&pgtype.Text{}, &pgtype.Int4{}) + + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + var a string + var b int32 + + err = ct.AssignTo([]interface{}{&a, &b}) + assert.NoError(t, err) + + assert.Equal(t, "foo", a) + assert.Equal(t, int32(42), b) + } + + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + var a pgtype.Text + var b pgtype.Int4 + + err = ct.AssignTo([]interface{}{&a, &b}) + assert.NoError(t, err) + + assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a) + assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b) + } + + // Allow nil destination component as no-op + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + var b int32 + + err = ct.AssignTo([]interface{}{nil, &b}) + assert.NoError(t, err) + + assert.Equal(t, int32(42), b) + } + + // *[]interface{} dest when null + { + err := ct.Set(nil) + assert.NoError(t, err) + + var a pgtype.Text + var b pgtype.Int4 + dst := []interface{}{&a, &b} + + err = ct.AssignTo(&dst) + assert.NoError(t, err) + + assert.Nil(t, dst) + } + + // *[]interface{} dest when not null + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + var a pgtype.Text + var b pgtype.Int4 + dst := []interface{}{&a, &b} + + err = ct.AssignTo(&dst) + assert.NoError(t, err) + + assert.NotNil(t, dst) + assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a) + assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b) + } +} + func Example_composite() { conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - E(err) + if err != nil { + fmt.Println(err) + return + } defer conn.Close(context.Background()) - _, err = conn.Exec(context.Background(), `drop type if exists mytype; + _, err = conn.Exec(context.Background(), `drop type if exists mytype;`) + if err != nil { + fmt.Println(err) + return + } -create type mytype as ( + _, err = conn.Exec(context.Background(), `create type mytype as ( a int4, b text );`) - E(err) + if err != nil { + fmt.Println(err) + return + } defer conn.Exec(context.Background(), "drop type mytype") - qrf := pgx.QueryResultFormats{pgx.BinaryFormatCode} + var oid uint32 + err = conn.QueryRow(context.Background(), `select 'mytype'::regtype::oid`).Scan(&oid) + if err != nil { + fmt.Println(err) + return + } + + c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{}) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: c, Name: "mytype", OID: oid}) - var isNull bool var a int var b *string - c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{}) - c.Set([]interface{}{2, "bar"}) + err = conn.QueryRow(context.Background(), "select $1::mytype", []interface{}{2, "bar"}).Scan([]interface{}{&a, &b}) + if err != nil { + fmt.Println(err) + return + } - err = conn.QueryRow(context.Background(), "select $1::mytype", qrf, c). - Scan(c.Scan(&isNull, &a, &b)) + fmt.Printf("First: a=%d b=%s\n", a, *b) + + err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan([]interface{}{&a, &b}) + if err != nil { + fmt.Println(err) + return + } + + fmt.Printf("Second: a=%d b=%v\n", a, b) + + scanTarget := []interface{}{&a, &b} + err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(&scanTarget) E(err) - fmt.Printf("First: isNull=%v a=%d b=%s\n", isNull, a, *b) - - err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype", qrf).Scan(c.Scan(&isNull, &a, &b)) - E(err) - - fmt.Printf("Second: isNull=%v a=%d b=%v\n", isNull, a, b) - - err = conn.QueryRow(context.Background(), "select NULL::mytype", qrf).Scan(c.Scan(&isNull, &a, &b)) - E(err) - - fmt.Printf("Third: isNull=%v\n", isNull) + fmt.Printf("Third: isNull=%v\n", scanTarget == nil) // Output: - // First: isNull=false a=2 b=bar - // Second: isNull=false a=1 b= + // First: a=2 b=bar + // Second: a=1 b= // Third: isNull=true } diff --git a/pgtype.go b/pgtype.go index 193980ef..25f1a1d5 100644 --- a/pgtype.go +++ b/pgtype.go @@ -197,24 +197,6 @@ type TextEncoder interface { EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) } -//The BinaryDecoderFunc type is an adapter to allow the use of ordinary functions as BinaryDecoder types. -// If f is a function with the appropriate signature, BinaryDecoderFunc(f) is a BinaryDecoder that calls f. -type BinaryDecoderFunc func(ci *ConnInfo, src []byte) error - -// DecodeBinary calls f(ci, src) -func (f BinaryDecoderFunc) DecodeBinary(ci *ConnInfo, src []byte) error { - return f(ci, src) -} - -//The BinaryEncoderFunc type is an adapter to allow the use of ordinary functions as BinaryDecoder types. -// If f is a function with the appropriate signature, BinaryEncoderFunc(f) is a BinaryDecoder that calls f. -type BinaryEncoderFunc func(ci *ConnInfo, buf []byte) ([]byte, error) - -// EncodeBinary calls f(ci, buf) -func (f BinaryEncoderFunc) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { - return f(ci, buf) -} - var errUndefined = errors.New("cannot encode status undefined") var errBadStatus = errors.New("invalid status") From 682201a4fcd72d5f688c511de165db7bf2227b3a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 10:26:51 -0500 Subject: [PATCH 0483/1158] Rename CloneTypeValue to NewTypeValue --- array_type.go | 2 +- enum_type.go | 2 +- pgtype.go | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/array_type.go b/array_type.go index f25051d1..4fe772fd 100644 --- a/array_type.go +++ b/array_type.go @@ -25,7 +25,7 @@ func NewArrayType(typeName string, newElement func() ValueTranscoder) *ArrayType return &ArrayType{typeName: typeName, newElement: newElement} } -func (at *ArrayType) CloneTypeValue() Value { +func (at *ArrayType) NewTypeValue() Value { return &ArrayType{ elements: at.elements, dimensions: at.dimensions, diff --git a/enum_type.go b/enum_type.go index 1a6a4b46..231c21fd 100644 --- a/enum_type.go +++ b/enum_type.go @@ -31,7 +31,7 @@ func NewEnumType(typeName string, members []string) EnumType { return et } -func (et *enumType) CloneTypeValue() Value { +func (et *enumType) NewTypeValue() Value { return &enumType{ value: et.value, status: et.status, diff --git a/pgtype.go b/pgtype.go index 25f1a1d5..6a703994 100644 --- a/pgtype.go +++ b/pgtype.go @@ -134,9 +134,9 @@ type Value interface { // In general, instances of TypeValue should not be used to directly represent a value. It should only be used as an // encoder and decoder internal to ConnInfo. type TypeValue interface { - // CloneTypeValue duplicates a TypeValue including references to internal type information. e.g. the list of members + // NewTypeValue creates a TypeValue including references to internal type information. e.g. the list of members // in an EnumType. - CloneTypeValue() Value + NewTypeValue() Value // TypeName returns the PostgreSQL name of this type. TypeName() string @@ -359,7 +359,7 @@ func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) { func (ci *ConnInfo) RegisterDataType(t DataType) { if tv, ok := t.Value.(TypeValue); ok { - t.Value = tv.CloneTypeValue() + t.Value = tv.NewTypeValue() } ci.oidToDataType[t.OID] = &t @@ -469,7 +469,7 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { for _, dt := range ci.oidToDataType { var value Value if tv, ok := dt.Value.(TypeValue); ok { - value = tv.CloneTypeValue() + value = tv.NewTypeValue() } else { value = reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value) } From e5992d0aede8ce4bb05e94e4df3f4cae7771e45a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 10:28:13 -0500 Subject: [PATCH 0484/1158] TypeValue should include Value --- enum_type.go | 1 - pgtype.go | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/enum_type.go b/enum_type.go index 231c21fd..a72bb13f 100644 --- a/enum_type.go +++ b/enum_type.go @@ -5,7 +5,6 @@ import errors "golang.org/x/xerrors" // EnumType represents a enum type. While it implements Value, this is only in service of its type conversion duties // when registered as a data type in a ConnType. It should not be used directly as a Value. type EnumType interface { - Value TypeValue // Members returns possible members of this enumeration. The returned slice must not be modified. diff --git a/pgtype.go b/pgtype.go index 6a703994..5662f4c7 100644 --- a/pgtype.go +++ b/pgtype.go @@ -128,12 +128,14 @@ type Value interface { AssignTo(dst interface{}) error } -// TypeValue represents values where instances can represent different PostgreSQL types. This can be useful for +// TypeValue is a Value where instances can represent different PostgreSQL types. This can be useful for // representing types such as enums, composites, and arrays. // // In general, instances of TypeValue should not be used to directly represent a value. It should only be used as an // encoder and decoder internal to ConnInfo. type TypeValue interface { + Value + // NewTypeValue creates a TypeValue including references to internal type information. e.g. the list of members // in an EnumType. NewTypeValue() Value From 9cdd928cb8cdae92448358be3e8945c92d2d2a8e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 10:40:13 -0500 Subject: [PATCH 0485/1158] CompositeType implements TypeValue --- composite_bench_test.go | 4 ++-- composite_type.go | 25 ++++++++++++++++++++++--- composite_type_test.go | 6 +++--- pgtype.go | 22 +++++++++++----------- 4 files changed, 38 insertions(+), 19 deletions(-) diff --git a/composite_bench_test.go b/composite_bench_test.go index 9eaf7632..e1dd6d04 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -100,7 +100,7 @@ func BenchmarkBinaryEncodingComposite(b *testing.B) { ci := pgtype.NewConnInfo() f1 := 2 f2 := ptrS("bar") - c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{}) + c := pgtype.NewCompositeType("test", &pgtype.Int4{}, &pgtype.Text{}) b.ResetTimer() for n := 0; n < b.N; n++ { @@ -163,7 +163,7 @@ func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { var f1 int var f2 *string - c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{}) + c := pgtype.NewCompositeType("test", &pgtype.Int4{}, &pgtype.Text{}) b.ResetTimer() for n := 0; n < b.N; n++ { diff --git a/composite_type.go b/composite_type.go index 53386f37..03d88aea 100644 --- a/composite_type.go +++ b/composite_type.go @@ -8,8 +8,10 @@ import ( ) type CompositeType struct { - fields []Value status Status + + typeName string + fields []Value } // NewCompositeType creates a Composite object, which acts as a "schema" for @@ -19,8 +21,8 @@ type CompositeType struct { // SetFields method // To read composite fields back pass result of Scan() method // to query Scan function. -func NewCompositeType(fields ...Value) *CompositeType { - return &CompositeType{fields, Undefined} +func NewCompositeType(typeName string, fields ...Value) *CompositeType { + return &CompositeType{typeName: typeName, fields: fields} } func (src CompositeType) Get() interface{} { @@ -38,6 +40,23 @@ func (src CompositeType) Get() interface{} { } } +func (ct *CompositeType) NewTypeValue() Value { + a := &CompositeType{ + typeName: ct.typeName, + fields: make([]Value, len(ct.fields)), + } + + for i := range ct.fields { + a.fields[i] = NewValue(ct.fields[i]) + } + + return a +} + +func (ct *CompositeType) TypeName() string { + return ct.typeName +} + func (dst *CompositeType) Set(src interface{}) error { if src == nil { dst.status = Null diff --git a/composite_type_test.go b/composite_type_test.go index 56b9318b..92ecc849 100644 --- a/composite_type_test.go +++ b/composite_type_test.go @@ -12,7 +12,7 @@ import ( ) func TestCompositeTypeSetAndGet(t *testing.T) { - ct := pgtype.NewCompositeType(&pgtype.Text{}, &pgtype.Int4{}) + ct := pgtype.NewCompositeType("test", &pgtype.Text{}, &pgtype.Int4{}) assert.Equal(t, pgtype.Undefined, ct.Get()) nilTests := []struct { @@ -54,7 +54,7 @@ func TestCompositeTypeSetAndGet(t *testing.T) { } func TestCompositeTypeAssignTo(t *testing.T) { - ct := pgtype.NewCompositeType(&pgtype.Text{}, &pgtype.Int4{}) + ct := pgtype.NewCompositeType("test", &pgtype.Text{}, &pgtype.Int4{}) { err := ct.Set([]interface{}{"foo", int32(42)}) @@ -161,7 +161,7 @@ func Example_composite() { return } - c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{}) + c := pgtype.NewCompositeType("mytype", &pgtype.Int4{}, &pgtype.Text{}) conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: c, Name: "mytype", OID: oid}) var a int diff --git a/pgtype.go b/pgtype.go index 5662f4c7..091e98c4 100644 --- a/pgtype.go +++ b/pgtype.go @@ -360,9 +360,7 @@ func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) { } func (ci *ConnInfo) RegisterDataType(t DataType) { - if tv, ok := t.Value.(TypeValue); ok { - t.Value = tv.NewTypeValue() - } + t.Value = NewValue(t.Value) ci.oidToDataType[t.OID] = &t ci.nameToDataType[t.Name] = &t @@ -469,15 +467,8 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { ci2 := newConnInfo() for _, dt := range ci.oidToDataType { - var value Value - if tv, ok := dt.Value.(TypeValue); ok { - value = tv.NewTypeValue() - } else { - value = reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value) - } - ci2.RegisterDataType(DataType{ - Value: value, + Value: NewValue(dt.Value), Name: dt.Name, OID: dt.OID, }) @@ -844,6 +835,15 @@ func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) } } +// NewValue returns a new instance of the same type as v. +func NewValue(v Value) Value { + if tv, ok := v.(TypeValue); ok { + return tv.NewTypeValue() + } else { + return reflect.New(reflect.ValueOf(v).Elem().Type()).Interface().(Value) + } +} + var nameValues map[string]Value func init() { From e92ee69901b17859ae173519dbbffb911cdf0e31 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 10:41:50 -0500 Subject: [PATCH 0486/1158] Expose EnumType directly instead of behind interface --- enum_type.go | 39 ++++++++++++++++----------------------- enum_type_test.go | 2 +- 2 files changed, 17 insertions(+), 24 deletions(-) diff --git a/enum_type.go b/enum_type.go index a72bb13f..d3a1df5c 100644 --- a/enum_type.go +++ b/enum_type.go @@ -4,14 +4,7 @@ import errors "golang.org/x/xerrors" // EnumType represents a enum type. While it implements Value, this is only in service of its type conversion duties // when registered as a data type in a ConnType. It should not be used directly as a Value. -type EnumType interface { - TypeValue - - // Members returns possible members of this enumeration. The returned slice must not be modified. - Members() []string -} - -type enumType struct { +type EnumType struct { value string status Status @@ -21,8 +14,8 @@ type enumType struct { } // NewEnumType initializes a new EnumType. It retains a read-only reference to members. members must not be changed. -func NewEnumType(typeName string, members []string) EnumType { - et := &enumType{typeName: typeName, members: members} +func NewEnumType(typeName string, members []string) *EnumType { + et := &EnumType{typeName: typeName, members: members} et.membersMap = make(map[string]string, len(members)) for _, m := range members { et.membersMap[m] = m @@ -30,8 +23,8 @@ func NewEnumType(typeName string, members []string) EnumType { return et } -func (et *enumType) NewTypeValue() Value { - return &enumType{ +func (et *EnumType) NewTypeValue() Value { + return &EnumType{ value: et.value, status: et.status, @@ -41,17 +34,17 @@ func (et *enumType) NewTypeValue() Value { } } -func (et *enumType) TypeName() string { +func (et *EnumType) TypeName() string { return et.typeName } -func (et *enumType) Members() []string { +func (et *EnumType) Members() []string { return et.members } // Set assigns src to dst. Set purposely does not check that src is a member. This allows continued error free // operation in the event the PostgreSQL enum type is modified during a connection. -func (dst *enumType) Set(src interface{}) error { +func (dst *EnumType) Set(src interface{}) error { if src == nil { dst.status = Null return nil @@ -92,7 +85,7 @@ func (dst *enumType) Set(src interface{}) error { return nil } -func (dst enumType) Get() interface{} { +func (dst EnumType) Get() interface{} { switch dst.status { case Present: return dst.value @@ -103,7 +96,7 @@ func (dst enumType) Get() interface{} { } } -func (src *enumType) AssignTo(dst interface{}) error { +func (src *EnumType) AssignTo(dst interface{}) error { switch src.status { case Present: switch v := dst.(type) { @@ -127,11 +120,11 @@ func (src *enumType) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } -func (enumType) PreferredResultFormat() int16 { +func (EnumType) PreferredResultFormat() int16 { return TextFormatCode } -func (dst *enumType) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *EnumType) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { dst.status = Null return nil @@ -151,15 +144,15 @@ func (dst *enumType) DecodeText(ci *ConnInfo, src []byte) error { return nil } -func (dst *enumType) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *EnumType) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } -func (enumType) PreferredParamFormat() int16 { +func (EnumType) PreferredParamFormat() int16 { return TextFormatCode } -func (src enumType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src EnumType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.status { case Null: return nil, nil @@ -170,6 +163,6 @@ func (src enumType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, src.value...), nil } -func (src enumType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src EnumType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return src.EncodeText(ci, buf) } diff --git a/enum_type_test.go b/enum_type_test.go index c1e2add0..4dd88f2a 100644 --- a/enum_type_test.go +++ b/enum_type_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -func setupEnum(t *testing.T, conn *pgx.Conn) pgtype.EnumType { +func setupEnum(t *testing.T, conn *pgx.Conn) *pgtype.EnumType { _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;") require.NoError(t, err) From 218663463828a6358d2a3004c00180d9d986a511 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 11:55:24 -0500 Subject: [PATCH 0487/1158] Add CompositeFields encoders --- composite_fields.go | 112 +++++++++++++++++++++++++++++ composite_fields_test.go | 147 +++++++++++++++++++++++++++++++++++++++ composite_type.go | 14 ++++ 3 files changed, 273 insertions(+) diff --git a/composite_fields.go b/composite_fields.go index 64a17b55..751adce8 100644 --- a/composite_fields.go +++ b/composite_fields.go @@ -1,11 +1,17 @@ package pgtype import ( + "encoding/binary" + + "github.com/jackc/pgio" errors "golang.org/x/xerrors" ) // CompositeFields scans the fields of a composite type into the elements of the CompositeFields value. To scan a // nullable value use a *CompositeFields. It will be set to nil in case of null. +// +// CompositeFields implements EncodeBinary and EncodeText. However, functionality is limited due to CompositeFields not +// knowing the PostgreSQL schema of the composite type. Prefer using a registered CompositeType. type CompositeFields []interface{} func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error { @@ -74,3 +80,109 @@ func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error { return nil } + +// EncodeText encodes composite fields into the text format. Prefer registering a CompositeType to using +// CompositeFields to encode directly. +func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + buf = append(buf, '(') + + fieldBuf := make([]byte, 0, 32) + + for _, f := range cf { + if f != nil { + fieldBuf = fieldBuf[0:0] + if textEncoder, ok := f.(TextEncoder); ok { + var err error + fieldBuf, err = textEncoder.EncodeText(ci, fieldBuf) + if err != nil { + return nil, err + } + if fieldBuf != nil { + buf = append(buf, QuoteCompositeFieldIfNeeded(string(fieldBuf))...) + } + } else { + dt, ok := ci.DataTypeForValue(f) + if !ok { + return nil, errors.Errorf("Unknown data type for %#v", f) + } + + err := dt.Value.Set(f) + if err != nil { + return nil, err + } + + if textEncoder, ok := dt.Value.(TextEncoder); ok { + var err error + fieldBuf, err = textEncoder.EncodeText(ci, fieldBuf) + if err != nil { + return nil, err + } + if fieldBuf != nil { + buf = append(buf, QuoteCompositeFieldIfNeeded(string(fieldBuf))...) + } + } else { + return nil, errors.Errorf("Cannot encode text format for %v", f) + } + } + } + buf = append(buf, ',') + } + + buf[len(buf)-1] = ')' + return buf, nil +} + +// EncodeBinary encodes composite fields into the binary format. Unlike CompositeType the schema of the destination is +// unknown. Prefer registering a CompositeType to using CompositeFields to encode directly. Because the binary +// composite format requires the OID of each field to be specified the only types that will work are those known to +// ConnInfo. +// +// In particular: +// +// * Nil cannot be used because there is no way to determine what type it. +// * Integer types must be exact matches. e.g. A Go int32 into a PostgreSQL bigint will fail. +// * No dereferencing will be done. e.g. *Text must be used instead of Text. +func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + buf = pgio.AppendUint32(buf, uint32(len(cf))) + + for _, f := range cf { + dt, ok := ci.DataTypeForValue(f) + if !ok { + return nil, errors.Errorf("Unknown OID for %#v", f) + } + + buf = pgio.AppendUint32(buf, dt.OID) + lengthPos := len(buf) + buf = pgio.AppendInt32(buf, -1) + + if binaryEncoder, ok := f.(BinaryEncoder); ok { + fieldBuf, err := binaryEncoder.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if fieldBuf != nil { + binary.BigEndian.PutUint32(buf[lengthPos:], uint32(len(fieldBuf)-len(buf))) + buf = fieldBuf + } + } else { + err := dt.Value.Set(f) + if err != nil { + return nil, err + } + if binaryEncoder, ok := dt.Value.(BinaryEncoder); ok { + fieldBuf, err := binaryEncoder.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if fieldBuf != nil { + binary.BigEndian.PutUint32(buf[lengthPos:], uint32(len(fieldBuf)-len(buf))) + buf = fieldBuf + } + } else { + return nil, errors.Errorf("Cannot encode binary format for %v", f) + } + } + } + + return buf, nil +} diff --git a/composite_fields_test.go b/composite_fields_test.go index d53e48ec..dc4d4c29 100644 --- a/composite_fields_test.go +++ b/composite_fields_test.go @@ -8,6 +8,7 @@ import ( "github.com/jackc/pgtype/testutil" "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCompositeFieldsDecode(t *testing.T) { @@ -123,4 +124,150 @@ func TestCompositeFieldsDecode(t *testing.T) { assert.EqualValuesf(t, []int64{1, 2, 3}, b, "Format: %v", format) } } + + // Skip nil fields + { + var a int32 + var c float64 + + for _, format := range formats { + err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, nil, &c}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.EqualValuesf(t, 1, a, "Format: %v", format) + assert.EqualValuesf(t, 2.1, c, "Format: %v", format) + } + } +} + +func TestCompositeFieldsEncode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists cf_encode; + +create type cf_encode as ( + a text, + b int4, + c text, + d float8, + e text +);`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type cf_encode") + + // Use simple protocol to force text or binary encoding + simpleProtocols := []bool{true, false} + + // Assorted values + { + var a string + var b int32 + var c string + var d float64 + var e string + + for _, simpleProtocol := range simpleProtocols { + err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{"hi", int32(1), "ok", float64(2.1), "bye"}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, "ok", c, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 2.1, d, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, "bye", e, "Simple Protocol: %v", simpleProtocol) + } + } + } + + // untyped nil + { + var a pgtype.Text + var b int32 + var c string + var d pgtype.Float8 + var e pgtype.Text + + simpleProtocol := true + err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{nil, int32(1), "null", nil, nil}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol) + assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol) + assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol) + } + + // untyped nil cannot be represented in binary format because CompositeFields does not know the PostgreSQL schema + // of the composite type. + simpleProtocol = false + err = conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{nil, int32(1), "null", nil, nil}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + assert.Errorf(t, err, "Simple Protocol: %v", simpleProtocol) + } + + // nulls, string "null", and empty string fields + { + var a pgtype.Text + var b int32 + var c string + var d pgtype.Float8 + var e pgtype.Text + + for _, simpleProtocol := range simpleProtocols { + err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{&pgtype.Text{Status: pgtype.Null}, int32(1), "null", &pgtype.Float8{Status: pgtype.Null}, &pgtype.Text{Status: pgtype.Null}}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol) + assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol) + assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol) + } + } + } + + // quotes and special characters + { + var a string + var b int32 + var c string + var d float64 + var e string + + for _, simpleProtocol := range simpleProtocols { + err := conn.QueryRow( + context.Background(), + `select $1::cf_encode`, + pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{`"`, int32(42), `foo'bar`, float64(1.2), `baz)bar`}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.Equalf(t, `"`, a, "Simple Protocol: %v", simpleProtocol) + assert.Equalf(t, int32(42), b, "Simple Protocol: %v", simpleProtocol) + assert.Equalf(t, `foo'bar`, c, "Simple Protocol: %v", simpleProtocol) + assert.Equalf(t, float64(1.2), d, "Simple Protocol: %v", simpleProtocol) + assert.Equalf(t, `baz)bar`, e, "Simple Protocol: %v", simpleProtocol) + } + } + } } diff --git a/composite_type.go b/composite_type.go index 03d88aea..b4b1ab28 100644 --- a/composite_type.go +++ b/composite_type.go @@ -2,6 +2,7 @@ package pgtype import ( "encoding/binary" + "strings" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -366,3 +367,16 @@ func RecordAdd(buf []byte, oid uint32, fieldBytes []byte) []byte { func RecordAddNull(buf []byte, oid uint32) []byte { return pgio.AppendInt32(buf, int32(-1)) } + +var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteCompositeField(src string) string { + return `"` + quoteCompositeReplacer.Replace(src) + `"` +} + +func QuoteCompositeFieldIfNeeded(src string) string { + if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) { + return quoteCompositeField(src) + } + return src +} From e51cb1ef09a161010a263455ce67125d9c42d8e5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 14:04:11 -0500 Subject: [PATCH 0488/1158] Add CompositeBinaryBuilder --- composite_bench_test.go | 23 +++++++------ composite_fields.go | 29 +++------------- composite_type.go | 76 ++++++++++++++++++++++++++++++++++------- convert.go | 23 ++++--------- 4 files changed, 87 insertions(+), 64 deletions(-) diff --git a/composite_bench_test.go b/composite_bench_test.go index e1dd6d04..4858ccad 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -3,6 +3,7 @@ package pgtype_test import ( "testing" + "github.com/jackc/pgio" "github.com/jackc/pgtype" errors "golang.org/x/xerrors" ) @@ -12,22 +13,22 @@ type MyCompositeRaw struct { B *string } -func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) { - a := pgtype.Int4{src.A, pgtype.Present} +func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + buf = pgio.AppendUint32(buf, 2) - fieldBytes := make([]byte, 0, 64) - fieldBytes, _ = a.EncodeBinary(ci, fieldBytes[:0]) - - newBuf = pgtype.RecordStart(buf, 2) - newBuf = pgtype.RecordAdd(newBuf, pgtype.Int4OID, fieldBytes) + buf = pgio.AppendUint32(buf, pgtype.Int4OID) + buf = pgio.AppendInt32(buf, 4) + buf = pgio.AppendInt32(buf, src.A) + buf = pgio.AppendUint32(buf, pgtype.TextOID) if src.B != nil { - fieldBytes, _ = pgtype.Text{*src.B, pgtype.Present}.EncodeBinary(ci, fieldBytes[:0]) - newBuf = pgtype.RecordAdd(newBuf, pgtype.TextOID, fieldBytes) + buf = pgio.AppendInt32(buf, int32(len(*src.B))) + buf = append(buf, (*src.B)...) } else { - newBuf = pgtype.RecordAddNull(newBuf, pgtype.TextOID) + buf = pgio.AppendInt32(buf, -1) } - return + + return buf, nil } func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { diff --git a/composite_fields.go b/composite_fields.go index 751adce8..b97506eb 100644 --- a/composite_fields.go +++ b/composite_fields.go @@ -1,9 +1,6 @@ package pgtype import ( - "encoding/binary" - - "github.com/jackc/pgio" errors "golang.org/x/xerrors" ) @@ -143,7 +140,7 @@ func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { // * Integer types must be exact matches. e.g. A Go int32 into a PostgreSQL bigint will fail. // * No dereferencing will be done. e.g. *Text must be used instead of Text. func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - buf = pgio.AppendUint32(buf, uint32(len(cf))) + b := NewCompositeBinaryBuilder(ci, buf) for _, f := range cf { dt, ok := ci.DataTypeForValue(f) @@ -151,38 +148,20 @@ func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) return nil, errors.Errorf("Unknown OID for %#v", f) } - buf = pgio.AppendUint32(buf, dt.OID) - lengthPos := len(buf) - buf = pgio.AppendInt32(buf, -1) - if binaryEncoder, ok := f.(BinaryEncoder); ok { - fieldBuf, err := binaryEncoder.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if fieldBuf != nil { - binary.BigEndian.PutUint32(buf[lengthPos:], uint32(len(fieldBuf)-len(buf))) - buf = fieldBuf - } + b.AppendEncoder(dt.OID, binaryEncoder) } else { err := dt.Value.Set(f) if err != nil { return nil, err } if binaryEncoder, ok := dt.Value.(BinaryEncoder); ok { - fieldBuf, err := binaryEncoder.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if fieldBuf != nil { - binary.BigEndian.PutUint32(buf[lengthPos:], uint32(len(fieldBuf)-len(buf))) - buf = fieldBuf - } + b.AppendEncoder(dt.OID, binaryEncoder) } else { return nil, errors.Errorf("Cannot encode binary format for %v", f) } } } - return buf, nil + return b.Finish() } diff --git a/composite_type.go b/composite_type.go index b4b1ab28..99f0189f 100644 --- a/composite_type.go +++ b/composite_type.go @@ -350,22 +350,74 @@ func (cfs *CompositeTextScanner) Err() error { return cfs.err } -// RecordStart adds record header to the buf -func RecordStart(buf []byte, fieldCount int) []byte { - return pgio.AppendUint32(buf, uint32(fieldCount)) +type CompositeBinaryBuilder struct { + ci *ConnInfo + buf []byte + startIdx int + fieldCount uint32 + err error } -// RecordAdd adds record field to the buf -func RecordAdd(buf []byte, oid uint32, fieldBytes []byte) []byte { - buf = pgio.AppendUint32(buf, oid) - buf = pgio.AppendUint32(buf, uint32(len(fieldBytes))) - buf = append(buf, fieldBytes...) - return buf +func NewCompositeBinaryBuilder(ci *ConnInfo, buf []byte) *CompositeBinaryBuilder { + startIdx := len(buf) + buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields + return &CompositeBinaryBuilder{ci: ci, buf: buf, startIdx: startIdx} } -// RecordAddNull adds null value as a field to the buf -func RecordAddNull(buf []byte, oid uint32) []byte { - return pgio.AppendInt32(buf, int32(-1)) +func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) { + if b.err != nil { + return + } + + dt, ok := b.ci.DataTypeForOID(oid) + if !ok { + b.err = errors.Errorf("unknown data type for OID: %d", oid) + return + } + + err := dt.Value.Set(field) + if err != nil { + b.err = err + return + } + + binaryEncoder, ok := dt.Value.(BinaryEncoder) + if !ok { + b.err = errors.Errorf("unable to encode binary for OID: %d", oid) + return + } + + b.AppendEncoder(oid, binaryEncoder) +} + +func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field BinaryEncoder) { + if b.err != nil { + return + } + + b.buf = pgio.AppendUint32(b.buf, oid) + lengthPos := len(b.buf) + b.buf = pgio.AppendInt32(b.buf, -1) + fieldBuf, err := field.EncodeBinary(b.ci, b.buf) + if err != nil { + b.err = err + return + } + if fieldBuf != nil { + binary.BigEndian.PutUint32(b.buf[lengthPos:], uint32(len(fieldBuf)-len(b.buf))) + b.buf = fieldBuf + } + + b.fieldCount++ +} + +func (b *CompositeBinaryBuilder) Finish() ([]byte, error) { + if b.err != nil { + return nil, b.err + } + + binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount) + return b.buf, nil } var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) diff --git a/convert.go b/convert.go index 6e70e82e..f170e05b 100644 --- a/convert.go +++ b/convert.go @@ -435,30 +435,21 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { // EncodeRow builds a binary representation of row values (row(), composite types) func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) { - fieldBytes := make([]byte, 0, 128) + b := NewCompositeBinaryBuilder(ci, buf) - newBuf = RecordStart(buf, len(fields)) for _, f := range fields { dt, ok := ci.DataTypeForValue(f) if !ok { return nil, errors.Errorf("Unknown OID for %s", f) } - if f.Get() != nil { - binaryEncoder, ok := f.(BinaryEncoder) - if !ok { - return nil, errors.Errorf("record field doesn't implement binary encoding: %s", reflect.TypeOf(f).Name()) - } - fieldBytes, err = binaryEncoder.EncodeBinary(ci, fieldBytes[:0]) - if err != nil { - return nil, err - } - newBuf = RecordAdd(newBuf, dt.OID, fieldBytes) - } else { - newBuf = RecordAddNull(newBuf, dt.OID) + binaryEncoder, ok := f.(BinaryEncoder) + if !ok { + return nil, errors.Errorf("record field doesn't implement binary encoding: %s", reflect.TypeOf(f).Name()) } - + b.AppendEncoder(dt.OID, binaryEncoder) } - return + + return b.Finish() } func init() { From fcb385dccbdd133189d6349c0e402f40d18c248e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 15:04:14 -0500 Subject: [PATCH 0489/1158] Add ScanDecoder and ScanValue to composite scanners. Rename Scan to Next to disambiguate. --- composite_bench_test.go | 24 +------- composite_fields.go | 35 ++--------- composite_type.go | 132 ++++++++++++++++++++++++++++------------ record.go | 7 +-- 4 files changed, 104 insertions(+), 94 deletions(-) diff --git a/composite_bench_test.go b/composite_bench_test.go index 4858ccad..cff9d518 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -5,7 +5,6 @@ import ( "github.com/jackc/pgio" "github.com/jackc/pgtype" - errors "golang.org/x/xerrors" ) type MyCompositeRaw struct { @@ -35,26 +34,9 @@ func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { a := pgtype.Int4{} b := pgtype.Text{} - scanner, err := pgtype.NewCompositeBinaryScanner(src) - if err != nil { - return err - } - - if 2 != scanner.FieldCount() { - return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=2", scanner.FieldCount()) - } - - if scanner.Scan() { - if err = a.DecodeBinary(ci, scanner.Bytes()); err != nil { - return err - } - } - - if scanner.Scan() { - if err = b.DecodeBinary(ci, scanner.Bytes()); err != nil { - return err - } - } + scanner := pgtype.NewCompositeBinaryScanner(ci, src) + scanner.ScanDecoder(&a) + scanner.ScanDecoder(&b) if scanner.Err() != nil { return scanner.Err() diff --git a/composite_fields.go b/composite_fields.go index b97506eb..b2d9f844 100644 --- a/composite_fields.go +++ b/composite_fields.go @@ -20,19 +20,10 @@ func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error { return errors.Errorf("cannot decode unexpected null into CompositeFields") } - scanner, err := NewCompositeBinaryScanner(src) - if err != nil { - return err - } - if len(cf) != scanner.FieldCount() { - return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), scanner.FieldCount()) - } + scanner := NewCompositeBinaryScanner(ci, src) - for i := 0; scanner.Scan(); i++ { - err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), cf[i]) - if err != nil { - return err - } + for _, f := range cf { + scanner.ScanValue(f) } if scanner.Err() != nil { @@ -51,30 +42,16 @@ func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error { return errors.Errorf("cannot decode unexpected null into CompositeFields") } - scanner, err := NewCompositeTextScanner(src) - if err != nil { - return err - } + scanner := NewCompositeTextScanner(ci, src) - fieldCount := 0 - - for i := 0; scanner.Scan(); i++ { - err := ci.Scan(0, TextFormatCode, scanner.Bytes(), cf[i]) - if err != nil { - return err - } - - fieldCount += 1 + for _, f := range cf { + scanner.ScanValue(f) } if scanner.Err() != nil { return scanner.Err() } - if len(cf) != fieldCount { - return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), fieldCount) - } - return nil } diff --git a/composite_type.go b/composite_type.go index 99f0189f..f01e8e64 100644 --- a/composite_type.go +++ b/composite_type.go @@ -12,7 +12,7 @@ type CompositeType struct { status Status typeName string - fields []Value + fields []ValueTranscoder } // NewCompositeType creates a Composite object, which acts as a "schema" for @@ -22,7 +22,7 @@ type CompositeType struct { // SetFields method // To read composite fields back pass result of Scan() method // to query Scan function. -func NewCompositeType(typeName string, fields ...Value) *CompositeType { +func NewCompositeType(typeName string, fields ...ValueTranscoder) *CompositeType { return &CompositeType{typeName: typeName, fields: fields} } @@ -44,11 +44,11 @@ func (src CompositeType) Get() interface{} { func (ct *CompositeType) NewTypeValue() Value { a := &CompositeType{ typeName: ct.typeName, - fields: make([]Value, len(ct.fields)), + fields: make([]ValueTranscoder, len(ct.fields)), } for i := range ct.fields { - a.fields[i] = NewValue(ct.fields[i]) + a.fields[i] = NewValue(ct.fields[i]).(ValueTranscoder) } return a @@ -138,36 +138,34 @@ func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, case Undefined: return nil, errUndefined } - return EncodeRow(ci, buf, src.fields...) + + b := NewCompositeBinaryBuilder(ci, buf) + for _, f := range src.fields { + dt, ok := ci.DataTypeForValue(f) + if !ok { + return nil, errors.Errorf("unknown oid") + } + + b.AppendEncoder(dt.OID, f) + } + + return b.Finish() } // DecodeBinary implements BinaryDecoder interface. // Opposite to Record, fields in a composite act as a "schema" // and decoding fails if SQL value can't be assigned due to // type mismatch -func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { +func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { if buf == nil { dst.status = Null return nil } - scanner, err := NewCompositeBinaryScanner(buf) - if err != nil { - return err - } - if len(dst.fields) != scanner.FieldCount() { - return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(dst.fields), scanner.FieldCount()) - } + scanner := NewCompositeBinaryScanner(ci, buf) - for i := 0; scanner.Scan(); i++ { - binaryDecoder, ok := dst.fields[i].(BinaryDecoder) - if !ok { - return errors.New("Composite field doesn't support binary protocol") - } - - if err = binaryDecoder.DecodeBinary(ci, scanner.Bytes()); err != nil { - return err - } + for _, f := range dst.fields { + scanner.ScanDecoder(f) } if scanner.Err() != nil { @@ -180,6 +178,7 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { } type CompositeBinaryScanner struct { + ci *ConnInfo rp int src []byte @@ -190,25 +189,52 @@ type CompositeBinaryScanner struct { } // NewCompositeBinaryScanner a scanner over a binary encoded composite balue. -func NewCompositeBinaryScanner(src []byte) (CompositeBinaryScanner, error) { +func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner { rp := 0 if len(src[rp:]) < 4 { - return CompositeBinaryScanner{}, errors.Errorf("Record incomplete %v", src) + return &CompositeBinaryScanner{err: errors.Errorf("Record incomplete %v", src)} } fieldCount := int32(binary.BigEndian.Uint32(src[rp:])) rp += 4 - return CompositeBinaryScanner{ + return &CompositeBinaryScanner{ + ci: ci, rp: rp, src: src, fieldCount: fieldCount, - }, nil + } } -// Scan advances the scanner to the next field. It returns false after the last field is read or an error occurs. After -// Scan returns false, the Err method can be called to check if any errors occurred. -func (cfs *CompositeBinaryScanner) Scan() bool { +// ScanDecoder calls Next and decodes the result with d. +func (cfs *CompositeBinaryScanner) ScanDecoder(d BinaryDecoder) { + if cfs.err != nil { + return + } + + if cfs.Next() { + cfs.err = d.DecodeBinary(cfs.ci, cfs.fieldBytes) + } else { + cfs.err = errors.New("read past end of composite") + } +} + +// ScanDecoder calls Next and scans the result into d. +func (cfs *CompositeBinaryScanner) ScanValue(d interface{}) { + if cfs.err != nil { + return + } + + if cfs.Next() { + cfs.err = cfs.ci.Scan(cfs.OID(), BinaryFormatCode, cfs.Bytes(), d) + } else { + cfs.err = errors.New("read past end of composite") + } +} + +// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Next returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeBinaryScanner) Next() bool { if cfs.err != nil { return false } @@ -261,6 +287,7 @@ func (cfs *CompositeBinaryScanner) Err() error { } type CompositeTextScanner struct { + ci *ConnInfo rp int src []byte @@ -268,29 +295,56 @@ type CompositeTextScanner struct { err error } -// NewCompositeTextScanner a scanner over a text encoded composite balue. -func NewCompositeTextScanner(src []byte) (CompositeTextScanner, error) { +// NewCompositeTextScanner a scanner over a text encoded composite value. +func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner { if len(src) < 2 { - return CompositeTextScanner{}, errors.Errorf("Record incomplete %v", src) + return &CompositeTextScanner{err: errors.Errorf("Record incomplete %v", src)} } if src[0] != '(' { - return CompositeTextScanner{}, errors.Errorf("composite text format must start with '('") + return &CompositeTextScanner{err: errors.Errorf("composite text format must start with '('")} } if src[len(src)-1] != ')' { - return CompositeTextScanner{}, errors.Errorf("composite text format must end with ')'") + return &CompositeTextScanner{err: errors.Errorf("composite text format must end with ')'")} } - return CompositeTextScanner{ + return &CompositeTextScanner{ + ci: ci, rp: 1, src: src, - }, nil + } } -// Scan advances the scanner to the next field. It returns false after the last field is read or an error occurs. After -// Scan returns false, the Err method can be called to check if any errors occurred. -func (cfs *CompositeTextScanner) Scan() bool { +// ScanDecoder calls Next and decodes the result with d. +func (cfs *CompositeTextScanner) ScanDecoder(d TextDecoder) { + if cfs.err != nil { + return + } + + if cfs.Next() { + cfs.err = d.DecodeText(cfs.ci, cfs.fieldBytes) + } else { + cfs.err = errors.New("read past end of composite") + } +} + +// ScanDecoder calls Next and scans the result into d. +func (cfs *CompositeTextScanner) ScanValue(d interface{}) { + if cfs.err != nil { + return + } + + if cfs.Next() { + cfs.err = cfs.ci.Scan(0, TextFormatCode, cfs.Bytes(), d) + } else { + cfs.err = errors.New("read past end of composite") + } +} + +// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Next returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeTextScanner) Next() bool { if cfs.err != nil { return false } diff --git a/record.go b/record.go index 0d51ad4c..7899a881 100644 --- a/record.go +++ b/record.go @@ -102,14 +102,11 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } - scanner, err := NewCompositeBinaryScanner(src) - if err != nil { - return err - } + scanner := NewCompositeBinaryScanner(ci, src) fields := make([]Value, scanner.FieldCount()) - for i := 0; scanner.Scan(); i++ { + for i := 0; scanner.Next(); i++ { binaryDecoder, err := prepareNewBinaryDecoder(ci, scanner.OID(), &fields[i]) if err != nil { return err From e45ef46424155812ce5be493fac400d67d1b05e0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 15:42:26 -0500 Subject: [PATCH 0490/1158] Refactor and add CompositeTextBuilder --- composite_fields.go | 47 +++------------------ composite_type.go | 92 ++++++++++++++++++++++++++++++++++++++++++ composite_type_test.go | 43 ++++++++++++++++++++ 3 files changed, 141 insertions(+), 41 deletions(-) diff --git a/composite_fields.go b/composite_fields.go index b2d9f844..af7bab1e 100644 --- a/composite_fields.go +++ b/composite_fields.go @@ -58,52 +58,17 @@ func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error { // EncodeText encodes composite fields into the text format. Prefer registering a CompositeType to using // CompositeFields to encode directly. func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - buf = append(buf, '(') - - fieldBuf := make([]byte, 0, 32) + b := NewCompositeTextBuilder(ci, buf) for _, f := range cf { - if f != nil { - fieldBuf = fieldBuf[0:0] - if textEncoder, ok := f.(TextEncoder); ok { - var err error - fieldBuf, err = textEncoder.EncodeText(ci, fieldBuf) - if err != nil { - return nil, err - } - if fieldBuf != nil { - buf = append(buf, QuoteCompositeFieldIfNeeded(string(fieldBuf))...) - } - } else { - dt, ok := ci.DataTypeForValue(f) - if !ok { - return nil, errors.Errorf("Unknown data type for %#v", f) - } - - err := dt.Value.Set(f) - if err != nil { - return nil, err - } - - if textEncoder, ok := dt.Value.(TextEncoder); ok { - var err error - fieldBuf, err = textEncoder.EncodeText(ci, fieldBuf) - if err != nil { - return nil, err - } - if fieldBuf != nil { - buf = append(buf, QuoteCompositeFieldIfNeeded(string(fieldBuf))...) - } - } else { - return nil, errors.Errorf("Cannot encode text format for %v", f) - } - } + if textEncoder, ok := f.(TextEncoder); ok { + b.AppendEncoder(textEncoder) + } else { + b.AppendValue(f) } - buf = append(buf, ',') } - buf[len(buf)-1] = ')' - return buf, nil + return b.Finish() } // EncodeBinary encodes composite fields into the binary format. Unlike CompositeType the schema of the destination is diff --git a/composite_type.go b/composite_type.go index f01e8e64..6baa639a 100644 --- a/composite_type.go +++ b/composite_type.go @@ -177,6 +177,27 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { return nil } +func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error { + if buf == nil { + dst.status = Null + return nil + } + + scanner := NewCompositeTextScanner(ci, buf) + + for _, f := range dst.fields { + scanner.ScanDecoder(f) + } + + if scanner.Err() != nil { + return scanner.Err() + } + + dst.status = Present + + return nil +} + type CompositeBinaryScanner struct { ci *ConnInfo rp int @@ -474,6 +495,77 @@ func (b *CompositeBinaryBuilder) Finish() ([]byte, error) { return b.buf, nil } +type CompositeTextBuilder struct { + ci *ConnInfo + buf []byte + startIdx int + fieldCount uint32 + err error + fieldBuf [32]byte +} + +func NewCompositeTextBuilder(ci *ConnInfo, buf []byte) *CompositeTextBuilder { + buf = append(buf, '(') // allocate room for number of fields + return &CompositeTextBuilder{ci: ci, buf: buf} +} + +func (b *CompositeTextBuilder) AppendValue(field interface{}) { + if b.err != nil { + return + } + + if field == nil { + b.buf = append(b.buf, ',') + return + } + + dt, ok := b.ci.DataTypeForValue(field) + if !ok { + b.err = errors.Errorf("unknown data type for field: %v", field) + return + } + + err := dt.Value.Set(field) + if err != nil { + b.err = err + return + } + + textEncoder, ok := dt.Value.(TextEncoder) + if !ok { + b.err = errors.Errorf("unable to encode text for value: %v", field) + return + } + + b.AppendEncoder(textEncoder) +} + +func (b *CompositeTextBuilder) AppendEncoder(field TextEncoder) { + if b.err != nil { + return + } + + fieldBuf, err := field.EncodeText(b.ci, b.fieldBuf[0:0]) + if err != nil { + b.err = err + return + } + if fieldBuf != nil { + b.buf = append(b.buf, QuoteCompositeFieldIfNeeded(string(fieldBuf))...) + } + + b.buf = append(b.buf, ',') +} + +func (b *CompositeTextBuilder) Finish() ([]byte, error) { + if b.err != nil { + return nil, b.err + } + + b.buf[len(b.buf)-1] = ')' + return b.buf, nil +} + var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) func quoteCompositeField(src string) string { diff --git a/composite_type_test.go b/composite_type_test.go index 92ecc849..17d34251 100644 --- a/composite_type_test.go +++ b/composite_type_test.go @@ -7,8 +7,10 @@ import ( "testing" "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" pgx "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCompositeTypeSetAndGet(t *testing.T) { @@ -130,6 +132,47 @@ func TestCompositeTypeAssignTo(t *testing.T) { } } +func TestCompositeTypeTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists ct_test; + +create type ct_test as ( + a text, + b int4 +);`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type ct_test") + + var oid uint32 + err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid) + require.NoError(t, err) + + defer conn.Exec(context.Background(), "drop type ct_test") + + ct := pgtype.NewCompositeType("ct_test", &pgtype.Text{}, &pgtype.Int4{}) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: "ct_test", OID: oid}) + + // Use simple protocol to force text or binary encoding + simpleProtocols := []bool{true, false} + + var a string + var b int32 + + for _, simpleProtocol := range simpleProtocols { + err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{"hi", int32(42)}, + ).Scan( + []interface{}{&a, &b}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 42, b, "Simple Protocol: %v", simpleProtocol) + } + } +} + func Example_composite() { conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { From eebc6975def21dd3e2faa3a66b3627993ef2a0d4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 15:45:16 -0500 Subject: [PATCH 0491/1158] Add EncodeText support for CompositeType --- composite_type.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/composite_type.go b/composite_type.go index 6baa639a..4aa65c3d 100644 --- a/composite_type.go +++ b/composite_type.go @@ -198,6 +198,22 @@ func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error { return nil } +func (src CompositeType) EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { + switch src.status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + b := NewCompositeTextBuilder(ci, buf) + for _, f := range src.fields { + b.AppendEncoder(f) + } + + return b.Finish() +} + type CompositeBinaryScanner struct { ci *ConnInfo rp int From 506ea3683521cc1e4cc8c0e1836b4de11582d911 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 15:47:44 -0500 Subject: [PATCH 0492/1158] Do not export quoteCompositeFieldIfNeeded --- composite_type.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/composite_type.go b/composite_type.go index 4aa65c3d..c37aef27 100644 --- a/composite_type.go +++ b/composite_type.go @@ -567,7 +567,7 @@ func (b *CompositeTextBuilder) AppendEncoder(field TextEncoder) { return } if fieldBuf != nil { - b.buf = append(b.buf, QuoteCompositeFieldIfNeeded(string(fieldBuf))...) + b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...) } b.buf = append(b.buf, ',') @@ -588,7 +588,7 @@ func quoteCompositeField(src string) string { return `"` + quoteCompositeReplacer.Replace(src) + `"` } -func QuoteCompositeFieldIfNeeded(src string) string { +func quoteCompositeFieldIfNeeded(src string) string { if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) { return quoteCompositeField(src) } From 9a3923b6e06923d858663c0ad1690583ef48016b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 15:51:27 -0500 Subject: [PATCH 0493/1158] EncodeRow is superceded by CompositeFields --- convert.go | 19 ------------------- custom_composite_test.go | 2 +- 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/convert.go b/convert.go index f170e05b..45c226be 100644 --- a/convert.go +++ b/convert.go @@ -433,25 +433,6 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { return nil, false } -// EncodeRow builds a binary representation of row values (row(), composite types) -func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) { - b := NewCompositeBinaryBuilder(ci, buf) - - for _, f := range fields { - dt, ok := ci.DataTypeForValue(f) - if !ok { - return nil, errors.Errorf("Unknown OID for %s", f) - } - binaryEncoder, ok := f.(BinaryEncoder) - if !ok { - return nil, errors.Errorf("record field doesn't implement binary encoding: %s", reflect.TypeOf(f).Name()) - } - b.AppendEncoder(dt.OID, binaryEncoder) - } - - return b.Finish() -} - func init() { kindTypes = map[reflect.Kind]reflect.Type{ reflect.Bool: reflect.TypeOf(false), diff --git a/custom_composite_test.go b/custom_composite_test.go index a93a8ad0..296fcc90 100644 --- a/custom_composite_test.go +++ b/custom_composite_test.go @@ -36,7 +36,7 @@ func (src MyType) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, b = pgtype.Text{Status: pgtype.Null} } - return pgtype.EncodeRow(ci, buf, &a, &b) + return (pgtype.CompositeFields{&a, &b}).EncodeBinary(ci, buf) } func ptrS(s string) *string { From b3e1355a466d62bdb678216ce292a14115641f81 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 16:58:16 -0500 Subject: [PATCH 0494/1158] CompositeType can assign to struct via reflection --- composite_type.go | 76 +++++++++++++++++++++++++++++++++++------- composite_type_test.go | 17 ++++++++++ 2 files changed, 81 insertions(+), 12 deletions(-) diff --git a/composite_type.go b/composite_type.go index c37aef27..7f5ae694 100644 --- a/composite_type.go +++ b/composite_type.go @@ -2,6 +2,7 @@ package pgtype import ( "encoding/binary" + "reflect" "strings" "github.com/jackc/pgio" @@ -102,24 +103,19 @@ func (src CompositeType) AssignTo(dst interface{}) error { continue } - assignToErr := src.fields[i].AssignTo(v[i]) - if assignToErr != nil { - // Try to use get / set instead -- this avoids every type having to be able to AssignTo type of self. - setSucceeded := false - if setter, ok := v[i].(Value); ok { - err := setter.Set(src.fields[i].Get()) - setSucceeded = err == nil - } - if !setSucceeded { - return errors.Errorf("unable to assign to dst[%d]: %v", i, assignToErr) - } + err := assignToOrSet(src.fields[i], v[i]) + if err != nil { + return errors.Errorf("unable to assign to dst[%d]: %v", i, err) } - } return nil case *[]interface{}: return src.AssignTo(*v) default: + if isPtrStruct, err := src.assignToPtrStruct(dst); isPtrStruct { + return err + } + if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } @@ -131,6 +127,62 @@ func (src CompositeType) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func assignToOrSet(src Value, dst interface{}) error { + assignToErr := src.AssignTo(dst) + if assignToErr != nil { + // Try to use get / set instead -- this avoids every type having to be able to AssignTo type of self. + setSucceeded := false + if setter, ok := dst.(Value); ok { + err := setter.Set(src.Get()) + setSucceeded = err == nil + } + if !setSucceeded { + return assignToErr + } + } + + return nil +} + +func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) { + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() != reflect.Ptr { + return false, nil + } + + if dstValue.IsNil() { + return false, nil + } + + dstElemValue := dstValue.Elem() + dstElemType := dstElemValue.Type() + + if dstElemType.Kind() != reflect.Struct { + return false, nil + } + + exportedFields := make([]int, 0, dstElemType.NumField()) + for i := 0; i < dstElemType.NumField(); i++ { + sf := dstElemType.Field(i) + if sf.PkgPath == "" { + exportedFields = append(exportedFields, i) + } + } + + if len(exportedFields) != len(src.fields) { + return false, nil + } + + for i := range exportedFields { + err := assignToOrSet(src.fields[i], dstElemValue.Field(exportedFields[i]).Addr().Interface()) + if err != nil { + return true, errors.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err) + } + } + + return true, nil +} + func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { switch src.status { case Null: diff --git a/composite_type_test.go b/composite_type_test.go index 17d34251..0225e443 100644 --- a/composite_type_test.go +++ b/composite_type_test.go @@ -130,6 +130,23 @@ func TestCompositeTypeAssignTo(t *testing.T) { assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a) assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b) } + + // Struct fields positionally via reflection + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + s := struct { + A string + B int32 + }{} + + err = ct.AssignTo(&s) + if assert.NoError(t, err) { + assert.Equal(t, "foo", s.A) + assert.Equal(t, int32(42), s.B) + } + } } func TestCompositeTypeTranscode(t *testing.T) { From 0e2bc3467a62ad36ef4ae5fb861e5905c8643678 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 21:01:06 -0500 Subject: [PATCH 0495/1158] Fix ext/gofrs-uuid AssignTo *uuid.UUID --- ext/gofrs-uuid/uuid.go | 1 + 1 file changed, 1 insertion(+) diff --git a/ext/gofrs-uuid/uuid.go b/ext/gofrs-uuid/uuid.go index fec912bc..e29933c9 100644 --- a/ext/gofrs-uuid/uuid.go +++ b/ext/gofrs-uuid/uuid.go @@ -77,6 +77,7 @@ func (src *UUID) AssignTo(dst interface{}) error { switch v := dst.(type) { case *uuid.UUID: *v = src.UUID + return nil case *[16]byte: *v = [16]byte(src.UUID) return nil From ee0e207ee4db74e81614065445d1c8c149f042ea Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 13 May 2020 07:09:52 -0500 Subject: [PATCH 0496/1158] CompositeType fields contain name and oid --- composite_bench_test.go | 13 +++++- composite_type.go | 99 ++++++++++++++++++++++++++--------------- composite_type_test.go | 33 +++++++++++--- 3 files changed, 102 insertions(+), 43 deletions(-) diff --git a/composite_bench_test.go b/composite_bench_test.go index cff9d518..7aef8c4f 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -5,6 +5,7 @@ import ( "github.com/jackc/pgio" "github.com/jackc/pgtype" + "github.com/stretchr/testify/require" ) type MyCompositeRaw struct { @@ -83,7 +84,11 @@ func BenchmarkBinaryEncodingComposite(b *testing.B) { ci := pgtype.NewConnInfo() f1 := 2 f2 := ptrS("bar") - c := pgtype.NewCompositeType("test", &pgtype.Int4{}, &pgtype.Text{}) + c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ + {"a", pgtype.Int4OID}, + {"b", pgtype.TextOID}, + }, ci) + require.NoError(b, err) b.ResetTimer() for n := 0; n < b.N; n++ { @@ -146,7 +151,11 @@ func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { var f1 int var f2 *string - c := pgtype.NewCompositeType("test", &pgtype.Int4{}, &pgtype.Text{}) + c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ + {"a", pgtype.Int4OID}, + {"b", pgtype.TextOID}, + }, ci) + require.NoError(b, err) b.ResetTimer() for n := 0; n < b.N; n++ { diff --git a/composite_type.go b/composite_type.go index 7f5ae694..389bf178 100644 --- a/composite_type.go +++ b/composite_type.go @@ -9,30 +9,59 @@ import ( errors "golang.org/x/xerrors" ) +type CompositeTypeField struct { + Name string + OID uint32 +} + type CompositeType struct { status Status typeName string - fields []ValueTranscoder + + fields []CompositeTypeField + valueTranscoders []ValueTranscoder } -// NewCompositeType creates a Composite object, which acts as a "schema" for -// SQL composite values. -// To pass Composite as SQL parameter first set it's fields, either by -// passing initialized Value{} instances to NewCompositeType or by calling -// SetFields method -// To read composite fields back pass result of Scan() method -// to query Scan function. -func NewCompositeType(typeName string, fields ...ValueTranscoder) *CompositeType { - return &CompositeType{typeName: typeName, fields: fields} +// NewCompositeType creates a CompositeType from fields and ci. ci is used to find the ValueTranscoders used +// for fields. All field OIDs must be previously registered in ci. +func NewCompositeType(typeName string, fields []CompositeTypeField, ci *ConnInfo) (*CompositeType, error) { + valueTranscoders := make([]ValueTranscoder, len(fields)) + + for i := range fields { + dt, ok := ci.DataTypeForOID(fields[i].OID) + if !ok { + return nil, errors.Errorf("no data type registered for oid: %d", fields[i].OID) + } + + value := NewValue(dt.Value) + valueTranscoder, ok := value.(ValueTranscoder) + if !ok { + return nil, errors.Errorf("data type for oid does not implement ValueTranscoder: %d", fields[i].OID) + } + + valueTranscoders[i] = valueTranscoder + } + + return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: valueTranscoders}, nil +} + +// NewCompositeTypeValues creates a CompositeType from fields and values. fields and values must have the same length. +// Prefer NewCompositeType unless overriding the transcoding of fields is required. +func NewCompositeTypeValues(typeName string, fields []CompositeTypeField, values []ValueTranscoder) (*CompositeType, error) { + if len(fields) != len(values) { + return nil, errors.New("fields and valueTranscoders must have same length") + } + + return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: values}, nil } func (src CompositeType) Get() interface{} { switch src.status { case Present: - results := make([]interface{}, len(src.fields)) + results := make([]interface{}, len(src.valueTranscoders)) for i := range results { - results[i] = src.fields[i].Get() + results[i] = src.valueTranscoders[i].Get() } return results case Null: @@ -44,12 +73,13 @@ func (src CompositeType) Get() interface{} { func (ct *CompositeType) NewTypeValue() Value { a := &CompositeType{ - typeName: ct.typeName, - fields: make([]ValueTranscoder, len(ct.fields)), + typeName: ct.typeName, + fields: ct.fields, + valueTranscoders: make([]ValueTranscoder, len(ct.valueTranscoders)), } - for i := range ct.fields { - a.fields[i] = NewValue(ct.fields[i]).(ValueTranscoder) + for i := range ct.valueTranscoders { + a.valueTranscoders[i] = NewValue(ct.valueTranscoders[i]).(ValueTranscoder) } return a @@ -59,6 +89,10 @@ func (ct *CompositeType) TypeName() string { return ct.typeName } +func (ct *CompositeType) Fields() []CompositeTypeField { + return ct.fields +} + func (dst *CompositeType) Set(src interface{}) error { if src == nil { dst.status = Null @@ -67,11 +101,11 @@ func (dst *CompositeType) Set(src interface{}) error { switch value := src.(type) { case []interface{}: - if len(value) != len(dst.fields) { - return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.fields)) + if len(value) != len(dst.valueTranscoders) { + return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.valueTranscoders)) } for i, v := range value { - if err := dst.fields[i].Set(v); err != nil { + if err := dst.valueTranscoders[i].Set(v); err != nil { return err } } @@ -95,15 +129,15 @@ func (src CompositeType) AssignTo(dst interface{}) error { case Present: switch v := dst.(type) { case []interface{}: - if len(v) != len(src.fields) { - return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.fields)) + if len(v) != len(src.valueTranscoders) { + return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders)) } - for i := range src.fields { + for i := range src.valueTranscoders { if v[i] == nil { continue } - err := assignToOrSet(src.fields[i], v[i]) + err := assignToOrSet(src.valueTranscoders[i], v[i]) if err != nil { return errors.Errorf("unable to assign to dst[%d]: %v", i, err) } @@ -169,12 +203,12 @@ func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) { } } - if len(exportedFields) != len(src.fields) { + if len(exportedFields) != len(src.valueTranscoders) { return false, nil } for i := range exportedFields { - err := assignToOrSet(src.fields[i], dstElemValue.Field(exportedFields[i]).Addr().Interface()) + err := assignToOrSet(src.valueTranscoders[i], dstElemValue.Field(exportedFields[i]).Addr().Interface()) if err != nil { return true, errors.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err) } @@ -192,13 +226,8 @@ func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, } b := NewCompositeBinaryBuilder(ci, buf) - for _, f := range src.fields { - dt, ok := ci.DataTypeForValue(f) - if !ok { - return nil, errors.Errorf("unknown oid") - } - - b.AppendEncoder(dt.OID, f) + for i := range src.valueTranscoders { + b.AppendEncoder(src.fields[i].OID, src.valueTranscoders[i]) } return b.Finish() @@ -216,7 +245,7 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { scanner := NewCompositeBinaryScanner(ci, buf) - for _, f := range dst.fields { + for _, f := range dst.valueTranscoders { scanner.ScanDecoder(f) } @@ -237,7 +266,7 @@ func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error { scanner := NewCompositeTextScanner(ci, buf) - for _, f := range dst.fields { + for _, f := range dst.valueTranscoders { scanner.ScanDecoder(f) } @@ -259,7 +288,7 @@ func (src CompositeType) EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, er } b := NewCompositeTextBuilder(ci, buf) - for _, f := range src.fields { + for _, f := range src.valueTranscoders { b.AppendEncoder(f) } diff --git a/composite_type_test.go b/composite_type_test.go index 0225e443..b32810ff 100644 --- a/composite_type_test.go +++ b/composite_type_test.go @@ -14,7 +14,12 @@ import ( ) func TestCompositeTypeSetAndGet(t *testing.T) { - ct := pgtype.NewCompositeType("test", &pgtype.Text{}, &pgtype.Int4{}) + ci := pgtype.NewConnInfo() + ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ + {"a", pgtype.TextOID}, + {"b", pgtype.Int4OID}, + }, ci) + require.NoError(t, err) assert.Equal(t, pgtype.Undefined, ct.Get()) nilTests := []struct { @@ -56,7 +61,12 @@ func TestCompositeTypeSetAndGet(t *testing.T) { } func TestCompositeTypeAssignTo(t *testing.T) { - ct := pgtype.NewCompositeType("test", &pgtype.Text{}, &pgtype.Int4{}) + ci := pgtype.NewConnInfo() + ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ + {"a", pgtype.TextOID}, + {"b", pgtype.Int4OID}, + }, ci) + require.NoError(t, err) { err := ct.Set([]interface{}{"foo", int32(42)}) @@ -168,8 +178,12 @@ create type ct_test as ( defer conn.Exec(context.Background(), "drop type ct_test") - ct := pgtype.NewCompositeType("ct_test", &pgtype.Text{}, &pgtype.Int4{}) - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: "ct_test", OID: oid}) + ct, err := pgtype.NewCompositeType("ct_test", []pgtype.CompositeTypeField{ + {"a", pgtype.TextOID}, + {"b", pgtype.Int4OID}, + }, conn.ConnInfo()) + require.NoError(t, err) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid}) // Use simple protocol to force text or binary encoding simpleProtocols := []bool{true, false} @@ -221,8 +235,15 @@ func Example_composite() { return } - c := pgtype.NewCompositeType("mytype", &pgtype.Int4{}, &pgtype.Text{}) - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: c, Name: "mytype", OID: oid}) + ct, err := pgtype.NewCompositeType("mytype", []pgtype.CompositeTypeField{ + {"a", pgtype.Int4OID}, + {"b", pgtype.TextOID}, + }, conn.ConnInfo()) + if err != nil { + fmt.Println(err) + return + } + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid}) var a int var b *string From f8471ebfa8cfbbccd3bcb72748714f52dae9663d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 13 May 2020 07:11:10 -0500 Subject: [PATCH 0497/1158] ArrayType requires element OID --- array_type.go | 16 +++++----------- array_type_test.go | 4 ++-- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/array_type.go b/array_type.go index 4fe772fd..5de39818 100644 --- a/array_type.go +++ b/array_type.go @@ -18,11 +18,12 @@ type ArrayType struct { status Status typeName string + elementOID uint32 newElement func() ValueTranscoder } -func NewArrayType(typeName string, newElement func() ValueTranscoder) *ArrayType { - return &ArrayType{typeName: typeName, newElement: newElement} +func NewArrayType(typeName string, elementOID uint32, newElement func() ValueTranscoder) *ArrayType { + return &ArrayType{typeName: typeName, elementOID: elementOID, newElement: newElement} } func (at *ArrayType) NewTypeValue() Value { @@ -32,6 +33,7 @@ func (at *ArrayType) NewTypeValue() Value { status: at.status, typeName: at.typeName, + elementOID: at.elementOID, newElement: at.newElement, } } @@ -281,15 +283,7 @@ func (src ArrayType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { arrayHeader := ArrayHeader{ Dimensions: src.dimensions, - } - - { - value := src.newElement() - if dt, ok := ci.DataTypeForValue(value); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for element type %v", value) - } + ElementOID: int32(src.elementOID), } for i := range src.elements { diff --git a/array_type_test.go b/array_type_test.go index d0812a67..0f296bb5 100644 --- a/array_type_test.go +++ b/array_type_test.go @@ -10,7 +10,7 @@ import ( ) func TestArrayTypeValue(t *testing.T) { - arrayType := pgtype.NewArrayType("_text", func() pgtype.ValueTranscoder { return &pgtype.Text{} }) + arrayType := pgtype.NewArrayType("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }) err := arrayType.Set(nil) require.NoError(t, err) @@ -49,7 +49,7 @@ func TestArrayTypeTranscode(t *testing.T) { defer testutil.MustCloseContext(t, conn) conn.ConnInfo().RegisterDataType(pgtype.DataType{ - Value: pgtype.NewArrayType("_text", func() pgtype.ValueTranscoder { return &pgtype.Text{} }), + Value: pgtype.NewArrayType("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }), Name: "_text", OID: pgtype.TextArrayOID, }) From 6a1a9d05bc259886ab8987286f13ee0cfb5e1d13 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 13 May 2020 07:34:10 -0500 Subject: [PATCH 0498/1158] Add pgxtype package for simpler type registration --- go.mod | 1 + pgxtype/README.md | 3 + pgxtype/pgxtype.go | 145 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 149 insertions(+) create mode 100644 pgxtype/README.md create mode 100644 pgxtype/pgxtype.go diff --git a/go.mod b/go.mod index 35ba688e..c7738ac9 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.12 require ( github.com/gofrs/uuid v3.2.0+incompatible + github.com/jackc/pgconn v1.5.0 github.com/jackc/pgio v1.0.0 github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9 github.com/lib/pq v1.3.0 diff --git a/pgxtype/README.md b/pgxtype/README.md new file mode 100644 index 00000000..a070111f --- /dev/null +++ b/pgxtype/README.md @@ -0,0 +1,3 @@ +# pgxtype + +pgxtype is a helper module that connects pgx and pgtype. This package is not currently covered by semantic version guarantees. i.e. The interfaces may change without a major version release of pgtype. diff --git a/pgxtype/pgxtype.go b/pgxtype/pgxtype.go new file mode 100644 index 00000000..041f2545 --- /dev/null +++ b/pgxtype/pgxtype.go @@ -0,0 +1,145 @@ +package pgxtype + +import ( + "context" + "errors" + + "github.com/jackc/pgconn" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" +) + +type Querier interface { + Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) + Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row +} + +// LoadDataType uses conn to inspect the database for typeName and produces a pgtype.DataType suitable for +// registration on ci. +func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeName string) (pgtype.DataType, error) { + var oid uint32 + + err := conn.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid) + if err != nil { + return pgtype.DataType{}, err + } + + var typtype string + + err = conn.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype) + if err != nil { + return pgtype.DataType{}, err + } + + switch typtype { + case "b": // array + elementOID, err := GetArrayElementOID(ctx, conn, oid) + if err != nil { + return pgtype.DataType{}, err + } + + var element pgtype.ValueTranscoder + if dt, ok := ci.DataTypeForOID(elementOID); ok { + if element, ok = dt.Value.(pgtype.ValueTranscoder); !ok { + return pgtype.DataType{}, errors.New("array element OID not registered as ValueTranscoder") + } + } + + newElement := func() pgtype.ValueTranscoder { + return pgtype.NewValue(element).(pgtype.ValueTranscoder) + } + + at := pgtype.NewArrayType(typeName, elementOID, newElement) + return pgtype.DataType{Value: at, Name: typeName, OID: oid}, nil + case "c": // composite + fields, err := GetCompositeFields(ctx, conn, oid) + if err != nil { + return pgtype.DataType{}, err + } + ct, err := pgtype.NewCompositeType(typeName, fields, ci) + if err != nil { + return pgtype.DataType{}, err + } + return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil + case "e": // enum + members, err := GetEnumMembers(ctx, conn, oid) + if err != nil { + return pgtype.DataType{}, err + } + return pgtype.DataType{Value: pgtype.NewEnumType(typeName, members), Name: typeName, OID: oid}, nil + default: + return pgtype.DataType{}, errors.New("unknown typtype") + } +} + +func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32, error) { + var typelem uint32 + + err := conn.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem) + if err != nil { + return 0, err + } + + return typelem, nil +} + +// GetCompositeFields gets the fields of a composite type. +func GetCompositeFields(ctx context.Context, conn Querier, oid uint32) ([]pgtype.CompositeTypeField, error) { + var typrelid uint32 + + err := conn.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid) + if err != nil { + return nil, err + } + + var fields []pgtype.CompositeTypeField + + rows, err := conn.Query(ctx, `select attname, atttypid +from pg_attribute +where attrelid=$1 +order by attnum`, typrelid) + if err != nil { + return nil, err + } + + for rows.Next() { + var f pgtype.CompositeTypeField + err := rows.Scan(&f.Name, &f.OID) + if err != nil { + return nil, err + } + fields = append(fields, f) + } + + if rows.Err() != nil { + return nil, rows.Err() + } + + return fields, nil +} + +// GetEnumMembers gets the possible values of the enum by oid. +func GetEnumMembers(ctx context.Context, conn Querier, oid uint32) ([]string, error) { + members := []string{} + + rows, err := conn.Query(ctx, "select enumlabel from pg_enum where enumtypid=$1 order by enumsortorder", oid) + if err != nil { + return nil, err + } + + for rows.Next() { + var m string + err := rows.Scan(&m) + if err != nil { + return nil, err + } + members = append(members, m) + } + + if rows.Err() != nil { + return nil, rows.Err() + } + + return members, nil +} From 238967ec4e4c1d3e61d704fc815935d2104b5574 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 13 May 2020 08:05:19 -0500 Subject: [PATCH 0499/1158] Improve accuracy of numeric to float fixes #27 --- numeric.go | 17 +++++++---------- numeric_test.go | 1 + 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/numeric.go b/numeric.go index e6c58391..fc8e1789 100644 --- a/numeric.go +++ b/numeric.go @@ -291,19 +291,16 @@ func (dst *Numeric) toBigInt() (*big.Int, error) { } func (src *Numeric) toFloat64() (float64, error) { - f, err := strconv.ParseFloat(src.Int.String(), 64) + buf := make([]byte, 0, 32) + + buf = append(buf, src.Int.String()...) + buf = append(buf, 'e') + buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...) + + f, err := strconv.ParseFloat(string(buf), 64) if err != nil { return 0, err } - if src.Exp > 0 { - for i := 0; i < int(src.Exp); i++ { - f *= 10 - } - } else if src.Exp < 0 { - for i := 0; i > int(src.Exp); i-- { - f /= 10 - } - } return f, nil } diff --git a/numeric_test.go b/numeric_test.go index b925be83..263c78b6 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -266,6 +266,7 @@ func TestNumericAssignTo(t *testing.T) { {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Status: pgtype.Present}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgtype/issues/27 } for i, tt := range simpleTests { From 2ccb66fe2159792f5b28d01e65e2461795a6f854 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 May 2020 18:48:05 -0500 Subject: [PATCH 0500/1158] Doc fix --- pgconn.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pgconn.go b/pgconn.go index 4a6ef430..541f280e 100644 --- a/pgconn.go +++ b/pgconn.go @@ -890,11 +890,11 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { // ExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). // // paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or -// binary format. If paramFormats is nil all results will be in text protocol. ExecParams will panic if +// binary format. If paramFormats is nil all params are text format. ExecParams will panic if // len(paramFormats) is not 0, 1, or len(paramValues). // // resultFormats is a slice of format codes determining for each result column whether it is encoded in text or -// binary format. If resultFormats is nil all results will be in text protocol. +// binary format. If resultFormats is nil all results will be in text format. // // ResultReader must be closed before PgConn can be used again. func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { @@ -917,11 +917,11 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] // paramValues are the parameter values. It must be encoded in the format given by paramFormats. // // paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or -// binary format. If paramFormats is nil all results will be in text protocol. ExecPrepared will panic if +// binary format. If paramFormats is nil all params are text format. ExecPrepared will panic if // len(paramFormats) is not 0, 1, or len(paramValues). // // resultFormats is a slice of format codes determining for each result column whether it is encoded in text or -// binary format. If resultFormats is nil all results will be in text protocol. +// binary format. If resultFormats is nil all results will be in text format. // // ResultReader must be closed before PgConn can be used again. func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { From afff6abc6c79872e19ac2b84748de37883da964a Mon Sep 17 00:00:00 2001 From: Pablo Morelli Date: Wed, 20 May 2020 15:01:21 +0200 Subject: [PATCH 0501/1158] TID AssignTo string --- tid.go | 13 +++++++++++++ tid_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/tid.go b/tid.go index 98b95e2a..f7b80f94 100644 --- a/tid.go +++ b/tid.go @@ -44,6 +44,19 @@ func (dst TID) Get() interface{} { } func (src *TID) AssignTo(dst interface{}) error { + if src.Status == Present { + switch v := dst.(type) { + case *string: + *v = fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return errors.Errorf("unable to assign to %T", dst) + } + } + return errors.Errorf("cannot assign %v to %T", src, dst) } diff --git a/tid_test.go b/tid_test.go index 773bd96f..818be8af 100644 --- a/tid_test.go +++ b/tid_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "reflect" "testing" "github.com/jackc/pgtype" @@ -14,3 +15,49 @@ func TestTIDTranscode(t *testing.T) { &pgtype.TID{Status: pgtype.Null}, }) } + +func TestTIDAssignTo(t *testing.T) { + var s string + var sp *string + + simpleTests := []struct { + src pgtype.TID + dst interface{} + expected interface{} + }{ + {src: pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, dst: &s, expected: "(42,43)"}, + {src: pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, dst: &s, expected: "(4294967295,65535)"}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.TID + dst interface{} + expected interface{} + }{ + {src: pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, dst: &sp, expected: "(42,43)"}, + {src: pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, dst: &sp, expected: "(4294967295,65535)"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} + From 8c33aa24430a9bbd9d34af6d8c211ab632f22e17 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 May 2020 11:47:42 -0500 Subject: [PATCH 0502/1158] Remove CPU wasting empty default statement fixes #39 --- pgconn.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 541f280e..43edbb6b 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1151,7 +1151,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co default: signalMessageChan = pgConn.signalMessage() } - default: } } close(abortCopyChan) From 2647eff5675f7a45d02b82b633580357b11e05ad Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 25 May 2020 11:49:37 -0500 Subject: [PATCH 0503/1158] Fix ValidateConnect with cancelable context fixes #40 --- pgconn.go | 7 +++++++ pgconn_test.go | 7 +++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/pgconn.go b/pgconn.go index 43edbb6b..5644904a 100644 --- a/pgconn.go +++ b/pgconn.go @@ -288,6 +288,13 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle if config.ValidateConnect != nil { + // ValidateConnect may execute commands that cause the context to be watched again. Unwatch first to avoid + // the watch already in progress panic. This is that last thing done by this method so there is no need to + // restart the watch after ValidateConnect returns. + // + // See https://github.com/jackc/pgconn/issues/40. + pgConn.contextWatcher.Unwatch() + err := config.ValidateConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() diff --git a/pgconn_test.go b/pgconn_test.go index 9a75dede..6362c51b 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -346,9 +346,12 @@ func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite config.RuntimeParams["default_transaction_read_only"] = "on" - conn, err := pgconn.ConnectConfig(context.Background(), config) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + conn, err := pgconn.ConnectConfig(ctx, config) if !assert.NotNil(t, err) { - conn.Close(context.Background()) + conn.Close(ctx) } } From 8d541d00043bb65c60ee837927dc06c163729954 Mon Sep 17 00:00:00 2001 From: georgysavva Date: Mon, 1 Jun 2020 19:20:17 +0300 Subject: [PATCH 0504/1158] Add Config.Copy() method that return a smart copy of the config. --- config.go | 9 +++++++++ config_test.go | 33 +++++++++++++++++++++++++++++++++ go.mod | 1 + go.sum | 2 ++ 4 files changed, 45 insertions(+) diff --git a/config.go b/config.go index 299d4784..6640038b 100644 --- a/config.go +++ b/config.go @@ -17,6 +17,8 @@ import ( "strings" "time" + "github.com/mohae/deepcopy" + "github.com/jackc/chunkreader/v2" "github.com/jackc/pgpassfile" "github.com/jackc/pgproto3/v2" @@ -62,6 +64,13 @@ type Config struct { createdByParseConfig bool // Used to enforce created by ParseConfig rule. } +func (c *Config) Copy() *Config { + newConf := deepcopy.Copy(c).(*Config) + // We need to set this field manually because it's unexported and deep copy won't touch it. + newConf.createdByParseConfig = c.createdByParseConfig + return newConf +} + // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a // network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections. type FallbackConfig struct { diff --git a/config_test.go b/config_test.go index 515ea6d3..72b775d4 100644 --- a/config_test.go +++ b/config_test.go @@ -1,6 +1,7 @@ package pgconn_test import ( + "context" "crypto/tls" "fmt" "io/ioutil" @@ -527,6 +528,38 @@ func TestParseConfig(t *testing.T) { } } +func TestConfigCopyReturnsEqualConfig(t *testing.T) { + connString := "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5" + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") +} + +func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) { + connString := "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5" + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + copied.Port = uint16(5433) + + assert.Equal(t, uint16(5432), original.Port) +} + +func TestConfigCopyCanBeUsedToConnect(t *testing.T) { + connString := os.Getenv("PGX_TEST_CONN_STRING") + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assert.NotPanics(t, func() { + _, err = pgconn.ConnectConfig(context.Background(), copied) + }) + assert.NoError(t, err) +} + func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { if !assert.NotNil(t, expected) { return diff --git a/go.mod b/go.mod index 4dc095ca..841eccc7 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.1 github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 + github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 golang.org/x/text v0.3.2 diff --git a/go.sum b/go.sum index 23fb8b32..1514a339 100644 --- a/go.sum +++ b/go.sum @@ -54,6 +54,8 @@ github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= From a6d9265506df51336914769fd786ef8712997c31 Mon Sep 17 00:00:00 2001 From: georgysavva Date: Mon, 1 Jun 2020 20:52:08 +0300 Subject: [PATCH 0505/1158] Implement deep copy manually, stop using an external deep copy library. Add comment to the Config.Copy() method. --- config.go | 30 +++++++++++++++++++++++++----- config_test.go | 10 ++++++++-- go.mod | 1 - go.sum | 2 -- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/config.go b/config.go index 6640038b..7ed99096 100644 --- a/config.go +++ b/config.go @@ -17,8 +17,6 @@ import ( "strings" "time" - "github.com/mohae/deepcopy" - "github.com/jackc/chunkreader/v2" "github.com/jackc/pgpassfile" "github.com/jackc/pgproto3/v2" @@ -64,10 +62,32 @@ type Config struct { createdByParseConfig bool // Used to enforce created by ParseConfig rule. } +// Copy returns a deep copy of the config that is safe to use and modify. +// The only exception is the TLSConfig field: +// according to the tls.Config docs it must not be modified after creation. func (c *Config) Copy() *Config { - newConf := deepcopy.Copy(c).(*Config) - // We need to set this field manually because it's unexported and deep copy won't touch it. - newConf.createdByParseConfig = c.createdByParseConfig + newConf := new(Config) + *newConf = *c + if newConf.TLSConfig != nil { + newConf.TLSConfig = c.TLSConfig.Clone() + } + if newConf.RuntimeParams != nil { + newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams)) + for k, v := range c.RuntimeParams { + newConf.RuntimeParams[k] = v + } + } + if newConf.Fallbacks != nil { + newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks)) + for i, fallback := range c.Fallbacks { + newFallback := new(FallbackConfig) + *newFallback = *fallback + if newFallback.TLSConfig != nil { + newFallback.TLSConfig = fallback.TLSConfig.Clone() + } + newConf.Fallbacks[i] = newFallback + } + } return newConf } diff --git a/config_test.go b/config_test.go index 72b775d4..ebe627b1 100644 --- a/config_test.go +++ b/config_test.go @@ -529,7 +529,7 @@ func TestParseConfig(t *testing.T) { } func TestConfigCopyReturnsEqualConfig(t *testing.T) { - connString := "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5" + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" original, err := pgconn.ParseConfig(connString) require.NoError(t, err) @@ -538,14 +538,20 @@ func TestConfigCopyReturnsEqualConfig(t *testing.T) { } func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) { - connString := "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5" + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" original, err := pgconn.ParseConfig(connString) require.NoError(t, err) copied := original.Copy() + assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") + copied.Port = uint16(5433) + copied.RuntimeParams["foo"] = "bar" + copied.Fallbacks[0].Port = uint16(5433) assert.Equal(t, uint16(5432), original.Port) + assert.Equal(t, "", original.RuntimeParams["foo"]) + assert.Equal(t, uint16(5432), original.Fallbacks[0].Port) } func TestConfigCopyCanBeUsedToConnect(t *testing.T) { diff --git a/go.mod b/go.mod index 841eccc7..4dc095ca 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,6 @@ require ( github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.1 github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 - github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 golang.org/x/text v0.3.2 diff --git a/go.sum b/go.sum index 1514a339..23fb8b32 100644 --- a/go.sum +++ b/go.sum @@ -54,8 +54,6 @@ github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= -github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= From 3cbb81631a4047fa418c9ec594a63d861cd0e8fa Mon Sep 17 00:00:00 2001 From: leighhopcroft Date: Tue, 2 Jun 2020 18:35:58 +0100 Subject: [PATCH 0506/1158] added NaN support to Numeric.Set --- numeric.go | 6 ++++++ numeric_test.go | 2 ++ 2 files changed, 8 insertions(+) diff --git a/numeric.go b/numeric.go index fc8e1789..f4ddb789 100644 --- a/numeric.go +++ b/numeric.go @@ -64,12 +64,18 @@ func (dst *Numeric) Set(src interface{}) error { switch value := src.(type) { case float32: + if math.IsNaN(float64(value)) { + return nil + } num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) if err != nil { return err } *dst = Numeric{Int: num, Exp: exp, Status: Present} case float64: + if math.IsNaN(value) { + return nil + } num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) if err != nil { return err diff --git a/numeric_test.go b/numeric_test.go index 263c78b6..3b099b55 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -210,6 +210,8 @@ func TestNumericSet(t *testing.T) { {source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Status: pgtype.Present}}, {source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Status: pgtype.Present}}, {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}}, + {source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Undefined}}, + {source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Undefined}}, } for i, tt := range successfulTests { From b708c8b985ce0602ab017cab5f03a6ae7fd2f2be Mon Sep 17 00:00:00 2001 From: leighhopcroft Date: Tue, 2 Jun 2020 19:07:10 +0100 Subject: [PATCH 0507/1158] support NaN in Numeric.AssignTo --- numeric.go | 7 +++++++ numeric_test.go | 18 ++++++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/numeric.go b/numeric.go index f4ddb789..644ee23f 100644 --- a/numeric.go +++ b/numeric.go @@ -267,6 +267,13 @@ func (src *Numeric) AssignTo(dst interface{}) error { } case Null: return NullAssignTo(dst) + case Undefined: + switch v := dst.(type) { + case *float32: + *v = float32(math.NaN()) + case *float64: + *v = math.NaN() + } } return nil diff --git a/numeric_test.go b/numeric_test.go index 3b099b55..ee72ff5e 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -269,6 +269,8 @@ func TestNumericAssignTo(t *testing.T) { {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Status: pgtype.Present}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgtype/issues/27 + {src: &pgtype.Numeric{}, dst: &f64, expected: math.NaN()}, + {src: &pgtype.Numeric{}, dst: &f32, expected: float32(math.NaN())}, } for i, tt := range simpleTests { @@ -277,8 +279,20 @@ func TestNumericAssignTo(t *testing.T) { t.Errorf("%d: %v", i, err) } - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + dst := reflect.ValueOf(tt.dst).Elem().Interface() + switch dstTyped := dst.(type) { + case float32: + if math.IsNaN(float64(tt.expected.(float32))) && !math.IsNaN(float64(dstTyped)) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + case float64: + if math.IsNaN(tt.expected.(float64)) && !math.IsNaN(dstTyped) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + default: + if dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } } } From f2a2797a88765112814ec47c2b02bec97451278a Mon Sep 17 00:00:00 2001 From: leighhopcroft Date: Tue, 2 Jun 2020 20:14:51 +0100 Subject: [PATCH 0508/1158] support NaN in Numeric encode and decode methods --- numeric.go | 32 ++++++++++++++++++++++++-------- numeric_test.go | 11 ++++++++--- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/numeric.go b/numeric.go index 644ee23f..7ee517be 100644 --- a/numeric.go +++ b/numeric.go @@ -15,6 +15,11 @@ import ( // PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 const nbase = 10000 +const ( + pgNumericNaN = 0x000000000c000000 + pgNumericNaNSign = 0x0c00 +) + var big0 *big.Int = big.NewInt(0) var big1 *big.Int = big.NewInt(1) var big10 *big.Int = big.NewInt(10) @@ -323,6 +328,11 @@ func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { return nil } + if string(src) == "NaN" { + *dst = Numeric{} + return nil + } + num, exp, err := parseNumericString(string(src)) if err != nil { return err @@ -366,12 +376,6 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { rp := 0 ndigits := int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 - - if ndigits == 0 { - *dst = Numeric{Int: big.NewInt(0), Status: Present} - return nil - } - weight := int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 sign := int16(binary.BigEndian.Uint16(src[rp:])) @@ -379,6 +383,16 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { dscale := int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 + if sign == pgNumericNaNSign { + *dst = Numeric{} + return nil + } + + if ndigits == 0 { + *dst = Numeric{Int: big.NewInt(0), Status: Present} + return nil + } + if len(src[rp:]) < int(ndigits)*2 { return errors.Errorf("numeric incomplete %v", src) } @@ -477,7 +491,8 @@ func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Null: return nil, nil case Undefined: - return nil, errUndefined + buf = append(buf, []byte("NaN")...) + return buf, nil } buf = append(buf, src.Int.String()...) @@ -491,7 +506,8 @@ func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Null: return nil, nil case Undefined: - return nil, errUndefined + buf = pgio.AppendUint64(buf, pgNumericNaN) + return buf, nil } var sign int16 diff --git a/numeric_test.go b/numeric_test.go index ee72ff5e..259f397e 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -344,6 +344,8 @@ func TestNumericEncodeDecodeBinary(t *testing.T) { 123, 0.000012345, 1.00002345, + math.NaN(), + float32(math.NaN()), } for i, tt := range tests { @@ -351,7 +353,7 @@ func TestNumericEncodeDecodeBinary(t *testing.T) { ci := pgtype.NewConnInfo() text, err := n.EncodeText(ci, nil) if err != nil { - t.Errorf("%d: %v", i, err) + t.Errorf("%d (EncodeText): %v", i, err) } return string(text) } @@ -360,10 +362,13 @@ func TestNumericEncodeDecodeBinary(t *testing.T) { encoded, err := numeric.EncodeBinary(ci, nil) if err != nil { - t.Errorf("%d: %v", i, err) + t.Errorf("%d (EncodeBinary): %v", i, err) } decoded := &pgtype.Numeric{} - decoded.DecodeBinary(ci, encoded) + err = decoded.DecodeBinary(ci, encoded) + if err != nil { + t.Errorf("%d (DecodeBinary): %v", i, err) + } text0 := toString(numeric) text1 := toString(decoded) From 43e4070cb4b816954a9d4e0618f1fb344e98e9c5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 5 Jun 2020 13:39:53 -0500 Subject: [PATCH 0509/1158] Better CompositeType and ArrayType Get implementation --- array_type.go | 6 +++++- composite_type.go | 6 +++--- composite_type_test.go | 8 ++++---- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/array_type.go b/array_type.go index 5de39818..9454021b 100644 --- a/array_type.go +++ b/array_type.go @@ -84,7 +84,11 @@ func (dst *ArrayType) Set(src interface{}) error { func (dst ArrayType) Get() interface{} { switch dst.status { case Present: - return dst.elements + elementValues := make([]interface{}, len(dst.elements)) + for i := range dst.elements { + elementValues[i] = dst.elements[i].Get() + } + return elementValues case Null: return nil default: diff --git a/composite_type.go b/composite_type.go index 389bf178..49ce70fa 100644 --- a/composite_type.go +++ b/composite_type.go @@ -59,9 +59,9 @@ func NewCompositeTypeValues(typeName string, fields []CompositeTypeField, values func (src CompositeType) Get() interface{} { switch src.status { case Present: - results := make([]interface{}, len(src.valueTranscoders)) - for i := range results { - results[i] = src.valueTranscoders[i].Get() + results := make(map[string]interface{}, len(src.valueTranscoders)) + for i := range src.valueTranscoders { + results[src.fields[i].Name] = src.valueTranscoders[i].Get() } return results case Null: diff --git a/composite_type_test.go b/composite_type_test.go index b32810ff..664fe36e 100644 --- a/composite_type_test.go +++ b/composite_type_test.go @@ -37,19 +37,19 @@ func TestCompositeTypeSetAndGet(t *testing.T) { compatibleValuesTests := []struct { src []interface{} - expected []interface{} + expected map[string]interface{} }{ { src: []interface{}{"foo", int32(42)}, - expected: []interface{}{"foo", int32(42)}, + expected: map[string]interface{}{"a": "foo", "b": int32(42)}, }, { src: []interface{}{nil, nil}, - expected: []interface{}{nil, nil}, + expected: map[string]interface{}{"a": nil, "b": nil}, }, { src: []interface{}{&pgtype.Text{String: "hi", Status: pgtype.Present}, &pgtype.Int4{Int: 7, Status: pgtype.Present}}, - expected: []interface{}{"hi", int32(7)}, + expected: map[string]interface{}{"a": "hi", "b": int32(7)}, }, } From 91a46ce219f5ebbea3ffb6070ac8c5bffa0ff954 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Jun 2020 08:22:44 -0500 Subject: [PATCH 0510/1158] Clarify and normalize Value semantics Previously, Get implicitly allowed returning a reference to an internal value (e.g. a []byte) but AssignTo was documented as requiring a deep copy. This inconsistency meant that either Get was unsafe or the deep copy in AssignTo was superfluous. In addition, Scan into a []byte skips going through Bytea and returns a []byte of the unparsed bytes directly. i.e. a reference not a copy. Standardize on allowing Get and AssignTo to return internal references but require a Value never mutate internal values - only replace them. --- pgtype.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/pgtype.go b/pgtype.go index 091e98c4..5fe78704 100644 --- a/pgtype.go +++ b/pgtype.go @@ -115,16 +115,23 @@ const ( BinaryFormatCode = 1 ) +// Value translates values to and from an internal canonical representation for the type. To actually be usable a type +// that implements Value should also implement some combination of BinaryDecoder, BinaryEncoder, TextDecoder, +// and TextEncoder. +// +// Operations that update a Value (e.g. Set, DecodeText, DecodeBinary) should entirely replace the value. e.g. Internal +// slices should be replaced not resized and reused. This allows Get and AssignTo to return a slice directly rather +// than incur a usually unnecessary copy. type Value interface { - // Set converts and assigns src to itself. + // Set converts and assigns src to itself. Value takes ownership of src. Set(src interface{}) error - // Get returns the simplest representation of Value. If no simpler representation is - // possible, then Get() returns Value. + // Get returns the simplest representation of Value. Get may return a pointer to an internal value but it must never + // mutate that value. e.g. If Get returns a []byte Value must never change the contents of the []byte. Get() interface{} - // AssignTo converts and assigns the Value to dst. It MUST make a deep copy of - // any reference types. + // AssignTo converts and assigns the Value to dst. AssignTo may a pointer to an internal value but it must never + // mutate that value. e.g. If Get returns a []byte Value must never change the contents of the []byte. AssignTo(dst interface{}) error } From f6355165a91cd2c3476675a4dd15504a3af85611 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Jun 2020 09:10:11 -0500 Subject: [PATCH 0511/1158] Remove superfluous argument from ScanPlan --- go.mod | 2 +- go.sum | 3 +++ pgtype.go | 28 ++++++++++++++-------------- pgtype_test.go | 6 +++--- 4 files changed, 21 insertions(+), 18 deletions(-) diff --git a/go.mod b/go.mod index c7738ac9..ebf3a3c5 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.12 require ( github.com/gofrs/uuid v3.2.0+incompatible - github.com/jackc/pgconn v1.5.0 + github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853 github.com/jackc/pgio v1.0.0 github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9 github.com/lib/pq v1.3.0 diff --git a/go.sum b/go.sum index a4816869..049ed43b 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,8 @@ github.com/jackc/pgconn v1.4.0 h1:E82UBzFyD752mvI+4RIl1WSxfO2ug64T+sLjvDBWTpA= github.com/jackc/pgconn v1.4.0/go.mod h1:Y2O3ZDF0q4mMacyWV3AstPJpeHXWGEetiFttmq5lahk= github.com/jackc/pgconn v1.5.0 h1:oFSOilzIZkyg787M1fEmyMfOUUvwj0daqYMfaWwNL4o= github.com/jackc/pgconn v1.5.0/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= +github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853 h1:LRlrfJW9S99uiOCY8F/qLvX1yEY1TVAaCBHFb79yHBQ= +github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA= @@ -51,6 +53,7 @@ github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrU github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4WpC0= github.com/jackc/pgtype v1.3.1-0.20200510045248-7e66ab1e146c/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= +github.com/jackc/pgtype v1.3.1-0.20200513130519-238967ec4e4c/go.mod h1:f3c+S645fwV5ZqwPvLWZmmnAfPkmaTeLnXs0byan+aA= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96 h1:ylEAOd688Duev/fxTmGdupsbyZfxNMdngIG14DoBKTM= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912 h1:YuOWGsSK5L4Fz81Olx5TNlZftmDuNrfv4ip0Yos77Tw= diff --git a/pgtype.go b/pgtype.go index 5fe78704..0997df6e 100644 --- a/pgtype.go +++ b/pgtype.go @@ -502,7 +502,7 @@ func (scanPlanDstBinaryDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, return d.DecodeBinary(ci, src) } - newPlan := ci.PlanScan(oid, formatCode, src, dst) + newPlan := ci.PlanScan(oid, formatCode, dst) return newPlan.Scan(ci, oid, formatCode, src, dst) } @@ -513,7 +513,7 @@ func (plan scanPlanDstTextDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int return d.DecodeText(ci, src) } - newPlan := ci.PlanScan(oid, formatCode, src, dst) + newPlan := ci.PlanScan(oid, formatCode, dst) return newPlan.Scan(ci, oid, formatCode, src, dst) } @@ -522,7 +522,7 @@ type scanPlanDataTypeSQLScanner DataType func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { scanner, ok := dst.(sql.Scanner) if !ok { - newPlan := ci.PlanScan(oid, formatCode, src, dst) + newPlan := ci.PlanScan(oid, formatCode, dst) return newPlan.Scan(ci, oid, formatCode, src, dst) } @@ -566,7 +566,7 @@ func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode } // assignToErr might have failed because the type of destination has changed - newPlan := ci.PlanScan(oid, formatCode, src, dst) + newPlan := ci.PlanScan(oid, formatCode, dst) if newPlan, sameType := newPlan.(*scanPlanDataTypeAssignTo); !sameType { return newPlan.Scan(ci, oid, formatCode, src, dst) } @@ -604,7 +604,7 @@ func (scanPlanReflection) Scan(ci *ConnInfo, oid uint32, formatCode int16, src [ elemPtr := reflect.New(refVal.Type().Elem().Elem()) refVal.Elem().Set(elemPtr) - plan := ci.PlanScan(oid, formatCode, src, elemPtr.Interface()) + plan := ci.PlanScan(oid, formatCode, elemPtr.Interface()) return plan.Scan(ci, oid, formatCode, src, elemPtr.Interface()) } @@ -627,7 +627,7 @@ func (scanPlanBinaryInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src return nil } - newPlan := ci.PlanScan(oid, formatCode, src, dst) + newPlan := ci.PlanScan(oid, formatCode, dst) return newPlan.Scan(ci, oid, formatCode, src, dst) } @@ -647,7 +647,7 @@ func (scanPlanBinaryInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src return nil } - newPlan := ci.PlanScan(oid, formatCode, src, dst) + newPlan := ci.PlanScan(oid, formatCode, dst) return newPlan.Scan(ci, oid, formatCode, src, dst) } @@ -667,7 +667,7 @@ func (scanPlanBinaryInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src return nil } - newPlan := ci.PlanScan(oid, formatCode, src, dst) + newPlan := ci.PlanScan(oid, formatCode, dst) return newPlan.Scan(ci, oid, formatCode, src, dst) } @@ -688,7 +688,7 @@ func (scanPlanBinaryFloat32) Scan(ci *ConnInfo, oid uint32, formatCode int16, sr return nil } - newPlan := ci.PlanScan(oid, formatCode, src, dst) + newPlan := ci.PlanScan(oid, formatCode, dst) return newPlan.Scan(ci, oid, formatCode, src, dst) } @@ -709,7 +709,7 @@ func (scanPlanBinaryFloat64) Scan(ci *ConnInfo, oid uint32, formatCode int16, sr return nil } - newPlan := ci.PlanScan(oid, formatCode, src, dst) + newPlan := ci.PlanScan(oid, formatCode, dst) return newPlan.Scan(ci, oid, formatCode, src, dst) } @@ -721,7 +721,7 @@ func (scanPlanBinaryBytes) Scan(ci *ConnInfo, oid uint32, formatCode int16, src return nil } - newPlan := ci.PlanScan(oid, formatCode, src, dst) + newPlan := ci.PlanScan(oid, formatCode, dst) return newPlan.Scan(ci, oid, formatCode, src, dst) } @@ -737,12 +737,12 @@ func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byt return nil } - newPlan := ci.PlanScan(oid, formatCode, src, dst) + newPlan := ci.PlanScan(oid, formatCode, dst) return newPlan.Scan(ci, oid, formatCode, src, dst) } // PlanScan prepares a plan to scan a value into dst. -func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, buf []byte, dst interface{}) ScanPlan { +func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan { switch formatCode { case BinaryFormatCode: switch dst.(type) { @@ -819,7 +819,7 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, src []byte, dst interface return nil } - plan := ci.PlanScan(oid, formatCode, src, dst) + plan := ci.PlanScan(oid, formatCode, dst) return plan.Scan(ci, oid, formatCode, src, dst) } diff --git a/pgtype_test.go b/pgtype_test.go index 6bdbe7c8..b3a23676 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -186,7 +186,7 @@ func TestScanPlanBinaryInt32ScanChangedType(t *testing.T) { src := []byte{0, 0, 0, 42} var v int32 - plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) require.NoError(t, err) require.EqualValues(t, 42, v) @@ -220,7 +220,7 @@ func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) { src := []byte{0, 0, 0, 42} var v pgtype.Int4 - plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) for i := 0; i < b.N; i++ { v = pgtype.Int4{} @@ -239,7 +239,7 @@ func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { src := []byte{0, 0, 0, 42} var v int32 - plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) for i := 0; i < b.N; i++ { v = 0 From 937aec9841d240b77e63510dee382db97f9814c0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Jun 2020 09:55:14 -0500 Subject: [PATCH 0512/1158] Fix tests with newest pgx --- go.mod | 2 +- go.sum | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index ebf3a3c5..c70404df 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/gofrs/uuid v3.2.0+incompatible github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853 github.com/jackc/pgio v1.0.0 - github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9 + github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904 github.com/lib/pq v1.3.0 github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc github.com/stretchr/testify v1.5.1 diff --git a/go.sum b/go.sum index 049ed43b..ec5dd367 100644 --- a/go.sum +++ b/go.sum @@ -54,6 +54,7 @@ github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4 github.com/jackc/pgtype v1.3.1-0.20200510045248-7e66ab1e146c/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= github.com/jackc/pgtype v1.3.1-0.20200513130519-238967ec4e4c/go.mod h1:f3c+S645fwV5ZqwPvLWZmmnAfPkmaTeLnXs0byan+aA= +github.com/jackc/pgtype v1.3.1-0.20200606141011-f6355165a91c/go.mod h1:cvk9Bgu/VzJ9/lxTO5R5sf80p0DiucVtN7ZxvaC4GmQ= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96 h1:ylEAOd688Duev/fxTmGdupsbyZfxNMdngIG14DoBKTM= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912 h1:YuOWGsSK5L4Fz81Olx5TNlZftmDuNrfv4ip0Yos77Tw= @@ -64,6 +65,10 @@ github.com/jackc/pgx/v4 v4.5.0 h1:mN7Z3n0uqPe29+tA4yLWyZNceYKgRvUWNk8qW+D066E= github.com/jackc/pgx/v4 v4.5.0/go.mod h1:EpAKPLdnTorwmPUUsqrPxy5fphV18j9q3wrfRXgo+kA= github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9 h1:rche9LTjh3HEvkE6eb8ITYxRsgEKgBkODHrhdvDVX74= github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9/go.mod h1:t3/cdRQl6fOLDxqtlyhe9UWgfIi9R8+8v8GKV5TRA/o= +github.com/jackc/pgx/v4 v4.6.1-0.20200606144914-81140f6c27c9 h1:uLmaWN4t6P8AHANy8+XCNmOHp9ya68meFRPtvlnxNow= +github.com/jackc/pgx/v4 v4.6.1-0.20200606144914-81140f6c27c9/go.mod h1:ZDaNWkt9sW1JMiNn0kdYBaLelIhw7Pg4qd+Vk6tw7Hg= +github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904 h1:SdGWuGg+Cpxq6Z+ArXt0nafaKeTvtKGEoW+yvycspUU= +github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904/go.mod h1:ZDaNWkt9sW1JMiNn0kdYBaLelIhw7Pg4qd+Vk6tw7Hg= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= @@ -86,10 +91,12 @@ github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU= github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -150,6 +157,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= From 36944b232f3846ddeb8ad110df5aa266639f3469 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Jun 2020 10:26:34 -0500 Subject: [PATCH 0513/1158] Fix hstore with empty string values --- hstore.go | 3 ++- hstore_test.go | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/hstore.go b/hstore.go index 3fe50ae5..ec510df7 100644 --- a/hstore.go +++ b/hstore.go @@ -168,6 +168,7 @@ func (src Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { firstPair := true + inElemBuf := make([]byte, 0, 32) for k, v := range src.Map { if firstPair { firstPair = false @@ -178,7 +179,7 @@ func (src Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { buf = append(buf, quoteHstoreElementIfNeeded(k)...) buf = append(buf, "=>"...) - elemBuf, err := v.EncodeText(ci, nil) + elemBuf, err := v.EncodeText(ci, inElemBuf) if err != nil { return nil, err } diff --git a/hstore_test.go b/hstore_test.go index ba6c9373..dce8baf2 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -15,6 +15,7 @@ func TestHstoreTranscode(t *testing.T) { values := []interface{}{ &pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(""), "bar": text(""), "baz": text("123")}, Status: pgtype.Present}, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, From 6cd2127b96fdbc7cdddcec0f8cdfbe6a322cbf24 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Jun 2020 10:44:22 -0500 Subject: [PATCH 0514/1158] Update pgproto3 dependency --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 4dc095ca..9b6baf5b 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.1 + github.com/jackc/pgproto3/v2 v2.0.2 github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 diff --git a/go.sum b/go.sum index 23fb8b32..2063a801 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.1 h1:Rdjp4NFjwHnEslx2b66FfCI2S0LhO4itac3hXz6WX9M= github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.0.2 h1:q1Hsy66zh4vuNsajBUF2PNqfAMMfxU5mk594lPE9vjY= +github.com/jackc/pgproto3/v2 v2.0.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= From 59a0074b0a32d05ee32ccc39c6c9ca013013a69d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Jun 2020 10:51:44 -0500 Subject: [PATCH 0515/1158] Release v1.6.0 --- CHANGELOG.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c4c3b2d2..68b151d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,14 @@ +# 1.6.0 (June 6, 2020) + +* Fix panic when closing conn during cancellable query +* Fix behavior of sslmode=require with sslrootcert present (Petr Jediný) +* Fix field descriptions available after command concluded (Tobias Salzmann) +* Support connect_timeout (georgysavva) +* Handle IPv6 in connection URLs (Lukas Vogel) +* Fix ValidateConnect with cancelable context +* Improve CopyFrom performance +* Add Config.Copy (georgysavva) + # 1.5.0 (March 30, 2020) * Update golang.org/x/crypto for security fix From 9b79c87d648217d4e3559268367da148137a3e25 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Jun 2020 10:59:27 -0500 Subject: [PATCH 0516/1158] Update changelog --- CHANGELOG.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 560abff3..b7ee9abb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,14 @@ +# Unreleased + +* Add JSON support to ext/gofrs-uuid +* Performance improvements in Scan path +* Improved ext/shopspring-numeric binary decoding performance +* Add composite type support (Maxim Ivanov and Jack Christensen) +* Add better generic enum type support +* Add generic array type support +* Clarify and normalize Value semantics +* Fix hstore with empty string values + # 1.3.0 (March 30, 2020) * Get implemented on T instead of *T From a6d42976c615e0e13915864b0f7aa846877dcda8 Mon Sep 17 00:00:00 2001 From: georgysavva Date: Mon, 8 Jun 2020 13:18:54 +0300 Subject: [PATCH 0517/1158] Make it possible to scan destination of *interface{} type. --- go.sum | 4 ---- pgtype.go | 5 +++++ pgtype_test.go | 16 ++++++++++++++++ 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/go.sum b/go.sum index ec5dd367..464f0091 100644 --- a/go.sum +++ b/go.sum @@ -51,9 +51,7 @@ github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01C github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4WpC0= -github.com/jackc/pgtype v1.3.1-0.20200510045248-7e66ab1e146c/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= -github.com/jackc/pgtype v1.3.1-0.20200513130519-238967ec4e4c/go.mod h1:f3c+S645fwV5ZqwPvLWZmmnAfPkmaTeLnXs0byan+aA= github.com/jackc/pgtype v1.3.1-0.20200606141011-f6355165a91c/go.mod h1:cvk9Bgu/VzJ9/lxTO5R5sf80p0DiucVtN7ZxvaC4GmQ= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96 h1:ylEAOd688Duev/fxTmGdupsbyZfxNMdngIG14DoBKTM= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= @@ -65,8 +63,6 @@ github.com/jackc/pgx/v4 v4.5.0 h1:mN7Z3n0uqPe29+tA4yLWyZNceYKgRvUWNk8qW+D066E= github.com/jackc/pgx/v4 v4.5.0/go.mod h1:EpAKPLdnTorwmPUUsqrPxy5fphV18j9q3wrfRXgo+kA= github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9 h1:rche9LTjh3HEvkE6eb8ITYxRsgEKgBkODHrhdvDVX74= github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9/go.mod h1:t3/cdRQl6fOLDxqtlyhe9UWgfIi9R8+8v8GKV5TRA/o= -github.com/jackc/pgx/v4 v4.6.1-0.20200606144914-81140f6c27c9 h1:uLmaWN4t6P8AHANy8+XCNmOHp9ya68meFRPtvlnxNow= -github.com/jackc/pgx/v4 v4.6.1-0.20200606144914-81140f6c27c9/go.mod h1:ZDaNWkt9sW1JMiNn0kdYBaLelIhw7Pg4qd+Vk6tw7Hg= github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904 h1:SdGWuGg+Cpxq6Z+ArXt0nafaKeTvtKGEoW+yvycspUU= github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904/go.mod h1:ZDaNWkt9sW1JMiNn0kdYBaLelIhw7Pg4qd+Vk6tw7Hg= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= diff --git a/pgtype.go b/pgtype.go index 0997df6e..621a8b95 100644 --- a/pgtype.go +++ b/pgtype.go @@ -565,6 +565,11 @@ func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode return nil } + if dstPtr, ok := dst.(*interface{}); ok { + *dstPtr = dt.Value.Get() + return nil + } + // assignToErr might have failed because the type of destination has changed newPlan := ci.PlanScan(oid, formatCode, dst) if newPlan, sameType := newPlan.(*scanPlanDataTypeAssignTo); !sameType { diff --git a/pgtype_test.go b/pgtype_test.go index b3a23676..0c2bec83 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -71,6 +71,22 @@ func TestConnInfoScanNilIsNoOp(t *testing.T) { assert.NoError(t, err) } +func TestConnInfoScanTextFormatInterfacePtr(t *testing.T) { + ci := pgtype.NewConnInfo() + var got interface{} + err := ci.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), &got) + require.NoError(t, err) + assert.Equal(t, "foo", got) +} + +func TestConnInfoScanBinaryFormatInterfacePtr(t *testing.T) { + ci := pgtype.NewConnInfo() + var got interface{} + err := ci.Scan(pgtype.TextOID, pgx.BinaryFormatCode, []byte("foo"), &got) + require.NoError(t, err) + assert.Equal(t, "foo", got) +} + func TestConnInfoScanUnknownOIDToStringsAndBytes(t *testing.T) { unknownOID := uint32(999999) srcBuf := []byte("foo") From 3e586004db8ff5a400374a9cef72e5964876a17c Mon Sep 17 00:00:00 2001 From: Jacob Powers Date: Tue, 9 Jun 2020 18:08:38 -0700 Subject: [PATCH 0518/1158] add travis config --- .travis.yml | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..4389d5da --- /dev/null +++ b/.travis.yml @@ -0,0 +1,35 @@ +# source: https://github.com/jackc/pgx/blob/master/.travis.yml + +language: go + +go: + - 1.14.x + - 1.13.x + - tip + +# Derived from https://github.com/lib/pq/blob/master/.travis.yml +before_install: + - ./travis/before_install.bash + +env: + global: + - GO111MODULE=on + - PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test + + matrix: + - CRATEVERSION=2.1 PGX_TEST_CRATEDB_CONN_STRING="host=127.0.0.1 port=6543 user=pgx database=pgx_test" + - PGVERSION=12 + - PGVERSION=11 + - PGVERSION=10 + - PGVERSION=9.6 + - PGVERSION=9.5 + +before_script: + - ./travis/before_script.bash + +script: + - ./travis/script.bash + +matrix: + allow_failures: + - go: tip \ No newline at end of file From 96f49eb89bab4d53213baa3d3b130781b487e154 Mon Sep 17 00:00:00 2001 From: Jacob Powers Date: Tue, 9 Jun 2020 18:16:23 -0700 Subject: [PATCH 0519/1158] copy travis configs over from pgx --- travis/before_install.bash | 41 ++++++++++++++++++++++++++++++++++++++ travis/before_script.bash | 10 ++++++++++ travis/script.bash | 11 ++++++++++ 3 files changed, 62 insertions(+) create mode 100755 travis/before_install.bash create mode 100755 travis/before_script.bash create mode 100755 travis/script.bash diff --git a/travis/before_install.bash b/travis/before_install.bash new file mode 100755 index 00000000..c95969f9 --- /dev/null +++ b/travis/before_install.bash @@ -0,0 +1,41 @@ +#!/usr/bin/env bash +# source: https://github.com/jackc/pgx/blob/master/travis/before_install.bash + +set -eux + +if [ "${PGVERSION-}" != "" ] +then + sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common + sudo rm -rf /var/lib/postgresql + wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - + sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list" + sudo apt-get update -qq + sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION + sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf + if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then + echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf + echo "max_wal_senders=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf + echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf + fi + sudo /etc/init.d/postgresql restart +fi + +if [ "${CRATEVERSION-}" != "" ] +then + docker run \ + -p "6543:5432" \ + -d \ + crate:"$CRATEVERSION" \ + crate \ + -Cnetwork.host=0.0.0.0 \ + -Ctransport.host=localhost \ + -Clicense.enterprise=false +fi diff --git a/travis/before_script.bash b/travis/before_script.bash new file mode 100755 index 00000000..5c412631 --- /dev/null +++ b/travis/before_script.bash @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +# source: https://github.com/jackc/pgx/blob/master/travis/before_script.bash +set -eux + +if [ "${PGVERSION-}" != "" ] +then + psql -U postgres -c 'create database pgx_test' + psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)' + psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" +fi diff --git a/travis/script.bash b/travis/script.bash new file mode 100755 index 00000000..6ee46ac3 --- /dev/null +++ b/travis/script.bash @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +# source: https://github.com/jackc/pgx/blob/master/travis/script.bash +set -eux + +if [ "${PGVERSION-}" != "" ] +then + go test -v -race ./... +elif [ "${CRATEVERSION-}" != "" ] +then + go test -v -race -run 'TestCrateDBConnect' +fi From 6d62aec6b1e288ed2a646ddb004f645ea02eb4ac Mon Sep 17 00:00:00 2001 From: Jacob Powers Date: Tue, 9 Jun 2020 18:31:49 -0700 Subject: [PATCH 0520/1158] remove irrelevant test from pgx --- .travis.yml | 1 - travis/script.bash | 8 +------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/.travis.yml b/.travis.yml index 4389d5da..d6762735 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,7 +17,6 @@ env: - PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test matrix: - - CRATEVERSION=2.1 PGX_TEST_CRATEDB_CONN_STRING="host=127.0.0.1 port=6543 user=pgx database=pgx_test" - PGVERSION=12 - PGVERSION=11 - PGVERSION=10 diff --git a/travis/script.bash b/travis/script.bash index 6ee46ac3..1dfa2c20 100755 --- a/travis/script.bash +++ b/travis/script.bash @@ -2,10 +2,4 @@ # source: https://github.com/jackc/pgx/blob/master/travis/script.bash set -eux -if [ "${PGVERSION-}" != "" ] -then - go test -v -race ./... -elif [ "${CRATEVERSION-}" != "" ] -then - go test -v -race -run 'TestCrateDBConnect' -fi +go test -v -race ./... From 97e4debcc0714a10444593cc51b99e4deb9f6a00 Mon Sep 17 00:00:00 2001 From: Jacob Powers Date: Wed, 10 Jun 2020 08:27:56 -0700 Subject: [PATCH 0521/1158] disable test cases that require a binary sql snapshot --- aclitem_array_test.go | 2 +- aclitem_test.go | 2 +- go.sum | 1 - testutil/setup.sql | 0 4 files changed, 2 insertions(+), 3 deletions(-) create mode 100644 testutil/setup.sql diff --git a/aclitem_array_test.go b/aclitem_array_test.go index dafd13b0..f1dbc663 100644 --- a/aclitem_array_test.go +++ b/aclitem_array_test.go @@ -28,7 +28,7 @@ func TestACLItemArrayTranscode(t *testing.T) { Elements: []pgtype.ACLItem{ {String: "=r/postgres", Status: pgtype.Present}, {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + //{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, {String: "=r/postgres", Status: pgtype.Present}, {Status: pgtype.Null}, {String: "=r/postgres", Status: pgtype.Present}, diff --git a/aclitem_test.go b/aclitem_test.go index 480c457c..a37d7657 100644 --- a/aclitem_test.go +++ b/aclitem_test.go @@ -11,7 +11,7 @@ import ( func TestACLItemTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "aclitem", []interface{}{ &pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - &pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + //&pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, &pgtype.ACLItem{Status: pgtype.Null}, }) } diff --git a/go.sum b/go.sum index a4816869..4ad8b902 100644 --- a/go.sum +++ b/go.sum @@ -49,7 +49,6 @@ github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01C github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4WpC0= -github.com/jackc/pgtype v1.3.1-0.20200510045248-7e66ab1e146c/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96 h1:ylEAOd688Duev/fxTmGdupsbyZfxNMdngIG14DoBKTM= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= diff --git a/testutil/setup.sql b/testutil/setup.sql new file mode 100644 index 00000000..e69de29b From 0b762c6e268d5deeda378e54fcd161082d290ef6 Mon Sep 17 00:00:00 2001 From: leighhopcroft Date: Wed, 10 Jun 2020 16:59:08 +0100 Subject: [PATCH 0522/1158] updated to use boolean IsNaN field on Numeric --- numeric.go | 34 +++++++++++++++++++++++----------- numeric_test.go | 18 ++++++++++++------ 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/numeric.go b/numeric.go index 7ee517be..074c2edc 100644 --- a/numeric.go +++ b/numeric.go @@ -52,6 +52,7 @@ type Numeric struct { Int *big.Int Exp int32 Status Status + IsNaN bool } func (dst *Numeric) Set(src interface{}) error { @@ -70,6 +71,7 @@ func (dst *Numeric) Set(src interface{}) error { switch value := src.(type) { case float32: if math.IsNaN(float64(value)) { + *dst = Numeric{Status: Present, IsNaN: true} return nil } num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) @@ -79,6 +81,7 @@ func (dst *Numeric) Set(src interface{}) error { *dst = Numeric{Int: num, Exp: exp, Status: Present} case float64: if math.IsNaN(value) { + *dst = Numeric{Status: Present, IsNaN: true} return nil } num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) @@ -272,13 +275,6 @@ func (src *Numeric) AssignTo(dst interface{}) error { } case Null: return NullAssignTo(dst) - case Undefined: - switch v := dst.(type) { - case *float32: - *v = float32(math.NaN()) - case *float64: - *v = math.NaN() - } } return nil @@ -309,6 +305,10 @@ func (dst *Numeric) toBigInt() (*big.Int, error) { } func (src *Numeric) toFloat64() (float64, error) { + if src.IsNaN { + return math.NaN(), nil + } + buf := make([]byte, 0, 32) buf = append(buf, src.Int.String()...) @@ -328,8 +328,8 @@ func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { return nil } - if string(src) == "NaN" { - *dst = Numeric{} + if string(src) == "'NaN'" { // includes single quotes, see EncodeText for details. + *dst = Numeric{Status: Present, IsNaN: true} return nil } @@ -384,7 +384,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { rp += 2 if sign == pgNumericNaNSign { - *dst = Numeric{} + *dst = Numeric{Status: Present, IsNaN: true} return nil } @@ -491,7 +491,15 @@ func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Null: return nil, nil case Undefined: - buf = append(buf, []byte("NaN")...) + return nil, errUndefined + } + + if src.IsNaN { + // encode as 'NaN' including single quotes, + // "When writing this value [NaN] as a constant in an SQL command, + // you must put quotes around it, for example UPDATE table SET x = 'NaN'" + // https://www.postgresql.org/docs/9.3/datatype-numeric.html + buf = append(buf, "'NaN'"...) return buf, nil } @@ -506,6 +514,10 @@ func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Null: return nil, nil case Undefined: + return nil, errUndefined + } + + if src.IsNaN { buf = pgio.AppendUint64(buf, pgNumericNaN) return buf, nil } diff --git a/numeric_test.go b/numeric_test.go index 259f397e..4d9c5252 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -210,8 +210,8 @@ func TestNumericSet(t *testing.T) { {source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Status: pgtype.Present}}, {source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Status: pgtype.Present}}, {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}}, - {source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Undefined}}, - {source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Undefined}}, + {source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, IsNaN: true}}, + {source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, IsNaN: true}}, } for i, tt := range successfulTests { @@ -269,8 +269,8 @@ func TestNumericAssignTo(t *testing.T) { {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Status: pgtype.Present}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgtype/issues/27 - {src: &pgtype.Numeric{}, dst: &f64, expected: math.NaN()}, - {src: &pgtype.Numeric{}, dst: &f32, expected: float32(math.NaN())}, + {src: &pgtype.Numeric{Status: pgtype.Present, IsNaN: true}, dst: &f64, expected: math.NaN()}, + {src: &pgtype.Numeric{Status: pgtype.Present, IsNaN: true}, dst: &f32, expected: float32(math.NaN())}, } for i, tt := range simpleTests { @@ -282,11 +282,17 @@ func TestNumericAssignTo(t *testing.T) { dst := reflect.ValueOf(tt.dst).Elem().Interface() switch dstTyped := dst.(type) { case float32: - if math.IsNaN(float64(tt.expected.(float32))) && !math.IsNaN(float64(dstTyped)) { + nanExpected := math.IsNaN(float64(tt.expected.(float32))) + if nanExpected && !math.IsNaN(float64(dstTyped)) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } else if !nanExpected && dst != tt.expected { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } case float64: - if math.IsNaN(tt.expected.(float64)) && !math.IsNaN(dstTyped) { + nanExpected := math.IsNaN(tt.expected.(float64)) + if nanExpected && !math.IsNaN(dstTyped) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } else if !nanExpected && dst != tt.expected { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } default: From de77c70f48df17454d767fb058e9e5a2ab9c89a6 Mon Sep 17 00:00:00 2001 From: Jacob Powers Date: Wed, 10 Jun 2020 09:01:34 -0700 Subject: [PATCH 0523/1158] enable hstore extension before running tests --- testutil/setup.sql | 0 travis/before_script.bash | 1 + 2 files changed, 1 insertion(+) delete mode 100644 testutil/setup.sql diff --git a/testutil/setup.sql b/testutil/setup.sql deleted file mode 100644 index e69de29b..00000000 diff --git a/travis/before_script.bash b/travis/before_script.bash index 5c412631..13147ab0 100755 --- a/travis/before_script.bash +++ b/travis/before_script.bash @@ -7,4 +7,5 @@ then psql -U postgres -c 'create database pgx_test' psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)' psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" + psql -U postgres pgx_test -c 'create extension if not exists hstore;' fi From 25d18b98e523a9f481e1c2ac778963a4103f83b3 Mon Sep 17 00:00:00 2001 From: Jacob Powers Date: Wed, 10 Jun 2020 09:26:59 -0700 Subject: [PATCH 0524/1158] fix regression --- aclitem_array_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/aclitem_array_test.go b/aclitem_array_test.go index f1dbc663..fb1e93fc 100644 --- a/aclitem_array_test.go +++ b/aclitem_array_test.go @@ -29,6 +29,7 @@ func TestACLItemArrayTranscode(t *testing.T) { {String: "=r/postgres", Status: pgtype.Present}, {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, //{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + {String: `postgres=arwdDxt/postgres`, Status: pgtype.Present}, // todo: remove after fixing above case {String: "=r/postgres", Status: pgtype.Present}, {Status: pgtype.Null}, {String: "=r/postgres", Status: pgtype.Present}, From 6b254a445e49cc9f23f25a1e2eca7cd98fd65850 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 11 Jun 2020 20:51:40 -0500 Subject: [PATCH 0525/1158] Fix doc for ParseConfig --- config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.go b/config.go index 7ed99096..279ae400 100644 --- a/config.go +++ b/config.go @@ -112,7 +112,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { return network, address } -// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same +// ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same // defaults as libpq (e.g. port=5432) and understands most PG* environment variables. connString may be a URL or a DSN. // It also may be empty to only read from the environment. If a password is not supplied it will attempt to read the // .pgpass file. From a1b9eb4d4e06feaa3587b1633165b7a52c80b4e7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 11 Jun 2020 20:55:41 -0500 Subject: [PATCH 0526/1158] Fix parseServiceSettings not returning error --- config.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 279ae400..f15b33b4 100644 --- a/config.go +++ b/config.go @@ -536,12 +536,12 @@ func parseDSNSettings(s string) (map[string]string, error) { func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) { servicefile, err := pgservicefile.ReadServicefile(servicefilePath) if err != nil { - fmt.Errorf("failed to read service file: %v", servicefile) + return nil, fmt.Errorf("failed to read service file: %v", servicefile) } service, err := servicefile.GetService(serviceName) if err != nil { - fmt.Errorf("unable to find service: %v", servicefile) + return nil, fmt.Errorf("unable to find service: %v", servicefile) } nameMap := map[string]string{ From 7bcd9fbdaff6aeb6efd1c5ee488eb36c5acb1953 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 11 Jun 2020 21:35:32 -0500 Subject: [PATCH 0527/1158] Rename IsNaN to NaN --- numeric.go | 16 ++++++++-------- numeric_test.go | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/numeric.go b/numeric.go index 074c2edc..37a81edf 100644 --- a/numeric.go +++ b/numeric.go @@ -52,7 +52,7 @@ type Numeric struct { Int *big.Int Exp int32 Status Status - IsNaN bool + NaN bool } func (dst *Numeric) Set(src interface{}) error { @@ -71,7 +71,7 @@ func (dst *Numeric) Set(src interface{}) error { switch value := src.(type) { case float32: if math.IsNaN(float64(value)) { - *dst = Numeric{Status: Present, IsNaN: true} + *dst = Numeric{Status: Present, NaN: true} return nil } num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) @@ -81,7 +81,7 @@ func (dst *Numeric) Set(src interface{}) error { *dst = Numeric{Int: num, Exp: exp, Status: Present} case float64: if math.IsNaN(value) { - *dst = Numeric{Status: Present, IsNaN: true} + *dst = Numeric{Status: Present, NaN: true} return nil } num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) @@ -305,7 +305,7 @@ func (dst *Numeric) toBigInt() (*big.Int, error) { } func (src *Numeric) toFloat64() (float64, error) { - if src.IsNaN { + if src.NaN { return math.NaN(), nil } @@ -329,7 +329,7 @@ func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { } if string(src) == "'NaN'" { // includes single quotes, see EncodeText for details. - *dst = Numeric{Status: Present, IsNaN: true} + *dst = Numeric{Status: Present, NaN: true} return nil } @@ -384,7 +384,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { rp += 2 if sign == pgNumericNaNSign { - *dst = Numeric{Status: Present, IsNaN: true} + *dst = Numeric{Status: Present, NaN: true} return nil } @@ -494,7 +494,7 @@ func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } - if src.IsNaN { + if src.NaN { // encode as 'NaN' including single quotes, // "When writing this value [NaN] as a constant in an SQL command, // you must put quotes around it, for example UPDATE table SET x = 'NaN'" @@ -517,7 +517,7 @@ func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } - if src.IsNaN { + if src.NaN { buf = pgio.AppendUint64(buf, pgNumericNaN) return buf, nil } diff --git a/numeric_test.go b/numeric_test.go index 4d9c5252..675eddc4 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -210,8 +210,8 @@ func TestNumericSet(t *testing.T) { {source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Status: pgtype.Present}}, {source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Status: pgtype.Present}}, {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}}, - {source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, IsNaN: true}}, - {source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, IsNaN: true}}, + {source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, NaN: true}}, + {source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, NaN: true}}, } for i, tt := range successfulTests { @@ -269,8 +269,8 @@ func TestNumericAssignTo(t *testing.T) { {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Status: pgtype.Present}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgtype/issues/27 - {src: &pgtype.Numeric{Status: pgtype.Present, IsNaN: true}, dst: &f64, expected: math.NaN()}, - {src: &pgtype.Numeric{Status: pgtype.Present, IsNaN: true}, dst: &f32, expected: float32(math.NaN())}, + {src: &pgtype.Numeric{Status: pgtype.Present, NaN: true}, dst: &f64, expected: math.NaN()}, + {src: &pgtype.Numeric{Status: pgtype.Present, NaN: true}, dst: &f32, expected: float32(math.NaN())}, } for i, tt := range simpleTests { From 09efc38390474f0d6b26977605a850b7600b0ad3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 11 Jun 2020 21:36:50 -0500 Subject: [PATCH 0528/1158] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b7ee9abb..d8b891c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * Add generic array type support * Clarify and normalize Value semantics * Fix hstore with empty string values +* Numeric supports NaN values (leighhopcroft) # 1.3.0 (March 30, 2020) From f27e874d554167e7bf0bead3d5b5bf8923abba05 Mon Sep 17 00:00:00 2001 From: Lukas Vogel Date: Fri, 12 Jun 2020 13:01:57 +0200 Subject: [PATCH 0529/1158] redact passwords in parse config errors Redact passwords when printing the parseConfigError in a best effort manner. This prevents people from leaking the password into logs, if they just print the error in logs. --- errors.go | 30 ++++++++++++++++++++++++++++-- errors_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ export_test.go | 11 +++++++++++ 3 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 errors_test.go create mode 100644 export_test.go diff --git a/errors.go b/errors.go index 7a21af98..b746c825 100644 --- a/errors.go +++ b/errors.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "net" + "net/url" + "regexp" "strings" errors "golang.org/x/xerrors" @@ -98,10 +100,11 @@ type parseConfigError struct { } func (e *parseConfigError) Error() string { + connString := redactPW(e.connString) if e.err == nil { - return fmt.Sprintf("cannot parse `%s`: %s", e.connString, e.msg) + return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg) } - return fmt.Sprintf("cannot parse `%s`: %s (%s)", e.connString, e.msg, e.err.Error()) + return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error()) } func (e *parseConfigError) Unwrap() error { @@ -164,3 +167,26 @@ func (e *writeError) SafeToRetry() bool { func (e *writeError) Unwrap() error { return e.err } + +func redactPW(connString string) string { + if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { + if u, err := url.Parse(connString); err == nil { + return redactURL(u) + } + } + quotedDSN := regexp.MustCompile(`password='[^']*'`) + connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx") + plainDSN := regexp.MustCompile(`password=[^ ]*`) + connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx") + return connString +} + +func redactURL(u *url.URL) string { + if u == nil { + return "" + } + if _, pwSet := u.User.Password(); pwSet { + u.User = url.UserPassword(u.User.Username(), "xxxxx") + } + return u.String() +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 00000000..bef835f8 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,44 @@ +package pgconn_test + +import ( + "testing" + + "github.com/jackc/pgconn" + "github.com/stretchr/testify/assert" +) + +func TestConfigError(t *testing.T) { + tests := []struct { + name string + err error + expectedMsg string + }{ + { + name: "url with password", + err: pgconn.NewParseConfigError("postgresql://foo:password@host", "msg", nil), + expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg", + }, + { + name: "dsn with password unquoted", + err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil), + expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", + }, + { + name: "dsn with password quoted", + err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil), + expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", + }, + { + name: "weird url", + err: pgconn.NewParseConfigError("postgresql://foo::pasword@host:1:", "msg", nil), + expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg", + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.EqualError(t, tt.err, tt.expectedMsg) + }) + } +} diff --git a/export_test.go b/export_test.go new file mode 100644 index 00000000..2a0bad8b --- /dev/null +++ b/export_test.go @@ -0,0 +1,11 @@ +// File export_test exports some methods for better testing. + +package pgconn + +func NewParseConfigError(conn, msg string, err error) error { + return &parseConfigError{ + connString: conn, + msg: msg, + err: err, + } +} From 3105c6e7065650f46fee8585b6690687e5947d48 Mon Sep 17 00:00:00 2001 From: megaturbo Date: Wed, 17 Jun 2020 15:13:59 +0200 Subject: [PATCH 0530/1158] Add support for nullable types in Value.Get implementations --- bool.go | 12 ++++++++ date.go | 12 ++++++++ float4.go | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++ float8.go | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++ inet.go | 20 +++++++++++-- int2.go | 66 ++++++++++++++++++++++++++++++++++++++++++ int4.go | 66 ++++++++++++++++++++++++++++++++++++++++++ int8.go | 66 ++++++++++++++++++++++++++++++++++++++++++ macaddr.go | 12 ++++++++ numeric.go | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++ timestamp.go | 6 ++++ timestamptz.go | 6 ++++ uuid.go | 6 ++++ 13 files changed, 504 insertions(+), 2 deletions(-) diff --git a/bool.go b/bool.go index 8b03a1af..9ec5097f 100644 --- a/bool.go +++ b/bool.go @@ -35,6 +35,18 @@ func (dst *Bool) Set(src interface{}) error { return err } *dst = Bool{Bool: bb, Status: Present} + case *bool: + if value == nil { + *dst = Bool{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Bool{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingBoolType(src); ok { return dst.Set(originalSrc) diff --git a/date.go b/date.go index 37fb8302..59e225df 100644 --- a/date.go +++ b/date.go @@ -39,6 +39,18 @@ func (dst *Date) Set(src interface{}) error { *dst = Date{Time: value, Status: Present} case string: return dst.DecodeText(nil, []byte(value)) + case *time.Time: + if value == nil { + *dst = Date{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Date{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) diff --git a/float4.go b/float4.go index e33dfc75..5faad54d 100644 --- a/float4.go +++ b/float4.go @@ -89,6 +89,84 @@ func (dst *Float4) Set(src interface{}) error { return err } *dst = Float4{Float: float32(num), Status: Present} + case *float64: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *float32: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *int8: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) diff --git a/float8.go b/float8.go index 41d0fe70..d7412301 100644 --- a/float8.go +++ b/float8.go @@ -79,6 +79,84 @@ func (dst *Float8) Set(src interface{}) error { return err } *dst = Float8{Float: float64(num), Status: Present} + case *float64: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *float32: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *int8: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) diff --git a/inet.go b/inet.go index 7ab78bdf..f3dce87b 100644 --- a/inet.go +++ b/inet.go @@ -37,8 +37,6 @@ func (dst *Inet) Set(src interface{}) error { switch value := src.(type) { case net.IPNet: *dst = Inet{IPNet: &value, Status: Present} - case *net.IPNet: - *dst = Inet{IPNet: value, Status: Present} case net.IP: bitCount := len(value) * 8 mask := net.CIDRMask(bitCount, bitCount) @@ -49,6 +47,24 @@ func (dst *Inet) Set(src interface{}) error { return err } *dst = Inet{IPNet: ipnet, Status: Present} + case *net.IPNet: + if value == nil { + *dst = Inet{Status: Null} + } else { + return dst.Set(*value) + } + case *net.IP: + if value == nil { + *dst = Inet{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Inet{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingPtrType(src); ok { return dst.Set(originalSrc) diff --git a/int2.go b/int2.go index 54bab272..67fa1acc 100644 --- a/int2.go +++ b/int2.go @@ -85,6 +85,72 @@ func (dst *Int2) Set(src interface{}) error { return err } *dst = Int2{Int: int16(num), Status: Present} + case *int8: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) diff --git a/int4.go b/int4.go index 66fe9155..c4ed6103 100644 --- a/int4.go +++ b/int4.go @@ -77,6 +77,72 @@ func (dst *Int4) Set(src interface{}) error { return err } *dst = Int4{Int: int32(num), Status: Present} + case *int8: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) diff --git a/int8.go b/int8.go index fd721142..445fef0d 100644 --- a/int8.go +++ b/int8.go @@ -68,6 +68,72 @@ func (dst *Int8) Set(src interface{}) error { return err } *dst = Int8{Int: num, Status: Present} + case *int8: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) diff --git a/macaddr.go b/macaddr.go index af0901b0..6cc14114 100644 --- a/macaddr.go +++ b/macaddr.go @@ -36,6 +36,18 @@ func (dst *Macaddr) Set(src interface{}) error { return err } *dst = Macaddr{Addr: addr, Status: Present} + case *net.HardwareAddr: + if value == nil { + *dst = Macaddr{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Macaddr{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingPtrType(src); ok { return dst.Set(originalSrc) diff --git a/numeric.go b/numeric.go index 37a81edf..f2b04006 100644 --- a/numeric.go +++ b/numeric.go @@ -115,6 +115,84 @@ func (dst *Numeric) Set(src interface{}) error { return err } *dst = Numeric{Int: num, Exp: exp, Status: Present} + case *float64: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *float32: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *int8: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) diff --git a/timestamp.go b/timestamp.go index de059f7e..88cb7672 100644 --- a/timestamp.go +++ b/timestamp.go @@ -40,6 +40,12 @@ func (dst *Timestamp) Set(src interface{}) error { switch value := src.(type) { case time.Time: *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present} + case *time.Time: + if value == nil { + *dst = Timestamp{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) diff --git a/timestamptz.go b/timestamptz.go index 100f44a5..25ea659d 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -42,6 +42,12 @@ func (dst *Timestamptz) Set(src interface{}) error { switch value := src.(type) { case time.Time: *dst = Timestamptz{Time: value, Status: Present} + case *time.Time: + if value == nil { + *dst = Timestamptz{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) diff --git a/uuid.go b/uuid.go index bdbe17e4..634f6463 100644 --- a/uuid.go +++ b/uuid.go @@ -45,6 +45,12 @@ func (dst *UUID) Set(src interface{}) error { return err } *dst = UUID{Bytes: uuid, Status: Present} + case *string: + if value == nil { + *dst = UUID{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingUUIDType(src); ok { return dst.Set(originalSrc) From 066bc77610f0e459a3c097be13de3730e1c687a7 Mon Sep 17 00:00:00 2001 From: megaturbo Date: Wed, 17 Jun 2020 15:17:17 +0200 Subject: [PATCH 0531/1158] Add support for slice of nullable types in array types --- aclitem_array.go | 28 ++++++ bool_array.go | 28 ++++++ bpchar_array.go | 28 ++++++ cidr_array.go | 28 ++++++ date_array.go | 28 ++++++ enum_array.go | 28 ++++++ float4_array.go | 28 ++++++ float8_array.go | 28 ++++++ inet_array.go | 28 ++++++ int2_array.go | 224 +++++++++++++++++++++++++++++++++++++++++++ int4_array.go | 224 +++++++++++++++++++++++++++++++++++++++++++ int8_array.go | 224 +++++++++++++++++++++++++++++++++++++++++++ macaddr_array.go | 28 ++++++ numeric_array.go | 112 ++++++++++++++++++++++ text_array.go | 28 ++++++ timestamp_array.go | 28 ++++++ timestamptz_array.go | 28 ++++++ typed_array_gen.sh | 38 ++++---- uuid_array.go | 28 ++++++ varchar_array.go | 28 ++++++ 20 files changed, 1223 insertions(+), 19 deletions(-) diff --git a/aclitem_array.go b/aclitem_array.go index 1d3de130..064436fd 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -47,6 +47,25 @@ func (dst *ACLItemArray) Set(src interface{}) error { } } + case []*string: + if value == nil { + *dst = ACLItemArray{Status: Null} + } else if len(value) == 0 { + *dst = ACLItemArray{Status: Present} + } else { + elements := make([]ACLItem, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = ACLItemArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []ACLItem: if value == nil { *dst = ACLItemArray{Status: Null} @@ -94,6 +113,15 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { } return nil + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/bool_array.go b/bool_array.go index c1af1e1f..d5f89629 100644 --- a/bool_array.go +++ b/bool_array.go @@ -49,6 +49,25 @@ func (dst *BoolArray) Set(src interface{}) error { } } + case []*bool: + if value == nil { + *dst = BoolArray{Status: Null} + } else if len(value) == 0 { + *dst = BoolArray{Status: Present} + } else { + elements := make([]Bool, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = BoolArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Bool: if value == nil { *dst = BoolArray{Status: Null} @@ -96,6 +115,15 @@ func (src *BoolArray) AssignTo(dst interface{}) error { } return nil + case *[]*bool: + *v = make([]*bool, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/bpchar_array.go b/bpchar_array.go index b6eeabd7..10d0d0f7 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -49,6 +49,25 @@ func (dst *BPCharArray) Set(src interface{}) error { } } + case []*string: + if value == nil { + *dst = BPCharArray{Status: Null} + } else if len(value) == 0 { + *dst = BPCharArray{Status: Present} + } else { + elements := make([]BPChar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = BPCharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []BPChar: if value == nil { *dst = BPCharArray{Status: Null} @@ -96,6 +115,15 @@ func (src *BPCharArray) AssignTo(dst interface{}) error { } return nil + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/cidr_array.go b/cidr_array.go index 4f3097a0..5231e208 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -69,6 +69,25 @@ func (dst *CIDRArray) Set(src interface{}) error { } } + case []*net.IP: + if value == nil { + *dst = CIDRArray{Status: Null} + } else if len(value) == 0 { + *dst = CIDRArray{Status: Present} + } else { + elements := make([]CIDR, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CIDRArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []CIDR: if value == nil { *dst = CIDRArray{Status: Null} @@ -125,6 +144,15 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { } return nil + case *[]*net.IP: + *v = make([]*net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/date_array.go b/date_array.go index 644e78fe..51d00da1 100644 --- a/date_array.go +++ b/date_array.go @@ -50,6 +50,25 @@ func (dst *DateArray) Set(src interface{}) error { } } + case []*time.Time: + if value == nil { + *dst = DateArray{Status: Null} + } else if len(value) == 0 { + *dst = DateArray{Status: Present} + } else { + elements := make([]Date, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = DateArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Date: if value == nil { *dst = DateArray{Status: Null} @@ -97,6 +116,15 @@ func (src *DateArray) AssignTo(dst interface{}) error { } return nil + case *[]*time.Time: + *v = make([]*time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/enum_array.go b/enum_array.go index a31916dc..528cdb03 100644 --- a/enum_array.go +++ b/enum_array.go @@ -47,6 +47,25 @@ func (dst *EnumArray) Set(src interface{}) error { } } + case []*string: + if value == nil { + *dst = EnumArray{Status: Null} + } else if len(value) == 0 { + *dst = EnumArray{Status: Present} + } else { + elements := make([]GenericText, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = EnumArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []GenericText: if value == nil { *dst = EnumArray{Status: Null} @@ -94,6 +113,15 @@ func (src *EnumArray) AssignTo(dst interface{}) error { } return nil + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/float4_array.go b/float4_array.go index ccd718a1..bc9d4746 100644 --- a/float4_array.go +++ b/float4_array.go @@ -49,6 +49,25 @@ func (dst *Float4Array) Set(src interface{}) error { } } + case []*float32: + if value == nil { + *dst = Float4Array{Status: Null} + } else if len(value) == 0 { + *dst = Float4Array{Status: Present} + } else { + elements := make([]Float4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Float4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Float4: if value == nil { *dst = Float4Array{Status: Null} @@ -96,6 +115,15 @@ func (src *Float4Array) AssignTo(dst interface{}) error { } return nil + case *[]*float32: + *v = make([]*float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/float8_array.go b/float8_array.go index 740e8558..acc94b3f 100644 --- a/float8_array.go +++ b/float8_array.go @@ -49,6 +49,25 @@ func (dst *Float8Array) Set(src interface{}) error { } } + case []*float64: + if value == nil { + *dst = Float8Array{Status: Null} + } else if len(value) == 0 { + *dst = Float8Array{Status: Present} + } else { + elements := make([]Float8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Float8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Float8: if value == nil { *dst = Float8Array{Status: Null} @@ -96,6 +115,15 @@ func (src *Float8Array) AssignTo(dst interface{}) error { } return nil + case *[]*float64: + *v = make([]*float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/inet_array.go b/inet_array.go index a663d51d..6d9f11fb 100644 --- a/inet_array.go +++ b/inet_array.go @@ -69,6 +69,25 @@ func (dst *InetArray) Set(src interface{}) error { } } + case []*net.IP: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Inet: if value == nil { *dst = InetArray{Status: Null} @@ -125,6 +144,15 @@ func (src *InetArray) AssignTo(dst interface{}) error { } return nil + case *[]*net.IP: + *v = make([]*net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/int2_array.go b/int2_array.go index 98552171..35f73fee 100644 --- a/int2_array.go +++ b/int2_array.go @@ -49,6 +49,25 @@ func (dst *Int2Array) Set(src interface{}) error { } } + case []*int16: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint16: if value == nil { *dst = Int2Array{Status: Null} @@ -68,6 +87,25 @@ func (dst *Int2Array) Set(src interface{}) error { } } + case []*uint16: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []int32: if value == nil { *dst = Int2Array{Status: Null} @@ -87,6 +125,25 @@ func (dst *Int2Array) Set(src interface{}) error { } } + case []*int32: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint32: if value == nil { *dst = Int2Array{Status: Null} @@ -106,6 +163,25 @@ func (dst *Int2Array) Set(src interface{}) error { } } + case []*uint32: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []int64: if value == nil { *dst = Int2Array{Status: Null} @@ -125,6 +201,25 @@ func (dst *Int2Array) Set(src interface{}) error { } } + case []*int64: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint64: if value == nil { *dst = Int2Array{Status: Null} @@ -144,6 +239,25 @@ func (dst *Int2Array) Set(src interface{}) error { } } + case []*uint64: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []int: if value == nil { *dst = Int2Array{Status: Null} @@ -163,6 +277,25 @@ func (dst *Int2Array) Set(src interface{}) error { } } + case []*int: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint: if value == nil { *dst = Int2Array{Status: Null} @@ -182,6 +315,25 @@ func (dst *Int2Array) Set(src interface{}) error { } } + case []*uint: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Int2: if value == nil { *dst = Int2Array{Status: Null} @@ -229,6 +381,15 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } return nil + case *[]*int16: + *v = make([]*int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]uint16: *v = make([]uint16, len(src.Elements)) for i := range src.Elements { @@ -238,6 +399,15 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } return nil + case *[]*uint16: + *v = make([]*uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]int32: *v = make([]int32, len(src.Elements)) for i := range src.Elements { @@ -247,6 +417,15 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } return nil + case *[]*int32: + *v = make([]*int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]uint32: *v = make([]uint32, len(src.Elements)) for i := range src.Elements { @@ -256,6 +435,15 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } return nil + case *[]*uint32: + *v = make([]*uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]int64: *v = make([]int64, len(src.Elements)) for i := range src.Elements { @@ -265,6 +453,15 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } return nil + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]uint64: *v = make([]uint64, len(src.Elements)) for i := range src.Elements { @@ -274,6 +471,15 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } return nil + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]int: *v = make([]int, len(src.Elements)) for i := range src.Elements { @@ -283,6 +489,15 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } return nil + case *[]*int: + *v = make([]*int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]uint: *v = make([]uint, len(src.Elements)) for i := range src.Elements { @@ -292,6 +507,15 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } return nil + case *[]*uint: + *v = make([]*uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/int4_array.go b/int4_array.go index a52ab437..2ff32ee1 100644 --- a/int4_array.go +++ b/int4_array.go @@ -49,6 +49,25 @@ func (dst *Int4Array) Set(src interface{}) error { } } + case []*int16: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint16: if value == nil { *dst = Int4Array{Status: Null} @@ -68,6 +87,25 @@ func (dst *Int4Array) Set(src interface{}) error { } } + case []*uint16: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []int32: if value == nil { *dst = Int4Array{Status: Null} @@ -87,6 +125,25 @@ func (dst *Int4Array) Set(src interface{}) error { } } + case []*int32: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint32: if value == nil { *dst = Int4Array{Status: Null} @@ -106,6 +163,25 @@ func (dst *Int4Array) Set(src interface{}) error { } } + case []*uint32: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []int64: if value == nil { *dst = Int4Array{Status: Null} @@ -125,6 +201,25 @@ func (dst *Int4Array) Set(src interface{}) error { } } + case []*int64: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint64: if value == nil { *dst = Int4Array{Status: Null} @@ -144,6 +239,25 @@ func (dst *Int4Array) Set(src interface{}) error { } } + case []*uint64: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []int: if value == nil { *dst = Int4Array{Status: Null} @@ -163,6 +277,25 @@ func (dst *Int4Array) Set(src interface{}) error { } } + case []*int: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint: if value == nil { *dst = Int4Array{Status: Null} @@ -182,6 +315,25 @@ func (dst *Int4Array) Set(src interface{}) error { } } + case []*uint: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Int4: if value == nil { *dst = Int4Array{Status: Null} @@ -229,6 +381,15 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } return nil + case *[]*int16: + *v = make([]*int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]uint16: *v = make([]uint16, len(src.Elements)) for i := range src.Elements { @@ -238,6 +399,15 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } return nil + case *[]*uint16: + *v = make([]*uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]int32: *v = make([]int32, len(src.Elements)) for i := range src.Elements { @@ -247,6 +417,15 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } return nil + case *[]*int32: + *v = make([]*int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]uint32: *v = make([]uint32, len(src.Elements)) for i := range src.Elements { @@ -256,6 +435,15 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } return nil + case *[]*uint32: + *v = make([]*uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]int64: *v = make([]int64, len(src.Elements)) for i := range src.Elements { @@ -265,6 +453,15 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } return nil + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]uint64: *v = make([]uint64, len(src.Elements)) for i := range src.Elements { @@ -274,6 +471,15 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } return nil + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]int: *v = make([]int, len(src.Elements)) for i := range src.Elements { @@ -283,6 +489,15 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } return nil + case *[]*int: + *v = make([]*int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]uint: *v = make([]uint, len(src.Elements)) for i := range src.Elements { @@ -292,6 +507,15 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } return nil + case *[]*uint: + *v = make([]*uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/int8_array.go b/int8_array.go index f6d577f0..17968338 100644 --- a/int8_array.go +++ b/int8_array.go @@ -49,6 +49,25 @@ func (dst *Int8Array) Set(src interface{}) error { } } + case []*int16: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint16: if value == nil { *dst = Int8Array{Status: Null} @@ -68,6 +87,25 @@ func (dst *Int8Array) Set(src interface{}) error { } } + case []*uint16: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []int32: if value == nil { *dst = Int8Array{Status: Null} @@ -87,6 +125,25 @@ func (dst *Int8Array) Set(src interface{}) error { } } + case []*int32: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint32: if value == nil { *dst = Int8Array{Status: Null} @@ -106,6 +163,25 @@ func (dst *Int8Array) Set(src interface{}) error { } } + case []*uint32: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []int64: if value == nil { *dst = Int8Array{Status: Null} @@ -125,6 +201,25 @@ func (dst *Int8Array) Set(src interface{}) error { } } + case []*int64: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint64: if value == nil { *dst = Int8Array{Status: Null} @@ -144,6 +239,25 @@ func (dst *Int8Array) Set(src interface{}) error { } } + case []*uint64: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []int: if value == nil { *dst = Int8Array{Status: Null} @@ -163,6 +277,25 @@ func (dst *Int8Array) Set(src interface{}) error { } } + case []*int: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint: if value == nil { *dst = Int8Array{Status: Null} @@ -182,6 +315,25 @@ func (dst *Int8Array) Set(src interface{}) error { } } + case []*uint: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Int8: if value == nil { *dst = Int8Array{Status: Null} @@ -229,6 +381,15 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } return nil + case *[]*int16: + *v = make([]*int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]uint16: *v = make([]uint16, len(src.Elements)) for i := range src.Elements { @@ -238,6 +399,15 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } return nil + case *[]*uint16: + *v = make([]*uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]int32: *v = make([]int32, len(src.Elements)) for i := range src.Elements { @@ -247,6 +417,15 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } return nil + case *[]*int32: + *v = make([]*int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]uint32: *v = make([]uint32, len(src.Elements)) for i := range src.Elements { @@ -256,6 +435,15 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } return nil + case *[]*uint32: + *v = make([]*uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]int64: *v = make([]int64, len(src.Elements)) for i := range src.Elements { @@ -265,6 +453,15 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } return nil + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]uint64: *v = make([]uint64, len(src.Elements)) for i := range src.Elements { @@ -274,6 +471,15 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } return nil + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]int: *v = make([]int, len(src.Elements)) for i := range src.Elements { @@ -283,6 +489,15 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } return nil + case *[]*int: + *v = make([]*int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]uint: *v = make([]uint, len(src.Elements)) for i := range src.Elements { @@ -292,6 +507,15 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } return nil + case *[]*uint: + *v = make([]*uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/macaddr_array.go b/macaddr_array.go index 97b13537..72a4e8d4 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -50,6 +50,25 @@ func (dst *MacaddrArray) Set(src interface{}) error { } } + case []*net.HardwareAddr: + if value == nil { + *dst = MacaddrArray{Status: Null} + } else if len(value) == 0 { + *dst = MacaddrArray{Status: Present} + } else { + elements := make([]Macaddr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = MacaddrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Macaddr: if value == nil { *dst = MacaddrArray{Status: Null} @@ -97,6 +116,15 @@ func (src *MacaddrArray) AssignTo(dst interface{}) error { } return nil + case *[]*net.HardwareAddr: + *v = make([]*net.HardwareAddr, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/numeric_array.go b/numeric_array.go index 3cec9fea..e808669c 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -49,6 +49,25 @@ func (dst *NumericArray) Set(src interface{}) error { } } + case []*float32: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []float64: if value == nil { *dst = NumericArray{Status: Null} @@ -68,6 +87,25 @@ func (dst *NumericArray) Set(src interface{}) error { } } + case []*float64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []int64: if value == nil { *dst = NumericArray{Status: Null} @@ -87,6 +125,25 @@ func (dst *NumericArray) Set(src interface{}) error { } } + case []*int64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint64: if value == nil { *dst = NumericArray{Status: Null} @@ -106,6 +163,25 @@ func (dst *NumericArray) Set(src interface{}) error { } } + case []*uint64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Numeric: if value == nil { *dst = NumericArray{Status: Null} @@ -153,6 +229,15 @@ func (src *NumericArray) AssignTo(dst interface{}) error { } return nil + case *[]*float32: + *v = make([]*float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]float64: *v = make([]float64, len(src.Elements)) for i := range src.Elements { @@ -162,6 +247,15 @@ func (src *NumericArray) AssignTo(dst interface{}) error { } return nil + case *[]*float64: + *v = make([]*float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]int64: *v = make([]int64, len(src.Elements)) for i := range src.Elements { @@ -171,6 +265,15 @@ func (src *NumericArray) AssignTo(dst interface{}) error { } return nil + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + case *[]uint64: *v = make([]uint64, len(src.Elements)) for i := range src.Elements { @@ -180,6 +283,15 @@ func (src *NumericArray) AssignTo(dst interface{}) error { } return nil + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/text_array.go b/text_array.go index 2130af84..969054f8 100644 --- a/text_array.go +++ b/text_array.go @@ -49,6 +49,25 @@ func (dst *TextArray) Set(src interface{}) error { } } + case []*string: + if value == nil { + *dst = TextArray{Status: Null} + } else if len(value) == 0 { + *dst = TextArray{Status: Present} + } else { + elements := make([]Text, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TextArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Text: if value == nil { *dst = TextArray{Status: Null} @@ -96,6 +115,15 @@ func (src *TextArray) AssignTo(dst interface{}) error { } return nil + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/timestamp_array.go b/timestamp_array.go index 49ac98fd..81fd85f8 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -50,6 +50,25 @@ func (dst *TimestampArray) Set(src interface{}) error { } } + case []*time.Time: + if value == nil { + *dst = TimestampArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestampArray{Status: Present} + } else { + elements := make([]Timestamp, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TimestampArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Timestamp: if value == nil { *dst = TimestampArray{Status: Null} @@ -97,6 +116,15 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { } return nil + case *[]*time.Time: + *v = make([]*time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/timestamptz_array.go b/timestamptz_array.go index 2e26692b..48725e29 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -50,6 +50,25 @@ func (dst *TimestamptzArray) Set(src interface{}) error { } } + case []*time.Time: + if value == nil { + *dst = TimestamptzArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestamptzArray{Status: Present} + } else { + elements := make([]Timestamptz, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TimestamptzArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Timestamptz: if value == nil { *dst = TimestamptzArray{Status: Null} @@ -97,6 +116,15 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { } return nil + case *[]*time.Time: + *v = make([]*time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 6fd49264..b96dc381 100755 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -1,26 +1,26 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16,[]int32,[]uint32,[]int64,[]uint64,[]int,[]uint element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]uint16,[]int32,[]uint32,[]int64,[]uint64,[]int,[]uint element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]uint16,[]int32,[]uint32,[]int64,[]uint64,[]int,[]uint element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[]*bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time,[]*time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time,[]*time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange go_array_types=[]Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstzrange_array.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go -erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr go_array_types=[]net.HardwareAddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go -erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > text_array.go -erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null=NULL binary_format=true typed_array.go.erb > varchar_array.go -erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string element_type_name=bpchar text_null=NULL binary_format=true typed_array.go.erb > bpchar_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time,[]*time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32,[]*float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64,[]*float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go +erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr go_array_types=[]net.HardwareAddr,[]*net.HardwareAddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go +erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string,[]*string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string,[]*string element_type_name=varchar text_null=NULL binary_format=true typed_array.go.erb > varchar_array.go +erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string,[]*string element_type_name=bpchar text_null=NULL binary_format=true typed_array.go.erb > bpchar_array.go erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go -erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string,[]*string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go -erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]float64,[]int64,[]uint64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go -erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go +erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]*float32,[]float64,[]*float64,[]int64,[]*int64,[]uint64,[]*uint64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go +erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string,[]*string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go # While the binary format is theoretically possible it is only practical to use the text format. -erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string text_null=NULL binary_format=false typed_array.go.erb > enum_array.go +erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string,[]*string text_null=NULL binary_format=false typed_array.go.erb > enum_array.go goimports -w *_array.go diff --git a/uuid_array.go b/uuid_array.go index 4cd65017..0c02977f 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -87,6 +87,25 @@ func (dst *UUIDArray) Set(src interface{}) error { } } + case []*string: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []UUID: if value == nil { *dst = UUIDArray{Status: Null} @@ -152,6 +171,15 @@ func (src *UUIDArray) AssignTo(dst interface{}) error { } return nil + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/varchar_array.go b/varchar_array.go index b13f29ce..5758ba62 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -49,6 +49,25 @@ func (dst *VarcharArray) Set(src interface{}) error { } } + case []*string: + if value == nil { + *dst = VarcharArray{Status: Null} + } else if len(value) == 0 { + *dst = VarcharArray{Status: Present} + } else { + elements := make([]Varchar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = VarcharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []Varchar: if value == nil { *dst = VarcharArray{Status: Null} @@ -96,6 +115,15 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { } return nil + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) From bc07106f0e2f48fd1031ce2ba3c23e8778382b4e Mon Sep 17 00:00:00 2001 From: megaturbo Date: Wed, 17 Jun 2020 15:24:34 +0200 Subject: [PATCH 0532/1158] Add `Code generated` notice at the top of the file --- aclitem_array.go | 2 ++ bool_array.go | 2 ++ bpchar_array.go | 2 ++ bytea_array.go | 2 ++ cidr_array.go | 2 ++ date_array.go | 2 ++ enum_array.go | 2 ++ float4_array.go | 2 ++ float8_array.go | 2 ++ hstore_array.go | 2 ++ inet_array.go | 2 ++ int2_array.go | 2 ++ int4_array.go | 2 ++ int8_array.go | 2 ++ macaddr_array.go | 2 ++ numeric_array.go | 2 ++ text_array.go | 2 ++ timestamp_array.go | 2 ++ timestamptz_array.go | 2 ++ tstzrange_array.go | 2 ++ typed_array.go.erb | 2 ++ uuid_array.go | 2 ++ varchar_array.go | 2 ++ 23 files changed, 46 insertions(+) diff --git a/aclitem_array.go b/aclitem_array.go index 064436fd..2df0ccd4 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/bool_array.go b/bool_array.go index d5f89629..a8c75a25 100644 --- a/bool_array.go +++ b/bool_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/bpchar_array.go b/bpchar_array.go index 10d0d0f7..ed6fe703 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/bytea_array.go b/bytea_array.go index 6a45e4da..87d77f9e 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/cidr_array.go b/cidr_array.go index 5231e208..a2e025cc 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/date_array.go b/date_array.go index 51d00da1..fe185f67 100644 --- a/date_array.go +++ b/date_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/enum_array.go b/enum_array.go index 528cdb03..9312264c 100644 --- a/enum_array.go +++ b/enum_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/float4_array.go b/float4_array.go index bc9d4746..0e95c446 100644 --- a/float4_array.go +++ b/float4_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/float8_array.go b/float8_array.go index acc94b3f..240e88d6 100644 --- a/float8_array.go +++ b/float8_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/hstore_array.go b/hstore_array.go index 54909e42..b258cbdd 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/inet_array.go b/inet_array.go index 6d9f11fb..ca4c1a02 100644 --- a/inet_array.go +++ b/inet_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/int2_array.go b/int2_array.go index 35f73fee..ad2bd094 100644 --- a/int2_array.go +++ b/int2_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/int4_array.go b/int4_array.go index 2ff32ee1..15565f64 100644 --- a/int4_array.go +++ b/int4_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/int8_array.go b/int8_array.go index 17968338..e8e8823a 100644 --- a/int8_array.go +++ b/int8_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/macaddr_array.go b/macaddr_array.go index 72a4e8d4..616d6f85 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/numeric_array.go b/numeric_array.go index e808669c..e086ca7a 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/text_array.go b/text_array.go index 969054f8..d1583557 100644 --- a/text_array.go +++ b/text_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/timestamp_array.go b/timestamp_array.go index 81fd85f8..3b2c3141 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/timestamptz_array.go b/timestamptz_array.go index 48725e29..3328ec05 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/tstzrange_array.go b/tstzrange_array.go index 2c365645..c19a9bfa 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/typed_array.go.erb b/typed_array.go.erb index d8ae97dd..a3deea5b 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/uuid_array.go b/uuid_array.go index 0c02977f..06d2d576 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( diff --git a/varchar_array.go b/varchar_array.go index 5758ba62..32ca5941 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( From 35d2873de19d96a67465f97238aa9b67d35d7d01 Mon Sep 17 00:00:00 2001 From: tserakhau Date: Thu, 18 Jun 2020 17:11:54 +0300 Subject: [PATCH 0533/1158] Fix 490: Add jsonb arrays for pgx v4 --- jsonb_array.go | 300 +++++++++++++++++++++++++++++++++++++++++++++++++ pgtype.go | 2 + 2 files changed, 302 insertions(+) create mode 100644 jsonb_array.go diff --git a/jsonb_array.go b/jsonb_array.go new file mode 100644 index 00000000..7abc8193 --- /dev/null +++ b/jsonb_array.go @@ -0,0 +1,300 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type JSONBArray struct { + Elements []JSONB + Dimensions []ArrayDimension + Status Status +} + +func (dst *JSONBArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = JSONBArray{Status: Null} + return nil + } + + switch value := src.(type) { + + case []string: + if value == nil { + *dst = JSONBArray{Status: Null} + } else if len(value) == 0 { + *dst = JSONBArray{Status: Present} + } else { + elements := make([]JSONB, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = JSONBArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to JSONBArray", value) + } + + return nil +} + +func (dst *JSONBArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *JSONBArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %#v into %T", src, dst) +} + +func (dst *JSONBArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = JSONBArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []JSONB + + if len(uta.Elements) > 0 { + elements = make([]JSONB, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem JSONB + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = JSONBArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *JSONBArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = JSONBArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = JSONBArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]JSONB, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = JSONBArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *JSONBArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `"NULL"`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *JSONBArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("jsonb"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "text") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *JSONBArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *JSONBArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype.go b/pgtype.go index 621a8b95..5aa466d2 100644 --- a/pgtype.go +++ b/pgtype.go @@ -72,6 +72,7 @@ const ( UUIDOID = 2950 UUIDArrayOID = 2951 JSONBOID = 3802 + JSONBArrayOID = 3807 DaterangeOID = 3912 Int4rangeOID = 3904 NumrangeOID = 3906 @@ -878,6 +879,7 @@ func init() { "_timestamptz": &TimestamptzArray{}, "_uuid": &UUIDArray{}, "_varchar": &VarcharArray{}, + "_jsonb": &JSONBArray{}, "aclitem": &ACLItem{}, "bit": &Bit{}, "bool": &Bool{}, From 7cf5101bb27a95b5c5af77632b7dc0ddcef20690 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Jun 2020 12:58:15 -0500 Subject: [PATCH 0534/1158] Add NewConfig() refs #42 --- config.go | 21 ++++++++++++++++++++- pgconn_test.go | 18 ++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index f15b33b4..b6c27ce5 100644 --- a/config.go +++ b/config.go @@ -112,6 +112,18 @@ func NetworkAddress(host string, port uint16) (network, address string) { return network, address } +// NewConfig returns an *Config without parsing a connection string or reading the standard PG* environment variables. +// Host, Port, Database, User, and Password must be set before the config can be used to establish a connection. +func NewConfig() *Config { + return &Config{ + DialFunc: makeDefaultDialer().DialContext, + LookupFunc: makeDefaultResolver().LookupHost, + BuildFrontend: makeDefaultBuildFrontendFunc(8192), + RuntimeParams: map[string]string{}, + createdByParseConfig: true, + } +} + // ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same // defaults as libpq (e.g. port=5432) and understands most PG* environment variables. connString may be a URL or a DSN. // It also may be empty to only read from the environment. If a password is not supplied it will attempt to read the @@ -154,7 +166,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are // usually but not always the environment variable name downcased and without the "PG" prefix. // -// Important TLS Security Notes: +// Important Security Notes: // // ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if // not set. @@ -162,6 +174,13 @@ func NetworkAddress(host string, port uint16) (network, address string) { // See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of // security each sslmode provides. // +// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of +// the Config struct. If the main TLS config is manually changed it will not affect the fallbacks. For example, in the +// case of sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try +// the fallback which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config +// is manually changed later but the unencrypted fallback is present. Remove or update all fallbacks or use NewConfig +// to build the config manually. +// // Other known differences with libpq: // // If a host name resolves into multiple addresses, libpq will try all addresses. pgconn will only try the first. diff --git a/pgconn_test.go b/pgconn_test.go index 6362c51b..2d3e482b 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -198,6 +198,24 @@ func TestConnectWithConnectionRefused(t *testing.T) { } } +func TestConnectConfigFromNewConfig(t *testing.T) { + t.Parallel() + + baseConfig, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + config := pgconn.NewConfig() + config.Host = baseConfig.Host + config.Port = baseConfig.Port + config.Database = baseConfig.Database + config.User = baseConfig.User + config.Password = baseConfig.Password + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + closeConn(t, conn) +} + func TestConnectCustomDialer(t *testing.T) { t.Parallel() From 41a185b6112f3e3a1d0aa867d766f8f6c2db2af2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Jun 2020 13:26:06 -0500 Subject: [PATCH 0535/1158] Allow converting intervals with months and days to duration It's a lossy conversion but so is numeric to float. fixes #42 --- interval.go | 8 ++++---- interval_test.go | 11 +++++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/interval.go b/interval.go index 3a91c595..309e880c 100644 --- a/interval.go +++ b/interval.go @@ -16,6 +16,8 @@ const ( microsecondsPerSecond = 1000000 microsecondsPerMinute = 60 * microsecondsPerSecond microsecondsPerHour = 60 * microsecondsPerMinute + microsecondsPerDay = 24 * microsecondsPerHour + microsecondsPerMonth = 30 * microsecondsPerDay ) type Interval struct { @@ -67,10 +69,8 @@ func (src *Interval) AssignTo(dst interface{}) error { case Present: switch v := dst.(type) { case *time.Duration: - if src.Days > 0 || src.Months > 0 { - return errors.Errorf("interval with months or days cannot be decoded into %T", dst) - } - *v = time.Duration(src.Microseconds) * time.Microsecond + us := int64(src.Months)*microsecondsPerMonth + int64(src.Days)*microsecondsPerDay + src.Microseconds + *v = time.Duration(us) * time.Microsecond return nil default: if nextDst, retry := GetAssignToDstType(dst); retry { diff --git a/interval_test.go b/interval_test.go index 6a4787e0..1ee094d7 100644 --- a/interval_test.go +++ b/interval_test.go @@ -2,9 +2,12 @@ package pgtype_test import ( "testing" + "time" "github.com/jackc/pgtype" "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestIntervalTranscode(t *testing.T) { @@ -61,3 +64,11 @@ func TestIntervalNormalize(t *testing.T) { }, }) } + +func TestIntervalLossyConversionToDuration(t *testing.T) { + interval := &pgtype.Interval{Months: 1, Days: 1, Status: pgtype.Present} + var d time.Duration + err := interval.AssignTo(&d) + require.NoError(t, err) + assert.EqualValues(t, int64(2678400000000000), d.Nanoseconds()) +} From 44f45c6c62198949596eceb23990eddeb9581b6b Mon Sep 17 00:00:00 2001 From: tserakhau Date: Sun, 21 Jun 2020 14:21:16 +0300 Subject: [PATCH 0536/1158] Use erb for jsonb array generation --- jsonb_array.go | 48 ++++++++++++++++++++++++++++++++-------------- typed_array_gen.sh | 1 + 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/jsonb_array.go b/jsonb_array.go index 7abc8193..fd78fc80 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -4,12 +4,12 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgio" + errors "golang.org/x/xerrors" ) type JSONBArray struct { - Elements []JSONB + Elements []Text Dimensions []ArrayDimension Status Status } @@ -21,6 +21,13 @@ func (dst *JSONBArray) Set(src interface{}) error { return nil } + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + switch value := src.(type) { case []string: @@ -29,7 +36,7 @@ func (dst *JSONBArray) Set(src interface{}) error { } else if len(value) == 0 { *dst = JSONBArray{Status: Present} } else { - elements := make([]JSONB, len(value)) + elements := make([]Text, len(value)) for i := range value { if err := elements[i].Set(value[i]); err != nil { return err @@ -42,6 +49,18 @@ func (dst *JSONBArray) Set(src interface{}) error { } } + case []Text: + if value == nil { + *dst = JSONBArray{Status: Null} + } else if len(value) == 0 { + *dst = JSONBArray{Status: Present} + } else { + *dst = JSONBArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) @@ -52,7 +71,7 @@ func (dst *JSONBArray) Set(src interface{}) error { return nil } -func (dst *JSONBArray) Get() interface{} { +func (dst JSONBArray) Get() interface{} { switch dst.Status { case Present: return dst @@ -81,6 +100,7 @@ func (src *JSONBArray) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } + return errors.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) @@ -100,13 +120,13 @@ func (dst *JSONBArray) DecodeText(ci *ConnInfo, src []byte) error { return err } - var elements []JSONB + var elements []Text if len(uta.Elements) > 0 { - elements = make([]JSONB, len(uta.Elements)) + elements = make([]Text, len(uta.Elements)) for i, s := range uta.Elements { - var elem JSONB + var elem Text var elemSrc []byte if s != "NULL" { elemSrc = []byte(s) @@ -147,7 +167,7 @@ func (dst *JSONBArray) DecodeBinary(ci *ConnInfo, src []byte) error { elementCount *= d.Length } - elements := make([]JSONB, elementCount) + elements := make([]Text, elementCount) for i := range elements { elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) @@ -167,7 +187,7 @@ func (dst *JSONBArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *JSONBArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src JSONBArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -209,7 +229,7 @@ func (src *JSONBArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if elemBuf == nil { - buf = append(buf, `"NULL"`...) + buf = append(buf, `NULL`...) } else { buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } @@ -224,7 +244,7 @@ func (src *JSONBArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *JSONBArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src JSONBArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -236,7 +256,7 @@ func (src *JSONBArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { Dimensions: src.Dimensions, } - if dt, ok := ci.DataTypeForName("jsonb"); ok { + if dt, ok := ci.DataTypeForName("text"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { return nil, errors.Errorf("unable to find oid for type name %v", "text") @@ -287,7 +307,7 @@ func (dst *JSONBArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *JSONBArray) Value() (driver.Value, error) { +func (src JSONBArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 6fd49264..523b2600 100755 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -19,6 +19,7 @@ erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[] erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]float64,[]int64,[]uint64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go +erb pgtype_array_type=JSONBArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go # While the binary format is theoretically possible it is only practical to use the text format. erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string text_null=NULL binary_format=false typed_array.go.erb > enum_array.go From 66a0b33655ecdd86d9fcadd51478916699b341f8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 24 Jun 2020 08:40:34 -0500 Subject: [PATCH 0537/1158] Rerun typed_array_gen.sh --- jsonb_array.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jsonb_array.go b/jsonb_array.go index fd78fc80..daebfa7b 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -1,3 +1,5 @@ +// Code generated by erb. DO NOT EDIT. + package pgtype import ( From 473062b114e54e039d7af4a951e877b430ea0c67 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Jun 2020 11:29:21 -0500 Subject: [PATCH 0538/1158] Remove NewConfig and add more docs for ParseConfig refs #42 --- config.go | 31 ++++++++++++------------------- pgconn_test.go | 18 ------------------ 2 files changed, 12 insertions(+), 37 deletions(-) diff --git a/config.go b/config.go index b6c27ce5..75292125 100644 --- a/config.go +++ b/config.go @@ -27,8 +27,8 @@ import ( type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error -// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and -// then it can be modified. A manually initialized Config will cause ConnectConfig to panic. +// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A +// manually initialized Config will cause ConnectConfig to panic. type Config struct { Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) Port uint16 @@ -112,18 +112,6 @@ func NetworkAddress(host string, port uint16) (network, address string) { return network, address } -// NewConfig returns an *Config without parsing a connection string or reading the standard PG* environment variables. -// Host, Port, Database, User, and Password must be set before the config can be used to establish a connection. -func NewConfig() *Config { - return &Config{ - DialFunc: makeDefaultDialer().DialContext, - LookupFunc: makeDefaultResolver().LookupHost, - BuildFrontend: makeDefaultBuildFrontendFunc(8192), - RuntimeParams: map[string]string{}, - createdByParseConfig: true, - } -} - // ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same // defaults as libpq (e.g. port=5432) and understands most PG* environment variables. connString may be a URL or a DSN. // It also may be empty to only read from the environment. If a password is not supplied it will attempt to read the @@ -135,6 +123,11 @@ func NewConfig() *Config { // # Example URL // postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca // +// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done +// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be +// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should +// not be modified individually. They should all be modified or all left unchanged. +// // ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated // values that will be tried in order. This can be used as part of a high availability system. See // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. @@ -175,11 +168,11 @@ func NewConfig() *Config { // security each sslmode provides. // // The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of -// the Config struct. If the main TLS config is manually changed it will not affect the fallbacks. For example, in the -// case of sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try -// the fallback which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config -// is manually changed later but the unencrypted fallback is present. Remove or update all fallbacks or use NewConfig -// to build the config manually. +// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of +// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback +// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually +// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting +// TLCConfig. // // Other known differences with libpq: // diff --git a/pgconn_test.go b/pgconn_test.go index 2d3e482b..6362c51b 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -198,24 +198,6 @@ func TestConnectWithConnectionRefused(t *testing.T) { } } -func TestConnectConfigFromNewConfig(t *testing.T) { - t.Parallel() - - baseConfig, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - - config := pgconn.NewConfig() - config.Host = baseConfig.Host - config.Port = baseConfig.Port - config.Database = baseConfig.Database - config.User = baseConfig.User - config.Password = baseConfig.Password - - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - closeConn(t, conn) -} - func TestConnectCustomDialer(t *testing.T) { t.Parallel() From 82c2752e7151902340c65d2b61e483424b308c96 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Jun 2020 11:35:23 -0500 Subject: [PATCH 0539/1158] Update golang.org/x/text to 0.3.3 golang.org/x/text had a vulnerability: https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14040 pgconn does not appear to use the affected code path, but it is still worth updating away from the vulnerable version. fixes #44 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 9b6baf5b..45aa8a46 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,6 @@ require ( github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 - golang.org/x/text v0.3.2 + golang.org/x/text v0.3.3 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 ) diff --git a/go.sum b/go.sum index 2063a801..29b3ebd8 100644 --- a/go.sum +++ b/go.sum @@ -102,6 +102,8 @@ golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= From 65717779e443e346ee1d7183f1ba1e2fb3947e7b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Jun 2020 11:46:16 -0500 Subject: [PATCH 0540/1158] Fix crash when PGSERVICE not found --- config.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 75292125..44953e0f 100644 --- a/config.go +++ b/config.go @@ -548,12 +548,12 @@ func parseDSNSettings(s string) (map[string]string, error) { func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) { servicefile, err := pgservicefile.ReadServicefile(servicefilePath) if err != nil { - return nil, fmt.Errorf("failed to read service file: %v", servicefile) + return nil, fmt.Errorf("failed to read service file: %v", servicefilePath) } service, err := servicefile.GetService(serviceName) if err != nil { - return nil, fmt.Errorf("unable to find service: %v", servicefile) + return nil, fmt.Errorf("unable to find service: %v", serviceName) } nameMap := map[string]string{ From bd7ffdb480379b6d0e73a0bf7fdf9d7050f9fa54 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Jun 2020 11:48:20 -0500 Subject: [PATCH 0541/1158] Update golang.org/x/crypto dependency --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 45aa8a46..2487c271 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/jackc/pgproto3/v2 v2.0.2 github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 github.com/stretchr/testify v1.5.1 - golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 + golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/text v0.3.3 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 ) diff --git a/go.sum b/go.sum index 29b3ebd8..2440dd48 100644 --- a/go.sum +++ b/go.sum @@ -87,6 +87,8 @@ golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1X golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 h1:3zb4D3T4G8jdExgVU/95+vQXfpEPiMdCaZgmGVxjNHM= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= From 503c2b445f76da704197860e7158bf75ce2a9ef0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Jun 2020 11:51:30 -0500 Subject: [PATCH 0542/1158] Release v1.6.1 --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 68b151d8..25376301 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +# 1.6.1 (June 27, 2020) + +* Update golang.org/x/crypto to latest +* Update golang.org/x/text to 0.3.3 +* Fix error handling for bad PGSERVICE definition +* Redact passwords in ParseConfig errors (Lukas Vogel) + # 1.6.0 (June 6, 2020) * Fix panic when closing conn during cancellable query From c4e2b4bda398ba1f43210372e903493f5eb18f6d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Jun 2020 12:24:46 -0500 Subject: [PATCH 0543/1158] Update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d8b891c1..57db99c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ * Clarify and normalize Value semantics * Fix hstore with empty string values * Numeric supports NaN values (leighhopcroft) +* Add slice of pointer support to array types (megaturbo) +* Add jsonb array type (tserakhau) +* Allow converting intervals with months and days to duration # 1.3.0 (March 30, 2020) From efe4704c57977307927227871fece9478a9777c7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Jun 2020 12:25:17 -0500 Subject: [PATCH 0544/1158] Release v1.4.0 --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 57db99c8..0c749d76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -# Unreleased +# 1.4.0 (June 27, 2020) * Add JSON support to ext/gofrs-uuid * Performance improvements in Scan path From 5576567c19b6dd9cf12a1777aa96522b7a69ca80 Mon Sep 17 00:00:00 2001 From: James Lawrence Date: Mon, 6 Jul 2020 11:27:15 -0400 Subject: [PATCH 0545/1158] support unformatted uuid hex string. adds the abiility to support uuids in the form: 000102030405060708090a0b0c0d0e0f --- uuid.go | 10 ++++++++-- uuid_test.go | 4 ++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/uuid.go b/uuid.go index 634f6463..9f9bbefd 100644 --- a/uuid.go +++ b/uuid.go @@ -100,10 +100,16 @@ func (src *UUID) AssignTo(dst interface{}) error { // parseUUID converts a string UUID in standard form to a byte array. func parseUUID(src string) (dst [16]byte, err error) { - if len(src) < 36 { + switch len(src) { + case 36: + src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:] + case 32: + // dashes already stripped, assume valid + default: + // assume invalid. return dst, errors.Errorf("cannot parse UUID %v", src) } - src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:] + buf, err := hex.DecodeString(src) if err != nil { return dst, err diff --git a/uuid_test.go b/uuid_test.go index f0480f9a..9f7b19e2 100644 --- a/uuid_test.go +++ b/uuid_test.go @@ -46,6 +46,10 @@ func TestUUIDSet(t *testing.T) { source: "00010203-0405-0607-0809-0a0b0c0d0e0f", result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, + { + source: "000102030405060708090a0b0c0d0e0f", + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, } for i, tt := range successfulTests { From 193ecfec7316e9b99c5155f8006aef9a6fc80321 Mon Sep 17 00:00:00 2001 From: bakape Date: Fri, 10 Jul 2020 19:41:25 +0300 Subject: [PATCH 0546/1158] optimise struct padding --- array_type.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/array_type.go b/array_type.go index 9454021b..32ce7ac4 100644 --- a/array_type.go +++ b/array_type.go @@ -15,11 +15,12 @@ import ( type ArrayType struct { elements []ValueTranscoder dimensions []ArrayDimension - status Status typeName string - elementOID uint32 newElement func() ValueTranscoder + + elementOID uint32 + status Status } func NewArrayType(typeName string, elementOID uint32, newElement func() ValueTranscoder) *ArrayType { From 12752ce5d63917f9fa710ba0b117aa1b550b43ba Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 13 Jul 2020 19:34:45 -0500 Subject: [PATCH 0547/1158] Update pgservicefile --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 2487c271..d3550ca8 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.2 - github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 + github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/text v0.3.3 diff --git a/go.sum b/go.sum index 2440dd48..0b144d0f 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/jackc/pgproto3/v2 v2.0.2 h1:q1Hsy66zh4vuNsajBUF2PNqfAMMfxU5mk594lPE9v github.com/jackc/pgproto3/v2 v2.0.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= From 7a3e774a5210e09e9c9471252e5e7f276e3d455c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 14 Jul 2020 11:58:03 -0500 Subject: [PATCH 0548/1158] Fix ArrayType DecodeBinary empty array breaks future reads --- array_type.go | 8 ++++++-- array_type_test.go | 22 ++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/array_type.go b/array_type.go index 32ce7ac4..04b8710c 100644 --- a/array_type.go +++ b/array_type.go @@ -185,8 +185,12 @@ func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error { return err } + var elements []ValueTranscoder + if len(arrayHeader.Dimensions) == 0 { - *dst = ArrayType{dimensions: arrayHeader.Dimensions, status: Present} + dst.elements = elements + dst.dimensions = arrayHeader.Dimensions + dst.status = Present return nil } @@ -195,7 +199,7 @@ func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error { elementCount *= d.Length } - elements := make([]ValueTranscoder, elementCount) + elements = make([]ValueTranscoder, elementCount) for i := range elements { elem := dst.newElement() diff --git a/array_type_test.go b/array_type_test.go index 0f296bb5..626df4dc 100644 --- a/array_type_test.go +++ b/array_type_test.go @@ -60,3 +60,25 @@ func TestArrayTypeTranscode(t *testing.T) { require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) } + +func TestArrayTypeEmptyArrayDoesNotBreakArrayType(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + conn.ConnInfo().RegisterDataType(pgtype.DataType{ + Value: pgtype.NewArrayType("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }), + Name: "_text", + OID: pgtype.TextArrayOID, + }) + + var dstStrings []string + err := conn.QueryRow(context.Background(), "select '{}'::text[]").Scan(&dstStrings) + require.NoError(t, err) + + require.EqualValues(t, []string{}, dstStrings) + + err = conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings) + require.NoError(t, err) + + require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) +} From 9295bf7483021745c921e818151ef3b735090b4f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 14 Jul 2020 12:07:27 -0500 Subject: [PATCH 0549/1158] Update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 25376301..c3088dd0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.6.2 (July 14, 2020) + +* Update pgservicefile library + # 1.6.1 (June 27, 2020) * Update golang.org/x/crypto to latest From 271b0ac95ee4426f3495a2577b624296c5372a70 Mon Sep 17 00:00:00 2001 From: vahid-sohrabloo Date: Fri, 17 Jul 2020 20:31:10 +0430 Subject: [PATCH 0550/1158] AppendCertsFromPEM doesn't have error and removes pgTLSArgs AppendCertsFromPEM doesn't have error and removes pgTLSArgs because not used --- config.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/config.go b/config.go index 44953e0f..906ed7f4 100644 --- a/config.go +++ b/config.go @@ -571,13 +571,6 @@ func parseServiceSettings(servicefilePath, serviceName string) (map[string]strin return settings, nil } -type pgTLSArgs struct { - sslMode string - sslRootCert string - sslCert string - sslKey string -} - // configTLS uses libpq's TLS parameters to construct []*tls.Config. It is // necessary to allow returning multiple TLS configs as sslmode "allow" and // "prefer" allow fallback. @@ -662,7 +655,7 @@ func configTLS(settings map[string]string) ([]*tls.Config, error) { } if !caCertPool.AppendCertsFromPEM(caCert) { - return nil, errors.Errorf("unable to add CA to cert pool: %w", err) + return nil, errors.New("unable to add CA to cert pool") } tlsConfig.RootCAs = caCertPool From b939bc8d681d6e74a0b23f0a28edea25d012edf8 Mon Sep 17 00:00:00 2001 From: Yaz Saito Date: Tue, 21 Jul 2020 23:35:43 -0700 Subject: [PATCH 0551/1158] Fix encoding of a large composite data type If encoding a field caused a buffer reallocation, the its length would be written to a wrong place. --- composite_type.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composite_type.go b/composite_type.go index 49ce70fa..cbe0a245 100644 --- a/composite_type.go +++ b/composite_type.go @@ -576,7 +576,7 @@ func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field BinaryEncoder) return } if fieldBuf != nil { - binary.BigEndian.PutUint32(b.buf[lengthPos:], uint32(len(fieldBuf)-len(b.buf))) + binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf))) b.buf = fieldBuf } From 37c9edc242e83750fcfbef327001fd65603d63d0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 22 Jul 2020 06:43:39 -0500 Subject: [PATCH 0552/1158] Release v1.6.3 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c3088dd0..58481415 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.6.3 (July 22, 2020) + +* Fix error message after AppendCertsFromPEM failure (vahid-sohrabloo) + # 1.6.2 (July 14, 2020) * Update pgservicefile library From 7673c8578d80adfbc0e76e2350ffa539f44e92bb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 22 Jul 2020 06:45:10 -0500 Subject: [PATCH 0553/1158] Update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c749d76..bd98fa1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.4.1 (July 14, 2020) + +* Fix ArrayType DecodeBinary empty array breaks future reads + # 1.4.0 (June 27, 2020) * Add JSON support to ext/gofrs-uuid From d831ba712a609d578f0fd6f25c13f4c8075eacc7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 22 Jul 2020 06:46:27 -0500 Subject: [PATCH 0554/1158] Release v1.4.2 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bd98fa1d..d117d239 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.4.2 (July 22, 2020) + +* Fix encoding of a large composite data type (Yaz Saito) + # 1.4.1 (July 14, 2020) * Fix ArrayType DecodeBinary empty array breaks future reads From 4e4c4ea5410aba437bc6d6e2c5a93c4acf6cce73 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 29 Jul 2020 21:47:23 -0500 Subject: [PATCH 0555/1158] Fix deadlock on error after CommandComplete but before ReadyForQuery See: https://github.com/jackc/pgx/issues/800 --- pgconn.go | 7 +++++- pgconn_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 5644904a..50607095 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1435,12 +1435,17 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error } func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { + // Keep the first error that is recorded. Store the error before checking if the command is already concluded to + // allow for receiving an error after CommandComplete but before ReadyForQuery. + if err != nil && rr.err == nil { + rr.err = err + } + if rr.commandConcluded { return } rr.commandTag = commandTag - rr.err = err rr.rowValues = nil rr.commandConcluded = true } diff --git a/pgconn_test.go b/pgconn_test.go index 6362c51b..379aa266 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1752,6 +1752,71 @@ func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) { pgConn.Close(closeCtx) } +// https://github.com/jackc/pgx/issues/800 +func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) { + t.Parallel() + + steps := pgmock.AcceptUnauthenticatedConnRequestSteps() + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Bind{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Execute{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Sync{})) + steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + {Name: []byte("mock")}, + }})) + steps = append(steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 0")})) + steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"})) + + script := &pgmock.Script{Steps: steps} + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + conn, err := pgconn.Connect(ctx, connStr) + require.NoError(t, err) + + rr := conn.ExecParams(ctx, "mocked...", nil, nil, nil, nil) + + for rr.NextRow() { + } + + _, err = rr.Close() + require.Error(t, err) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { From 44079b0d2c9ac3629a8ea9cafe4d75568b376f9e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 29 Jul 2020 22:11:15 -0500 Subject: [PATCH 0556/1158] Fix panic on parsing DSN with trailing '=' Also correctly return error with leading '='. fixes #47 --- config.go | 7 ++++++- config_test.go | 11 +++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 906ed7f4..b2583546 100644 --- a/config.go +++ b/config.go @@ -497,7 +497,8 @@ func parseDSNSettings(s string) (map[string]string, error) { key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f") - if s[0] != '\'' { + if len(s) == 0 { + } else if s[0] != '\'' { end := 0 for ; end < len(s); end++ { if asciiSpace[s[end]] == 1 { @@ -539,6 +540,10 @@ func parseDSNSettings(s string) (map[string]string, error) { key = k } + if key == "" { + return nil, errors.New("invalid dsn") + } + settings[key] = val } diff --git a/config_test.go b/config_test.go index ebe627b1..264eb299 100644 --- a/config_test.go +++ b/config_test.go @@ -528,6 +528,17 @@ func TestParseConfig(t *testing.T) { } } +// https://github.com/jackc/pgconn/issues/47 +func TestParseConfigDSNWithTrailingEmptyEqualDoesNotPanic(t *testing.T) { + _, err := pgconn.ParseConfig("host= user= password= port= database=") + require.NoError(t, err) +} + +func TestParseConfigDSNLeadingEqual(t *testing.T) { + _, err := pgconn.ParseConfig("= user=jack") + require.Error(t, err) +} + func TestConfigCopyReturnsEqualConfig(t *testing.T) { connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" original, err := pgconn.ParseConfig(connString) From f45b4d6b76091608f30b1f8ff5de046a32080d3d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 29 Jul 2020 22:17:02 -0500 Subject: [PATCH 0557/1158] Release v1.6.4 --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 58481415..a6668fb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 1.6.4 (July 29, 2020) + +* Fix deadlock on error after CommandComplete but before ReadyForQuery +* Fix panic on parsing DSN with trailing '=' + # 1.6.3 (July 22, 2020) * Fix error message after AppendCertsFromPEM failure (vahid-sohrabloo) From 6d0b4c45e4183bbd69ca91bba3e2bd1b1854761b Mon Sep 17 00:00:00 2001 From: Matt Jibson Date: Fri, 31 Jul 2020 15:42:06 -0600 Subject: [PATCH 0558/1158] correctly encode CopyInResponse's format field --- copy_in_response.go | 1 + 1 file changed, 1 insertion(+) diff --git a/copy_in_response.go b/copy_in_response.go index 4439a032..5f2595b8 100644 --- a/copy_in_response.go +++ b/copy_in_response.go @@ -48,6 +48,7 @@ func (src *CopyInResponse) Encode(dst []byte) []byte { sp := len(dst) dst = pgio.AppendInt32(dst, -1) + dst = append(dst, src.OverallFormat) dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) From b6e34b44e5c0657be2eb7c36f5b12cc5c88dfe1f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 31 Jul 2020 17:04:18 -0500 Subject: [PATCH 0559/1158] Update pgproto3 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index d3550ca8..a20501c5 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.2 + github.com/jackc/pgproto3/v2 v2.0.3 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 diff --git a/go.sum b/go.sum index 0b144d0f..d3226ebc 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,8 @@ github.com/jackc/pgproto3/v2 v2.0.1 h1:Rdjp4NFjwHnEslx2b66FfCI2S0LhO4itac3hXz6WX github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.0.2 h1:q1Hsy66zh4vuNsajBUF2PNqfAMMfxU5mk594lPE9vjY= github.com/jackc/pgproto3/v2 v2.0.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.0.3 h1:2S4PhE00mvdvaSiCYR1ZCmR1NAxeYfTSsqqSKxE1vzo= +github.com/jackc/pgproto3/v2 v2.0.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= From 2799a6e9a6b9c6b4211e72a8f3fee18280d29b6c Mon Sep 17 00:00:00 2001 From: Matt Jibson Date: Fri, 31 Jul 2020 16:13:23 -0600 Subject: [PATCH 0560/1158] mark CopyDone as frontend too --- copy_done.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/copy_done.go b/copy_done.go index d8b6e5d7..0e13282b 100644 --- a/copy_done.go +++ b/copy_done.go @@ -10,6 +10,9 @@ type CopyDone struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*CopyDone) Backend() {} +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*CopyDone) Frontend() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *CopyDone) Decode(src []byte) error { From c894ca8b7d2a9e3dcf03f8cc319461a73f6a7fc6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Aug 2020 05:49:56 -0500 Subject: [PATCH 0561/1158] Update pgproto3 to v2.0.4 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index a20501c5..a74028c8 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.3 + github.com/jackc/pgproto3/v2 v2.0.4 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 diff --git a/go.sum b/go.sum index d3226ebc..61a2896b 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/jackc/pgproto3/v2 v2.0.2 h1:q1Hsy66zh4vuNsajBUF2PNqfAMMfxU5mk594lPE9v github.com/jackc/pgproto3/v2 v2.0.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.0.3 h1:2S4PhE00mvdvaSiCYR1ZCmR1NAxeYfTSsqqSKxE1vzo= github.com/jackc/pgproto3/v2 v2.0.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.0.4 h1:RHkX5ZUD9bl/kn0f9dYUWs1N7Nwvo1wwUYvKiR26Zco= +github.com/jackc/pgproto3/v2 v2.0.4/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= From 449a8a4f8e7a35dc38380391065a4f7122f9d21b Mon Sep 17 00:00:00 2001 From: Simo Haasanen Date: Fri, 7 Aug 2020 13:10:32 +0100 Subject: [PATCH 0562/1158] Add multidimensional array and slice support. Adds array support - previously only slices were supported. Adds new test cases for multidimensional arrays and slices. All previous test cases are unmodified and passed (fully backwards compatible). Removes hard-coded type conversions for arrays, instead now relies on the type support of the array element's type conversion support. Less maintenance for arrays, new type conversions are automatically supported when array's element gains new type support. Simplifies typed_array_gen.sh generator script by removing the hard-coded single-dimensional types for arrays. Only typed_array.go.erb and typed_array_gen.sh have been changed + 1 new auxiliary function in array.go file + additional tests in test files for each array. Other changes are from generated code. --- aclitem_array.go | 212 ++++++++----- aclitem_array_test.go | 171 +++++++++++ array.go | 22 ++ bool_array.go | 212 ++++++++----- bool_array_test.go | 125 ++++++++ bpchar_array.go | 212 ++++++++----- bytea_array.go | 184 +++++++++--- bytea_array_test.go | 104 +++++++ cidr_array.go | 241 ++++++++------- cidr_array_test.go | 144 +++++++++ date_array.go | 213 +++++++++----- date_array_test.go | 179 +++++++++++ enum_array.go | 212 ++++++++----- enum_array_test.go | 125 ++++++++ float4_array.go | 212 ++++++++----- float4_array_test.go | 125 ++++++++ float8_array.go | 212 ++++++++----- float8_array_test.go | 101 +++++++ hstore_array.go | 184 +++++++++--- hstore_array_test.go | 250 +++++++++++++++- inet_array.go | 241 ++++++++------- inet_array_test.go | 144 +++++++++ int2_array.go | 604 +++++++++----------------------------- int2_array_test.go | 125 ++++++++ int4_array.go | 604 +++++++++----------------------------- int4_array_test.go | 125 ++++++++ int8_array.go | 604 +++++++++----------------------------- int8_array_test.go | 125 ++++++++ jsonb_array.go | 184 +++++++++--- macaddr_array.go | 213 +++++++++----- macaddr_array_test.go | 152 ++++++++++ numeric_array.go | 380 +++++++++--------------- numeric_array_test.go | 125 ++++++++ text_array.go | 212 ++++++++----- text_array_test.go | 125 ++++++++ timestamp_array.go | 213 +++++++++----- timestamp_array_test.go | 143 +++++++++ timestamptz_array.go | 213 +++++++++----- timestamptz_array_test.go | 179 +++++++++++ tstzrange_array.go | 165 +++++++++-- typed_array.go.erb | 187 ++++++++---- typed_array_gen.sh | 46 +-- uuid_array.go | 268 +++++++++-------- uuid_array_test.go | 152 ++++++++++ varchar_array.go | 212 ++++++++----- varchar_array_test.go | 125 ++++++++ 46 files changed, 6193 insertions(+), 3113 deletions(-) diff --git a/aclitem_array.go b/aclitem_array.go index 2df0ccd4..09a64fb6 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -4,6 +4,7 @@ package pgtype import ( "database/sql/driver" + "reflect" errors "golang.org/x/xerrors" ) @@ -28,68 +29,94 @@ func (dst *ACLItemArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = ACLItemArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = ACLItemArray{Status: Null} - } else if len(value) == 0 { - *dst = ACLItemArray{Status: Present} - } else { - elements := make([]ACLItem, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = ACLItemArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = ACLItemArray{Status: Null} - } else if len(value) == 0 { - *dst = ACLItemArray{Status: Present} - } else { - elements := make([]ACLItem, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = ACLItemArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []ACLItem: - if value == nil { - *dst = ACLItemArray{Status: Null} - } else if len(value) == 0 { - *dst = ACLItemArray{Status: Present} - } else { - *dst = ACLItemArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for ACLItemArray", src) + } + if elementsLength == 0 { + *dst = ACLItemArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to ACLItemArray", value) + return errors.Errorf("cannot convert %v to ACLItemArray", src) + } + + *dst = ACLItemArray{ + Elements: make([]ACLItem, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]ACLItem, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to ACLItemArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *ACLItemArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to ACLItemArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in ACLItemArray", err) + } + index++ + + return index, nil +} + func (dst ACLItemArray) Get() interface{} { switch dst.Status { case Present: @@ -104,32 +131,26 @@ func (dst ACLItemArray) Get() interface{} { func (src *ACLItemArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -137,6 +158,49 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *ACLItemArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from ACLItemArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = ACLItemArray{Status: Null} diff --git a/aclitem_array_test.go b/aclitem_array_test.go index fb1e93fc..73e9ce71 100644 --- a/aclitem_array_test.go +++ b/aclitem_array_test.go @@ -69,6 +69,74 @@ func TestACLItemArraySet(t *testing.T) { source: (([]string)(nil)), result: pgtype.ACLItemArray{Status: pgtype.Null}, }, + { + source: [][]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -88,6 +156,10 @@ func TestACLItemArrayAssignTo(t *testing.T) { var stringSlice []string type _stringSlice []string var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string simpleTests := []struct { src pgtype.ACLItemArray @@ -117,6 +189,78 @@ func TestACLItemArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: (([]string)(nil)), }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringSliceDim2, + expected: [][]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + expected: [2][1]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + }, } for i, tt := range simpleTests { @@ -142,6 +286,33 @@ func TestACLItemArrayAssignTo(t *testing.T) { }, dst: &stringSlice, }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringSlice, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + }, } for i, tt := range errorTests { diff --git a/array.go b/array.go index bd3a993b..b779cd9d 100644 --- a/array.go +++ b/array.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "io" + "reflect" "strconv" "strings" "unicode" @@ -350,3 +351,24 @@ func QuoteArrayElementIfNeeded(src string) string { } return src } + +func findDimensionsFromValue(value reflect.Value, dimensions []ArrayDimension, elementsLength int) ([]ArrayDimension, int, bool) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + length := value.Len() + if 0 == elementsLength { + elementsLength = length + } else { + elementsLength *= length + } + dimensions = append(dimensions, ArrayDimension{Length: int32(length), LowerBound: 1}) + for i := 0; i < length; i++ { + if d, l, ok := findDimensionsFromValue(value.Index(i), dimensions, elementsLength); ok { + return d, l, true + } + } + } + return dimensions, elementsLength, true +} diff --git a/bool_array.go b/bool_array.go index a8c75a25..6569d5ca 100644 --- a/bool_array.go +++ b/bool_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *BoolArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = BoolArray{Status: Null} + return nil + } - case []bool: - if value == nil { - *dst = BoolArray{Status: Null} - } else if len(value) == 0 { - *dst = BoolArray{Status: Present} - } else { - elements := make([]Bool, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BoolArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*bool: - if value == nil { - *dst = BoolArray{Status: Null} - } else if len(value) == 0 { - *dst = BoolArray{Status: Present} - } else { - elements := make([]Bool, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BoolArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Bool: - if value == nil { - *dst = BoolArray{Status: Null} - } else if len(value) == 0 { - *dst = BoolArray{Status: Present} - } else { - *dst = BoolArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for BoolArray", src) + } + if elementsLength == 0 { + *dst = BoolArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to BoolArray", value) + return errors.Errorf("cannot convert %v to BoolArray", src) + } + + *dst = BoolArray{ + Elements: make([]Bool, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Bool, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to BoolArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *BoolArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to BoolArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in BoolArray", err) + } + index++ + + return index, nil +} + func (dst BoolArray) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst BoolArray) Get() interface{} { func (src *BoolArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]bool: - *v = make([]bool, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*bool: - *v = make([]*bool, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *BoolArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *BoolArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from BoolArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = BoolArray{Status: Null} diff --git a/bool_array_test.go b/bool_array_test.go index bef94622..7f31e252 100644 --- a/bool_array_test.go +++ b/bool_array_test.go @@ -68,6 +68,54 @@ func TestBoolArraySet(t *testing.T) { source: (([]bool)(nil)), result: pgtype.BoolArray{Status: pgtype.Null}, }, + { + source: [][]bool{{true}, {false}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]bool{{true}, {false}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -87,6 +135,10 @@ func TestBoolArrayAssignTo(t *testing.T) { var boolSlice []bool type _boolSlice []bool var namedBoolSlice _boolSlice + var boolSliceDim2 [][]bool + var boolSliceDim4 [][][][]bool + var boolArrayDim2 [2][1]bool + var boolArrayDim4 [2][1][1][3]bool simpleTests := []struct { src pgtype.BoolArray @@ -116,6 +168,58 @@ func TestBoolArrayAssignTo(t *testing.T) { dst: &boolSlice, expected: (([]bool)(nil)), }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]bool{{true}, {false}}, + dst: &boolSliceDim2, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + dst: &boolSliceDim4, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]bool{{true}, {false}}, + dst: &boolArrayDim2, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + dst: &boolArrayDim4, + }, } for i, tt := range simpleTests { @@ -141,6 +245,27 @@ func TestBoolArrayAssignTo(t *testing.T) { }, dst: &boolSlice, }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &boolArrayDim2, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &boolSlice, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &boolArrayDim4, + }, } for i, tt := range errorTests { diff --git a/bpchar_array.go b/bpchar_array.go index ed6fe703..8aef8330 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *BPCharArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = BPCharArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = BPCharArray{Status: Null} - } else if len(value) == 0 { - *dst = BPCharArray{Status: Present} - } else { - elements := make([]BPChar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BPCharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = BPCharArray{Status: Null} - } else if len(value) == 0 { - *dst = BPCharArray{Status: Present} - } else { - elements := make([]BPChar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BPCharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []BPChar: - if value == nil { - *dst = BPCharArray{Status: Null} - } else if len(value) == 0 { - *dst = BPCharArray{Status: Present} - } else { - *dst = BPCharArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for BPCharArray", src) + } + if elementsLength == 0 { + *dst = BPCharArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to BPCharArray", value) + return errors.Errorf("cannot convert %v to BPCharArray", src) + } + + *dst = BPCharArray{ + Elements: make([]BPChar, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]BPChar, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to BPCharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *BPCharArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to BPCharArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in BPCharArray", err) + } + index++ + + return index, nil +} + func (dst BPCharArray) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst BPCharArray) Get() interface{} { func (src *BPCharArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *BPCharArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *BPCharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from BPCharArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *BPCharArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = BPCharArray{Status: Null} diff --git a/bytea_array.go b/bytea_array.go index 87d77f9e..3addb99a 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,49 +31,94 @@ func (dst *ByteaArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = ByteaArray{Status: Null} + return nil + } - case [][]byte: - if value == nil { - *dst = ByteaArray{Status: Null} - } else if len(value) == 0 { - *dst = ByteaArray{Status: Present} - } else { - elements := make([]Bytea, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = ByteaArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Bytea: - if value == nil { - *dst = ByteaArray{Status: Null} - } else if len(value) == 0 { - *dst = ByteaArray{Status: Present} - } else { - *dst = ByteaArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for ByteaArray", src) + } + if elementsLength == 0 { + *dst = ByteaArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to ByteaArray", value) + return errors.Errorf("cannot convert %v to ByteaArray", src) + } + + *dst = ByteaArray{ + Elements: make([]Bytea, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Bytea, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to ByteaArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *ByteaArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to ByteaArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in ByteaArray", err) + } + index++ + + return index, nil +} + func (dst ByteaArray) Get() interface{} { switch dst.Status { case Present: @@ -87,23 +133,26 @@ func (dst ByteaArray) Get() interface{} { func (src *ByteaArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[][]byte: - *v = make([][]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -111,6 +160,49 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *ByteaArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from ByteaArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = ByteaArray{Status: Null} diff --git a/bytea_array_test.go b/bytea_array_test.go index a4eb2d91..f40005a2 100644 --- a/bytea_array_test.go +++ b/bytea_array_test.go @@ -68,6 +68,54 @@ func TestByteaArraySet(t *testing.T) { source: (([][]byte)(nil)), result: pgtype.ByteaArray{Status: pgtype.Null}, }, + { + source: [][][]byte{{{1}}, {{2}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, + {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, + {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, + {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, + {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1][]byte{{{1}}, {{2}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, + {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, + {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, + {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, + {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -85,6 +133,10 @@ func TestByteaArraySet(t *testing.T) { func TestByteaArrayAssignTo(t *testing.T) { var byteByteSlice [][]byte + var byteByteSliceDim2 [][][]byte + var byteByteSliceDim4 [][][][][]byte + var byteByteArraySliceDim2 [2][1][]byte + var byteByteArraySliceDim4 [2][1][1][3][]byte simpleTests := []struct { src pgtype.ByteaArray @@ -105,6 +157,58 @@ func TestByteaArrayAssignTo(t *testing.T) { dst: &byteByteSlice, expected: (([][]byte)(nil)), }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &byteByteSliceDim2, + expected: [][][]byte{{{1}}, {{2}}}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, + {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, + {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, + {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, + {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &byteByteSliceDim4, + expected: [][][][][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &byteByteArraySliceDim2, + expected: [2][1][]byte{{{1}}, {{2}}}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, + {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, + {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, + {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, + {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &byteByteArraySliceDim4, + expected: [2][1][1][3][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + }, } for i, tt := range simpleTests { diff --git a/cidr_array.go b/cidr_array.go index a2e025cc..1ef2f428 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "net" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,87 +31,94 @@ func (dst *CIDRArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = CIDRArray{Status: Null} + return nil + } - case []*net.IPNet: - if value == nil { - *dst = CIDRArray{Status: Null} - } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []net.IP: - if value == nil { - *dst = CIDRArray{Status: Null} - } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*net.IP: - if value == nil { - *dst = CIDRArray{Status: Null} - } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []CIDR: - if value == nil { - *dst = CIDRArray{Status: Null} - } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} - } else { - *dst = CIDRArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for CIDRArray", src) + } + if elementsLength == 0 { + *dst = CIDRArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to CIDRArray", value) + return errors.Errorf("cannot convert %v to CIDRArray", src) + } + + *dst = CIDRArray{ + Elements: make([]CIDR, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]CIDR, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to CIDRArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *CIDRArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to CIDRArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in CIDRArray", err) + } + index++ + + return index, nil +} + func (dst CIDRArray) Get() interface{} { switch dst.Status { case Present: @@ -126,41 +133,26 @@ func (dst CIDRArray) Get() interface{} { func (src *CIDRArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]*net.IPNet: - *v = make([]*net.IPNet, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]net.IP: - *v = make([]net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.IP: - *v = make([]*net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -168,6 +160,49 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *CIDRArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from CIDRArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = CIDRArray{Status: Null} diff --git a/cidr_array_test.go b/cidr_array_test.go index 421aec4e..b1769c38 100644 --- a/cidr_array_test.go +++ b/cidr_array_test.go @@ -80,6 +80,74 @@ func TestCIDRArraySet(t *testing.T) { source: (([]net.IP)(nil)), result: pgtype.CIDRArray{Status: pgtype.Null}, }, + { + source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -98,6 +166,10 @@ func TestCIDRArraySet(t *testing.T) { func TestCIDRArrayAssignTo(t *testing.T) { var ipnetSlice []*net.IPNet var ipSlice []net.IP + var ipSliceDim2 [][]net.IP + var ipnetSliceDim4 [][][][]*net.IPNet + var ipArrayDim2 [2][1]net.IP + var ipnetArrayDim4 [2][1][1][3]*net.IPNet simpleTests := []struct { src pgtype.CIDRArray @@ -150,6 +222,78 @@ func TestCIDRArrayAssignTo(t *testing.T) { dst: &ipSlice, expected: (([]net.IP)(nil)), }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &ipSliceDim2, + expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &ipnetSliceDim4, + expected: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &ipArrayDim2, + expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &ipnetArrayDim4, + expected: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, } for i, tt := range simpleTests { diff --git a/date_array.go b/date_array.go index fe185f67..4ccdafe0 100644 --- a/date_array.go +++ b/date_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "time" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,68 +31,94 @@ func (dst *DateArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = DateArray{Status: Null} + return nil + } - case []time.Time: - if value == nil { - *dst = DateArray{Status: Null} - } else if len(value) == 0 { - *dst = DateArray{Status: Present} - } else { - elements := make([]Date, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = DateArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*time.Time: - if value == nil { - *dst = DateArray{Status: Null} - } else if len(value) == 0 { - *dst = DateArray{Status: Present} - } else { - elements := make([]Date, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = DateArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Date: - if value == nil { - *dst = DateArray{Status: Null} - } else if len(value) == 0 { - *dst = DateArray{Status: Present} - } else { - *dst = DateArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for DateArray", src) + } + if elementsLength == 0 { + *dst = DateArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to DateArray", value) + return errors.Errorf("cannot convert %v to DateArray", src) + } + + *dst = DateArray{ + Elements: make([]Date, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Date, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to DateArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *DateArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to DateArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in DateArray", err) + } + index++ + + return index, nil +} + func (dst DateArray) Get() interface{} { switch dst.Status { case Present: @@ -107,32 +133,26 @@ func (dst DateArray) Get() interface{} { func (src *DateArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*time.Time: - *v = make([]*time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -140,6 +160,49 @@ func (src *DateArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *DateArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from DateArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = DateArray{Status: Null} diff --git a/date_array_test.go b/date_array_test.go index 9f4a96a9..089c7dd4 100644 --- a/date_array_test.go +++ b/date_array_test.go @@ -69,6 +69,78 @@ func TestDateArraySet(t *testing.T) { source: (([]time.Time)(nil)), result: pgtype.DateArray{Status: pgtype.Null}, }, + { + source: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -86,6 +158,10 @@ func TestDateArraySet(t *testing.T) { func TestDateArrayAssignTo(t *testing.T) { var timeSlice []time.Time + var timeSliceDim2 [][]time.Time + var timeSliceDim4 [][][][]time.Time + var timeArrayDim2 [2][1]time.Time + var timeArrayDim4 [2][1][1][3]time.Time simpleTests := []struct { src pgtype.DateArray @@ -106,6 +182,82 @@ func TestDateArrayAssignTo(t *testing.T) { dst: &timeSlice, expected: (([]time.Time)(nil)), }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeSliceDim2, + expected: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeSliceDim4, + expected: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + expected: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + expected: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, } for i, tt := range simpleTests { @@ -131,6 +283,33 @@ func TestDateArrayAssignTo(t *testing.T) { }, dst: &timeSlice, }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeSlice, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + }, } for i, tt := range errorTests { diff --git a/enum_array.go b/enum_array.go index 9312264c..2c83db24 100644 --- a/enum_array.go +++ b/enum_array.go @@ -4,6 +4,7 @@ package pgtype import ( "database/sql/driver" + "reflect" errors "golang.org/x/xerrors" ) @@ -28,68 +29,94 @@ func (dst *EnumArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = EnumArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = EnumArray{Status: Null} - } else if len(value) == 0 { - *dst = EnumArray{Status: Present} - } else { - elements := make([]GenericText, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = EnumArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = EnumArray{Status: Null} - } else if len(value) == 0 { - *dst = EnumArray{Status: Present} - } else { - elements := make([]GenericText, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = EnumArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []GenericText: - if value == nil { - *dst = EnumArray{Status: Null} - } else if len(value) == 0 { - *dst = EnumArray{Status: Present} - } else { - *dst = EnumArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for EnumArray", src) + } + if elementsLength == 0 { + *dst = EnumArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to EnumArray", value) + return errors.Errorf("cannot convert %v to EnumArray", src) + } + + *dst = EnumArray{ + Elements: make([]GenericText, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]GenericText, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to EnumArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *EnumArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to EnumArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in EnumArray", err) + } + index++ + + return index, nil +} + func (dst EnumArray) Get() interface{} { switch dst.Status { case Present: @@ -104,32 +131,26 @@ func (dst EnumArray) Get() interface{} { func (src *EnumArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -137,6 +158,49 @@ func (src *EnumArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *EnumArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from EnumArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = EnumArray{Status: Null} diff --git a/enum_array_test.go b/enum_array_test.go index 406c6b47..91a81ab6 100644 --- a/enum_array_test.go +++ b/enum_array_test.go @@ -67,6 +67,54 @@ func TestEnumArrayArraySet(t *testing.T) { source: (([]string)(nil)), result: pgtype.EnumArray{Status: pgtype.Null}, }, + { + source: [][]string{{"foo"}, {"bar"}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]string{{"foo"}, {"bar"}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -86,6 +134,10 @@ func TestEnumArrayArrayAssignTo(t *testing.T) { var stringSlice []string type _stringSlice []string var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string simpleTests := []struct { src pgtype.EnumArray @@ -115,6 +167,58 @@ func TestEnumArrayArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: (([]string)(nil)), }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringSliceDim2, + expected: [][]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + expected: [2][1]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, } for i, tt := range simpleTests { @@ -140,6 +244,27 @@ func TestEnumArrayArrayAssignTo(t *testing.T) { }, dst: &stringSlice, }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringSlice, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + }, } for i, tt := range errorTests { diff --git a/float4_array.go b/float4_array.go index 0e95c446..78d1a860 100644 --- a/float4_array.go +++ b/float4_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *Float4Array) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = Float4Array{Status: Null} + return nil + } - case []float32: - if value == nil { - *dst = Float4Array{Status: Null} - } else if len(value) == 0 { - *dst = Float4Array{Status: Present} - } else { - elements := make([]Float4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*float32: - if value == nil { - *dst = Float4Array{Status: Null} - } else if len(value) == 0 { - *dst = Float4Array{Status: Present} - } else { - elements := make([]Float4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Float4: - if value == nil { - *dst = Float4Array{Status: Null} - } else if len(value) == 0 { - *dst = Float4Array{Status: Present} - } else { - *dst = Float4Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Float4Array", src) + } + if elementsLength == 0 { + *dst = Float4Array{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Float4Array", value) + return errors.Errorf("cannot convert %v to Float4Array", src) + } + + *dst = Float4Array{ + Elements: make([]Float4, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Float4, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Float4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *Float4Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to Float4Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in Float4Array", err) + } + index++ + + return index, nil +} + func (dst Float4Array) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst Float4Array) Get() interface{} { func (src *Float4Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]float32: - *v = make([]float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float32: - *v = make([]*float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *Float4Array) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *Float4Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from Float4Array") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4Array{Status: Null} diff --git a/float4_array_test.go b/float4_array_test.go index 658b3381..23a94ee8 100644 --- a/float4_array_test.go +++ b/float4_array_test.go @@ -68,6 +68,54 @@ func TestFloat4ArraySet(t *testing.T) { source: (([]float32)(nil)), result: pgtype.Float4Array{Status: pgtype.Null}, }, + { + source: [][]float32{{1}, {2}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]float32{{1}, {2}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -86,6 +134,10 @@ func TestFloat4ArraySet(t *testing.T) { func TestFloat4ArrayAssignTo(t *testing.T) { var float32Slice []float32 var namedFloat32Slice _float32Slice + var float32SliceDim2 [][]float32 + var float32SliceDim4 [][][][]float32 + var float32ArrayDim2 [2][1]float32 + var float32ArrayDim4 [2][1][1][3]float32 simpleTests := []struct { src pgtype.Float4Array @@ -115,6 +167,58 @@ func TestFloat4ArrayAssignTo(t *testing.T) { dst: &float32Slice, expected: (([]float32)(nil)), }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]float32{{1}, {2}}, + dst: &float32SliceDim2, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float32SliceDim4, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]float32{{1}, {2}}, + dst: &float32ArrayDim2, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float32ArrayDim4, + }, } for i, tt := range simpleTests { @@ -140,6 +244,27 @@ func TestFloat4ArrayAssignTo(t *testing.T) { }, dst: &float32Slice, }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float32ArrayDim2, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float32Slice, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float32ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/float8_array.go b/float8_array.go index 240e88d6..19223c52 100644 --- a/float8_array.go +++ b/float8_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *Float8Array) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = Float8Array{Status: Null} + return nil + } - case []float64: - if value == nil { - *dst = Float8Array{Status: Null} - } else if len(value) == 0 { - *dst = Float8Array{Status: Present} - } else { - elements := make([]Float8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*float64: - if value == nil { - *dst = Float8Array{Status: Null} - } else if len(value) == 0 { - *dst = Float8Array{Status: Present} - } else { - elements := make([]Float8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Float8: - if value == nil { - *dst = Float8Array{Status: Null} - } else if len(value) == 0 { - *dst = Float8Array{Status: Present} - } else { - *dst = Float8Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Float8Array", src) + } + if elementsLength == 0 { + *dst = Float8Array{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Float8Array", value) + return errors.Errorf("cannot convert %v to Float8Array", src) + } + + *dst = Float8Array{ + Elements: make([]Float8, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Float8, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Float8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *Float8Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to Float8Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in Float8Array", err) + } + index++ + + return index, nil +} + func (dst Float8Array) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst Float8Array) Get() interface{} { func (src *Float8Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]float64: - *v = make([]float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float64: - *v = make([]*float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *Float8Array) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *Float8Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from Float8Array") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8Array{Status: Null} diff --git a/float8_array_test.go b/float8_array_test.go index 2e29a19f..052ab3f3 100644 --- a/float8_array_test.go +++ b/float8_array_test.go @@ -68,6 +68,30 @@ func TestFloat8ArraySet(t *testing.T) { source: (([]float64)(nil)), result: pgtype.Float8Array{Status: pgtype.Null}, }, + { + source: [][]float64{{1}, {2}}, + result: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -86,6 +110,10 @@ func TestFloat8ArraySet(t *testing.T) { func TestFloat8ArrayAssignTo(t *testing.T) { var float64Slice []float64 var namedFloat64Slice _float64Slice + var float64SliceDim2 [][]float64 + var float64SliceDim4 [][][][]float64 + var float64ArrayDim2 [2][1]float64 + var float64ArrayDim4 [2][1][1][3]float64 simpleTests := []struct { src pgtype.Float8Array @@ -115,6 +143,58 @@ func TestFloat8ArrayAssignTo(t *testing.T) { dst: &float64Slice, expected: (([]float64)(nil)), }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]float64{{1}, {2}}, + dst: &float64SliceDim2, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float64SliceDim4, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]float64{{1}, {2}}, + dst: &float64ArrayDim2, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float64ArrayDim4, + }, } for i, tt := range simpleTests { @@ -140,6 +220,27 @@ func TestFloat8ArrayAssignTo(t *testing.T) { }, dst: &float64Slice, }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float64ArrayDim2, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float64Slice, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float64ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/hstore_array.go b/hstore_array.go index b258cbdd..8764aae7 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,49 +31,94 @@ func (dst *HstoreArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = HstoreArray{Status: Null} + return nil + } - case []map[string]string: - if value == nil { - *dst = HstoreArray{Status: Null} - } else if len(value) == 0 { - *dst = HstoreArray{Status: Present} - } else { - elements := make([]Hstore, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = HstoreArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Hstore: - if value == nil { - *dst = HstoreArray{Status: Null} - } else if len(value) == 0 { - *dst = HstoreArray{Status: Present} - } else { - *dst = HstoreArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for HstoreArray", src) + } + if elementsLength == 0 { + *dst = HstoreArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to HstoreArray", value) + return errors.Errorf("cannot convert %v to HstoreArray", src) + } + + *dst = HstoreArray{ + Elements: make([]Hstore, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Hstore, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to HstoreArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *HstoreArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to HstoreArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in HstoreArray", err) + } + index++ + + return index, nil +} + func (dst HstoreArray) Get() interface{} { switch dst.Status { case Present: @@ -87,23 +133,26 @@ func (dst HstoreArray) Get() interface{} { func (src *HstoreArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]map[string]string: - *v = make([]map[string]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -111,6 +160,49 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *HstoreArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from HstoreArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = HstoreArray{Status: Null} diff --git a/hstore_array_test.go b/hstore_array_test.go index 32b91840..fac66b4a 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -131,7 +131,7 @@ func TestHstoreArrayTranscode(t *testing.T) { func TestHstoreArraySet(t *testing.T) { successfulTests := []struct { - src []map[string]string + src interface{} result pgtype.HstoreArray }{ { @@ -147,6 +147,118 @@ func TestHstoreArraySet(t *testing.T) { Status: pgtype.Present, }, }, + { + src: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + }, + { + src: [][][][]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present, + }, + }, + { + src: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + }, + { + src: [2][1][1][3]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present, + }, + }, } for i, tt := range successfulTests { @@ -163,12 +275,16 @@ func TestHstoreArraySet(t *testing.T) { } func TestHstoreArrayAssignTo(t *testing.T) { - var m []map[string]string + var hstoreSlice []map[string]string + var hstoreSliceDim2 [][]map[string]string + var hstoreSliceDim4 [][][][]map[string]string + var hstoreArrayDim2 [2][1]map[string]string + var hstoreArrayDim4 [2][1][1][3]map[string]string simpleTests := []struct { src pgtype.HstoreArray - dst *[]map[string]string - expected []map[string]string + dst interface{} + expected interface{} }{ { src: pgtype.HstoreArray{ @@ -181,9 +297,127 @@ func TestHstoreArrayAssignTo(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, - dst: &m, + dst: &hstoreSlice, expected: []map[string]string{{"foo": "bar"}}}, - {src: pgtype.HstoreArray{Status: pgtype.Null}, dst: &m, expected: (([]map[string]string)(nil))}, + { + src: pgtype.HstoreArray{Status: pgtype.Null}, dst: &hstoreSlice, expected: (([]map[string]string)(nil)), + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &hstoreSliceDim2, + expected: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present, + }, + dst: &hstoreSliceDim4, + expected: [][][][]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &hstoreArrayDim2, + expected: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present, + }, + dst: &hstoreArrayDim4, + expected: [2][1][1][3]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + }, } for i, tt := range simpleTests { @@ -192,8 +426,8 @@ func TestHstoreArrayAssignTo(t *testing.T) { t.Errorf("%d: %v", i, err) } - if !reflect.DeepEqual(*tt.dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } } } diff --git a/inet_array.go b/inet_array.go index ca4c1a02..91f5d6e8 100644 --- a/inet_array.go +++ b/inet_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "net" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,87 +31,94 @@ func (dst *InetArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = InetArray{Status: Null} + return nil + } - case []*net.IPNet: - if value == nil { - *dst = InetArray{Status: Null} - } else if len(value) == 0 { - *dst = InetArray{Status: Present} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []net.IP: - if value == nil { - *dst = InetArray{Status: Null} - } else if len(value) == 0 { - *dst = InetArray{Status: Present} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*net.IP: - if value == nil { - *dst = InetArray{Status: Null} - } else if len(value) == 0 { - *dst = InetArray{Status: Present} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Inet: - if value == nil { - *dst = InetArray{Status: Null} - } else if len(value) == 0 { - *dst = InetArray{Status: Present} - } else { - *dst = InetArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for InetArray", src) + } + if elementsLength == 0 { + *dst = InetArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to InetArray", value) + return errors.Errorf("cannot convert %v to InetArray", src) + } + + *dst = InetArray{ + Elements: make([]Inet, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Inet, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to InetArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *InetArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to InetArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in InetArray", err) + } + index++ + + return index, nil +} + func (dst InetArray) Get() interface{} { switch dst.Status { case Present: @@ -126,41 +133,26 @@ func (dst InetArray) Get() interface{} { func (src *InetArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]*net.IPNet: - *v = make([]*net.IPNet, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]net.IP: - *v = make([]net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.IP: - *v = make([]*net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -168,6 +160,49 @@ func (src *InetArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *InetArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from InetArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = InetArray{Status: Null} diff --git a/inet_array_test.go b/inet_array_test.go index 6737aac0..d78b91c0 100644 --- a/inet_array_test.go +++ b/inet_array_test.go @@ -80,6 +80,74 @@ func TestInetArraySet(t *testing.T) { source: (([]net.IP)(nil)), result: pgtype.InetArray{Status: pgtype.Null}, }, + { + source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -98,6 +166,10 @@ func TestInetArraySet(t *testing.T) { func TestInetArrayAssignTo(t *testing.T) { var ipnetSlice []*net.IPNet var ipSlice []net.IP + var ipSliceDim2 [][]net.IP + var ipnetSliceDim4 [][][][]*net.IPNet + var ipArrayDim2 [2][1]net.IP + var ipnetArrayDim4 [2][1][1][3]*net.IPNet simpleTests := []struct { src pgtype.InetArray @@ -150,6 +222,78 @@ func TestInetArrayAssignTo(t *testing.T) { dst: &ipSlice, expected: (([]net.IP)(nil)), }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &ipSliceDim2, + expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &ipnetSliceDim4, + expected: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &ipArrayDim2, + expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &ipnetArrayDim4, + expected: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, } for i, tt := range simpleTests { diff --git a/int2_array.go b/int2_array.go index ad2bd094..06febf01 100644 --- a/int2_array.go +++ b/int2_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,334 +31,94 @@ func (dst *Int2Array) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = Int2Array{Status: Null} + return nil + } - case []int16: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int16: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint16: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint16: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int32: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int32: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint32: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint32: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int64: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int64: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint64: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint64: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Int2: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - *dst = Int2Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Int2Array", src) + } + if elementsLength == 0 { + *dst = Int2Array{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int2Array", value) + return errors.Errorf("cannot convert %v to Int2Array", src) + } + + *dst = Int2Array{ + Elements: make([]Int2, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int2, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Int2Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *Int2Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to Int2Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in Int2Array", err) + } + index++ + + return index, nil +} + func (dst Int2Array) Get() interface{} { switch dst.Status { case Present: @@ -372,158 +133,26 @@ func (dst Int2Array) Get() interface{} { func (src *Int2Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int16: - *v = make([]*int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint16: - *v = make([]*uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int32: - *v = make([]*int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint32: - *v = make([]*uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int: - *v = make([]int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int: - *v = make([]*int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint: - *v = make([]uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint: - *v = make([]*uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -531,6 +160,49 @@ func (src *Int2Array) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *Int2Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from Int2Array") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2Array{Status: Null} diff --git a/int2_array_test.go b/int2_array_test.go index 22f71745..dfe84c19 100644 --- a/int2_array_test.go +++ b/int2_array_test.go @@ -110,6 +110,54 @@ func TestInt2ArraySet(t *testing.T) { source: (([]int16)(nil)), result: pgtype.Int2Array{Status: pgtype.Null}, }, + { + source: [][]int16{{1}, {2}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]int16{{1}, {2}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -129,6 +177,10 @@ func TestInt2ArrayAssignTo(t *testing.T) { var int16Slice []int16 var uint16Slice []uint16 var namedInt16Slice _int16Slice + var int16SliceDim2 [][]int16 + var int16SliceDim4 [][][][]int16 + var int16ArrayDim2 [2][1]int16 + var int16ArrayDim4 [2][1][1][3]int16 simpleTests := []struct { src pgtype.Int2Array @@ -167,6 +219,58 @@ func TestInt2ArrayAssignTo(t *testing.T) { dst: &int16Slice, expected: (([]int16)(nil)), }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]int16{{1}, {2}}, + dst: &int16SliceDim2, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int16SliceDim4, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]int16{{1}, {2}}, + dst: &int16ArrayDim2, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int16ArrayDim4, + }, } for i, tt := range simpleTests { @@ -200,6 +304,27 @@ func TestInt2ArrayAssignTo(t *testing.T) { }, dst: &uint16Slice, }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int16ArrayDim2, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int16Slice, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &int16ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/int4_array.go b/int4_array.go index 15565f64..189bd238 100644 --- a/int4_array.go +++ b/int4_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,334 +31,94 @@ func (dst *Int4Array) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = Int4Array{Status: Null} + return nil + } - case []int16: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int16: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint16: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint16: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int32: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int32: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint32: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint32: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int64: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int64: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint64: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint64: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Int4: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - *dst = Int4Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Int4Array", src) + } + if elementsLength == 0 { + *dst = Int4Array{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int4Array", value) + return errors.Errorf("cannot convert %v to Int4Array", src) + } + + *dst = Int4Array{ + Elements: make([]Int4, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int4, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Int4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *Int4Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to Int4Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in Int4Array", err) + } + index++ + + return index, nil +} + func (dst Int4Array) Get() interface{} { switch dst.Status { case Present: @@ -372,158 +133,26 @@ func (dst Int4Array) Get() interface{} { func (src *Int4Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int16: - *v = make([]*int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint16: - *v = make([]*uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int32: - *v = make([]*int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint32: - *v = make([]*uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int: - *v = make([]int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int: - *v = make([]*int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint: - *v = make([]uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint: - *v = make([]*uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -531,6 +160,49 @@ func (src *Int4Array) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *Int4Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from Int4Array") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4Array{Status: Null} diff --git a/int4_array_test.go b/int4_array_test.go index c839c1c9..35b791d3 100644 --- a/int4_array_test.go +++ b/int4_array_test.go @@ -116,6 +116,54 @@ func TestInt4ArraySet(t *testing.T) { source: (([]int32)(nil)), result: pgtype.Int4Array{Status: pgtype.Null}, }, + { + source: [][]int32{{1}, {2}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]int32{{1}, {2}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -143,6 +191,10 @@ func TestInt4ArrayAssignTo(t *testing.T) { var int32Slice []int32 var uint32Slice []uint32 var namedInt32Slice _int32Slice + var int32SliceDim2 [][]int32 + var int32SliceDim4 [][][][]int32 + var int32ArrayDim2 [2][1]int32 + var int32ArrayDim4 [2][1][1][3]int32 simpleTests := []struct { src pgtype.Int4Array @@ -181,6 +233,58 @@ func TestInt4ArrayAssignTo(t *testing.T) { dst: &int32Slice, expected: (([]int32)(nil)), }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]int32{{1}, {2}}, + dst: &int32SliceDim2, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int32SliceDim4, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]int32{{1}, {2}}, + dst: &int32ArrayDim2, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int32ArrayDim4, + }, } for i, tt := range simpleTests { @@ -214,6 +318,27 @@ func TestInt4ArrayAssignTo(t *testing.T) { }, dst: &uint32Slice, }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int32ArrayDim2, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int32Slice, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &int32ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/int8_array.go b/int8_array.go index e8e8823a..edb232cb 100644 --- a/int8_array.go +++ b/int8_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,334 +31,94 @@ func (dst *Int8Array) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = Int8Array{Status: Null} + return nil + } - case []int16: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int16: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint16: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint16: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int32: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int32: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint32: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint32: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int64: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int64: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint64: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint64: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Int8: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - *dst = Int8Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Int8Array", src) + } + if elementsLength == 0 { + *dst = Int8Array{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int8Array", value) + return errors.Errorf("cannot convert %v to Int8Array", src) + } + + *dst = Int8Array{ + Elements: make([]Int8, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int8, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Int8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *Int8Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to Int8Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in Int8Array", err) + } + index++ + + return index, nil +} + func (dst Int8Array) Get() interface{} { switch dst.Status { case Present: @@ -372,158 +133,26 @@ func (dst Int8Array) Get() interface{} { func (src *Int8Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int16: - *v = make([]*int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint16: - *v = make([]*uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int32: - *v = make([]*int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint32: - *v = make([]*uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int: - *v = make([]int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int: - *v = make([]*int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint: - *v = make([]uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint: - *v = make([]*uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -531,6 +160,49 @@ func (src *Int8Array) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *Int8Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from Int8Array") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8Array{Status: Null} diff --git a/int8_array_test.go b/int8_array_test.go index e9e7acfb..d65b875a 100644 --- a/int8_array_test.go +++ b/int8_array_test.go @@ -117,6 +117,54 @@ func TestInt8ArraySet(t *testing.T) { source: (([]int64)(nil)), result: pgtype.Int8Array{Status: pgtype.Null}, }, + { + source: [][]int64{{1}, {2}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]int64{{1}, {2}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -136,6 +184,10 @@ func TestInt8ArrayAssignTo(t *testing.T) { var int64Slice []int64 var uint64Slice []uint64 var namedInt64Slice _int64Slice + var int64SliceDim2 [][]int64 + var int64SliceDim4 [][][][]int64 + var int64ArrayDim2 [2][1]int64 + var int64ArrayDim4 [2][1][1][3]int64 simpleTests := []struct { src pgtype.Int8Array @@ -174,6 +226,58 @@ func TestInt8ArrayAssignTo(t *testing.T) { dst: &int64Slice, expected: (([]int64)(nil)), }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]int64{{1}, {2}}, + dst: &int64SliceDim2, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int64SliceDim4, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]int64{{1}, {2}}, + dst: &int64ArrayDim2, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int64ArrayDim4, + }, } for i, tt := range simpleTests { @@ -207,6 +311,27 @@ func TestInt8ArrayAssignTo(t *testing.T) { }, dst: &uint64Slice, }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int64ArrayDim2, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int64Slice, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &int64ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/jsonb_array.go b/jsonb_array.go index daebfa7b..c5a40a1d 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,49 +31,94 @@ func (dst *JSONBArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = JSONBArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = JSONBArray{Status: Null} - } else if len(value) == 0 { - *dst = JSONBArray{Status: Present} - } else { - elements := make([]Text, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = JSONBArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Text: - if value == nil { - *dst = JSONBArray{Status: Null} - } else if len(value) == 0 { - *dst = JSONBArray{Status: Present} - } else { - *dst = JSONBArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for JSONBArray", src) + } + if elementsLength == 0 { + *dst = JSONBArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to JSONBArray", value) + return errors.Errorf("cannot convert %v to JSONBArray", src) + } + + *dst = JSONBArray{ + Elements: make([]Text, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Text, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to JSONBArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *JSONBArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to JSONBArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in JSONBArray", err) + } + index++ + + return index, nil +} + func (dst JSONBArray) Get() interface{} { switch dst.Status { case Present: @@ -87,23 +133,26 @@ func (dst JSONBArray) Get() interface{} { func (src *JSONBArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -111,6 +160,49 @@ func (src *JSONBArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *JSONBArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from JSONBArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *JSONBArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = JSONBArray{Status: Null} diff --git a/macaddr_array.go b/macaddr_array.go index 616d6f85..398db1fe 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "net" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,68 +31,94 @@ func (dst *MacaddrArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = MacaddrArray{Status: Null} + return nil + } - case []net.HardwareAddr: - if value == nil { - *dst = MacaddrArray{Status: Null} - } else if len(value) == 0 { - *dst = MacaddrArray{Status: Present} - } else { - elements := make([]Macaddr, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = MacaddrArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*net.HardwareAddr: - if value == nil { - *dst = MacaddrArray{Status: Null} - } else if len(value) == 0 { - *dst = MacaddrArray{Status: Present} - } else { - elements := make([]Macaddr, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = MacaddrArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Macaddr: - if value == nil { - *dst = MacaddrArray{Status: Null} - } else if len(value) == 0 { - *dst = MacaddrArray{Status: Present} - } else { - *dst = MacaddrArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for MacaddrArray", src) + } + if elementsLength == 0 { + *dst = MacaddrArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to MacaddrArray", value) + return errors.Errorf("cannot convert %v to MacaddrArray", src) + } + + *dst = MacaddrArray{ + Elements: make([]Macaddr, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Macaddr, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to MacaddrArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *MacaddrArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to MacaddrArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in MacaddrArray", err) + } + index++ + + return index, nil +} + func (dst MacaddrArray) Get() interface{} { switch dst.Status { case Present: @@ -107,32 +133,26 @@ func (dst MacaddrArray) Get() interface{} { func (src *MacaddrArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]net.HardwareAddr: - *v = make([]net.HardwareAddr, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.HardwareAddr: - *v = make([]*net.HardwareAddr, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -140,6 +160,49 @@ func (src *MacaddrArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *MacaddrArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from MacaddrArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *MacaddrArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = MacaddrArray{Status: Null} diff --git a/macaddr_array_test.go b/macaddr_array_test.go index d2b0a73b..647db8cf 100644 --- a/macaddr_array_test.go +++ b/macaddr_array_test.go @@ -44,6 +44,78 @@ func TestMacaddrArraySet(t *testing.T) { source: (([]net.HardwareAddr)(nil)), result: pgtype.MacaddrArray{Status: pgtype.Null}, }, + { + source: [][]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -61,6 +133,10 @@ func TestMacaddrArraySet(t *testing.T) { func TestMacaddrArrayAssignTo(t *testing.T) { var macaddrSlice []net.HardwareAddr + var macaddrSliceDim2 [][]net.HardwareAddr + var macaddrSliceDim4 [][][][]net.HardwareAddr + var macaddrArrayDim2 [2][1]net.HardwareAddr + var macaddrArrayDim4 [2][1][1][3]net.HardwareAddr simpleTests := []struct { src pgtype.MacaddrArray @@ -90,6 +166,82 @@ func TestMacaddrArrayAssignTo(t *testing.T) { dst: &macaddrSlice, expected: (([]net.HardwareAddr)(nil)), }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &macaddrSliceDim2, + expected: [][]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &macaddrSliceDim4, + expected: [][][][]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &macaddrArrayDim2, + expected: [2][1]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &macaddrArrayDim4, + expected: [2][1][1][3]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + }, } for i, tt := range simpleTests { diff --git a/numeric_array.go b/numeric_array.go index e086ca7a..dec81535 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,182 +31,94 @@ func (dst *NumericArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = NumericArray{Status: Null} + return nil + } - case []float32: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*float32: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []float64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*float64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Numeric: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - *dst = NumericArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for NumericArray", src) + } + if elementsLength == 0 { + *dst = NumericArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to NumericArray", value) + return errors.Errorf("cannot convert %v to NumericArray", src) + } + + *dst = NumericArray{ + Elements: make([]Numeric, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Numeric, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to NumericArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *NumericArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to NumericArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in NumericArray", err) + } + index++ + + return index, nil +} + func (dst NumericArray) Get() interface{} { switch dst.Status { case Present: @@ -220,86 +133,26 @@ func (dst NumericArray) Get() interface{} { func (src *NumericArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]float32: - *v = make([]float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float32: - *v = make([]*float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]float64: - *v = make([]float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float64: - *v = make([]*float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -307,6 +160,49 @@ func (src *NumericArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *NumericArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from NumericArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = NumericArray{Status: Null} diff --git a/numeric_array_test.go b/numeric_array_test.go index eafd31be..29300bf0 100644 --- a/numeric_array_test.go +++ b/numeric_array_test.go @@ -91,6 +91,54 @@ func TestNumericArraySet(t *testing.T) { source: (([]float32)(nil)), result: pgtype.NumericArray{Status: pgtype.Null}, }, + { + source: [][]float32{{1}, {2}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(5), Status: pgtype.Present}, + {Int: big.NewInt(6), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]float32{{1}, {2}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(5), Status: pgtype.Present}, + {Int: big.NewInt(6), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -109,6 +157,10 @@ func TestNumericArraySet(t *testing.T) { func TestNumericArrayAssignTo(t *testing.T) { var float32Slice []float32 var float64Slice []float64 + var float32SliceDim2 [][]float32 + var float32SliceDim4 [][][][]float32 + var float32ArrayDim2 [2][1]float32 + var float32ArrayDim4 [2][1][1][3]float32 simpleTests := []struct { src pgtype.NumericArray @@ -138,6 +190,58 @@ func TestNumericArrayAssignTo(t *testing.T) { dst: &float32Slice, expected: (([]float32)(nil)), }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float32SliceDim2, + expected: [][]float32{{1}, {2}}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(5), Status: pgtype.Present}, + {Int: big.NewInt(6), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &float32SliceDim4, + expected: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float32ArrayDim2, + expected: [2][1]float32{{1}, {2}}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(5), Status: pgtype.Present}, + {Int: big.NewInt(6), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &float32ArrayDim4, + expected: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + }, } for i, tt := range simpleTests { @@ -163,6 +267,27 @@ func TestNumericArrayAssignTo(t *testing.T) { }, dst: &float32Slice, }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float32ArrayDim2, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float32Slice, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float32ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/text_array.go b/text_array.go index d1583557..31ed04ac 100644 --- a/text_array.go +++ b/text_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *TextArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = TextArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = TextArray{Status: Null} - } else if len(value) == 0 { - *dst = TextArray{Status: Present} - } else { - elements := make([]Text, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TextArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = TextArray{Status: Null} - } else if len(value) == 0 { - *dst = TextArray{Status: Present} - } else { - elements := make([]Text, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TextArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Text: - if value == nil { - *dst = TextArray{Status: Null} - } else if len(value) == 0 { - *dst = TextArray{Status: Present} - } else { - *dst = TextArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TextArray", src) + } + if elementsLength == 0 { + *dst = TextArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TextArray", value) + return errors.Errorf("cannot convert %v to TextArray", src) + } + + *dst = TextArray{ + Elements: make([]Text, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Text, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TextArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *TextArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to TextArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in TextArray", err) + } + index++ + + return index, nil +} + func (dst TextArray) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst TextArray) Get() interface{} { func (src *TextArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *TextArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *TextArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from TextArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TextArray{Status: Null} diff --git a/text_array_test.go b/text_array_test.go index a29ce617..125d6034 100644 --- a/text_array_test.go +++ b/text_array_test.go @@ -68,6 +68,54 @@ func TestTextArraySet(t *testing.T) { source: (([]string)(nil)), result: pgtype.TextArray{Status: pgtype.Null}, }, + { + source: [][]string{{"foo"}, {"bar"}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]string{{"foo"}, {"bar"}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -87,6 +135,10 @@ func TestTextArrayAssignTo(t *testing.T) { var stringSlice []string type _stringSlice []string var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string simpleTests := []struct { src pgtype.TextArray @@ -116,6 +168,58 @@ func TestTextArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: (([]string)(nil)), }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringSliceDim2, + expected: [][]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + expected: [2][1]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, } for i, tt := range simpleTests { @@ -141,6 +245,27 @@ func TestTextArrayAssignTo(t *testing.T) { }, dst: &stringSlice, }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringSlice, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + }, } for i, tt := range errorTests { diff --git a/timestamp_array.go b/timestamp_array.go index 3b2c3141..355b29c5 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "time" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,68 +31,94 @@ func (dst *TimestampArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = TimestampArray{Status: Null} + return nil + } - case []time.Time: - if value == nil { - *dst = TimestampArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestampArray{Status: Present} - } else { - elements := make([]Timestamp, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestampArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*time.Time: - if value == nil { - *dst = TimestampArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestampArray{Status: Present} - } else { - elements := make([]Timestamp, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestampArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Timestamp: - if value == nil { - *dst = TimestampArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestampArray{Status: Present} - } else { - *dst = TimestampArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TimestampArray", src) + } + if elementsLength == 0 { + *dst = TimestampArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TimestampArray", value) + return errors.Errorf("cannot convert %v to TimestampArray", src) + } + + *dst = TimestampArray{ + Elements: make([]Timestamp, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Timestamp, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TimestampArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *TimestampArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to TimestampArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in TimestampArray", err) + } + index++ + + return index, nil +} + func (dst TimestampArray) Get() interface{} { switch dst.Status { case Present: @@ -107,32 +133,26 @@ func (dst TimestampArray) Get() interface{} { func (src *TimestampArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*time.Time: - *v = make([]*time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -140,6 +160,49 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *TimestampArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from TimestampArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestampArray{Status: Null} diff --git a/timestamp_array_test.go b/timestamp_array_test.go index d7632fa3..c6f32d20 100644 --- a/timestamp_array_test.go +++ b/timestamp_array_test.go @@ -85,6 +85,42 @@ func TestTimestampArraySet(t *testing.T) { source: (([]time.Time)(nil)), result: pgtype.TimestampArray{Status: pgtype.Null}, }, + { + source: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -102,6 +138,10 @@ func TestTimestampArraySet(t *testing.T) { func TestTimestampArrayAssignTo(t *testing.T) { var timeSlice []time.Time + var timeSliceDim2 [][]time.Time + var timeSliceDim4 [][][][]time.Time + var timeArrayDim2 [2][1]time.Time + var timeArrayDim4 [2][1][1][3]time.Time simpleTests := []struct { src pgtype.TimestampArray @@ -122,6 +162,82 @@ func TestTimestampArrayAssignTo(t *testing.T) { dst: &timeSlice, expected: (([]time.Time)(nil)), }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeSliceDim2, + expected: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeSliceDim4, + expected: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + expected: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + expected: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, } for i, tt := range simpleTests { @@ -147,6 +263,33 @@ func TestTimestampArrayAssignTo(t *testing.T) { }, dst: &timeSlice, }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeSlice, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + }, } for i, tt := range errorTests { diff --git a/timestamptz_array.go b/timestamptz_array.go index 3328ec05..94a791b6 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "time" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,68 +31,94 @@ func (dst *TimestamptzArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = TimestamptzArray{Status: Null} + return nil + } - case []time.Time: - if value == nil { - *dst = TimestamptzArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestamptzArray{Status: Present} - } else { - elements := make([]Timestamptz, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestamptzArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*time.Time: - if value == nil { - *dst = TimestamptzArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestamptzArray{Status: Present} - } else { - elements := make([]Timestamptz, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestamptzArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Timestamptz: - if value == nil { - *dst = TimestamptzArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestamptzArray{Status: Present} - } else { - *dst = TimestamptzArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TimestamptzArray", src) + } + if elementsLength == 0 { + *dst = TimestamptzArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TimestamptzArray", value) + return errors.Errorf("cannot convert %v to TimestamptzArray", src) + } + + *dst = TimestamptzArray{ + Elements: make([]Timestamptz, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Timestamptz, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TimestamptzArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *TimestamptzArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to TimestamptzArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in TimestamptzArray", err) + } + index++ + + return index, nil +} + func (dst TimestamptzArray) Get() interface{} { switch dst.Status { case Present: @@ -107,32 +133,26 @@ func (dst TimestamptzArray) Get() interface{} { func (src *TimestamptzArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*time.Time: - *v = make([]*time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -140,6 +160,49 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *TimestamptzArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from TimestamptzArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestamptzArray{Status: Null} diff --git a/timestamptz_array_test.go b/timestamptz_array_test.go index 8a4cfd1d..f4e80413 100644 --- a/timestamptz_array_test.go +++ b/timestamptz_array_test.go @@ -85,6 +85,78 @@ func TestTimestamptzArraySet(t *testing.T) { source: (([]time.Time)(nil)), result: pgtype.TimestamptzArray{Status: pgtype.Null}, }, + { + source: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -102,6 +174,10 @@ func TestTimestamptzArraySet(t *testing.T) { func TestTimestamptzArrayAssignTo(t *testing.T) { var timeSlice []time.Time + var timeSliceDim2 [][]time.Time + var timeSliceDim4 [][][][]time.Time + var timeArrayDim2 [2][1]time.Time + var timeArrayDim4 [2][1][1][3]time.Time simpleTests := []struct { src pgtype.TimestamptzArray @@ -122,6 +198,82 @@ func TestTimestamptzArrayAssignTo(t *testing.T) { dst: &timeSlice, expected: (([]time.Time)(nil)), }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeSliceDim2, + expected: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeSliceDim4, + expected: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + expected: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + expected: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, } for i, tt := range simpleTests { @@ -147,6 +299,33 @@ func TestTimestamptzArrayAssignTo(t *testing.T) { }, dst: &timeSlice, }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeSlice, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + }, } for i, tt := range errorTests { diff --git a/tstzrange_array.go b/tstzrange_array.go index c19a9bfa..f5043c65 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,30 +31,94 @@ func (dst *TstzrangeArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = TstzrangeArray{Status: Null} + return nil + } - case []Tstzrange: - if value == nil { - *dst = TstzrangeArray{Status: Null} - } else if len(value) == 0 { - *dst = TstzrangeArray{Status: Present} - } else { - *dst = TstzrangeArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TstzrangeArray", src) + } + if elementsLength == 0 { + *dst = TstzrangeArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TstzrangeArray", value) + return errors.Errorf("cannot convert %v to TstzrangeArray", src) + } + + *dst = TstzrangeArray{ + Elements: make([]Tstzrange, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Tstzrange, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TstzrangeArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *TstzrangeArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to TstzrangeArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in TstzrangeArray", err) + } + index++ + + return index, nil +} + func (dst TstzrangeArray) Get() interface{} { switch dst.Status { case Present: @@ -68,23 +133,26 @@ func (dst TstzrangeArray) Get() interface{} { func (src *TstzrangeArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]Tstzrange: - *v = make([]Tstzrange, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -92,6 +160,49 @@ func (src *TstzrangeArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *TstzrangeArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from TstzrangeArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *TstzrangeArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TstzrangeArray{Status: Null} diff --git a/typed_array.go.erb b/typed_array.go.erb index a3deea5b..fb964ec8 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -30,51 +30,94 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { } } - switch value := src.(type) { - <% go_array_types.split(",").each do |t| %> - <% if t != "[]#{pgtype_element_type}" %> - case <%= t %>: - if value == nil { - *dst = <%= pgtype_array_type %>{Status: Null} - } else if len(value) == 0 { - *dst = <%= pgtype_array_type %>{Status: Present} - } else { - elements := make([]<%= pgtype_element_type %>, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = <%= pgtype_array_type %>{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - <% end %> - <% end %> - case []<%= pgtype_element_type %>: - if value == nil { - *dst = <%= pgtype_array_type %>{Status: Null} - } else if len(value) == 0 { - *dst = <%= pgtype_array_type %>{Status: Present} - } else { - *dst = <%= pgtype_array_type %>{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status : Present, - } - } - default: + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = <%= pgtype_array_type %>{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for <%= pgtype_array_type %>", src) + } + if elementsLength == 0 { + *dst = <%= pgtype_array_type %>{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>", value) + return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>", src) + } + + *dst = <%= pgtype_array_type %> { + Elements: make([]<%= pgtype_element_type %>, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]<%= pgtype_element_type %>, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *<%= pgtype_array_type %>) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to <%= pgtype_array_type %>") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in <%= pgtype_array_type %>", err) + } + index++ + + return index, nil +} + func (dst <%= pgtype_array_type %>) Get() interface{} { switch dst.Status { case Present: @@ -89,23 +132,26 @@ func (dst <%= pgtype_array_type %>) Get() interface{} { func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - <% go_array_types.split(",").each do |t| %> - case *<%= t %>: - *v = make(<%= t %>, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - <% end %> - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -113,6 +159,49 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *<%= pgtype_array_type %>) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from <%= pgtype_array_type %>") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 607d3bc3..8c594944 100755 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -1,27 +1,27 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[]*bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time,[]*time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time,[]*time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go -erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange go_array_types=[]Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstzrange_array.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time,[]*time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32,[]*float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64,[]*float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go -erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr go_array_types=[]net.HardwareAddr,[]*net.HardwareAddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go -erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string,[]*string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > text_array.go -erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string,[]*string element_type_name=varchar text_null=NULL binary_format=true typed_array.go.erb > varchar_array.go -erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string,[]*string element_type_name=bpchar text_null=NULL binary_format=true typed_array.go.erb > bpchar_array.go -erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go -erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string,[]*string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go -erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go -erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]*float32,[]float64,[]*float64,[]int64,[]*int64,[]uint64,[]*uint64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go -erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string,[]*string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go -erb pgtype_array_type=JSONBArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstzrange_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go +erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go +erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar element_type_name=varchar text_null=NULL binary_format=true typed_array.go.erb > varchar_array.go +erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar element_type_name=bpchar text_null=NULL binary_format=true typed_array.go.erb > bpchar_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go +erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go +erb pgtype_array_type=NumericArray pgtype_element_type=Numeric element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go +erb pgtype_array_type=UUIDArray pgtype_element_type=UUID element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go +erb pgtype_array_type=JSONBArray pgtype_element_type=Text element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go # While the binary format is theoretically possible it is only practical to use the text format. -erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string,[]*string text_null=NULL binary_format=false typed_array.go.erb > enum_array.go +erb pgtype_array_type=EnumArray pgtype_element_type=GenericText text_null=NULL binary_format=false typed_array.go.erb > enum_array.go goimports -w *_array.go diff --git a/uuid_array.go b/uuid_array.go index 06d2d576..e2c86cf8 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,106 +31,94 @@ func (dst *UUIDArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = UUIDArray{Status: Null} + return nil + } - case [][16]byte: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case [][]byte: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []string: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []UUID: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - *dst = UUIDArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for UUIDArray", src) + } + if elementsLength == 0 { + *dst = UUIDArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to UUIDArray", value) + return errors.Errorf("cannot convert %v to UUIDArray", src) + } + + *dst = UUIDArray{ + Elements: make([]UUID, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]UUID, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to UUIDArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *UUIDArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to UUIDArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in UUIDArray", err) + } + index++ + + return index, nil +} + func (dst UUIDArray) Get() interface{} { switch dst.Status { case Present: @@ -144,50 +133,26 @@ func (dst UUIDArray) Get() interface{} { func (src *UUIDArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[][16]byte: - *v = make([][16]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[][]byte: - *v = make([][]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -195,6 +160,49 @@ func (src *UUIDArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *UUIDArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from UUIDArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *UUIDArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = UUIDArray{Status: Null} diff --git a/uuid_array_test.go b/uuid_array_test.go index d5446920..cdb212bb 100644 --- a/uuid_array_test.go +++ b/uuid_array_test.go @@ -123,6 +123,78 @@ func TestUUIDArraySet(t *testing.T) { source: ([]string)(nil), result: pgtype.UUIDArray{Status: pgtype.Null}, }, + { + source: [][][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -142,6 +214,10 @@ func TestUUIDArrayAssignTo(t *testing.T) { var byteArraySlice [][16]byte var byteSliceSlice [][]byte var stringSlice []string + var byteArraySliceDim2 [][][16]byte + var stringSliceDim4 [][][][]string + var byteArrayDim2 [2][1][16]byte + var stringArrayDim4 [2][1][1][3]string simpleTests := []struct { src pgtype.UUIDArray @@ -190,6 +266,82 @@ func TestUUIDArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: ([]string)(nil), }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &byteArraySliceDim2, + expected: [][][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &byteArrayDim2, + expected: [2][1][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + }, } for i, tt := range simpleTests { diff --git a/varchar_array.go b/varchar_array.go index 32ca5941..ec378ed7 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *VarcharArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = VarcharArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = VarcharArray{Status: Null} - } else if len(value) == 0 { - *dst = VarcharArray{Status: Present} - } else { - elements := make([]Varchar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = VarcharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = VarcharArray{Status: Null} - } else if len(value) == 0 { - *dst = VarcharArray{Status: Present} - } else { - elements := make([]Varchar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = VarcharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Varchar: - if value == nil { - *dst = VarcharArray{Status: Null} - } else if len(value) == 0 { - *dst = VarcharArray{Status: Present} - } else { - *dst = VarcharArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for VarcharArray", src) + } + if elementsLength == 0 { + *dst = VarcharArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to VarcharArray", value) + return errors.Errorf("cannot convert %v to VarcharArray", src) + } + + *dst = VarcharArray{ + Elements: make([]Varchar, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Varchar, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to VarcharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *VarcharArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to VarcharArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in VarcharArray", err) + } + index++ + + return index, nil +} + func (dst VarcharArray) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst VarcharArray) Get() interface{} { func (src *VarcharArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *VarcharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from VarcharArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = VarcharArray{Status: Null} diff --git a/varchar_array_test.go b/varchar_array_test.go index 9ad80862..3b0e65ed 100644 --- a/varchar_array_test.go +++ b/varchar_array_test.go @@ -68,6 +68,54 @@ func TestVarcharArraySet(t *testing.T) { source: (([]string)(nil)), result: pgtype.VarcharArray{Status: pgtype.Null}, }, + { + source: [][]string{{"foo"}, {"bar"}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]string{{"foo"}, {"bar"}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -87,6 +135,10 @@ func TestVarcharArrayAssignTo(t *testing.T) { var stringSlice []string type _stringSlice []string var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string simpleTests := []struct { src pgtype.VarcharArray @@ -116,6 +168,58 @@ func TestVarcharArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: (([]string)(nil)), }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringSliceDim2, + expected: [][]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + expected: [2][1]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, } for i, tt := range simpleTests { @@ -141,6 +245,27 @@ func TestVarcharArrayAssignTo(t *testing.T) { }, dst: &stringSlice, }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringSlice, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + }, } for i, tt := range errorTests { From b90570feb55e00d4977de788e840bd3b10d5414f Mon Sep 17 00:00:00 2001 From: Simo Haasanen Date: Sat, 8 Aug 2020 19:51:37 +0100 Subject: [PATCH 0563/1158] Restored more optimised array type conversions for a few select 1D-slice types. Results of calls to the reflect lib are now stored as local variables for small performance gains. --- aclitem_array.go | 188 ++++++++++---- bool_array.go | 188 ++++++++++---- bpchar_array.go | 188 ++++++++++---- bytea_array.go | 160 ++++++++---- cidr_array.go | 217 ++++++++++++---- date_array.go | 189 ++++++++++---- enum_array.go | 188 ++++++++++---- float4_array.go | 188 ++++++++++---- float8_array.go | 188 ++++++++++---- hstore_array.go | 160 ++++++++---- inet_array.go | 217 ++++++++++++---- int2_array.go | 580 +++++++++++++++++++++++++++++++++++++++---- int4_array.go | 580 +++++++++++++++++++++++++++++++++++++++---- int8_array.go | 580 +++++++++++++++++++++++++++++++++++++++---- jsonb_array.go | 160 ++++++++---- macaddr_array.go | 189 ++++++++++---- numeric_array.go | 356 ++++++++++++++++++++++---- text_array.go | 188 ++++++++++---- timestamp_array.go | 189 ++++++++++---- timestamptz_array.go | 189 ++++++++++---- tstzrange_array.go | 143 +++++++---- typed_array.go.erb | 164 ++++++++---- typed_array_gen.sh | 46 ++-- uuid_array.go | 244 ++++++++++++++---- varchar_array.go | 188 ++++++++++---- 25 files changed, 4594 insertions(+), 1273 deletions(-) diff --git a/aclitem_array.go b/aclitem_array.go index 09a64fb6..52b67d85 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -29,56 +29,110 @@ func (dst *ACLItemArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = ACLItemArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for ACLItemArray", src) - } - if elementsLength == 0 { - *dst = ACLItemArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to ACLItemArray", src) - } - - *dst = ACLItemArray{ - Elements: make([]ACLItem, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []string: + if value == nil { + *dst = ACLItemArray{Status: Null} + } else if len(value) == 0 { + *dst = ACLItemArray{Status: Present} + } else { + elements := make([]ACLItem, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]ACLItem, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = ACLItemArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*string: + if value == nil { + *dst = ACLItemArray{Status: Null} + } else if len(value) == 0 { + *dst = ACLItemArray{Status: Present} + } else { + elements := make([]ACLItem, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = ACLItemArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []ACLItem: + if value == nil { + *dst = ACLItemArray{Status: Null} + } else if len(value) == 0 { + *dst = ACLItemArray{Status: Present} + } else { + *dst = ACLItemArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = ACLItemArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for ACLItemArray", src) + } + if elementsLength == 0 { + *dst = ACLItemArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to ACLItemArray", src) + } + + *dst = ACLItemArray{ + Elements: make([]ACLItem, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]ACLItem, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to ACLItemArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to ACLItemArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -93,10 +147,11 @@ func (dst *ACLItemArray) setRecursive(value reflect.Value, index, dimension int) break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -131,6 +186,30 @@ func (dst ACLItemArray) Get() interface{} { func (src *ACLItemArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -169,10 +248,12 @@ func (src *ACLItemArray) assignToRecursive(value reflect.Value, index, dimension length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -190,11 +271,14 @@ func (src *ACLItemArray) assignToRecursive(value reflect.Value, index, dimension if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from ACLItemArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from ACLItemArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/bool_array.go b/bool_array.go index 6569d5ca..6a4b3454 100644 --- a/bool_array.go +++ b/bool_array.go @@ -31,56 +31,110 @@ func (dst *BoolArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = BoolArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for BoolArray", src) - } - if elementsLength == 0 { - *dst = BoolArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to BoolArray", src) - } - - *dst = BoolArray{ - Elements: make([]Bool, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []bool: + if value == nil { + *dst = BoolArray{Status: Null} + } else if len(value) == 0 { + *dst = BoolArray{Status: Present} + } else { + elements := make([]Bool, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Bool, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = BoolArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*bool: + if value == nil { + *dst = BoolArray{Status: Null} + } else if len(value) == 0 { + *dst = BoolArray{Status: Present} + } else { + elements := make([]Bool, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = BoolArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Bool: + if value == nil { + *dst = BoolArray{Status: Null} + } else if len(value) == 0 { + *dst = BoolArray{Status: Present} + } else { + *dst = BoolArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = BoolArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for BoolArray", src) + } + if elementsLength == 0 { + *dst = BoolArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to BoolArray", src) + } + + *dst = BoolArray{ + Elements: make([]Bool, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Bool, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to BoolArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to BoolArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +149,11 @@ func (dst *BoolArray) setRecursive(value reflect.Value, index, dimension int) (i break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +188,30 @@ func (dst BoolArray) Get() interface{} { func (src *BoolArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]bool: + *v = make([]bool, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*bool: + *v = make([]*bool, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +250,12 @@ func (src *BoolArray) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +273,14 @@ func (src *BoolArray) assignToRecursive(value reflect.Value, index, dimension in if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from BoolArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from BoolArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/bpchar_array.go b/bpchar_array.go index 8aef8330..1f79a3fe 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -31,56 +31,110 @@ func (dst *BPCharArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = BPCharArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for BPCharArray", src) - } - if elementsLength == 0 { - *dst = BPCharArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to BPCharArray", src) - } - - *dst = BPCharArray{ - Elements: make([]BPChar, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []string: + if value == nil { + *dst = BPCharArray{Status: Null} + } else if len(value) == 0 { + *dst = BPCharArray{Status: Present} + } else { + elements := make([]BPChar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]BPChar, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = BPCharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*string: + if value == nil { + *dst = BPCharArray{Status: Null} + } else if len(value) == 0 { + *dst = BPCharArray{Status: Present} + } else { + elements := make([]BPChar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = BPCharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []BPChar: + if value == nil { + *dst = BPCharArray{Status: Null} + } else if len(value) == 0 { + *dst = BPCharArray{Status: Present} + } else { + *dst = BPCharArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = BPCharArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for BPCharArray", src) + } + if elementsLength == 0 { + *dst = BPCharArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to BPCharArray", src) + } + + *dst = BPCharArray{ + Elements: make([]BPChar, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]BPChar, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to BPCharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to BPCharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +149,11 @@ func (dst *BPCharArray) setRecursive(value reflect.Value, index, dimension int) break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +188,30 @@ func (dst BPCharArray) Get() interface{} { func (src *BPCharArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +250,12 @@ func (src *BPCharArray) assignToRecursive(value reflect.Value, index, dimension length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +273,14 @@ func (src *BPCharArray) assignToRecursive(value reflect.Value, index, dimension if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from BPCharArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from BPCharArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/bytea_array.go b/bytea_array.go index 3addb99a..17136554 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -31,56 +31,91 @@ func (dst *ByteaArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = ByteaArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for ByteaArray", src) - } - if elementsLength == 0 { - *dst = ByteaArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to ByteaArray", src) - } - - *dst = ByteaArray{ - Elements: make([]Bytea, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case [][]byte: + if value == nil { + *dst = ByteaArray{Status: Null} + } else if len(value) == 0 { + *dst = ByteaArray{Status: Present} + } else { + elements := make([]Bytea, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Bytea, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = ByteaArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Bytea: + if value == nil { + *dst = ByteaArray{Status: Null} + } else if len(value) == 0 { + *dst = ByteaArray{Status: Present} + } else { + *dst = ByteaArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = ByteaArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for ByteaArray", src) + } + if elementsLength == 0 { + *dst = ByteaArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to ByteaArray", src) + } + + *dst = ByteaArray{ + Elements: make([]Bytea, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Bytea, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to ByteaArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to ByteaArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +130,11 @@ func (dst *ByteaArray) setRecursive(value reflect.Value, index, dimension int) ( break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +169,21 @@ func (dst ByteaArray) Get() interface{} { func (src *ByteaArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[][]byte: + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +222,12 @@ func (src *ByteaArray) assignToRecursive(value reflect.Value, index, dimension i length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +245,14 @@ func (src *ByteaArray) assignToRecursive(value reflect.Value, index, dimension i if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from ByteaArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from ByteaArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/cidr_array.go b/cidr_array.go index 1ef2f428..770c4b8c 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "net" "reflect" "github.com/jackc/pgio" @@ -31,56 +32,129 @@ func (dst *CIDRArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = CIDRArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for CIDRArray", src) - } - if elementsLength == 0 { - *dst = CIDRArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to CIDRArray", src) - } - - *dst = CIDRArray{ - Elements: make([]CIDR, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []*net.IPNet: + if value == nil { + *dst = CIDRArray{Status: Null} + } else if len(value) == 0 { + *dst = CIDRArray{Status: Present} + } else { + elements := make([]CIDR, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]CIDR, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = CIDRArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []net.IP: + if value == nil { + *dst = CIDRArray{Status: Null} + } else if len(value) == 0 { + *dst = CIDRArray{Status: Present} + } else { + elements := make([]CIDR, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CIDRArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*net.IP: + if value == nil { + *dst = CIDRArray{Status: Null} + } else if len(value) == 0 { + *dst = CIDRArray{Status: Present} + } else { + elements := make([]CIDR, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CIDRArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []CIDR: + if value == nil { + *dst = CIDRArray{Status: Null} + } else if len(value) == 0 { + *dst = CIDRArray{Status: Present} + } else { + *dst = CIDRArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = CIDRArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for CIDRArray", src) + } + if elementsLength == 0 { + *dst = CIDRArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to CIDRArray", src) + } + + *dst = CIDRArray{ + Elements: make([]CIDR, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]CIDR, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to CIDRArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to CIDRArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +169,11 @@ func (dst *CIDRArray) setRecursive(value reflect.Value, index, dimension int) (i break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +208,39 @@ func (dst CIDRArray) Get() interface{} { func (src *CIDRArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]*net.IPNet: + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]net.IP: + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*net.IP: + *v = make([]*net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +279,12 @@ func (src *CIDRArray) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +302,14 @@ func (src *CIDRArray) assignToRecursive(value reflect.Value, index, dimension in if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from CIDRArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from CIDRArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/date_array.go b/date_array.go index 4ccdafe0..7ba93daa 100644 --- a/date_array.go +++ b/date_array.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "encoding/binary" "reflect" + "time" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,56 +32,110 @@ func (dst *DateArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = DateArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for DateArray", src) - } - if elementsLength == 0 { - *dst = DateArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to DateArray", src) - } - - *dst = DateArray{ - Elements: make([]Date, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []time.Time: + if value == nil { + *dst = DateArray{Status: Null} + } else if len(value) == 0 { + *dst = DateArray{Status: Present} + } else { + elements := make([]Date, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Date, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = DateArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*time.Time: + if value == nil { + *dst = DateArray{Status: Null} + } else if len(value) == 0 { + *dst = DateArray{Status: Present} + } else { + elements := make([]Date, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = DateArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Date: + if value == nil { + *dst = DateArray{Status: Null} + } else if len(value) == 0 { + *dst = DateArray{Status: Present} + } else { + *dst = DateArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = DateArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for DateArray", src) + } + if elementsLength == 0 { + *dst = DateArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to DateArray", src) + } + + *dst = DateArray{ + Elements: make([]Date, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Date, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to DateArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to DateArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +150,11 @@ func (dst *DateArray) setRecursive(value reflect.Value, index, dimension int) (i break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +189,30 @@ func (dst DateArray) Get() interface{} { func (src *DateArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]time.Time: + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*time.Time: + *v = make([]*time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +251,12 @@ func (src *DateArray) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +274,14 @@ func (src *DateArray) assignToRecursive(value reflect.Value, index, dimension in if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from DateArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from DateArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/enum_array.go b/enum_array.go index 2c83db24..561d4495 100644 --- a/enum_array.go +++ b/enum_array.go @@ -29,56 +29,110 @@ func (dst *EnumArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = EnumArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for EnumArray", src) - } - if elementsLength == 0 { - *dst = EnumArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to EnumArray", src) - } - - *dst = EnumArray{ - Elements: make([]GenericText, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []string: + if value == nil { + *dst = EnumArray{Status: Null} + } else if len(value) == 0 { + *dst = EnumArray{Status: Present} + } else { + elements := make([]GenericText, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]GenericText, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = EnumArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*string: + if value == nil { + *dst = EnumArray{Status: Null} + } else if len(value) == 0 { + *dst = EnumArray{Status: Present} + } else { + elements := make([]GenericText, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = EnumArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []GenericText: + if value == nil { + *dst = EnumArray{Status: Null} + } else if len(value) == 0 { + *dst = EnumArray{Status: Present} + } else { + *dst = EnumArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = EnumArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for EnumArray", src) + } + if elementsLength == 0 { + *dst = EnumArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to EnumArray", src) + } + + *dst = EnumArray{ + Elements: make([]GenericText, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]GenericText, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to EnumArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to EnumArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -93,10 +147,11 @@ func (dst *EnumArray) setRecursive(value reflect.Value, index, dimension int) (i break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -131,6 +186,30 @@ func (dst EnumArray) Get() interface{} { func (src *EnumArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -169,10 +248,12 @@ func (src *EnumArray) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -190,11 +271,14 @@ func (src *EnumArray) assignToRecursive(value reflect.Value, index, dimension in if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from EnumArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from EnumArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/float4_array.go b/float4_array.go index 78d1a860..829708e1 100644 --- a/float4_array.go +++ b/float4_array.go @@ -31,56 +31,110 @@ func (dst *Float4Array) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = Float4Array{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for Float4Array", src) - } - if elementsLength == 0 { - *dst = Float4Array{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Float4Array", src) - } - - *dst = Float4Array{ - Elements: make([]Float4, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []float32: + if value == nil { + *dst = Float4Array{Status: Null} + } else if len(value) == 0 { + *dst = Float4Array{Status: Present} + } else { + elements := make([]Float4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Float4, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = Float4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*float32: + if value == nil { + *dst = Float4Array{Status: Null} + } else if len(value) == 0 { + *dst = Float4Array{Status: Present} + } else { + elements := make([]Float4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Float4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Float4: + if value == nil { + *dst = Float4Array{Status: Null} + } else if len(value) == 0 { + *dst = Float4Array{Status: Present} + } else { + *dst = Float4Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = Float4Array{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Float4Array", src) + } + if elementsLength == 0 { + *dst = Float4Array{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Float4Array", src) + } + + *dst = Float4Array{ + Elements: make([]Float4, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Float4, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to Float4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Float4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +149,11 @@ func (dst *Float4Array) setRecursive(value reflect.Value, index, dimension int) break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +188,30 @@ func (dst Float4Array) Get() interface{} { func (src *Float4Array) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]float32: + *v = make([]float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*float32: + *v = make([]*float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +250,12 @@ func (src *Float4Array) assignToRecursive(value reflect.Value, index, dimension length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +273,14 @@ func (src *Float4Array) assignToRecursive(value reflect.Value, index, dimension if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from Float4Array") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from Float4Array") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/float8_array.go b/float8_array.go index 19223c52..6932cb88 100644 --- a/float8_array.go +++ b/float8_array.go @@ -31,56 +31,110 @@ func (dst *Float8Array) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = Float8Array{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for Float8Array", src) - } - if elementsLength == 0 { - *dst = Float8Array{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Float8Array", src) - } - - *dst = Float8Array{ - Elements: make([]Float8, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []float64: + if value == nil { + *dst = Float8Array{Status: Null} + } else if len(value) == 0 { + *dst = Float8Array{Status: Present} + } else { + elements := make([]Float8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Float8, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = Float8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*float64: + if value == nil { + *dst = Float8Array{Status: Null} + } else if len(value) == 0 { + *dst = Float8Array{Status: Present} + } else { + elements := make([]Float8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Float8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Float8: + if value == nil { + *dst = Float8Array{Status: Null} + } else if len(value) == 0 { + *dst = Float8Array{Status: Present} + } else { + *dst = Float8Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = Float8Array{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Float8Array", src) + } + if elementsLength == 0 { + *dst = Float8Array{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Float8Array", src) + } + + *dst = Float8Array{ + Elements: make([]Float8, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Float8, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to Float8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Float8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +149,11 @@ func (dst *Float8Array) setRecursive(value reflect.Value, index, dimension int) break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +188,30 @@ func (dst Float8Array) Get() interface{} { func (src *Float8Array) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]float64: + *v = make([]float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*float64: + *v = make([]*float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +250,12 @@ func (src *Float8Array) assignToRecursive(value reflect.Value, index, dimension length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +273,14 @@ func (src *Float8Array) assignToRecursive(value reflect.Value, index, dimension if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from Float8Array") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from Float8Array") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/hstore_array.go b/hstore_array.go index 8764aae7..4dc172be 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -31,56 +31,91 @@ func (dst *HstoreArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = HstoreArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for HstoreArray", src) - } - if elementsLength == 0 { - *dst = HstoreArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to HstoreArray", src) - } - - *dst = HstoreArray{ - Elements: make([]Hstore, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []map[string]string: + if value == nil { + *dst = HstoreArray{Status: Null} + } else if len(value) == 0 { + *dst = HstoreArray{Status: Present} + } else { + elements := make([]Hstore, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Hstore, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = HstoreArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Hstore: + if value == nil { + *dst = HstoreArray{Status: Null} + } else if len(value) == 0 { + *dst = HstoreArray{Status: Present} + } else { + *dst = HstoreArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = HstoreArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for HstoreArray", src) + } + if elementsLength == 0 { + *dst = HstoreArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to HstoreArray", src) + } + + *dst = HstoreArray{ + Elements: make([]Hstore, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Hstore, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to HstoreArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to HstoreArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +130,11 @@ func (dst *HstoreArray) setRecursive(value reflect.Value, index, dimension int) break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +169,21 @@ func (dst HstoreArray) Get() interface{} { func (src *HstoreArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]map[string]string: + *v = make([]map[string]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +222,12 @@ func (src *HstoreArray) assignToRecursive(value reflect.Value, index, dimension length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +245,14 @@ func (src *HstoreArray) assignToRecursive(value reflect.Value, index, dimension if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from HstoreArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from HstoreArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/inet_array.go b/inet_array.go index 91f5d6e8..75f1328f 100644 --- a/inet_array.go +++ b/inet_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "net" "reflect" "github.com/jackc/pgio" @@ -31,56 +32,129 @@ func (dst *InetArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = InetArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for InetArray", src) - } - if elementsLength == 0 { - *dst = InetArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to InetArray", src) - } - - *dst = InetArray{ - Elements: make([]Inet, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []*net.IPNet: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Inet, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []net.IP: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*net.IP: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Inet: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + *dst = InetArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = InetArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for InetArray", src) + } + if elementsLength == 0 { + *dst = InetArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to InetArray", src) + } + + *dst = InetArray{ + Elements: make([]Inet, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Inet, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to InetArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to InetArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +169,11 @@ func (dst *InetArray) setRecursive(value reflect.Value, index, dimension int) (i break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +208,39 @@ func (dst InetArray) Get() interface{} { func (src *InetArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]*net.IPNet: + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]net.IP: + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*net.IP: + *v = make([]*net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +279,12 @@ func (src *InetArray) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +302,14 @@ func (src *InetArray) assignToRecursive(value reflect.Value, index, dimension in if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from InetArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from InetArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/int2_array.go b/int2_array.go index 06febf01..ede35bac 100644 --- a/int2_array.go +++ b/int2_array.go @@ -31,56 +31,376 @@ func (dst *Int2Array) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = Int2Array{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for Int2Array", src) - } - if elementsLength == 0 { - *dst = Int2Array{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Int2Array", src) - } - - *dst = Int2Array{ - Elements: make([]Int2, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []int16: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Int2, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int16: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint16: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint16: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int32: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int32: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint32: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint32: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int64: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int64: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint64: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint64: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Int2: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + *dst = Int2Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = Int2Array{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Int2Array", src) + } + if elementsLength == 0 { + *dst = Int2Array{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Int2Array", src) + } + + *dst = Int2Array{ + Elements: make([]Int2, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int2, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to Int2Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Int2Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +415,11 @@ func (dst *Int2Array) setRecursive(value reflect.Value, index, dimension int) (i break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +454,156 @@ func (dst Int2Array) Get() interface{} { func (src *Int2Array) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]int16: + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int16: + *v = make([]*int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint16: + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint16: + *v = make([]*uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int32: + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int32: + *v = make([]*int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint32: + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint32: + *v = make([]*uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int: + *v = make([]int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int: + *v = make([]*int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint: + *v = make([]uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint: + *v = make([]*uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +642,12 @@ func (src *Int2Array) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +665,14 @@ func (src *Int2Array) assignToRecursive(value reflect.Value, index, dimension in if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from Int2Array") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from Int2Array") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/int4_array.go b/int4_array.go index 189bd238..b0856da9 100644 --- a/int4_array.go +++ b/int4_array.go @@ -31,56 +31,376 @@ func (dst *Int4Array) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = Int4Array{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for Int4Array", src) - } - if elementsLength == 0 { - *dst = Int4Array{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Int4Array", src) - } - - *dst = Int4Array{ - Elements: make([]Int4, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []int16: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Int4, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int16: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint16: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint16: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int32: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int32: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint32: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint32: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int64: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int64: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint64: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint64: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Int4: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + *dst = Int4Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = Int4Array{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Int4Array", src) + } + if elementsLength == 0 { + *dst = Int4Array{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Int4Array", src) + } + + *dst = Int4Array{ + Elements: make([]Int4, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int4, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to Int4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Int4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +415,11 @@ func (dst *Int4Array) setRecursive(value reflect.Value, index, dimension int) (i break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +454,156 @@ func (dst Int4Array) Get() interface{} { func (src *Int4Array) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]int16: + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int16: + *v = make([]*int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint16: + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint16: + *v = make([]*uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int32: + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int32: + *v = make([]*int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint32: + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint32: + *v = make([]*uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int: + *v = make([]int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int: + *v = make([]*int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint: + *v = make([]uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint: + *v = make([]*uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +642,12 @@ func (src *Int4Array) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +665,14 @@ func (src *Int4Array) assignToRecursive(value reflect.Value, index, dimension in if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from Int4Array") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from Int4Array") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/int8_array.go b/int8_array.go index edb232cb..c95ebef5 100644 --- a/int8_array.go +++ b/int8_array.go @@ -31,56 +31,376 @@ func (dst *Int8Array) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = Int8Array{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for Int8Array", src) - } - if elementsLength == 0 { - *dst = Int8Array{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Int8Array", src) - } - - *dst = Int8Array{ - Elements: make([]Int8, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []int16: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Int8, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int16: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint16: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint16: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int32: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int32: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint32: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint32: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int64: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int64: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint64: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint64: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Int8: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + *dst = Int8Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = Int8Array{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Int8Array", src) + } + if elementsLength == 0 { + *dst = Int8Array{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Int8Array", src) + } + + *dst = Int8Array{ + Elements: make([]Int8, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int8, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to Int8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Int8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +415,11 @@ func (dst *Int8Array) setRecursive(value reflect.Value, index, dimension int) (i break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +454,156 @@ func (dst Int8Array) Get() interface{} { func (src *Int8Array) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]int16: + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int16: + *v = make([]*int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint16: + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint16: + *v = make([]*uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int32: + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int32: + *v = make([]*int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint32: + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint32: + *v = make([]*uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int: + *v = make([]int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int: + *v = make([]*int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint: + *v = make([]uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint: + *v = make([]*uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +642,12 @@ func (src *Int8Array) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +665,14 @@ func (src *Int8Array) assignToRecursive(value reflect.Value, index, dimension in if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from Int8Array") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from Int8Array") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/jsonb_array.go b/jsonb_array.go index c5a40a1d..faf2d364 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -31,56 +31,91 @@ func (dst *JSONBArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = JSONBArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for JSONBArray", src) - } - if elementsLength == 0 { - *dst = JSONBArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to JSONBArray", src) - } - - *dst = JSONBArray{ - Elements: make([]Text, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []string: + if value == nil { + *dst = JSONBArray{Status: Null} + } else if len(value) == 0 { + *dst = JSONBArray{Status: Present} + } else { + elements := make([]Text, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Text, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = JSONBArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Text: + if value == nil { + *dst = JSONBArray{Status: Null} + } else if len(value) == 0 { + *dst = JSONBArray{Status: Present} + } else { + *dst = JSONBArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = JSONBArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for JSONBArray", src) + } + if elementsLength == 0 { + *dst = JSONBArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to JSONBArray", src) + } + + *dst = JSONBArray{ + Elements: make([]Text, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Text, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to JSONBArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to JSONBArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +130,11 @@ func (dst *JSONBArray) setRecursive(value reflect.Value, index, dimension int) ( break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +169,21 @@ func (dst JSONBArray) Get() interface{} { func (src *JSONBArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +222,12 @@ func (src *JSONBArray) assignToRecursive(value reflect.Value, index, dimension i length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +245,14 @@ func (src *JSONBArray) assignToRecursive(value reflect.Value, index, dimension i if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from JSONBArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from JSONBArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/macaddr_array.go b/macaddr_array.go index 398db1fe..6f75ffbc 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "net" "reflect" "github.com/jackc/pgio" @@ -31,56 +32,110 @@ func (dst *MacaddrArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = MacaddrArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for MacaddrArray", src) - } - if elementsLength == 0 { - *dst = MacaddrArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to MacaddrArray", src) - } - - *dst = MacaddrArray{ - Elements: make([]Macaddr, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []net.HardwareAddr: + if value == nil { + *dst = MacaddrArray{Status: Null} + } else if len(value) == 0 { + *dst = MacaddrArray{Status: Present} + } else { + elements := make([]Macaddr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Macaddr, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = MacaddrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*net.HardwareAddr: + if value == nil { + *dst = MacaddrArray{Status: Null} + } else if len(value) == 0 { + *dst = MacaddrArray{Status: Present} + } else { + elements := make([]Macaddr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = MacaddrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Macaddr: + if value == nil { + *dst = MacaddrArray{Status: Null} + } else if len(value) == 0 { + *dst = MacaddrArray{Status: Present} + } else { + *dst = MacaddrArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = MacaddrArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for MacaddrArray", src) + } + if elementsLength == 0 { + *dst = MacaddrArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to MacaddrArray", src) + } + + *dst = MacaddrArray{ + Elements: make([]Macaddr, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Macaddr, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to MacaddrArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to MacaddrArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +150,11 @@ func (dst *MacaddrArray) setRecursive(value reflect.Value, index, dimension int) break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +189,30 @@ func (dst MacaddrArray) Get() interface{} { func (src *MacaddrArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]net.HardwareAddr: + *v = make([]net.HardwareAddr, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*net.HardwareAddr: + *v = make([]*net.HardwareAddr, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +251,12 @@ func (src *MacaddrArray) assignToRecursive(value reflect.Value, index, dimension length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +274,14 @@ func (src *MacaddrArray) assignToRecursive(value reflect.Value, index, dimension if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from MacaddrArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from MacaddrArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/numeric_array.go b/numeric_array.go index dec81535..e848b133 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -31,56 +31,224 @@ func (dst *NumericArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = NumericArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for NumericArray", src) - } - if elementsLength == 0 { - *dst = NumericArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to NumericArray", src) - } - - *dst = NumericArray{ - Elements: make([]Numeric, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []float32: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Numeric, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*float32: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []float64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*float64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Numeric: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + *dst = NumericArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = NumericArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for NumericArray", src) + } + if elementsLength == 0 { + *dst = NumericArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to NumericArray", src) + } + + *dst = NumericArray{ + Elements: make([]Numeric, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Numeric, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to NumericArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to NumericArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +263,11 @@ func (dst *NumericArray) setRecursive(value reflect.Value, index, dimension int) break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +302,84 @@ func (dst NumericArray) Get() interface{} { func (src *NumericArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]float32: + *v = make([]float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*float32: + *v = make([]*float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]float64: + *v = make([]float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*float64: + *v = make([]*float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +418,12 @@ func (src *NumericArray) assignToRecursive(value reflect.Value, index, dimension length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +441,14 @@ func (src *NumericArray) assignToRecursive(value reflect.Value, index, dimension if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from NumericArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from NumericArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/text_array.go b/text_array.go index 31ed04ac..c6a950f8 100644 --- a/text_array.go +++ b/text_array.go @@ -31,56 +31,110 @@ func (dst *TextArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = TextArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for TextArray", src) - } - if elementsLength == 0 { - *dst = TextArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to TextArray", src) - } - - *dst = TextArray{ - Elements: make([]Text, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []string: + if value == nil { + *dst = TextArray{Status: Null} + } else if len(value) == 0 { + *dst = TextArray{Status: Present} + } else { + elements := make([]Text, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Text, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = TextArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*string: + if value == nil { + *dst = TextArray{Status: Null} + } else if len(value) == 0 { + *dst = TextArray{Status: Present} + } else { + elements := make([]Text, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TextArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Text: + if value == nil { + *dst = TextArray{Status: Null} + } else if len(value) == 0 { + *dst = TextArray{Status: Present} + } else { + *dst = TextArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = TextArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TextArray", src) + } + if elementsLength == 0 { + *dst = TextArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to TextArray", src) + } + + *dst = TextArray{ + Elements: make([]Text, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Text, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to TextArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TextArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +149,11 @@ func (dst *TextArray) setRecursive(value reflect.Value, index, dimension int) (i break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +188,30 @@ func (dst TextArray) Get() interface{} { func (src *TextArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +250,12 @@ func (src *TextArray) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +273,14 @@ func (src *TextArray) assignToRecursive(value reflect.Value, index, dimension in if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from TextArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from TextArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/timestamp_array.go b/timestamp_array.go index 355b29c5..d0254d47 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "encoding/binary" "reflect" + "time" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,56 +32,110 @@ func (dst *TimestampArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = TimestampArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for TimestampArray", src) - } - if elementsLength == 0 { - *dst = TimestampArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to TimestampArray", src) - } - - *dst = TimestampArray{ - Elements: make([]Timestamp, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []time.Time: + if value == nil { + *dst = TimestampArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestampArray{Status: Present} + } else { + elements := make([]Timestamp, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Timestamp, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = TimestampArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*time.Time: + if value == nil { + *dst = TimestampArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestampArray{Status: Present} + } else { + elements := make([]Timestamp, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TimestampArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Timestamp: + if value == nil { + *dst = TimestampArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestampArray{Status: Present} + } else { + *dst = TimestampArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = TimestampArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TimestampArray", src) + } + if elementsLength == 0 { + *dst = TimestampArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to TimestampArray", src) + } + + *dst = TimestampArray{ + Elements: make([]Timestamp, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Timestamp, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to TimestampArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TimestampArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +150,11 @@ func (dst *TimestampArray) setRecursive(value reflect.Value, index, dimension in break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +189,30 @@ func (dst TimestampArray) Get() interface{} { func (src *TimestampArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]time.Time: + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*time.Time: + *v = make([]*time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +251,12 @@ func (src *TimestampArray) assignToRecursive(value reflect.Value, index, dimensi length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +274,14 @@ func (src *TimestampArray) assignToRecursive(value reflect.Value, index, dimensi if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from TimestampArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from TimestampArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/timestamptz_array.go b/timestamptz_array.go index 94a791b6..97ce2715 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "encoding/binary" "reflect" + "time" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,56 +32,110 @@ func (dst *TimestamptzArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = TimestamptzArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for TimestamptzArray", src) - } - if elementsLength == 0 { - *dst = TimestamptzArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to TimestamptzArray", src) - } - - *dst = TimestamptzArray{ - Elements: make([]Timestamptz, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []time.Time: + if value == nil { + *dst = TimestamptzArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestamptzArray{Status: Present} + } else { + elements := make([]Timestamptz, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Timestamptz, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = TimestamptzArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*time.Time: + if value == nil { + *dst = TimestamptzArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestamptzArray{Status: Present} + } else { + elements := make([]Timestamptz, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TimestamptzArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Timestamptz: + if value == nil { + *dst = TimestamptzArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestamptzArray{Status: Present} + } else { + *dst = TimestamptzArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = TimestamptzArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TimestamptzArray", src) + } + if elementsLength == 0 { + *dst = TimestamptzArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to TimestamptzArray", src) + } + + *dst = TimestamptzArray{ + Elements: make([]Timestamptz, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Timestamptz, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to TimestamptzArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TimestamptzArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +150,11 @@ func (dst *TimestamptzArray) setRecursive(value reflect.Value, index, dimension break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +189,30 @@ func (dst TimestamptzArray) Get() interface{} { func (src *TimestamptzArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]time.Time: + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*time.Time: + *v = make([]*time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +251,12 @@ func (src *TimestamptzArray) assignToRecursive(value reflect.Value, index, dimen length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +274,14 @@ func (src *TimestamptzArray) assignToRecursive(value reflect.Value, index, dimen if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from TimestamptzArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from TimestamptzArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/tstzrange_array.go b/tstzrange_array.go index f5043c65..02a98e66 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -31,56 +31,72 @@ func (dst *TstzrangeArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = TstzrangeArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for TstzrangeArray", src) - } - if elementsLength == 0 { - *dst = TstzrangeArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to TstzrangeArray", src) - } - - *dst = TstzrangeArray{ - Elements: make([]Tstzrange, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } + case []Tstzrange: + if value == nil { + *dst = TstzrangeArray{Status: Null} + } else if len(value) == 0 { + *dst = TstzrangeArray{Status: Present} + } else { + *dst = TstzrangeArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, } - dst.Elements = make([]Tstzrange, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = TstzrangeArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TstzrangeArray", src) + } + if elementsLength == 0 { + *dst = TstzrangeArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to TstzrangeArray", src) + } + + *dst = TstzrangeArray{ + Elements: make([]Tstzrange, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Tstzrange, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to TstzrangeArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TstzrangeArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +111,11 @@ func (dst *TstzrangeArray) setRecursive(value reflect.Value, index, dimension in break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +150,21 @@ func (dst TstzrangeArray) Get() interface{} { func (src *TstzrangeArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]Tstzrange: + *v = make([]Tstzrange, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +203,12 @@ func (src *TstzrangeArray) assignToRecursive(value reflect.Value, index, dimensi length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +226,14 @@ func (src *TstzrangeArray) assignToRecursive(value reflect.Value, index, dimensi if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from TstzrangeArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from TstzrangeArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/typed_array.go.erb b/typed_array.go.erb index fb964ec8..5bf582b2 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -30,56 +30,93 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = <%= pgtype_array_type %>{Status: Null} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for <%= pgtype_array_type %>", src) - } - if elementsLength == 0 { - *dst = <%= pgtype_array_type %>{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>", src) - } - - *dst = <%= pgtype_array_type %> { - Elements: make([]<%= pgtype_element_type %>, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + switch value := src.(type) { + <% go_array_types.split(",").each do |t| %> + <% if t != "[]#{pgtype_element_type}" %> + case <%= t %>: + if value == nil { + *dst = <%= pgtype_array_type %>{Status: Null} + } else if len(value) == 0 { + *dst = <%= pgtype_array_type %>{Status: Present} + } else { + elements := make([]<%= pgtype_element_type %>, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]<%= pgtype_element_type %>, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = <%= pgtype_array_type %>{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + <% end %> + <% end %> + case []<%= pgtype_element_type %>: + if value == nil { + *dst = <%= pgtype_array_type %>{Status: Null} + } else if len(value) == 0 { + *dst = <%= pgtype_array_type %>{Status: Present} + } else { + *dst = <%= pgtype_array_type %>{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status : Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = <%= pgtype_array_type %>{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for <%= pgtype_array_type %>", src) + } + if elementsLength == 0 { + *dst = <%= pgtype_array_type %>{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>", src) + } + + *dst = <%= pgtype_array_type %> { + Elements: make([]<%= pgtype_element_type %>, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]<%= pgtype_element_type %>, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -94,10 +131,11 @@ func (dst *<%= pgtype_array_type %>) setRecursive(value reflect.Value, index, di break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -132,6 +170,21 @@ func (dst <%= pgtype_array_type %>) Get() interface{} { func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1{ + switch v := dst.(type) { + <% go_array_types.split(",").each do |t| %> + case *<%= t %>: + *v = make(<%= t %>, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + <% end %> + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -170,10 +223,12 @@ func (src *<%= pgtype_array_type %>) assignToRecursive(value reflect.Value, inde length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -191,11 +246,14 @@ func (src *<%= pgtype_array_type %>) assignToRecursive(value reflect.Value, inde if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr(){ return 0, errors.Errorf("cannot assign all values from <%= pgtype_array_type %>") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from <%= pgtype_array_type %>") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 8c594944..607d3bc3 100755 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -1,27 +1,27 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go -erb pgtype_array_type=DateArray pgtype_element_type=Date element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go -erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstzrange_array.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go -erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go -erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go -erb pgtype_array_type=TextArray pgtype_element_type=Text element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > text_array.go -erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar element_type_name=varchar text_null=NULL binary_format=true typed_array.go.erb > varchar_array.go -erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar element_type_name=bpchar text_null=NULL binary_format=true typed_array.go.erb > bpchar_array.go -erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go -erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go -erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go -erb pgtype_array_type=NumericArray pgtype_element_type=Numeric element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go -erb pgtype_array_type=UUIDArray pgtype_element_type=UUID element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go -erb pgtype_array_type=JSONBArray pgtype_element_type=Text element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[]*bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time,[]*time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time,[]*time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange go_array_types=[]Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstzrange_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time,[]*time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32,[]*float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64,[]*float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go +erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr go_array_types=[]net.HardwareAddr,[]*net.HardwareAddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go +erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string,[]*string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string,[]*string element_type_name=varchar text_null=NULL binary_format=true typed_array.go.erb > varchar_array.go +erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string,[]*string element_type_name=bpchar text_null=NULL binary_format=true typed_array.go.erb > bpchar_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go +erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string,[]*string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go +erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]*float32,[]float64,[]*float64,[]int64,[]*int64,[]uint64,[]*uint64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go +erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string,[]*string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go +erb pgtype_array_type=JSONBArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go # While the binary format is theoretically possible it is only practical to use the text format. -erb pgtype_array_type=EnumArray pgtype_element_type=GenericText text_null=NULL binary_format=false typed_array.go.erb > enum_array.go +erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string,[]*string text_null=NULL binary_format=false typed_array.go.erb > enum_array.go goimports -w *_array.go diff --git a/uuid_array.go b/uuid_array.go index e2c86cf8..09c6878f 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -31,56 +31,148 @@ func (dst *UUIDArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = UUIDArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for UUIDArray", src) - } - if elementsLength == 0 { - *dst = UUIDArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to UUIDArray", src) - } - - *dst = UUIDArray{ - Elements: make([]UUID, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case [][16]byte: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]UUID, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case [][]byte: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []string: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*string: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []UUID: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + *dst = UUIDArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = UUIDArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for UUIDArray", src) + } + if elementsLength == 0 { + *dst = UUIDArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to UUIDArray", src) + } + + *dst = UUIDArray{ + Elements: make([]UUID, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]UUID, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to UUIDArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to UUIDArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +187,11 @@ func (dst *UUIDArray) setRecursive(value reflect.Value, index, dimension int) (i break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +226,48 @@ func (dst UUIDArray) Get() interface{} { func (src *UUIDArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[][16]byte: + *v = make([][16]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[][]byte: + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +306,12 @@ func (src *UUIDArray) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +329,14 @@ func (src *UUIDArray) assignToRecursive(value reflect.Value, index, dimension in if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from UUIDArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from UUIDArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ diff --git a/varchar_array.go b/varchar_array.go index ec378ed7..ad19d423 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -31,56 +31,110 @@ func (dst *VarcharArray) Set(src interface{}) error { } } - value := reflect.ValueOf(src) - if !value.IsValid() || value.IsZero() { - *dst = VarcharArray{Status: Null} - return nil - } + switch value := src.(type) { - dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) - if !ok { - return errors.Errorf("cannot find dimensions of %v for VarcharArray", src) - } - if elementsLength == 0 { - *dst = VarcharArray{Status: Present} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to VarcharArray", src) - } - - *dst = VarcharArray{ - Elements: make([]Varchar, elementsLength), - Dimensions: dimensions, - Status: Present, - } - elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) + case []string: + if value == nil { + *dst = VarcharArray{Status: Null} + } else if len(value) == 0 { + *dst = VarcharArray{Status: Present} + } else { + elements := make([]Varchar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err } } - dst.Elements = make([]Varchar, elementsLength) - elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) - if err != nil { + *dst = VarcharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*string: + if value == nil { + *dst = VarcharArray{Status: Null} + } else if len(value) == 0 { + *dst = VarcharArray{Status: Present} + } else { + elements := make([]Varchar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = VarcharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Varchar: + if value == nil { + *dst = VarcharArray{Status: Null} + } else if len(value) == 0 { + *dst = VarcharArray{Status: Present} + } else { + *dst = VarcharArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = VarcharArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for VarcharArray", src) + } + if elementsLength == 0 { + *dst = VarcharArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to VarcharArray", src) + } + + *dst = VarcharArray{ + Elements: make([]Varchar, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Varchar, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { return err } - } else { - return err } - } - if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to VarcharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to VarcharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } } return nil @@ -95,10 +149,11 @@ func (dst *VarcharArray) setRecursive(value reflect.Value, index, dimension int) break } - if int32(value.Len()) != dst.Dimensions[dimension].Length { + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") } - for i := 0; i < value.Len(); i++ { + for i := 0; i < valueLen; i++ { var err error index, err = dst.setRecursive(value.Index(i), index, dimension+1) if err != nil { @@ -133,6 +188,30 @@ func (dst VarcharArray) Get() interface{} { func (src *VarcharArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Dimensions) == 1 { + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -171,10 +250,12 @@ func (src *VarcharArray) assignToRecursive(value reflect.Value, index, dimension length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { - if value.Type().Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + typ := value.Type() + typLen := typ.Len() + if typLen != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) } - value.Set(reflect.New(value.Type()).Elem()) + value.Set(reflect.New(typ).Elem()) } else { value.Set(reflect.MakeSlice(value.Type(), length, length)) } @@ -192,11 +273,14 @@ func (src *VarcharArray) assignToRecursive(value reflect.Value, index, dimension if len(src.Dimensions) != dimension { return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } - if !value.CanAddr() || !value.Addr().CanInterface() { + if !value.CanAddr() { return 0, errors.Errorf("cannot assign all values from VarcharArray") } - err := src.Elements[index].AssignTo(value.Addr().Interface()) - if err != nil { + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from VarcharArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err } index++ From ec14212d300cdcfab05dd26dd401941c2e540a61 Mon Sep 17 00:00:00 2001 From: Simo Haasanen Date: Sun, 9 Aug 2020 09:17:40 +0100 Subject: [PATCH 0564/1158] Add comments to explain the use of reflection after type assertion. Removes one local variable, which is used twice only in an error. --- aclitem_array.go | 13 ++++++++++--- bool_array.go | 13 ++++++++++--- bpchar_array.go | 13 ++++++++++--- bytea_array.go | 13 ++++++++++--- cidr_array.go | 13 ++++++++++--- date_array.go | 13 ++++++++++--- enum_array.go | 13 ++++++++++--- float4_array.go | 13 ++++++++++--- float8_array.go | 13 ++++++++++--- hstore_array.go | 13 ++++++++++--- inet_array.go | 13 ++++++++++--- int2_array.go | 13 ++++++++++--- int4_array.go | 13 ++++++++++--- int8_array.go | 13 ++++++++++--- jsonb_array.go | 13 ++++++++++--- macaddr_array.go | 13 ++++++++++--- numeric_array.go | 13 ++++++++++--- text_array.go | 13 ++++++++++--- timestamp_array.go | 13 ++++++++++--- timestamptz_array.go | 13 ++++++++++--- tstzrange_array.go | 13 ++++++++++--- typed_array.go.erb | 13 ++++++++++--- uuid_array.go | 13 ++++++++++--- varchar_array.go | 13 ++++++++++--- 24 files changed, 240 insertions(+), 72 deletions(-) diff --git a/aclitem_array.go b/aclitem_array.go index 52b67d85..260bbe4c 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -29,6 +29,7 @@ func (dst *ACLItemArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []string: @@ -82,6 +83,9 @@ func (dst *ACLItemArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = ACLItemArray{Status: Null} @@ -187,6 +191,7 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]string: @@ -210,6 +215,9 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -249,9 +257,8 @@ func (src *ACLItemArray) assignToRecursive(value reflect.Value, index, dimension length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/bool_array.go b/bool_array.go index 6a4b3454..149b0c9f 100644 --- a/bool_array.go +++ b/bool_array.go @@ -31,6 +31,7 @@ func (dst *BoolArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []bool: @@ -84,6 +85,9 @@ func (dst *BoolArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = BoolArray{Status: Null} @@ -189,6 +193,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]bool: @@ -212,6 +217,9 @@ func (src *BoolArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -251,9 +259,8 @@ func (src *BoolArray) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/bpchar_array.go b/bpchar_array.go index 1f79a3fe..d28d22ac 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -31,6 +31,7 @@ func (dst *BPCharArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []string: @@ -84,6 +85,9 @@ func (dst *BPCharArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = BPCharArray{Status: Null} @@ -189,6 +193,7 @@ func (src *BPCharArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]string: @@ -212,6 +217,9 @@ func (src *BPCharArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -251,9 +259,8 @@ func (src *BPCharArray) assignToRecursive(value reflect.Value, index, dimension length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/bytea_array.go b/bytea_array.go index 17136554..26956edb 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -31,6 +31,7 @@ func (dst *ByteaArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case [][]byte: @@ -65,6 +66,9 @@ func (dst *ByteaArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = ByteaArray{Status: Null} @@ -170,6 +174,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[][]byte: @@ -184,6 +189,9 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -223,9 +231,8 @@ func (src *ByteaArray) assignToRecursive(value reflect.Value, index, dimension i length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/cidr_array.go b/cidr_array.go index 770c4b8c..d6108fe2 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -32,6 +32,7 @@ func (dst *CIDRArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []*net.IPNet: @@ -104,6 +105,9 @@ func (dst *CIDRArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = CIDRArray{Status: Null} @@ -209,6 +213,7 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]*net.IPNet: @@ -241,6 +246,9 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -280,9 +288,8 @@ func (src *CIDRArray) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/date_array.go b/date_array.go index 7ba93daa..e1b6061a 100644 --- a/date_array.go +++ b/date_array.go @@ -32,6 +32,7 @@ func (dst *DateArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []time.Time: @@ -85,6 +86,9 @@ func (dst *DateArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = DateArray{Status: Null} @@ -190,6 +194,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]time.Time: @@ -213,6 +218,9 @@ func (src *DateArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -252,9 +260,8 @@ func (src *DateArray) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/enum_array.go b/enum_array.go index 561d4495..b2fb063c 100644 --- a/enum_array.go +++ b/enum_array.go @@ -29,6 +29,7 @@ func (dst *EnumArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []string: @@ -82,6 +83,9 @@ func (dst *EnumArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = EnumArray{Status: Null} @@ -187,6 +191,7 @@ func (src *EnumArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]string: @@ -210,6 +215,9 @@ func (src *EnumArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -249,9 +257,8 @@ func (src *EnumArray) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/float4_array.go b/float4_array.go index 829708e1..7e750df8 100644 --- a/float4_array.go +++ b/float4_array.go @@ -31,6 +31,7 @@ func (dst *Float4Array) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []float32: @@ -84,6 +85,9 @@ func (dst *Float4Array) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = Float4Array{Status: Null} @@ -189,6 +193,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]float32: @@ -212,6 +217,9 @@ func (src *Float4Array) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -251,9 +259,8 @@ func (src *Float4Array) assignToRecursive(value reflect.Value, index, dimension length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/float8_array.go b/float8_array.go index 6932cb88..12520722 100644 --- a/float8_array.go +++ b/float8_array.go @@ -31,6 +31,7 @@ func (dst *Float8Array) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []float64: @@ -84,6 +85,9 @@ func (dst *Float8Array) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = Float8Array{Status: Null} @@ -189,6 +193,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]float64: @@ -212,6 +217,9 @@ func (src *Float8Array) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -251,9 +259,8 @@ func (src *Float8Array) assignToRecursive(value reflect.Value, index, dimension length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/hstore_array.go b/hstore_array.go index 4dc172be..d2ff2874 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -31,6 +31,7 @@ func (dst *HstoreArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []map[string]string: @@ -65,6 +66,9 @@ func (dst *HstoreArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = HstoreArray{Status: Null} @@ -170,6 +174,7 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]map[string]string: @@ -184,6 +189,9 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -223,9 +231,8 @@ func (src *HstoreArray) assignToRecursive(value reflect.Value, index, dimension length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/inet_array.go b/inet_array.go index 75f1328f..7133fc0b 100644 --- a/inet_array.go +++ b/inet_array.go @@ -32,6 +32,7 @@ func (dst *InetArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []*net.IPNet: @@ -104,6 +105,9 @@ func (dst *InetArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = InetArray{Status: Null} @@ -209,6 +213,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]*net.IPNet: @@ -241,6 +246,9 @@ func (src *InetArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -280,9 +288,8 @@ func (src *InetArray) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/int2_array.go b/int2_array.go index ede35bac..b64e0689 100644 --- a/int2_array.go +++ b/int2_array.go @@ -31,6 +31,7 @@ func (dst *Int2Array) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []int16: @@ -350,6 +351,9 @@ func (dst *Int2Array) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = Int2Array{Status: Null} @@ -455,6 +459,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]int16: @@ -604,6 +609,9 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -643,9 +651,8 @@ func (src *Int2Array) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/int4_array.go b/int4_array.go index b0856da9..01613d39 100644 --- a/int4_array.go +++ b/int4_array.go @@ -31,6 +31,7 @@ func (dst *Int4Array) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []int16: @@ -350,6 +351,9 @@ func (dst *Int4Array) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = Int4Array{Status: Null} @@ -455,6 +459,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]int16: @@ -604,6 +609,9 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -643,9 +651,8 @@ func (src *Int4Array) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/int8_array.go b/int8_array.go index c95ebef5..0babbe43 100644 --- a/int8_array.go +++ b/int8_array.go @@ -31,6 +31,7 @@ func (dst *Int8Array) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []int16: @@ -350,6 +351,9 @@ func (dst *Int8Array) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = Int8Array{Status: Null} @@ -455,6 +459,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]int16: @@ -604,6 +609,9 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -643,9 +651,8 @@ func (src *Int8Array) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/jsonb_array.go b/jsonb_array.go index faf2d364..1e82843d 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -31,6 +31,7 @@ func (dst *JSONBArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []string: @@ -65,6 +66,9 @@ func (dst *JSONBArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = JSONBArray{Status: Null} @@ -170,6 +174,7 @@ func (src *JSONBArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]string: @@ -184,6 +189,9 @@ func (src *JSONBArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -223,9 +231,8 @@ func (src *JSONBArray) assignToRecursive(value reflect.Value, index, dimension i length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/macaddr_array.go b/macaddr_array.go index 6f75ffbc..94a009fd 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -32,6 +32,7 @@ func (dst *MacaddrArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []net.HardwareAddr: @@ -85,6 +86,9 @@ func (dst *MacaddrArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = MacaddrArray{Status: Null} @@ -190,6 +194,7 @@ func (src *MacaddrArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]net.HardwareAddr: @@ -213,6 +218,9 @@ func (src *MacaddrArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -252,9 +260,8 @@ func (src *MacaddrArray) assignToRecursive(value reflect.Value, index, dimension length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/numeric_array.go b/numeric_array.go index e848b133..884e8b14 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -31,6 +31,7 @@ func (dst *NumericArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []float32: @@ -198,6 +199,9 @@ func (dst *NumericArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = NumericArray{Status: Null} @@ -303,6 +307,7 @@ func (src *NumericArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]float32: @@ -380,6 +385,9 @@ func (src *NumericArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -419,9 +427,8 @@ func (src *NumericArray) assignToRecursive(value reflect.Value, index, dimension length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/text_array.go b/text_array.go index c6a950f8..b2825b29 100644 --- a/text_array.go +++ b/text_array.go @@ -31,6 +31,7 @@ func (dst *TextArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []string: @@ -84,6 +85,9 @@ func (dst *TextArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = TextArray{Status: Null} @@ -189,6 +193,7 @@ func (src *TextArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]string: @@ -212,6 +217,9 @@ func (src *TextArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -251,9 +259,8 @@ func (src *TextArray) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/timestamp_array.go b/timestamp_array.go index d0254d47..0bc30f17 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -32,6 +32,7 @@ func (dst *TimestampArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []time.Time: @@ -85,6 +86,9 @@ func (dst *TimestampArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = TimestampArray{Status: Null} @@ -190,6 +194,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]time.Time: @@ -213,6 +218,9 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -252,9 +260,8 @@ func (src *TimestampArray) assignToRecursive(value reflect.Value, index, dimensi length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/timestamptz_array.go b/timestamptz_array.go index 97ce2715..313bde81 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -32,6 +32,7 @@ func (dst *TimestamptzArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []time.Time: @@ -85,6 +86,9 @@ func (dst *TimestamptzArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = TimestamptzArray{Status: Null} @@ -190,6 +194,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]time.Time: @@ -213,6 +218,9 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -252,9 +260,8 @@ func (src *TimestamptzArray) assignToRecursive(value reflect.Value, index, dimen length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/tstzrange_array.go b/tstzrange_array.go index 02a98e66..216182df 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -31,6 +31,7 @@ func (dst *TstzrangeArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []Tstzrange: @@ -46,6 +47,9 @@ func (dst *TstzrangeArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = TstzrangeArray{Status: Null} @@ -151,6 +155,7 @@ func (src *TstzrangeArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]Tstzrange: @@ -165,6 +170,9 @@ func (src *TstzrangeArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -204,9 +212,8 @@ func (src *TstzrangeArray) assignToRecursive(value reflect.Value, index, dimensi length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/typed_array.go.erb b/typed_array.go.erb index 5bf582b2..809c7884 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -30,6 +30,7 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { <% go_array_types.split(",").each do |t| %> <% if t != "[]#{pgtype_element_type}" %> @@ -66,6 +67,9 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = <%= pgtype_array_type %>{Status: Null} @@ -171,6 +175,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1{ + // Attempt to match to select common types: switch v := dst.(type) { <% go_array_types.split(",").each do |t| %> case *<%= t %>: @@ -185,6 +190,9 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -224,9 +232,8 @@ func (src *<%= pgtype_array_type %>) assignToRecursive(value reflect.Value, inde length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/uuid_array.go b/uuid_array.go index 09c6878f..47e348f3 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -31,6 +31,7 @@ func (dst *UUIDArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case [][16]byte: @@ -122,6 +123,9 @@ func (dst *UUIDArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = UUIDArray{Status: Null} @@ -227,6 +231,7 @@ func (src *UUIDArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[][16]byte: @@ -268,6 +273,9 @@ func (src *UUIDArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -307,9 +315,8 @@ func (src *UUIDArray) assignToRecursive(value reflect.Value, index, dimension in length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { diff --git a/varchar_array.go b/varchar_array.go index ad19d423..e68614bb 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -31,6 +31,7 @@ func (dst *VarcharArray) Set(src interface{}) error { } } + // Attempt to match to select common types: switch value := src.(type) { case []string: @@ -84,6 +85,9 @@ func (dst *VarcharArray) Set(src interface{}) error { } } default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { *dst = VarcharArray{Status: Null} @@ -189,6 +193,7 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { switch src.Status { case Present: if len(src.Dimensions) == 1 { + // Attempt to match to select common types: switch v := dst.(type) { case *[]string: @@ -212,6 +217,9 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { } } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices value := reflect.ValueOf(dst) if value.Kind() == reflect.Ptr { value = value.Elem() @@ -251,9 +259,8 @@ func (src *VarcharArray) assignToRecursive(value reflect.Value, index, dimension length := int(src.Dimensions[dimension].Length) if reflect.Array == kind { typ := value.Type() - typLen := typ.Len() - if typLen != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typLen) + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { From 3eb5432c4738bc58b1e52a91c58decf07324130f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 20 Aug 2020 22:00:21 -0500 Subject: [PATCH 0565/1158] Add PgConn.CleanupChan --- CHANGELOG.md | 4 ++++ helper_test.go | 5 +++++ pgconn.go | 22 +++++++++++++++++++++- pgconn_test.go | 35 +++++++++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a6668fb0..8b988590 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# Unreleased + +* Add PgConn.CleanupChan so connection pools can determine when async close is complete + # 1.6.4 (July 29, 2020) * Fix deadlock on error after CommandComplete but before ReadyForQuery diff --git a/helper_test.go b/helper_test.go index 1a3ca75e..abb04905 100644 --- a/helper_test.go +++ b/helper_test.go @@ -15,6 +15,11 @@ func closeConn(t testing.TB, conn *pgconn.PgConn) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() require.NoError(t, conn.Close(ctx)) + select { + case <-conn.CleanupChan(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } } // Do a simple query to ensure the connection is still usable diff --git a/pgconn.go b/pgconn.go index 50607095..c132b26b 100644 --- a/pgconn.go +++ b/pgconn.go @@ -89,6 +89,8 @@ type PgConn struct { resultReader ResultReader multiResultReader MultiResultReader contextWatcher *ctxwatch.ContextWatcher + + cleanupChan chan struct{} } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -201,6 +203,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn := new(PgConn) pgConn.config = config pgConn.wbuf = make([]byte, 0, wbufLen) + pgConn.cleanupChan = make(chan struct{}) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) @@ -504,6 +507,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error { } pgConn.status = connStatusClosed + defer close(pgConn.cleanupChan) defer pgConn.conn.Close() if ctx != context.Background() { @@ -538,6 +542,7 @@ func (pgConn *PgConn) asyncClose() { pgConn.status = connStatusClosed go func() { + defer close(pgConn.cleanupChan) defer pgConn.conn.Close() deadline := time.Now().Add(time.Second * 15) @@ -554,7 +559,21 @@ func (pgConn *PgConn) asyncClose() { }() } +// CleanupChan returns a channel that will be closed after all underlying resources have been cleaned up. A closed +// connection is no longer usable, but underlying resources, in particular the net.Conn, may not have finished closing +// yet. This is because certain errors such as a context cancellation require that the interrupted function call return +// immediately, but the error may also cause the connection to be closed. In these cases the underlying resources are +// closed asynchronously. +// +// This is only likely to be useful to connection pools. It gives them a way avoid establishing a new connection while +// an old connection is still being cleaned up and thereby exceeding the maximum pool size. +func (pgConn *PgConn) CleanupChan() chan (struct{}) { + return pgConn.cleanupChan +} + // IsClosed reports if the connection has been closed. +// +// CleanupChan() can be used to determine if all cleanup has been completed. func (pgConn *PgConn) IsClosed() bool { return pgConn.status < connStatusIdle } @@ -1585,7 +1604,8 @@ func Construct(hc *HijackedConn) (*PgConn, error) { status: connStatusIdle, - wbuf: make([]byte, 0, wbufLen), + wbuf: make([]byte, 0, wbufLen), + cleanupChan: make(chan struct{}), } pgConn.contextWatcher = ctxwatch.NewContextWatcher( diff --git a/pgconn_test.go b/pgconn_test.go index 379aa266..56afc1c2 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -547,6 +547,11 @@ func TestConnExecContextCanceled(t *testing.T) { err = multiResult.Close() assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupChan(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } } func TestConnExecContextPrecanceled(t *testing.T) { @@ -680,6 +685,11 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupChan(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } } func TestConnExecParamsPrecanceled(t *testing.T) { @@ -824,6 +834,11 @@ func TestConnExecPreparedCanceled(t *testing.T) { assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupChan(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } } func TestConnExecPreparedPrecanceled(t *testing.T) { @@ -1306,6 +1321,11 @@ func TestConnCopyToCanceled(t *testing.T) { assert.Equal(t, pgconn.CommandTag(nil), res) assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupChan(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } } func TestConnCopyToPrecanceled(t *testing.T) { @@ -1397,6 +1417,11 @@ func TestConnCopyFromCanceled(t *testing.T) { assert.Error(t, err) assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupChan(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } } func TestConnCopyFromPrecanceled(t *testing.T) { @@ -1647,6 +1672,11 @@ func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) { err = multiResult.Close() assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupChan(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } otherConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) @@ -1750,6 +1780,11 @@ func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) { closeCtx, _ := context.WithCancel(context.Background()) pgConn.Close(closeCtx) + select { + case <-pgConn.CleanupChan(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } } // https://github.com/jackc/pgx/issues/800 From fdfc783345f6b5df05b2039666c59ebd29a7e683 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 20 Aug 2020 22:08:40 -0500 Subject: [PATCH 0566/1158] Rename CleanupChan to CleanupDone --- CHANGELOG.md | 2 +- helper_test.go | 2 +- pgconn.go | 18 +++++++++--------- pgconn_test.go | 14 +++++++------- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b988590..497e00a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Unreleased -* Add PgConn.CleanupChan so connection pools can determine when async close is complete +* Add PgConn.CleanupDone so connection pools can determine when async close is complete # 1.6.4 (July 29, 2020) diff --git a/helper_test.go b/helper_test.go index abb04905..87613dc9 100644 --- a/helper_test.go +++ b/helper_test.go @@ -16,7 +16,7 @@ func closeConn(t testing.TB, conn *pgconn.PgConn) { defer cancel() require.NoError(t, conn.Close(ctx)) select { - case <-conn.CleanupChan(): + case <-conn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } diff --git a/pgconn.go b/pgconn.go index c132b26b..d031b7a1 100644 --- a/pgconn.go +++ b/pgconn.go @@ -90,7 +90,7 @@ type PgConn struct { multiResultReader MultiResultReader contextWatcher *ctxwatch.ContextWatcher - cleanupChan chan struct{} + cleanupDone chan struct{} } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -203,7 +203,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn := new(PgConn) pgConn.config = config pgConn.wbuf = make([]byte, 0, wbufLen) - pgConn.cleanupChan = make(chan struct{}) + pgConn.cleanupDone = make(chan struct{}) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) @@ -507,7 +507,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error { } pgConn.status = connStatusClosed - defer close(pgConn.cleanupChan) + defer close(pgConn.cleanupDone) defer pgConn.conn.Close() if ctx != context.Background() { @@ -542,7 +542,7 @@ func (pgConn *PgConn) asyncClose() { pgConn.status = connStatusClosed go func() { - defer close(pgConn.cleanupChan) + defer close(pgConn.cleanupDone) defer pgConn.conn.Close() deadline := time.Now().Add(time.Second * 15) @@ -559,7 +559,7 @@ func (pgConn *PgConn) asyncClose() { }() } -// CleanupChan returns a channel that will be closed after all underlying resources have been cleaned up. A closed +// CleanupDone returns a channel that will be closed after all underlying resources have been cleaned up. A closed // connection is no longer usable, but underlying resources, in particular the net.Conn, may not have finished closing // yet. This is because certain errors such as a context cancellation require that the interrupted function call return // immediately, but the error may also cause the connection to be closed. In these cases the underlying resources are @@ -567,13 +567,13 @@ func (pgConn *PgConn) asyncClose() { // // This is only likely to be useful to connection pools. It gives them a way avoid establishing a new connection while // an old connection is still being cleaned up and thereby exceeding the maximum pool size. -func (pgConn *PgConn) CleanupChan() chan (struct{}) { - return pgConn.cleanupChan +func (pgConn *PgConn) CleanupDone() chan (struct{}) { + return pgConn.cleanupDone } // IsClosed reports if the connection has been closed. // -// CleanupChan() can be used to determine if all cleanup has been completed. +// CleanupDone() can be used to determine if all cleanup has been completed. func (pgConn *PgConn) IsClosed() bool { return pgConn.status < connStatusIdle } @@ -1605,7 +1605,7 @@ func Construct(hc *HijackedConn) (*PgConn, error) { status: connStatusIdle, wbuf: make([]byte, 0, wbufLen), - cleanupChan: make(chan struct{}), + cleanupDone: make(chan struct{}), } pgConn.contextWatcher = ctxwatch.NewContextWatcher( diff --git a/pgconn_test.go b/pgconn_test.go index 56afc1c2..f6750a60 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -548,7 +548,7 @@ func TestConnExecContextCanceled(t *testing.T) { assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) select { - case <-pgConn.CleanupChan(): + case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } @@ -686,7 +686,7 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.True(t, pgConn.IsClosed()) select { - case <-pgConn.CleanupChan(): + case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } @@ -835,7 +835,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) select { - case <-pgConn.CleanupChan(): + case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } @@ -1322,7 +1322,7 @@ func TestConnCopyToCanceled(t *testing.T) { assert.True(t, pgConn.IsClosed()) select { - case <-pgConn.CleanupChan(): + case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } @@ -1418,7 +1418,7 @@ func TestConnCopyFromCanceled(t *testing.T) { assert.True(t, pgConn.IsClosed()) select { - case <-pgConn.CleanupChan(): + case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } @@ -1673,7 +1673,7 @@ func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) { assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) select { - case <-pgConn.CleanupChan(): + case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } @@ -1781,7 +1781,7 @@ func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) { closeCtx, _ := context.WithCancel(context.Background()) pgConn.Close(closeCtx) select { - case <-pgConn.CleanupChan(): + case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } From 1debbfeec4c2b878d81b3491499f0bc6f5c5a40d Mon Sep 17 00:00:00 2001 From: Sebastiaan Mannem Date: Sun, 2 Aug 2020 16:43:47 +0200 Subject: [PATCH 0567/1158] Adding SendBytesWithResults option to receive data after sending a message (used by copy-both) --- pgconn.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/pgconn.go b/pgconn.go index d031b7a1..decdc03d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -904,6 +904,51 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { return multiResult } +// SendBytesWithResults sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as +// error to call SendBytes while reading the result of a query. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +// +// So far this only seems required with CopyDone handling. +func (pgConn *PgConn) SendBytesWithResults(ctx context.Context, buf []byte) *MultiResultReader { + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } + + pgConn.multiResultReader = MultiResultReader{ + pgConn: pgConn, + ctx: ctx, + } + multiResult := &pgConn.multiResultReader + if ctx != context.Background() { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.asyncClose() + pgConn.contextWatcher.Unwatch() + multiResult.closed = true + multiResult.err = &writeError{err: err, safeToRetry: n == 0} + pgConn.unlock() + return multiResult + } + + return multiResult +} + // ExecParams executes a command via the PostgreSQL extended query protocol. // // sql is a SQL command string. It may only contain one query. Parameter substitution is positional using $1, $2, $3, From 5db484908cf74895bb9e03414d1ba022a24e11bd Mon Sep 17 00:00:00 2001 From: Sebastiaan Mannem Date: Sun, 23 Aug 2020 00:21:46 +0200 Subject: [PATCH 0568/1158] Changing SendBytesWithResults to ReceiveResults (that only does the reading). --- pgconn.go | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/pgconn.go b/pgconn.go index decdc03d..e2ab5c13 100644 --- a/pgconn.go +++ b/pgconn.go @@ -904,14 +904,12 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { return multiResult } -// SendBytesWithResults sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as -// error to call SendBytes while reading the result of a query. +// ReceiveResults reads the result that might be returned by Postgres after a SendBytes +// (e.a. after sending a CopyDone in a copy-both situation). // // This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. // See https://www.postgresql.org/docs/current/protocol.html. -// -// So far this only seems required with CopyDone handling. -func (pgConn *PgConn) SendBytesWithResults(ctx context.Context, buf []byte) *MultiResultReader { +func (pgConn *PgConn) ReceiveResults(ctx context.Context) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, @@ -936,16 +934,6 @@ func (pgConn *PgConn) SendBytesWithResults(ctx context.Context, buf []byte) *Mul pgConn.contextWatcher.Watch(ctx) } - n, err := pgConn.conn.Write(buf) - if err != nil { - pgConn.asyncClose() - pgConn.contextWatcher.Unwatch() - multiResult.closed = true - multiResult.err = &writeError{err: err, safeToRetry: n == 0} - pgConn.unlock() - return multiResult - } - return multiResult } From 08088ecf9a92d8ba11a0784f7ff093bd5dfde1bd Mon Sep 17 00:00:00 2001 From: Yuli Khodorkovskiy Date: Mon, 4 May 2020 13:30:57 -0400 Subject: [PATCH 0569/1158] Fix notification response Notification response was missing the PID in the Encode function --- notification_response.go | 1 + 1 file changed, 1 insertion(+) diff --git a/notification_response.go b/notification_response.go index cd83c5ba..e762eb96 100644 --- a/notification_response.go +++ b/notification_response.go @@ -46,6 +46,7 @@ func (src *NotificationResponse) Encode(dst []byte) []byte { sp := len(dst) dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint32(dst, src.PID) dst = append(dst, src.Channel...) dst = append(dst, 0) dst = append(dst, src.Payload...) From 79b05217d14ece98b13c69ba3358b47248ab4bbc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 4 Sep 2020 18:41:34 -0500 Subject: [PATCH 0570/1158] Fix JSONBArray to have elements of JSONB --- jsonb_array.go | 50 +++++++++++++++++++++++++++++++++++---------- jsonb_array_test.go | 36 ++++++++++++++++++++++++++++++++ typed_array_gen.sh | 2 +- 3 files changed, 76 insertions(+), 12 deletions(-) create mode 100644 jsonb_array_test.go diff --git a/jsonb_array.go b/jsonb_array.go index 1e82843d..8f51b789 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -12,7 +12,7 @@ import ( ) type JSONBArray struct { - Elements []Text + Elements []JSONB Dimensions []ArrayDimension Status Status } @@ -40,7 +40,7 @@ func (dst *JSONBArray) Set(src interface{}) error { } else if len(value) == 0 { *dst = JSONBArray{Status: Present} } else { - elements := make([]Text, len(value)) + elements := make([]JSONB, len(value)) for i := range value { if err := elements[i].Set(value[i]); err != nil { return err @@ -53,7 +53,26 @@ func (dst *JSONBArray) Set(src interface{}) error { } } - case []Text: + case [][]byte: + if value == nil { + *dst = JSONBArray{Status: Null} + } else if len(value) == 0 { + *dst = JSONBArray{Status: Present} + } else { + elements := make([]JSONB, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = JSONBArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []JSONB: if value == nil { *dst = JSONBArray{Status: Null} } else if len(value) == 0 { @@ -91,7 +110,7 @@ func (dst *JSONBArray) Set(src interface{}) error { } *dst = JSONBArray{ - Elements: make([]Text, elementsLength), + Elements: make([]JSONB, elementsLength), Dimensions: dimensions, Status: Present, } @@ -108,7 +127,7 @@ func (dst *JSONBArray) Set(src interface{}) error { elementsLength *= int(dim.Length) } } - dst.Elements = make([]Text, elementsLength) + dst.Elements = make([]JSONB, elementsLength) elementCount, err = dst.setRecursive(reflectedValue, 0, 0) if err != nil { return err @@ -186,6 +205,15 @@ func (src *JSONBArray) AssignTo(dst interface{}) error { } return nil + case *[][]byte: + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + } } @@ -277,13 +305,13 @@ func (dst *JSONBArray) DecodeText(ci *ConnInfo, src []byte) error { return err } - var elements []Text + var elements []JSONB if len(uta.Elements) > 0 { - elements = make([]Text, len(uta.Elements)) + elements = make([]JSONB, len(uta.Elements)) for i, s := range uta.Elements { - var elem Text + var elem JSONB var elemSrc []byte if s != "NULL" { elemSrc = []byte(s) @@ -324,7 +352,7 @@ func (dst *JSONBArray) DecodeBinary(ci *ConnInfo, src []byte) error { elementCount *= d.Length } - elements := make([]Text, elementCount) + elements := make([]JSONB, elementCount) for i := range elements { elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) @@ -413,10 +441,10 @@ func (src JSONBArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { Dimensions: src.Dimensions, } - if dt, ok := ci.DataTypeForName("text"); ok { + if dt, ok := ci.DataTypeForName("jsonb"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "text") + return nil, errors.Errorf("unable to find oid for type name %v", "jsonb") } for i := range src.Elements { diff --git a/jsonb_array_test.go b/jsonb_array_test.go new file mode 100644 index 00000000..65f1777a --- /dev/null +++ b/jsonb_array_test.go @@ -0,0 +1,36 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestJSONBArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "jsonb[]", []interface{}{ + &pgtype.JSONBArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.JSONBArray{ + Elements: []pgtype.JSONB{ + {Bytes: []byte(`"foo"`), Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.JSONBArray{Status: pgtype.Null}, + &pgtype.JSONBArray{ + Elements: []pgtype.JSONB{ + {Bytes: []byte(`"foo"`), Status: pgtype.Present}, + {Bytes: []byte("null"), Status: pgtype.Present}, + {Bytes: []byte("42"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, + Status: pgtype.Present, + }, + }) +} diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 607d3bc3..fe9eb62b 100755 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -19,7 +19,7 @@ erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[] erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]*float32,[]float64,[]*float64,[]int64,[]*int64,[]uint64,[]*uint64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string,[]*string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go -erb pgtype_array_type=JSONBArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go +erb pgtype_array_type=JSONBArray pgtype_element_type=JSONB go_array_types=[]string,[][]byte element_type_name=jsonb text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go # While the binary format is theoretically possible it is only practical to use the text format. erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string,[]*string text_null=NULL binary_format=false typed_array.go.erb > enum_array.go From 9da6afcad782f26368737c3ef06ed8ba867f8292 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Sep 2020 10:56:22 -0500 Subject: [PATCH 0571/1158] Fix selecting empty array Failing test was in pgx: TestReadingValueAfterEmptyArray --- aclitem_array.go | 2 +- bool_array.go | 2 +- bpchar_array.go | 2 +- bytea_array.go | 2 +- cidr_array.go | 2 +- date_array.go | 2 +- enum_array.go | 2 +- float4_array.go | 2 +- float8_array.go | 2 +- hstore_array.go | 2 +- inet_array.go | 2 +- int2_array.go | 2 +- int4_array.go | 2 +- int8_array.go | 2 +- jsonb_array.go | 2 +- macaddr_array.go | 2 +- numeric_array.go | 2 +- text_array.go | 2 +- timestamp_array.go | 2 +- timestamptz_array.go | 2 +- tstzrange_array.go | 2 +- typed_array.go.erb | 8 ++++---- uuid_array.go | 2 +- varchar_array.go | 2 +- 24 files changed, 27 insertions(+), 27 deletions(-) diff --git a/aclitem_array.go b/aclitem_array.go index 260bbe4c..673951a6 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -190,7 +190,7 @@ func (dst ACLItemArray) Get() interface{} { func (src *ACLItemArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/bool_array.go b/bool_array.go index 149b0c9f..ed411a28 100644 --- a/bool_array.go +++ b/bool_array.go @@ -192,7 +192,7 @@ func (dst BoolArray) Get() interface{} { func (src *BoolArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/bpchar_array.go b/bpchar_array.go index d28d22ac..0b92ba30 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -192,7 +192,7 @@ func (dst BPCharArray) Get() interface{} { func (src *BPCharArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/bytea_array.go b/bytea_array.go index 26956edb..f80980bd 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -173,7 +173,7 @@ func (dst ByteaArray) Get() interface{} { func (src *ByteaArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/cidr_array.go b/cidr_array.go index d6108fe2..0b902cca 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -212,7 +212,7 @@ func (dst CIDRArray) Get() interface{} { func (src *CIDRArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/date_array.go b/date_array.go index e1b6061a..b306589e 100644 --- a/date_array.go +++ b/date_array.go @@ -193,7 +193,7 @@ func (dst DateArray) Get() interface{} { func (src *DateArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/enum_array.go b/enum_array.go index b2fb063c..4b6d2af4 100644 --- a/enum_array.go +++ b/enum_array.go @@ -190,7 +190,7 @@ func (dst EnumArray) Get() interface{} { func (src *EnumArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/float4_array.go b/float4_array.go index 7e750df8..22577023 100644 --- a/float4_array.go +++ b/float4_array.go @@ -192,7 +192,7 @@ func (dst Float4Array) Get() interface{} { func (src *Float4Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/float8_array.go b/float8_array.go index 12520722..6c309700 100644 --- a/float8_array.go +++ b/float8_array.go @@ -192,7 +192,7 @@ func (dst Float8Array) Get() interface{} { func (src *Float8Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/hstore_array.go b/hstore_array.go index d2ff2874..413e3993 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -173,7 +173,7 @@ func (dst HstoreArray) Get() interface{} { func (src *HstoreArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/inet_array.go b/inet_array.go index 7133fc0b..c4368ebc 100644 --- a/inet_array.go +++ b/inet_array.go @@ -212,7 +212,7 @@ func (dst InetArray) Get() interface{} { func (src *InetArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/int2_array.go b/int2_array.go index b64e0689..71ccc0c4 100644 --- a/int2_array.go +++ b/int2_array.go @@ -458,7 +458,7 @@ func (dst Int2Array) Get() interface{} { func (src *Int2Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/int4_array.go b/int4_array.go index 01613d39..09b23c2f 100644 --- a/int4_array.go +++ b/int4_array.go @@ -458,7 +458,7 @@ func (dst Int4Array) Get() interface{} { func (src *Int4Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/int8_array.go b/int8_array.go index 0babbe43..93a902b0 100644 --- a/int8_array.go +++ b/int8_array.go @@ -458,7 +458,7 @@ func (dst Int8Array) Get() interface{} { func (src *Int8Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/jsonb_array.go b/jsonb_array.go index 8f51b789..98970dcf 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -192,7 +192,7 @@ func (dst JSONBArray) Get() interface{} { func (src *JSONBArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/macaddr_array.go b/macaddr_array.go index 94a009fd..eafa5482 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -193,7 +193,7 @@ func (dst MacaddrArray) Get() interface{} { func (src *MacaddrArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/numeric_array.go b/numeric_array.go index 884e8b14..806557bc 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -306,7 +306,7 @@ func (dst NumericArray) Get() interface{} { func (src *NumericArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/text_array.go b/text_array.go index b2825b29..03f72d37 100644 --- a/text_array.go +++ b/text_array.go @@ -192,7 +192,7 @@ func (dst TextArray) Get() interface{} { func (src *TextArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/timestamp_array.go b/timestamp_array.go index 0bc30f17..27f6e867 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -193,7 +193,7 @@ func (dst TimestampArray) Get() interface{} { func (src *TimestampArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/timestamptz_array.go b/timestamptz_array.go index 313bde81..4db5c979 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -193,7 +193,7 @@ func (dst TimestamptzArray) Get() interface{} { func (src *TimestamptzArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/tstzrange_array.go b/tstzrange_array.go index 216182df..2c9492f4 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -154,7 +154,7 @@ func (dst TstzrangeArray) Get() interface{} { func (src *TstzrangeArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/typed_array.go.erb b/typed_array.go.erb index 809c7884..c4c797de 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -134,7 +134,7 @@ func (dst *<%= pgtype_array_type %>) setRecursive(value reflect.Value, index, di if len(dst.Dimensions) == dimension { break } - + valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") @@ -174,7 +174,7 @@ func (dst <%= pgtype_array_type %>) Get() interface{} { func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1{ + if len(src.Dimensions) <= 1{ // Attempt to match to select common types: switch v := dst.(type) { <% go_array_types.split(",").each do |t| %> @@ -189,7 +189,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { <% end %> } } - + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -211,7 +211,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { if elementCount != len(src.Elements) { return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } - + return nil case Null: return NullAssignTo(dst) diff --git a/uuid_array.go b/uuid_array.go index 47e348f3..035fb114 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -230,7 +230,7 @@ func (dst UUIDArray) Get() interface{} { func (src *UUIDArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/varchar_array.go b/varchar_array.go index e68614bb..95ab48f3 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -192,7 +192,7 @@ func (dst VarcharArray) Get() interface{} { func (src *VarcharArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Dimensions) == 1 { + if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { From e7d2b057a716db954f25b3dc144eaa775f656eb7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Sep 2020 11:13:53 -0500 Subject: [PATCH 0572/1158] Text formatted values except bytea can be directly scanned to []byte This significantly improves performance of scanning text to []byte as it avoids multiple allocations and copies. --- pgtype.go | 6 +++++- pgtype_test.go | 8 ++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/pgtype.go b/pgtype.go index 5aa466d2..df5078a9 100644 --- a/pgtype.go +++ b/pgtype.go @@ -779,7 +779,7 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan } case *[]byte: switch oid { - case ByteaOID, TextOID, VarcharOID: + case ByteaOID, TextOID, VarcharOID, JSONOID: return scanPlanBinaryBytes{} } case BinaryDecoder: @@ -789,6 +789,10 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan switch dst.(type) { case *string: return scanPlanString{} + case *[]byte: + if oid != ByteaOID { + return scanPlanBinaryBytes{} + } case TextDecoder: return scanPlanDstTextDecoder{} } diff --git a/pgtype_test.go b/pgtype_test.go index 0c2bec83..32ce0a99 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -79,6 +79,14 @@ func TestConnInfoScanTextFormatInterfacePtr(t *testing.T) { assert.Equal(t, "foo", got) } +func TestConnInfoScanTextFormatNonByteaIntoByteSlice(t *testing.T) { + ci := pgtype.NewConnInfo() + var got []byte + err := ci.Scan(pgtype.JSONBOID, pgx.TextFormatCode, []byte("{}"), &got) + require.NoError(t, err) + assert.Equal(t, []byte("{}"), got) +} + func TestConnInfoScanBinaryFormatInterfacePtr(t *testing.T) { ci := pgtype.NewConnInfo() var got interface{} From fede0ce5d6582beb0dfd3785d5e10e00a438dd68 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Sep 2020 11:29:10 -0500 Subject: [PATCH 0573/1158] Document that received messages are only valid until the next receive. --- backend.go | 2 +- frontend.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend.go b/backend.go index cd7e8ce2..1f854c69 100644 --- a/backend.go +++ b/backend.go @@ -91,7 +91,7 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { } } -// Receive receives a message from the frontend. +// Receive receives a message from the frontend. The returned message is only valid until the next call to Receive. func (b *Backend) Receive() (FrontendMessage, error) { if !b.partialMsg { header, err := b.cr.Next(5) diff --git a/frontend.go b/frontend.go index 3298d7e6..b8f545ca 100644 --- a/frontend.go +++ b/frontend.go @@ -65,7 +65,7 @@ func translateEOFtoErrUnexpectedEOF(err error) error { return err } -// Receive receives a message from the backend. +// Receive receives a message from the backend. The returned message is only valid until the next call to Receive. func (f *Frontend) Receive() (BackendMessage, error) { if !f.partialMsg { header, err := f.cr.Next(5) From 0d4f029683fc678cb3084b4dd714e1fde88856e3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Sep 2020 13:14:11 -0500 Subject: [PATCH 0574/1158] Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded Previously, it wasn't loaded until NextRow was called the first time. --- pgconn.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++++-- pgconn_test.go | 38 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 84 insertions(+), 4 deletions(-) diff --git a/pgconn.go b/pgconn.go index e2ab5c13..ff812069 100644 --- a/pgconn.go +++ b/pgconn.go @@ -84,6 +84,8 @@ type PgConn struct { bufferingReceiveMsg pgproto3.BackendMessage bufferingReceiveErr error + peekedMsg pgproto3.BackendMessage + // Reusable / preallocated resources wbuf []byte // write buffer resultReader ResultReader @@ -427,8 +429,12 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa return msg, err } -// receiveMessage receives a message without setting up context cancellation -func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { +// peekMessage peeks at the next message without setting up context cancellation. +func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { + if pgConn.peekedMsg != nil { + return pgConn.peekedMsg, nil + } + var msg pgproto3.BackendMessage var err error if pgConn.bufferingReceive { @@ -455,6 +461,23 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { return nil, err } + pgConn.peekedMsg = msg + return msg, nil +} + +// receiveMessage receives a message without setting up context cancellation +func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := pgConn.peekMessage() + if err != nil { + // Close on anything other than timeout error - everything else is fatal + if err, ok := err.(net.Error); !(ok && err.Timeout()) { + pgConn.asyncClose() + } + + return nil, err + } + pgConn.peekedMsg = nil + switch msg := msg.(type) { case *pgproto3.ReadyForQuery: pgConn.txStatus = msg.TxStatus @@ -1044,7 +1067,10 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() + return } + + result.readUntilRowDescription() } // CopyTo executes the copy command sql and copies the results to w. @@ -1454,6 +1480,26 @@ func (rr *ResultReader) Close() (CommandTag, error) { return rr.commandTag, rr.err } +// readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any +// error will be stored in the ResultReader. +func (rr *ResultReader) readUntilRowDescription() { + for !rr.commandConcluded { + // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method. + // This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are + // manually used to construct a query that does not issue a describe statement. + msg, _ := rr.pgConn.peekMessage() + if _, ok := msg.(*pgproto3.DataRow); ok { + return + } + + // Consume the message + msg, _ = rr.receiveMessage() + if _, ok := msg.(*pgproto3.RowDescription); ok { + return + } + } +} + func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { if rr.multiResultReader == nil { msg, err = rr.pgConn.receiveMessage() diff --git a/pgconn_test.go b/pgconn_test.go index f6750a60..24200e73 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -481,6 +481,34 @@ func TestConnExecMultipleQueries(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + mrr := pgConn.Exec(context.Background(), "select 'Hello, world' as msg; select 1 as num") + + require.True(t, mrr.NextResult()) + require.Len(t, mrr.ResultReader().FieldDescriptions(), 1) + assert.Equal(t, []byte("msg"), mrr.ResultReader().FieldDescriptions()[0].Name) + _, err = mrr.ResultReader().Close() + require.NoError(t, err) + + require.True(t, mrr.NextResult()) + require.Len(t, mrr.ResultReader().FieldDescriptions(), 1) + assert.Equal(t, []byte("num"), mrr.ResultReader().FieldDescriptions()[0].Name) + _, err = mrr.ResultReader().Close() + require.NoError(t, err) + + require.False(t, mrr.NextResult()) + + require.NoError(t, mrr.Close()) + + ensureConnValid(t, pgConn) +} + func TestConnExecMultipleQueriesError(t *testing.T) { t.Parallel() @@ -578,7 +606,10 @@ func TestConnExecParams(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) - result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) + rowCount := 0 for result.NextRow() { rowCount += 1 @@ -734,13 +765,16 @@ func TestConnExecPrepared(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text as msg", nil) require.NoError(t, err) require.NotNil(t, psd) assert.Len(t, psd.ParamOIDs, 1) assert.Len(t, psd.Fields, 1) result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) + rowCount := 0 for result.NextRow() { rowCount += 1 From b6b3a8631050ce8a4398a68c4c165269c1a14450 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Sep 2020 13:26:56 -0500 Subject: [PATCH 0575/1158] Update CI Go versions --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 0371101f..95dce226 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,8 @@ language: go go: + - 1.15.x - 1.14.x - - 1.13.x - tip git: From be69c1c10b10bcaeb5cb7d1e7b72022060c4222d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 10 Sep 2020 19:40:46 -0500 Subject: [PATCH 0576/1158] Fix parseDSNSettings with bad backslash fixes #49 --- config.go | 3 +++ config_test.go | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/config.go b/config.go index b2583546..b05727ca 100644 --- a/config.go +++ b/config.go @@ -506,6 +506,9 @@ func parseDSNSettings(s string) (map[string]string, error) { } if s[end] == '\\' { end++ + if end == len(s) { + return nil, errors.New("invalid backslash") + } } } val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) diff --git a/config_test.go b/config_test.go index 264eb299..d322f65a 100644 --- a/config_test.go +++ b/config_test.go @@ -539,6 +539,13 @@ func TestParseConfigDSNLeadingEqual(t *testing.T) { require.Error(t, err) } +// https://github.com/jackc/pgconn/issues/49 +func TestParseConfigDSNTrailingBackslash(t *testing.T) { + _, err := pgconn.ParseConfig(`x=x\`) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid backslash") +} + func TestConfigCopyReturnsEqualConfig(t *testing.T) { connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" original, err := pgconn.ParseConfig(connString) From d540ca39be4f4f307552b89256cb6afe998221ff Mon Sep 17 00:00:00 2001 From: bakmataliev Date: Fri, 11 Sep 2020 16:24:48 +0300 Subject: [PATCH 0577/1158] New marshalers have been added --- .gitignore | 1 + point.go | 78 ++++++++++++++++++++++++++++- point_test.go | 134 ++++++++++++++++++++++++++++++++++++++++++++++++++ uuid.go | 16 ++++++ 4 files changed, 228 insertions(+), 1 deletion(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..723ef36f --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea \ No newline at end of file diff --git a/point.go b/point.go index 87993656..9961f624 100644 --- a/point.go +++ b/point.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "math" + "regexp" "strconv" "strings" @@ -22,8 +23,62 @@ type Point struct { Status Status } +var nullRE = regexp.MustCompile("^null$") + func (dst *Point) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Point", src) + if src == nil { + dst.Status = Null + return nil + } + err := errors.Errorf("cannot convert %v to Point", src) + var p *Point + switch value := src.(type) { + case string: + p, err = parsePoint([]byte(value)) + case []byte: + if nullRE.Match(value) { + dst.Status = Null + return nil + } + p, err = parsePoint(value) + default: + return err + } + if err != nil { + return err + } + *dst = *p + return nil +} + +var pointRE = regexp.MustCompile("^\\(\\d+\\.\\d+,\\s?\\d+\\.\\d+\\)$") +var chunkRE = regexp.MustCompile("\\d+\\.\\d+") + +func parsePoint(p []byte) (*Point, error) { + err := errors.Errorf("cannot parse %s", p) + if pointRE.Match(p) { + chunks := chunkRE.FindAll(p, 2) + if len(chunks) != 2 { + return nil, err + } + x, xErr := strconv.ParseFloat(string(chunks[0]), 64) + y, yErr := strconv.ParseFloat(string(chunks[1]), 64) + if xErr != nil || yErr != nil { + return nil, err + } + return &Point{ + P: Vec2{ + X: x, + Y: y, + }, + Status: Present, + }, nil + } else if nullRE.Match(p) { + return &Point{ + Status: Null, + }, nil + } + return nil, err } func (dst Point) Get() interface{} { @@ -140,3 +195,24 @@ func (dst *Point) Scan(src interface{}) error { func (src Point) Value() (driver.Value, error) { return EncodeValueText(src) } + +func (src Point) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return []byte(fmt.Sprintf("(%g, %g)", src.P.X, src.P.Y)), nil + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + return nil, errBadStatus +} + +func (dst *Point) UnmarshalJSON(point []byte) error { + p, err := parsePoint(point) + if err != nil { + return err + } + *dst = *p + return nil +} diff --git a/point_test.go b/point_test.go index 0d191b5e..9a659cbc 100644 --- a/point_test.go +++ b/point_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "reflect" "testing" "github.com/jackc/pgtype" @@ -14,3 +15,136 @@ func TestPointTranscode(t *testing.T) { &pgtype.Point{Status: pgtype.Null}, }) } + +func TestPoint_Set(t *testing.T) { + tests := []struct { + name string + arg interface{} + status pgtype.Status + wantErr bool + }{ + { + name: "first", + arg: "(12312.123123, 123123.123123)", + status: pgtype.Present, + wantErr: false, + }, + { + name: "second", + arg: "(1231s2.123123, 123123.123123)", + status: pgtype.Undefined, + wantErr: true, + }, + { + name: "third", + arg: []byte("(122.123123,123.123123)"), + status: pgtype.Present, + wantErr: false, + }, + { + name: "third", + arg: nil, + status: pgtype.Null, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dst := &pgtype.Point{} + if err := dst.Set(tt.arg); (err != nil) != tt.wantErr { + t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr) + } + if dst.Status != tt.status { + t.Errorf("Expected status: %v; got: %v", tt.status, dst.Status) + } + }) + } +} + +func TestPoint_MarshalJSON(t *testing.T) { + tests := []struct { + name string + point pgtype.Point + want []byte + wantErr bool + }{ + { + name: "first", + point: pgtype.Point{ + P: pgtype.Vec2{}, + Status: 0, + }, + want: nil, + wantErr: true, + }, + { + name: "second", + point: pgtype.Point{ + P: pgtype.Vec2{X: 12.245, Y: 432.12}, + Status: pgtype.Present, + }, + want: []byte("(12.245, 432.12)"), + wantErr: false, + }, + { + name: "third", + point: pgtype.Point{ + P: pgtype.Vec2{}, + Status: pgtype.Null, + }, + want: []byte("null"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.point.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPoint_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + status pgtype.Status + arg []byte + wantErr bool + }{ + { + name: "first", + status: pgtype.Present, + arg: []byte("(123.123, 54.12)"), + wantErr: false, + }, + { + name: "second", + status: pgtype.Undefined, + arg: []byte("(123.123, 54.1sad2)"), + wantErr: true, + }, + { + name: "third", + status: pgtype.Null, + arg: []byte("null"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dst := &pgtype.Point{} + if err := dst.UnmarshalJSON(tt.arg); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if dst.Status != tt.status { + t.Errorf("Status mismatch: %v != %v", dst.Status, tt.status) + } + }) + } +} diff --git a/uuid.go b/uuid.go index 9f9bbefd..caaef2a7 100644 --- a/uuid.go +++ b/uuid.go @@ -203,3 +203,19 @@ func (dst *UUID) Scan(src interface{}) error { func (src UUID) Value() (driver.Value, error) { return EncodeValueText(src) } + +func (src UUID) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return []byte(encodeUUID(src.Bytes)), nil + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + return nil, errBadStatus +} + +func (dst *UUID) UnmarshalJSON(bytes []byte) error { + return dst.Set(bytes) +} From cd9b888ff6ea0d039677380d4a859e0c5b953cc4 Mon Sep 17 00:00:00 2001 From: bakmataliev Date: Fri, 11 Sep 2020 16:28:49 +0300 Subject: [PATCH 0578/1158] Remove unnecessary check for null --- point.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/point.go b/point.go index 9961f624..37187117 100644 --- a/point.go +++ b/point.go @@ -36,10 +36,6 @@ func (dst *Point) Set(src interface{}) error { case string: p, err = parsePoint([]byte(value)) case []byte: - if nullRE.Match(value) { - dst.Status = Null - return nil - } p, err = parsePoint(value) default: return err From 6777e0294b5de77ea2022318b75b2ce7336b63cc Mon Sep 17 00:00:00 2001 From: bakmataliev Date: Tue, 15 Sep 2020 13:24:17 +0300 Subject: [PATCH 0579/1158] eliminate regex dep --- point.go | 57 +++++++++++++++++++++++---------------------------- point_test.go | 10 ++++----- 2 files changed, 31 insertions(+), 36 deletions(-) diff --git a/point.go b/point.go index 37187117..55c6c8d1 100644 --- a/point.go +++ b/point.go @@ -1,11 +1,11 @@ package pgtype import ( + "bytes" "database/sql/driver" "encoding/binary" "fmt" "math" - "regexp" "strconv" "strings" @@ -23,8 +23,6 @@ type Point struct { Status Status } -var nullRE = regexp.MustCompile("^null$") - func (dst *Point) Set(src interface{}) error { if src == nil { dst.Status = Null @@ -47,34 +45,31 @@ func (dst *Point) Set(src interface{}) error { return nil } -var pointRE = regexp.MustCompile("^\\(\\d+\\.\\d+,\\s?\\d+\\.\\d+\\)$") -var chunkRE = regexp.MustCompile("\\d+\\.\\d+") - -func parsePoint(p []byte) (*Point, error) { - err := errors.Errorf("cannot parse %s", p) - if pointRE.Match(p) { - chunks := chunkRE.FindAll(p, 2) - if len(chunks) != 2 { - return nil, err - } - x, xErr := strconv.ParseFloat(string(chunks[0]), 64) - y, yErr := strconv.ParseFloat(string(chunks[1]), 64) - if xErr != nil || yErr != nil { - return nil, err - } - return &Point{ - P: Vec2{ - X: x, - Y: y, - }, - Status: Present, - }, nil - } else if nullRE.Match(p) { - return &Point{ - Status: Null, - }, nil +func parsePoint(src []byte) (*Point, error) { + if src == nil || bytes.Compare(src, []byte("null")) == 0 { + return &Point{Status: Null}, nil } - return nil, err + + if len(src) < 5 { + return nil, errors.Errorf("invalid length for point: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return nil, errors.Errorf("invalid format for point") + } + + x, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return nil, err + } + + y, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return nil, err + } + + return &Point{P: Vec2{x, y}, Status: Present}, nil } func (dst Point) Get() interface{} { @@ -195,7 +190,7 @@ func (src Point) Value() (driver.Value, error) { func (src Point) MarshalJSON() ([]byte, error) { switch src.Status { case Present: - return []byte(fmt.Sprintf("(%g, %g)", src.P.X, src.P.Y)), nil + return []byte(fmt.Sprintf("(%g,%g)", src.P.X, src.P.Y)), nil case Null: return []byte("null"), nil case Undefined: diff --git a/point_test.go b/point_test.go index 9a659cbc..3601cf02 100644 --- a/point_test.go +++ b/point_test.go @@ -25,13 +25,13 @@ func TestPoint_Set(t *testing.T) { }{ { name: "first", - arg: "(12312.123123, 123123.123123)", + arg: "(12312.123123,123123.123123)", status: pgtype.Present, wantErr: false, }, { name: "second", - arg: "(1231s2.123123, 123123.123123)", + arg: "(1231s2.123123,123123.123123)", status: pgtype.Undefined, wantErr: true, }, @@ -83,7 +83,7 @@ func TestPoint_MarshalJSON(t *testing.T) { P: pgtype.Vec2{X: 12.245, Y: 432.12}, Status: pgtype.Present, }, - want: []byte("(12.245, 432.12)"), + want: []byte("(12.245,432.12)"), wantErr: false, }, { @@ -120,13 +120,13 @@ func TestPoint_UnmarshalJSON(t *testing.T) { { name: "first", status: pgtype.Present, - arg: []byte("(123.123, 54.12)"), + arg: []byte("(123.123,54.12)"), wantErr: false, }, { name: "second", status: pgtype.Undefined, - arg: []byte("(123.123, 54.1sad2)"), + arg: []byte("(123.123,54.1sad2)"), wantErr: true, }, { From fbe354aea17873cb129a792b32fe0717b0482935 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 15 Sep 2020 17:21:13 -0500 Subject: [PATCH 0580/1158] Remove editor specific .gitignore --- .gitignore | 1 - 1 file changed, 1 deletion(-) delete mode 100644 .gitignore diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 723ef36f..00000000 --- a/.gitignore +++ /dev/null @@ -1 +0,0 @@ -.idea \ No newline at end of file From 835cf1b0689d054f58d82f3970d59abae56178dc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 17 Sep 2020 17:03:30 -0500 Subject: [PATCH 0581/1158] Fix: Bind.MarshalJSON when ParameterFormatCodes is nil or single element refs #10 --- bind.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/bind.go b/bind.go index 65b4c1d6..52372095 100644 --- a/bind.go +++ b/bind.go @@ -151,7 +151,14 @@ func (src Bind) MarshalJSON() ([]byte, error) { continue } - if src.ParameterFormatCodes[i] == 0 { + textFormat := true + if len(src.ParameterFormatCodes) == 1 { + textFormat = src.ParameterFormatCodes[0] == 0 + } else if len(src.ParameterFormatCodes) > 1 { + textFormat = src.ParameterFormatCodes[i] == 0 + } + + if textFormat { formattedParameters[i] = map[string]string{"text": string(p)} } else { formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)} From d7f92427adf195b84fe5f21caefd2309519e07e9 Mon Sep 17 00:00:00 2001 From: Bekmamat Date: Sat, 19 Sep 2020 21:50:56 +0300 Subject: [PATCH 0582/1158] fixed marshaling and unmarshaling --- point.go | 10 ++++-- point_test.go | 24 ++++++------- uuid.go | 17 +++++++-- uuid_test.go | 98 +++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 132 insertions(+), 17 deletions(-) diff --git a/point.go b/point.go index 55c6c8d1..8e6bacf2 100644 --- a/point.go +++ b/point.go @@ -53,7 +53,9 @@ func parsePoint(src []byte) (*Point, error) { if len(src) < 5 { return nil, errors.Errorf("invalid length for point: %v", len(src)) } - + if src[0] == '"' && src[len(src)-1] == '"' { + src = src[1 : len(src)-1] + } parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) if len(parts) < 2 { return nil, errors.Errorf("invalid format for point") @@ -190,7 +192,11 @@ func (src Point) Value() (driver.Value, error) { func (src Point) MarshalJSON() ([]byte, error) { switch src.Status { case Present: - return []byte(fmt.Sprintf("(%g,%g)", src.P.X, src.P.Y)), nil + var buff bytes.Buffer + buff.WriteByte('"') + buff.WriteString(fmt.Sprintf("(%g,%g)", src.P.X, src.P.Y)) + buff.WriteByte('"') + return buff.Bytes(), nil case Null: return []byte("null"), nil case Undefined: diff --git a/point_test.go b/point_test.go index 3601cf02..63f8df07 100644 --- a/point_test.go +++ b/point_test.go @@ -72,7 +72,7 @@ func TestPoint_MarshalJSON(t *testing.T) { name: "first", point: pgtype.Point{ P: pgtype.Vec2{}, - Status: 0, + Status: pgtype.Undefined, }, want: nil, wantErr: true, @@ -83,7 +83,7 @@ func TestPoint_MarshalJSON(t *testing.T) { P: pgtype.Vec2{X: 12.245, Y: 432.12}, Status: pgtype.Present, }, - want: []byte("(12.245,432.12)"), + want: []byte(`"(12.245,432.12)"`), wantErr: false, }, { @@ -113,26 +113,26 @@ func TestPoint_MarshalJSON(t *testing.T) { func TestPoint_UnmarshalJSON(t *testing.T) { tests := []struct { name string - status pgtype.Status + status pgtype.Status arg []byte wantErr bool }{ { - name: "first", - status: pgtype.Present, - arg: []byte("(123.123,54.12)"), + name: "first", + status: pgtype.Present, + arg: []byte(`"(123.123,54.12)"`), wantErr: false, }, { - name: "second", - status: pgtype.Undefined, - arg: []byte("(123.123,54.1sad2)"), + name: "second", + status: pgtype.Undefined, + arg: []byte(`"(123.123,54.1sad2)"`), wantErr: true, }, { - name: "third", - status: pgtype.Null, - arg: []byte("null"), + name: "third", + status: pgtype.Null, + arg: []byte("null"), wantErr: false, }, } diff --git a/uuid.go b/uuid.go index caaef2a7..b1681a78 100644 --- a/uuid.go +++ b/uuid.go @@ -1,6 +1,7 @@ package pgtype import ( + "bytes" "database/sql/driver" "encoding/hex" "fmt" @@ -207,7 +208,11 @@ func (src UUID) Value() (driver.Value, error) { func (src UUID) MarshalJSON() ([]byte, error) { switch src.Status { case Present: - return []byte(encodeUUID(src.Bytes)), nil + var buff bytes.Buffer + buff.WriteByte('"') + buff.WriteString(encodeUUID(src.Bytes)) + buff.WriteByte('"') + return buff.Bytes(), nil case Null: return []byte("null"), nil case Undefined: @@ -216,6 +221,12 @@ func (src UUID) MarshalJSON() ([]byte, error) { return nil, errBadStatus } -func (dst *UUID) UnmarshalJSON(bytes []byte) error { - return dst.Set(bytes) +func (dst *UUID) UnmarshalJSON(src []byte) error { + if bytes.Compare(src, []byte("null")) == 0 { + return dst.Set(nil) + } + if len(src) != 38 { + return errors.Errorf("invalid length for UUID: %v", len(src)) + } + return dst.Set(string(src[1 : len(src)-1])) } diff --git a/uuid_test.go b/uuid_test.go index 9f7b19e2..8de5b9f6 100644 --- a/uuid_test.go +++ b/uuid_test.go @@ -2,6 +2,7 @@ package pgtype_test import ( "bytes" + "reflect" "testing" "github.com/jackc/pgtype" @@ -127,3 +128,100 @@ func TestUUIDAssignTo(t *testing.T) { } } + +func TestUUID_MarshalJSON(t *testing.T) { + tests := []struct { + name string + src pgtype.UUID + want []byte + wantErr bool + }{ + { + name: "first", + src: pgtype.UUID{ + Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, + Status: pgtype.Present, + }, + want: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), + wantErr: false, + }, + { + name: "second", + src: pgtype.UUID{ + Bytes: [16]byte{}, + Status: pgtype.Undefined, + }, + want: nil, + wantErr: true, + }, + { + name: "third", + src: pgtype.UUID{ + Bytes: [16]byte{}, + Status: pgtype.Null, + }, + want: []byte("null"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.src.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUUID_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + want *pgtype.UUID + src []byte + wantErr bool + }{ + { + name: "first", + want: &pgtype.UUID{ + Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, + Status: pgtype.Present, + }, + src: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), + wantErr: false, + }, + { + name: "second", + want: &pgtype.UUID{ + Bytes: [16]byte{}, + Status: pgtype.Null, + }, + src: []byte("null"), + wantErr: false, + }, + { + name: "third", + want: &pgtype.UUID{ + Bytes: [16]byte{}, + Status: pgtype.Undefined, + }, + src: []byte("1d485a7a-6d18-4599-8c6c-34425616887a"), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &pgtype.UUID{} + if err := got.UnmarshalJSON(tt.src); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} From 28d24269e93ebc5aacc9271320226e6faae0c4dc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Sep 2020 11:35:23 -0500 Subject: [PATCH 0583/1158] Upgrade pgproto3 to v2.0.5 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index a74028c8..f2c10401 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.4 + github.com/jackc/pgproto3/v2 v2.0.5 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 diff --git a/go.sum b/go.sum index 61a2896b..08a11e19 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,8 @@ github.com/jackc/pgproto3/v2 v2.0.3 h1:2S4PhE00mvdvaSiCYR1ZCmR1NAxeYfTSsqqSKxE1v github.com/jackc/pgproto3/v2 v2.0.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.0.4 h1:RHkX5ZUD9bl/kn0f9dYUWs1N7Nwvo1wwUYvKiR26Zco= github.com/jackc/pgproto3/v2 v2.0.4/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.0.5 h1:NUbEWPmCQZbMmYlTjVoNPhc0CfnYyz2bfUAh6A5ZVJM= +github.com/jackc/pgproto3/v2 v2.0.5/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= From 035868ca0c24b120f199e4bef6ac29a333e76baa Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Sep 2020 11:39:23 -0500 Subject: [PATCH 0584/1158] Release v1.7.0 --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 497e00a1..e7444fcd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ -# Unreleased +# 1.7.0 (September 26, 2020) +* Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded +* Add ReceiveResults (Sebastiaan Mannem) +* Fix parsing DSN connection with bad backslash * Add PgConn.CleanupDone so connection pools can determine when async close is complete # 1.6.4 (July 29, 2020) From 116eba440170191ad93916df893f6543a05324b3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Sep 2020 11:48:29 -0500 Subject: [PATCH 0585/1158] Release v1.5.0 --- CHANGELOG.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d117d239..774f0c1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,12 @@ +# 1.5.0 (September 26, 2020) + +* Add slice of slice mapping to multi-dimensional arrays (Simo Haasanen) +* Fix JSONBArray +* Fix selecting empty array +* Text formatted values except bytea can be directly scanned to []byte +* Add JSON marshalling for UUID (bakmataliev) +* Improve point type conversions (bakmataliev) + # 1.4.2 (July 22, 2020) * Fix encoding of a large composite data type (Yaz Saito) From 909d814f6574a3f4d666929a46da2b6f82acae01 Mon Sep 17 00:00:00 2001 From: lqu3j Date: Tue, 29 Sep 2020 13:10:38 +0800 Subject: [PATCH 0586/1158] support float64, float32 convert to int2, int4, int8 --- int2.go | 22 ++++++++++++++++++++++ int4.go | 22 ++++++++++++++++++++++ int8.go | 22 ++++++++++++++++++++++ 3 files changed, 66 insertions(+) diff --git a/int2.go b/int2.go index 67fa1acc..b7517881 100644 --- a/int2.go +++ b/int2.go @@ -85,6 +85,16 @@ func (dst *Int2) Set(src interface{}) error { return err } *dst = Int2{Int: int16(num), Status: Present} + case float32: + if value > math.MaxInt16 { + return errors.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case float64: + if value > math.MaxInt16 { + return errors.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} case *int8: if value == nil { *dst = Int2{Status: Null} @@ -151,6 +161,18 @@ func (dst *Int2) Set(src interface{}) error { } else { return dst.Set(*value) } + case *float32: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *float64: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) diff --git a/int4.go b/int4.go index c4ed6103..66652bbe 100644 --- a/int4.go +++ b/int4.go @@ -77,6 +77,16 @@ func (dst *Int4) Set(src interface{}) error { return err } *dst = Int4{Int: int32(num), Status: Present} + case float32: + if value > math.MaxInt32 { + return errors.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Status: Present} + case float64: + if value > math.MaxInt32 { + return errors.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Status: Present} case *int8: if value == nil { *dst = Int4{Status: Null} @@ -143,6 +153,18 @@ func (dst *Int4) Set(src interface{}) error { } else { return dst.Set(*value) } + case *float32: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *float64: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) diff --git a/int8.go b/int8.go index 445fef0d..f0114194 100644 --- a/int8.go +++ b/int8.go @@ -68,6 +68,16 @@ func (dst *Int8) Set(src interface{}) error { return err } *dst = Int8{Int: num, Status: Present} + case float32: + if value > math.MaxInt64 { + return errors.Errorf("%d is greater than maximum value for Int8", value) + } + *dst = Int8{Int: int64(value), Status: Present} + case float64: + if value > math.MaxInt64 { + return errors.Errorf("%d is greater than maximum value for Int8", value) + } + *dst = Int8{Int: int64(value), Status: Present} case *int8: if value == nil { *dst = Int8{Status: Null} @@ -134,6 +144,18 @@ func (dst *Int8) Set(src interface{}) error { } else { return dst.Set(*value) } + case *float32: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *float64: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) From 376361f53ddd86ad4381ee4e5d5a802fa33c3de7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Oct 2020 08:36:40 -0500 Subject: [PATCH 0587/1158] Add tests for Int(2|4|8).Set accepting float(32|64) --- int2_test.go | 2 ++ int4_test.go | 2 ++ int8_test.go | 2 ++ 3 files changed, 6 insertions(+) diff --git a/int2_test.go b/int2_test.go index cf8acd30..178eb278 100644 --- a/int2_test.go +++ b/int2_test.go @@ -37,6 +37,8 @@ func TestInt2Set(t *testing.T) { {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: float32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: float64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, } diff --git a/int4_test.go b/int4_test.go index c679de74..ae01114f 100644 --- a/int4_test.go +++ b/int4_test.go @@ -37,6 +37,8 @@ func TestInt4Set(t *testing.T) { {source: uint16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, {source: uint32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, {source: uint64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: float32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: float64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, {source: "1", result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, {source: _int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, } diff --git a/int8_test.go b/int8_test.go index fb6f581b..4e28e374 100644 --- a/int8_test.go +++ b/int8_test.go @@ -37,6 +37,8 @@ func TestInt8Set(t *testing.T) { {source: uint16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, {source: uint32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, {source: uint64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: float32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: float64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, {source: "1", result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, {source: _int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, } From 416f037e777022678e4138e780381b8f5b58364f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 5 Oct 2020 19:39:05 -0500 Subject: [PATCH 0588/1158] Fix docs for Timeout --- errors.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/errors.go b/errors.go index b746c825..164b0848 100644 --- a/errors.go +++ b/errors.go @@ -20,7 +20,7 @@ func SafeToRetry(err error) bool { } // Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a -// context.Canceled, context.Canceled or an implementer of net.Error where Timeout() is true. +// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. func Timeout(err error) bool { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return true From 66c36ff24fdbb4a032ff317393441f370a6ab385 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Agsj=C3=B6?= Date: Thu, 8 Oct 2020 09:38:42 +0200 Subject: [PATCH 0589/1158] Support setting infinite timestamps --- timestamp.go | 2 ++ timestamp_test.go | 2 ++ timestamptz.go | 2 ++ timestamptz_test.go | 2 ++ 4 files changed, 8 insertions(+) diff --git a/timestamp.go b/timestamp.go index 88cb7672..0e127695 100644 --- a/timestamp.go +++ b/timestamp.go @@ -46,6 +46,8 @@ func (dst *Timestamp) Set(src interface{}) error { } else { return dst.Set(*value) } + case InfinityModifier: + *dst = Timestamp{InfinityModifier: value, Status: Present} default: if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) diff --git a/timestamp_test.go b/timestamp_test.go index 2fdc7171..b2fbda94 100644 --- a/timestamp_test.go +++ b/timestamp_test.go @@ -92,6 +92,8 @@ func TestTimestampSet(t *testing.T) { {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: pgtype.Infinity, result: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, + {source: pgtype.NegativeInfinity, result: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, } for i, tt := range successfulTests { diff --git a/timestamptz.go b/timestamptz.go index 25ea659d..d54974af 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -48,6 +48,8 @@ func (dst *Timestamptz) Set(src interface{}) error { } else { return dst.Set(*value) } + case InfinityModifier: + *dst = Timestamptz{InfinityModifier: value, Status: Present} default: if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) diff --git a/timestamptz_test.go b/timestamptz_test.go index a088fc08..828184b7 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -91,6 +91,8 @@ func TestTimestamptzSet(t *testing.T) { {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), Status: pgtype.Present}}, {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: pgtype.Infinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, + {source: pgtype.NegativeInfinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, } for i, tt := range successfulTests { From 2dca42ee7d3456ee38ba6a101de0661a0b4ee663 Mon Sep 17 00:00:00 2001 From: duohedron Date: Tue, 6 Oct 2020 08:41:57 +0200 Subject: [PATCH 0590/1158] Add Set(string|[]Vec2|[]float64) to Polygon --- polygon.go | 45 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/polygon.go b/polygon.go index 653b04c1..99a3c45a 100644 --- a/polygon.go +++ b/polygon.go @@ -18,7 +18,50 @@ type Polygon struct { } func (dst *Polygon) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Polygon", src) + if src == nil { + dst.Status = Null + return nil + } + err := errors.Errorf("cannot convert %v to Polygon", src) + var p *Polygon + switch value := src.(type) { + case string: + p, err = parseString(value) + case []Vec2: + p = &Polygon{Status: Present, P: value} + err = nil + case []float64: + p, err = parseFloat64(value) + default: + return err + } + if err != nil { + return err + } + *dst = *p + return nil +} + +func parseString(src string) (*Polygon, error) { + p := &Polygon{} + err := p.DecodeText(nil, []byte(src)) + return p, err +} + +func parseFloat64(src []float64) (*Polygon, error) { + p := &Polygon{Status: Null} + if len(src) == 0 { + return p, nil + } + if len(src)%2 != 0 { + return p, errors.Errorf("invalid length for polygon: %v", len(src)) + } + p.Status = Present + p.P = make([]Vec2, 0) + for i := 0; i < len(src); i += 2 { + p.P = append(p.P, Vec2{X: src[i], Y: src[i+1]}) + } + return p, nil } func (dst Polygon) Get() interface{} { From e09987f1d687b49de408699139f9bb9d263fd884 Mon Sep 17 00:00:00 2001 From: duohedron Date: Tue, 6 Oct 2020 08:43:41 +0200 Subject: [PATCH 0591/1158] Add tests to Polygon --- polygon_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/polygon_test.go b/polygon_test.go index f8b02ca2..7bda2628 100644 --- a/polygon_test.go +++ b/polygon_test.go @@ -20,3 +20,45 @@ func TestPolygonTranscode(t *testing.T) { &pgtype.Polygon{Status: pgtype.Null}, }) } + +func TestPolygon_Set(t *testing.T) { + tests := []struct { + name string + arg interface{} + status pgtype.Status + wantErr bool + }{ + { + name: "string", + arg: "((3.14,1.678901234),(7.1,5.234),(5.0,3.234))", + status: pgtype.Present, + wantErr: false, + }, { + name: "[]float64", + arg: []float64{1, 2, 3.45, 6.78, 1.23, 4.567, 8.9, 1.0}, + status: pgtype.Present, + wantErr: false, + }, { + name: "[]Vec2", + arg: []pgtype.Vec2{{1, 2}, {2.3, 4.5}, {6.78, 9.123}}, + status: pgtype.Present, + wantErr: false, + }, { + name: "null", + arg: nil, + status: pgtype.Null, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dst := &pgtype.Polygon{} + if err := dst.Set(tt.arg); (err != nil) != tt.wantErr { + t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr) + } + if dst.Status != tt.status { + t.Errorf("Expected status: %v; got: %v", tt.status, dst.Status) + } + }) + } +} From 6166c99b7719a110092800554b1efc73c1f46539 Mon Sep 17 00:00:00 2001 From: duohedron Date: Tue, 6 Oct 2020 09:05:55 +0200 Subject: [PATCH 0592/1158] Add Undefined status to invalid Polygon --- polygon.go | 1 + 1 file changed, 1 insertion(+) diff --git a/polygon.go b/polygon.go index 99a3c45a..5c7f564d 100644 --- a/polygon.go +++ b/polygon.go @@ -54,6 +54,7 @@ func parseFloat64(src []float64) (*Polygon, error) { return p, nil } if len(src)%2 != 0 { + p.Status = Undefined return p, errors.Errorf("invalid length for polygon: %v", len(src)) } p.Status = Present From 8aa7211df5f42e6aee0eb2db27e5eeb579147b74 Mon Sep 17 00:00:00 2001 From: duohedron Date: Tue, 6 Oct 2020 09:06:38 +0200 Subject: [PATCH 0593/1158] Add tests to Polygon --- polygon_test.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/polygon_test.go b/polygon_test.go index 7bda2628..1a139444 100644 --- a/polygon_test.go +++ b/polygon_test.go @@ -48,6 +48,31 @@ func TestPolygon_Set(t *testing.T) { arg: nil, status: pgtype.Null, wantErr: false, + }, { + name: "invalid_string_1", + arg: "((3.14,1.678901234),(7.1,5.234),(5.0,3.234x))", + status: pgtype.Undefined, + wantErr: true, + }, { + name: "invalid_string_2", + arg: "(3,4)", + status: pgtype.Undefined, + wantErr: true, + }, { + name: "invalid_[]float64", + arg: []float64{1, 2, 3.45, 6.78, 1.23, 4.567, 8.9}, + status: pgtype.Undefined, + wantErr: true, + }, { + name: "invalid_type", + arg: []int{1, 2, 3, 6}, + status: pgtype.Undefined, + wantErr: true, + }, { + name: "empty_[]float64", + arg: []float64{}, + status: pgtype.Null, + wantErr: false, }, } for _, tt := range tests { From b55f972f49ebc1a7e757c0e12e445839fc1277c1 Mon Sep 17 00:00:00 2001 From: duohedron Date: Tue, 6 Oct 2020 09:32:08 +0200 Subject: [PATCH 0594/1158] Add comment to Polygon.Set() --- polygon.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/polygon.go b/polygon.go index 5c7f564d..a99051d0 100644 --- a/polygon.go +++ b/polygon.go @@ -17,6 +17,12 @@ type Polygon struct { Status Status } +// Set converts src to dest. +// +// src can be nil, string, []float64, and []pgtype.Vec2. +// +// If src is string the format must be ((x1,y1),(x2,y2),...,(xn,yn)). +// Important that there are no spaces in it. func (dst *Polygon) Set(src interface{}) error { if src == nil { dst.Status = Null From 2bc8c67e4a62c892b203013d01b59110f648ee4b Mon Sep 17 00:00:00 2001 From: duohedron Date: Thu, 8 Oct 2020 14:31:02 +0200 Subject: [PATCH 0595/1158] Fix misleading names parseString and parseFloat64 in polygon.go --- polygon.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/polygon.go b/polygon.go index a99051d0..5124af7f 100644 --- a/polygon.go +++ b/polygon.go @@ -32,12 +32,12 @@ func (dst *Polygon) Set(src interface{}) error { var p *Polygon switch value := src.(type) { case string: - p, err = parseString(value) + p, err = stringToPolygon(value) case []Vec2: p = &Polygon{Status: Present, P: value} err = nil case []float64: - p, err = parseFloat64(value) + p, err = float64ToPolygon(value) default: return err } @@ -48,13 +48,13 @@ func (dst *Polygon) Set(src interface{}) error { return nil } -func parseString(src string) (*Polygon, error) { +func stringToPolygon(src string) (*Polygon, error) { p := &Polygon{} err := p.DecodeText(nil, []byte(src)) return p, err } -func parseFloat64(src []float64) (*Polygon, error) { +func float64ToPolygon(src []float64) (*Polygon, error) { p := &Polygon{Status: Null} if len(src) == 0 { return p, nil From e92478ec70e12ec4c64ca707d32dc41b164ee6b8 Mon Sep 17 00:00:00 2001 From: Tomas Volf Date: Tue, 13 Oct 2020 15:26:09 +0200 Subject: [PATCH 0596/1158] Fix Inet.Set to handle nil net.IP correctly When nil IP is returned from net.ParseIP, it is accepted into Inet type, but not properly marked as being Null. That introduces issues later on when calling for example EncodeBinary, since it does not assume this can happen. This commit resolves that by properly detecting zero-length net.IP and setting status to Null if that is the case. --- inet.go | 10 +++++++--- inet_test.go | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/inet.go b/inet.go index f3dce87b..b4498191 100644 --- a/inet.go +++ b/inet.go @@ -38,9 +38,13 @@ func (dst *Inet) Set(src interface{}) error { case net.IPNet: *dst = Inet{IPNet: &value, Status: Present} case net.IP: - bitCount := len(value) * 8 - mask := net.CIDRMask(bitCount, bitCount) - *dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Status: Present} + if len(value) == 0 { + *dst = Inet{Status: Null} + } else { + bitCount := len(value) * 8 + mask := net.CIDRMask(bitCount, bitCount) + *dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Status: Present} + } case string: _, ipnet, err := net.ParseCIDR(value) if err != nil { diff --git a/inet_test.go b/inet_test.go index 8257a63d..cb420a51 100644 --- a/inet_test.go +++ b/inet_test.go @@ -35,6 +35,7 @@ func TestInetSet(t *testing.T) { {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: net.ParseIP(""), result: pgtype.Inet{Status: pgtype.Null}}, } for i, tt := range successfulTests { From 9639a69d451f55456f598c1aa8b93053f8df3088 Mon Sep 17 00:00:00 2001 From: Simo Haasanen Date: Tue, 20 Oct 2020 19:52:05 +0100 Subject: [PATCH 0597/1158] Adds checks for zero length arrays. Assigning values from nil or zero length elements or dimensions now return immediately as there are no values to assign. --- aclitem_array.go | 4 ++++ aclitem_array_test.go | 5 +++++ bool_array.go | 4 ++++ bool_array_test.go | 5 +++++ bpchar_array.go | 4 ++++ bytea_array.go | 4 ++++ bytea_array_test.go | 5 +++++ cidr_array.go | 4 ++++ cidr_array_test.go | 10 ++++++++++ date_array.go | 4 ++++ date_array_test.go | 5 +++++ enum_array.go | 4 ++++ enum_array_test.go | 5 +++++ float4_array.go | 4 ++++ float4_array_test.go | 5 +++++ float8_array.go | 4 ++++ float8_array_test.go | 5 +++++ hstore_array.go | 4 ++++ hstore_array_test.go | 3 +++ inet_array.go | 4 ++++ inet_array_test.go | 10 ++++++++++ int2_array.go | 4 ++++ int2_array_test.go | 5 +++++ int4_array.go | 4 ++++ int4_array_test.go | 5 +++++ int8_array.go | 4 ++++ int8_array_test.go | 5 +++++ jsonb_array.go | 4 ++++ macaddr_array.go | 4 ++++ macaddr_array_test.go | 5 +++++ numeric_array.go | 4 ++++ numeric_array_test.go | 5 +++++ text_array.go | 4 ++++ text_array_test.go | 5 +++++ timestamp_array.go | 4 ++++ timestamp_array_test.go | 5 +++++ timestamptz_array.go | 4 ++++ timestamptz_array_test.go | 5 +++++ tstzrange_array.go | 4 ++++ typed_array.go.erb | 4 ++++ uuid_array.go | 4 ++++ uuid_array_test.go | 11 +++++++++++ varchar_array.go | 4 ++++ varchar_array_test.go | 5 +++++ 44 files changed, 210 insertions(+) diff --git a/aclitem_array.go b/aclitem_array.go index 673951a6..95f74aa7 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -190,6 +190,10 @@ func (dst ACLItemArray) Get() interface{} { func (src *ACLItemArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/aclitem_array_test.go b/aclitem_array_test.go index 73e9ce71..8ebb8c8d 100644 --- a/aclitem_array_test.go +++ b/aclitem_array_test.go @@ -189,6 +189,11 @@ func TestACLItemArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: (([]string)(nil)), }, + { + src: pgtype.ACLItemArray{Status: pgtype.Present}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, { src: pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ diff --git a/bool_array.go b/bool_array.go index ed411a28..c9447db1 100644 --- a/bool_array.go +++ b/bool_array.go @@ -192,6 +192,10 @@ func (dst BoolArray) Get() interface{} { func (src *BoolArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/bool_array_test.go b/bool_array_test.go index 7f31e252..4ff26bb5 100644 --- a/bool_array_test.go +++ b/bool_array_test.go @@ -168,6 +168,11 @@ func TestBoolArrayAssignTo(t *testing.T) { dst: &boolSlice, expected: (([]bool)(nil)), }, + { + src: pgtype.BoolArray{Status: pgtype.Present}, + dst: &boolSlice, + expected: (([]bool)(nil)), + }, { src: pgtype.BoolArray{ Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, diff --git a/bpchar_array.go b/bpchar_array.go index 0b92ba30..f814930f 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -192,6 +192,10 @@ func (dst BPCharArray) Get() interface{} { func (src *BPCharArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/bytea_array.go b/bytea_array.go index f80980bd..618a2f4b 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -173,6 +173,10 @@ func (dst ByteaArray) Get() interface{} { func (src *ByteaArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/bytea_array_test.go b/bytea_array_test.go index f40005a2..4964771c 100644 --- a/bytea_array_test.go +++ b/bytea_array_test.go @@ -157,6 +157,11 @@ func TestByteaArrayAssignTo(t *testing.T) { dst: &byteByteSlice, expected: (([][]byte)(nil)), }, + { + src: pgtype.ByteaArray{Status: pgtype.Present}, + dst: &byteByteSlice, + expected: (([][]byte)(nil)), + }, { src: pgtype.ByteaArray{ Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, diff --git a/cidr_array.go b/cidr_array.go index 0b902cca..8ea7d7a6 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -212,6 +212,10 @@ func (dst CIDRArray) Get() interface{} { func (src *CIDRArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/cidr_array_test.go b/cidr_array_test.go index b1769c38..aa933b62 100644 --- a/cidr_array_test.go +++ b/cidr_array_test.go @@ -217,11 +217,21 @@ func TestCIDRArrayAssignTo(t *testing.T) { dst: &ipnetSlice, expected: (([]*net.IPNet)(nil)), }, + { + src: pgtype.CIDRArray{Status: pgtype.Present}, + dst: &ipnetSlice, + expected: (([]*net.IPNet)(nil)), + }, { src: pgtype.CIDRArray{Status: pgtype.Null}, dst: &ipSlice, expected: (([]net.IP)(nil)), }, + { + src: pgtype.CIDRArray{Status: pgtype.Present}, + dst: &ipSlice, + expected: (([]net.IP)(nil)), + }, { src: pgtype.CIDRArray{ Elements: []pgtype.CIDR{ diff --git a/date_array.go b/date_array.go index b306589e..dc4cb2e3 100644 --- a/date_array.go +++ b/date_array.go @@ -193,6 +193,10 @@ func (dst DateArray) Get() interface{} { func (src *DateArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/date_array_test.go b/date_array_test.go index 089c7dd4..8791c31f 100644 --- a/date_array_test.go +++ b/date_array_test.go @@ -182,6 +182,11 @@ func TestDateArrayAssignTo(t *testing.T) { dst: &timeSlice, expected: (([]time.Time)(nil)), }, + { + src: pgtype.DateArray{Status: pgtype.Present}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, { src: pgtype.DateArray{ Elements: []pgtype.Date{ diff --git a/enum_array.go b/enum_array.go index 4b6d2af4..f5312a04 100644 --- a/enum_array.go +++ b/enum_array.go @@ -190,6 +190,10 @@ func (dst EnumArray) Get() interface{} { func (src *EnumArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/enum_array_test.go b/enum_array_test.go index 91a81ab6..9db8b49f 100644 --- a/enum_array_test.go +++ b/enum_array_test.go @@ -167,6 +167,11 @@ func TestEnumArrayArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: (([]string)(nil)), }, + { + src: pgtype.EnumArray{Status: pgtype.Present}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, { src: pgtype.EnumArray{ Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, diff --git a/float4_array.go b/float4_array.go index 22577023..88dd84ab 100644 --- a/float4_array.go +++ b/float4_array.go @@ -192,6 +192,10 @@ func (dst Float4Array) Get() interface{} { func (src *Float4Array) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/float4_array_test.go b/float4_array_test.go index 23a94ee8..88d35fd6 100644 --- a/float4_array_test.go +++ b/float4_array_test.go @@ -167,6 +167,11 @@ func TestFloat4ArrayAssignTo(t *testing.T) { dst: &float32Slice, expected: (([]float32)(nil)), }, + { + src: pgtype.Float4Array{Status: pgtype.Present}, + dst: &float32Slice, + expected: (([]float32)(nil)), + }, { src: pgtype.Float4Array{ Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, diff --git a/float8_array.go b/float8_array.go index 6c309700..9d79a449 100644 --- a/float8_array.go +++ b/float8_array.go @@ -192,6 +192,10 @@ func (dst Float8Array) Get() interface{} { func (src *Float8Array) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/float8_array_test.go b/float8_array_test.go index 052ab3f3..d7bf6ac3 100644 --- a/float8_array_test.go +++ b/float8_array_test.go @@ -143,6 +143,11 @@ func TestFloat8ArrayAssignTo(t *testing.T) { dst: &float64Slice, expected: (([]float64)(nil)), }, + { + src: pgtype.Float8Array{Status: pgtype.Present}, + dst: &float64Slice, + expected: (([]float64)(nil)), + }, { src: pgtype.Float8Array{ Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, diff --git a/hstore_array.go b/hstore_array.go index 413e3993..d0b34b3c 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -173,6 +173,10 @@ func (dst HstoreArray) Get() interface{} { func (src *HstoreArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/hstore_array_test.go b/hstore_array_test.go index fac66b4a..3d85545a 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -302,6 +302,9 @@ func TestHstoreArrayAssignTo(t *testing.T) { { src: pgtype.HstoreArray{Status: pgtype.Null}, dst: &hstoreSlice, expected: (([]map[string]string)(nil)), }, + { + src: pgtype.HstoreArray{Status: pgtype.Present}, dst: &hstoreSlice, expected: (([]map[string]string)(nil)), + }, { src: pgtype.HstoreArray{ Elements: []pgtype.Hstore{ diff --git a/inet_array.go b/inet_array.go index c4368ebc..2058db81 100644 --- a/inet_array.go +++ b/inet_array.go @@ -212,6 +212,10 @@ func (dst InetArray) Get() interface{} { func (src *InetArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/inet_array_test.go b/inet_array_test.go index d78b91c0..5beab960 100644 --- a/inet_array_test.go +++ b/inet_array_test.go @@ -217,11 +217,21 @@ func TestInetArrayAssignTo(t *testing.T) { dst: &ipnetSlice, expected: (([]*net.IPNet)(nil)), }, + { + src: pgtype.InetArray{Status: pgtype.Present}, + dst: &ipnetSlice, + expected: (([]*net.IPNet)(nil)), + }, { src: pgtype.InetArray{Status: pgtype.Null}, dst: &ipSlice, expected: (([]net.IP)(nil)), }, + { + src: pgtype.InetArray{Status: pgtype.Present}, + dst: &ipSlice, + expected: (([]net.IP)(nil)), + }, { src: pgtype.InetArray{ Elements: []pgtype.Inet{ diff --git a/int2_array.go b/int2_array.go index 71ccc0c4..bf6a6284 100644 --- a/int2_array.go +++ b/int2_array.go @@ -458,6 +458,10 @@ func (dst Int2Array) Get() interface{} { func (src *Int2Array) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/int2_array_test.go b/int2_array_test.go index dfe84c19..da669f7d 100644 --- a/int2_array_test.go +++ b/int2_array_test.go @@ -219,6 +219,11 @@ func TestInt2ArrayAssignTo(t *testing.T) { dst: &int16Slice, expected: (([]int16)(nil)), }, + { + src: pgtype.Int2Array{Status: pgtype.Present}, + dst: &int16Slice, + expected: (([]int16)(nil)), + }, { src: pgtype.Int2Array{ Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, diff --git a/int4_array.go b/int4_array.go index 09b23c2f..05e10dc3 100644 --- a/int4_array.go +++ b/int4_array.go @@ -458,6 +458,10 @@ func (dst Int4Array) Get() interface{} { func (src *Int4Array) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/int4_array_test.go b/int4_array_test.go index 35b791d3..a5aad827 100644 --- a/int4_array_test.go +++ b/int4_array_test.go @@ -233,6 +233,11 @@ func TestInt4ArrayAssignTo(t *testing.T) { dst: &int32Slice, expected: (([]int32)(nil)), }, + { + src: pgtype.Int4Array{Status: pgtype.Present}, + dst: &int32Slice, + expected: (([]int32)(nil)), + }, { src: pgtype.Int4Array{ Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, diff --git a/int8_array.go b/int8_array.go index 93a902b0..d149558f 100644 --- a/int8_array.go +++ b/int8_array.go @@ -458,6 +458,10 @@ func (dst Int8Array) Get() interface{} { func (src *Int8Array) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/int8_array_test.go b/int8_array_test.go index d65b875a..b0ee97ee 100644 --- a/int8_array_test.go +++ b/int8_array_test.go @@ -226,6 +226,11 @@ func TestInt8ArrayAssignTo(t *testing.T) { dst: &int64Slice, expected: (([]int64)(nil)), }, + { + src: pgtype.Int8Array{Status: pgtype.Present}, + dst: &int64Slice, + expected: (([]int64)(nil)), + }, { src: pgtype.Int8Array{ Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, diff --git a/jsonb_array.go b/jsonb_array.go index 98970dcf..36411b9d 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -192,6 +192,10 @@ func (dst JSONBArray) Get() interface{} { func (src *JSONBArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/macaddr_array.go b/macaddr_array.go index eafa5482..2ec5971e 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -193,6 +193,10 @@ func (dst MacaddrArray) Get() interface{} { func (src *MacaddrArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/macaddr_array_test.go b/macaddr_array_test.go index 647db8cf..6359a374 100644 --- a/macaddr_array_test.go +++ b/macaddr_array_test.go @@ -166,6 +166,11 @@ func TestMacaddrArrayAssignTo(t *testing.T) { dst: &macaddrSlice, expected: (([]net.HardwareAddr)(nil)), }, + { + src: pgtype.MacaddrArray{Status: pgtype.Present}, + dst: &macaddrSlice, + expected: (([]net.HardwareAddr)(nil)), + }, { src: pgtype.MacaddrArray{ Elements: []pgtype.Macaddr{ diff --git a/numeric_array.go b/numeric_array.go index 806557bc..7c044c8c 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -306,6 +306,10 @@ func (dst NumericArray) Get() interface{} { func (src *NumericArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/numeric_array_test.go b/numeric_array_test.go index 29300bf0..def8150d 100644 --- a/numeric_array_test.go +++ b/numeric_array_test.go @@ -190,6 +190,11 @@ func TestNumericArrayAssignTo(t *testing.T) { dst: &float32Slice, expected: (([]float32)(nil)), }, + { + src: pgtype.NumericArray{Status: pgtype.Present}, + dst: &float32Slice, + expected: (([]float32)(nil)), + }, { src: pgtype.NumericArray{ Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, diff --git a/text_array.go b/text_array.go index 03f72d37..01b5e6e6 100644 --- a/text_array.go +++ b/text_array.go @@ -192,6 +192,10 @@ func (dst TextArray) Get() interface{} { func (src *TextArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/text_array_test.go b/text_array_test.go index 125d6034..a538c617 100644 --- a/text_array_test.go +++ b/text_array_test.go @@ -168,6 +168,11 @@ func TestTextArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: (([]string)(nil)), }, + { + src: pgtype.TextArray{Status: pgtype.Present}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, { src: pgtype.TextArray{ Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, diff --git a/timestamp_array.go b/timestamp_array.go index 27f6e867..ee6037b0 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -193,6 +193,10 @@ func (dst TimestampArray) Get() interface{} { func (src *TimestampArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/timestamp_array_test.go b/timestamp_array_test.go index c6f32d20..85db94bb 100644 --- a/timestamp_array_test.go +++ b/timestamp_array_test.go @@ -162,6 +162,11 @@ func TestTimestampArrayAssignTo(t *testing.T) { dst: &timeSlice, expected: (([]time.Time)(nil)), }, + { + src: pgtype.TimestampArray{Status: pgtype.Present}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, { src: pgtype.TimestampArray{ Elements: []pgtype.Timestamp{ diff --git a/timestamptz_array.go b/timestamptz_array.go index 4db5c979..327b3ebc 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -193,6 +193,10 @@ func (dst TimestamptzArray) Get() interface{} { func (src *TimestamptzArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/timestamptz_array_test.go b/timestamptz_array_test.go index f4e80413..a4e1dded 100644 --- a/timestamptz_array_test.go +++ b/timestamptz_array_test.go @@ -198,6 +198,11 @@ func TestTimestamptzArrayAssignTo(t *testing.T) { dst: &timeSlice, expected: (([]time.Time)(nil)), }, + { + src: pgtype.TimestamptzArray{Status: pgtype.Present}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, { src: pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ diff --git a/tstzrange_array.go b/tstzrange_array.go index 2c9492f4..cac377af 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -154,6 +154,10 @@ func (dst TstzrangeArray) Get() interface{} { func (src *TstzrangeArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/typed_array.go.erb b/typed_array.go.erb index c4c797de..6d34b0e1 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -174,6 +174,10 @@ func (dst <%= pgtype_array_type %>) Get() interface{} { func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1{ // Attempt to match to select common types: switch v := dst.(type) { diff --git a/uuid_array.go b/uuid_array.go index 035fb114..33f2e62c 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -230,6 +230,10 @@ func (dst UUIDArray) Get() interface{} { func (src *UUIDArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/uuid_array_test.go b/uuid_array_test.go index cdb212bb..a1e14a04 100644 --- a/uuid_array_test.go +++ b/uuid_array_test.go @@ -214,6 +214,7 @@ func TestUUIDArrayAssignTo(t *testing.T) { var byteArraySlice [][16]byte var byteSliceSlice [][]byte var stringSlice []string + var byteSlice []byte var byteArraySliceDim2 [][][16]byte var stringSliceDim4 [][][][]string var byteArrayDim2 [2][1][16]byte @@ -252,6 +253,16 @@ func TestUUIDArrayAssignTo(t *testing.T) { dst: &byteSliceSlice, expected: ([][]byte)(nil), }, + { + src: pgtype.UUIDArray{Status: pgtype.Present}, + dst: &byteSlice, + expected: ([]byte)(nil), + }, + { + src: pgtype.UUIDArray{Status: pgtype.Present}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, { src: pgtype.UUIDArray{ Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, diff --git a/varchar_array.go b/varchar_array.go index 95ab48f3..b3b030b8 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -192,6 +192,10 @@ func (dst VarcharArray) Get() interface{} { func (src *VarcharArray) AssignTo(dst interface{}) error { switch src.Status { case Present: + if len(src.Elements) == 0 || len(src.Dimensions) == 0 { + // No values to assign + return nil + } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { diff --git a/varchar_array_test.go b/varchar_array_test.go index 3b0e65ed..ca9a15b7 100644 --- a/varchar_array_test.go +++ b/varchar_array_test.go @@ -168,6 +168,11 @@ func TestVarcharArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: (([]string)(nil)), }, + { + src: pgtype.VarcharArray{Status: pgtype.Present}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, { src: pgtype.VarcharArray{ Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, From 9d7fc8e63aa911ad288a838222a1fc2b359a3426 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Oct 2020 09:21:42 -0500 Subject: [PATCH 0598/1158] AssignTo pointer to pointer to slice and named types fixes #69 --- aclitem_array.go | 11 +++++------ bool_array.go | 11 +++++------ bpchar_array.go | 11 +++++------ bytea_array.go | 11 +++++------ cidr_array.go | 11 +++++------ date_array.go | 11 +++++------ enum_array.go | 11 +++++------ float4_array.go | 11 +++++------ float8_array.go | 11 +++++------ hstore_array.go | 11 +++++------ inet_array.go | 11 +++++------ int2_array.go | 11 +++++------ int4_array.go | 11 +++++------ int8_array.go | 11 +++++------ jsonb_array.go | 11 +++++------ macaddr_array.go | 11 +++++------ numeric_array.go | 11 +++++------ text_array.go | 11 +++++------ timestamp_array.go | 11 +++++------ timestamptz_array.go | 11 +++++------ tstzrange_array.go | 11 +++++------ typed_array.go.erb | 11 +++++------ uuid_array.go | 11 +++++------ varchar_array.go | 11 +++++------ 24 files changed, 120 insertions(+), 144 deletions(-) diff --git a/aclitem_array.go b/aclitem_array.go index 95f74aa7..229136ec 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -219,6 +219,11 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -226,12 +231,6 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/bool_array.go b/bool_array.go index c9447db1..9c960b0f 100644 --- a/bool_array.go +++ b/bool_array.go @@ -221,6 +221,11 @@ func (src *BoolArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -228,12 +233,6 @@ func (src *BoolArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/bpchar_array.go b/bpchar_array.go index f814930f..89a14aa0 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -221,6 +221,11 @@ func (src *BPCharArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -228,12 +233,6 @@ func (src *BPCharArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/bytea_array.go b/bytea_array.go index 618a2f4b..425ef954 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -193,6 +193,11 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -200,12 +205,6 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/cidr_array.go b/cidr_array.go index 8ea7d7a6..4ad10d8b 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -250,6 +250,11 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -257,12 +262,6 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/date_array.go b/date_array.go index dc4cb2e3..b29eee67 100644 --- a/date_array.go +++ b/date_array.go @@ -222,6 +222,11 @@ func (src *DateArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -229,12 +234,6 @@ func (src *DateArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/enum_array.go b/enum_array.go index f5312a04..76caac95 100644 --- a/enum_array.go +++ b/enum_array.go @@ -219,6 +219,11 @@ func (src *EnumArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -226,12 +231,6 @@ func (src *EnumArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/float4_array.go b/float4_array.go index 88dd84ab..d314563c 100644 --- a/float4_array.go +++ b/float4_array.go @@ -221,6 +221,11 @@ func (src *Float4Array) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -228,12 +233,6 @@ func (src *Float4Array) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/float8_array.go b/float8_array.go index 9d79a449..60d1a6d2 100644 --- a/float8_array.go +++ b/float8_array.go @@ -221,6 +221,11 @@ func (src *Float8Array) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -228,12 +233,6 @@ func (src *Float8Array) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/hstore_array.go b/hstore_array.go index d0b34b3c..02abe870 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -193,6 +193,11 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -200,12 +205,6 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/inet_array.go b/inet_array.go index 2058db81..4f8211ab 100644 --- a/inet_array.go +++ b/inet_array.go @@ -250,6 +250,11 @@ func (src *InetArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -257,12 +262,6 @@ func (src *InetArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/int2_array.go b/int2_array.go index bf6a6284..180db652 100644 --- a/int2_array.go +++ b/int2_array.go @@ -613,6 +613,11 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -620,12 +625,6 @@ func (src *Int2Array) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/int4_array.go b/int4_array.go index 05e10dc3..d36071a0 100644 --- a/int4_array.go +++ b/int4_array.go @@ -613,6 +613,11 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -620,12 +625,6 @@ func (src *Int4Array) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/int8_array.go b/int8_array.go index d149558f..3adb2f02 100644 --- a/int8_array.go +++ b/int8_array.go @@ -613,6 +613,11 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -620,12 +625,6 @@ func (src *Int8Array) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/jsonb_array.go b/jsonb_array.go index 36411b9d..562b0654 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -221,6 +221,11 @@ func (src *JSONBArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -228,12 +233,6 @@ func (src *JSONBArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/macaddr_array.go b/macaddr_array.go index 2ec5971e..511cd9ca 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -222,6 +222,11 @@ func (src *MacaddrArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -229,12 +234,6 @@ func (src *MacaddrArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/numeric_array.go b/numeric_array.go index 7c044c8c..e3c18600 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -389,6 +389,11 @@ func (src *NumericArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -396,12 +401,6 @@ func (src *NumericArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/text_array.go b/text_array.go index 01b5e6e6..5d0215c2 100644 --- a/text_array.go +++ b/text_array.go @@ -221,6 +221,11 @@ func (src *TextArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -228,12 +233,6 @@ func (src *TextArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/timestamp_array.go b/timestamp_array.go index ee6037b0..2495f2c9 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -222,6 +222,11 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -229,12 +234,6 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/timestamptz_array.go b/timestamptz_array.go index 327b3ebc..7ebcf9da 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -222,6 +222,11 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -229,12 +234,6 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/tstzrange_array.go b/tstzrange_array.go index cac377af..dae022d0 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -174,6 +174,11 @@ func (src *TstzrangeArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -181,12 +186,6 @@ func (src *TstzrangeArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/typed_array.go.erb b/typed_array.go.erb index 6d34b0e1..9951bfcb 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -194,6 +194,11 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -201,12 +206,6 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/uuid_array.go b/uuid_array.go index 33f2e62c..89cadd91 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -277,6 +277,11 @@ func (src *UUIDArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -284,12 +289,6 @@ func (src *UUIDArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { diff --git a/varchar_array.go b/varchar_array.go index b3b030b8..fd8de8a4 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -221,6 +221,11 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { } } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + // Fallback to reflection if an optimised match was not found. // The reflection is necessary for arrays and multidimensional slices, // but it comes with a 20-50% performance penalty for large arrays/slices @@ -228,12 +233,6 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { if value.Kind() == reflect.Ptr { value = value.Elem() } - if !value.CanSet() { - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return errors.Errorf("unable to assign to %T", dst) - } elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { From f3f5b70a872eb9875c7bc0cbc6f7b3876c08d92b Mon Sep 17 00:00:00 2001 From: Feike Steenbergen Date: Thu, 29 Oct 2020 18:59:15 +0100 Subject: [PATCH 0599/1158] Ensure the example code snippet compiles again There were 2 errors when using the example code: - not enough arguments in call to pgConn.Close - no new variables on left side of := With these changes, the example works again. --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 5d14e914..6a68e230 100644 --- a/README.md +++ b/README.md @@ -15,16 +15,16 @@ pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { log.Fatalln("pgconn failed to connect:", err) } -defer pgConn.Close() +defer pgConn.Close(context.Background()) result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) for result.NextRow() { fmt.Println("User 123 has email:", string(result.Values()[0])) } -_, err := result.Close() +_, err = result.Close() if err != nil { log.Fatalln("failed reading result:", err) -}) +} ``` ## Testing From 340bfece2c33b6375414a694688d05b56f6c31af Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 29 Oct 2020 21:20:28 -0500 Subject: [PATCH 0600/1158] Do not asyncClose in response to a FATAL PG error This will reduce spurious server log messages on authentication failures. See https://github.com/jackc/pgconn/pull/53. --- pgconn.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index ff812069..3652cedb 100644 --- a/pgconn.go +++ b/pgconn.go @@ -485,7 +485,8 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: if msg.Severity == "FATAL" { - pgConn.asyncClose() + pgConn.status = connStatusClosed + pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: From 9c2888b49ee8af394820dd9dd5c66ec81cea7685 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Oct 2020 16:25:01 -0500 Subject: [PATCH 0601/1158] Release v1.7.1 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e7444fcd..e9753526 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.7.1 (October 31, 2020) + +* Do not asyncClose after receiving FATAL error from PostgreSQL server + # 1.7.0 (September 26, 2020) * Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded From af0ca3a39b16dc19b75f40fd5fe38a79c9b0b5a8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Oct 2020 17:12:16 -0500 Subject: [PATCH 0602/1158] Fix simple protocol empty array and original recursive empty array issue Original issue https://github.com/jackc/pgtype/issues/68 This crash occurred in the recursive assignment system used to support multidimensional arrays. This was fixed in 9639a69d451f55456f598c1aa8b93053f8df3088. However, that fix incorrectly used nil instead of an empty slice. In hindsight, it appears the fundamental error is that an assignment to a slice of a type that is not specified is handled with the recursive / reflection path. Or another way of looking at it is as an unexpected feature where []T can now be scanned if individual elements are assignable to T even if []T is not specifically handled. But this new reflection / recursive path did not handle empty arrays. This fix handles the reflection path for an empty slice by allocating an empty slice. --- aclitem_array.go | 11 +++++++---- aclitem_array_test.go | 2 +- bool_array.go | 11 +++++++---- bool_array_test.go | 2 +- bpchar_array.go | 11 +++++++---- bytea_array.go | 11 +++++++---- bytea_array_test.go | 2 +- cidr_array.go | 11 +++++++---- cidr_array_test.go | 4 ++-- date_array.go | 11 +++++++---- date_array_test.go | 2 +- enum_array.go | 11 +++++++---- enum_array_test.go | 2 +- float4_array.go | 11 +++++++---- float4_array_test.go | 2 +- float8_array.go | 11 +++++++---- float8_array_test.go | 2 +- hstore_array.go | 11 +++++++---- hstore_array_test.go | 2 +- inet_array.go | 11 +++++++---- inet_array_test.go | 4 ++-- int2_array.go | 11 +++++++---- int2_array_test.go | 2 +- int4_array.go | 11 +++++++---- int4_array_test.go | 2 +- int8_array.go | 11 +++++++---- int8_array_test.go | 2 +- jsonb_array.go | 11 +++++++---- macaddr_array.go | 11 +++++++---- macaddr_array_test.go | 2 +- numeric_array.go | 11 +++++++---- numeric_array_test.go | 2 +- text_array.go | 11 +++++++---- text_array_test.go | 2 +- timestamp_array.go | 11 +++++++---- timestamp_array_test.go | 2 +- timestamptz_array.go | 11 +++++++---- timestamptz_array_test.go | 2 +- tstzrange_array.go | 11 +++++++---- typed_array.go.erb | 11 +++++++---- uuid_array.go | 11 +++++++---- uuid_array_test.go | 4 ++-- varchar_array.go | 11 +++++++---- varchar_array_test.go | 2 +- 44 files changed, 191 insertions(+), 119 deletions(-) diff --git a/aclitem_array.go b/aclitem_array.go index 229136ec..f4b14433 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -190,10 +190,6 @@ func (dst ACLItemArray) Get() interface{} { func (src *ACLItemArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -232,6 +228,13 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/aclitem_array_test.go b/aclitem_array_test.go index 8ebb8c8d..8f015f40 100644 --- a/aclitem_array_test.go +++ b/aclitem_array_test.go @@ -192,7 +192,7 @@ func TestACLItemArrayAssignTo(t *testing.T) { { src: pgtype.ACLItemArray{Status: pgtype.Present}, dst: &stringSlice, - expected: (([]string)(nil)), + expected: []string{}, }, { src: pgtype.ACLItemArray{ diff --git a/bool_array.go b/bool_array.go index 9c960b0f..41c6deda 100644 --- a/bool_array.go +++ b/bool_array.go @@ -192,10 +192,6 @@ func (dst BoolArray) Get() interface{} { func (src *BoolArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -234,6 +230,13 @@ func (src *BoolArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/bool_array_test.go b/bool_array_test.go index 4ff26bb5..be567e59 100644 --- a/bool_array_test.go +++ b/bool_array_test.go @@ -171,7 +171,7 @@ func TestBoolArrayAssignTo(t *testing.T) { { src: pgtype.BoolArray{Status: pgtype.Present}, dst: &boolSlice, - expected: (([]bool)(nil)), + expected: []bool{}, }, { src: pgtype.BoolArray{ diff --git a/bpchar_array.go b/bpchar_array.go index 89a14aa0..5fd7381a 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -192,10 +192,6 @@ func (dst BPCharArray) Get() interface{} { func (src *BPCharArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -234,6 +230,13 @@ func (src *BPCharArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/bytea_array.go b/bytea_array.go index 425ef954..9b5c9ee9 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -173,10 +173,6 @@ func (dst ByteaArray) Get() interface{} { func (src *ByteaArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -206,6 +202,13 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/bytea_array_test.go b/bytea_array_test.go index 4964771c..27c0382e 100644 --- a/bytea_array_test.go +++ b/bytea_array_test.go @@ -160,7 +160,7 @@ func TestByteaArrayAssignTo(t *testing.T) { { src: pgtype.ByteaArray{Status: pgtype.Present}, dst: &byteByteSlice, - expected: (([][]byte)(nil)), + expected: [][]byte{}, }, { src: pgtype.ByteaArray{ diff --git a/cidr_array.go b/cidr_array.go index 4ad10d8b..06192ddd 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -212,10 +212,6 @@ func (dst CIDRArray) Get() interface{} { func (src *CIDRArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -263,6 +259,13 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/cidr_array_test.go b/cidr_array_test.go index aa933b62..74c063fa 100644 --- a/cidr_array_test.go +++ b/cidr_array_test.go @@ -220,7 +220,7 @@ func TestCIDRArrayAssignTo(t *testing.T) { { src: pgtype.CIDRArray{Status: pgtype.Present}, dst: &ipnetSlice, - expected: (([]*net.IPNet)(nil)), + expected: []*net.IPNet{}, }, { src: pgtype.CIDRArray{Status: pgtype.Null}, @@ -230,7 +230,7 @@ func TestCIDRArrayAssignTo(t *testing.T) { { src: pgtype.CIDRArray{Status: pgtype.Present}, dst: &ipSlice, - expected: (([]net.IP)(nil)), + expected: []net.IP{}, }, { src: pgtype.CIDRArray{ diff --git a/date_array.go b/date_array.go index b29eee67..1961bf20 100644 --- a/date_array.go +++ b/date_array.go @@ -193,10 +193,6 @@ func (dst DateArray) Get() interface{} { func (src *DateArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -235,6 +231,13 @@ func (src *DateArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/date_array_test.go b/date_array_test.go index 8791c31f..4458abfe 100644 --- a/date_array_test.go +++ b/date_array_test.go @@ -185,7 +185,7 @@ func TestDateArrayAssignTo(t *testing.T) { { src: pgtype.DateArray{Status: pgtype.Present}, dst: &timeSlice, - expected: (([]time.Time)(nil)), + expected: []time.Time{}, }, { src: pgtype.DateArray{ diff --git a/enum_array.go b/enum_array.go index 76caac95..ebe838ad 100644 --- a/enum_array.go +++ b/enum_array.go @@ -190,10 +190,6 @@ func (dst EnumArray) Get() interface{} { func (src *EnumArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -232,6 +228,13 @@ func (src *EnumArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/enum_array_test.go b/enum_array_test.go index 9db8b49f..659340f0 100644 --- a/enum_array_test.go +++ b/enum_array_test.go @@ -170,7 +170,7 @@ func TestEnumArrayArrayAssignTo(t *testing.T) { { src: pgtype.EnumArray{Status: pgtype.Present}, dst: &stringSlice, - expected: (([]string)(nil)), + expected: []string{}, }, { src: pgtype.EnumArray{ diff --git a/float4_array.go b/float4_array.go index d314563c..44ba1fee 100644 --- a/float4_array.go +++ b/float4_array.go @@ -192,10 +192,6 @@ func (dst Float4Array) Get() interface{} { func (src *Float4Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -234,6 +230,13 @@ func (src *Float4Array) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/float4_array_test.go b/float4_array_test.go index 88d35fd6..db438999 100644 --- a/float4_array_test.go +++ b/float4_array_test.go @@ -170,7 +170,7 @@ func TestFloat4ArrayAssignTo(t *testing.T) { { src: pgtype.Float4Array{Status: pgtype.Present}, dst: &float32Slice, - expected: (([]float32)(nil)), + expected: []float32{}, }, { src: pgtype.Float4Array{ diff --git a/float8_array.go b/float8_array.go index 60d1a6d2..1065190d 100644 --- a/float8_array.go +++ b/float8_array.go @@ -192,10 +192,6 @@ func (dst Float8Array) Get() interface{} { func (src *Float8Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -234,6 +230,13 @@ func (src *Float8Array) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/float8_array_test.go b/float8_array_test.go index d7bf6ac3..85cb8f43 100644 --- a/float8_array_test.go +++ b/float8_array_test.go @@ -146,7 +146,7 @@ func TestFloat8ArrayAssignTo(t *testing.T) { { src: pgtype.Float8Array{Status: pgtype.Present}, dst: &float64Slice, - expected: (([]float64)(nil)), + expected: []float64{}, }, { src: pgtype.Float8Array{ diff --git a/hstore_array.go b/hstore_array.go index 02abe870..3899ae49 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -173,10 +173,6 @@ func (dst HstoreArray) Get() interface{} { func (src *HstoreArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -206,6 +202,13 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/hstore_array_test.go b/hstore_array_test.go index 3d85545a..672eca4a 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -303,7 +303,7 @@ func TestHstoreArrayAssignTo(t *testing.T) { src: pgtype.HstoreArray{Status: pgtype.Null}, dst: &hstoreSlice, expected: (([]map[string]string)(nil)), }, { - src: pgtype.HstoreArray{Status: pgtype.Present}, dst: &hstoreSlice, expected: (([]map[string]string)(nil)), + src: pgtype.HstoreArray{Status: pgtype.Present}, dst: &hstoreSlice, expected: []map[string]string{}, }, { src: pgtype.HstoreArray{ diff --git a/inet_array.go b/inet_array.go index 4f8211ab..5de138c0 100644 --- a/inet_array.go +++ b/inet_array.go @@ -212,10 +212,6 @@ func (dst InetArray) Get() interface{} { func (src *InetArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -263,6 +259,13 @@ func (src *InetArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/inet_array_test.go b/inet_array_test.go index 5beab960..46dc7d12 100644 --- a/inet_array_test.go +++ b/inet_array_test.go @@ -220,7 +220,7 @@ func TestInetArrayAssignTo(t *testing.T) { { src: pgtype.InetArray{Status: pgtype.Present}, dst: &ipnetSlice, - expected: (([]*net.IPNet)(nil)), + expected: []*net.IPNet{}, }, { src: pgtype.InetArray{Status: pgtype.Null}, @@ -230,7 +230,7 @@ func TestInetArrayAssignTo(t *testing.T) { { src: pgtype.InetArray{Status: pgtype.Present}, dst: &ipSlice, - expected: (([]net.IP)(nil)), + expected: []net.IP{}, }, { src: pgtype.InetArray{ diff --git a/int2_array.go b/int2_array.go index 180db652..6b4e4c8a 100644 --- a/int2_array.go +++ b/int2_array.go @@ -458,10 +458,6 @@ func (dst Int2Array) Get() interface{} { func (src *Int2Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -626,6 +622,13 @@ func (src *Int2Array) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/int2_array_test.go b/int2_array_test.go index da669f7d..17c37360 100644 --- a/int2_array_test.go +++ b/int2_array_test.go @@ -222,7 +222,7 @@ func TestInt2ArrayAssignTo(t *testing.T) { { src: pgtype.Int2Array{Status: pgtype.Present}, dst: &int16Slice, - expected: (([]int16)(nil)), + expected: []int16{}, }, { src: pgtype.Int2Array{ diff --git a/int4_array.go b/int4_array.go index d36071a0..8801947d 100644 --- a/int4_array.go +++ b/int4_array.go @@ -458,10 +458,6 @@ func (dst Int4Array) Get() interface{} { func (src *Int4Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -626,6 +622,13 @@ func (src *Int4Array) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/int4_array_test.go b/int4_array_test.go index a5aad827..110512a9 100644 --- a/int4_array_test.go +++ b/int4_array_test.go @@ -236,7 +236,7 @@ func TestInt4ArrayAssignTo(t *testing.T) { { src: pgtype.Int4Array{Status: pgtype.Present}, dst: &int32Slice, - expected: (([]int32)(nil)), + expected: []int32{}, }, { src: pgtype.Int4Array{ diff --git a/int8_array.go b/int8_array.go index 3adb2f02..13e20fca 100644 --- a/int8_array.go +++ b/int8_array.go @@ -458,10 +458,6 @@ func (dst Int8Array) Get() interface{} { func (src *Int8Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -626,6 +622,13 @@ func (src *Int8Array) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/int8_array_test.go b/int8_array_test.go index b0ee97ee..1d42a278 100644 --- a/int8_array_test.go +++ b/int8_array_test.go @@ -229,7 +229,7 @@ func TestInt8ArrayAssignTo(t *testing.T) { { src: pgtype.Int8Array{Status: pgtype.Present}, dst: &int64Slice, - expected: (([]int64)(nil)), + expected: []int64{}, }, { src: pgtype.Int8Array{ diff --git a/jsonb_array.go b/jsonb_array.go index 562b0654..f44f7fa5 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -192,10 +192,6 @@ func (dst JSONBArray) Get() interface{} { func (src *JSONBArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -234,6 +230,13 @@ func (src *JSONBArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/macaddr_array.go b/macaddr_array.go index 511cd9ca..5a27046f 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -193,10 +193,6 @@ func (dst MacaddrArray) Get() interface{} { func (src *MacaddrArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -235,6 +231,13 @@ func (src *MacaddrArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/macaddr_array_test.go b/macaddr_array_test.go index 6359a374..c1a8b72d 100644 --- a/macaddr_array_test.go +++ b/macaddr_array_test.go @@ -169,7 +169,7 @@ func TestMacaddrArrayAssignTo(t *testing.T) { { src: pgtype.MacaddrArray{Status: pgtype.Present}, dst: &macaddrSlice, - expected: (([]net.HardwareAddr)(nil)), + expected: []net.HardwareAddr{}, }, { src: pgtype.MacaddrArray{ diff --git a/numeric_array.go b/numeric_array.go index e3c18600..c281bfb3 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -306,10 +306,6 @@ func (dst NumericArray) Get() interface{} { func (src *NumericArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -402,6 +398,13 @@ func (src *NumericArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/numeric_array_test.go b/numeric_array_test.go index def8150d..7c1e8c3b 100644 --- a/numeric_array_test.go +++ b/numeric_array_test.go @@ -193,7 +193,7 @@ func TestNumericArrayAssignTo(t *testing.T) { { src: pgtype.NumericArray{Status: pgtype.Present}, dst: &float32Slice, - expected: (([]float32)(nil)), + expected: []float32{}, }, { src: pgtype.NumericArray{ diff --git a/text_array.go b/text_array.go index 5d0215c2..599764d8 100644 --- a/text_array.go +++ b/text_array.go @@ -192,10 +192,6 @@ func (dst TextArray) Get() interface{} { func (src *TextArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -234,6 +230,13 @@ func (src *TextArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/text_array_test.go b/text_array_test.go index a538c617..5a2317e3 100644 --- a/text_array_test.go +++ b/text_array_test.go @@ -171,7 +171,7 @@ func TestTextArrayAssignTo(t *testing.T) { { src: pgtype.TextArray{Status: pgtype.Present}, dst: &stringSlice, - expected: (([]string)(nil)), + expected: []string{}, }, { src: pgtype.TextArray{ diff --git a/timestamp_array.go b/timestamp_array.go index 2495f2c9..2f7176b8 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -193,10 +193,6 @@ func (dst TimestampArray) Get() interface{} { func (src *TimestampArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -235,6 +231,13 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/timestamp_array_test.go b/timestamp_array_test.go index 85db94bb..54d15b24 100644 --- a/timestamp_array_test.go +++ b/timestamp_array_test.go @@ -165,7 +165,7 @@ func TestTimestampArrayAssignTo(t *testing.T) { { src: pgtype.TimestampArray{Status: pgtype.Present}, dst: &timeSlice, - expected: (([]time.Time)(nil)), + expected: []time.Time{}, }, { src: pgtype.TimestampArray{ diff --git a/timestamptz_array.go b/timestamptz_array.go index 7ebcf9da..a10aaa8b 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -193,10 +193,6 @@ func (dst TimestamptzArray) Get() interface{} { func (src *TimestamptzArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -235,6 +231,13 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/timestamptz_array_test.go b/timestamptz_array_test.go index a4e1dded..9856e4e7 100644 --- a/timestamptz_array_test.go +++ b/timestamptz_array_test.go @@ -201,7 +201,7 @@ func TestTimestamptzArrayAssignTo(t *testing.T) { { src: pgtype.TimestamptzArray{Status: pgtype.Present}, dst: &timeSlice, - expected: (([]time.Time)(nil)), + expected: []time.Time{}, }, { src: pgtype.TimestamptzArray{ diff --git a/tstzrange_array.go b/tstzrange_array.go index dae022d0..7e57acfe 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -154,10 +154,6 @@ func (dst TstzrangeArray) Get() interface{} { func (src *TstzrangeArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -187,6 +183,13 @@ func (src *TstzrangeArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/typed_array.go.erb b/typed_array.go.erb index 9951bfcb..eb1a642e 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -174,10 +174,6 @@ func (dst <%= pgtype_array_type %>) Get() interface{} { func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1{ // Attempt to match to select common types: switch v := dst.(type) { @@ -207,6 +203,13 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/uuid_array.go b/uuid_array.go index 89cadd91..fc1ea3b3 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -230,10 +230,6 @@ func (dst UUIDArray) Get() interface{} { func (src *UUIDArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -290,6 +286,13 @@ func (src *UUIDArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/uuid_array_test.go b/uuid_array_test.go index a1e14a04..7d822e7a 100644 --- a/uuid_array_test.go +++ b/uuid_array_test.go @@ -256,12 +256,12 @@ func TestUUIDArrayAssignTo(t *testing.T) { { src: pgtype.UUIDArray{Status: pgtype.Present}, dst: &byteSlice, - expected: ([]byte)(nil), + expected: []byte{}, }, { src: pgtype.UUIDArray{Status: pgtype.Present}, dst: &stringSlice, - expected: (([]string)(nil)), + expected: []string{}, }, { src: pgtype.UUIDArray{ diff --git a/varchar_array.go b/varchar_array.go index fd8de8a4..9326c72d 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -192,10 +192,6 @@ func (dst VarcharArray) Get() interface{} { func (src *VarcharArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - if len(src.Elements) == 0 || len(src.Dimensions) == 0 { - // No values to assign - return nil - } if len(src.Dimensions) <= 1 { // Attempt to match to select common types: switch v := dst.(type) { @@ -234,6 +230,13 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { value = value.Elem() } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + elementCount, err := src.assignToRecursive(value, 0, 0) if err != nil { return err diff --git a/varchar_array_test.go b/varchar_array_test.go index ca9a15b7..5fb7326d 100644 --- a/varchar_array_test.go +++ b/varchar_array_test.go @@ -171,7 +171,7 @@ func TestVarcharArrayAssignTo(t *testing.T) { { src: pgtype.VarcharArray{Status: pgtype.Present}, dst: &stringSlice, - expected: (([]string)(nil)), + expected: []string{}, }, { src: pgtype.VarcharArray{ From 36a8da55cc3dffea7318d45c4b8c7f8ac5dd1dde Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 3 Nov 2020 08:28:53 -0600 Subject: [PATCH 0603/1158] Fix Timestamptz.DecodeText with too short text fixes #74 --- timestamp_test.go | 8 ++++++++ timestamptz.go | 4 ++-- timestamptz_test.go | 8 ++++++++ tstzrange_test.go | 8 ++++++++ 4 files changed, 26 insertions(+), 2 deletions(-) diff --git a/timestamp_test.go b/timestamp_test.go index b2fbda94..74cb1221 100644 --- a/timestamp_test.go +++ b/timestamp_test.go @@ -7,6 +7,7 @@ import ( "github.com/jackc/pgtype" "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" ) func TestTimestampTranscode(t *testing.T) { @@ -77,6 +78,13 @@ func TestTimestampNanosecondsTruncated(t *testing.T) { } } +// https://github.com/jackc/pgtype/issues/74 +func TestTimestampDecodeTextInvalid(t *testing.T) { + tstz := &pgtype.Timestamp{} + err := tstz.DecodeText(nil, []byte(`eeeee`)) + require.Error(t, err) +} + func TestTimestampSet(t *testing.T) { type _time time.Time diff --git a/timestamptz.go b/timestamptz.go index d54974af..a79bd66e 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -111,9 +111,9 @@ func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} default: var format string - if sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+' { + if len(sbuf) >= 9 && (sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+') { format = pgTimestamptzSecondFormat - } else if sbuf[len(sbuf)-6] == '-' || sbuf[len(sbuf)-6] == '+' { + } else if len(sbuf) >= 6 && (sbuf[len(sbuf)-6] == '-' || sbuf[len(sbuf)-6] == '+') { format = pgTimestamptzMinuteFormat } else { format = pgTimestamptzHourFormat diff --git a/timestamptz_test.go b/timestamptz_test.go index 828184b7..769c9239 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -7,6 +7,7 @@ import ( "github.com/jackc/pgtype" "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" ) func TestTimestamptzTranscode(t *testing.T) { @@ -77,6 +78,13 @@ func TestTimestamptzNanosecondsTruncated(t *testing.T) { } } +// https://github.com/jackc/pgtype/issues/74 +func TestTimestamptzDecodeTextInvalid(t *testing.T) { + tstz := &pgtype.Timestamptz{} + err := tstz.DecodeText(nil, []byte(`eeeee`)) + require.Error(t, err) +} + func TestTimestamptzSet(t *testing.T) { type _time time.Time diff --git a/tstzrange_test.go b/tstzrange_test.go index b3d3ff6c..f8e2c2c5 100644 --- a/tstzrange_test.go +++ b/tstzrange_test.go @@ -6,6 +6,7 @@ import ( "github.com/jackc/pgtype" "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" ) func TestTstzrangeTranscode(t *testing.T) { @@ -39,3 +40,10 @@ func TestTstzrangeTranscode(t *testing.T) { a.Upper.InfinityModifier == b.Upper.InfinityModifier }) } + +// https://github.com/jackc/pgtype/issues/74 +func TestTstzRangeDecodeTextInvalid(t *testing.T) { + tstzrange := &pgtype.Tstzrange{} + err := tstzrange.DecodeText(nil, []byte(`[eeee,)`)) + require.Error(t, err) +} From c34a8731b6a6e347de08f8326371255ab2d7da0f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 3 Nov 2020 19:15:07 -0600 Subject: [PATCH 0604/1158] Data row value slices need to be capacity limited Otherwise, appending to a slice that came from a data row could overwrite adjacent memory. --- data_row.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_row.go b/data_row.go index d908e7b2..5fa3c5d8 100644 --- a/data_row.go +++ b/data_row.go @@ -54,7 +54,7 @@ func (dst *DataRow) Decode(src []byte) error { return &invalidMessageFormatErr{messageType: "DataRow"} } - dst.Values[i] = src[rp : rp+msgSize] + dst.Values[i] = src[rp : rp+msgSize : rp+msgSize] rp += msgSize } } From 0f17ba2cf3b307aeddfa5cd6ada0d1fe7ad3e46c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 3 Nov 2020 19:17:52 -0600 Subject: [PATCH 0605/1158] Fix unconstrained data value slices See https://github.com/jackc/pgx/issues/859 --- go.mod | 2 +- go.sum | 2 ++ pgconn_test.go | 26 ++++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index f2c10401..7e578765 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.5 + github.com/jackc/pgproto3/v2 v2.0.6 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 diff --git a/go.sum b/go.sum index 08a11e19..f3eb0e08 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ github.com/jackc/pgproto3/v2 v2.0.4 h1:RHkX5ZUD9bl/kn0f9dYUWs1N7Nwvo1wwUYvKiR26Z github.com/jackc/pgproto3/v2 v2.0.4/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.0.5 h1:NUbEWPmCQZbMmYlTjVoNPhc0CfnYyz2bfUAh6A5ZVJM= github.com/jackc/pgproto3/v2 v2.0.5/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.0.6 h1:b1105ZGEMFe7aCvrT1Cca3VoVb4ZFMaFJLJcg/3zD+8= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= diff --git a/pgconn_test.go b/pgconn_test.go index 24200e73..b71e7d3f 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -758,6 +758,32 @@ func TestConnExecParamsEmptySQL(t *testing.T) { ensureConnValid(t, pgConn) } +// https://github.com/jackc/pgx/issues/859 +func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) + + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + assert.Equal(t, len(result.Values()[0]), cap(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) + + ensureConnValid(t, pgConn) +} + func TestConnExecPrepared(t *testing.T) { t.Parallel() From b82b993fa8aa3fd6d8aac15689301db049d5504f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 3 Nov 2020 19:20:03 -0600 Subject: [PATCH 0606/1158] Release v1.7.2 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9753526..92b1de06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.7.2 (November 3, 2020) + +* Fix data value slices into work buffer with capacities larger than length. + # 1.7.1 (October 31, 2020) * Do not asyncClose after receiving FATAL error from PostgreSQL server From 740b3a511515d881db506c90190948f4f312dd31 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Nov 2020 07:31:56 -0600 Subject: [PATCH 0607/1158] Fix: Text array parsing disambiguates NULL and "NULL". This solution is a little awkward, but it avoids breaking backwards compatibility. fixes #78 --- aclitem_array.go | 2 +- array.go | 22 ++++++++++++---------- array_test.go | 10 ++++++++++ bool_array.go | 2 +- bpchar_array.go | 2 +- bytea_array.go | 2 +- cidr_array.go | 2 +- date_array.go | 2 +- enum_array.go | 2 +- float4_array.go | 2 +- float8_array.go | 2 +- hstore_array.go | 2 +- inet_array.go | 2 +- int2_array.go | 2 +- int4_array.go | 2 +- int8_array.go | 2 +- jsonb_array.go | 2 +- macaddr_array.go | 2 +- numeric_array.go | 2 +- text_array.go | 2 +- text_array_test.go | 12 ++++++++++++ timestamp_array.go | 2 +- timestamptz_array.go | 2 +- tstzrange_array.go | 2 +- typed_array.go.erb | 2 +- uuid_array.go | 2 +- varchar_array.go | 2 +- 27 files changed, 58 insertions(+), 34 deletions(-) diff --git a/aclitem_array.go b/aclitem_array.go index f4b14433..501074d6 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -317,7 +317,7 @@ func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem ACLItem var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/array.go b/array.go index b779cd9d..93c91897 100644 --- a/array.go +++ b/array.go @@ -82,6 +82,7 @@ func (src ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte { type UntypedTextArray struct { Elements []string + Quoted []bool Dimensions []ArrayDimension } @@ -196,13 +197,14 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { } default: buf.UnreadRune() - value, err := arrayParseValue(buf) + value, quoted, err := arrayParseValue(buf) if err != nil { return nil, errors.Errorf("invalid array value: %v", err) } if currentDim == counterDim { implicitDimensions[currentDim].Length++ } + dst.Quoted = append(dst.Quoted, quoted) dst.Elements = append(dst.Elements, value) } @@ -239,10 +241,10 @@ func skipWhitespace(buf *bytes.Buffer) { } } -func arrayParseValue(buf *bytes.Buffer) (string, error) { +func arrayParseValue(buf *bytes.Buffer) (string, bool, error) { r, _, err := buf.ReadRune() if err != nil { - return "", err + return "", false, err } if r == '"' { return arrayParseQuotedValue(buf) @@ -254,41 +256,41 @@ func arrayParseValue(buf *bytes.Buffer) (string, error) { for { r, _, err := buf.ReadRune() if err != nil { - return "", err + return "", false, err } switch r { case ',', '}': buf.UnreadRune() - return s.String(), nil + return s.String(), false, nil } s.WriteRune(r) } } -func arrayParseQuotedValue(buf *bytes.Buffer) (string, error) { +func arrayParseQuotedValue(buf *bytes.Buffer) (string, bool, error) { s := &bytes.Buffer{} for { r, _, err := buf.ReadRune() if err != nil { - return "", err + return "", false, err } switch r { case '\\': r, _, err = buf.ReadRune() if err != nil { - return "", err + return "", false, err } case '"': r, _, err = buf.ReadRune() if err != nil { - return "", err + return "", false, err } buf.UnreadRune() - return s.String(), nil + return s.String(), true, nil } s.WriteRune(r) } diff --git a/array_test.go b/array_test.go index 486171b8..2f3b9237 100644 --- a/array_test.go +++ b/array_test.go @@ -16,6 +16,7 @@ func TestParseUntypedTextArray(t *testing.T) { source: "{}", result: pgtype.UntypedTextArray{ Elements: nil, + Quoted: nil, Dimensions: nil, }, }, @@ -23,6 +24,7 @@ func TestParseUntypedTextArray(t *testing.T) { source: "{1}", result: pgtype.UntypedTextArray{ Elements: []string{"1"}, + Quoted: []bool{false}, Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, }, }, @@ -30,6 +32,7 @@ func TestParseUntypedTextArray(t *testing.T) { source: "{a,b}", result: pgtype.UntypedTextArray{ Elements: []string{"a", "b"}, + Quoted: []bool{false, false}, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, }, }, @@ -37,6 +40,7 @@ func TestParseUntypedTextArray(t *testing.T) { source: `{"NULL"}`, result: pgtype.UntypedTextArray{ Elements: []string{"NULL"}, + Quoted: []bool{true}, Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, }, }, @@ -44,6 +48,7 @@ func TestParseUntypedTextArray(t *testing.T) { source: `{""}`, result: pgtype.UntypedTextArray{ Elements: []string{""}, + Quoted: []bool{true}, Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, }, }, @@ -51,6 +56,7 @@ func TestParseUntypedTextArray(t *testing.T) { source: `{"He said, \"Hello.\""}`, result: pgtype.UntypedTextArray{ Elements: []string{`He said, "Hello."`}, + Quoted: []bool{true}, Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, }, }, @@ -58,6 +64,7 @@ func TestParseUntypedTextArray(t *testing.T) { source: "{{a,b},{c,d},{e,f}}", result: pgtype.UntypedTextArray{ Elements: []string{"a", "b", "c", "d", "e", "f"}, + Quoted: []bool{false, false, false, false, false, false}, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, }, }, @@ -65,6 +72,7 @@ func TestParseUntypedTextArray(t *testing.T) { source: "{{{a,b},{c,d},{e,f}},{{a,b},{c,d},{e,f}}}", result: pgtype.UntypedTextArray{ Elements: []string{"a", "b", "c", "d", "e", "f", "a", "b", "c", "d", "e", "f"}, + Quoted: []bool{false, false, false, false, false, false, false, false, false, false, false, false}, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 1}, {Length: 3, LowerBound: 1}, @@ -76,6 +84,7 @@ func TestParseUntypedTextArray(t *testing.T) { source: "[4:4]={1}", result: pgtype.UntypedTextArray{ Elements: []string{"1"}, + Quoted: []bool{false}, Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 4}}, }, }, @@ -83,6 +92,7 @@ func TestParseUntypedTextArray(t *testing.T) { source: "[4:5][2:3]={{a,b},{c,d}}", result: pgtype.UntypedTextArray{ Elements: []string{"a", "b", "c", "d"}, + Quoted: []bool{false, false, false, false}, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, diff --git a/bool_array.go b/bool_array.go index 41c6deda..232863ec 100644 --- a/bool_array.go +++ b/bool_array.go @@ -319,7 +319,7 @@ func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Bool var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/bpchar_array.go b/bpchar_array.go index 5fd7381a..aad7c144 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -319,7 +319,7 @@ func (dst *BPCharArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem BPChar var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/bytea_array.go b/bytea_array.go index 9b5c9ee9..1dee05fa 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -291,7 +291,7 @@ func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Bytea var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/cidr_array.go b/cidr_array.go index 06192ddd..645c641a 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -348,7 +348,7 @@ func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem CIDR var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/date_array.go b/date_array.go index 1961bf20..a546a854 100644 --- a/date_array.go +++ b/date_array.go @@ -320,7 +320,7 @@ func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Date var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/enum_array.go b/enum_array.go index ebe838ad..d497dead 100644 --- a/enum_array.go +++ b/enum_array.go @@ -317,7 +317,7 @@ func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem GenericText var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/float4_array.go b/float4_array.go index 44ba1fee..c399697d 100644 --- a/float4_array.go +++ b/float4_array.go @@ -319,7 +319,7 @@ func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Float4 var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/float8_array.go b/float8_array.go index 1065190d..9a961c2f 100644 --- a/float8_array.go +++ b/float8_array.go @@ -319,7 +319,7 @@ func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Float8 var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/hstore_array.go b/hstore_array.go index 3899ae49..0be072cc 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -291,7 +291,7 @@ func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Hstore var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/inet_array.go b/inet_array.go index 5de138c0..d5d0a665 100644 --- a/inet_array.go +++ b/inet_array.go @@ -348,7 +348,7 @@ func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Inet var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/int2_array.go b/int2_array.go index 6b4e4c8a..8aeb7d46 100644 --- a/int2_array.go +++ b/int2_array.go @@ -711,7 +711,7 @@ func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Int2 var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/int4_array.go b/int4_array.go index 8801947d..76ca811e 100644 --- a/int4_array.go +++ b/int4_array.go @@ -711,7 +711,7 @@ func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Int4 var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/int8_array.go b/int8_array.go index 13e20fca..45d8447f 100644 --- a/int8_array.go +++ b/int8_array.go @@ -711,7 +711,7 @@ func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Int8 var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/jsonb_array.go b/jsonb_array.go index f44f7fa5..c8ef1fcd 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -319,7 +319,7 @@ func (dst *JSONBArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem JSONB var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/macaddr_array.go b/macaddr_array.go index 5a27046f..7f78c304 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -320,7 +320,7 @@ func (dst *MacaddrArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Macaddr var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/numeric_array.go b/numeric_array.go index c281bfb3..49c70855 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -487,7 +487,7 @@ func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Numeric var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/text_array.go b/text_array.go index 599764d8..d7125237 100644 --- a/text_array.go +++ b/text_array.go @@ -319,7 +319,7 @@ func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Text var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/text_array_test.go b/text_array_test.go index 5a2317e3..a5d050f6 100644 --- a/text_array_test.go +++ b/text_array_test.go @@ -6,8 +6,20 @@ import ( "github.com/jackc/pgtype" "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +// https://github.com/jackc/pgtype/issues/78 +func TestTextArrayDecodeTextNull(t *testing.T) { + textArray := &pgtype.TextArray{} + err := textArray.DecodeText(nil, []byte(`{abc,"NULL",NULL,def}`)) + require.NoError(t, err) + require.Len(t, textArray.Elements, 4) + assert.Equal(t, pgtype.Present, textArray.Elements[1].Status) + assert.Equal(t, pgtype.Null, textArray.Elements[2].Status) +} + func TestTextArrayTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "text[]", []interface{}{ &pgtype.TextArray{ diff --git a/timestamp_array.go b/timestamp_array.go index 2f7176b8..bb819017 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -320,7 +320,7 @@ func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Timestamp var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/timestamptz_array.go b/timestamptz_array.go index a10aaa8b..f028e0f9 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -320,7 +320,7 @@ func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Timestamptz var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/tstzrange_array.go b/tstzrange_array.go index 7e57acfe..4f03d838 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -272,7 +272,7 @@ func (dst *TstzrangeArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Tstzrange var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/typed_array.go.erb b/typed_array.go.erb index eb1a642e..60665270 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -292,7 +292,7 @@ func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error for i, s := range uta.Elements { var elem <%= pgtype_element_type %> var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/uuid_array.go b/uuid_array.go index fc1ea3b3..894bbd40 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -375,7 +375,7 @@ func (dst *UUIDArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem UUID var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) diff --git a/varchar_array.go b/varchar_array.go index 9326c72d..d515c2a4 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -319,7 +319,7 @@ func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { for i, s := range uta.Elements { var elem Varchar var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } err = elem.DecodeText(ci, elemSrc) From a885de9c949c36c1359edc6de00cff0bc4b16bb1 Mon Sep 17 00:00:00 2001 From: Ethan Pailes Date: Mon, 9 Nov 2020 08:20:34 -0500 Subject: [PATCH 0608/1158] stmtcache: add new StatementErrored method This patch adds a new StatementErrored method to the stmtcache. This routine MUST be called by users of the cache whenever the execution of a statement results in an error. This will allow the cache to make an intelligent decision about whether or not the statement needs to be purged from the cache. --- stmtcache/lru.go | 50 ++++++++++++++++++++++++++++++ stmtcache/lru_test.go | 69 ++++++++++++++++++++++++++++++++++++++++++ stmtcache/stmtcache.go | 8 +++++ 3 files changed, 127 insertions(+) diff --git a/stmtcache/lru.go b/stmtcache/lru.go index d82ced19..2f183f90 100644 --- a/stmtcache/lru.go +++ b/stmtcache/lru.go @@ -20,6 +20,7 @@ type LRU struct { m map[string]*list.Element l *list.List psNamePrefix string + stmtsToClear []string } // NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. @@ -41,6 +42,17 @@ func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { + // flush an outstanding bad statements + txStatus := c.conn.TxStatus() + if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 { + for _, stmt := range c.stmtsToClear { + err := c.clearStmt(ctx, stmt) + if err != nil { + return nil, err + } + } + } + if el, ok := c.m[sql]; ok { c.l.MoveToFront(el) return el.Value.(*pgconn.StatementDescription), nil @@ -76,6 +88,44 @@ func (c *LRU) Clear(ctx context.Context) error { return nil } +func (c *LRU) StatementErrored(ctx context.Context, sql string, err error) error { + pgErr, ok := err.(*pgconn.PgError) + if !ok { + // we don't know how to handle this error + return nil + } + + isInvalidCachedPlanError := pgErr.Severity == "ERROR" && + pgErr.Code == "0A000" && + pgErr.Message == "cached plan must not change result type" + if !isInvalidCachedPlanError { + // only flush if a plan has been changed out from under us + return nil + } + + c.stmtsToClear = append(c.stmtsToClear, sql) + + return nil +} + +func (c *LRU) clearStmt(ctx context.Context, sql string) error { + elem, inMap := c.m[sql] + if !inMap { + // The statement probably fell off the back of the list. In that case, we've + // ensured that it isn't in the cache, so we can declare victory. + return nil + } + + c.l.Remove(elem) + + psd := elem.Value.(*pgconn.StatementDescription) + delete(c.m, psd.SQL) + if c.mode == ModePrepare { + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close() + } + return nil +} + // Len returns the number of cached prepared statement descriptions. func (c *LRU) Len() int { return c.l.Len() diff --git a/stmtcache/lru_test.go b/stmtcache/lru_test.go index d2902dbb..75925509 100644 --- a/stmtcache/lru_test.go +++ b/stmtcache/lru_test.go @@ -59,6 +59,75 @@ func TestLRUModePrepare(t *testing.T) { require.Empty(t, fetchServerStatements(t, ctx, conn)) } +func TestLRUStmtInvalidation(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + // we construct a fake error because its not super straightforward to actually call + // a prepared statement from the LRU cache without the helper routines which live + // in pgx proper. + fakeInvalidCachePlanError := &pgconn.PgError{ + Severity: "ERROR", + Code: "0A000", + Message: "cached plan must not change result type", + } + + cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) + + // + // outside of a transaction, we eagerly flush the statement + // + + _, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + err = cache.StatementErrored(ctx, "select 1", fakeInvalidCachePlanError) + require.NoError(t, err) + _, err = cache.Get(ctx, "select 2") + require.NoError(t, err) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) + + err = cache.Clear(ctx) + require.NoError(t, err) + + // + // within an errored transaction, we defer the flush to after the first get + // that happens after the transaction is rolled back + // + + _, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + res := conn.Exec(ctx, "begin") + require.NoError(t, res.Close()) + require.Equal(t, byte('T'), conn.TxStatus()) + + res = conn.Exec(ctx, "selec") + require.Error(t, res.Close()) + require.Equal(t, byte('E'), conn.TxStatus()) + + err = cache.StatementErrored(ctx, "select 1", fakeInvalidCachePlanError) + require.EqualValues(t, 1, cache.Len()) + + res = conn.Exec(ctx, "rollback") + require.NoError(t, res.Close()) + + _, err = cache.Get(ctx, "select 2") + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) +} + func TestLRUModePrepareStress(t *testing.T) { t.Parallel() diff --git a/stmtcache/stmtcache.go b/stmtcache/stmtcache.go index 96215799..6e88ba54 100644 --- a/stmtcache/stmtcache.go +++ b/stmtcache/stmtcache.go @@ -20,6 +20,14 @@ type Cache interface { // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. Clear(ctx context.Context) error + // StatementErrored informs the cache that the given statement resulted in an error when it + // was last used against the database. In some cases, this will cause the cache to flush + // the statement from the cache. It will only do so when the underlying `*pgconn.PgConn` + // is not currently in a transaction. If the connection is in the middle of a transaction, + // the bad statement will instead be flushed during the next call to Get that occurrs outside + // of a transaction. + StatementErrored(ctx context.Context, sql string, err error) error + // Len returns the number of cached prepared statement descriptions. Len() int From 426124b32fb35daaee23175487b5a4117e38244e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 11 Nov 2020 15:48:49 -0600 Subject: [PATCH 0609/1158] Add stmtcache.LRU test thjat integrates over the database --- stmtcache/lru_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/stmtcache/lru_test.go b/stmtcache/lru_test.go index 75925509..58a0c378 100644 --- a/stmtcache/lru_test.go +++ b/stmtcache/lru_test.go @@ -128,6 +128,44 @@ func TestLRUStmtInvalidation(t *testing.T) { require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) } +func TestLRUStmtInvalidationIntegration(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) + + result := conn.ExecParams(ctx, "create temporary table stmtcache_table (a text)", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + sql := "select * from stmtcache_table" + sd1, err := cache.Get(ctx, sql) + require.NoError(t, err) + + result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read() + require.NoError(t, result.Err) + + result = conn.ExecParams(ctx, "alter table stmtcache_table add column b text", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read() + require.EqualError(t, result.Err, "ERROR: cached plan must not change result type (SQLSTATE 0A000)") + + cache.StatementErrored(ctx, sql, result.Err) + + sd2, err := cache.Get(ctx, sql) + require.NoError(t, err) + require.NotEqual(t, sd1.Name, sd2.Name) + + result = conn.ExecPrepared(ctx, sd2.Name, nil, nil, nil).Read() + require.NoError(t, result.Err) +} + func TestLRUModePrepareStress(t *testing.T) { t.Parallel() From cba610c245265ff50ea3c56a9961da218ed7d730 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 11 Nov 2020 15:52:59 -0600 Subject: [PATCH 0610/1158] StatementErrored does not need context nor return an error --- stmtcache/lru.go | 14 ++++---------- stmtcache/lru_test.go | 7 +++---- stmtcache/stmtcache.go | 10 ++++------ 3 files changed, 11 insertions(+), 20 deletions(-) diff --git a/stmtcache/lru.go b/stmtcache/lru.go index 2f183f90..f58f2ac3 100644 --- a/stmtcache/lru.go +++ b/stmtcache/lru.go @@ -88,24 +88,18 @@ func (c *LRU) Clear(ctx context.Context) error { return nil } -func (c *LRU) StatementErrored(ctx context.Context, sql string, err error) error { +func (c *LRU) StatementErrored(sql string, err error) { pgErr, ok := err.(*pgconn.PgError) if !ok { - // we don't know how to handle this error - return nil + return } isInvalidCachedPlanError := pgErr.Severity == "ERROR" && pgErr.Code == "0A000" && pgErr.Message == "cached plan must not change result type" - if !isInvalidCachedPlanError { - // only flush if a plan has been changed out from under us - return nil + if isInvalidCachedPlanError { + c.stmtsToClear = append(c.stmtsToClear, sql) } - - c.stmtsToClear = append(c.stmtsToClear, sql) - - return nil } func (c *LRU) clearStmt(ctx context.Context, sql string) error { diff --git a/stmtcache/lru_test.go b/stmtcache/lru_test.go index 58a0c378..2d620905 100644 --- a/stmtcache/lru_test.go +++ b/stmtcache/lru_test.go @@ -89,8 +89,7 @@ func TestLRUStmtInvalidation(t *testing.T) { require.EqualValues(t, 1, cache.Len()) require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) - err = cache.StatementErrored(ctx, "select 1", fakeInvalidCachePlanError) - require.NoError(t, err) + cache.StatementErrored("select 1", fakeInvalidCachePlanError) _, err = cache.Get(ctx, "select 2") require.NoError(t, err) require.EqualValues(t, 1, cache.Len()) @@ -117,7 +116,7 @@ func TestLRUStmtInvalidation(t *testing.T) { require.Error(t, res.Close()) require.Equal(t, byte('E'), conn.TxStatus()) - err = cache.StatementErrored(ctx, "select 1", fakeInvalidCachePlanError) + cache.StatementErrored("select 1", fakeInvalidCachePlanError) require.EqualValues(t, 1, cache.Len()) res = conn.Exec(ctx, "rollback") @@ -156,7 +155,7 @@ func TestLRUStmtInvalidationIntegration(t *testing.T) { result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read() require.EqualError(t, result.Err, "ERROR: cached plan must not change result type (SQLSTATE 0A000)") - cache.StatementErrored(ctx, sql, result.Err) + cache.StatementErrored(sql, result.Err) sd2, err := cache.Get(ctx, sql) require.NoError(t, err) diff --git a/stmtcache/stmtcache.go b/stmtcache/stmtcache.go index 6e88ba54..d083e1b4 100644 --- a/stmtcache/stmtcache.go +++ b/stmtcache/stmtcache.go @@ -21,12 +21,10 @@ type Cache interface { Clear(ctx context.Context) error // StatementErrored informs the cache that the given statement resulted in an error when it - // was last used against the database. In some cases, this will cause the cache to flush - // the statement from the cache. It will only do so when the underlying `*pgconn.PgConn` - // is not currently in a transaction. If the connection is in the middle of a transaction, - // the bad statement will instead be flushed during the next call to Get that occurrs outside - // of a transaction. - StatementErrored(ctx context.Context, sql string, err error) error + // was last used against the database. In some cases, this will cause the cache to maer that + // statement as bad. The bad statement will instead be flushed during the next call to Get + // that occurs outside of a failed transaction. + StatementErrored(sql string, err error) // Len returns the number of cached prepared statement descriptions. Len() int From 88b6398594fc9aa6abeae5dc4ddb453cf08b76b9 Mon Sep 17 00:00:00 2001 From: Roman Tkachenko Date: Tue, 17 Nov 2020 14:36:02 -0800 Subject: [PATCH 0611/1158] Add CopyData and CopyDone messages support to Backend --- backend.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/backend.go b/backend.go index 1f854c69..cc6f1f03 100644 --- a/backend.go +++ b/backend.go @@ -16,6 +16,8 @@ type Backend struct { cancelRequest CancelRequest _close Close copyFail CopyFail + copyData CopyData + copyDone CopyDone describe Describe execute Execute flush Flush @@ -116,6 +118,10 @@ func (b *Backend) Receive() (FrontendMessage, error) { msg = &b.execute case 'f': msg = &b.copyFail + case 'd': + msg = &b.copyData + case 'c': + msg = &b.copyDone case 'H': msg = &b.flush case 'P': From 00d516f5c4fdecd15973cf216fb1e39716c94346 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 27 Nov 2020 11:56:21 -0600 Subject: [PATCH 0612/1158] Fix panic on assigning empty array to non-slice or array See https://github.com/jackc/pgx/issues/881 --- aclitem_array.go | 6 ++++++ array_test.go | 12 ++++++++++++ bool_array.go | 6 ++++++ bpchar_array.go | 6 ++++++ bytea_array.go | 6 ++++++ cidr_array.go | 6 ++++++ date_array.go | 6 ++++++ enum_array.go | 6 ++++++ float4_array.go | 6 ++++++ float8_array.go | 6 ++++++ hstore_array.go | 6 ++++++ inet_array.go | 6 ++++++ int2_array.go | 6 ++++++ int4_array.go | 6 ++++++ int8_array.go | 6 ++++++ jsonb_array.go | 6 ++++++ macaddr_array.go | 6 ++++++ numeric_array.go | 6 ++++++ text_array.go | 6 ++++++ timestamp_array.go | 6 ++++++ timestamptz_array.go | 6 ++++++ tstzrange_array.go | 6 ++++++ typed_array.go.erb | 6 ++++++ uuid_array.go | 6 ++++++ varchar_array.go | 6 ++++++ 25 files changed, 156 insertions(+) diff --git a/aclitem_array.go b/aclitem_array.go index 501074d6..bf7bba93 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -228,6 +228,12 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/array_test.go b/array_test.go index 2f3b9237..d2120677 100644 --- a/array_test.go +++ b/array_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/jackc/pgtype" + "github.com/stretchr/testify/require" ) func TestParseUntypedTextArray(t *testing.T) { @@ -113,3 +114,14 @@ func TestParseUntypedTextArray(t *testing.T) { } } } + +// https://github.com/jackc/pgx/issues/881 +func TestArrayAssignToEmptyToNonSlice(t *testing.T) { + var a pgtype.Int4Array + err := a.Set([]int32{}) + require.NoError(t, err) + + var iface interface{} + err = a.AssignTo(&iface) + require.EqualError(t, err, "cannot assign *pgtype.Int4Array to *interface {}") +} diff --git a/bool_array.go b/bool_array.go index 232863ec..2659321e 100644 --- a/bool_array.go +++ b/bool_array.go @@ -230,6 +230,12 @@ func (src *BoolArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/bpchar_array.go b/bpchar_array.go index aad7c144..d48b2b53 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -230,6 +230,12 @@ func (src *BPCharArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/bytea_array.go b/bytea_array.go index 1dee05fa..14d8afad 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -202,6 +202,12 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/cidr_array.go b/cidr_array.go index 645c641a..3ac1b183 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -259,6 +259,12 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/date_array.go b/date_array.go index a546a854..0c623b8f 100644 --- a/date_array.go +++ b/date_array.go @@ -231,6 +231,12 @@ func (src *DateArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/enum_array.go b/enum_array.go index d497dead..cf7c7066 100644 --- a/enum_array.go +++ b/enum_array.go @@ -228,6 +228,12 @@ func (src *EnumArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/float4_array.go b/float4_array.go index c399697d..91b3b0e2 100644 --- a/float4_array.go +++ b/float4_array.go @@ -230,6 +230,12 @@ func (src *Float4Array) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/float8_array.go b/float8_array.go index 9a961c2f..559ee292 100644 --- a/float8_array.go +++ b/float8_array.go @@ -230,6 +230,12 @@ func (src *Float8Array) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/hstore_array.go b/hstore_array.go index 0be072cc..a44ea629 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -202,6 +202,12 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/inet_array.go b/inet_array.go index d5d0a665..30adeabb 100644 --- a/inet_array.go +++ b/inet_array.go @@ -259,6 +259,12 @@ func (src *InetArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/int2_array.go b/int2_array.go index 8aeb7d46..f4bd64cc 100644 --- a/int2_array.go +++ b/int2_array.go @@ -622,6 +622,12 @@ func (src *Int2Array) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/int4_array.go b/int4_array.go index 76ca811e..528310ff 100644 --- a/int4_array.go +++ b/int4_array.go @@ -622,6 +622,12 @@ func (src *Int4Array) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/int8_array.go b/int8_array.go index 45d8447f..b1e52a97 100644 --- a/int8_array.go +++ b/int8_array.go @@ -622,6 +622,12 @@ func (src *Int8Array) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/jsonb_array.go b/jsonb_array.go index c8ef1fcd..5d658ed5 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -230,6 +230,12 @@ func (src *JSONBArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/macaddr_array.go b/macaddr_array.go index 7f78c304..0ac2618e 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -231,6 +231,12 @@ func (src *MacaddrArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/numeric_array.go b/numeric_array.go index 49c70855..1c2ae489 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -398,6 +398,12 @@ func (src *NumericArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/text_array.go b/text_array.go index d7125237..afdc507b 100644 --- a/text_array.go +++ b/text_array.go @@ -230,6 +230,12 @@ func (src *TextArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/timestamp_array.go b/timestamp_array.go index bb819017..5256f185 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -231,6 +231,12 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/timestamptz_array.go b/timestamptz_array.go index f028e0f9..47408c02 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -231,6 +231,12 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/tstzrange_array.go b/tstzrange_array.go index 4f03d838..6d9bfe3b 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -183,6 +183,12 @@ func (src *TstzrangeArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/typed_array.go.erb b/typed_array.go.erb index 60665270..52f14592 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -203,6 +203,12 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/uuid_array.go b/uuid_array.go index 894bbd40..c6970d52 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -286,6 +286,12 @@ func (src *UUIDArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/varchar_array.go b/varchar_array.go index d515c2a4..f3a9b001 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -230,6 +230,12 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { value = value.Elem() } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + if len(src.Elements) == 0 { if value.Kind() == reflect.Slice { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) From 3742d6209e5f0a4b70b173477c6c40a0aaf21ce9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 3 Dec 2020 19:12:18 -0600 Subject: [PATCH 0613/1158] Release v1.8.0 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 92b1de06..787853b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.8.0 (December 3, 2020) + +* Add StatementErrored method to stmtcache.Cache. This allows the cache to purge invalidated prepared statements. (Ethan Pailes) + # 1.7.2 (November 3, 2020) * Fix data value slices into work buffer with capacities larger than length. From 7a47d60bbd54ab1af0fd1027eb2272765ee7264d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 3 Dec 2020 19:18:40 -0600 Subject: [PATCH 0614/1158] Update missing changelog entries for v1.6.0 and v1.6.1 --- CHANGELOG.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 774f0c1c..87ff6178 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,16 @@ +# 1.6.1 (October 31, 2020) + +* Fix simple protocol empty array support + +# 1.6.0 (October 24, 2020) + +* Fix AssignTo pointer to pointer to slice and named types. +* Fix zero length array assignment (Simo Haasanen) +* Add float64, float32 convert to int2, int4, int8 (lqu3j) +* Support setting infinite timestamps (Erik Agsjö) +* Polygon improvements (duohedron) +* Fix Inet.Set with nil (Tomas Volf) + # 1.5.0 (September 26, 2020) * Add slice of slice mapping to multi-dimensional arrays (Simo Haasanen) From 880863b70a9b560d2c9fb47a465cf8b9d0b0afe9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 3 Dec 2020 19:20:11 -0600 Subject: [PATCH 0615/1158] Release v1.6.2 --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 87ff6178..38eb89cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +# 1.6.2 (December 3, 2020) + +* Fix panic on assigning empty array to non-slice or array +* Fix text array parsing disambiguates NULL and "NULL" +* Fix Timestamptz.DecodeText with too short text + # 1.6.1 (October 31, 2020) * Fix simple protocol empty array support From a581247a126989e47bee6507f0b0f3c0ac9b4167 Mon Sep 17 00:00:00 2001 From: "ip.novikov" Date: Sat, 5 Dec 2020 15:28:01 +0300 Subject: [PATCH 0616/1158] Add check for url with broken password replace broken password in parseConfigError message --- errors.go | 2 ++ errors_test.go | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/errors.go b/errors.go index 164b0848..369c8ca3 100644 --- a/errors.go +++ b/errors.go @@ -178,6 +178,8 @@ func redactPW(connString string) string { connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx") plainDSN := regexp.MustCompile(`password=[^ ]*`) connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx") + brokenURL := regexp.MustCompile(`:\w.*@`) + connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@") return connString } diff --git a/errors_test.go b/errors_test.go index bef835f8..1bff3656 100644 --- a/errors_test.go +++ b/errors_test.go @@ -33,6 +33,16 @@ func TestConfigError(t *testing.T) { err: pgconn.NewParseConfigError("postgresql://foo::pasword@host:1:", "msg", nil), expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg", }, + { + name: "weird url with slash in password", + err: pgconn.NewParseConfigError("postgres://user:pass/word@host:5432/db_name", "msg", nil), + expectedMsg: "cannot parse `postgres://user:xxxxxx@host:5432/db_name`: msg", + }, + { + name: "url without password", + err: pgconn.NewParseConfigError("postgresql://other@host/db", "msg", nil), + expectedMsg: "cannot parse `postgresql://other@host/db`: msg", + }, } for _, tt := range tests { tt := tt From e0d22c1100233860131b45abe453a1c196391f98 Mon Sep 17 00:00:00 2001 From: "ip.novikov" Date: Sat, 5 Dec 2020 22:11:52 +0300 Subject: [PATCH 0617/1158] improve regexp get shortest sequence between : and @ --- errors.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/errors.go b/errors.go index 369c8ca3..b37b1d97 100644 --- a/errors.go +++ b/errors.go @@ -178,7 +178,7 @@ func redactPW(connString string) string { connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx") plainDSN := regexp.MustCompile(`password=[^ ]*`) connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx") - brokenURL := regexp.MustCompile(`:\w.*@`) + brokenURL := regexp.MustCompile(`:[^:@]+?@`) connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@") return connString } From b77cee2a28e57a61ae079dcd411d417695d6e270 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 23 Dec 2020 11:17:02 -0600 Subject: [PATCH 0618/1158] Fix scanning int into **sql.Scanner implementor See https://github.com/jackc/pgx/issues/897. --- convert.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/convert.go b/convert.go index 45c226be..193f771f 100644 --- a/convert.go +++ b/convert.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql" "math" "reflect" "time" @@ -277,6 +278,8 @@ func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { return errors.Errorf("%d is less than zero for uint64", srcVal) } *v = uint64(srcVal) + case sql.Scanner: + return v.Scan(srcVal) default: if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { el := v.Elem() From e276d9b832bfd155cb35b9080b21827bc5d0f996 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 23 Dec 2020 12:21:34 -0600 Subject: [PATCH 0619/1158] Add more documentation to TxStatus --- pgconn.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 3652cedb..53e32252 100644 --- a/pgconn.go +++ b/pgconn.go @@ -512,7 +512,14 @@ func (pgConn *PgConn) PID() uint32 { return pgConn.pid } -// TxStatus returns the current TxStatus as reported by the server. +// TxStatus returns the current TxStatus as reported by the server in the ReadyForQuery message. +// +// Possible return values: +// 'I' - idle / not in transaction +// 'T' - in a transaction +// 'E' - in a failed transaction +// +// See https://www.postgresql.org/docs/current/protocol-message-formats.html. func (pgConn *PgConn) TxStatus() byte { return pgConn.txStatus } From 1213b6977451a98e03ac7db11dbc41e52f08154b Mon Sep 17 00:00:00 2001 From: Yuli Khodorkovskiy Date: Wed, 16 Dec 2020 03:42:55 +0000 Subject: [PATCH 0620/1158] Add support to ErrorResponse for unlocalized severity Add missing 'V' field for unlocalized severity added in PG versions 9.6 and greater. See https://www.postgresql.org/docs/current/protocol-error-fields.html --- error_response.go | 43 ++++++++++++++++++++++++++----------------- frontend_test.go | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 17 deletions(-) diff --git a/error_response.go b/error_response.go index d444798b..4eb0a196 100644 --- a/error_response.go +++ b/error_response.go @@ -7,23 +7,24 @@ import ( ) type ErrorResponse struct { - Severity string - Code string - Message string - Detail string - Hint string - Position int32 - InternalPosition int32 - InternalQuery string - Where string - SchemaName string - TableName string - ColumnName string - DataTypeName string - ConstraintName string - File string - Line int32 - Routine string + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string UnknownFields map[byte]string } @@ -56,6 +57,8 @@ func (dst *ErrorResponse) Decode(src []byte) error { switch k { case 'S': dst.Severity = v + case 'V': + dst.SeverityUnlocalized = v case 'C': dst.Code = v case 'M': @@ -123,6 +126,11 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { buf.WriteString(src.Severity) buf.WriteByte(0) } + if src.SeverityUnlocalized != "" { + buf.WriteByte('V') + buf.WriteString(src.SeverityUnlocalized) + buf.WriteByte(0) + } if src.Code != "" { buf.WriteByte('C') buf.WriteString(src.Code) @@ -210,6 +218,7 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { buf.WriteString(v) buf.WriteByte(0) } + buf.WriteByte(0) binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) diff --git a/frontend_test.go b/frontend_test.go index 002da759..d202451f 100644 --- a/frontend_test.go +++ b/frontend_test.go @@ -6,6 +6,7 @@ import ( "github.com/jackc/pgproto3/v2" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type interruptReader struct { @@ -78,3 +79,39 @@ func TestFrontendReceiveUnexpectedEOF(t *testing.T) { assert.Nil(t, msg) assert.Equal(t, io.ErrUnexpectedEOF, err) } + +func TestErrorResponse(t *testing.T) { + t.Parallel() + + want := &pgproto3.ErrorResponse{ + Severity: "ERROR", + SeverityUnlocalized: "ERROR", + Message: `column "foo" does not exist`, + File: "parse_relation.c", + Code: "42703", + Position: 8, + Line: 3513, + Routine: "errorMissingColumn", + } + + raw := []byte{ + 'E', 0, 0, 0, 'f', + 'S', 'E', 'R', 'R', 'O', 'R', 0, + 'V', 'E', 'R', 'R', 'O', 'R', 0, + 'C', '4', '2', '7', '0', '3', 0, + 'M', 'c', 'o', 'l', 'u', 'm', 'n', 32, '"', 'f', 'o', 'o', '"', 32, 'd', 'o', 'e', 's', 32, 'n', 'o', 't', 32, 'e', 'x', 'i', 's', 't', 0, + 'P', '8', 0, + 'F', 'p', 'a', 'r', 's', 'e', '_', 'r', 'e', 'l', 'a', 't', 'i', 'o', 'n', '.', 'c', 0, + 'L', '3', '5', '1', '3', 0, + 'R', 'e', 'r', 'r', 'o', 'r', 'M', 'i', 's', 's', 'i', 'n', 'g', 'C', 'o', 'l', 'u', 'm', 'n', 0, 0, + } + + server := &interruptReader{} + server.push(raw) + + frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil) + + got, err := frontend.Receive() + require.NoError(t, err) + assert.Equal(t, want, got) +} From 97f8f6a25a82a73217b26c10c8cd6d639c015f1b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 28 Dec 2020 12:22:56 -0600 Subject: [PATCH 0621/1158] Begin CI with Github Actions --- .github/workflows/ci.yml | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..06c9a8d2 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,25 @@ +name: CI + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + + test: + name: Test + runs-on: ubuntu-latest + steps: + + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ^1.13 + + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: Test + run: go test -v ./... From ea92194719fc3d45ec59205e902a33c823701ee2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 28 Dec 2020 12:35:01 -0600 Subject: [PATCH 0622/1158] Add PostgreSQL service to CI --- .github/workflows/ci.yml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 06c9a8d2..e7160525 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,6 +11,18 @@ jobs: test: name: Test runs-on: ubuntu-latest + + services: + postgres: + image: postgres + env: + POSTGRES_PASSWORD: secret + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + steps: - name: Set up Go 1.x @@ -23,3 +35,6 @@ jobs: - name: Test run: go test -v ./... + env: + PGHOST: postgres + PGPASSWORD: secret From e2115310b7ba477c4ea4990d71a087f266f9b19e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 28 Dec 2020 12:50:42 -0600 Subject: [PATCH 0623/1158] More CI --- .github/workflows/ci.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e7160525..242ad7b3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,6 +22,8 @@ jobs: --health-interval 10s --health-timeout 5s --health-retries 5 + ports: + - 5432:5432 steps: @@ -36,5 +38,6 @@ jobs: - name: Test run: go test -v ./... env: - PGHOST: postgres + PGHOST: localhost + PGUSER: postgres PGPASSWORD: secret From be67555d02d8c58d0b65a26eff462746663a0fff Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 28 Dec 2020 12:56:41 -0600 Subject: [PATCH 0624/1158] Another CI tweak --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 242ad7b3..30e17b6a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,3 +41,4 @@ jobs: PGHOST: localhost PGUSER: postgres PGPASSWORD: secret + PGSSLMODE: disable From 6e11216708bb4c097c3b54cc19de371d036a30dc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 28 Dec 2020 13:02:34 -0600 Subject: [PATCH 0625/1158] Yet another CI tweak --- .github/workflows/ci.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 30e17b6a..4b5a72f2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,6 +35,14 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v2 + - name: Create hstore extension + run: psql -c 'create extension hstore' + env: + PGHOST: localhost + PGUSER: postgres + PGPASSWORD: secret + PGSSLMODE: disable + - name: Test run: go test -v ./... env: From b23d41c3992700f1dd4e412b6a3303a2207a47f3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 28 Dec 2020 13:11:36 -0600 Subject: [PATCH 0626/1158] Add CI badge --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 6848acc5..77d59b31 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ [![](https://godoc.org/github.com/jackc/pgtype?status.svg)](https://godoc.org/github.com/jackc/pgtype) +![CI](https://github.com/jackc/pgtype/workflows/CI/badge.svg) # pgtype From 724bf94515c9c2f671860d27d278cb8593f6055b Mon Sep 17 00:00:00 2001 From: Moshe Katz Date: Sat, 2 Jan 2021 22:51:02 -0500 Subject: [PATCH 0627/1158] use proper pgpass location on Windows --- config.go | 43 -------------------------------------- config_test.go | 18 ++++++++++++++-- defaults.go | 51 +++++++++++++++++++++++++++++++++++++++++++++ defaults_windows.go | 46 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 113 insertions(+), 45 deletions(-) create mode 100644 defaults.go create mode 100644 defaults_windows.go diff --git a/config.go b/config.go index b05727ca..e4ee244a 100644 --- a/config.go +++ b/config.go @@ -11,7 +11,6 @@ import ( "net" "net/url" "os" - "os/user" "path/filepath" "strconv" "strings" @@ -338,48 +337,6 @@ func ParseConfig(connString string) (*Config, error) { return config, nil } -func defaultSettings() map[string]string { - settings := make(map[string]string) - - settings["host"] = defaultHost() - settings["port"] = "5432" - - // Default to the OS user name. Purposely ignoring err getting user name from - // OS. The client application will simply have to specify the user in that - // case (which they typically will be doing anyway). - user, err := user.Current() - if err == nil { - settings["user"] = user.Username - settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") - settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") - } - - settings["target_session_attrs"] = "any" - - settings["min_read_buffer_size"] = "8192" - - return settings -} - -// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost -// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it -// checks the existence of common locations. -func defaultHost() string { - candidatePaths := []string{ - "/var/run/postgresql", // Debian - "/private/tmp", // OSX - homebrew - "/tmp", // standard PostgreSQL - } - - for _, path := range candidatePaths { - if _, err := os.Stat(path); err == nil { - return path - } - } - - return "localhost" -} - func mergeSettings(settingSets ...map[string]string) map[string]string { settings := make(map[string]string) diff --git a/config_test.go b/config_test.go index d322f65a..f6391672 100644 --- a/config_test.go +++ b/config_test.go @@ -7,6 +7,8 @@ import ( "io/ioutil" "os" "os/user" + "runtime" + "strings" "testing" "time" @@ -21,7 +23,13 @@ func TestParseConfig(t *testing.T) { var osUserName string osUser, err := user.Current() if err == nil { - osUserName = osUser.Username + // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, + // but the libpq default is just the `user` portion, so we strip off the first part. + if runtime.GOOS == "windows" && strings.Contains(osUser.Username, "\\") { + osUserName = osUser.Username[strings.LastIndex(osUser.Username, "\\")+1:] + } else { + osUserName = osUser.Username + } } tests := []struct { @@ -630,7 +638,13 @@ func TestParseConfigEnvLibpq(t *testing.T) { var osUserName string osUser, err := user.Current() if err == nil { - osUserName = osUser.Username + // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, + // but the libpq default is just the `user` portion, so we strip off the first part. + if runtime.GOOS == "windows" && strings.Contains(osUser.Username, "\\") { + osUserName = osUser.Username[strings.LastIndex(osUser.Username, "\\")+1:] + } else { + osUserName = osUser.Username + } } pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"} diff --git a/defaults.go b/defaults.go new file mode 100644 index 00000000..d3313481 --- /dev/null +++ b/defaults.go @@ -0,0 +1,51 @@ +// +build !windows + +package pgconn + +import ( + "os" + "os/user" + "path/filepath" +) + +func defaultSettings() map[string]string { + settings := make(map[string]string) + + settings["host"] = defaultHost() + settings["port"] = "5432" + + // Default to the OS user name. Purposely ignoring err getting user name from + // OS. The client application will simply have to specify the user in that + // case (which they typically will be doing anyway). + user, err := user.Current() + if err == nil { + settings["user"] = user.Username + settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") + settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") + } + + settings["target_session_attrs"] = "any" + + settings["min_read_buffer_size"] = "8192" + + return settings +} + +// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost +// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it +// checks the existence of common locations. +func defaultHost() string { + candidatePaths := []string{ + "/var/run/postgresql", // Debian + "/private/tmp", // OSX - homebrew + "/tmp", // standard PostgreSQL + } + + for _, path := range candidatePaths { + if _, err := os.Stat(path); err == nil { + return path + } + } + + return "localhost" +} diff --git a/defaults_windows.go b/defaults_windows.go new file mode 100644 index 00000000..55243700 --- /dev/null +++ b/defaults_windows.go @@ -0,0 +1,46 @@ +package pgconn + +import ( + "os" + "os/user" + "path/filepath" + "strings" +) + +func defaultSettings() map[string]string { + settings := make(map[string]string) + + settings["host"] = defaultHost() + settings["port"] = "5432" + + // Default to the OS user name. Purposely ignoring err getting user name from + // OS. The client application will simply have to specify the user in that + // case (which they typically will be doing anyway). + user, err := user.Current() + appData := os.Getenv("APPDATA") + if err == nil { + // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, + // but the libpq default is just the `user` portion, so we strip off the first part. + username := user.Username + if strings.Contains(username, "\\") { + username = username[strings.LastIndex(username, "\\")+1:] + } + + settings["user"] = username + settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf") + settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") + } + + settings["target_session_attrs"] = "any" + + settings["min_read_buffer_size"] = "8192" + + return settings +} + +// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost +// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it +// checks the existence of common locations. +func defaultHost() string { + return "localhost" +} From 1e141d8c32939b0c0fb2fac854cd37fc543b7835 Mon Sep 17 00:00:00 2001 From: Vasilii Novikov Date: Mon, 4 Jan 2021 13:48:11 +0300 Subject: [PATCH 0628/1158] Add tsrange array type. --- pgtype.go | 6 + tsrange_array.go | 470 +++++++++++++++++++++++++++++++++++++++++++++ typed_array_gen.sh | 1 + 3 files changed, 477 insertions(+) create mode 100644 tsrange_array.go diff --git a/pgtype.go b/pgtype.go index df5078a9..c5e537cd 100644 --- a/pgtype.go +++ b/pgtype.go @@ -77,7 +77,9 @@ const ( Int4rangeOID = 3904 NumrangeOID = 3906 TsrangeOID = 3908 + TsrangeArrayOID = 3909 TstzrangeOID = 3910 + TstzrangeArrayOID = 3911 Int8rangeOID = 3926 ) @@ -309,7 +311,9 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Timestamp{}, Name: "timestamp", OID: TimestampOID}) ci.RegisterDataType(DataType{Value: &Timestamptz{}, Name: "timestamptz", OID: TimestamptzOID}) ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) + ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID}) ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) + ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID}) ci.RegisterDataType(DataType{Value: &Unknown{}, Name: "unknown", OID: UnknownOID}) ci.RegisterDataType(DataType{Value: &UUID{}, Name: "uuid", OID: UUIDOID}) ci.RegisterDataType(DataType{Value: &Varbit{}, Name: "varbit", OID: VarbitOID}) @@ -924,7 +928,9 @@ func init() { "timestamp": &Timestamp{}, "timestamptz": &Timestamptz{}, "tsrange": &Tsrange{}, + "_tsrange": &TsrangeArray{}, "tstzrange": &Tstzrange{}, + "_tstzrange": &TstzrangeArray{}, "unknown": &Unknown{}, "uuid": &UUID{}, "varbit": &Varbit{}, diff --git a/tsrange_array.go b/tsrange_array.go new file mode 100644 index 00000000..15053f75 --- /dev/null +++ b/tsrange_array.go @@ -0,0 +1,470 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "reflect" + + "github.com/jackc/pgio" + errors "golang.org/x/xerrors" +) + +type TsrangeArray struct { + Elements []Tsrange + Dimensions []ArrayDimension + Status Status +} + +func (dst *TsrangeArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = TsrangeArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []Tsrange: + if value == nil { + *dst = TsrangeArray{Status: Null} + } else if len(value) == 0 { + *dst = TsrangeArray{Status: Present} + } else { + *dst = TsrangeArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = TsrangeArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TsrangeArray", src) + } + if elementsLength == 0 { + *dst = TsrangeArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to TsrangeArray", src) + } + + *dst = TsrangeArray{ + Elements: make([]Tsrange, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Tsrange, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TsrangeArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *TsrangeArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to TsrangeArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in TsrangeArray", err) + } + index++ + + return index, nil +} + +func (dst TsrangeArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *TsrangeArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]Tsrange: + *v = make([]Tsrange, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return errors.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *TsrangeArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, errors.Errorf("cannot assign all values from TsrangeArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, errors.Errorf("cannot assign all values from TsrangeArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *TsrangeArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TsrangeArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Tsrange + + if len(uta.Elements) > 0 { + elements = make([]Tsrange, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Tsrange + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TsrangeArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TsrangeArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TsrangeArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TsrangeArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Tsrange, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = TsrangeArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src TsrangeArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src TsrangeArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("tsrange"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "tsrange") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TsrangeArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src TsrangeArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/typed_array_gen.sh b/typed_array_gen.sh index fe9eb62b..ea28be07 100755 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -5,6 +5,7 @@ erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[ erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time,[]*time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time,[]*time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange go_array_types=[]Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstzrange_array.go +erb pgtype_array_type=TsrangeArray pgtype_element_type=Tsrange go_array_types=[]Tsrange element_type_name=tsrange text_null=NULL binary_format=true typed_array.go.erb > tsrange_array.go erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time,[]*time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32,[]*float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64,[]*float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go From 59b79a2e49183d58630a10d848d05e3e45a4630e Mon Sep 17 00:00:00 2001 From: Stephane Martin Date: Wed, 6 Jan 2021 14:20:46 +0100 Subject: [PATCH 0629/1158] Fix: escaped strings when they start or end with a newline char (jackc/pgtype#86) --- array.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array.go b/array.go index 93c91897..6063c9e6 100644 --- a/array.go +++ b/array.go @@ -348,7 +348,7 @@ func quoteArrayElement(src string) string { } func QuoteArrayElementIfNeeded(src string) string { - if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `{},"\`) { + if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || src[0] == ' ' || src[0] == '\n' || src[len(src)-1] == ' ' || src[len(src)-1] == '\n' || strings.ContainsAny(src, `{},"\`) { return quoteArrayElement(src) } return src From 6830cc09847cfe17ae59177e7f81b67312496108 Mon Sep 17 00:00:00 2001 From: Stephane Martin Date: Sun, 10 Jan 2021 01:05:56 +0100 Subject: [PATCH 0630/1158] Fix: also consider \r, \f, \t as whitespace (jackc/pgtype#86) --- array.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/array.go b/array.go index 6063c9e6..4e22166e 100644 --- a/array.go +++ b/array.go @@ -347,8 +347,13 @@ func quoteArrayElement(src string) string { return `"` + quoteArrayReplacer.Replace(src) + `"` } +func isSpace(ch byte) bool { + // see https://github.com/postgres/postgres/blob/REL_12_STABLE/src/backend/parser/scansup.c#L224 + return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\f' +} + func QuoteArrayElementIfNeeded(src string) string { - if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || src[0] == ' ' || src[0] == '\n' || src[len(src)-1] == ' ' || src[len(src)-1] == '\n' || strings.ContainsAny(src, `{},"\`) { + if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) { return quoteArrayElement(src) } return src From 120139a206078c030cdab77ee1d05984bb503fe5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 14 Jan 2021 18:22:18 -0600 Subject: [PATCH 0631/1158] Add link to PG docs for connString format fixes #62 --- config.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/config.go b/config.go index e4ee244a..38e94f26 100644 --- a/config.go +++ b/config.go @@ -112,9 +112,10 @@ func NetworkAddress(host string, port uint16) (network, address string) { } // ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same -// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. connString may be a URL or a DSN. -// It also may be empty to only read from the environment. If a password is not supplied it will attempt to read the -// .pgpass file. +// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely matches +// the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). See +// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be +// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. // // # Example DSN // user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca From 7d8845a9d8f32c059555e20783828da2534e52f8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 12:47:34 -0600 Subject: [PATCH 0632/1158] Initial import from pgtype --- .github/workflows/ci.yml | 52 ++++++++++++++++++++++++++++++++++++++++ README.md | 1 + 2 files changed, 53 insertions(+) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..27ea2d4d --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,52 @@ +name: CI + +on: + push: + branches: [ github-ci-wip ] + pull_request: + branches: [ github-ci-wip ] + +jobs: + + test: + name: Test + runs-on: ubuntu-latest + + services: + postgres: + image: postgres + env: + POSTGRES_PASSWORD: secret + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + + steps: + + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ^1.13 + + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: Create hstore extension + run: psql -c 'create extension hstore' + env: + PGHOST: localhost + PGUSER: postgres + PGPASSWORD: secret + PGSSLMODE: disable + + - name: Test + run: go test -v ./... + env: + PGHOST: localhost + PGUSER: postgres + PGPASSWORD: secret + PGSSLMODE: disable diff --git a/README.md b/README.md index 6a68e230..d7238c39 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ [![](https://godoc.org/github.com/jackc/pgconn?status.svg)](https://godoc.org/github.com/jackc/pgconn) [![Build Status](https://travis-ci.org/jackc/pgconn.svg)](https://travis-ci.org/jackc/pgconn) +![CI](https://github.com/jackc/pgtype/workflows/CI/badge.svg) # pgconn From 63bcdfde61d2395e45710c959bed950feeaa5bde Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 12:48:58 -0600 Subject: [PATCH 0633/1158] Fix CI link --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d7238c39..feead016 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ [![](https://godoc.org/github.com/jackc/pgconn?status.svg)](https://godoc.org/github.com/jackc/pgconn) [![Build Status](https://travis-ci.org/jackc/pgconn.svg)](https://travis-ci.org/jackc/pgconn) -![CI](https://github.com/jackc/pgtype/workflows/CI/badge.svg) +![CI](https://github.com/jackc/pgconn/workflows/CI/badge.svg) # pgconn From 6c2a423dbc25d634270b04ecaac7a1d644037945 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 12:58:25 -0600 Subject: [PATCH 0634/1158] Try to debug failing CI test --- config_test.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/config_test.go b/config_test.go index f6391672..7b8b8937 100644 --- a/config_test.go +++ b/config_test.go @@ -568,7 +568,16 @@ func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) { original, err := pgconn.ParseConfig(connString) require.NoError(t, err) + fmt.Printf("original: %#v\n", original) + for i, f := range original.Fallbacks { + fmt.Printf("original fallback %d: %#v\n", i, f) + } + copied := original.Copy() + fmt.Printf("copied: %#v\n", copied) + for i, f := range copied.Fallbacks { + fmt.Printf("copied fallback %d: %#v\n", i, f) + } assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") copied.Port = uint16(5433) From a9c2b5c3cbb210352546ca3763dd259d7b752771 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 13:01:27 -0600 Subject: [PATCH 0635/1158] Revert "Try to debug failing CI test" This reverts commit 6c2a423dbc25d634270b04ecaac7a1d644037945. --- config_test.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/config_test.go b/config_test.go index 7b8b8937..f6391672 100644 --- a/config_test.go +++ b/config_test.go @@ -568,16 +568,7 @@ func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) { original, err := pgconn.ParseConfig(connString) require.NoError(t, err) - fmt.Printf("original: %#v\n", original) - for i, f := range original.Fallbacks { - fmt.Printf("original fallback %d: %#v\n", i, f) - } - copied := original.Copy() - fmt.Printf("copied: %#v\n", copied) - for i, f := range copied.Fallbacks { - fmt.Printf("copied fallback %d: %#v\n", i, f) - } assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") copied.Port = uint16(5433) From 74517d73154ecdf045aad3fedcf47d66499b5548 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 13:03:56 -0600 Subject: [PATCH 0636/1158] Fix test when PGSSLMODE=disable When PGSSLMODE=disable no fallback config was created which would cause the check that fallbacks are deep copied to crash on: copied.Fallbacks[0].Port = uint16(5433) --- config_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config_test.go b/config_test.go index f6391672..e869422d 100644 --- a/config_test.go +++ b/config_test.go @@ -564,7 +564,7 @@ func TestConfigCopyReturnsEqualConfig(t *testing.T) { } func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) { - connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5&sslmode=prefer" original, err := pgconn.ParseConfig(connString) require.NoError(t, err) From eb322859067bf699fbfe8e8a5a8c6c89a1f5ff7e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 16:22:38 -0600 Subject: [PATCH 0637/1158] Use native PostgreSQL package Also remove travis integration. --- .github/workflows/ci.yml | 32 ++++-------- .travis.yml | 49 ------------------- {travis => ci}/script.bash | 0 .../before_install.bash => ci/setup_test.bash | 12 +++++ travis/before_script.bash | 17 ------- 5 files changed, 21 insertions(+), 89 deletions(-) delete mode 100644 .travis.yml rename {travis => ci}/script.bash (100%) rename travis/before_install.bash => ci/setup_test.bash (73%) delete mode 100755 travis/before_script.bash diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 27ea2d4d..3e3c1ed8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,19 +12,6 @@ jobs: name: Test runs-on: ubuntu-latest - services: - postgres: - image: postgres - env: - POSTGRES_PASSWORD: secret - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - ports: - - 5432:5432 - steps: - name: Set up Go 1.x @@ -35,18 +22,17 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v2 - - name: Create hstore extension - run: psql -c 'create extension hstore' + - name: Setup database server for testing + run: ci/setup_test.bash env: - PGHOST: localhost - PGUSER: postgres - PGPASSWORD: secret - PGSSLMODE: disable + PGVERSION: 12 - name: Test run: go test -v ./... env: - PGHOST: localhost - PGUSER: postgres - PGPASSWORD: secret - PGSSLMODE: disable + PGX_TEST_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test + PGX_TEST_UNIX_SOCKET_CONN_STRING: "host=/var/run/postgresql dbname=pgx_test" + PGX_TEST_TCP_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test + PGX_TEST_TLS_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + PGX_TEST_MD5_PASSWORD_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test + PGX_TEST_PLAIN_PASSWORD_CONN_STRING: postgres://pgx_pw:secret@127.0.0.1/pgx_test diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 95dce226..00000000 --- a/.travis.yml +++ /dev/null @@ -1,49 +0,0 @@ -language: go - -go: - - 1.15.x - - 1.14.x - - tip - -git: - depth: 1 - -# Derived from https://github.com/lib/pq/blob/master/.travis.yml -before_install: - - ./travis/before_install.bash - -env: - global: - - GO111MODULE=on - - GOPROXY=https://proxy.golang.org - - GOFLAGS=-mod=readonly - - PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" - - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - - PGX_TEST_TLS_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - - PGX_TEST_MD5_PASSWORD_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - - PGX_TEST_PLAIN_PASSWORD_CONN_STRING=postgres://pgx_pw:secret@127.0.0.1/pgx_test - matrix: - - CRATEVERSION=2.1 PGX_TEST_CRATEDB_CONN_STRING="host=127.0.0.1 port=6543 user=pgx dbname=pgx_test" - - PGVERSION=12 - - PGVERSION=11 - - PGVERSION=10 - - PGVERSION=9.6 - - PGVERSION=9.5 - -cache: - directories: - - $HOME/.cache/go-build - - $HOME/gopath/pkg/mod - -before_script: - - ./travis/before_script.bash - -install: go mod download - -script: - - ./travis/script.bash - -matrix: - allow_failures: - - go: tip diff --git a/travis/script.bash b/ci/script.bash similarity index 100% rename from travis/script.bash rename to ci/script.bash diff --git a/travis/before_install.bash b/ci/setup_test.bash similarity index 73% rename from travis/before_install.bash rename to ci/setup_test.bash index 23c7d9cf..78e30383 100755 --- a/travis/before_install.bash +++ b/ci/setup_test.bash @@ -24,6 +24,18 @@ then echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf fi sudo /etc/init.d/postgresql restart + + # The tricky test user, below, has to actually exist so that it can be used in a test + # of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. + psql -U postgres -c 'create database pgx_test' + psql -U postgres pgx_test -c 'create extension hstore' + psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)' + psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user travis" + psql -U postgres -c "create user pgx_replication with replication password 'secret'" + psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" fi if [ "${CRATEVERSION-}" != "" ] diff --git a/travis/before_script.bash b/travis/before_script.bash deleted file mode 100755 index 923b7d06..00000000 --- a/travis/before_script.bash +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env bash -set -eux - -if [ "${PGVERSION-}" != "" ] -then - # The tricky test user, below, has to actually exist so that it can be used in a test - # of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. - psql -U postgres -c 'create database pgx_test' - psql -U postgres pgx_test -c 'create extension hstore' - psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)' - psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user travis" - psql -U postgres -c "create user pgx_replication with replication password 'secret'" - psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" -fi From c107f909a2aba0b35ed9817cafc6acf872861a89 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 16:28:27 -0600 Subject: [PATCH 0638/1158] Create user for Unix domain socket --- ci/setup_test.bash | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/setup_test.bash b/ci/setup_test.bash index 78e30383..144e93fd 100755 --- a/ci/setup_test.bash +++ b/ci/setup_test.bash @@ -33,7 +33,7 @@ then psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user travis" + psql -U postgres -c "create user `whoami`" psql -U postgres -c "create user pgx_replication with replication password 'secret'" psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" fi From c10c60cad5d4a336c46cfb324c7879185266f34b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 16:38:58 -0600 Subject: [PATCH 0639/1158] Add build matrix for Go and PG --- .github/workflows/ci.yml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3e3c1ed8..b37ca273 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,12 +12,17 @@ jobs: name: Test runs-on: ubuntu-latest + strategy: + matrix: + go_version: [1.14, 1.15] + pg_version: [9.6, 10, 11, 12, 13] + steps: - name: Set up Go 1.x uses: actions/setup-go@v2 with: - go-version: ^1.13 + go-version: ${{ matrix.go_version }} - name: Check out code into the Go module directory uses: actions/checkout@v2 @@ -25,7 +30,7 @@ jobs: - name: Setup database server for testing run: ci/setup_test.bash env: - PGVERSION: 12 + PGVERSION: ${{ matrix.pg_version }} - name: Test run: go test -v ./... From ed0090f61043e2bce64be49da76fe6b7e4a1fbca Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 16:44:17 -0600 Subject: [PATCH 0640/1158] Use race detector on Github CI --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b37ca273..5acb0eea 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,7 +33,7 @@ jobs: PGVERSION: ${{ matrix.pg_version }} - name: Test - run: go test -v ./... + run: go test -v -race ./... env: PGX_TEST_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test PGX_TEST_UNIX_SOCKET_CONN_STRING: "host=/var/run/postgresql dbname=pgx_test" From 609cd81d64b4689ca9126322fd54f5eecaaf909f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 16:47:51 -0600 Subject: [PATCH 0641/1158] Remove obsolete Travis badge --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index feead016..c651f483 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,4 @@ [![](https://godoc.org/github.com/jackc/pgconn?status.svg)](https://godoc.org/github.com/jackc/pgconn) -[![Build Status](https://travis-ci.org/jackc/pgconn.svg)](https://travis-ci.org/jackc/pgconn) ![CI](https://github.com/jackc/pgconn/workflows/CI/badge.svg) # pgconn From 9cf57526250f6cd3e6cbf4fd7269c882e66898ce Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 16:48:51 -0600 Subject: [PATCH 0642/1158] Change Github CI to run on master --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5acb0eea..862235ae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,9 +2,9 @@ name: CI on: push: - branches: [ github-ci-wip ] + branches: [ master ] pull_request: - branches: [ github-ci-wip ] + branches: [ master ] jobs: From a78ab5bdcda1e98bb43673be2ffda39435b91fda Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Feb 2021 09:39:42 -0600 Subject: [PATCH 0643/1158] Test should abort if cannot setup database --- pgconn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn_test.go b/pgconn_test.go index b71e7d3f..76156420 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -990,7 +990,7 @@ func TestConnExecBatchDeferredError(t *testing.T) { insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() - assert.NoError(t, err) + require.NoError(t, err) batch := &pgconn.Batch{} From d05c52217a6e39cdc3ad75808786189aace7b71b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Feb 2021 10:47:22 -0600 Subject: [PATCH 0644/1158] Initial CockroachDB testing --- pgconn_test.go | 90 ++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 80 insertions(+), 10 deletions(-) diff --git a/pgconn_test.go b/pgconn_test.go index 76156420..564a0c51 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -524,9 +524,18 @@ func TestConnExecMultipleQueriesError(t *testing.T) { t.Errorf("unexpected error: %v", err) } - assert.Len(t, results, 1) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "1", string(results[0].Rows[0][0])) + if pgConn.ParameterStatus("crdb_version") != "" { + // CockroachDB starts the second query result set and then sends the divide by zero error. + require.Len(t, results, 2) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) + assert.Len(t, results[1].Rows, 0) + } else { + // PostgreSQL sends the divide by zero and never sends the second query result set. + require.Len(t, results, 1) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) + } ensureConnValid(t, pgConn) } @@ -538,6 +547,10 @@ func TestConnExecDeferredError(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + } + setupSQL := `create temporary table t ( id text primary key, n int not null, @@ -630,6 +643,10 @@ func TestConnExecParamsDeferredError(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + } + setupSQL := `create temporary table t ( id text primary key, n int not null, @@ -860,14 +877,19 @@ func TestConnExecPreparedTooManyParams(t *testing.T) { sql := "values" + strings.Join(params, ", ") psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + if pgConn.ParameterStatus("crdb_version") != "" { + // CockroachDB rejects preparing a statement with more than 65535 parameters. + require.EqualError(t, err, "ERROR: more than 65535 arguments to prepared statement: 65536 (SQLSTATE 08P01)") + } else { + // PostgreSQL accepts preparing a statement with more than 65535 parameters and only fails when executing it through the extended protocol. + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.EqualError(t, result.Err, "extended protocol limited to 65535 parameters") + } ensureConnValid(t, pgConn) } @@ -981,6 +1003,10 @@ func TestConnExecBatchDeferredError(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + } + setupSQL := `create temporary table t ( id text primary key, n int not null, @@ -1161,6 +1187,10 @@ func TestConnOnNotice(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support PL/PGSQL (https://github.com/cockroachdb/cockroach/issues/17511)") + } + multiResult := pgConn.Exec(context.Background(), `do $$ begin raise notice 'hello, world'; @@ -1187,6 +1217,10 @@ func TestConnOnNotification(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + } + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() require.NoError(t, err) @@ -1219,6 +1253,10 @@ func TestConnWaitForNotification(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + } + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() require.NoError(t, err) @@ -1279,6 +1317,10 @@ func TestConnCopyToSmall(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does support COPY TO") + } + _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int2, b int4, @@ -1317,6 +1359,10 @@ func TestConnCopyToLarge(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does support COPY TO") + } + _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int2, b int4, @@ -1372,6 +1418,10 @@ func TestConnCopyToCanceled(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") + } + outputWriter := &bytes.Buffer{} ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -1415,6 +1465,10 @@ func TestConnCopyFrom(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)") + } + _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int4, b varchar @@ -1451,6 +1505,10 @@ func TestConnCopyFromCanceled(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") + } + _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int4, b varchar @@ -1528,6 +1586,10 @@ func TestConnCopyFromGzipReader(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)") + } + _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int4, b varchar @@ -1627,6 +1689,10 @@ func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support triggers (https://github.com/cockroachdb/cockroach/issues/28296)") + } + _, err = pgConn.Exec(ctx, `create temporary table sentences( t text, ts tsvector @@ -1693,6 +1759,10 @@ func TestConnCancelRequest(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") + } + multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a From 4bde08d1a63976925a721ab2f4e000ad594fb34f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Feb 2021 11:19:09 -0600 Subject: [PATCH 0645/1158] LRU statement cache tests handle CockroackDB --- stmtcache/lru_test.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/stmtcache/lru_test.go b/stmtcache/lru_test.go index 2d620905..a4108155 100644 --- a/stmtcache/lru_test.go +++ b/stmtcache/lru_test.go @@ -5,6 +5,7 @@ import ( "fmt" "math/rand" "os" + "regexp" "testing" "time" @@ -239,7 +240,19 @@ func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgCon require.NoError(t, result.Err) var statements []string for _, r := range result.Rows { - statements = append(statements, string(r[0])) + statement := string(r[0]) + if conn.ParameterStatus("crdb_version") != "" { + if statement == "PREPARE AS select statement from pg_prepared_statements" { + // CockroachDB includes the currently running unnamed prepared statement while PostgreSQL does not. Ignore it. + continue + } + + // CockroachDB includes the "PREPARE ... AS" text in the statement even if it was prepared through the extended + // protocol will PostgreSQL does not. Normalize the statement. + re := regexp.MustCompile(`^PREPARE lrupsc[0-9_]+ AS `) + statement = re.ReplaceAllString(statement, "") + } + statements = append(statements, statement) } return statements } From abeb337246854b40048ee995343b94dad92867d8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Feb 2021 09:28:14 -0600 Subject: [PATCH 0646/1158] Accept nil *time.Time in Time.Set --- time.go | 6 ++++++ time_test.go | 3 +++ 2 files changed, 9 insertions(+) diff --git a/time.go b/time.go index 16a2a393..237c4b5b 100644 --- a/time.go +++ b/time.go @@ -42,6 +42,12 @@ func (dst *Time) Set(src interface{}) error { int64(value.Second())*microsecondsPerSecond + int64(value.Nanosecond())/1000 *dst = Time{Microseconds: usec, Status: Present} + case *time.Time: + if value == nil { + *dst = Time{Status: Null} + } else { + return dst.Set(*value) + } default: if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) diff --git a/time_test.go b/time_test.go index bf6365ef..0af42b1e 100644 --- a/time_test.go +++ b/time_test.go @@ -48,6 +48,9 @@ func TestTimeSet(t *testing.T) { {source: time.Date(1970, 1, 1, 0, 0, 0, 1000, time.UTC), result: pgtype.Time{Microseconds: 1, Status: pgtype.Present}}, {source: time.Date(1999, 12, 31, 23, 59, 59, 999999999, time.UTC), result: pgtype.Time{Microseconds: 86399999999, Status: pgtype.Present}}, {source: time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local), result: pgtype.Time{Microseconds: 2, Status: pgtype.Present}}, + {source: func(t time.Time) *time.Time { return &t }(time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local)), result: pgtype.Time{Microseconds: 2, Status: pgtype.Present}}, + {source: nil, result: pgtype.Time{Status: pgtype.Null}}, + {source: (*time.Time)(nil), result: pgtype.Time{Status: pgtype.Null}}, {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 3000, time.UTC)), result: pgtype.Time{Microseconds: 3, Status: pgtype.Present}}, } From fb88a34cb4995248d154b5eaadde52136de25547 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Feb 2021 16:40:16 -0600 Subject: [PATCH 0647/1158] Skip test with known issue on CockroachDB --- pgconn_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pgconn_test.go b/pgconn_test.go index 564a0c51..87edefc2 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1098,6 +1098,10 @@ func TestConnExecBatchImplicitTransaction(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/44803)") + } + _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll() require.NoError(t, err) From b9a1aad8d94163ffdf29aaedb48b78d3e2329ee3 Mon Sep 17 00:00:00 2001 From: Georges Varouchas Date: Thu, 4 Mar 2021 17:58:49 +0100 Subject: [PATCH 0648/1158] add failing test to highlight issue #65 if frontend returns a message with "Severity: FATAL", even after calling "conn.Close()", the 'CleanupDone()' channel is still blocking --- frontend_test.go | 70 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 frontend_test.go diff --git a/frontend_test.go b/frontend_test.go new file mode 100644 index 00000000..b82552bf --- /dev/null +++ b/frontend_test.go @@ -0,0 +1,70 @@ +package pgconn_test + +import ( + "context" + "io" + "os" + "testing" + + "github.com/jackc/pgconn" + "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// frontendWrapper allows to hijack a regular frontend, and inject a specific response +type frontendWrapper struct { + front pgconn.Frontend + + msg pgproto3.BackendMessage +} + +// frontendWrapper implements the pgconn.Frontend interface +var _ pgconn.Frontend = (*frontendWrapper)(nil) + +func (f *frontendWrapper) Receive() (pgproto3.BackendMessage, error) { + if f.msg != nil { + return f.msg, nil + } + + return f.front.Receive() +} + +func TestFrontendFatalErrExec(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + buildFrontend := config.BuildFrontend + var front *frontendWrapper + + config.BuildFrontend = func(r io.Reader, w io.Writer) pgconn.Frontend { + wrapped := buildFrontend(r, w) + front = &frontendWrapper{wrapped, nil} + + return front + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.NotNil(t, conn) + require.NotNil(t, front) + + // set frontend to return a "FATAL" message on next call + front.msg = &pgproto3.ErrorResponse{Severity: "FATAL", Message: "unit testing fatal error"} + + _, err = conn.Exec(context.Background(), "SELECT 1").ReadAll() + assert.Error(t, err) + + err = conn.Close(context.Background()) + assert.NoError(t, err) + + select { + case <-conn.CleanupDone(): + t.Log("ok, CleanupDone() is not blocking") + + default: + assert.Fail(t, "connection closed but CleanupDone() still blocking") + } +} From 36c8fb8257391de896e4c934ace6e82ea5631f3a Mon Sep 17 00:00:00 2001 From: Georges Varouchas Date: Thu, 4 Mar 2021 18:07:41 +0100 Subject: [PATCH 0649/1158] fix #65 : close cleanupDone channel on "FATAL" messages --- pgconn.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pgconn.go b/pgconn.go index 53e32252..0c1717ff 100644 --- a/pgconn.go +++ b/pgconn.go @@ -487,6 +487,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { if msg.Severity == "FATAL" { pgConn.status = connStatusClosed pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. + close(pgConn.cleanupDone) return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: From 3b0400a0d401491f45add1f347ed0383ca6a76a1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Mar 2021 14:42:22 -0600 Subject: [PATCH 0650/1158] Test Go 1.15 and 1.16 in CI --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 862235ae..fa5c9e8f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: - go_version: [1.14, 1.15] + go_version: [1.15, 1.16] pg_version: [9.6, 10, 11, 12, 13] steps: From cf5894e0927e66175468e7622712d2a4c6df0964 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Mar 2021 14:45:33 -0600 Subject: [PATCH 0651/1158] Use std errors instead of golang.org/x/xerrors New error functionality was introduced in Go 1.13. pgconn only officially supports 1.15+. Transitional xerrors package can now be removed. --- auth_scram.go | 6 +++--- config.go | 8 ++++---- errors.go | 3 +-- go.mod | 1 - pgconn.go | 5 +++-- pgconn_test.go | 6 ++---- 6 files changed, 13 insertions(+), 16 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index 665fc2c2..6a143fcd 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -18,13 +18,13 @@ import ( "crypto/rand" "crypto/sha256" "encoding/base64" + "errors" "fmt" "strconv" "github.com/jackc/pgproto3/v2" "golang.org/x/crypto/pbkdf2" "golang.org/x/text/secure/precis" - errors "golang.org/x/xerrors" ) const clientNonceLen = 18 @@ -192,12 +192,12 @@ func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { var err error sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr)) if err != nil { - return errors.Errorf("invalid SCRAM salt received from server: %w", err) + return fmt.Errorf("invalid SCRAM salt received from server: %w", err) } sc.iterations, err = strconv.Atoi(string(iterationsStr)) if err != nil || sc.iterations <= 0 { - return errors.Errorf("invalid SCRAM iteration count received from server: %w", err) + return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err) } if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) { diff --git a/config.go b/config.go index 38e94f26..c162d3c3 100644 --- a/config.go +++ b/config.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" "fmt" "io" "io/ioutil" @@ -20,7 +21,6 @@ import ( "github.com/jackc/pgpassfile" "github.com/jackc/pgproto3/v2" "github.com/jackc/pgservicefile" - errors "golang.org/x/xerrors" ) type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error @@ -409,7 +409,7 @@ func parseURLSettings(connString string) (map[string]string, error) { } h, p, err := net.SplitHostPort(host) if err != nil { - return nil, errors.Errorf("failed to split host:port in '%s', err: %w", host, err) + return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err) } hosts = append(hosts, h) ports = append(ports, p) @@ -617,7 +617,7 @@ func configTLS(settings map[string]string) ([]*tls.Config, error) { caPath := sslrootcert caCert, err := ioutil.ReadFile(caPath) if err != nil { - return nil, errors.Errorf("unable to read CA file: %w", err) + return nil, fmt.Errorf("unable to read CA file: %w", err) } if !caCertPool.AppendCertsFromPEM(caCert) { @@ -635,7 +635,7 @@ func configTLS(settings map[string]string) ([]*tls.Config, error) { if sslcert != "" && sslkey != "" { cert, err := tls.LoadX509KeyPair(sslcert, sslkey) if err != nil { - return nil, errors.Errorf("unable to read cert: %w", err) + return nil, fmt.Errorf("unable to read cert: %w", err) } tlsConfig.Certificates = []tls.Certificate{cert} diff --git a/errors.go b/errors.go index b37b1d97..77adfcf0 100644 --- a/errors.go +++ b/errors.go @@ -2,13 +2,12 @@ package pgconn import ( "context" + "errors" "fmt" "net" "net/url" "regexp" "strings" - - errors "golang.org/x/xerrors" ) // SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. diff --git a/go.mod b/go.mod index 7e578765..2dc0cd4d 100644 --- a/go.mod +++ b/go.mod @@ -12,5 +12,4 @@ require ( github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/text v0.3.3 - golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 ) diff --git a/pgconn.go b/pgconn.go index 0c1717ff..20233e57 100644 --- a/pgconn.go +++ b/pgconn.go @@ -6,6 +6,8 @@ import ( "crypto/tls" "encoding/binary" "encoding/hex" + "errors" + "fmt" "io" "math" "net" @@ -16,7 +18,6 @@ import ( "github.com/jackc/pgconn/internal/ctxwatch" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" - errors "golang.org/x/xerrors" ) const ( @@ -1043,7 +1044,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by } if len(paramValues) > math.MaxUint16 { - result.concludeCommand(nil, errors.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true pgConn.unlock() return result diff --git a/pgconn_test.go b/pgconn_test.go index 87edefc2..7ceda791 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -5,6 +5,7 @@ import ( "compress/gzip" "context" "crypto/tls" + "errors" "fmt" "io" "io/ioutil" @@ -17,12 +18,9 @@ import ( "testing" "time" - "github.com/jackc/pgmock" - "github.com/jackc/pgconn" + "github.com/jackc/pgmock" "github.com/jackc/pgproto3/v2" - errors "golang.org/x/xerrors" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) From a0350a932a7e4313c547e36e6e2e8b7ccd8ce3d1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Mar 2021 15:01:44 -0600 Subject: [PATCH 0652/1158] ci.yml consistently uses kebab case --- .github/workflows/ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fa5c9e8f..77d32cb7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,15 +14,15 @@ jobs: strategy: matrix: - go_version: [1.15, 1.16] - pg_version: [9.6, 10, 11, 12, 13] + go-version: [1.15, 1.16] + pg-version: [9.6, 10, 11, 12, 13] steps: - name: Set up Go 1.x uses: actions/setup-go@v2 with: - go-version: ${{ matrix.go_version }} + go-version: ${{ matrix.go-version }} - name: Check out code into the Go module directory uses: actions/checkout@v2 @@ -30,7 +30,7 @@ jobs: - name: Setup database server for testing run: ci/setup_test.bash env: - PGVERSION: ${{ matrix.pg_version }} + PGVERSION: ${{ matrix.pg-version }} - name: Test run: go test -v -race ./... From 7de3392269f1eb7d43900b8406392ea767fae479 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Mar 2021 15:15:03 -0600 Subject: [PATCH 0653/1158] Manually specify all build matrix options - Saves some CI time by only testing older version of Go once - Specify connection --- .github/workflows/ci.yml | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 77d32cb7..67ffeaab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,8 +14,30 @@ jobs: strategy: matrix: - go-version: [1.15, 1.16] - pg-version: [9.6, 10, 11, 12, 13] + include: + - go-version: 1.15 + pg-version: 13 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + - go-version: 1.16 + pg-version: 9.6 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + - go-version: 1.16 + pg-version: 10 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + - go-version: 1.16 + pg-version: 11 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + - go-version: 1.16 + pg-version: 12 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + - go-version: 1.16 + pg-version: 13 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test steps: @@ -35,9 +57,9 @@ jobs: - name: Test run: go test -v -race ./... env: - PGX_TEST_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_UNIX_SOCKET_CONN_STRING: "host=/var/run/postgresql dbname=pgx_test" - PGX_TEST_TCP_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_TLS_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - PGX_TEST_MD5_PASSWORD_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_PLAIN_PASSWORD_CONN_STRING: postgres://pgx_pw:secret@127.0.0.1/pgx_test + PGX_TEST_CONN_STRING: ${{ matrix.pgx-test-conn-string }} + PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }} + PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }} + PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }} + PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }} + PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }} From 1e905d8e38f6c9344707931ccd2afa03a2f34273 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Mar 2021 15:20:03 -0600 Subject: [PATCH 0654/1158] Refactor connection strings into build matrix This is in preparation for adding CockroachDB to the build matrix. --- .github/workflows/ci.yml | 40 +++++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 67ffeaab..6880ae90 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,24 +14,38 @@ jobs: strategy: matrix: + go-version: [1.15, 1.16] + pg-version: [9.6, 10, 11, 12, 13] include: - - go-version: 1.15 - pg-version: 13 + - pg-version: 9.6 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - - go-version: 1.16 - pg-version: 9.6 + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: 10 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - - go-version: 1.16 - pg-version: 10 + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: 11 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - - go-version: 1.16 - pg-version: 11 + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: 12 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - - go-version: 1.16 - pg-version: 12 - pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - - go-version: 1.16 - pg-version: 13 + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: 13 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test From 0d307bcc5e8ce129be1875bce1595a397aa46140 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Mar 2021 15:49:50 -0600 Subject: [PATCH 0655/1158] Add CockroachDB to CI --- .github/workflows/ci.yml | 6 ++++-- ci/setup_test.bash | 10 +++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6880ae90..d84462da 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,12 +10,12 @@ jobs: test: name: Test - runs-on: ubuntu-latest + runs-on: ubuntu-18.04 strategy: matrix: go-version: [1.15, 1.16] - pg-version: [9.6, 10, 11, 12, 13] + pg-version: [9.6, 10, 11, 12, 13, cockroachdb] include: - pg-version: 9.6 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test @@ -52,6 +52,8 @@ jobs: pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: cockroachdb + pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" steps: diff --git a/ci/setup_test.bash b/ci/setup_test.bash index 144e93fd..f71bd98c 100755 --- a/ci/setup_test.bash +++ b/ci/setup_test.bash @@ -1,7 +1,7 @@ #!/usr/bin/env bash set -eux -if [ "${PGVERSION-}" != "" ] +if [[ "${PGVERSION-}" =~ ^[0-9.]+$ ]] then sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common sudo rm -rf /var/lib/postgresql @@ -38,6 +38,14 @@ then psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" fi +if [[ "${PGVERSION-}" =~ ^cockroach ]] +then + wget -qO- https://binaries.cockroachdb.com/cockroach-v20.2.5.linux-amd64.tgz | tar xvz + sudo mv cockroach-v20.2.5.linux-amd64/cockroach /usr/local/bin/ + cockroach start-single-node --insecure --background --listen-addr=localhost + cockroach sql --insecure -e 'create database pgx_test' +fi + if [ "${CRATEVERSION-}" != "" ] then docker run \ From 5daa019e4eb52df3409ebf17c83116b7c0e827e5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Mar 2021 16:08:38 -0600 Subject: [PATCH 0656/1158] Update README.md to authentication test setup --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c651f483..1c698a11 100644 --- a/README.md +++ b/README.md @@ -52,5 +52,5 @@ PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./... Pgconn supports multiple connection types and means of authentication. These tests are optional. They will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being -skipped. Most developers will not need to enable these tests. See `travis.yml` for an example set up if you need change +skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change authentication code. From 0f1bda20b06513437fbbe380d444cf1404ce6a2c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 11 Mar 2021 19:48:47 -0600 Subject: [PATCH 0657/1158] Fix numeric NaN support fixes #93 --- numeric.go | 14 +++++--------- numeric_test.go | 5 ++++- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/numeric.go b/numeric.go index f2b04006..4d966d5e 100644 --- a/numeric.go +++ b/numeric.go @@ -16,8 +16,8 @@ import ( const nbase = 10000 const ( - pgNumericNaN = 0x000000000c000000 - pgNumericNaNSign = 0x0c00 + pgNumericNaN = 0x00000000c0000000 + pgNumericNaNSign = 0xc000 ) var big0 *big.Int = big.NewInt(0) @@ -406,7 +406,7 @@ func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { return nil } - if string(src) == "'NaN'" { // includes single quotes, see EncodeText for details. + if string(src) == "NaN" { *dst = Numeric{Status: Present, NaN: true} return nil } @@ -456,7 +456,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { rp += 2 weight := int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 - sign := int16(binary.BigEndian.Uint16(src[rp:])) + sign := uint16(binary.BigEndian.Uint16(src[rp:])) rp += 2 dscale := int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 @@ -573,11 +573,7 @@ func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } if src.NaN { - // encode as 'NaN' including single quotes, - // "When writing this value [NaN] as a constant in an SQL command, - // you must put quotes around it, for example UPDATE table SET x = 'NaN'" - // https://www.postgresql.org/docs/9.3/datatype-numeric.html - buf = append(buf, "'NaN'"...) + buf = append(buf, "NaN"...) return buf, nil } diff --git a/numeric_test.go b/numeric_test.go index 675eddc4..81595cb3 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -15,7 +15,8 @@ import ( func numericEqual(left, right *pgtype.Numeric) bool { return left.Status == right.Status && left.Exp == right.Exp && - ((left.Int == nil && right.Int == nil) || (left.Int != nil && right.Int != nil && left.Int.Cmp(right.Int) == 0)) + ((left.Int == nil && right.Int == nil) || (left.Int != nil && right.Int != nil && left.Int.Cmp(right.Int) == 0)) && + left.NaN == right.NaN } // For test purposes only. @@ -117,6 +118,8 @@ func TestNumericNormalize(t *testing.T) { func TestNumericTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ + &pgtype.Numeric{NaN: true, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, From 26ccb4ee08e9895ad83905cbfbd7dc782261f8c3 Mon Sep 17 00:00:00 2001 From: Andrey Borodin Date: Wed, 10 Mar 2021 22:19:41 +0500 Subject: [PATCH 0658/1158] Resume fallback on server error When server responds with "TLS required" or too "many connections for role" fallbacks are not traversed any further. This could be OK, but fallbacks without TLS are added autoatically so that if we have multiple hosts requiring TLS we never traverse beyond first one. --- pgconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 20233e57..a245159d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -151,7 +151,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err if err == nil { break } else if err, ok := err.(*PgError); ok { - return nil, &connectError{config: config, msg: "server error", err: err} + err = &connectError{config: config, msg: "server error", err: err} } } From 70be4b4a02e4c00a3cf4199749f60a0544e12d9b Mon Sep 17 00:00:00 2001 From: Andrey Borodin Date: Wed, 10 Mar 2021 22:29:01 +0500 Subject: [PATCH 0659/1158] Fix incoherent type assignment --- pgconn.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgconn.go b/pgconn.go index a245159d..826d70e9 100644 --- a/pgconn.go +++ b/pgconn.go @@ -150,8 +150,8 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err pgConn, err = connect(ctx, config, fc) if err == nil { break - } else if err, ok := err.(*PgError); ok { - err = &connectError{config: config, msg: "server error", err: err} + } else if pgerr, ok := err.(*PgError); ok { + err = &connectError{config: config, msg: "server error", err: pgerr} } } From b6027e37f43987793a1e39b97b99598777218547 Mon Sep 17 00:00:00 2001 From: Andrey Borodin Date: Fri, 12 Mar 2021 11:48:43 +0500 Subject: [PATCH 0660/1158] Stop fallback in case of invalid password --- pgconn.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pgconn.go b/pgconn.go index 826d70e9..668808aa 100644 --- a/pgconn.go +++ b/pgconn.go @@ -152,6 +152,10 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err break } else if pgerr, ok := err.(*PgError); ok { err = &connectError{config: config, msg: "server error", err: pgerr} + ERRCODE_INVALID_PASSWORD := "28P01" + if pgerr.Code == ERRCODE_INVALID_PASSWORD { + break; + } } } From 8990c125cf4a71bcf938328b43d52a289053725e Mon Sep 17 00:00:00 2001 From: Andrey Borodin Date: Fri, 12 Mar 2021 11:55:01 +0500 Subject: [PATCH 0661/1158] Stop fallback on ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION --- pgconn.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pgconn.go b/pgconn.go index 668808aa..197aad4a 100644 --- a/pgconn.go +++ b/pgconn.go @@ -152,9 +152,10 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err break } else if pgerr, ok := err.(*PgError); ok { err = &connectError{config: config, msg: "server error", err: pgerr} - ERRCODE_INVALID_PASSWORD := "28P01" - if pgerr.Code == ERRCODE_INVALID_PASSWORD { - break; + ERRCODE_INVALID_PASSWORD := "28P01" // worng password + ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION := "28000" // db does not exist + if pgerr.Code == ERRCODE_INVALID_PASSWORD || pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION { + break } } } From aa897205768c9fdb17c148465e60517260823c96 Mon Sep 17 00:00:00 2001 From: drewdogg Date: Fri, 12 Mar 2021 18:08:10 -0700 Subject: [PATCH 0662/1158] go 1.13 --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index c70404df..990e79f3 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/jackc/pgtype -go 1.12 +go 1.13 require ( github.com/gofrs/uuid v3.2.0+incompatible From e8f75629d095956d7bff81362bdcd17e37d02464 Mon Sep 17 00:00:00 2001 From: Ethan Pailes Date: Mon, 22 Mar 2021 13:51:08 -0400 Subject: [PATCH 0663/1158] upgrade x/crypto to avoid CVE-2020-9283 I found this when scanning for security issues in some dependencies. I doubt that this CVE will impact pgconn since I don't think it uses the ssh cropto module, but I think it is worth being fairly agressive about upgrading security sensative libraries and this doesn't seem to be a breaking change. --- go.mod | 2 +- go.sum | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 2dc0cd4d..e9003cb7 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,6 @@ require ( github.com/jackc/pgproto3/v2 v2.0.6 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.5.1 - golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 + golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 golang.org/x/text v0.3.3 ) diff --git a/go.sum b/go.sum index f3eb0e08..58bb1286 100644 --- a/go.sum +++ b/go.sum @@ -99,10 +99,13 @@ golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 h1:3zb4D3T4G8jdExgVU/95+v golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -111,6 +114,8 @@ golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= From dd160540c4760d444a45b8666422ba30f564c26e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 25 Mar 2021 09:01:59 -0400 Subject: [PATCH 0664/1158] Use Go 1.13 errors instead of xerrors --- aclitem.go | 11 +++--- aclitem_array.go | 31 ++++++++--------- array.go | 36 +++++++++---------- array_type.go | 12 +++---- bool.go | 15 ++++---- bool_array.go | 32 ++++++++--------- box.go | 11 +++--- bpchar_array.go | 32 ++++++++--------- bytea.go | 13 ++++--- bytea_array.go | 32 ++++++++--------- cidr_array.go | 32 ++++++++--------- circle.go | 11 +++--- composite_fields.go | 16 ++++----- composite_type.go | 41 +++++++++++----------- convert.go | 55 ++++++++++++++--------------- custom_composite_test.go | 2 +- database_sql.go | 3 +- date.go | 14 ++++---- date_array.go | 32 ++++++++--------- daterange.go | 24 ++++++------- enum_array.go | 31 ++++++++--------- enum_type.go | 8 ++--- ext/gofrs-uuid/uuid.go | 16 ++++----- ext/shopspring-numeric/decimal.go | 54 ++++++++++++++-------------- float4.go | 20 +++++------ float4_array.go | 32 ++++++++--------- float8.go | 16 ++++----- float8_array.go | 32 ++++++++--------- go.mod | 1 - hstore.go | 34 +++++++++--------- hstore_array.go | 32 ++++++++--------- inet.go | 17 +++++---- inet_array.go | 32 ++++++++--------- int2.go | 36 +++++++++---------- int2_array.go | 32 ++++++++--------- int4.go | 30 ++++++++-------- int4_array.go | 32 ++++++++--------- int4range.go | 24 ++++++------- int8.go | 20 +++++------ int8_array.go | 32 ++++++++--------- int8range.go | 24 ++++++------- interval.go | 23 ++++++------ json.go | 8 ++--- jsonb.go | 7 ++-- jsonb_array.go | 32 ++++++++--------- line.go | 13 ++++--- lseg.go | 11 +++--- macaddr.go | 13 ++++--- macaddr_array.go | 32 ++++++++--------- numeric.go | 58 +++++++++++++++---------------- numeric_array.go | 32 ++++++++--------- numrange.go | 24 ++++++------- oid.go | 12 +++---- path.go | 13 ++++--- pgtype.go | 29 ++++++++-------- pgtype_test.go | 2 +- pguint32.go | 14 ++++---- point.go | 17 +++++---- polygon.go | 15 ++++---- qchar.go | 33 +++++++++--------- range.go | 39 ++++++++++----------- record.go | 13 ++++--- text.go | 11 +++--- text_array.go | 32 ++++++++--------- tid.go | 15 ++++---- time.go | 23 ++++++------ timestamp.go | 18 +++++----- timestamp_array.go | 32 ++++++++--------- timestamptz.go | 14 ++++---- timestamptz_array.go | 32 ++++++++--------- tsrange.go | 24 ++++++------- tsrange_array.go | 32 ++++++++--------- tstzrange.go | 24 ++++++------- tstzrange_array.go | 32 ++++++++--------- typed_array.go.erb | 30 ++++++++-------- typed_range.go.erb | 22 ++++++------ uuid.go | 18 +++++----- uuid_array.go | 32 ++++++++--------- varbit.go | 10 +++--- varchar_array.go | 32 ++++++++--------- 80 files changed, 927 insertions(+), 956 deletions(-) diff --git a/aclitem.go b/aclitem.go index d2fe7529..9f6587be 100644 --- a/aclitem.go +++ b/aclitem.go @@ -2,8 +2,7 @@ package pgtype import ( "database/sql/driver" - - errors "golang.org/x/xerrors" + "fmt" ) // ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem @@ -49,7 +48,7 @@ func (dst *ACLItem) Set(src interface{}) error { if originalSrc, ok := underlyingStringType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to ACLItem", value) + return fmt.Errorf("cannot convert %v to ACLItem", value) } return nil @@ -77,13 +76,13 @@ func (src *ACLItem) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (dst *ACLItem) DecodeText(ci *ConnInfo, src []byte) error { @@ -123,7 +122,7 @@ func (dst *ACLItem) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/aclitem_array.go b/aclitem_array.go index bf7bba93..4e3be3bd 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -4,9 +4,8 @@ package pgtype import ( "database/sql/driver" + "fmt" "reflect" - - errors "golang.org/x/xerrors" ) type ACLItemArray struct { @@ -94,7 +93,7 @@ func (dst *ACLItemArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for ACLItemArray", src) + return fmt.Errorf("cannot find dimensions of %v for ACLItemArray", src) } if elementsLength == 0 { *dst = ACLItemArray{Status: Present} @@ -104,7 +103,7 @@ func (dst *ACLItemArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to ACLItemArray", src) + return fmt.Errorf("cannot convert %v to ACLItemArray", src) } *dst = ACLItemArray{ @@ -135,7 +134,7 @@ func (dst *ACLItemArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to ACLItemArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to ACLItemArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -153,7 +152,7 @@ func (dst *ACLItemArray) setRecursive(value reflect.Value, index, dimension int) valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -166,10 +165,10 @@ func (dst *ACLItemArray) setRecursive(value reflect.Value, index, dimension int) return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to ACLItemArray") + return 0, fmt.Errorf("cannot convert all values to ACLItemArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in ACLItemArray", err) + return 0, fmt.Errorf("%v in ACLItemArray", err) } index++ @@ -231,7 +230,7 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -246,7 +245,7 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -254,7 +253,7 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *ACLItemArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -270,7 +269,7 @@ func (src *ACLItemArray) assignToRecursive(value reflect.Value, index, dimension if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -288,14 +287,14 @@ func (src *ACLItemArray) assignToRecursive(value reflect.Value, index, dimension return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from ACLItemArray") + return 0, fmt.Errorf("cannot assign all values from ACLItemArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from ACLItemArray") + return 0, fmt.Errorf("cannot assign all values from ACLItemArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -412,7 +411,7 @@ func (dst *ACLItemArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/array.go b/array.go index 4e22166e..3d5930c1 100644 --- a/array.go +++ b/array.go @@ -3,6 +3,7 @@ package pgtype import ( "bytes" "encoding/binary" + "fmt" "io" "reflect" "strconv" @@ -10,7 +11,6 @@ import ( "unicode" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) // Information on the internals of PostgreSQL arrays can be found in @@ -30,7 +30,7 @@ type ArrayDimension struct { func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { if len(src) < 12 { - return 0, errors.Errorf("array header too short: %d", len(src)) + return 0, fmt.Errorf("array header too short: %d", len(src)) } rp := 0 @@ -48,7 +48,7 @@ func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { dst.Dimensions = make([]ArrayDimension, numDims) } if len(src) < 12+numDims*8 { - return 0, errors.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) + return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) } for i := range dst.Dimensions { dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:])) @@ -95,7 +95,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { r, _, err := buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %v", err) } var explicitDimensions []ArrayDimension @@ -107,41 +107,41 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { for { r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %v", err) } if r == '=' { break } else if r != '[' { - return nil, errors.Errorf("invalid array, expected '[' or '=' got %v", r) + return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r) } lower, err := arrayParseInteger(buf) if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %v", err) } r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %v", err) } if r != ':' { - return nil, errors.Errorf("invalid array, expected ':' got %v", r) + return nil, fmt.Errorf("invalid array, expected ':' got %v", r) } upper, err := arrayParseInteger(buf) if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %v", err) } r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %v", err) } if r != ']' { - return nil, errors.Errorf("invalid array, expected ']' got %v", r) + return nil, fmt.Errorf("invalid array, expected ']' got %v", r) } explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1}) @@ -149,12 +149,12 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %v", err) } } if r != '{' { - return nil, errors.Errorf("invalid array, expected '{': %v", err) + return nil, fmt.Errorf("invalid array, expected '{': %v", err) } implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}} @@ -163,7 +163,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { for { r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %v", err) } if r == '{' { @@ -180,7 +180,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { for { r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %v", err) } switch r { @@ -199,7 +199,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { buf.UnreadRune() value, quoted, err := arrayParseValue(buf) if err != nil { - return nil, errors.Errorf("invalid array value: %v", err) + return nil, fmt.Errorf("invalid array value: %v", err) } if currentDim == counterDim { implicitDimensions[currentDim].Length++ @@ -216,7 +216,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { skipWhitespace(buf) if buf.Len() > 0 { - return nil, errors.Errorf("unexpected trailing data: %v", buf.String()) + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) } if len(dst.Elements) == 0 { diff --git a/array_type.go b/array_type.go index 04b8710c..1bd0244b 100644 --- a/array_type.go +++ b/array_type.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) // ArrayType represents an array type. While it implements Value, this is only in service of its type conversion duties @@ -58,7 +58,7 @@ func (dst *ArrayType) Set(src interface{}) error { sliceVal := reflect.ValueOf(src) if sliceVal.Kind() != reflect.Slice { - return errors.Errorf("cannot set non-slice") + return fmt.Errorf("cannot set non-slice") } if sliceVal.IsNil() { @@ -100,14 +100,14 @@ func (dst ArrayType) Get() interface{} { func (src *ArrayType) AssignTo(dst interface{}) error { ptrSlice := reflect.ValueOf(dst) if ptrSlice.Kind() != reflect.Ptr { - return errors.Errorf("cannot assign to non-pointer") + return fmt.Errorf("cannot assign to non-pointer") } sliceVal := ptrSlice.Elem() sliceType := sliceVal.Type() if sliceType.Kind() != reflect.Slice { - return errors.Errorf("cannot assign to pointer to non-slice") + return fmt.Errorf("cannot assign to pointer to non-slice") } switch src.status { @@ -132,7 +132,7 @@ func (src *ArrayType) AssignTo(dst interface{}) error { return nil } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error { @@ -336,7 +336,7 @@ func (dst *ArrayType) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/bool.go b/bool.go index 9ec5097f..676c8e5d 100644 --- a/bool.go +++ b/bool.go @@ -3,9 +3,8 @@ package pgtype import ( "database/sql/driver" "encoding/json" + "fmt" "strconv" - - errors "golang.org/x/xerrors" ) type Bool struct { @@ -51,7 +50,7 @@ func (dst *Bool) Set(src interface{}) error { if originalSrc, ok := underlyingBoolType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Bool", value) + return fmt.Errorf("cannot convert %v to Bool", value) } return nil @@ -79,13 +78,13 @@ func (src *Bool) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { @@ -95,7 +94,7 @@ func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) != 1 { - return errors.Errorf("invalid length for bool: %v", len(src)) + return fmt.Errorf("invalid length for bool: %v", len(src)) } *dst = Bool{Bool: src[0] == 't', Status: Present} @@ -109,7 +108,7 @@ func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 1 { - return errors.Errorf("invalid length for bool: %v", len(src)) + return fmt.Errorf("invalid length for bool: %v", len(src)) } *dst = Bool{Bool: src[0] == 1, Status: Present} @@ -169,7 +168,7 @@ func (dst *Bool) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/bool_array.go b/bool_array.go index 2659321e..6558d971 100644 --- a/bool_array.go +++ b/bool_array.go @@ -5,10 +5,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type BoolArray struct { @@ -96,7 +96,7 @@ func (dst *BoolArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for BoolArray", src) + return fmt.Errorf("cannot find dimensions of %v for BoolArray", src) } if elementsLength == 0 { *dst = BoolArray{Status: Present} @@ -106,7 +106,7 @@ func (dst *BoolArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to BoolArray", src) + return fmt.Errorf("cannot convert %v to BoolArray", src) } *dst = BoolArray{ @@ -137,7 +137,7 @@ func (dst *BoolArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to BoolArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to BoolArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -155,7 +155,7 @@ func (dst *BoolArray) setRecursive(value reflect.Value, index, dimension int) (i valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -168,10 +168,10 @@ func (dst *BoolArray) setRecursive(value reflect.Value, index, dimension int) (i return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to BoolArray") + return 0, fmt.Errorf("cannot convert all values to BoolArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in BoolArray", err) + return 0, fmt.Errorf("%v in BoolArray", err) } index++ @@ -233,7 +233,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -248,7 +248,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -256,7 +256,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *BoolArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -272,7 +272,7 @@ func (src *BoolArray) assignToRecursive(value reflect.Value, index, dimension in if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -290,14 +290,14 @@ func (src *BoolArray) assignToRecursive(value reflect.Value, index, dimension in return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from BoolArray") + return 0, fmt.Errorf("cannot assign all values from BoolArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from BoolArray") + return 0, fmt.Errorf("cannot assign all values from BoolArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -456,7 +456,7 @@ func (src BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("bool"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "bool") + return nil, fmt.Errorf("unable to find oid for type name %v", "bool") } for i := range src.Elements { @@ -500,7 +500,7 @@ func (dst *BoolArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/box.go b/box.go index 75d50f98..27fb829e 100644 --- a/box.go +++ b/box.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Box struct { @@ -18,7 +17,7 @@ type Box struct { } func (dst *Box) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Box", src) + return fmt.Errorf("cannot convert %v to Box", src) } func (dst Box) Get() interface{} { @@ -33,7 +32,7 @@ func (dst Box) Get() interface{} { } func (src *Box) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { @@ -43,7 +42,7 @@ func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 11 { - return errors.Errorf("invalid length for Box: %v", len(src)) + return fmt.Errorf("invalid length for Box: %v", len(src)) } str := string(src[1:]) @@ -90,7 +89,7 @@ func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 32 { - return errors.Errorf("invalid length for Box: %v", len(src)) + return fmt.Errorf("invalid length for Box: %v", len(src)) } x1 := binary.BigEndian.Uint64(src) @@ -157,7 +156,7 @@ func (dst *Box) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/bpchar_array.go b/bpchar_array.go index d48b2b53..8e792214 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -5,10 +5,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type BPCharArray struct { @@ -96,7 +96,7 @@ func (dst *BPCharArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for BPCharArray", src) + return fmt.Errorf("cannot find dimensions of %v for BPCharArray", src) } if elementsLength == 0 { *dst = BPCharArray{Status: Present} @@ -106,7 +106,7 @@ func (dst *BPCharArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to BPCharArray", src) + return fmt.Errorf("cannot convert %v to BPCharArray", src) } *dst = BPCharArray{ @@ -137,7 +137,7 @@ func (dst *BPCharArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to BPCharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to BPCharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -155,7 +155,7 @@ func (dst *BPCharArray) setRecursive(value reflect.Value, index, dimension int) valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -168,10 +168,10 @@ func (dst *BPCharArray) setRecursive(value reflect.Value, index, dimension int) return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to BPCharArray") + return 0, fmt.Errorf("cannot convert all values to BPCharArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in BPCharArray", err) + return 0, fmt.Errorf("%v in BPCharArray", err) } index++ @@ -233,7 +233,7 @@ func (src *BPCharArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -248,7 +248,7 @@ func (src *BPCharArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -256,7 +256,7 @@ func (src *BPCharArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *BPCharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -272,7 +272,7 @@ func (src *BPCharArray) assignToRecursive(value reflect.Value, index, dimension if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -290,14 +290,14 @@ func (src *BPCharArray) assignToRecursive(value reflect.Value, index, dimension return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from BPCharArray") + return 0, fmt.Errorf("cannot assign all values from BPCharArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from BPCharArray") + return 0, fmt.Errorf("cannot assign all values from BPCharArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -456,7 +456,7 @@ func (src BPCharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("bpchar"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "bpchar") + return nil, fmt.Errorf("unable to find oid for type name %v", "bpchar") } for i := range src.Elements { @@ -500,7 +500,7 @@ func (dst *BPCharArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/bytea.go b/bytea.go index b9e4d15a..67eba350 100644 --- a/bytea.go +++ b/bytea.go @@ -3,8 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/hex" - - errors "golang.org/x/xerrors" + "fmt" ) type Bytea struct { @@ -36,7 +35,7 @@ func (dst *Bytea) Set(src interface{}) error { if originalSrc, ok := underlyingBytesType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Bytea", value) + return fmt.Errorf("cannot convert %v to Bytea", value) } return nil @@ -66,13 +65,13 @@ func (src *Bytea) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } // DecodeText only supports the hex format. This has been the default since @@ -84,7 +83,7 @@ func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 2 || src[0] != '\\' || src[1] != 'x' { - return errors.Errorf("invalid hex format") + return fmt.Errorf("invalid hex format") } buf := make([]byte, (len(src)-2)/2) @@ -148,7 +147,7 @@ func (dst *Bytea) Scan(src interface{}) error { return nil } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/bytea_array.go b/bytea_array.go index 14d8afad..69d1ceb9 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -5,10 +5,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type ByteaArray struct { @@ -77,7 +77,7 @@ func (dst *ByteaArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for ByteaArray", src) + return fmt.Errorf("cannot find dimensions of %v for ByteaArray", src) } if elementsLength == 0 { *dst = ByteaArray{Status: Present} @@ -87,7 +87,7 @@ func (dst *ByteaArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to ByteaArray", src) + return fmt.Errorf("cannot convert %v to ByteaArray", src) } *dst = ByteaArray{ @@ -118,7 +118,7 @@ func (dst *ByteaArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to ByteaArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to ByteaArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -136,7 +136,7 @@ func (dst *ByteaArray) setRecursive(value reflect.Value, index, dimension int) ( valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -149,10 +149,10 @@ func (dst *ByteaArray) setRecursive(value reflect.Value, index, dimension int) ( return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to ByteaArray") + return 0, fmt.Errorf("cannot convert all values to ByteaArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in ByteaArray", err) + return 0, fmt.Errorf("%v in ByteaArray", err) } index++ @@ -205,7 +205,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -220,7 +220,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -228,7 +228,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *ByteaArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -244,7 +244,7 @@ func (src *ByteaArray) assignToRecursive(value reflect.Value, index, dimension i if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -262,14 +262,14 @@ func (src *ByteaArray) assignToRecursive(value reflect.Value, index, dimension i return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from ByteaArray") + return 0, fmt.Errorf("cannot assign all values from ByteaArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from ByteaArray") + return 0, fmt.Errorf("cannot assign all values from ByteaArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -428,7 +428,7 @@ func (src ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("bytea"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "bytea") + return nil, fmt.Errorf("unable to find oid for type name %v", "bytea") } for i := range src.Elements { @@ -472,7 +472,7 @@ func (dst *ByteaArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/cidr_array.go b/cidr_array.go index 3ac1b183..783c599c 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -5,11 +5,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "net" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type CIDRArray struct { @@ -116,7 +116,7 @@ func (dst *CIDRArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for CIDRArray", src) + return fmt.Errorf("cannot find dimensions of %v for CIDRArray", src) } if elementsLength == 0 { *dst = CIDRArray{Status: Present} @@ -126,7 +126,7 @@ func (dst *CIDRArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to CIDRArray", src) + return fmt.Errorf("cannot convert %v to CIDRArray", src) } *dst = CIDRArray{ @@ -157,7 +157,7 @@ func (dst *CIDRArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to CIDRArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to CIDRArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -175,7 +175,7 @@ func (dst *CIDRArray) setRecursive(value reflect.Value, index, dimension int) (i valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -188,10 +188,10 @@ func (dst *CIDRArray) setRecursive(value reflect.Value, index, dimension int) (i return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to CIDRArray") + return 0, fmt.Errorf("cannot convert all values to CIDRArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in CIDRArray", err) + return 0, fmt.Errorf("%v in CIDRArray", err) } index++ @@ -262,7 +262,7 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -277,7 +277,7 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -285,7 +285,7 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *CIDRArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -301,7 +301,7 @@ func (src *CIDRArray) assignToRecursive(value reflect.Value, index, dimension in if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -319,14 +319,14 @@ func (src *CIDRArray) assignToRecursive(value reflect.Value, index, dimension in return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from CIDRArray") + return 0, fmt.Errorf("cannot assign all values from CIDRArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from CIDRArray") + return 0, fmt.Errorf("cannot assign all values from CIDRArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -485,7 +485,7 @@ func (src CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("cidr"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "cidr") + return nil, fmt.Errorf("unable to find oid for type name %v", "cidr") } for i := range src.Elements { @@ -529,7 +529,7 @@ func (dst *CIDRArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/circle.go b/circle.go index d3f8b38a..4279650e 100644 --- a/circle.go +++ b/circle.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Circle struct { @@ -19,7 +18,7 @@ type Circle struct { } func (dst *Circle) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Circle", src) + return fmt.Errorf("cannot convert %v to Circle", src) } func (dst Circle) Get() interface{} { @@ -34,7 +33,7 @@ func (dst Circle) Get() interface{} { } func (src *Circle) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { @@ -44,7 +43,7 @@ func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 9 { - return errors.Errorf("invalid length for Circle: %v", len(src)) + return fmt.Errorf("invalid length for Circle: %v", len(src)) } str := string(src[2:]) @@ -80,7 +79,7 @@ func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 24 { - return errors.Errorf("invalid length for Circle: %v", len(src)) + return fmt.Errorf("invalid length for Circle: %v", len(src)) } x := binary.BigEndian.Uint64(src) @@ -142,7 +141,7 @@ func (dst *Circle) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/composite_fields.go b/composite_fields.go index af7bab1e..b6d09fcf 100644 --- a/composite_fields.go +++ b/composite_fields.go @@ -1,8 +1,6 @@ package pgtype -import ( - errors "golang.org/x/xerrors" -) +import "fmt" // CompositeFields scans the fields of a composite type into the elements of the CompositeFields value. To scan a // nullable value use a *CompositeFields. It will be set to nil in case of null. @@ -13,11 +11,11 @@ type CompositeFields []interface{} func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error { if len(cf) == 0 { - return errors.Errorf("cannot decode into empty CompositeFields") + return fmt.Errorf("cannot decode into empty CompositeFields") } if src == nil { - return errors.Errorf("cannot decode unexpected null into CompositeFields") + return fmt.Errorf("cannot decode unexpected null into CompositeFields") } scanner := NewCompositeBinaryScanner(ci, src) @@ -35,11 +33,11 @@ func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error { func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error { if len(cf) == 0 { - return errors.Errorf("cannot decode into empty CompositeFields") + return fmt.Errorf("cannot decode into empty CompositeFields") } if src == nil { - return errors.Errorf("cannot decode unexpected null into CompositeFields") + return fmt.Errorf("cannot decode unexpected null into CompositeFields") } scanner := NewCompositeTextScanner(ci, src) @@ -87,7 +85,7 @@ func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) for _, f := range cf { dt, ok := ci.DataTypeForValue(f) if !ok { - return nil, errors.Errorf("Unknown OID for %#v", f) + return nil, fmt.Errorf("Unknown OID for %#v", f) } if binaryEncoder, ok := f.(BinaryEncoder); ok { @@ -100,7 +98,7 @@ func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) if binaryEncoder, ok := dt.Value.(BinaryEncoder); ok { b.AppendEncoder(dt.OID, binaryEncoder) } else { - return nil, errors.Errorf("Cannot encode binary format for %v", f) + return nil, fmt.Errorf("Cannot encode binary format for %v", f) } } } diff --git a/composite_type.go b/composite_type.go index cbe0a245..7c8dbcd5 100644 --- a/composite_type.go +++ b/composite_type.go @@ -2,11 +2,12 @@ package pgtype import ( "encoding/binary" + "errors" + "fmt" "reflect" "strings" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type CompositeTypeField struct { @@ -31,13 +32,13 @@ func NewCompositeType(typeName string, fields []CompositeTypeField, ci *ConnInfo for i := range fields { dt, ok := ci.DataTypeForOID(fields[i].OID) if !ok { - return nil, errors.Errorf("no data type registered for oid: %d", fields[i].OID) + return nil, fmt.Errorf("no data type registered for oid: %d", fields[i].OID) } value := NewValue(dt.Value) valueTranscoder, ok := value.(ValueTranscoder) if !ok { - return nil, errors.Errorf("data type for oid does not implement ValueTranscoder: %d", fields[i].OID) + return nil, fmt.Errorf("data type for oid does not implement ValueTranscoder: %d", fields[i].OID) } valueTranscoders[i] = valueTranscoder @@ -102,7 +103,7 @@ func (dst *CompositeType) Set(src interface{}) error { switch value := src.(type) { case []interface{}: if len(value) != len(dst.valueTranscoders) { - return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.valueTranscoders)) + return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.valueTranscoders)) } for i, v := range value { if err := dst.valueTranscoders[i].Set(v); err != nil { @@ -117,7 +118,7 @@ func (dst *CompositeType) Set(src interface{}) error { } return dst.Set(*value) default: - return errors.Errorf("Can not convert %v to Composite", src) + return fmt.Errorf("Can not convert %v to Composite", src) } return nil @@ -130,7 +131,7 @@ func (src CompositeType) AssignTo(dst interface{}) error { switch v := dst.(type) { case []interface{}: if len(v) != len(src.valueTranscoders) { - return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders)) + return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders)) } for i := range src.valueTranscoders { if v[i] == nil { @@ -139,7 +140,7 @@ func (src CompositeType) AssignTo(dst interface{}) error { err := assignToOrSet(src.valueTranscoders[i], v[i]) if err != nil { - return errors.Errorf("unable to assign to dst[%d]: %v", i, err) + return fmt.Errorf("unable to assign to dst[%d]: %v", i, err) } } return nil @@ -153,12 +154,12 @@ func (src CompositeType) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func assignToOrSet(src Value, dst interface{}) error { @@ -210,7 +211,7 @@ func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) { for i := range exportedFields { err := assignToOrSet(src.valueTranscoders[i], dstElemValue.Field(exportedFields[i]).Addr().Interface()) if err != nil { - return true, errors.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err) + return true, fmt.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err) } } @@ -310,7 +311,7 @@ type CompositeBinaryScanner struct { func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner { rp := 0 if len(src[rp:]) < 4 { - return &CompositeBinaryScanner{err: errors.Errorf("Record incomplete %v", src)} + return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)} } fieldCount := int32(binary.BigEndian.Uint32(src[rp:])) @@ -362,7 +363,7 @@ func (cfs *CompositeBinaryScanner) Next() bool { } if len(cfs.src[cfs.rp:]) < 8 { - cfs.err = errors.Errorf("Record incomplete %v", cfs.src) + cfs.err = fmt.Errorf("Record incomplete %v", cfs.src) return false } cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:]) @@ -373,7 +374,7 @@ func (cfs *CompositeBinaryScanner) Next() bool { if fieldLen >= 0 { if len(cfs.src[cfs.rp:]) < fieldLen { - cfs.err = errors.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src) + cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src) return false } cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen] @@ -416,15 +417,15 @@ type CompositeTextScanner struct { // NewCompositeTextScanner a scanner over a text encoded composite value. func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner { if len(src) < 2 { - return &CompositeTextScanner{err: errors.Errorf("Record incomplete %v", src)} + return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)} } if src[0] != '(' { - return &CompositeTextScanner{err: errors.Errorf("composite text format must start with '('")} + return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")} } if src[len(src)-1] != ')' { - return &CompositeTextScanner{err: errors.Errorf("composite text format must end with ')'")} + return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")} } return &CompositeTextScanner{ @@ -543,7 +544,7 @@ func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) { dt, ok := b.ci.DataTypeForOID(oid) if !ok { - b.err = errors.Errorf("unknown data type for OID: %d", oid) + b.err = fmt.Errorf("unknown data type for OID: %d", oid) return } @@ -555,7 +556,7 @@ func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) { binaryEncoder, ok := dt.Value.(BinaryEncoder) if !ok { - b.err = errors.Errorf("unable to encode binary for OID: %d", oid) + b.err = fmt.Errorf("unable to encode binary for OID: %d", oid) return } @@ -618,7 +619,7 @@ func (b *CompositeTextBuilder) AppendValue(field interface{}) { dt, ok := b.ci.DataTypeForValue(field) if !ok { - b.err = errors.Errorf("unknown data type for field: %v", field) + b.err = fmt.Errorf("unknown data type for field: %v", field) return } @@ -630,7 +631,7 @@ func (b *CompositeTextBuilder) AppendValue(field interface{}) { textEncoder, ok := dt.Value.(TextEncoder) if !ok { - b.err = errors.Errorf("unable to encode text for value: %v", field) + b.err = fmt.Errorf("unable to encode text for value: %v", field) return } diff --git a/convert.go b/convert.go index 193f771f..8ae599b9 100644 --- a/convert.go +++ b/convert.go @@ -2,11 +2,10 @@ package pgtype import ( "database/sql" + "fmt" "math" "reflect" "time" - - errors "golang.org/x/xerrors" ) const maxUint = ^uint(0) @@ -212,70 +211,70 @@ func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { switch v := dst.(type) { case *int: if srcVal < int64(minInt) { - return errors.Errorf("%d is less than minimum value for int", srcVal) + return fmt.Errorf("%d is less than minimum value for int", srcVal) } else if srcVal > int64(maxInt) { - return errors.Errorf("%d is greater than maximum value for int", srcVal) + return fmt.Errorf("%d is greater than maximum value for int", srcVal) } *v = int(srcVal) case *int8: if srcVal < math.MinInt8 { - return errors.Errorf("%d is less than minimum value for int8", srcVal) + return fmt.Errorf("%d is less than minimum value for int8", srcVal) } else if srcVal > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for int8", srcVal) + return fmt.Errorf("%d is greater than maximum value for int8", srcVal) } *v = int8(srcVal) case *int16: if srcVal < math.MinInt16 { - return errors.Errorf("%d is less than minimum value for int16", srcVal) + return fmt.Errorf("%d is less than minimum value for int16", srcVal) } else if srcVal > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for int16", srcVal) + return fmt.Errorf("%d is greater than maximum value for int16", srcVal) } *v = int16(srcVal) case *int32: if srcVal < math.MinInt32 { - return errors.Errorf("%d is less than minimum value for int32", srcVal) + return fmt.Errorf("%d is less than minimum value for int32", srcVal) } else if srcVal > math.MaxInt32 { - return errors.Errorf("%d is greater than maximum value for int32", srcVal) + return fmt.Errorf("%d is greater than maximum value for int32", srcVal) } *v = int32(srcVal) case *int64: if srcVal < math.MinInt64 { - return errors.Errorf("%d is less than minimum value for int64", srcVal) + return fmt.Errorf("%d is less than minimum value for int64", srcVal) } else if srcVal > math.MaxInt64 { - return errors.Errorf("%d is greater than maximum value for int64", srcVal) + return fmt.Errorf("%d is greater than maximum value for int64", srcVal) } *v = int64(srcVal) case *uint: if srcVal < 0 { - return errors.Errorf("%d is less than zero for uint", srcVal) + return fmt.Errorf("%d is less than zero for uint", srcVal) } else if uint64(srcVal) > uint64(maxUint) { - return errors.Errorf("%d is greater than maximum value for uint", srcVal) + return fmt.Errorf("%d is greater than maximum value for uint", srcVal) } *v = uint(srcVal) case *uint8: if srcVal < 0 { - return errors.Errorf("%d is less than zero for uint8", srcVal) + return fmt.Errorf("%d is less than zero for uint8", srcVal) } else if srcVal > math.MaxUint8 { - return errors.Errorf("%d is greater than maximum value for uint8", srcVal) + return fmt.Errorf("%d is greater than maximum value for uint8", srcVal) } *v = uint8(srcVal) case *uint16: if srcVal < 0 { - return errors.Errorf("%d is less than zero for uint32", srcVal) + return fmt.Errorf("%d is less than zero for uint32", srcVal) } else if srcVal > math.MaxUint16 { - return errors.Errorf("%d is greater than maximum value for uint16", srcVal) + return fmt.Errorf("%d is greater than maximum value for uint16", srcVal) } *v = uint16(srcVal) case *uint32: if srcVal < 0 { - return errors.Errorf("%d is less than zero for uint32", srcVal) + return fmt.Errorf("%d is less than zero for uint32", srcVal) } else if srcVal > math.MaxUint32 { - return errors.Errorf("%d is greater than maximum value for uint32", srcVal) + return fmt.Errorf("%d is greater than maximum value for uint32", srcVal) } *v = uint32(srcVal) case *uint64: if srcVal < 0 { - return errors.Errorf("%d is less than zero for uint64", srcVal) + return fmt.Errorf("%d is less than zero for uint64", srcVal) } *v = uint64(srcVal) case sql.Scanner: @@ -293,22 +292,22 @@ func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { return int64AssignTo(srcVal, srcStatus, el.Interface()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if el.OverflowInt(int64(srcVal)) { - return errors.Errorf("cannot put %d into %T", srcVal, dst) + return fmt.Errorf("cannot put %d into %T", srcVal, dst) } el.SetInt(int64(srcVal)) return nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: if srcVal < 0 { - return errors.Errorf("%d is less than zero for %T", srcVal, dst) + return fmt.Errorf("%d is less than zero for %T", srcVal, dst) } if el.OverflowUint(uint64(srcVal)) { - return errors.Errorf("cannot put %d into %T", srcVal, dst) + return fmt.Errorf("cannot put %d into %T", srcVal, dst) } el.SetUint(uint64(srcVal)) return nil } } - return errors.Errorf("cannot assign %v into %T", srcVal, dst) + return fmt.Errorf("cannot assign %v into %T", srcVal, dst) } return nil } @@ -322,7 +321,7 @@ func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { } } - return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) } func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { @@ -350,7 +349,7 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { } } } - return errors.Errorf("cannot assign %v into %T", srcVal, dst) + return fmt.Errorf("cannot assign %v into %T", srcVal, dst) } return nil } @@ -364,7 +363,7 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { } } - return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) } func NullAssignTo(dst interface{}) error { diff --git a/custom_composite_test.go b/custom_composite_test.go index 296fcc90..9ca8dd5e 100644 --- a/custom_composite_test.go +++ b/custom_composite_test.go @@ -2,12 +2,12 @@ package pgtype_test import ( "context" + "errors" "fmt" "os" "github.com/jackc/pgtype" pgx "github.com/jackc/pgx/v4" - errors "golang.org/x/xerrors" ) type MyType struct { diff --git a/database_sql.go b/database_sql.go index f54a750d..9d1cf822 100644 --- a/database_sql.go +++ b/database_sql.go @@ -2,8 +2,7 @@ package pgtype import ( "database/sql/driver" - - errors "golang.org/x/xerrors" + "errors" ) func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { diff --git a/date.go b/date.go index 59e225df..e8d21a78 100644 --- a/date.go +++ b/date.go @@ -4,10 +4,10 @@ import ( "database/sql/driver" "encoding/binary" "encoding/json" + "fmt" "time" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Date struct { @@ -55,7 +55,7 @@ func (dst *Date) Set(src interface{}) error { if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Date", value) + return fmt.Errorf("cannot convert %v to Date", value) } return nil @@ -81,7 +81,7 @@ func (src *Date) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: if src.InfinityModifier != None { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } *v = src.Time return nil @@ -89,13 +89,13 @@ func (src *Date) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { @@ -129,7 +129,7 @@ func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 4 { - return errors.Errorf("invalid length for date: %v", len(src)) + return fmt.Errorf("invalid length for date: %v", len(src)) } dayOffset := int32(binary.BigEndian.Uint32(src)) @@ -213,7 +213,7 @@ func (dst *Date) Scan(src interface{}) error { return nil } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/date_array.go b/date_array.go index 0c623b8f..24152fa0 100644 --- a/date_array.go +++ b/date_array.go @@ -5,11 +5,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "time" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type DateArray struct { @@ -97,7 +97,7 @@ func (dst *DateArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for DateArray", src) + return fmt.Errorf("cannot find dimensions of %v for DateArray", src) } if elementsLength == 0 { *dst = DateArray{Status: Present} @@ -107,7 +107,7 @@ func (dst *DateArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to DateArray", src) + return fmt.Errorf("cannot convert %v to DateArray", src) } *dst = DateArray{ @@ -138,7 +138,7 @@ func (dst *DateArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to DateArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to DateArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -156,7 +156,7 @@ func (dst *DateArray) setRecursive(value reflect.Value, index, dimension int) (i valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -169,10 +169,10 @@ func (dst *DateArray) setRecursive(value reflect.Value, index, dimension int) (i return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to DateArray") + return 0, fmt.Errorf("cannot convert all values to DateArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in DateArray", err) + return 0, fmt.Errorf("%v in DateArray", err) } index++ @@ -234,7 +234,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -249,7 +249,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -257,7 +257,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *DateArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -273,7 +273,7 @@ func (src *DateArray) assignToRecursive(value reflect.Value, index, dimension in if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -291,14 +291,14 @@ func (src *DateArray) assignToRecursive(value reflect.Value, index, dimension in return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from DateArray") + return 0, fmt.Errorf("cannot assign all values from DateArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from DateArray") + return 0, fmt.Errorf("cannot assign all values from DateArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -457,7 +457,7 @@ func (src DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("date"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "date") + return nil, fmt.Errorf("unable to find oid for type name %v", "date") } for i := range src.Elements { @@ -501,7 +501,7 @@ func (dst *DateArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/daterange.go b/daterange.go index 7b9af795..63164a5a 100644 --- a/daterange.go +++ b/daterange.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" + "fmt" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Daterange struct { @@ -30,7 +30,7 @@ func (dst *Daterange) Set(src interface{}) error { case string: return dst.DecodeText(nil, []byte(value)) default: - return errors.Errorf("cannot convert %v to Daterange", src) + return fmt.Errorf("cannot convert %v to Daterange", src) } return nil @@ -48,7 +48,7 @@ func (dst Daterange) Get() interface{} { } func (src *Daterange) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *Daterange) DecodeText(ci *ConnInfo, src []byte) error { @@ -137,7 +137,7 @@ func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -147,7 +147,7 @@ func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -158,7 +158,7 @@ func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -168,7 +168,7 @@ func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -192,7 +192,7 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -202,7 +202,7 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -218,7 +218,7 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -233,7 +233,7 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -258,7 +258,7 @@ func (dst *Daterange) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/enum_array.go b/enum_array.go index cf7c7066..59b5a3ed 100644 --- a/enum_array.go +++ b/enum_array.go @@ -4,9 +4,8 @@ package pgtype import ( "database/sql/driver" + "fmt" "reflect" - - errors "golang.org/x/xerrors" ) type EnumArray struct { @@ -94,7 +93,7 @@ func (dst *EnumArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for EnumArray", src) + return fmt.Errorf("cannot find dimensions of %v for EnumArray", src) } if elementsLength == 0 { *dst = EnumArray{Status: Present} @@ -104,7 +103,7 @@ func (dst *EnumArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to EnumArray", src) + return fmt.Errorf("cannot convert %v to EnumArray", src) } *dst = EnumArray{ @@ -135,7 +134,7 @@ func (dst *EnumArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to EnumArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to EnumArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -153,7 +152,7 @@ func (dst *EnumArray) setRecursive(value reflect.Value, index, dimension int) (i valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -166,10 +165,10 @@ func (dst *EnumArray) setRecursive(value reflect.Value, index, dimension int) (i return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to EnumArray") + return 0, fmt.Errorf("cannot convert all values to EnumArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in EnumArray", err) + return 0, fmt.Errorf("%v in EnumArray", err) } index++ @@ -231,7 +230,7 @@ func (src *EnumArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -246,7 +245,7 @@ func (src *EnumArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -254,7 +253,7 @@ func (src *EnumArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *EnumArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -270,7 +269,7 @@ func (src *EnumArray) assignToRecursive(value reflect.Value, index, dimension in if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -288,14 +287,14 @@ func (src *EnumArray) assignToRecursive(value reflect.Value, index, dimension in return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from EnumArray") + return 0, fmt.Errorf("cannot assign all values from EnumArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from EnumArray") + return 0, fmt.Errorf("cannot assign all values from EnumArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -412,7 +411,7 @@ func (dst *EnumArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/enum_type.go b/enum_type.go index d3a1df5c..d340320f 100644 --- a/enum_type.go +++ b/enum_type.go @@ -1,6 +1,6 @@ package pgtype -import errors "golang.org/x/xerrors" +import "fmt" // EnumType represents a enum type. While it implements Value, this is only in service of its type conversion duties // when registered as a data type in a ConnType. It should not be used directly as a Value. @@ -79,7 +79,7 @@ func (dst *EnumType) Set(src interface{}) error { if originalSrc, ok := underlyingStringType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to enum %s", value, dst.typeName) + return fmt.Errorf("cannot convert %v to enum %s", value, dst.typeName) } return nil @@ -111,13 +111,13 @@ func (src *EnumType) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (EnumType) PreferredResultFormat() int16 { diff --git a/ext/gofrs-uuid/uuid.go b/ext/gofrs-uuid/uuid.go index e29933c9..a5e0a3c3 100644 --- a/ext/gofrs-uuid/uuid.go +++ b/ext/gofrs-uuid/uuid.go @@ -2,8 +2,8 @@ package uuid import ( "database/sql/driver" - - errors "golang.org/x/xerrors" + "errors" + "fmt" "github.com/gofrs/uuid" "github.com/jackc/pgtype" @@ -37,7 +37,7 @@ func (dst *UUID) Set(src interface{}) error { *dst = UUID{UUID: uuid.UUID(value), Status: pgtype.Present} case []byte: if len(value) != 16 { - return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) } *dst = UUID{Status: pgtype.Present} copy(dst.UUID[:], value) @@ -51,7 +51,7 @@ func (dst *UUID) Set(src interface{}) error { // If all else fails see if pgtype.UUID can handle it. If so, translate through that. pgUUID := &pgtype.UUID{} if err := pgUUID.Set(value); err != nil { - return errors.Errorf("cannot convert %v to UUID", value) + return fmt.Errorf("cannot convert %v to UUID", value) } *dst = UUID{UUID: uuid.UUID(pgUUID.Bytes), Status: pgUUID.Status} @@ -92,13 +92,13 @@ func (src *UUID) AssignTo(dst interface{}) error { if nextDst, retry := pgtype.GetAssignToDstType(v); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case pgtype.Null: return pgtype.NullAssignTo(dst) } - return errors.Errorf("cannot assign %v into %T", src, dst) + return fmt.Errorf("cannot assign %v into %T", src, dst) } func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { @@ -123,7 +123,7 @@ func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { } if len(src) != 16 { - return errors.Errorf("invalid length for UUID: %v", len(src)) + return fmt.Errorf("invalid length for UUID: %v", len(src)) } *dst = UUID{Status: pgtype.Present} @@ -167,7 +167,7 @@ func (dst *UUID) Scan(src interface{}) error { return dst.DecodeText(nil, src) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index 148589a4..e8694111 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -2,10 +2,10 @@ package numeric import ( "database/sql/driver" + "errors" + "fmt" "strconv" - errors "golang.org/x/xerrors" - "github.com/jackc/pgtype" "github.com/shopspring/decimal" ) @@ -78,17 +78,17 @@ func (dst *Numeric) Set(src interface{}) error { // If all else fails see if pgtype.Numeric can handle it. If so, translate through that. num := &pgtype.Numeric{} if err := num.Set(value); err != nil { - return errors.Errorf("cannot convert %v to Numeric", value) + return fmt.Errorf("cannot convert %v to Numeric", value) } buf, err := num.EncodeText(nil, nil) if err != nil { - return errors.Errorf("cannot convert %v to Numeric", value) + return fmt.Errorf("cannot convert %v to Numeric", value) } dec, err := decimal.NewFromString(string(buf)) if err != nil { - return errors.Errorf("cannot convert %v to Numeric", value) + return fmt.Errorf("cannot convert %v to Numeric", value) } *dst = Numeric{Decimal: dec, Status: pgtype.Present} } @@ -121,99 +121,99 @@ func (src *Numeric) AssignTo(dst interface{}) error { *v = f case *int: if src.Decimal.Exponent() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseInt(src.Decimal.String(), 10, strconv.IntSize) if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } *v = int(n) case *int8: if src.Decimal.Exponent() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseInt(src.Decimal.String(), 10, 8) if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } *v = int8(n) case *int16: if src.Decimal.Exponent() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseInt(src.Decimal.String(), 10, 16) if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } *v = int16(n) case *int32: if src.Decimal.Exponent() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseInt(src.Decimal.String(), 10, 32) if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } *v = int32(n) case *int64: if src.Decimal.Exponent() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseInt(src.Decimal.String(), 10, 64) if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } *v = int64(n) case *uint: if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseUint(src.Decimal.String(), 10, strconv.IntSize) if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } *v = uint(n) case *uint8: if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseUint(src.Decimal.String(), 10, 8) if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } *v = uint8(n) case *uint16: if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseUint(src.Decimal.String(), 10, 16) if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } *v = uint16(n) case *uint32: if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseUint(src.Decimal.String(), 10, 32) if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } *v = uint32(n) case *uint64: if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseUint(src.Decimal.String(), 10, 64) if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) + return fmt.Errorf("cannot convert %v to %T", dst, *v) } *v = uint64(n) default: if nextDst, retry := pgtype.GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case pgtype.Null: return pgtype.NullAssignTo(dst) @@ -300,7 +300,7 @@ func (dst *Numeric) Scan(src interface{}) error { return dst.DecodeText(nil, src) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/float4.go b/float4.go index 5faad54d..89b9e8fa 100644 --- a/float4.go +++ b/float4.go @@ -3,11 +3,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "math" "strconv" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Float4 struct { @@ -46,42 +46,42 @@ func (dst *Float4) Set(src interface{}) error { if int32(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return errors.Errorf("%v cannot be exactly represented as float32", value) + return fmt.Errorf("%v cannot be exactly represented as float32", value) } case uint32: f32 := float32(value) if uint32(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return errors.Errorf("%v cannot be exactly represented as float32", value) + return fmt.Errorf("%v cannot be exactly represented as float32", value) } case int64: f32 := float32(value) if int64(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return errors.Errorf("%v cannot be exactly represented as float32", value) + return fmt.Errorf("%v cannot be exactly represented as float32", value) } case uint64: f32 := float32(value) if uint64(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return errors.Errorf("%v cannot be exactly represented as float32", value) + return fmt.Errorf("%v cannot be exactly represented as float32", value) } case int: f32 := float32(value) if int(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return errors.Errorf("%v cannot be exactly represented as float32", value) + return fmt.Errorf("%v cannot be exactly represented as float32", value) } case uint: f32 := float32(value) if uint(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return errors.Errorf("%v cannot be exactly represented as float32", value) + return fmt.Errorf("%v cannot be exactly represented as float32", value) } case string: num, err := strconv.ParseFloat(value, 32) @@ -171,7 +171,7 @@ func (dst *Float4) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Float8", value) + return fmt.Errorf("cannot convert %v to Float8", value) } return nil @@ -214,7 +214,7 @@ func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 4 { - return errors.Errorf("invalid length for float4: %v", len(src)) + return fmt.Errorf("invalid length for float4: %v", len(src)) } n := int32(binary.BigEndian.Uint32(src)) @@ -266,7 +266,7 @@ func (dst *Float4) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/float4_array.go b/float4_array.go index 91b3b0e2..41f2ec8f 100644 --- a/float4_array.go +++ b/float4_array.go @@ -5,10 +5,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Float4Array struct { @@ -96,7 +96,7 @@ func (dst *Float4Array) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for Float4Array", src) + return fmt.Errorf("cannot find dimensions of %v for Float4Array", src) } if elementsLength == 0 { *dst = Float4Array{Status: Present} @@ -106,7 +106,7 @@ func (dst *Float4Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Float4Array", src) + return fmt.Errorf("cannot convert %v to Float4Array", src) } *dst = Float4Array{ @@ -137,7 +137,7 @@ func (dst *Float4Array) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to Float4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to Float4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -155,7 +155,7 @@ func (dst *Float4Array) setRecursive(value reflect.Value, index, dimension int) valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -168,10 +168,10 @@ func (dst *Float4Array) setRecursive(value reflect.Value, index, dimension int) return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to Float4Array") + return 0, fmt.Errorf("cannot convert all values to Float4Array") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in Float4Array", err) + return 0, fmt.Errorf("%v in Float4Array", err) } index++ @@ -233,7 +233,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -248,7 +248,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -256,7 +256,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *Float4Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -272,7 +272,7 @@ func (src *Float4Array) assignToRecursive(value reflect.Value, index, dimension if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -290,14 +290,14 @@ func (src *Float4Array) assignToRecursive(value reflect.Value, index, dimension return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from Float4Array") + return 0, fmt.Errorf("cannot assign all values from Float4Array") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from Float4Array") + return 0, fmt.Errorf("cannot assign all values from Float4Array") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -456,7 +456,7 @@ func (src Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("float4"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "float4") + return nil, fmt.Errorf("unable to find oid for type name %v", "float4") } for i := range src.Elements { @@ -500,7 +500,7 @@ func (dst *Float4Array) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/float8.go b/float8.go index d7412301..4d9e7116 100644 --- a/float8.go +++ b/float8.go @@ -3,11 +3,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "math" "strconv" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Float8 struct { @@ -50,28 +50,28 @@ func (dst *Float8) Set(src interface{}) error { if int64(f64) == value { *dst = Float8{Float: f64, Status: Present} } else { - return errors.Errorf("%v cannot be exactly represented as float64", value) + return fmt.Errorf("%v cannot be exactly represented as float64", value) } case uint64: f64 := float64(value) if uint64(f64) == value { *dst = Float8{Float: f64, Status: Present} } else { - return errors.Errorf("%v cannot be exactly represented as float64", value) + return fmt.Errorf("%v cannot be exactly represented as float64", value) } case int: f64 := float64(value) if int(f64) == value { *dst = Float8{Float: f64, Status: Present} } else { - return errors.Errorf("%v cannot be exactly represented as float64", value) + return fmt.Errorf("%v cannot be exactly represented as float64", value) } case uint: f64 := float64(value) if uint(f64) == value { *dst = Float8{Float: f64, Status: Present} } else { - return errors.Errorf("%v cannot be exactly represented as float64", value) + return fmt.Errorf("%v cannot be exactly represented as float64", value) } case string: num, err := strconv.ParseFloat(value, 64) @@ -161,7 +161,7 @@ func (dst *Float8) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Float8", value) + return fmt.Errorf("cannot convert %v to Float8", value) } return nil @@ -204,7 +204,7 @@ func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 { - return errors.Errorf("invalid length for float4: %v", len(src)) + return fmt.Errorf("invalid length for float4: %v", len(src)) } n := int64(binary.BigEndian.Uint64(src)) @@ -256,7 +256,7 @@ func (dst *Float8) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/float8_array.go b/float8_array.go index 559ee292..836ee19d 100644 --- a/float8_array.go +++ b/float8_array.go @@ -5,10 +5,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Float8Array struct { @@ -96,7 +96,7 @@ func (dst *Float8Array) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for Float8Array", src) + return fmt.Errorf("cannot find dimensions of %v for Float8Array", src) } if elementsLength == 0 { *dst = Float8Array{Status: Present} @@ -106,7 +106,7 @@ func (dst *Float8Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Float8Array", src) + return fmt.Errorf("cannot convert %v to Float8Array", src) } *dst = Float8Array{ @@ -137,7 +137,7 @@ func (dst *Float8Array) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to Float8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to Float8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -155,7 +155,7 @@ func (dst *Float8Array) setRecursive(value reflect.Value, index, dimension int) valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -168,10 +168,10 @@ func (dst *Float8Array) setRecursive(value reflect.Value, index, dimension int) return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to Float8Array") + return 0, fmt.Errorf("cannot convert all values to Float8Array") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in Float8Array", err) + return 0, fmt.Errorf("%v in Float8Array", err) } index++ @@ -233,7 +233,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -248,7 +248,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -256,7 +256,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *Float8Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -272,7 +272,7 @@ func (src *Float8Array) assignToRecursive(value reflect.Value, index, dimension if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -290,14 +290,14 @@ func (src *Float8Array) assignToRecursive(value reflect.Value, index, dimension return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from Float8Array") + return 0, fmt.Errorf("cannot assign all values from Float8Array") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from Float8Array") + return 0, fmt.Errorf("cannot assign all values from Float8Array") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -456,7 +456,7 @@ func (src Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("float8"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "float8") + return nil, fmt.Errorf("unable to find oid for type name %v", "float8") } for i := range src.Elements { @@ -500,7 +500,7 @@ func (dst *Float8Array) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/go.mod b/go.mod index 990e79f3..f213388a 100644 --- a/go.mod +++ b/go.mod @@ -10,5 +10,4 @@ require ( github.com/lib/pq v1.3.0 github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc github.com/stretchr/testify v1.5.1 - golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 ) diff --git a/hstore.go b/hstore.go index ec510df7..18b413c6 100644 --- a/hstore.go +++ b/hstore.go @@ -4,12 +4,12 @@ import ( "bytes" "database/sql/driver" "encoding/binary" + "errors" + "fmt" "strings" "unicode" "unicode/utf8" - errors "golang.org/x/xerrors" - "github.com/jackc/pgio" ) @@ -41,7 +41,7 @@ func (dst *Hstore) Set(src interface{}) error { } *dst = Hstore{Map: m, Status: Present} default: - return errors.Errorf("cannot convert %v to Hstore", src) + return fmt.Errorf("cannot convert %v to Hstore", src) } return nil @@ -66,7 +66,7 @@ func (src *Hstore) AssignTo(dst interface{}) error { *v = make(map[string]string, len(src.Map)) for k, val := range src.Map { if val.Status != Present { - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } (*v)[k] = val.String } @@ -75,13 +75,13 @@ func (src *Hstore) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { @@ -113,7 +113,7 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { rp := 0 if len(src[rp:]) < 4 { - return errors.Errorf("hstore incomplete %v", src) + return fmt.Errorf("hstore incomplete %v", src) } pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 @@ -122,19 +122,19 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { for i := 0; i < pairCount; i++ { if len(src[rp:]) < 4 { - return errors.Errorf("hstore incomplete %v", src) + return fmt.Errorf("hstore incomplete %v", src) } keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 if len(src[rp:]) < keyLen { - return errors.Errorf("hstore incomplete %v", src) + return fmt.Errorf("hstore incomplete %v", src) } key := string(src[rp : rp+keyLen]) rp += keyLen if len(src[rp:]) < 4 { - return errors.Errorf("hstore incomplete %v", src) + return fmt.Errorf("hstore incomplete %v", src) } valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 @@ -338,13 +338,13 @@ func parseHstore(s string) (k []string, v []Text, err error) { case r == 'N': state = hsNul default: - err = errors.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) + err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) } default: - err = errors.Errorf("Invalid character after '=', expecting '>'") + err = fmt.Errorf("Invalid character after '=', expecting '>'") } } else { - err = errors.Errorf("Invalid character '%c' after value, expecting '='", r) + err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) } case hsVal: switch r { @@ -381,7 +381,7 @@ func parseHstore(s string) (k []string, v []Text, err error) { values = append(values, Text{Status: Null}) state = hsNext } else { - err = errors.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) + err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) } case hsNext: if r == ',' { @@ -393,10 +393,10 @@ func parseHstore(s string) (k []string, v []Text, err error) { r, end = p.Consume() state = hsKey default: - err = errors.Errorf("Invalid character '%c' after ', ', expecting \"", r) + err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) } } else { - err = errors.Errorf("Invalid character '%c' after value, expecting ','", r) + err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) } } @@ -430,7 +430,7 @@ func (dst *Hstore) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/hstore_array.go b/hstore_array.go index a44ea629..47b4b3ff 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -5,10 +5,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type HstoreArray struct { @@ -77,7 +77,7 @@ func (dst *HstoreArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for HstoreArray", src) + return fmt.Errorf("cannot find dimensions of %v for HstoreArray", src) } if elementsLength == 0 { *dst = HstoreArray{Status: Present} @@ -87,7 +87,7 @@ func (dst *HstoreArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to HstoreArray", src) + return fmt.Errorf("cannot convert %v to HstoreArray", src) } *dst = HstoreArray{ @@ -118,7 +118,7 @@ func (dst *HstoreArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to HstoreArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to HstoreArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -136,7 +136,7 @@ func (dst *HstoreArray) setRecursive(value reflect.Value, index, dimension int) valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -149,10 +149,10 @@ func (dst *HstoreArray) setRecursive(value reflect.Value, index, dimension int) return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to HstoreArray") + return 0, fmt.Errorf("cannot convert all values to HstoreArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in HstoreArray", err) + return 0, fmt.Errorf("%v in HstoreArray", err) } index++ @@ -205,7 +205,7 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -220,7 +220,7 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -228,7 +228,7 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *HstoreArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -244,7 +244,7 @@ func (src *HstoreArray) assignToRecursive(value reflect.Value, index, dimension if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -262,14 +262,14 @@ func (src *HstoreArray) assignToRecursive(value reflect.Value, index, dimension return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from HstoreArray") + return 0, fmt.Errorf("cannot assign all values from HstoreArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from HstoreArray") + return 0, fmt.Errorf("cannot assign all values from HstoreArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -428,7 +428,7 @@ func (src HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("hstore"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "hstore") + return nil, fmt.Errorf("unable to find oid for type name %v", "hstore") } for i := range src.Elements { @@ -472,7 +472,7 @@ func (dst *HstoreArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/inet.go b/inet.go index b4498191..101b9ab4 100644 --- a/inet.go +++ b/inet.go @@ -2,9 +2,8 @@ package pgtype import ( "database/sql/driver" + "fmt" "net" - - errors "golang.org/x/xerrors" ) // Network address family is dependent on server socket.h value for AF_INET. @@ -73,7 +72,7 @@ func (dst *Inet) Set(src interface{}) error { if originalSrc, ok := underlyingPtrType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Inet", value) + return fmt.Errorf("cannot convert %v to Inet", value) } return nil @@ -104,7 +103,7 @@ func (src *Inet) AssignTo(dst interface{}) error { return nil case *net.IP: if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } *v = make(net.IP, len(src.IPNet.IP)) copy(*v, src.IPNet.IP) @@ -113,13 +112,13 @@ func (src *Inet) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { @@ -157,7 +156,7 @@ func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 && len(src) != 20 { - return errors.Errorf("Received an invalid size for a inet: %d", len(src)) + return fmt.Errorf("Received an invalid size for a inet: %d", len(src)) } // ignore family @@ -202,7 +201,7 @@ func (src Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case net.IPv6len: family = defaultAFInet6 default: - return nil, errors.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) + return nil, fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) } buf = append(buf, family) @@ -234,7 +233,7 @@ func (dst *Inet) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/inet_array.go b/inet_array.go index 30adeabb..2460a1c4 100644 --- a/inet_array.go +++ b/inet_array.go @@ -5,11 +5,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "net" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type InetArray struct { @@ -116,7 +116,7 @@ func (dst *InetArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for InetArray", src) + return fmt.Errorf("cannot find dimensions of %v for InetArray", src) } if elementsLength == 0 { *dst = InetArray{Status: Present} @@ -126,7 +126,7 @@ func (dst *InetArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to InetArray", src) + return fmt.Errorf("cannot convert %v to InetArray", src) } *dst = InetArray{ @@ -157,7 +157,7 @@ func (dst *InetArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to InetArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to InetArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -175,7 +175,7 @@ func (dst *InetArray) setRecursive(value reflect.Value, index, dimension int) (i valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -188,10 +188,10 @@ func (dst *InetArray) setRecursive(value reflect.Value, index, dimension int) (i return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to InetArray") + return 0, fmt.Errorf("cannot convert all values to InetArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in InetArray", err) + return 0, fmt.Errorf("%v in InetArray", err) } index++ @@ -262,7 +262,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -277,7 +277,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -285,7 +285,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *InetArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -301,7 +301,7 @@ func (src *InetArray) assignToRecursive(value reflect.Value, index, dimension in if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -319,14 +319,14 @@ func (src *InetArray) assignToRecursive(value reflect.Value, index, dimension in return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from InetArray") + return 0, fmt.Errorf("cannot assign all values from InetArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from InetArray") + return 0, fmt.Errorf("cannot assign all values from InetArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -485,7 +485,7 @@ func (src InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("inet"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "inet") + return nil, fmt.Errorf("unable to find oid for type name %v", "inet") } for i := range src.Elements { @@ -529,7 +529,7 @@ func (dst *InetArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/int2.go b/int2.go index b7517881..3eb5aeb5 100644 --- a/int2.go +++ b/int2.go @@ -3,11 +3,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "math" "strconv" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Int2 struct { @@ -37,46 +37,46 @@ func (dst *Int2) Set(src interface{}) error { *dst = Int2{Int: int16(value), Status: Present} case uint16: if value > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) + return fmt.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case int32: if value < math.MinInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) + return fmt.Errorf("%d is greater than maximum value for Int2", value) } if value > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) + return fmt.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case uint32: if value > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) + return fmt.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case int64: if value < math.MinInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) + return fmt.Errorf("%d is greater than maximum value for Int2", value) } if value > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) + return fmt.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case uint64: if value > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) + return fmt.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case int: if value < math.MinInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) + return fmt.Errorf("%d is greater than maximum value for Int2", value) } if value > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) + return fmt.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case uint: if value > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) + return fmt.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case string: @@ -87,12 +87,12 @@ func (dst *Int2) Set(src interface{}) error { *dst = Int2{Int: int16(num), Status: Present} case float32: if value > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) + return fmt.Errorf("%f is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case float64: if value > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) + return fmt.Errorf("%f is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case *int8: @@ -177,7 +177,7 @@ func (dst *Int2) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int2", value) + return fmt.Errorf("cannot convert %v to Int2", value) } return nil @@ -220,7 +220,7 @@ func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 2 { - return errors.Errorf("invalid length for int2: %v", len(src)) + return fmt.Errorf("invalid length for int2: %v", len(src)) } n := int16(binary.BigEndian.Uint16(src)) @@ -260,10 +260,10 @@ func (dst *Int2) Scan(src interface{}) error { switch src := src.(type) { case int64: if src < math.MinInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", src) + return fmt.Errorf("%d is greater than maximum value for Int2", src) } if src > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", src) + return fmt.Errorf("%d is greater than maximum value for Int2", src) } *dst = Int2{Int: int16(src), Status: Present} return nil @@ -275,7 +275,7 @@ func (dst *Int2) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/int2_array.go b/int2_array.go index f4bd64cc..a5133845 100644 --- a/int2_array.go +++ b/int2_array.go @@ -5,10 +5,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Int2Array struct { @@ -362,7 +362,7 @@ func (dst *Int2Array) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for Int2Array", src) + return fmt.Errorf("cannot find dimensions of %v for Int2Array", src) } if elementsLength == 0 { *dst = Int2Array{Status: Present} @@ -372,7 +372,7 @@ func (dst *Int2Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int2Array", src) + return fmt.Errorf("cannot convert %v to Int2Array", src) } *dst = Int2Array{ @@ -403,7 +403,7 @@ func (dst *Int2Array) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to Int2Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to Int2Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -421,7 +421,7 @@ func (dst *Int2Array) setRecursive(value reflect.Value, index, dimension int) (i valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -434,10 +434,10 @@ func (dst *Int2Array) setRecursive(value reflect.Value, index, dimension int) (i return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to Int2Array") + return 0, fmt.Errorf("cannot convert all values to Int2Array") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in Int2Array", err) + return 0, fmt.Errorf("%v in Int2Array", err) } index++ @@ -625,7 +625,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -640,7 +640,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -648,7 +648,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *Int2Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -664,7 +664,7 @@ func (src *Int2Array) assignToRecursive(value reflect.Value, index, dimension in if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -682,14 +682,14 @@ func (src *Int2Array) assignToRecursive(value reflect.Value, index, dimension in return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from Int2Array") + return 0, fmt.Errorf("cannot assign all values from Int2Array") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from Int2Array") + return 0, fmt.Errorf("cannot assign all values from Int2Array") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -848,7 +848,7 @@ func (src Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("int2"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "int2") + return nil, fmt.Errorf("unable to find oid for type name %v", "int2") } for i := range src.Elements { @@ -892,7 +892,7 @@ func (dst *Int2Array) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/int4.go b/int4.go index 66652bbe..22b48e5e 100644 --- a/int4.go +++ b/int4.go @@ -4,11 +4,11 @@ import ( "database/sql/driver" "encoding/binary" "encoding/json" + "fmt" "math" "strconv" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Int4 struct { @@ -42,33 +42,33 @@ func (dst *Int4) Set(src interface{}) error { *dst = Int4{Int: int32(value), Status: Present} case uint32: if value > math.MaxInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", value) + return fmt.Errorf("%d is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case int64: if value < math.MinInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", value) + return fmt.Errorf("%d is greater than maximum value for Int4", value) } if value > math.MaxInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", value) + return fmt.Errorf("%d is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case uint64: if value > math.MaxInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", value) + return fmt.Errorf("%d is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case int: if value < math.MinInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", value) + return fmt.Errorf("%d is greater than maximum value for Int4", value) } if value > math.MaxInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", value) + return fmt.Errorf("%d is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case uint: if value > math.MaxInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", value) + return fmt.Errorf("%d is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case string: @@ -79,12 +79,12 @@ func (dst *Int4) Set(src interface{}) error { *dst = Int4{Int: int32(num), Status: Present} case float32: if value > math.MaxInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", value) + return fmt.Errorf("%f is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case float64: if value > math.MaxInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", value) + return fmt.Errorf("%f is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case *int8: @@ -169,7 +169,7 @@ func (dst *Int4) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int4", value) + return fmt.Errorf("cannot convert %v to Int4", value) } return nil @@ -212,7 +212,7 @@ func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 4 { - return errors.Errorf("invalid length for int4: %v", len(src)) + return fmt.Errorf("invalid length for int4: %v", len(src)) } n := int32(binary.BigEndian.Uint32(src)) @@ -252,10 +252,10 @@ func (dst *Int4) Scan(src interface{}) error { switch src := src.(type) { case int64: if src < math.MinInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", src) + return fmt.Errorf("%d is greater than maximum value for Int4", src) } if src > math.MaxInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", src) + return fmt.Errorf("%d is greater than maximum value for Int4", src) } *dst = Int4{Int: int32(src), Status: Present} return nil @@ -267,7 +267,7 @@ func (dst *Int4) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/int4_array.go b/int4_array.go index 528310ff..de26236f 100644 --- a/int4_array.go +++ b/int4_array.go @@ -5,10 +5,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Int4Array struct { @@ -362,7 +362,7 @@ func (dst *Int4Array) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for Int4Array", src) + return fmt.Errorf("cannot find dimensions of %v for Int4Array", src) } if elementsLength == 0 { *dst = Int4Array{Status: Present} @@ -372,7 +372,7 @@ func (dst *Int4Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int4Array", src) + return fmt.Errorf("cannot convert %v to Int4Array", src) } *dst = Int4Array{ @@ -403,7 +403,7 @@ func (dst *Int4Array) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to Int4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to Int4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -421,7 +421,7 @@ func (dst *Int4Array) setRecursive(value reflect.Value, index, dimension int) (i valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -434,10 +434,10 @@ func (dst *Int4Array) setRecursive(value reflect.Value, index, dimension int) (i return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to Int4Array") + return 0, fmt.Errorf("cannot convert all values to Int4Array") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in Int4Array", err) + return 0, fmt.Errorf("%v in Int4Array", err) } index++ @@ -625,7 +625,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -640,7 +640,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -648,7 +648,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *Int4Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -664,7 +664,7 @@ func (src *Int4Array) assignToRecursive(value reflect.Value, index, dimension in if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -682,14 +682,14 @@ func (src *Int4Array) assignToRecursive(value reflect.Value, index, dimension in return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from Int4Array") + return 0, fmt.Errorf("cannot assign all values from Int4Array") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from Int4Array") + return 0, fmt.Errorf("cannot assign all values from Int4Array") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -848,7 +848,7 @@ func (src Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("int4"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "int4") + return nil, fmt.Errorf("unable to find oid for type name %v", "int4") } for i := range src.Elements { @@ -892,7 +892,7 @@ func (dst *Int4Array) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/int4range.go b/int4range.go index 442f2501..c7f51fa6 100644 --- a/int4range.go +++ b/int4range.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" + "fmt" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Int4range struct { @@ -30,7 +30,7 @@ func (dst *Int4range) Set(src interface{}) error { case string: return dst.DecodeText(nil, []byte(value)) default: - return errors.Errorf("cannot convert %v to Int4range", src) + return fmt.Errorf("cannot convert %v to Int4range", src) } return nil @@ -48,7 +48,7 @@ func (dst Int4range) Get() interface{} { } func (src *Int4range) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *Int4range) DecodeText(ci *ConnInfo, src []byte) error { @@ -137,7 +137,7 @@ func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -147,7 +147,7 @@ func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -158,7 +158,7 @@ func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -168,7 +168,7 @@ func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -192,7 +192,7 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -202,7 +202,7 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -218,7 +218,7 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -233,7 +233,7 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -258,7 +258,7 @@ func (dst *Int4range) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/int8.go b/int8.go index f0114194..0e089979 100644 --- a/int8.go +++ b/int8.go @@ -4,11 +4,11 @@ import ( "database/sql/driver" "encoding/binary" "encoding/json" + "fmt" "math" "strconv" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Int8 struct { @@ -46,20 +46,20 @@ func (dst *Int8) Set(src interface{}) error { *dst = Int8{Int: int64(value), Status: Present} case uint64: if value > math.MaxInt64 { - return errors.Errorf("%d is greater than maximum value for Int8", value) + return fmt.Errorf("%d is greater than maximum value for Int8", value) } *dst = Int8{Int: int64(value), Status: Present} case int: if int64(value) < math.MinInt64 { - return errors.Errorf("%d is greater than maximum value for Int8", value) + return fmt.Errorf("%d is greater than maximum value for Int8", value) } if int64(value) > math.MaxInt64 { - return errors.Errorf("%d is greater than maximum value for Int8", value) + return fmt.Errorf("%d is greater than maximum value for Int8", value) } *dst = Int8{Int: int64(value), Status: Present} case uint: if uint64(value) > math.MaxInt64 { - return errors.Errorf("%d is greater than maximum value for Int8", value) + return fmt.Errorf("%d is greater than maximum value for Int8", value) } *dst = Int8{Int: int64(value), Status: Present} case string: @@ -70,12 +70,12 @@ func (dst *Int8) Set(src interface{}) error { *dst = Int8{Int: num, Status: Present} case float32: if value > math.MaxInt64 { - return errors.Errorf("%d is greater than maximum value for Int8", value) + return fmt.Errorf("%f is greater than maximum value for Int8", value) } *dst = Int8{Int: int64(value), Status: Present} case float64: if value > math.MaxInt64 { - return errors.Errorf("%d is greater than maximum value for Int8", value) + return fmt.Errorf("%f is greater than maximum value for Int8", value) } *dst = Int8{Int: int64(value), Status: Present} case *int8: @@ -160,7 +160,7 @@ func (dst *Int8) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int8", value) + return fmt.Errorf("cannot convert %v to Int8", value) } return nil @@ -203,7 +203,7 @@ func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 { - return errors.Errorf("invalid length for int8: %v", len(src)) + return fmt.Errorf("invalid length for int8: %v", len(src)) } n := int64(binary.BigEndian.Uint64(src)) @@ -253,7 +253,7 @@ func (dst *Int8) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/int8_array.go b/int8_array.go index b1e52a97..e405b326 100644 --- a/int8_array.go +++ b/int8_array.go @@ -5,10 +5,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Int8Array struct { @@ -362,7 +362,7 @@ func (dst *Int8Array) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for Int8Array", src) + return fmt.Errorf("cannot find dimensions of %v for Int8Array", src) } if elementsLength == 0 { *dst = Int8Array{Status: Present} @@ -372,7 +372,7 @@ func (dst *Int8Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int8Array", src) + return fmt.Errorf("cannot convert %v to Int8Array", src) } *dst = Int8Array{ @@ -403,7 +403,7 @@ func (dst *Int8Array) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to Int8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to Int8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -421,7 +421,7 @@ func (dst *Int8Array) setRecursive(value reflect.Value, index, dimension int) (i valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -434,10 +434,10 @@ func (dst *Int8Array) setRecursive(value reflect.Value, index, dimension int) (i return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to Int8Array") + return 0, fmt.Errorf("cannot convert all values to Int8Array") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in Int8Array", err) + return 0, fmt.Errorf("%v in Int8Array", err) } index++ @@ -625,7 +625,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -640,7 +640,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -648,7 +648,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *Int8Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -664,7 +664,7 @@ func (src *Int8Array) assignToRecursive(value reflect.Value, index, dimension in if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -682,14 +682,14 @@ func (src *Int8Array) assignToRecursive(value reflect.Value, index, dimension in return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from Int8Array") + return 0, fmt.Errorf("cannot assign all values from Int8Array") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from Int8Array") + return 0, fmt.Errorf("cannot assign all values from Int8Array") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -848,7 +848,7 @@ func (src Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("int8"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "int8") + return nil, fmt.Errorf("unable to find oid for type name %v", "int8") } for i := range src.Elements { @@ -892,7 +892,7 @@ func (dst *Int8Array) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/int8range.go b/int8range.go index 92fcb136..71369373 100644 --- a/int8range.go +++ b/int8range.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" + "fmt" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Int8range struct { @@ -30,7 +30,7 @@ func (dst *Int8range) Set(src interface{}) error { case string: return dst.DecodeText(nil, []byte(value)) default: - return errors.Errorf("cannot convert %v to Int8range", src) + return fmt.Errorf("cannot convert %v to Int8range", src) } return nil @@ -48,7 +48,7 @@ func (dst Int8range) Get() interface{} { } func (src *Int8range) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *Int8range) DecodeText(ci *ConnInfo, src []byte) error { @@ -137,7 +137,7 @@ func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -147,7 +147,7 @@ func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -158,7 +158,7 @@ func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -168,7 +168,7 @@ func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -192,7 +192,7 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -202,7 +202,7 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -218,7 +218,7 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -233,7 +233,7 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -258,7 +258,7 @@ func (dst *Int8range) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/interval.go b/interval.go index 309e880c..b01fbb7c 100644 --- a/interval.go +++ b/interval.go @@ -9,7 +9,6 @@ import ( "time" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) const ( @@ -47,7 +46,7 @@ func (dst *Interval) Set(src interface{}) error { if originalSrc, ok := underlyingPtrType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Interval", value) + return fmt.Errorf("cannot convert %v to Interval", value) } return nil @@ -76,13 +75,13 @@ func (src *Interval) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { @@ -100,7 +99,7 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { for i := 0; i < len(parts)-1; i += 2 { scalar, err := strconv.ParseInt(parts[i], 10, 64) if err != nil { - return errors.Errorf("bad interval format") + return fmt.Errorf("bad interval format") } switch parts[i+1] { @@ -116,7 +115,7 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { if len(parts)%2 == 1 { timeParts := strings.SplitN(parts[len(parts)-1], ":", 3) if len(timeParts) != 3 { - return errors.Errorf("bad interval format") + return fmt.Errorf("bad interval format") } var negative bool @@ -127,26 +126,26 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { hours, err := strconv.ParseInt(timeParts[0], 10, 64) if err != nil { - return errors.Errorf("bad interval hour format: %s", timeParts[0]) + return fmt.Errorf("bad interval hour format: %s", timeParts[0]) } minutes, err := strconv.ParseInt(timeParts[1], 10, 64) if err != nil { - return errors.Errorf("bad interval minute format: %s", timeParts[1]) + return fmt.Errorf("bad interval minute format: %s", timeParts[1]) } secondParts := strings.SplitN(timeParts[2], ".", 2) seconds, err := strconv.ParseInt(secondParts[0], 10, 64) if err != nil { - return errors.Errorf("bad interval second format: %s", secondParts[0]) + return fmt.Errorf("bad interval second format: %s", secondParts[0]) } var uSeconds int64 if len(secondParts) == 2 { uSeconds, err = strconv.ParseInt(secondParts[1], 10, 64) if err != nil { - return errors.Errorf("bad interval decimal format: %s", secondParts[1]) + return fmt.Errorf("bad interval decimal format: %s", secondParts[1]) } for i := 0; i < 6-len(secondParts[1]); i++ { @@ -175,7 +174,7 @@ func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 16 { - return errors.Errorf("Received an invalid size for a interval: %d", len(src)) + return fmt.Errorf("Received an invalid size for a interval: %d", len(src)) } microseconds := int64(binary.BigEndian.Uint64(src)) @@ -249,7 +248,7 @@ func (dst *Interval) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/json.go b/json.go index 922da50d..32bef5e7 100644 --- a/json.go +++ b/json.go @@ -3,8 +3,8 @@ package pgtype import ( "database/sql/driver" "encoding/json" - - errors "golang.org/x/xerrors" + "errors" + "fmt" ) type JSON struct { @@ -82,7 +82,7 @@ func (src *JSON) AssignTo(dst interface{}) error { if src.Status == Present { *v = string(src.Bytes) } else { - return errors.Errorf("cannot assign non-present status to %T", dst) + return fmt.Errorf("cannot assign non-present status to %T", dst) } case **string: if src.Status == Present { @@ -166,7 +166,7 @@ func (dst *JSON) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/jsonb.go b/jsonb.go index c129ac9b..c9dafc93 100644 --- a/jsonb.go +++ b/jsonb.go @@ -2,8 +2,7 @@ package pgtype import ( "database/sql/driver" - - errors "golang.org/x/xerrors" + "fmt" ) type JSONB JSON @@ -35,11 +34,11 @@ func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) == 0 { - return errors.Errorf("jsonb too short") + return fmt.Errorf("jsonb too short") } if src[0] != 1 { - return errors.Errorf("unknown jsonb version number %d", src[0]) + return fmt.Errorf("unknown jsonb version number %d", src[0]) } *dst = JSONB{Bytes: src[1:], Status: Present} diff --git a/jsonb_array.go b/jsonb_array.go index 5d658ed5..c4b7cd3d 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -5,10 +5,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type JSONBArray struct { @@ -96,7 +96,7 @@ func (dst *JSONBArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for JSONBArray", src) + return fmt.Errorf("cannot find dimensions of %v for JSONBArray", src) } if elementsLength == 0 { *dst = JSONBArray{Status: Present} @@ -106,7 +106,7 @@ func (dst *JSONBArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to JSONBArray", src) + return fmt.Errorf("cannot convert %v to JSONBArray", src) } *dst = JSONBArray{ @@ -137,7 +137,7 @@ func (dst *JSONBArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to JSONBArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to JSONBArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -155,7 +155,7 @@ func (dst *JSONBArray) setRecursive(value reflect.Value, index, dimension int) ( valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -168,10 +168,10 @@ func (dst *JSONBArray) setRecursive(value reflect.Value, index, dimension int) ( return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to JSONBArray") + return 0, fmt.Errorf("cannot convert all values to JSONBArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in JSONBArray", err) + return 0, fmt.Errorf("%v in JSONBArray", err) } index++ @@ -233,7 +233,7 @@ func (src *JSONBArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -248,7 +248,7 @@ func (src *JSONBArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -256,7 +256,7 @@ func (src *JSONBArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *JSONBArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -272,7 +272,7 @@ func (src *JSONBArray) assignToRecursive(value reflect.Value, index, dimension i if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -290,14 +290,14 @@ func (src *JSONBArray) assignToRecursive(value reflect.Value, index, dimension i return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from JSONBArray") + return 0, fmt.Errorf("cannot assign all values from JSONBArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from JSONBArray") + return 0, fmt.Errorf("cannot assign all values from JSONBArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -456,7 +456,7 @@ func (src JSONBArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("jsonb"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "jsonb") + return nil, fmt.Errorf("unable to find oid for type name %v", "jsonb") } for i := range src.Elements { @@ -500,7 +500,7 @@ func (dst *JSONBArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/line.go b/line.go index 737f5d86..3564b174 100644 --- a/line.go +++ b/line.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Line struct { @@ -18,7 +17,7 @@ type Line struct { } func (dst *Line) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Line", src) + return fmt.Errorf("cannot convert %v to Line", src) } func (dst Line) Get() interface{} { @@ -33,7 +32,7 @@ func (dst Line) Get() interface{} { } func (src *Line) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { @@ -43,12 +42,12 @@ func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 7 { - return errors.Errorf("invalid length for Line: %v", len(src)) + return fmt.Errorf("invalid length for Line: %v", len(src)) } parts := strings.SplitN(string(src[1:len(src)-1]), ",", 3) if len(parts) < 3 { - return errors.Errorf("invalid format for line") + return fmt.Errorf("invalid format for line") } a, err := strconv.ParseFloat(parts[0], 64) @@ -77,7 +76,7 @@ func (dst *Line) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 24 { - return errors.Errorf("invalid length for Line: %v", len(src)) + return fmt.Errorf("invalid length for Line: %v", len(src)) } a := binary.BigEndian.Uint64(src) @@ -140,7 +139,7 @@ func (dst *Line) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/lseg.go b/lseg.go index a16dcea3..5c4babb6 100644 --- a/lseg.go +++ b/lseg.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Lseg struct { @@ -18,7 +17,7 @@ type Lseg struct { } func (dst *Lseg) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Lseg", src) + return fmt.Errorf("cannot convert %v to Lseg", src) } func (dst Lseg) Get() interface{} { @@ -33,7 +32,7 @@ func (dst Lseg) Get() interface{} { } func (src *Lseg) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { @@ -43,7 +42,7 @@ func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 11 { - return errors.Errorf("invalid length for Lseg: %v", len(src)) + return fmt.Errorf("invalid length for Lseg: %v", len(src)) } str := string(src[2:]) @@ -90,7 +89,7 @@ func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 32 { - return errors.Errorf("invalid length for Lseg: %v", len(src)) + return fmt.Errorf("invalid length for Lseg: %v", len(src)) } x1 := binary.BigEndian.Uint64(src) @@ -157,7 +156,7 @@ func (dst *Lseg) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/macaddr.go b/macaddr.go index 6cc14114..1d3cfe7b 100644 --- a/macaddr.go +++ b/macaddr.go @@ -2,9 +2,8 @@ package pgtype import ( "database/sql/driver" + "fmt" "net" - - errors "golang.org/x/xerrors" ) type Macaddr struct { @@ -52,7 +51,7 @@ func (dst *Macaddr) Set(src interface{}) error { if originalSrc, ok := underlyingPtrType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Macaddr", value) + return fmt.Errorf("cannot convert %v to Macaddr", value) } return nil @@ -84,13 +83,13 @@ func (src *Macaddr) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Macaddr) DecodeText(ci *ConnInfo, src []byte) error { @@ -115,7 +114,7 @@ func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 6 { - return errors.Errorf("Received an invalid size for a macaddr: %d", len(src)) + return fmt.Errorf("Received an invalid size for a macaddr: %d", len(src)) } addr := make(net.HardwareAddr, 6) @@ -165,7 +164,7 @@ func (dst *Macaddr) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/macaddr_array.go b/macaddr_array.go index 0ac2618e..bdb1f203 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -5,11 +5,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "net" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type MacaddrArray struct { @@ -97,7 +97,7 @@ func (dst *MacaddrArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for MacaddrArray", src) + return fmt.Errorf("cannot find dimensions of %v for MacaddrArray", src) } if elementsLength == 0 { *dst = MacaddrArray{Status: Present} @@ -107,7 +107,7 @@ func (dst *MacaddrArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to MacaddrArray", src) + return fmt.Errorf("cannot convert %v to MacaddrArray", src) } *dst = MacaddrArray{ @@ -138,7 +138,7 @@ func (dst *MacaddrArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to MacaddrArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to MacaddrArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -156,7 +156,7 @@ func (dst *MacaddrArray) setRecursive(value reflect.Value, index, dimension int) valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -169,10 +169,10 @@ func (dst *MacaddrArray) setRecursive(value reflect.Value, index, dimension int) return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to MacaddrArray") + return 0, fmt.Errorf("cannot convert all values to MacaddrArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in MacaddrArray", err) + return 0, fmt.Errorf("%v in MacaddrArray", err) } index++ @@ -234,7 +234,7 @@ func (src *MacaddrArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -249,7 +249,7 @@ func (src *MacaddrArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -257,7 +257,7 @@ func (src *MacaddrArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *MacaddrArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -273,7 +273,7 @@ func (src *MacaddrArray) assignToRecursive(value reflect.Value, index, dimension if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -291,14 +291,14 @@ func (src *MacaddrArray) assignToRecursive(value reflect.Value, index, dimension return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from MacaddrArray") + return 0, fmt.Errorf("cannot assign all values from MacaddrArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from MacaddrArray") + return 0, fmt.Errorf("cannot assign all values from MacaddrArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -457,7 +457,7 @@ func (src MacaddrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("macaddr"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "macaddr") + return nil, fmt.Errorf("unable to find oid for type name %v", "macaddr") } for i := range src.Elements { @@ -501,7 +501,7 @@ func (dst *MacaddrArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/numeric.go b/numeric.go index 4d966d5e..a7efa704 100644 --- a/numeric.go +++ b/numeric.go @@ -3,13 +3,13 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "math" "math/big" "strconv" "strings" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) // PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 @@ -197,7 +197,7 @@ func (dst *Numeric) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Numeric", value) + return fmt.Errorf("cannot convert %v to Numeric", value) } return nil @@ -236,10 +236,10 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(bigMaxInt) > 0 { - return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) } if normalizedInt.Cmp(bigMinInt) < 0 { - return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) } *v = int(normalizedInt.Int64()) case *int8: @@ -248,10 +248,10 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(bigMaxInt8) > 0 { - return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) } if normalizedInt.Cmp(bigMinInt8) < 0 { - return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) } *v = int8(normalizedInt.Int64()) case *int16: @@ -260,10 +260,10 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(bigMaxInt16) > 0 { - return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) } if normalizedInt.Cmp(bigMinInt16) < 0 { - return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) } *v = int16(normalizedInt.Int64()) case *int32: @@ -272,10 +272,10 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(bigMaxInt32) > 0 { - return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) } if normalizedInt.Cmp(bigMinInt32) < 0 { - return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) } *v = int32(normalizedInt.Int64()) case *int64: @@ -284,10 +284,10 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(bigMaxInt64) > 0 { - return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) } if normalizedInt.Cmp(bigMinInt64) < 0 { - return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) } *v = normalizedInt.Int64() case *uint: @@ -296,9 +296,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(big0) < 0 { - return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) } else if normalizedInt.Cmp(bigMaxUint) > 0 { - return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) } *v = uint(normalizedInt.Uint64()) case *uint8: @@ -307,9 +307,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(big0) < 0 { - return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) } else if normalizedInt.Cmp(bigMaxUint8) > 0 { - return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) } *v = uint8(normalizedInt.Uint64()) case *uint16: @@ -318,9 +318,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(big0) < 0 { - return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) } else if normalizedInt.Cmp(bigMaxUint16) > 0 { - return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) } *v = uint16(normalizedInt.Uint64()) case *uint32: @@ -329,9 +329,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(big0) < 0 { - return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) } else if normalizedInt.Cmp(bigMaxUint32) > 0 { - return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) } *v = uint32(normalizedInt.Uint64()) case *uint64: @@ -340,16 +340,16 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(big0) < 0 { - return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) } else if normalizedInt.Cmp(bigMaxUint64) > 0 { - return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) } *v = normalizedInt.Uint64() default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) @@ -377,7 +377,7 @@ func (dst *Numeric) toBigInt() (*big.Int, error) { remainder := &big.Int{} num.DivMod(num, div, remainder) if remainder.Cmp(big0) != 0 { - return nil, errors.Errorf("cannot convert %v to integer", dst) + return nil, fmt.Errorf("cannot convert %v to integer", dst) } return num, nil } @@ -435,7 +435,7 @@ func parseNumericString(str string) (n *big.Int, exp int32, err error) { accum := &big.Int{} if _, ok := accum.SetString(digits, 10); !ok { - return nil, 0, errors.Errorf("%s is not a number", str) + return nil, 0, fmt.Errorf("%s is not a number", str) } return accum, exp, nil @@ -448,7 +448,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) < 8 { - return errors.Errorf("numeric incomplete %v", src) + return fmt.Errorf("numeric incomplete %v", src) } rp := 0 @@ -472,7 +472,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src[rp:]) < int(ndigits)*2 { - return errors.Errorf("numeric incomplete %v", src) + return fmt.Errorf("numeric incomplete %v", src) } accum := &big.Int{} @@ -493,7 +493,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { case 4: mul = bigNBaseX4 default: - return errors.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) + return fmt.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) } accum.Mul(accum, mul) } @@ -695,7 +695,7 @@ func (dst *Numeric) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/numeric_array.go b/numeric_array.go index 1c2ae489..31899dec 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -5,10 +5,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type NumericArray struct { @@ -210,7 +210,7 @@ func (dst *NumericArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for NumericArray", src) + return fmt.Errorf("cannot find dimensions of %v for NumericArray", src) } if elementsLength == 0 { *dst = NumericArray{Status: Present} @@ -220,7 +220,7 @@ func (dst *NumericArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to NumericArray", src) + return fmt.Errorf("cannot convert %v to NumericArray", src) } *dst = NumericArray{ @@ -251,7 +251,7 @@ func (dst *NumericArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to NumericArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to NumericArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -269,7 +269,7 @@ func (dst *NumericArray) setRecursive(value reflect.Value, index, dimension int) valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -282,10 +282,10 @@ func (dst *NumericArray) setRecursive(value reflect.Value, index, dimension int) return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to NumericArray") + return 0, fmt.Errorf("cannot convert all values to NumericArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in NumericArray", err) + return 0, fmt.Errorf("%v in NumericArray", err) } index++ @@ -401,7 +401,7 @@ func (src *NumericArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -416,7 +416,7 @@ func (src *NumericArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -424,7 +424,7 @@ func (src *NumericArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *NumericArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -440,7 +440,7 @@ func (src *NumericArray) assignToRecursive(value reflect.Value, index, dimension if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -458,14 +458,14 @@ func (src *NumericArray) assignToRecursive(value reflect.Value, index, dimension return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from NumericArray") + return 0, fmt.Errorf("cannot assign all values from NumericArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from NumericArray") + return 0, fmt.Errorf("cannot assign all values from NumericArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -624,7 +624,7 @@ func (src NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("numeric"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "numeric") + return nil, fmt.Errorf("unable to find oid for type name %v", "numeric") } for i := range src.Elements { @@ -668,7 +668,7 @@ func (dst *NumericArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/numrange.go b/numrange.go index 40467686..3d5951a2 100644 --- a/numrange.go +++ b/numrange.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" + "fmt" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Numrange struct { @@ -30,7 +30,7 @@ func (dst *Numrange) Set(src interface{}) error { case string: return dst.DecodeText(nil, []byte(value)) default: - return errors.Errorf("cannot convert %v to Numrange", src) + return fmt.Errorf("cannot convert %v to Numrange", src) } return nil @@ -48,7 +48,7 @@ func (dst Numrange) Get() interface{} { } func (src *Numrange) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *Numrange) DecodeText(ci *ConnInfo, src []byte) error { @@ -137,7 +137,7 @@ func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -147,7 +147,7 @@ func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -158,7 +158,7 @@ func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -168,7 +168,7 @@ func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -192,7 +192,7 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -202,7 +202,7 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -218,7 +218,7 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -233,7 +233,7 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -258,7 +258,7 @@ func (dst *Numrange) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/oid.go b/oid.go index 593a5261..31677e89 100644 --- a/oid.go +++ b/oid.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "strconv" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) // OID (Object Identifier Type) is, according to @@ -20,7 +20,7 @@ type OID uint32 func (dst *OID) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - return errors.Errorf("cannot decode nil into OID") + return fmt.Errorf("cannot decode nil into OID") } n, err := strconv.ParseUint(string(src), 10, 32) @@ -34,11 +34,11 @@ func (dst *OID) DecodeText(ci *ConnInfo, src []byte) error { func (dst *OID) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - return errors.Errorf("cannot decode nil into OID") + return fmt.Errorf("cannot decode nil into OID") } if len(src) != 4 { - return errors.Errorf("invalid length: %v", len(src)) + return fmt.Errorf("invalid length: %v", len(src)) } n := binary.BigEndian.Uint32(src) @@ -57,7 +57,7 @@ func (src OID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *OID) Scan(src interface{}) error { if src == nil { - return errors.Errorf("cannot scan NULL into %T", src) + return fmt.Errorf("cannot scan NULL into %T", src) } switch src := src.(type) { @@ -72,7 +72,7 @@ func (dst *OID) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/path.go b/path.go index c5031330..9f89969e 100644 --- a/path.go +++ b/path.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Path struct { @@ -19,7 +18,7 @@ type Path struct { } func (dst *Path) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Path", src) + return fmt.Errorf("cannot convert %v to Path", src) } func (dst Path) Get() interface{} { @@ -34,7 +33,7 @@ func (dst Path) Get() interface{} { } func (src *Path) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *Path) DecodeText(ci *ConnInfo, src []byte) error { @@ -44,7 +43,7 @@ func (dst *Path) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 7 { - return errors.Errorf("invalid length for Path: %v", len(src)) + return fmt.Errorf("invalid length for Path: %v", len(src)) } closed := src[0] == '(' @@ -87,7 +86,7 @@ func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) < 5 { - return errors.Errorf("invalid length for Path: %v", len(src)) + return fmt.Errorf("invalid length for Path: %v", len(src)) } closed := src[0] == 1 @@ -96,7 +95,7 @@ func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { rp := 5 if 5+pointCount*16 != len(src) { - return errors.Errorf("invalid length for Path with %d points: %v", pointCount, len(src)) + return fmt.Errorf("invalid length for Path with %d points: %v", pointCount, len(src)) } points := make([]Vec2, pointCount) @@ -187,7 +186,7 @@ func (dst *Path) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype.go b/pgtype.go index c5e537cd..f1d40146 100644 --- a/pgtype.go +++ b/pgtype.go @@ -3,13 +3,12 @@ package pgtype import ( "database/sql" "encoding/binary" + "errors" "fmt" "math" "net" "reflect" "time" - - errors "golang.org/x/xerrors" ) // PostgreSQL oids for common types @@ -625,11 +624,11 @@ type scanPlanBinaryInt16 struct{} func (scanPlanBinaryInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { - return errors.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan null into %T", dst) } if len(src) != 2 { - return errors.Errorf("invalid length for int2: %v", len(src)) + return fmt.Errorf("invalid length for int2: %v", len(src)) } if p, ok := (dst).(*int16); ok { @@ -645,11 +644,11 @@ type scanPlanBinaryInt32 struct{} func (scanPlanBinaryInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { - return errors.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan null into %T", dst) } if len(src) != 4 { - return errors.Errorf("invalid length for int4: %v", len(src)) + return fmt.Errorf("invalid length for int4: %v", len(src)) } if p, ok := (dst).(*int32); ok { @@ -665,11 +664,11 @@ type scanPlanBinaryInt64 struct{} func (scanPlanBinaryInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { - return errors.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan null into %T", dst) } if len(src) != 8 { - return errors.Errorf("invalid length for int8: %v", len(src)) + return fmt.Errorf("invalid length for int8: %v", len(src)) } if p, ok := (dst).(*int64); ok { @@ -685,11 +684,11 @@ type scanPlanBinaryFloat32 struct{} func (scanPlanBinaryFloat32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { - return errors.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan null into %T", dst) } if len(src) != 4 { - return errors.Errorf("invalid length for int4: %v", len(src)) + return fmt.Errorf("invalid length for int4: %v", len(src)) } if p, ok := (dst).(*float32); ok { @@ -706,11 +705,11 @@ type scanPlanBinaryFloat64 struct{} func (scanPlanBinaryFloat64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { - return errors.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan null into %T", dst) } if len(src) != 8 { - return errors.Errorf("invalid length for int8: %v", len(src)) + return fmt.Errorf("invalid length for int8: %v", len(src)) } if p, ok := (dst).(*float64); ok { @@ -739,7 +738,7 @@ type scanPlanString struct{} func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { - return errors.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan null into %T", dst) } if p, ok := (dst).(*string); ok { @@ -841,7 +840,7 @@ func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) switch dest := dest.(type) { case *string: if formatCode == BinaryFormatCode { - return errors.Errorf("unknown oid %d in binary format cannot be scanned into %T", oid, dest) + return fmt.Errorf("unknown oid %d in binary format cannot be scanned into %T", oid, dest) } *dest = string(buf) return nil @@ -852,7 +851,7 @@ func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) if nextDst, retry := GetAssignToDstType(dest); retry { return scanUnknownType(oid, formatCode, buf, nextDst) } - return errors.Errorf("unknown oid %d cannot be scanned into %T", oid, dest) + return fmt.Errorf("unknown oid %d cannot be scanned into %T", oid, dest) } } diff --git a/pgtype_test.go b/pgtype_test.go index 32ce0a99..f46ec12a 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -2,6 +2,7 @@ package pgtype_test import ( "bytes" + "errors" "net" "testing" @@ -11,7 +12,6 @@ import ( _ "github.com/lib/pq" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - errors "golang.org/x/xerrors" ) // Test for renamed types diff --git a/pguint32.go b/pguint32.go index a245d2c9..a0e88ca2 100644 --- a/pguint32.go +++ b/pguint32.go @@ -3,11 +3,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "math" "strconv" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) // pguint32 is the core type that is used to implement PostgreSQL types such as @@ -24,16 +24,16 @@ func (dst *pguint32) Set(src interface{}) error { switch value := src.(type) { case int64: if value < 0 { - return errors.Errorf("%d is less than minimum value for pguint32", value) + return fmt.Errorf("%d is less than minimum value for pguint32", value) } if value > math.MaxUint32 { - return errors.Errorf("%d is greater than maximum value for pguint32", value) + return fmt.Errorf("%d is greater than maximum value for pguint32", value) } *dst = pguint32{Uint: uint32(value), Status: Present} case uint32: *dst = pguint32{Uint: value, Status: Present} default: - return errors.Errorf("cannot convert %v to pguint32", value) + return fmt.Errorf("cannot convert %v to pguint32", value) } return nil @@ -58,7 +58,7 @@ func (src *pguint32) AssignTo(dst interface{}) error { if src.Status == Present { *v = src.Uint } else { - return errors.Errorf("cannot assign %v into %T", src, dst) + return fmt.Errorf("cannot assign %v into %T", src, dst) } case **uint32: if src.Status == Present { @@ -94,7 +94,7 @@ func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 4 { - return errors.Errorf("invalid length: %v", len(src)) + return fmt.Errorf("invalid length: %v", len(src)) } n := binary.BigEndian.Uint32(src) @@ -146,7 +146,7 @@ func (dst *pguint32) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/point.go b/point.go index 8e6bacf2..0c799106 100644 --- a/point.go +++ b/point.go @@ -10,7 +10,6 @@ import ( "strings" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Vec2 struct { @@ -28,7 +27,7 @@ func (dst *Point) Set(src interface{}) error { dst.Status = Null return nil } - err := errors.Errorf("cannot convert %v to Point", src) + err := fmt.Errorf("cannot convert %v to Point", src) var p *Point switch value := src.(type) { case string: @@ -51,14 +50,14 @@ func parsePoint(src []byte) (*Point, error) { } if len(src) < 5 { - return nil, errors.Errorf("invalid length for point: %v", len(src)) + return nil, fmt.Errorf("invalid length for point: %v", len(src)) } if src[0] == '"' && src[len(src)-1] == '"' { src = src[1 : len(src)-1] } parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) if len(parts) < 2 { - return nil, errors.Errorf("invalid format for point") + return nil, fmt.Errorf("invalid format for point") } x, err := strconv.ParseFloat(parts[0], 64) @@ -86,7 +85,7 @@ func (dst Point) Get() interface{} { } func (src *Point) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { @@ -96,12 +95,12 @@ func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 5 { - return errors.Errorf("invalid length for point: %v", len(src)) + return fmt.Errorf("invalid length for point: %v", len(src)) } parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) if len(parts) < 2 { - return errors.Errorf("invalid format for point") + return fmt.Errorf("invalid format for point") } x, err := strconv.ParseFloat(parts[0], 64) @@ -125,7 +124,7 @@ func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 16 { - return errors.Errorf("invalid length for point: %v", len(src)) + return fmt.Errorf("invalid length for point: %v", len(src)) } x := binary.BigEndian.Uint64(src) @@ -181,7 +180,7 @@ func (dst *Point) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/polygon.go b/polygon.go index 5124af7f..207cadc0 100644 --- a/polygon.go +++ b/polygon.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Polygon struct { @@ -28,7 +27,7 @@ func (dst *Polygon) Set(src interface{}) error { dst.Status = Null return nil } - err := errors.Errorf("cannot convert %v to Polygon", src) + err := fmt.Errorf("cannot convert %v to Polygon", src) var p *Polygon switch value := src.(type) { case string: @@ -61,7 +60,7 @@ func float64ToPolygon(src []float64) (*Polygon, error) { } if len(src)%2 != 0 { p.Status = Undefined - return p, errors.Errorf("invalid length for polygon: %v", len(src)) + return p, fmt.Errorf("invalid length for polygon: %v", len(src)) } p.Status = Present p.P = make([]Vec2, 0) @@ -83,7 +82,7 @@ func (dst Polygon) Get() interface{} { } func (src *Polygon) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *Polygon) DecodeText(ci *ConnInfo, src []byte) error { @@ -93,7 +92,7 @@ func (dst *Polygon) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 7 { - return errors.Errorf("invalid length for Polygon: %v", len(src)) + return fmt.Errorf("invalid length for Polygon: %v", len(src)) } points := make([]Vec2, 0) @@ -135,14 +134,14 @@ func (dst *Polygon) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) < 5 { - return errors.Errorf("invalid length for Polygon: %v", len(src)) + return fmt.Errorf("invalid length for Polygon: %v", len(src)) } pointCount := int(binary.BigEndian.Uint32(src)) rp := 4 if 4+pointCount*16 != len(src) { - return errors.Errorf("invalid length for Polygon with %d points: %v", pointCount, len(src)) + return fmt.Errorf("invalid length for Polygon with %d points: %v", pointCount, len(src)) } points := make([]Vec2, pointCount) @@ -218,7 +217,7 @@ func (dst *Polygon) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/qchar.go b/qchar.go index 93964058..574f6066 100644 --- a/qchar.go +++ b/qchar.go @@ -1,10 +1,9 @@ package pgtype import ( + "fmt" "math" "strconv" - - errors "golang.org/x/xerrors" ) // QChar is for PostgreSQL's special 8-bit-only "char" type more akin to the C @@ -41,59 +40,59 @@ func (dst *QChar) Set(src interface{}) error { *dst = QChar{Int: value, Status: Present} case uint8: if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) + return fmt.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case int16: if value < math.MinInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) + return fmt.Errorf("%d is greater than maximum value for QChar", value) } if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) + return fmt.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case uint16: if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) + return fmt.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case int32: if value < math.MinInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) + return fmt.Errorf("%d is greater than maximum value for QChar", value) } if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) + return fmt.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case uint32: if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) + return fmt.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case int64: if value < math.MinInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) + return fmt.Errorf("%d is greater than maximum value for QChar", value) } if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) + return fmt.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case uint64: if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) + return fmt.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case int: if value < math.MinInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) + return fmt.Errorf("%d is greater than maximum value for QChar", value) } if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) + return fmt.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case uint: if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) + return fmt.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case string: @@ -106,7 +105,7 @@ func (dst *QChar) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to QChar", value) + return fmt.Errorf("cannot convert %v to QChar", value) } return nil @@ -134,7 +133,7 @@ func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 1 { - return errors.Errorf(`invalid length for "char": %v`, len(src)) + return fmt.Errorf(`invalid length for "char": %v`, len(src)) } *dst = QChar{Int: int8(src[0]), Status: Present} diff --git a/range.go b/range.go index 35b80ced..e999f6a9 100644 --- a/range.go +++ b/range.go @@ -3,8 +3,7 @@ package pgtype import ( "bytes" "encoding/binary" - - errors "golang.org/x/xerrors" + "fmt" ) type BoundType byte @@ -41,7 +40,7 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { r, _, err := buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid lower bound: %v", err) + return nil, fmt.Errorf("invalid lower bound: %v", err) } switch r { case '(': @@ -49,12 +48,12 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { case '[': utr.LowerType = Inclusive default: - return nil, errors.Errorf("missing lower bound, instead got: %v", string(r)) + return nil, fmt.Errorf("missing lower bound, instead got: %v", string(r)) } r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid lower value: %v", err) + return nil, fmt.Errorf("invalid lower value: %v", err) } buf.UnreadRune() @@ -63,21 +62,21 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { } else { utr.Lower, err = rangeParseValue(buf) if err != nil { - return nil, errors.Errorf("invalid lower value: %v", err) + return nil, fmt.Errorf("invalid lower value: %v", err) } } r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("missing range separator: %v", err) + return nil, fmt.Errorf("missing range separator: %v", err) } if r != ',' { - return nil, errors.Errorf("missing range separator: %v", r) + return nil, fmt.Errorf("missing range separator: %v", r) } r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid upper value: %v", err) + return nil, fmt.Errorf("invalid upper value: %v", err) } if r == ')' || r == ']' { @@ -86,12 +85,12 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { buf.UnreadRune() utr.Upper, err = rangeParseValue(buf) if err != nil { - return nil, errors.Errorf("invalid upper value: %v", err) + return nil, fmt.Errorf("invalid upper value: %v", err) } r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("missing upper bound: %v", err) + return nil, fmt.Errorf("missing upper bound: %v", err) } switch r { case ')': @@ -99,14 +98,14 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { case ']': utr.UpperType = Inclusive default: - return nil, errors.Errorf("missing upper bound, instead got: %v", string(r)) + return nil, fmt.Errorf("missing upper bound, instead got: %v", string(r)) } } skipWhitespace(buf) if buf.Len() > 0 { - return nil, errors.Errorf("unexpected trailing data: %v", buf.String()) + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) } return utr, nil @@ -202,7 +201,7 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { ubr := &UntypedBinaryRange{} if len(src) == 0 { - return nil, errors.Errorf("range too short: %v", len(src)) + return nil, fmt.Errorf("range too short: %v", len(src)) } rangeType := src[0] @@ -210,7 +209,7 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { if rangeType&emptyMask > 0 { if len(src[rp:]) > 0 { - return nil, errors.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) + return nil, fmt.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) } ubr.LowerType = Empty ubr.UpperType = Empty @@ -235,13 +234,13 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { if ubr.LowerType == Unbounded && ubr.UpperType == Unbounded { if len(src[rp:]) > 0 { - return nil, errors.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) + return nil, fmt.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) } return ubr, nil } if len(src[rp:]) < 4 { - return nil, errors.Errorf("too few bytes for size: %v", src[rp:]) + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) } valueLen := int(binary.BigEndian.Uint32(src[rp:])) rp += 4 @@ -254,14 +253,14 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { } else { ubr.Upper = val if len(src[rp:]) > 0 { - return nil, errors.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) } return ubr, nil } if ubr.UpperType != Unbounded { if len(src[rp:]) < 4 { - return nil, errors.Errorf("too few bytes for size: %v", src[rp:]) + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) } valueLen := int(binary.BigEndian.Uint32(src[rp:])) rp += 4 @@ -270,7 +269,7 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { } if len(src[rp:]) > 0 { - return nil, errors.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) } return ubr, nil diff --git a/record.go b/record.go index 7899a881..718c3570 100644 --- a/record.go +++ b/record.go @@ -1,9 +1,8 @@ package pgtype import ( + "fmt" "reflect" - - errors "golang.org/x/xerrors" ) // Record is the generic PostgreSQL record type such as is created with the @@ -33,7 +32,7 @@ func (dst *Record) Set(src interface{}) error { case []Value: *dst = Record{Fields: value, Status: Present} default: - return errors.Errorf("cannot convert %v to Record", src) + return fmt.Errorf("cannot convert %v to Record", src) } return nil @@ -68,13 +67,13 @@ func (src *Record) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func prepareNewBinaryDecoder(ci *ConnInfo, fieldOID uint32, v *Value) (BinaryDecoder, error) { @@ -83,11 +82,11 @@ func prepareNewBinaryDecoder(ci *ConnInfo, fieldOID uint32, v *Value) (BinaryDec if dt, ok := ci.DataTypeForOID(fieldOID); ok { binaryDecoder, _ = dt.Value.(BinaryDecoder) } else { - return nil, errors.Errorf("unknown oid while decoding record: %v", fieldOID) + return nil, fmt.Errorf("unknown oid while decoding record: %v", fieldOID) } if binaryDecoder == nil { - return nil, errors.Errorf("no binary decoder registered for: %v", fieldOID) + return nil, fmt.Errorf("no binary decoder registered for: %v", fieldOID) } // Duplicate struct to scan into diff --git a/text.go b/text.go index 4c9e4a21..6b01d1b4 100644 --- a/text.go +++ b/text.go @@ -3,8 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/json" - - errors "golang.org/x/xerrors" + "fmt" ) type Text struct { @@ -44,7 +43,7 @@ func (dst *Text) Set(src interface{}) error { if originalSrc, ok := underlyingStringType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Text", value) + return fmt.Errorf("cannot convert %v to Text", value) } return nil @@ -76,13 +75,13 @@ func (src *Text) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (Text) PreferredResultFormat() int16 { @@ -138,7 +137,7 @@ func (dst *Text) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/text_array.go b/text_array.go index afdc507b..2461966b 100644 --- a/text_array.go +++ b/text_array.go @@ -5,10 +5,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type TextArray struct { @@ -96,7 +96,7 @@ func (dst *TextArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for TextArray", src) + return fmt.Errorf("cannot find dimensions of %v for TextArray", src) } if elementsLength == 0 { *dst = TextArray{Status: Present} @@ -106,7 +106,7 @@ func (dst *TextArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TextArray", src) + return fmt.Errorf("cannot convert %v to TextArray", src) } *dst = TextArray{ @@ -137,7 +137,7 @@ func (dst *TextArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to TextArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to TextArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -155,7 +155,7 @@ func (dst *TextArray) setRecursive(value reflect.Value, index, dimension int) (i valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -168,10 +168,10 @@ func (dst *TextArray) setRecursive(value reflect.Value, index, dimension int) (i return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to TextArray") + return 0, fmt.Errorf("cannot convert all values to TextArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in TextArray", err) + return 0, fmt.Errorf("%v in TextArray", err) } index++ @@ -233,7 +233,7 @@ func (src *TextArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -248,7 +248,7 @@ func (src *TextArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -256,7 +256,7 @@ func (src *TextArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *TextArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -272,7 +272,7 @@ func (src *TextArray) assignToRecursive(value reflect.Value, index, dimension in if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -290,14 +290,14 @@ func (src *TextArray) assignToRecursive(value reflect.Value, index, dimension in return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from TextArray") + return 0, fmt.Errorf("cannot assign all values from TextArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from TextArray") + return 0, fmt.Errorf("cannot assign all values from TextArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -456,7 +456,7 @@ func (src TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("text"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "text") + return nil, fmt.Errorf("unable to find oid for type name %v", "text") } for i := range src.Elements { @@ -500,7 +500,7 @@ func (dst *TextArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/tid.go b/tid.go index f7b80f94..4bb57f64 100644 --- a/tid.go +++ b/tid.go @@ -8,7 +8,6 @@ import ( "strings" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) // TID is PostgreSQL's Tuple Identifier type. @@ -29,7 +28,7 @@ type TID struct { } func (dst *TID) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to TID", src) + return fmt.Errorf("cannot convert %v to TID", src) } func (dst TID) Get() interface{} { @@ -53,11 +52,11 @@ func (src *TID) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } } - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *TID) DecodeText(ci *ConnInfo, src []byte) error { @@ -67,12 +66,12 @@ func (dst *TID) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 5 { - return errors.Errorf("invalid length for tid: %v", len(src)) + return fmt.Errorf("invalid length for tid: %v", len(src)) } parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) if len(parts) < 2 { - return errors.Errorf("invalid format for tid") + return fmt.Errorf("invalid format for tid") } blockNumber, err := strconv.ParseUint(parts[0], 10, 32) @@ -96,7 +95,7 @@ func (dst *TID) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 6 { - return errors.Errorf("invalid length for tid: %v", len(src)) + return fmt.Errorf("invalid length for tid: %v", len(src)) } *dst = TID{ @@ -148,7 +147,7 @@ func (dst *TID) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/time.go b/time.go index 237c4b5b..f7a28870 100644 --- a/time.go +++ b/time.go @@ -8,7 +8,6 @@ import ( "time" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) // Time represents the PostgreSQL time type. The PostgreSQL time is a time of day without time zone. @@ -52,7 +51,7 @@ func (dst *Time) Set(src interface{}) error { if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Time", value) + return fmt.Errorf("cannot convert %v to Time", value) } return nil @@ -77,7 +76,7 @@ func (src *Time) AssignTo(dst interface{}) error { // 24:00:00 is max allowed time in PostgreSQL, but time.Time will normalize that to 00:00:00 the next day. var maxRepresentableByTime int64 = 24*60*60*1000000 - 1 if src.Microseconds > maxRepresentableByTime { - return errors.Errorf("%d microseconds cannot be represented as time.Time", src.Microseconds) + return fmt.Errorf("%d microseconds cannot be represented as time.Time", src.Microseconds) } usec := src.Microseconds @@ -94,13 +93,13 @@ func (src *Time) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } // DecodeText decodes from src into dst. @@ -113,24 +112,24 @@ func (dst *Time) DecodeText(ci *ConnInfo, src []byte) error { s := string(src) if len(s) < 8 { - return errors.Errorf("cannot decode %v into Time", s) + return fmt.Errorf("cannot decode %v into Time", s) } hours, err := strconv.ParseInt(s[0:2], 10, 64) if err != nil { - return errors.Errorf("cannot decode %v into Time", s) + return fmt.Errorf("cannot decode %v into Time", s) } usec := hours * microsecondsPerHour minutes, err := strconv.ParseInt(s[3:5], 10, 64) if err != nil { - return errors.Errorf("cannot decode %v into Time", s) + return fmt.Errorf("cannot decode %v into Time", s) } usec += minutes * microsecondsPerMinute seconds, err := strconv.ParseInt(s[6:8], 10, 64) if err != nil { - return errors.Errorf("cannot decode %v into Time", s) + return fmt.Errorf("cannot decode %v into Time", s) } usec += seconds * microsecondsPerSecond @@ -138,7 +137,7 @@ func (dst *Time) DecodeText(ci *ConnInfo, src []byte) error { fraction := s[9:] n, err := strconv.ParseInt(fraction, 10, 64) if err != nil { - return errors.Errorf("cannot decode %v into Time", s) + return fmt.Errorf("cannot decode %v into Time", s) } for i := len(fraction); i < 6; i++ { @@ -161,7 +160,7 @@ func (dst *Time) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 { - return errors.Errorf("invalid length for time: %v", len(src)) + return fmt.Errorf("invalid length for time: %v", len(src)) } usec := int64(binary.BigEndian.Uint64(src)) @@ -223,7 +222,7 @@ func (dst *Time) Scan(src interface{}) error { return dst.Set(src) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/timestamp.go b/timestamp.go index 0e127695..46644115 100644 --- a/timestamp.go +++ b/timestamp.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "time" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) const pgTimestampFormat = "2006-01-02 15:04:05.999999999" @@ -52,7 +52,7 @@ func (dst *Timestamp) Set(src interface{}) error { if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Timestamp", value) + return fmt.Errorf("cannot convert %v to Timestamp", value) } return nil @@ -78,7 +78,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: if src.InfinityModifier != None { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } *v = src.Time return nil @@ -86,13 +86,13 @@ func (src *Timestamp) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } // DecodeText decodes from src into dst. The decoded time is considered to @@ -130,7 +130,7 @@ func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 { - return errors.Errorf("invalid length for timestamp: %v", len(src)) + return fmt.Errorf("invalid length for timestamp: %v", len(src)) } microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) @@ -159,7 +159,7 @@ func (src Timestamp) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } if src.Time.Location() != time.UTC { - return nil, errors.Errorf("cannot encode non-UTC time into timestamp") + return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") } var s string @@ -186,7 +186,7 @@ func (src Timestamp) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } if src.Time.Location() != time.UTC { - return nil, errors.Errorf("cannot encode non-UTC time into timestamp") + return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") } var microsecSinceY2K int64 @@ -222,7 +222,7 @@ func (dst *Timestamp) Scan(src interface{}) error { return nil } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/timestamp_array.go b/timestamp_array.go index 5256f185..e12481e3 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -5,11 +5,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "time" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type TimestampArray struct { @@ -97,7 +97,7 @@ func (dst *TimestampArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for TimestampArray", src) + return fmt.Errorf("cannot find dimensions of %v for TimestampArray", src) } if elementsLength == 0 { *dst = TimestampArray{Status: Present} @@ -107,7 +107,7 @@ func (dst *TimestampArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TimestampArray", src) + return fmt.Errorf("cannot convert %v to TimestampArray", src) } *dst = TimestampArray{ @@ -138,7 +138,7 @@ func (dst *TimestampArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to TimestampArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to TimestampArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -156,7 +156,7 @@ func (dst *TimestampArray) setRecursive(value reflect.Value, index, dimension in valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -169,10 +169,10 @@ func (dst *TimestampArray) setRecursive(value reflect.Value, index, dimension in return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to TimestampArray") + return 0, fmt.Errorf("cannot convert all values to TimestampArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in TimestampArray", err) + return 0, fmt.Errorf("%v in TimestampArray", err) } index++ @@ -234,7 +234,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -249,7 +249,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -257,7 +257,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *TimestampArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -273,7 +273,7 @@ func (src *TimestampArray) assignToRecursive(value reflect.Value, index, dimensi if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -291,14 +291,14 @@ func (src *TimestampArray) assignToRecursive(value reflect.Value, index, dimensi return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from TimestampArray") + return 0, fmt.Errorf("cannot assign all values from TimestampArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from TimestampArray") + return 0, fmt.Errorf("cannot assign all values from TimestampArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -457,7 +457,7 @@ func (src TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) if dt, ok := ci.DataTypeForName("timestamp"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "timestamp") + return nil, fmt.Errorf("unable to find oid for type name %v", "timestamp") } for i := range src.Elements { @@ -501,7 +501,7 @@ func (dst *TimestampArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/timestamptz.go b/timestamptz.go index a79bd66e..e0743060 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -4,10 +4,10 @@ import ( "database/sql/driver" "encoding/binary" "encoding/json" + "fmt" "time" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" @@ -54,7 +54,7 @@ func (dst *Timestamptz) Set(src interface{}) error { if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Timestamptz", value) + return fmt.Errorf("cannot convert %v to Timestamptz", value) } return nil @@ -80,7 +80,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: if src.InfinityModifier != None { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } *v = src.Time return nil @@ -88,13 +88,13 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } - return errors.Errorf("unable to assign to %T", dst) + return fmt.Errorf("unable to assign to %T", dst) } case Null: return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { @@ -137,7 +137,7 @@ func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 { - return errors.Errorf("invalid length for timestamptz: %v", len(src)) + return fmt.Errorf("invalid length for timestamptz: %v", len(src)) } microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) @@ -219,7 +219,7 @@ func (dst *Timestamptz) Scan(src interface{}) error { return nil } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/timestamptz_array.go b/timestamptz_array.go index 47408c02..a3b4b263 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -5,11 +5,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "time" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type TimestamptzArray struct { @@ -97,7 +97,7 @@ func (dst *TimestamptzArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for TimestamptzArray", src) + return fmt.Errorf("cannot find dimensions of %v for TimestamptzArray", src) } if elementsLength == 0 { *dst = TimestamptzArray{Status: Present} @@ -107,7 +107,7 @@ func (dst *TimestamptzArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TimestamptzArray", src) + return fmt.Errorf("cannot convert %v to TimestamptzArray", src) } *dst = TimestamptzArray{ @@ -138,7 +138,7 @@ func (dst *TimestamptzArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to TimestamptzArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to TimestamptzArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -156,7 +156,7 @@ func (dst *TimestamptzArray) setRecursive(value reflect.Value, index, dimension valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -169,10 +169,10 @@ func (dst *TimestamptzArray) setRecursive(value reflect.Value, index, dimension return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to TimestamptzArray") + return 0, fmt.Errorf("cannot convert all values to TimestamptzArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in TimestamptzArray", err) + return 0, fmt.Errorf("%v in TimestamptzArray", err) } index++ @@ -234,7 +234,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -249,7 +249,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -257,7 +257,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *TimestamptzArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -273,7 +273,7 @@ func (src *TimestamptzArray) assignToRecursive(value reflect.Value, index, dimen if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -291,14 +291,14 @@ func (src *TimestamptzArray) assignToRecursive(value reflect.Value, index, dimen return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from TimestamptzArray") + return 0, fmt.Errorf("cannot assign all values from TimestamptzArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from TimestamptzArray") + return 0, fmt.Errorf("cannot assign all values from TimestamptzArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -457,7 +457,7 @@ func (src TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, erro if dt, ok := ci.DataTypeForName("timestamptz"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "timestamptz") + return nil, fmt.Errorf("unable to find oid for type name %v", "timestamptz") } for i := range src.Elements { @@ -501,7 +501,7 @@ func (dst *TimestamptzArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/tsrange.go b/tsrange.go index 6ca12aed..19ecf446 100644 --- a/tsrange.go +++ b/tsrange.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" + "fmt" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Tsrange struct { @@ -30,7 +30,7 @@ func (dst *Tsrange) Set(src interface{}) error { case string: return dst.DecodeText(nil, []byte(value)) default: - return errors.Errorf("cannot convert %v to Tsrange", src) + return fmt.Errorf("cannot convert %v to Tsrange", src) } return nil @@ -48,7 +48,7 @@ func (dst Tsrange) Get() interface{} { } func (src *Tsrange) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *Tsrange) DecodeText(ci *ConnInfo, src []byte) error { @@ -137,7 +137,7 @@ func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -147,7 +147,7 @@ func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -158,7 +158,7 @@ func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -168,7 +168,7 @@ func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -192,7 +192,7 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -202,7 +202,7 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -218,7 +218,7 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -233,7 +233,7 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -258,7 +258,7 @@ func (dst *Tsrange) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/tsrange_array.go b/tsrange_array.go index 15053f75..c64048eb 100644 --- a/tsrange_array.go +++ b/tsrange_array.go @@ -5,10 +5,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type TsrangeArray struct { @@ -58,7 +58,7 @@ func (dst *TsrangeArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for TsrangeArray", src) + return fmt.Errorf("cannot find dimensions of %v for TsrangeArray", src) } if elementsLength == 0 { *dst = TsrangeArray{Status: Present} @@ -68,7 +68,7 @@ func (dst *TsrangeArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TsrangeArray", src) + return fmt.Errorf("cannot convert %v to TsrangeArray", src) } *dst = TsrangeArray{ @@ -99,7 +99,7 @@ func (dst *TsrangeArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to TsrangeArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to TsrangeArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -117,7 +117,7 @@ func (dst *TsrangeArray) setRecursive(value reflect.Value, index, dimension int) valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -130,10 +130,10 @@ func (dst *TsrangeArray) setRecursive(value reflect.Value, index, dimension int) return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to TsrangeArray") + return 0, fmt.Errorf("cannot convert all values to TsrangeArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in TsrangeArray", err) + return 0, fmt.Errorf("%v in TsrangeArray", err) } index++ @@ -186,7 +186,7 @@ func (src *TsrangeArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -201,7 +201,7 @@ func (src *TsrangeArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -209,7 +209,7 @@ func (src *TsrangeArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *TsrangeArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -225,7 +225,7 @@ func (src *TsrangeArray) assignToRecursive(value reflect.Value, index, dimension if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -243,14 +243,14 @@ func (src *TsrangeArray) assignToRecursive(value reflect.Value, index, dimension return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from TsrangeArray") + return 0, fmt.Errorf("cannot assign all values from TsrangeArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from TsrangeArray") + return 0, fmt.Errorf("cannot assign all values from TsrangeArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -409,7 +409,7 @@ func (src TsrangeArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("tsrange"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "tsrange") + return nil, fmt.Errorf("unable to find oid for type name %v", "tsrange") } for i := range src.Elements { @@ -453,7 +453,7 @@ func (dst *TsrangeArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/tstzrange.go b/tstzrange.go index 1b05c3ea..25576308 100644 --- a/tstzrange.go +++ b/tstzrange.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" + "fmt" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Tstzrange struct { @@ -30,7 +30,7 @@ func (dst *Tstzrange) Set(src interface{}) error { case string: return dst.DecodeText(nil, []byte(value)) default: - return errors.Errorf("cannot convert %v to Tstzrange", src) + return fmt.Errorf("cannot convert %v to Tstzrange", src) } return nil @@ -48,7 +48,7 @@ func (dst Tstzrange) Get() interface{} { } func (src *Tstzrange) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *Tstzrange) DecodeText(ci *ConnInfo, src []byte) error { @@ -137,7 +137,7 @@ func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -147,7 +147,7 @@ func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -158,7 +158,7 @@ func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -168,7 +168,7 @@ func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -192,7 +192,7 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -202,7 +202,7 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -218,7 +218,7 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -233,7 +233,7 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -258,7 +258,7 @@ func (dst *Tstzrange) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/tstzrange_array.go b/tstzrange_array.go index 6d9bfe3b..a216820a 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -5,10 +5,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type TstzrangeArray struct { @@ -58,7 +58,7 @@ func (dst *TstzrangeArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for TstzrangeArray", src) + return fmt.Errorf("cannot find dimensions of %v for TstzrangeArray", src) } if elementsLength == 0 { *dst = TstzrangeArray{Status: Present} @@ -68,7 +68,7 @@ func (dst *TstzrangeArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TstzrangeArray", src) + return fmt.Errorf("cannot convert %v to TstzrangeArray", src) } *dst = TstzrangeArray{ @@ -99,7 +99,7 @@ func (dst *TstzrangeArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to TstzrangeArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to TstzrangeArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -117,7 +117,7 @@ func (dst *TstzrangeArray) setRecursive(value reflect.Value, index, dimension in valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -130,10 +130,10 @@ func (dst *TstzrangeArray) setRecursive(value reflect.Value, index, dimension in return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to TstzrangeArray") + return 0, fmt.Errorf("cannot convert all values to TstzrangeArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in TstzrangeArray", err) + return 0, fmt.Errorf("%v in TstzrangeArray", err) } index++ @@ -186,7 +186,7 @@ func (src *TstzrangeArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -201,7 +201,7 @@ func (src *TstzrangeArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -209,7 +209,7 @@ func (src *TstzrangeArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *TstzrangeArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -225,7 +225,7 @@ func (src *TstzrangeArray) assignToRecursive(value reflect.Value, index, dimensi if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -243,14 +243,14 @@ func (src *TstzrangeArray) assignToRecursive(value reflect.Value, index, dimensi return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from TstzrangeArray") + return 0, fmt.Errorf("cannot assign all values from TstzrangeArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from TstzrangeArray") + return 0, fmt.Errorf("cannot assign all values from TstzrangeArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -409,7 +409,7 @@ func (src TstzrangeArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) if dt, ok := ci.DataTypeForName("tstzrange"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "tstzrange") + return nil, fmt.Errorf("unable to find oid for type name %v", "tstzrange") } for i := range src.Elements { @@ -453,7 +453,7 @@ func (dst *TstzrangeArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/typed_array.go.erb b/typed_array.go.erb index 52f14592..5788626b 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -78,7 +78,7 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for <%= pgtype_array_type %>", src) + return fmt.Errorf("cannot find dimensions of %v for <%= pgtype_array_type %>", src) } if elementsLength == 0 { *dst = <%= pgtype_array_type %>{Status: Present} @@ -88,7 +88,7 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>", src) + return fmt.Errorf("cannot convert %v to <%= pgtype_array_type %>", src) } *dst = <%= pgtype_array_type %> { @@ -119,7 +119,7 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to <%= pgtype_array_type %>, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -137,7 +137,7 @@ func (dst *<%= pgtype_array_type %>) setRecursive(value reflect.Value, index, di valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -150,10 +150,10 @@ func (dst *<%= pgtype_array_type %>) setRecursive(value reflect.Value, index, di return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to <%= pgtype_array_type %>") + return 0, fmt.Errorf("cannot convert all values to <%= pgtype_array_type %>") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in <%= pgtype_array_type %>", err) + return 0, fmt.Errorf("%v in <%= pgtype_array_type %>", err) } index++ @@ -206,7 +206,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -221,7 +221,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -229,7 +229,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *<%= pgtype_array_type %>) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -245,7 +245,7 @@ func (src *<%= pgtype_array_type %>) assignToRecursive(value reflect.Value, inde if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -263,14 +263,14 @@ func (src *<%= pgtype_array_type %>) assignToRecursive(value reflect.Value, inde return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr(){ - return 0, errors.Errorf("cannot assign all values from <%= pgtype_array_type %>") + return 0, fmt.Errorf("cannot assign all values from <%= pgtype_array_type %>") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from <%= pgtype_array_type %>") + return 0, fmt.Errorf("cannot assign all values from <%= pgtype_array_type %>") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -432,7 +432,7 @@ func (src <%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") + return nil, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") } for i := range src.Elements { @@ -477,7 +477,7 @@ func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/typed_range.go.erb b/typed_range.go.erb index e21b6cda..5625587a 100644 --- a/typed_range.go.erb +++ b/typed_range.go.erb @@ -32,7 +32,7 @@ func (dst *<%= range_type %>) Set(src interface{}) error { case string: return dst.DecodeText(nil, []byte(value)) default: - return errors.Errorf("cannot convert %v to <%= range_type %>", src) + return fmt.Errorf("cannot convert %v to <%= range_type %>", src) } return nil @@ -50,7 +50,7 @@ func (dst <%= range_type %>) Get() interface{} { } func (src *<%= range_type %>) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *<%= range_type %>) DecodeText(ci *ConnInfo, src []byte) error { @@ -139,7 +139,7 @@ func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error case Empty: return append(buf, "empty"...), nil default: - return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -149,7 +149,7 @@ func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error if err != nil { return nil, err } else if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -160,7 +160,7 @@ func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error if err != nil { return nil, err } else if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -170,7 +170,7 @@ func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error case Inclusive: buf = append(buf, ']') default: - return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -194,7 +194,7 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err case Empty: return append(buf, emptyMask), nil default: - return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -204,7 +204,7 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err rangeType |= upperUnboundedMask case Exclusive: default: - return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -220,7 +220,7 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err return nil, err } if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -235,7 +235,7 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err return nil, err } if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -260,7 +260,7 @@ func (dst *<%= range_type %>) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/uuid.go b/uuid.go index b1681a78..fa0be07f 100644 --- a/uuid.go +++ b/uuid.go @@ -5,8 +5,6 @@ import ( "database/sql/driver" "encoding/hex" "fmt" - - errors "golang.org/x/xerrors" ) type UUID struct { @@ -33,7 +31,7 @@ func (dst *UUID) Set(src interface{}) error { case []byte: if value != nil { if len(value) != 16 { - return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) } *dst = UUID{Status: Present} copy(dst.Bytes[:], value) @@ -56,7 +54,7 @@ func (dst *UUID) Set(src interface{}) error { if originalSrc, ok := underlyingUUIDType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to UUID", value) + return fmt.Errorf("cannot convert %v to UUID", value) } return nil @@ -96,7 +94,7 @@ func (src *UUID) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot assign %v into %T", src, dst) + return fmt.Errorf("cannot assign %v into %T", src, dst) } // parseUUID converts a string UUID in standard form to a byte array. @@ -108,7 +106,7 @@ func parseUUID(src string) (dst [16]byte, err error) { // dashes already stripped, assume valid default: // assume invalid. - return dst, errors.Errorf("cannot parse UUID %v", src) + return dst, fmt.Errorf("cannot parse UUID %v", src) } buf, err := hex.DecodeString(src) @@ -132,7 +130,7 @@ func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) != 36 { - return errors.Errorf("invalid length for UUID: %v", len(src)) + return fmt.Errorf("invalid length for UUID: %v", len(src)) } buf, err := parseUUID(string(src)) @@ -151,7 +149,7 @@ func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 16 { - return errors.Errorf("invalid length for UUID: %v", len(src)) + return fmt.Errorf("invalid length for UUID: %v", len(src)) } *dst = UUID{Status: Present} @@ -197,7 +195,7 @@ func (dst *UUID) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. @@ -226,7 +224,7 @@ func (dst *UUID) UnmarshalJSON(src []byte) error { return dst.Set(nil) } if len(src) != 38 { - return errors.Errorf("invalid length for UUID: %v", len(src)) + return fmt.Errorf("invalid length for UUID: %v", len(src)) } return dst.Set(string(src[1 : len(src)-1])) } diff --git a/uuid_array.go b/uuid_array.go index c6970d52..00721ef9 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -5,10 +5,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type UUIDArray struct { @@ -134,7 +134,7 @@ func (dst *UUIDArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for UUIDArray", src) + return fmt.Errorf("cannot find dimensions of %v for UUIDArray", src) } if elementsLength == 0 { *dst = UUIDArray{Status: Present} @@ -144,7 +144,7 @@ func (dst *UUIDArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to UUIDArray", src) + return fmt.Errorf("cannot convert %v to UUIDArray", src) } *dst = UUIDArray{ @@ -175,7 +175,7 @@ func (dst *UUIDArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to UUIDArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to UUIDArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -193,7 +193,7 @@ func (dst *UUIDArray) setRecursive(value reflect.Value, index, dimension int) (i valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -206,10 +206,10 @@ func (dst *UUIDArray) setRecursive(value reflect.Value, index, dimension int) (i return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to UUIDArray") + return 0, fmt.Errorf("cannot convert all values to UUIDArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in UUIDArray", err) + return 0, fmt.Errorf("%v in UUIDArray", err) } index++ @@ -289,7 +289,7 @@ func (src *UUIDArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -304,7 +304,7 @@ func (src *UUIDArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -312,7 +312,7 @@ func (src *UUIDArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *UUIDArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -328,7 +328,7 @@ func (src *UUIDArray) assignToRecursive(value reflect.Value, index, dimension in if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -346,14 +346,14 @@ func (src *UUIDArray) assignToRecursive(value reflect.Value, index, dimension in return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from UUIDArray") + return 0, fmt.Errorf("cannot assign all values from UUIDArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from UUIDArray") + return 0, fmt.Errorf("cannot assign all values from UUIDArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -512,7 +512,7 @@ func (src UUIDArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("uuid"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "uuid") + return nil, fmt.Errorf("unable to find oid for type name %v", "uuid") } for i := range src.Elements { @@ -556,7 +556,7 @@ func (dst *UUIDArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/varbit.go b/varbit.go index 7461bab3..f24dc5bc 100644 --- a/varbit.go +++ b/varbit.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type Varbit struct { @@ -15,7 +15,7 @@ type Varbit struct { } func (dst *Varbit) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Varbit", src) + return fmt.Errorf("cannot convert %v to Varbit", src) } func (dst Varbit) Get() interface{} { @@ -30,7 +30,7 @@ func (dst Varbit) Get() interface{} { } func (src *Varbit) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) + return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *Varbit) DecodeText(ci *ConnInfo, src []byte) error { @@ -65,7 +65,7 @@ func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) < 4 { - return errors.Errorf("invalid length for varbit: %v", len(src)) + return fmt.Errorf("invalid length for varbit: %v", len(src)) } bitLen := int32(binary.BigEndian.Uint32(src)) @@ -124,7 +124,7 @@ func (dst *Varbit) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/varchar_array.go b/varchar_array.go index f3a9b001..8a309a3f 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -5,10 +5,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "fmt" "reflect" "github.com/jackc/pgio" - errors "golang.org/x/xerrors" ) type VarcharArray struct { @@ -96,7 +96,7 @@ func (dst *VarcharArray) Set(src interface{}) error { dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) if !ok { - return errors.Errorf("cannot find dimensions of %v for VarcharArray", src) + return fmt.Errorf("cannot find dimensions of %v for VarcharArray", src) } if elementsLength == 0 { *dst = VarcharArray{Status: Present} @@ -106,7 +106,7 @@ func (dst *VarcharArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to VarcharArray", src) + return fmt.Errorf("cannot convert %v to VarcharArray", src) } *dst = VarcharArray{ @@ -137,7 +137,7 @@ func (dst *VarcharArray) Set(src interface{}) error { } } if elementCount != len(dst.Elements) { - return errors.Errorf("cannot convert %v to VarcharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + return fmt.Errorf("cannot convert %v to VarcharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } } @@ -155,7 +155,7 @@ func (dst *VarcharArray) setRecursive(value reflect.Value, index, dimension int) valueLen := value.Len() if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") } for i := 0; i < valueLen; i++ { var err error @@ -168,10 +168,10 @@ func (dst *VarcharArray) setRecursive(value reflect.Value, index, dimension int) return index, nil } if !value.CanInterface() { - return 0, errors.Errorf("cannot convert all values to VarcharArray") + return 0, fmt.Errorf("cannot convert all values to VarcharArray") } if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, errors.Errorf("%v in VarcharArray", err) + return 0, fmt.Errorf("%v in VarcharArray", err) } index++ @@ -233,7 +233,7 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { switch value.Kind() { case reflect.Array, reflect.Slice: default: - return errors.Errorf("cannot assign %T to %T", src, dst) + return fmt.Errorf("cannot assign %T to %T", src, dst) } if len(src.Elements) == 0 { @@ -248,7 +248,7 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { return err } if elementCount != len(src.Elements) { - return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) } return nil @@ -256,7 +256,7 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return errors.Errorf("cannot decode %#v into %T", src, dst) + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (src *VarcharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -272,7 +272,7 @@ func (src *VarcharArray) assignToRecursive(value reflect.Value, index, dimension if reflect.Array == kind { typ := value.Type() if typ.Len() != length { - return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) } value.Set(reflect.New(typ).Elem()) } else { @@ -290,14 +290,14 @@ func (src *VarcharArray) assignToRecursive(value reflect.Value, index, dimension return index, nil } if len(src.Dimensions) != dimension { - return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) } if !value.CanAddr() { - return 0, errors.Errorf("cannot assign all values from VarcharArray") + return 0, fmt.Errorf("cannot assign all values from VarcharArray") } addr := value.Addr() if !addr.CanInterface() { - return 0, errors.Errorf("cannot assign all values from VarcharArray") + return 0, fmt.Errorf("cannot assign all values from VarcharArray") } if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { return 0, err @@ -456,7 +456,7 @@ func (src VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("varchar"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, errors.Errorf("unable to find oid for type name %v", "varchar") + return nil, fmt.Errorf("unable to find oid for type name %v", "varchar") } for i := range src.Elements { @@ -500,7 +500,7 @@ func (dst *VarcharArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return errors.Errorf("cannot scan %T", src) + return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. From 63e2dbefaf2f441e96977c596129b610d90f116c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 25 Mar 2021 09:03:46 -0400 Subject: [PATCH 0665/1158] Update copyright date --- LICENSE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LICENSE b/LICENSE index dd9e7be9..5c486c39 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2013 Jack Christensen +Copyright (c) 2013-2021 Jack Christensen MIT License From cdb667b5b002eb70aaac3666814309b07539895d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 25 Mar 2021 09:09:55 -0400 Subject: [PATCH 0666/1158] Update copyright date --- LICENSE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LICENSE b/LICENSE index c1c4f50f..aebadd6c 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2019 Jack Christensen +Copyright (c) 2019-2021 Jack Christensen MIT License From 464a7d88d9ccf1ca9f76a84984d95a5657ac3faa Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 25 Mar 2021 09:15:34 -0400 Subject: [PATCH 0667/1158] Release v1.8.1 --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 787853b2..c377b3ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +# 1.8.1 (March 25, 2021) + +* Better connection string sanitization (ip.novikov) +* Use proper pgpass location on Windows (Moshe Katz) +* Use errors instead of golang.org/x/xerrors +* Resume fallback on server error in Connect (Andrey Borodin) + # 1.8.0 (December 3, 2020) * Add StatementErrored method to stmtcache.Cache. This allows the cache to purge invalidated prepared statements. (Ethan Pailes) From 4a3a424dff9a94723972bbe0510950feb7465087 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 25 Mar 2021 09:16:43 -0400 Subject: [PATCH 0668/1158] Release v1.7.0 --- CHANGELOG.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 38eb89cd..d89f6ddc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,12 @@ +# 1.7.0 (March 25, 2021) + +* Fix scanning int into **sql.Scanner implementor +* Add tsrange array type (Vasilii Novikov) +* Fix: escaped strings when they start or end with a newline char (Stephane Martin) +* Accept nil *time.Time in Time.Set +* Fix numeric NaN support +* Use Go 1.13 errors instead of xerrors + # 1.6.2 (December 3, 2020) * Fix panic on assigning empty array to non-slice or array From 3f76b98073687a376f84a10c0972c3dd0c5de55c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 9 Apr 2021 18:20:06 -0500 Subject: [PATCH 0669/1158] Allow dbname query parameter in URL conn string fixes #69 --- config.go | 8 ++++++++ config_test.go | 12 ++++++++++++ 2 files changed, 20 insertions(+) diff --git a/config.go b/config.go index c162d3c3..6991e1de 100644 --- a/config.go +++ b/config.go @@ -426,7 +426,15 @@ func parseURLSettings(connString string) (map[string]string, error) { settings["database"] = database } + nameMap := map[string]string{ + "dbname": "database", + } + for k, v := range url.Query() { + if k2, present := nameMap[k]; present { + k = k2 + } + settings[k] = v[0] } diff --git a/config_test.go b/config_test.go index e869422d..11dd23dc 100644 --- a/config_test.go +++ b/config_test.go @@ -227,6 +227,18 @@ func TestParseConfig(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + name: "database url dbname", + connString: "postgres://localhost/?dbname=foo&sslmode=disable", + config: &pgconn.Config{ + User: osUserName, + Host: "localhost", + Port: 5432, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { name: "database url postgresql protocol", connString: "postgresql://jack@localhost:5432/mydb?sslmode=disable", From 4380e23ae1c8c8b983ccabdc570eef807c4f4b8e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Apr 2021 08:08:34 -0500 Subject: [PATCH 0670/1158] CompositeTextScanner handles backslash escapes fixes https://github.com/jackc/pgx/issues/874 --- composite_type.go | 4 ++++ composite_type_test.go | 43 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/composite_type.go b/composite_type.go index 7c8dbcd5..32e0aa26 100644 --- a/composite_type.go +++ b/composite_type.go @@ -491,6 +491,10 @@ func (cfs *CompositeTextScanner) Next() bool { } else { break } + } else if ch == '\\' { + cfs.rp++ + cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp]) + cfs.rp++ } else { cfs.fieldBytes = append(cfs.fieldBytes, ch) cfs.rp++ diff --git a/composite_type_test.go b/composite_type_test.go index 664fe36e..2349a67d 100644 --- a/composite_type_test.go +++ b/composite_type_test.go @@ -204,6 +204,49 @@ create type ct_test as ( } } +// https://github.com/jackc/pgx/issues/874 +func TestCompositeTypeTextDecodeNested(t *testing.T) { + newCompositeType := func(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { + fields := make([]pgtype.CompositeTypeField, len(fieldNames)) + for i, name := range fieldNames { + fields[i] = pgtype.CompositeTypeField{Name: name} + } + + rowType, err := pgtype.NewCompositeTypeValues(name, fields, vals) + require.NoError(t, err) + return rowType + } + + dimensionsType := func() pgtype.ValueTranscoder { + return newCompositeType( + "dimensions", + []string{"width", "height"}, + &pgtype.Int4{}, + &pgtype.Int4{}, + ) + } + productImageType := func() pgtype.ValueTranscoder { + return newCompositeType( + "product_image_type", + []string{"source", "dimensions"}, + &pgtype.Text{}, + dimensionsType(), + ) + } + productImageSetType := newCompositeType( + "product_image_set_type", + []string{"name", "orig_image", "images"}, + &pgtype.Text{}, + productImageType(), + pgtype.NewArrayType("product_image", 0, func() pgtype.ValueTranscoder { + return productImageType() + }), + ) + + err := productImageSetType.DecodeText(nil, []byte(`(name,"(img1,""(11,11)"")","{""(img2,\\""(22,22)\\"")"",""(img3,\\""(33,33)\\"")""}")`)) + require.NoError(t, err) +} + func Example_composite() { conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { From cae98b5e457ed06ba70cec57d1282ba0a695077b Mon Sep 17 00:00:00 2001 From: Rueian Date: Mon, 3 May 2021 22:19:50 +0800 Subject: [PATCH 0671/1158] Register JSONBArray at NewConnInfo() --- pgtype.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pgtype.go b/pgtype.go index f1d40146..4a680844 100644 --- a/pgtype.go +++ b/pgtype.go @@ -293,6 +293,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Interval{}, Name: "interval", OID: IntervalOID}) ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID}) ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID}) + ci.RegisterDataType(DataType{Value: &JSONBArray{}, Name: "_jsonb", OID: JSONBArrayOID}) ci.RegisterDataType(DataType{Value: &Line{}, Name: "line", OID: LineOID}) ci.RegisterDataType(DataType{Value: &Lseg{}, Name: "lseg", OID: LsegOID}) ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID}) From fb42201c18fcd016c235d4b613f76b2fc1599588 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 May 2021 18:39:31 -0500 Subject: [PATCH 0672/1158] Fix default host when parsing URL without host but with port fixes https://github.com/jackc/pgconn/issues/72 --- config.go | 8 ++++++-- config_test.go | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 6991e1de..16480589 100644 --- a/config.go +++ b/config.go @@ -411,8 +411,12 @@ func parseURLSettings(connString string) (map[string]string, error) { if err != nil { return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err) } - hosts = append(hosts, h) - ports = append(ports, p) + if h != "" { + hosts = append(hosts, h) + } + if p != "" { + ports = append(ports, p) + } } if len(hosts) > 0 { settings["host"] = strings.Join(hosts, ",") diff --git a/config_test.go b/config_test.go index 11dd23dc..d29173d1 100644 --- a/config_test.go +++ b/config_test.go @@ -32,6 +32,10 @@ func TestParseConfig(t *testing.T) { } } + config, err := pgconn.ParseConfig("") + require.NoError(t, err) + defaultHost := config.Host + tests := []struct { name string connString string @@ -428,6 +432,20 @@ func TestParseConfig(t *testing.T) { }, }, }, + // https://github.com/jackc/pgconn/issues/72 + { + name: "URL without host but with port still uses default host", + connString: "postgres://jack:secret@:1/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: defaultHost, + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { name: "DSN multiple hosts one port", connString: "user=jack password=secret host=foo,bar,baz port=5432 dbname=mydb sslmode=disable", From 0977e29341917778a13cb5751801e1a96e54da31 Mon Sep 17 00:00:00 2001 From: Ivan Daunis Date: Mon, 10 May 2021 19:08:06 -0700 Subject: [PATCH 0673/1158] Support pointers of wrapping structs --- convert.go | 25 ++++++++++++++++++++++--- uuid_test.go | 18 ++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/convert.go b/convert.go index 8ae599b9..7c8ff198 100644 --- a/convert.go +++ b/convert.go @@ -8,9 +8,11 @@ import ( "time" ) -const maxUint = ^uint(0) -const maxInt = int(maxUint >> 1) -const minInt = -maxInt - 1 +const ( + maxUint = ^uint(0) + maxInt = int(maxUint >> 1) + minInt = -maxInt - 1 +) // underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8 func underlyingNumberType(val interface{}) (interface{}, bool) { @@ -432,6 +434,23 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { } } + if dstVal.Kind() == reflect.Struct { + if dstVal.Type().NumField() == 1 && dstVal.Type().Field(0).Anonymous { + dstPtr = dstVal.Field(0).Addr() + nested := dstVal.Type().Field(0).Type + if nested.Kind() == reflect.Array { + if baseElemType, ok := kindTypes[nested.Elem().Kind()]; ok { + baseArrayType := reflect.PtrTo(reflect.ArrayOf(nested.Len(), baseElemType)) + nextDst := dstPtr.Convert(baseArrayType) + return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + } + } + if _, ok := kindTypes[nested.Kind()]; ok && dstPtr.CanInterface() { + return dstPtr.Interface(), true + } + } + } + return nil, false } diff --git a/uuid_test.go b/uuid_test.go index 8de5b9f6..5a93ea8d 100644 --- a/uuid_test.go +++ b/uuid_test.go @@ -16,6 +16,10 @@ func TestUUIDTranscode(t *testing.T) { }) } +type SomeUUIDWrapper struct { + SomeUUIDType +} + type SomeUUIDType [16]byte func TestUUIDSet(t *testing.T) { @@ -127,6 +131,20 @@ func TestUUIDAssignTo(t *testing.T) { } } + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst SomeUUIDWrapper + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst.SomeUUIDType != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } } func TestUUID_MarshalJSON(t *testing.T) { From 5bca076182676b0364693d811b93d8273663b814 Mon Sep 17 00:00:00 2001 From: Ivan Daunis Date: Mon, 17 May 2021 14:11:56 -0700 Subject: [PATCH 0674/1158] Refactor to interface convert --- convert.go | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/convert.go b/convert.go index 7c8ff198..de9ba9ba 100644 --- a/convert.go +++ b/convert.go @@ -389,6 +389,11 @@ func NullAssignTo(dst interface{}) error { var kindTypes map[reflect.Kind]reflect.Type +func toInterface(dst reflect.Value, t reflect.Type) (interface{}, bool) { + nextDst := dst.Convert(t) + return nextDst.Interface(), dst.Type() != nextDst.Type() +} + // GetAssignToDstType attempts to convert dst to something AssignTo can assign // to. If dst is a pointer to pointer it allocates a value and returns the // dereferences pointer. If dst is a named type such as *Foo where Foo is type @@ -414,23 +419,18 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { // if dst is pointer to a base type that has been renamed if baseValType, ok := kindTypes[dstVal.Kind()]; ok { - nextDst := dstPtr.Convert(reflect.PtrTo(baseValType)) - return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + return toInterface(dstPtr, reflect.PtrTo(baseValType)) } if dstVal.Kind() == reflect.Slice { if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { - baseSliceType := reflect.PtrTo(reflect.SliceOf(baseElemType)) - nextDst := dstPtr.Convert(baseSliceType) - return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + return toInterface(dstPtr, reflect.PtrTo(reflect.SliceOf(baseElemType))) } } if dstVal.Kind() == reflect.Array { if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { - baseArrayType := reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType)) - nextDst := dstPtr.Convert(baseArrayType) - return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType))) } } @@ -440,9 +440,7 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { nested := dstVal.Type().Field(0).Type if nested.Kind() == reflect.Array { if baseElemType, ok := kindTypes[nested.Elem().Kind()]; ok { - baseArrayType := reflect.PtrTo(reflect.ArrayOf(nested.Len(), baseElemType)) - nextDst := dstPtr.Convert(baseArrayType) - return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(nested.Len(), baseElemType))) } } if _, ok := kindTypes[nested.Kind()]; ok && dstPtr.CanInterface() { From ba924e5715ad0b06cf1e1ddcc343ded6e9420cf4 Mon Sep 17 00:00:00 2001 From: Henrique Vicente Date: Sun, 16 May 2021 02:05:24 +0200 Subject: [PATCH 0675/1158] json: Implement json.Unmarshaler for messages. This will allow using pgmockproxy output as ingestion data for pgmock. --- authentication_md5_password.go | 22 +++++++++++ authentication_sasl_continue.go | 19 ++++++++++ authentication_sasl_final.go | 19 ++++++++++ bind.go | 35 +++++++++++++++++ close.go | 25 ++++++++++++ command_complete.go | 18 +++++++++ copy_both_response.go | 25 ++++++++++++ copy_data.go | 18 +++++++++ copy_in_response.go | 25 ++++++++++++ copy_out_response.go | 25 ++++++++++++ data_row.go | 25 ++++++++++++ describe.go | 24 ++++++++++++ error_response.go | 67 +++++++++++++++++++++++++++++++++ function_call_response.go | 18 +++++++++ pgproto3.go | 17 ++++++++- ready_for_query.go | 21 +++++++++++ row_description.go | 31 +++++++++++++++ sasl_initial_response.go | 23 +++++++++++ sasl_response.go | 16 ++++++++ 19 files changed, 472 insertions(+), 1 deletion(-) diff --git a/authentication_md5_password.go b/authentication_md5_password.go index d505d264..b80bd992 100644 --- a/authentication_md5_password.go +++ b/authentication_md5_password.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -41,3 +42,24 @@ func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { dst = append(dst, src.Salt[:]...) return dst } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Salt string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if len(msg.Salt) != 4 { + return errors.New("invalid salt size") + } + + copy(dst.Salt[:], []byte(msg.Salt)[:4]) + return nil +} diff --git a/authentication_sasl_continue.go b/authentication_sasl_continue.go index 1b918a6e..62a16c76 100644 --- a/authentication_sasl_continue.go +++ b/authentication_sasl_continue.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -46,3 +47,21 @@ func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { return dst } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil +} diff --git a/authentication_sasl_final.go b/authentication_sasl_final.go index 11d35660..de5e454a 100644 --- a/authentication_sasl_final.go +++ b/authentication_sasl_final.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -46,3 +47,21 @@ func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { return dst } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil +} diff --git a/bind.go b/bind.go index 52372095..57585c4d 100644 --- a/bind.go +++ b/bind.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "fmt" "github.com/jackc/pgio" ) @@ -181,3 +182,37 @@ func (src Bind) MarshalJSON() ([]byte, error) { ResultFormatCodes: src.ResultFormatCodes, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Bind) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters []map[string]string + ResultFormatCodes []int16 + } + err := json.Unmarshal(data, &msg) + if err != nil { + return err + } + bind := &Bind{ + DestinationPortal: msg.DestinationPortal, + PreparedStatement: msg.PreparedStatement, + ParameterFormatCodes: msg.ParameterFormatCodes, + Parameters: make([][]byte, len(msg.Parameters)), + ResultFormatCodes: msg.ResultFormatCodes, + } + for n, parameter := range msg.Parameters { + bind.Parameters[n], err = getValueFromJSON(parameter) + if err != nil { + return fmt.Errorf("cannot get param %d: %w", n, err) + } + } + return nil +} diff --git a/close.go b/close.go index 38296909..a45f2b93 100644 --- a/close.go +++ b/close.go @@ -3,6 +3,7 @@ package pgproto3 import ( "bytes" "encoding/json" + "errors" "github.com/jackc/pgio" ) @@ -62,3 +63,27 @@ func (src Close) MarshalJSON() ([]byte, error) { Name: src.Name, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Close) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + ObjectType string + Name string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.ObjectType) != 1 { + return errors.New("invalid length for Close.ObjectType") + } + + dst.ObjectType = byte(msg.ObjectType[0]) + dst.Name = msg.Name + return nil +} diff --git a/command_complete.go b/command_complete.go index b5106fda..cdc49f39 100644 --- a/command_complete.go +++ b/command_complete.go @@ -51,3 +51,21 @@ func (src CommandComplete) MarshalJSON() ([]byte, error) { CommandTag: string(src.CommandTag), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CommandComplete) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + CommandTag string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.CommandTag = []byte(msg.CommandTag) + return nil +} diff --git a/copy_both_response.go b/copy_both_response.go index 2d58f820..fbd985d8 100644 --- a/copy_both_response.go +++ b/copy_both_response.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" "github.com/jackc/pgio" ) @@ -68,3 +69,27 @@ func (src CopyBothResponse) MarshalJSON() ([]byte, error) { ColumnFormatCodes: src.ColumnFormatCodes, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyBothResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/copy_data.go b/copy_data.go index 7d6002fe..128aa198 100644 --- a/copy_data.go +++ b/copy_data.go @@ -42,3 +42,21 @@ func (src CopyData) MarshalJSON() ([]byte, error) { Data: hex.EncodeToString(src.Data), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyData) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil +} diff --git a/copy_in_response.go b/copy_in_response.go index 5f2595b8..80733adc 100644 --- a/copy_in_response.go +++ b/copy_in_response.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" "github.com/jackc/pgio" ) @@ -69,3 +70,27 @@ func (src CopyInResponse) MarshalJSON() ([]byte, error) { ColumnFormatCodes: src.ColumnFormatCodes, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyInResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyInResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/copy_out_response.go b/copy_out_response.go index 8538dfc7..5e607e3a 100644 --- a/copy_out_response.go +++ b/copy_out_response.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" "github.com/jackc/pgio" ) @@ -69,3 +70,27 @@ func (src CopyOutResponse) MarshalJSON() ([]byte, error) { ColumnFormatCodes: src.ColumnFormatCodes, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyOutResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/data_row.go b/data_row.go index 5fa3c5d8..63768761 100644 --- a/data_row.go +++ b/data_row.go @@ -115,3 +115,28 @@ func (src DataRow) MarshalJSON() ([]byte, error) { Values: formattedValues, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *DataRow) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Values []map[string]string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Values = make([][]byte, len(msg.Values)) + for n, parameter := range msg.Values { + var err error + dst.Values[n], err = getValueFromJSON(parameter) + if err != nil { + return err + } + } + return nil +} diff --git a/describe.go b/describe.go index 308f582e..0d825db1 100644 --- a/describe.go +++ b/describe.go @@ -3,6 +3,7 @@ package pgproto3 import ( "bytes" "encoding/json" + "errors" "github.com/jackc/pgio" ) @@ -62,3 +63,26 @@ func (src Describe) MarshalJSON() ([]byte, error) { Name: src.Name, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Describe) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + ObjectType string + Name string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if len(msg.ObjectType) != 1 { + return errors.New("invalid length for Describe.ObjectType") + } + + dst.ObjectType = byte(msg.ObjectType[0]) + dst.Name = msg.Name + return nil +} diff --git a/error_response.go b/error_response.go index 4eb0a196..9bbd78f4 100644 --- a/error_response.go +++ b/error_response.go @@ -3,6 +3,8 @@ package pgproto3 import ( "bytes" "encoding/binary" + "encoding/json" + "fmt" "strconv" ) @@ -225,3 +227,68 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { return buf.Bytes() } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *ErrorResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[string]string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Severity = msg.Severity + dst.SeverityUnlocalized = msg.SeverityUnlocalized + dst.Code = msg.Code + dst.Message = msg.Message + dst.Detail = msg.Detail + dst.Hint = msg.Hint + dst.Position = msg.Position + dst.InternalPosition = msg.InternalPosition + dst.InternalQuery = msg.InternalQuery + dst.Where = msg.Where + dst.SchemaName = msg.SchemaName + dst.TableName = msg.TableName + dst.ColumnName = msg.ColumnName + dst.DataTypeName = msg.DataTypeName + dst.ConstraintName = msg.ConstraintName + dst.File = msg.File + dst.Line = msg.Line + dst.Routine = msg.Routine + + if msg.UnknownFields != nil { + dst.UnknownFields = map[byte]string{} + } + for k, v := range msg.UnknownFields { + if len(k) != 1 { + return fmt.Errorf("invalid UnknownFields field %q value", k) + } + dst.UnknownFields[k[0]] = v + } + + return nil +} diff --git a/function_call_response.go b/function_call_response.go index 5cc2d4d2..53d64222 100644 --- a/function_call_response.go +++ b/function_call_response.go @@ -81,3 +81,21 @@ func (src FunctionCallResponse) MarshalJSON() ([]byte, error) { Result: formattedValue, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *FunctionCallResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Result map[string]string + } + err := json.Unmarshal(data, &msg) + if err != nil { + return err + } + dst.Result, err = getValueFromJSON(msg.Result) + return err +} diff --git a/pgproto3.go b/pgproto3.go index fe7b085b..5b39362c 100644 --- a/pgproto3.go +++ b/pgproto3.go @@ -1,6 +1,10 @@ package pgproto3 -import "fmt" +import ( + "encoding/hex" + "errors" + "fmt" +) // Message is the interface implemented by an object that can decode and encode // a particular PostgreSQL message. @@ -40,3 +44,14 @@ type invalidMessageFormatErr struct { func (e *invalidMessageFormatErr) Error() string { return fmt.Sprintf("%s body is invalid", e.messageType) } + +// getValueFromJSON gets the value from a protocol message representation in JSON. +func getValueFromJSON(v map[string]string) ([]byte, error) { + if text, ok := v["text"]; ok { + return []byte(text), nil + } + if binary, ok := v["binary"]; ok { + return hex.DecodeString(binary) + } + return nil, errors.New("unknown protocol representation") +} diff --git a/ready_for_query.go b/ready_for_query.go index 879afe39..67a39be3 100644 --- a/ready_for_query.go +++ b/ready_for_query.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/json" + "errors" ) type ReadyForQuery struct { @@ -38,3 +39,23 @@ func (src ReadyForQuery) MarshalJSON() ([]byte, error) { TxStatus: string(src.TxStatus), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *ReadyForQuery) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + TxStatus string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if len(msg.TxStatus) != 1 { + return errors.New("invalid length for ReadyForQuery.TxStatus") + } + dst.TxStatus = msg.TxStatus[0] + return nil +} diff --git a/row_description.go b/row_description.go index d9b8c7c9..a2e0d28e 100644 --- a/row_description.go +++ b/row_description.go @@ -132,3 +132,34 @@ func (src RowDescription) MarshalJSON() ([]byte, error) { Fields: src.Fields, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *RowDescription) UnmarshalJSON(data []byte) error { + var msg struct { + Fields []struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + Format int16 + } + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + dst.Fields = make([]FieldDescription, len(msg.Fields)) + for n, field := range msg.Fields { + dst.Fields[n] = FieldDescription{ + Name: []byte(field.Name), + TableOID: field.TableOID, + TableAttributeNumber: field.TableAttributeNumber, + DataTypeOID: field.DataTypeOID, + DataTypeSize: field.DataTypeSize, + TypeModifier: field.TypeModifier, + Format: field.Format, + } + } + return nil +} diff --git a/sasl_initial_response.go b/sasl_initial_response.go index 0bf8a9e5..ce994c51 100644 --- a/sasl_initial_response.go +++ b/sasl_initial_response.go @@ -67,3 +67,26 @@ func (src SASLInitialResponse) MarshalJSON() ([]byte, error) { Data: hex.EncodeToString(src.Data), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *SASLInitialResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + AuthMechanism string + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + decodedData, err := hex.DecodeString(msg.Data) + if err != nil { + return err + } + dst.AuthMechanism = msg.AuthMechanism + dst.Data = decodedData + return nil +} diff --git a/sasl_response.go b/sasl_response.go index 21be6d75..df60c5f7 100644 --- a/sasl_response.go +++ b/sasl_response.go @@ -41,3 +41,19 @@ func (src SASLResponse) MarshalJSON() ([]byte, error) { Data: hex.EncodeToString(src.Data), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *SASLResponse) UnmarshalJSON(data []byte) error { + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + decoded, err := hex.DecodeString(msg.Data) + if err != nil { + return err + } + dst.Data = decoded + return nil +} From 9c2c389e06738fc2fb5e3c15b5d51b125435b5a0 Mon Sep 17 00:00:00 2001 From: Henrique Vicente Date: Mon, 17 May 2021 02:11:29 +0200 Subject: [PATCH 0676/1158] json: fix implementation of json Unmarshalers. * AuthenticationMD5Password was wrong and is not needed * Bind was wrong * ErrorResponse is not needed * Minor improvements for reliability --- authentication_md5_password.go | 22 -- bind.go | 14 +- error_response.go | 67 ----- json_test.go | 508 +++++++++++++++++++++++++++++++++ pgproto3.go | 3 + sasl_initial_response.go | 12 +- sasl_response.go | 10 +- 7 files changed, 530 insertions(+), 106 deletions(-) create mode 100644 json_test.go diff --git a/authentication_md5_password.go b/authentication_md5_password.go index b80bd992..d505d264 100644 --- a/authentication_md5_password.go +++ b/authentication_md5_password.go @@ -2,7 +2,6 @@ package pgproto3 import ( "encoding/binary" - "encoding/json" "errors" "github.com/jackc/pgio" @@ -42,24 +41,3 @@ func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { dst = append(dst, src.Salt[:]...) return dst } - -// UnmarshalJSON implements encoding/json.Unmarshaler. -func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error { - // Ignore null, like in the main JSON package. - if string(data) == "null" { - return nil - } - - var msg struct { - Salt string - } - if err := json.Unmarshal(data, &msg); err != nil { - return err - } - if len(msg.Salt) != 4 { - return errors.New("invalid salt size") - } - - copy(dst.Salt[:], []byte(msg.Salt)[:4]) - return nil -} diff --git a/bind.go b/bind.go index 57585c4d..e9664f59 100644 --- a/bind.go +++ b/bind.go @@ -201,15 +201,13 @@ func (dst *Bind) UnmarshalJSON(data []byte) error { if err != nil { return err } - bind := &Bind{ - DestinationPortal: msg.DestinationPortal, - PreparedStatement: msg.PreparedStatement, - ParameterFormatCodes: msg.ParameterFormatCodes, - Parameters: make([][]byte, len(msg.Parameters)), - ResultFormatCodes: msg.ResultFormatCodes, - } + dst.DestinationPortal = msg.DestinationPortal + dst.PreparedStatement = msg.PreparedStatement + dst.ParameterFormatCodes = msg.ParameterFormatCodes + dst.Parameters = make([][]byte, len(msg.Parameters)) + dst.ResultFormatCodes = msg.ResultFormatCodes for n, parameter := range msg.Parameters { - bind.Parameters[n], err = getValueFromJSON(parameter) + dst.Parameters[n], err = getValueFromJSON(parameter) if err != nil { return fmt.Errorf("cannot get param %d: %w", n, err) } diff --git a/error_response.go b/error_response.go index 9bbd78f4..4eb0a196 100644 --- a/error_response.go +++ b/error_response.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/binary" - "encoding/json" - "fmt" "strconv" ) @@ -227,68 +225,3 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { return buf.Bytes() } - -// UnmarshalJSON implements encoding/json.Unmarshaler. -func (dst *ErrorResponse) UnmarshalJSON(data []byte) error { - // Ignore null, like in the main JSON package. - if string(data) == "null" { - return nil - } - - var msg struct { - Severity string - SeverityUnlocalized string // only in 9.6 and greater - Code string - Message string - Detail string - Hint string - Position int32 - InternalPosition int32 - InternalQuery string - Where string - SchemaName string - TableName string - ColumnName string - DataTypeName string - ConstraintName string - File string - Line int32 - Routine string - - UnknownFields map[string]string - } - if err := json.Unmarshal(data, &msg); err != nil { - return err - } - - dst.Severity = msg.Severity - dst.SeverityUnlocalized = msg.SeverityUnlocalized - dst.Code = msg.Code - dst.Message = msg.Message - dst.Detail = msg.Detail - dst.Hint = msg.Hint - dst.Position = msg.Position - dst.InternalPosition = msg.InternalPosition - dst.InternalQuery = msg.InternalQuery - dst.Where = msg.Where - dst.SchemaName = msg.SchemaName - dst.TableName = msg.TableName - dst.ColumnName = msg.ColumnName - dst.DataTypeName = msg.DataTypeName - dst.ConstraintName = msg.ConstraintName - dst.File = msg.File - dst.Line = msg.Line - dst.Routine = msg.Routine - - if msg.UnknownFields != nil { - dst.UnknownFields = map[byte]string{} - } - for k, v := range msg.UnknownFields { - if len(k) != 1 { - return fmt.Errorf("invalid UnknownFields field %q value", k) - } - dst.UnknownFields[k[0]] = v - } - - return nil -} diff --git a/json_test.go b/json_test.go new file mode 100644 index 00000000..c73807ab --- /dev/null +++ b/json_test.go @@ -0,0 +1,508 @@ +package pgproto3 + +import ( + "encoding/hex" + "encoding/json" + "reflect" + "testing" +) + +func TestJSONUnmarshalAuthenticationMD5Password(t *testing.T) { + data := []byte(`{"Type":"AuthenticationMD5Password", "Salt":[97,98,99,100]}`) + want := AuthenticationMD5Password{ + Salt: [4]byte{'a', 'b', 'c', 'd'}, + } + + var got AuthenticationMD5Password + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationMD5Password struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationSASL(t *testing.T) { + data := []byte(`{"Type":"AuthenticationSASL", "AuthMechanisms":[]}`) + want := AuthenticationSASL{ + AuthMechanisms: []string{}, + } + + var got AuthenticationSASL + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationSASL struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) { + data := []byte(`{"Type":"AuthenticationSASLContinue"}`) + want := AuthenticationSASLContinue{ + Data: []byte{}, + } + + var got AuthenticationSASLContinue + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationSASLContinue struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationSASLFinal(t *testing.T) { + data := []byte(`{"Type":"AuthenticationSASLFinal"}`) + want := AuthenticationSASLFinal{ + Data: []byte{}, + } + + var got AuthenticationSASLFinal + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationSASLFinal struct doesn't match expected value") + } +} + +func TestJSONUnmarshalBackendKeyData(t *testing.T) { + data := []byte(`{"Type":"BackendKeyData","ProcessID":8864,"SecretKey":3641487067}`) + want := BackendKeyData{ + ProcessID: 8864, + SecretKey: 3641487067, + } + + var got BackendKeyData + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled BackendKeyData struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCommandComplete(t *testing.T) { + data := []byte(`{"Type":"CommandComplete","CommandTag":"SELECT 1"}`) + want := CommandComplete{ + CommandTag: []byte("SELECT 1"), + } + + var got CommandComplete + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CommandComplete struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyBothResponse(t *testing.T) { + data := []byte(`{"Type":"CopyBothResponse", "OverallFormat": "W"}`) + want := CopyBothResponse{ + OverallFormat: 'W', + } + + var got CopyBothResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyBothResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyData(t *testing.T) { + data := []byte(`{"Type":"CopyData"}`) + want := CopyData{ + Data: []byte{}, + } + + var got CopyData + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyData struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyInResponse(t *testing.T) { + data := []byte(`{"Type":"CopyBothResponse", "OverallFormat": "W"}`) + want := CopyBothResponse{ + OverallFormat: 'W', + } + + var got CopyBothResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyBothResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyOutResponse(t *testing.T) { + data := []byte(`{"Type":"CopyOutResponse", "OverallFormat": "W"}`) + want := CopyOutResponse{ + OverallFormat: 'W', + } + + var got CopyOutResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyOutResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalDataRow(t *testing.T) { + data := []byte(`{"Type":"DataRow","Values":[{"text":"abc"},{"text":"this is a test"},{"binary":"000263d3114d2e34"}]}`) + want := DataRow{ + Values: [][]byte{ + []byte("abc"), + []byte("this is a test"), + {0, 2, 99, 211, 17, 77, 46, 52}, + }, + } + + var got DataRow + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled DataRow struct doesn't match expected value") + } +} + +func TestJSONUnmarshalErrorResponse(t *testing.T) { + data := []byte(`{"Type":"ErrorResponse", "UnknownFields": {"97": "foo"}}`) + want := ErrorResponse{ + UnknownFields: map[byte]string{ + 'a': "foo", + }, + } + + var got ErrorResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ErrorResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalFunctionCallResponse(t *testing.T) { + data := []byte(`{"Type":"FunctionCallResponse"}`) + want := FunctionCallResponse{} + + var got FunctionCallResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled FunctionCallResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalNoticeResponse(t *testing.T) { + data := []byte(`{"Type":"NoticeResponse", "UnknownFields": {"97": "foo"}}`) + want := NoticeResponse{ + UnknownFields: map[byte]string{ + 'a': "foo", + }, + } + + var got NoticeResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled NoticeResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalNotificationResponse(t *testing.T) { + data := []byte(`{"Type":"NotificationResponse"}`) + want := NotificationResponse{} + + var got NotificationResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled NotificationResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalParameterDescription(t *testing.T) { + data := []byte(`{"Type":"ParameterDescription", "ParameterOIDs": [25]}`) + want := ParameterDescription{ + ParameterOIDs: []uint32{25}, + } + + var got ParameterDescription + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ParameterDescription struct doesn't match expected value") + } +} + +func TestJSONUnmarshalParameterStatus(t *testing.T) { + data := []byte(`{"Type":"ParameterStatus","Name":"TimeZone","Value":"Europe/Amsterdam"}`) + want := ParameterStatus{ + Name: "TimeZone", + Value: "Europe/Amsterdam", + } + + var got ParameterStatus + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ParameterDescription struct doesn't match expected value") + } +} + +func TestJSONUnmarshalReadyForQuery(t *testing.T) { + data := []byte(`{"Type":"ReadyForQuery","TxStatus":"I"}`) + want := ReadyForQuery{ + TxStatus: 'I', + } + + var got ReadyForQuery + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ParameterDescription struct doesn't match expected value") + } +} + +func TestJSONUnmarshalRowDescription(t *testing.T) { + data := []byte(`{"Type":"RowDescription","Fields":[{"Name":"generate_series","TableOID":0,"TableAttributeNumber":0,"DataTypeOID":23,"DataTypeSize":4,"TypeModifier":-1,"Format":0}]}`) + want := RowDescription{ + Fields: []FieldDescription{ + { + Name: []byte("generate_series"), + DataTypeOID: 23, + DataTypeSize: 4, + TypeModifier: -1, + }, + }, + } + + var got RowDescription + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled RowDescription struct doesn't match expected value") + } +} + +func TestJSONUnmarshalBind(t *testing.T) { + var testCases = []struct { + desc string + data []byte + }{ + { + "textual", + []byte(`{"Type":"Bind","DestinationPortal":"","PreparedStatement":"lrupsc_1_0","ParameterFormatCodes":[0],"Parameters":[{"text":"ABC-123"}],"ResultFormatCodes":[0,0,0,0,0,1,1]}`), + }, + { + "binary", + []byte(`{"Type":"Bind","DestinationPortal":"","PreparedStatement":"lrupsc_1_0","ParameterFormatCodes":[0],"Parameters":[{"binary":"` + hex.EncodeToString([]byte("ABC-123")) + `"}],"ResultFormatCodes":[0,0,0,0,0,1,1]}`), + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + var want = Bind{ + PreparedStatement: "lrupsc_1_0", + ParameterFormatCodes: []int16{0}, + Parameters: [][]byte{[]byte("ABC-123")}, + ResultFormatCodes: []int16{0, 0, 0, 0, 0, 1, 1}, + } + + var got Bind + if err := json.Unmarshal(tc.data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Bind struct doesn't match expected value") + } + }) + } +} + +func TestJSONUnmarshalCancelRequest(t *testing.T) { + data := []byte(`{"Type":"CancelRequest","ProcessID":8864,"SecretKey":3641487067}`) + want := CancelRequest{ + ProcessID: 8864, + SecretKey: 3641487067, + } + + var got CancelRequest + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CancelRequest struct doesn't match expected value") + } +} + +func TestJSONUnmarshalClose(t *testing.T) { + data := []byte(`{"Type":"Close","ObjectType":"S","Name":"abc"}`) + want := Close{ + ObjectType: 'S', + Name: "abc", + } + + var got Close + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Close struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyFail(t *testing.T) { + data := []byte(`{"Type":"CopyFail","Message":"abc"}`) + want := CopyFail{ + Message: "abc", + } + + var got CopyFail + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyFail struct doesn't match expected value") + } +} + +func TestJSONUnmarshalDescribe(t *testing.T) { + data := []byte(`{"Type":"Describe","ObjectType":"S","Name":"abc"}`) + want := Describe{ + ObjectType: 'S', + Name: "abc", + } + + var got Describe + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Describe struct doesn't match expected value") + } +} + +func TestJSONUnmarshalExecute(t *testing.T) { + data := []byte(`{"Type":"Execute","Portal":"","MaxRows":0}`) + want := Execute{} + + var got Execute + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Execute struct doesn't match expected value") + } +} + +func TestJSONUnmarshalParse(t *testing.T) { + data := []byte(`{"Type":"Parse","Name":"lrupsc_1_0","Query":"SELECT id, name FROM t WHERE id = $1","ParameterOIDs":null}`) + want := Parse{ + Name: "lrupsc_1_0", + Query: "SELECT id, name FROM t WHERE id = $1", + } + + var got Parse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Parse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalPasswordMessage(t *testing.T) { + data := []byte(`{"Type":"PasswordMessage","Password":"abcdef"}`) + want := PasswordMessage{ + Password: "abcdef", + } + + var got PasswordMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled PasswordMessage struct doesn't match expected value") + } +} + +func TestJSONUnmarshalQuery(t *testing.T) { + data := []byte(`{"Type":"Query","String":"SELECT 1"}`) + want := Query{ + String: "SELECT 1", + } + + var got Query + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Query struct doesn't match expected value") + } +} + +func TestJSONUnmarshalSASLInitialResponse(t *testing.T) { + data := []byte(`{"Type":"SASLInitialResponse"}`) + want := SASLInitialResponse{} + + var got SASLInitialResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled SASLInitialResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalSASLResponse(t *testing.T) { + data := []byte(`{"Type":"SASLResponse","Message":"abc"}`) + want := SASLResponse{} + + var got SASLResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled SASLResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalStartupMessage(t *testing.T) { + data := []byte(`{"Type":"StartupMessage","ProtocolVersion":196608,"Parameters":{"database":"testing","user":"postgres"}}`) + want := StartupMessage{ + ProtocolVersion: 196608, + Parameters: map[string]string{ + "database": "testing", + "user": "postgres", + }, + } + + var got StartupMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled StartupMessage struct doesn't match expected value") + } +} diff --git a/pgproto3.go b/pgproto3.go index 5b39362c..fb0782cf 100644 --- a/pgproto3.go +++ b/pgproto3.go @@ -47,6 +47,9 @@ func (e *invalidMessageFormatErr) Error() string { // getValueFromJSON gets the value from a protocol message representation in JSON. func getValueFromJSON(v map[string]string) ([]byte, error) { + if v == nil { + return nil, nil + } if text, ok := v["text"]; ok { return []byte(text), nil } diff --git a/sasl_initial_response.go b/sasl_initial_response.go index ce994c51..f7e5f36a 100644 --- a/sasl_initial_response.go +++ b/sasl_initial_response.go @@ -82,11 +82,13 @@ func (dst *SASLInitialResponse) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &msg); err != nil { return err } - decodedData, err := hex.DecodeString(msg.Data) - if err != nil { - return err - } dst.AuthMechanism = msg.AuthMechanism - dst.Data = decodedData + if msg.Data != "" { + decoded, err := hex.DecodeString(msg.Data) + if err != nil { + return err + } + dst.Data = decoded + } return nil } diff --git a/sasl_response.go b/sasl_response.go index df60c5f7..41fb4c39 100644 --- a/sasl_response.go +++ b/sasl_response.go @@ -50,10 +50,12 @@ func (dst *SASLResponse) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &msg); err != nil { return err } - decoded, err := hex.DecodeString(msg.Data) - if err != nil { - return err + if msg.Data != "" { + decoded, err := hex.DecodeString(msg.Data) + if err != nil { + return err + } + dst.Data = decoded } - dst.Data = decoded return nil } From 28c20e93c0f5d81d5e2f324ac81a1dc0ab8649a4 Mon Sep 17 00:00:00 2001 From: Yuli Khodorkovskiy Date: Thu, 27 May 2021 14:47:56 -0400 Subject: [PATCH 0677/1158] Fix json marshal/unmarshal implementations Fix marshal/unmarshal for: - authentication_{cleartext_password, md5_password, ok, sasl, sasl_continue, sasl_final} - error_response --- authentication_cleartext_password.go | 10 +++ authentication_md5_password.go | 31 ++++++++ authentication_ok.go | 10 +++ authentication_sasl.go | 12 +++ authentication_sasl_continue.go | 11 +++ authentication_sasl_final.go | 11 +++ error_response.go | 107 +++++++++++++++++++++++++++ json_test.go | 80 ++++++++++++++++++-- 8 files changed, 264 insertions(+), 8 deletions(-) diff --git a/authentication_cleartext_password.go b/authentication_cleartext_password.go index dd82c7a7..1b87a718 100644 --- a/authentication_cleartext_password.go +++ b/authentication_cleartext_password.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -37,3 +38,12 @@ func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte { dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword) return dst } + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationCleartextPassword) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "AuthenticationCleartextPassword", + }) +} diff --git a/authentication_md5_password.go b/authentication_md5_password.go index d505d264..95795b31 100644 --- a/authentication_md5_password.go +++ b/authentication_md5_password.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -41,3 +42,33 @@ func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { dst = append(dst, src.Salt[:]...) return dst } + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationMD5Password) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Salt [4]byte + }{ + Type: "AuthenticationMD5Password", + Salt: src.Salt, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Salt [4]byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Salt = msg.Salt + return nil +} diff --git a/authentication_ok.go b/authentication_ok.go index 7b13c6e0..ad69b907 100644 --- a/authentication_ok.go +++ b/authentication_ok.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -37,3 +38,12 @@ func (src *AuthenticationOk) Encode(dst []byte) []byte { dst = pgio.AppendUint32(dst, AuthTypeOk) return dst } + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationOk) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "AuthenticationOK", + }) +} diff --git a/authentication_sasl.go b/authentication_sasl.go index c57ae32d..d2b09750 100644 --- a/authentication_sasl.go +++ b/authentication_sasl.go @@ -3,6 +3,7 @@ package pgproto3 import ( "bytes" "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -58,3 +59,14 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte { return dst } + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationSASL) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + AuthMechanisms []string + }{ + Type: "AuthenticationSASL", + AuthMechanisms: src.AuthMechanisms, + }) +} diff --git a/authentication_sasl_continue.go b/authentication_sasl_continue.go index 62a16c76..d258065f 100644 --- a/authentication_sasl_continue.go +++ b/authentication_sasl_continue.go @@ -48,6 +48,17 @@ func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationSASLContinue) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "AuthenticationSASLContinue", + Data: string(src.Data), + }) +} + // UnmarshalJSON implements encoding/json.Unmarshaler. func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error { // Ignore null, like in the main JSON package. diff --git a/authentication_sasl_final.go b/authentication_sasl_final.go index de5e454a..6a681d73 100644 --- a/authentication_sasl_final.go +++ b/authentication_sasl_final.go @@ -48,6 +48,17 @@ func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Unmarshaler. +func (src AuthenticationSASLFinal) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "AuthenticationSASLFinal", + Data: string(src.Data), + }) +} + // UnmarshalJSON implements encoding/json.Unmarshaler. func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error { // Ignore null, like in the main JSON package. diff --git a/error_response.go b/error_response.go index 4eb0a196..ec51e019 100644 --- a/error_response.go +++ b/error_response.go @@ -3,6 +3,7 @@ package pgproto3 import ( "bytes" "encoding/binary" + "encoding/json" "strconv" ) @@ -225,3 +226,109 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { return buf.Bytes() } + +// MarshalJSON implements encoding/json.Marshaler. +func (src ErrorResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string + }{ + Type: "ErrorResponse", + Severity: src.Severity, + SeverityUnlocalized: src.SeverityUnlocalized, + Code: src.Code, + Message: src.Message, + Detail: src.Detail, + Hint: src.Hint, + Position: src.Position, + InternalPosition: src.InternalPosition, + InternalQuery: src.InternalQuery, + Where: src.Where, + SchemaName: src.SchemaName, + TableName: src.TableName, + ColumnName: src.ColumnName, + DataTypeName: src.DataTypeName, + ConstraintName: src.ConstraintName, + File: src.File, + Line: src.Line, + Routine: src.Routine, + UnknownFields: src.UnknownFields, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *ErrorResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Severity = msg.Severity + dst.SeverityUnlocalized = msg.SeverityUnlocalized + dst.Code = msg.Code + dst.Message = msg.Message + dst.Detail = msg.Detail + dst.Hint = msg.Hint + dst.Position = msg.Position + dst.InternalPosition = msg.InternalPosition + dst.InternalQuery = msg.InternalQuery + dst.Where = msg.Where + dst.SchemaName = msg.SchemaName + dst.TableName = msg.TableName + dst.ColumnName = msg.ColumnName + dst.DataTypeName = msg.DataTypeName + dst.ConstraintName = msg.ConstraintName + dst.File = msg.File + dst.Line = msg.Line + dst.Routine = msg.Routine + + dst.UnknownFields = msg.UnknownFields + + return nil +} diff --git a/json_test.go b/json_test.go index c73807ab..eab26252 100644 --- a/json_test.go +++ b/json_test.go @@ -23,9 +23,9 @@ func TestJSONUnmarshalAuthenticationMD5Password(t *testing.T) { } func TestJSONUnmarshalAuthenticationSASL(t *testing.T) { - data := []byte(`{"Type":"AuthenticationSASL", "AuthMechanisms":[]}`) + data := []byte(`{"Type":"AuthenticationSASL","AuthMechanisms":["SCRAM-SHA-256"]}`) want := AuthenticationSASL{ - AuthMechanisms: []string{}, + []string{"SCRAM-SHA-256"}, } var got AuthenticationSASL @@ -38,9 +38,9 @@ func TestJSONUnmarshalAuthenticationSASL(t *testing.T) { } func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) { - data := []byte(`{"Type":"AuthenticationSASLContinue"}`) + data := []byte(`{"Type":"AuthenticationSASLContinue", "Data":"1"}`) want := AuthenticationSASLContinue{ - Data: []byte{}, + Data: []byte{'1'}, } var got AuthenticationSASLContinue @@ -53,9 +53,9 @@ func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) { } func TestJSONUnmarshalAuthenticationSASLFinal(t *testing.T) { - data := []byte(`{"Type":"AuthenticationSASLFinal"}`) + data := []byte(`{"Type":"AuthenticationSASLFinal", "Data":"1"}`) want := AuthenticationSASLFinal{ - Data: []byte{}, + Data: []byte{'1'}, } var got AuthenticationSASLFinal @@ -463,8 +463,11 @@ func TestJSONUnmarshalQuery(t *testing.T) { } func TestJSONUnmarshalSASLInitialResponse(t *testing.T) { - data := []byte(`{"Type":"SASLInitialResponse"}`) - want := SASLInitialResponse{} + data := []byte(`{"Type":"SASLInitialResponse", "AuthMechanism":"SCRAM-SHA-256", "Data": "6D"}`) + want := SASLInitialResponse{ + AuthMechanism: "SCRAM-SHA-256", + Data: []byte{109}, + } var got SASLInitialResponse if err := json.Unmarshal(data, &got); err != nil { @@ -506,3 +509,64 @@ func TestJSONUnmarshalStartupMessage(t *testing.T) { t.Error("unmarshaled StartupMessage struct doesn't match expected value") } } + +func TestAuthenticationOK(t *testing.T) { + data := []byte(`{"Type":"AuthenticationOK"}`) + want := AuthenticationOk{} + + var got AuthenticationOk + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationOK struct doesn't match expected value") + } +} + +func TestAuthenticationCleartextPassword(t *testing.T) { + data := []byte(`{"Type":"AuthenticationCleartextPassword"}`) + want := AuthenticationCleartextPassword{} + + var got AuthenticationCleartextPassword + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationCleartextPassword struct doesn't match expected value") + } +} + +func TestAuthenticationMD5Password(t *testing.T) { + data := []byte(`{"Type":"AuthenticationMD5Password","Salt":[1,2,3,4]}`) + want := AuthenticationMD5Password{ + Salt: [4]byte{1, 2, 3, 4}, + } + + var got AuthenticationMD5Password + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationMD5Password struct doesn't match expected value") + } +} + +func TestErrorResponse(t *testing.T) { + data := []byte(`{"Type":"ErrorResponse","UnknownFields":{"112":"foo"},"Code": "Fail","Position":1,"Message":"this is an error"}`) + want := ErrorResponse{ + UnknownFields: map[byte]string{ + 'p': "foo", + }, + Code: "Fail", + Position: 1, + Message: "this is an error", + } + + var got ErrorResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ErrorResponse struct doesn't match expected value") + } +} From 7c9e8407262f7bfb750aef36c3e49bbff6596d35 Mon Sep 17 00:00:00 2001 From: Yuli Khodorkovskiy Date: Thu, 27 May 2021 14:48:11 -0400 Subject: [PATCH 0678/1158] Add support for identifying authentication messages The pgprotocol overloads 'p' messages with PasswordMessage, SASLInitialResponse, SASLResponse, and GSSResponse. This patch allows contextual identification of the message by setting the authType in the frontend and then setting this value in the backend when a AuthenticationResponseMessage is received. --- authentication_cleartext_password.go | 3 ++ authentication_md5_password.go | 3 ++ authentication_ok.go | 3 ++ authentication_sasl.go | 3 ++ authentication_sasl_continue.go | 3 ++ authentication_sasl_final.go | 3 ++ backend.go | 81 +++++++++++++++++++++------- frontend.go | 27 ++++++++-- password_message.go | 3 ++ pgproto3.go | 5 ++ 10 files changed, 113 insertions(+), 21 deletions(-) diff --git a/authentication_cleartext_password.go b/authentication_cleartext_password.go index 1b87a718..241fa600 100644 --- a/authentication_cleartext_password.go +++ b/authentication_cleartext_password.go @@ -15,6 +15,9 @@ type AuthenticationCleartextPassword struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationCleartextPassword) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationCleartextPassword) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationCleartextPassword) Decode(src []byte) error { diff --git a/authentication_md5_password.go b/authentication_md5_password.go index 95795b31..32ec0390 100644 --- a/authentication_md5_password.go +++ b/authentication_md5_password.go @@ -16,6 +16,9 @@ type AuthenticationMD5Password struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationMD5Password) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationMD5Password) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationMD5Password) Decode(src []byte) error { diff --git a/authentication_ok.go b/authentication_ok.go index ad69b907..2b476fe5 100644 --- a/authentication_ok.go +++ b/authentication_ok.go @@ -15,6 +15,9 @@ type AuthenticationOk struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationOk) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationOk) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationOk) Decode(src []byte) error { diff --git a/authentication_sasl.go b/authentication_sasl.go index d2b09750..bdcb2c36 100644 --- a/authentication_sasl.go +++ b/authentication_sasl.go @@ -17,6 +17,9 @@ type AuthenticationSASL struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationSASL) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationSASL) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationSASL) Decode(src []byte) error { diff --git a/authentication_sasl_continue.go b/authentication_sasl_continue.go index d258065f..7f4a9c23 100644 --- a/authentication_sasl_continue.go +++ b/authentication_sasl_continue.go @@ -16,6 +16,9 @@ type AuthenticationSASLContinue struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationSASLContinue) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationSASLContinue) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationSASLContinue) Decode(src []byte) error { diff --git a/authentication_sasl_final.go b/authentication_sasl_final.go index 6a681d73..d82b9ee4 100644 --- a/authentication_sasl_final.go +++ b/authentication_sasl_final.go @@ -16,6 +16,9 @@ type AuthenticationSASLFinal struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationSASLFinal) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationSASLFinal) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationSASLFinal) Decode(src []byte) error { diff --git a/backend.go b/backend.go index cc6f1f03..232aa11d 100644 --- a/backend.go +++ b/backend.go @@ -12,27 +12,27 @@ type Backend struct { w io.Writer // Frontend message flyweights - bind Bind - cancelRequest CancelRequest - _close Close - copyFail CopyFail - copyData CopyData - copyDone CopyDone - describe Describe - execute Execute - flush Flush - gssEncRequest GSSEncRequest - parse Parse - passwordMessage PasswordMessage - query Query - sslRequest SSLRequest - startupMessage StartupMessage - sync Sync - terminate Terminate + bind Bind + cancelRequest CancelRequest + _close Close + copyFail CopyFail + copyData CopyData + copyDone CopyDone + describe Describe + execute Execute + flush Flush + gssEncRequest GSSEncRequest + parse Parse + query Query + sslRequest SSLRequest + startupMessage StartupMessage + sync Sync + terminate Terminate bodyLen int msgType byte partialMsg bool + authType uint32 } // NewBackend creates a new Backend. @@ -127,7 +127,19 @@ func (b *Backend) Receive() (FrontendMessage, error) { case 'P': msg = &b.parse case 'p': - msg = &b.passwordMessage + switch b.authType { + case AuthTypeSASL: + msg = &SASLInitialResponse{} + case AuthTypeSASLContinue: + msg = &SASLResponse{} + case AuthTypeSASLFinal: + msg = &SASLResponse{} + case AuthTypeCleartextPassword, AuthTypeMD5Password: + fallthrough + default: + // to maintain backwards compatability + msg = &PasswordMessage{} + } case 'Q': msg = &b.query case 'S': @@ -148,3 +160,36 @@ func (b *Backend) Receive() (FrontendMessage, error) { err = msg.Decode(msgBody) return msg, err } + +// SetAuthType sets the authentication type in the backend. +// Since multiple message types can start with 'p', SetAuthType allows +// contextual identification of FrontendMessages. For example, in the +// PG message flow documentation for PasswordMessage: +// +// Byte1('p') +// +// Identifies the message as a password response. Note that this is also used for +// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from +// the context. +// +// Since the Frontend does not know about the state of a backend, it is important +// to call SetAuthType() after an authentication request is received by the Frontend. +func (b *Backend) SetAuthType(authType uint32) error { + switch authType { + case AuthTypeOk, + AuthTypeCleartextPassword, + AuthTypeMD5Password, + AuthTypeSCMCreds, + AuthTypeGSS, + AuthTypeGSSCont, + AuthTypeSSPI, + AuthTypeSASL, + AuthTypeSASLContinue, + AuthTypeSASLFinal: + b.authType = authType + default: + return fmt.Errorf("authType not recognized: %d", authType) + } + + return nil +} diff --git a/frontend.go b/frontend.go index b8f545ca..c33dfb08 100644 --- a/frontend.go +++ b/frontend.go @@ -45,6 +45,7 @@ type Frontend struct { bodyLen int msgType byte partialMsg bool + authType uint32 } // NewFrontend creates a new Frontend. @@ -146,10 +147,16 @@ func (f *Frontend) Receive() (BackendMessage, error) { } // Authentication message type constants. +// See src/include/libpq/pqcomm.h for all +// constants. const ( AuthTypeOk = 0 AuthTypeCleartextPassword = 3 AuthTypeMD5Password = 5 + AuthTypeSCMCreds = 6 + AuthTypeGSS = 7 + AuthTypeGSSCont = 8 + AuthTypeSSPI = 9 AuthTypeSASL = 10 AuthTypeSASLContinue = 11 AuthTypeSASLFinal = 12 @@ -159,15 +166,23 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er if len(src) < 4 { return nil, errors.New("authentication message too short") } - authType := binary.BigEndian.Uint32(src[:4]) + f.authType = binary.BigEndian.Uint32(src[:4]) - switch authType { + switch f.authType { case AuthTypeOk: return &f.authenticationOk, nil case AuthTypeCleartextPassword: return &f.authenticationCleartextPassword, nil case AuthTypeMD5Password: return &f.authenticationMD5Password, nil + case AuthTypeSCMCreds: + return nil, errors.New("AuthTypeSCMCreds is unimplemented") + case AuthTypeGSS: + return nil, errors.New("AuthTypeGSS is unimplemented") + case AuthTypeGSSCont: + return nil, errors.New("AuthTypeGSSCont is unimplemented") + case AuthTypeSSPI: + return nil, errors.New("AuthTypeSSPI is unimplemented") case AuthTypeSASL: return &f.authenticationSASL, nil case AuthTypeSASLContinue: @@ -175,6 +190,12 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er case AuthTypeSASLFinal: return &f.authenticationSASLFinal, nil default: - return nil, fmt.Errorf("unknown authentication type: %d", authType) + return nil, fmt.Errorf("unknown authentication type: %d", f.authType) } } + +// GetAuthType returns the authType used in the current state of the frontend. +// See SetAuthType for more information. +func (f *Frontend) GetAuthType() uint32 { + return f.authType +} diff --git a/password_message.go b/password_message.go index 4b68b31a..cae76c50 100644 --- a/password_message.go +++ b/password_message.go @@ -14,6 +14,9 @@ type PasswordMessage struct { // Frontend identifies this message as sendable by a PostgreSQL frontend. func (*PasswordMessage) Frontend() {} +// Frontend identifies this message as an authentication response. +func (*PasswordMessage) InitialResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *PasswordMessage) Decode(src []byte) error { diff --git a/pgproto3.go b/pgproto3.go index fb0782cf..70c825e3 100644 --- a/pgproto3.go +++ b/pgproto3.go @@ -27,6 +27,11 @@ type BackendMessage interface { Backend() // no-op method to distinguish frontend from backend methods } +type AuthenticationResponseMessage interface { + BackendMessage + AuthenticationResponse() // no-op method to distinguish authentication responses +} + type invalidMessageLenErr struct { messageType string expectedLen int From 821e0521e464a4f88219ba97c1d4be77c15cd8e8 Mon Sep 17 00:00:00 2001 From: Sivabalan Thirunavukkarasu Date: Thu, 17 Jun 2021 19:43:59 +0800 Subject: [PATCH 0679/1158] Updating dependency versions --- go.mod | 4 +- go.sum | 347 +++++++++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 317 insertions(+), 34 deletions(-) diff --git a/go.mod b/go.mod index f213388a..e79435f6 100644 --- a/go.mod +++ b/go.mod @@ -4,9 +4,9 @@ go 1.13 require ( github.com/gofrs/uuid v3.2.0+incompatible - github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853 + github.com/jackc/pgconn v1.8.1 github.com/jackc/pgio v1.0.0 - github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904 + github.com/jackc/pgx/v4 v4.11.0 github.com/lib/pq v1.3.0 github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc github.com/stretchr/testify v1.5.1 diff --git a/go.sum b/go.sum index 464f0091..c053fb49 100644 --- a/go.sum +++ b/go.sum @@ -1,34 +1,139 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= +github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= +github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= +github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= +github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= +github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= +github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= +github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= +github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/aryann/difflib v0.0.0-20170710044230-e206f873d14a/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A= +github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU= +github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= +github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= +github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ= +github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec/go.mod h1:jMjuTZXRI4dUb/I5gc9Hdhagfvm9+RyrPryS/auMzxE= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= +github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= +github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= +github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= +github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= +github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= +github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= +github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= +github.com/envoyproxy/go-control-plane v0.6.9/go.mod h1:SBwIajubJHhxtWwsL9s8ss4safvEdbitLhGGK48rN6g= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= +github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.10.0/go.mod h1:xUsJbQ/Fp4kEt7AFgCuvyX4a71u8h9jB8tj/ORgOZ7o= +github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= +github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= +github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= +github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= +github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= +github.com/hashicorp/consul/api v1.3.0/go.mod h1:MmDNSzIMUjNpY/mQ398R4bk2FnqQLoPndWW5VkKPlCE= +github.com/hashicorp/consul/sdk v0.3.0/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= +github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-msgpack v0.5.3/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= +github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU= +github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= +github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= +github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= +github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= +github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= +github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= +github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= -github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= -github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3 h1:ZFYpB74Kq8xE9gmfxCmXD6QxZ27ja+j3HwGFc+YurhQ= github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= -github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb h1:d6GP9szHvXVopAOAnZ7WhRnF3Xdxrylmm/9jnfmW4Ag= github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= -github.com/jackc/pgconn v1.4.0 h1:E82UBzFyD752mvI+4RIl1WSxfO2ug64T+sLjvDBWTpA= github.com/jackc/pgconn v1.4.0/go.mod h1:Y2O3ZDF0q4mMacyWV3AstPJpeHXWGEetiFttmq5lahk= -github.com/jackc/pgconn v1.5.0 h1:oFSOilzIZkyg787M1fEmyMfOUUvwj0daqYMfaWwNL4o= github.com/jackc/pgconn v1.5.0/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= -github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853 h1:LRlrfJW9S99uiOCY8F/qLvX1yEY1TVAaCBHFb79yHBQ= github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= +github.com/jackc/pgconn v1.8.1 h1:ySBX7Q87vOMqKU2bbmKbUvtYhauDFclYbNDYIE1/h6s= +github.com/jackc/pgconn v1.8.1/go.mod h1:JV6m6b6jhjdmzchES0drzCcYcAHS1OPD5xu3OZ/lE2g= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA= @@ -37,41 +142,48 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.1 h1:Rdjp4NFjwHnEslx2b66FfCI2S0LhO4itac3hXz6WX9M= github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= +github.com/jackc/pgproto3/v2 v2.0.6 h1:b1105ZGEMFe7aCvrT1Cca3VoVb4ZFMaFJLJcg/3zD+8= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4WpC0= github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= github.com/jackc/pgtype v1.3.1-0.20200606141011-f6355165a91c/go.mod h1:cvk9Bgu/VzJ9/lxTO5R5sf80p0DiucVtN7ZxvaC4GmQ= -github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96 h1:ylEAOd688Duev/fxTmGdupsbyZfxNMdngIG14DoBKTM= +github.com/jackc/pgtype v1.7.0/go.mod h1:ZnHF+rMePVqDKaOfJVI4Q8IVvAQMryDlDkZnKOI75BE= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= -github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912 h1:YuOWGsSK5L4Fz81Olx5TNlZftmDuNrfv4ip0Yos77Tw= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= -github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186 h1:ZQM8qLT/E/CGD6XX0E6q9FAwxJYmWpJufzmLMaFuzgQ= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= -github.com/jackc/pgx/v4 v4.5.0 h1:mN7Z3n0uqPe29+tA4yLWyZNceYKgRvUWNk8qW+D066E= github.com/jackc/pgx/v4 v4.5.0/go.mod h1:EpAKPLdnTorwmPUUsqrPxy5fphV18j9q3wrfRXgo+kA= -github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9 h1:rche9LTjh3HEvkE6eb8ITYxRsgEKgBkODHrhdvDVX74= github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9/go.mod h1:t3/cdRQl6fOLDxqtlyhe9UWgfIi9R8+8v8GKV5TRA/o= -github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904 h1:SdGWuGg+Cpxq6Z+ArXt0nafaKeTvtKGEoW+yvycspUU= github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904/go.mod h1:ZDaNWkt9sW1JMiNn0kdYBaLelIhw7Pg4qd+Vk6tw7Hg= +github.com/jackc/pgx/v4 v4.11.0 h1:J86tSWd3Y7nKjwT/43xZBvpi04keQWx8gNC2YkdJhZI= +github.com/jackc/pgx/v4 v4.11.0/go.mod h1:i62xJgdrtVDsnL3U8ekyrQXEwGNTRoG7/8r+CIdYfcc= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= +github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.8/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -79,105 +191,276 @@ github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.1.0 h1:/5u4a+KGJptBRqGzPvYQL9p0d/tPR4S31+Tnzj9lEO4= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU= github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= +github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= +github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= +github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= +github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= +github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= +github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= +github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= +github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/nats-io/jwt v0.3.0/go.mod h1:fRYCDE99xlTsqUzISS1Bi75UBJ6ljOJQOAAu5VglpSg= +github.com/nats-io/jwt v0.3.2/go.mod h1:/euKqTS1ZD+zzjYrY7pseZrTtWQSjujC7xjPc8wL6eU= +github.com/nats-io/nats-server/v2 v2.1.2/go.mod h1:Afk+wRZqkMQs/p45uXdrVLuab3gwv3Z8C4HTBu8GD/k= +github.com/nats-io/nats.go v1.9.1/go.mod h1:ZjDU1L/7fJ09jvUSRVBR2e7+RnLiiIQyqyzEE/Zbp4w= +github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= +github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= +github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= +github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs= +github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= +github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/op/go-logging v0.0.0-20160315200505-970db520ece7/go.mod h1:HzydrMdWErDVzsI23lYNej1Htcns9BCg93Dk0bBINWk= +github.com/opentracing-contrib/go-observer v0.0.0-20170622124052-a52f23424492/go.mod h1:Ngi6UdF0k5OKD5t5wlmGhe/EDKPoUM3BXZSSfIuJbis= +github.com/opentracing/basictracer-go v1.0.0/go.mod h1:QfBfYuafItcjQuMwinw9GhYKwFXS9KnPs5lxoYwgW74= +github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/openzipkin-contrib/zipkin-go-opentracing v0.4.5/go.mod h1:/wsWhb9smxSfWAKL3wpBW7V8scJMt8N8gnaMCS9E/cA= +github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= +github.com/openzipkin/zipkin-go v0.2.1/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= +github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= +github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM= +github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= +github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= +github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac= +github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= +github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= +github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs= +github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= +github.com/prometheus/client_golang v1.3.0/go.mod h1:hJaj2vgQTGQmVCsAACORcieXFeDPbaTKGT+JTgUa3og= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.1.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA= +github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= +github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= -github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= +github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= +github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= -github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 h1:pntxY8Ary0t43dCZ5dqY4YTJCObLY1kIXl0uzMv+7DE= +github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc h1:jUIKcSPO9MoMJBbEoyE/RJoE8vz7Mb8AjvifMMwSyvY= github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= +github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= +github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= +github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= +github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= +github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= +github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= +github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= +go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg= +go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= +go.opencensus.io v0.20.2/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= +go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= +golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7 h1:0hQKqeLdqlt5iIwVOBErRisrHJAN57yOiPRQItI20fU= golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 h1:3zb4D3T4G8jdExgVU/95+vQXfpEPiMdCaZgmGVxjNHM= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= +golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190530194941-fb225487d101/go.mod h1:z3L6/3dTEVtUr6QSP8miRzeRqwQOioJ9I66odjN4I7s= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.20.0/go.mod h1:chYK+tFQF0nDUGJgXMSgLCQk3phJEuONr2DCgLDdAQM= +google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= +google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= +google.golang.org/grpc v1.22.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/gcfg.v1 v1.2.3/go.mod h1:yesOnuUOFQAhST5vPY4nbZsb/huCgGGXlipJsBn0b3o= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= +gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= +sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU= From 2c22da0155d719b3de0d6a332fb25aa5f6a2beda Mon Sep 17 00:00:00 2001 From: Sivabalan Thirunavukkarasu Date: Thu, 17 Jun 2021 20:46:51 +0800 Subject: [PATCH 0680/1158] Bumping versions for other dependencies --- go.mod | 8 ++++---- go.sum | 15 ++++++++++----- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index e79435f6..dd2449e6 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,11 @@ module github.com/jackc/pgtype go 1.13 require ( - github.com/gofrs/uuid v3.2.0+incompatible + github.com/gofrs/uuid v4.0.0+incompatible github.com/jackc/pgconn v1.8.1 github.com/jackc/pgio v1.0.0 github.com/jackc/pgx/v4 v4.11.0 - github.com/lib/pq v1.3.0 - github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc - github.com/stretchr/testify v1.5.1 + github.com/lib/pq v1.10.2 + github.com/shopspring/decimal v1.2.0 + github.com/stretchr/testify v1.7.0 ) diff --git a/go.sum b/go.sum index c053fb49..01f503c9 100644 --- a/go.sum +++ b/go.sum @@ -67,8 +67,9 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= +github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= @@ -193,8 +194,9 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU= github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= @@ -287,8 +289,9 @@ github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0 github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= -github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc h1:jUIKcSPO9MoMJBbEoyE/RJoE8vz7Mb8AjvifMMwSyvY= github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= @@ -308,8 +311,9 @@ github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoH github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= @@ -456,8 +460,9 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWD gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= From cfcd61d0cbf58dfa254d094864bcfc22c7a3e104 Mon Sep 17 00:00:00 2001 From: Sivabalan Thirunavukkarasu Date: Thu, 17 Jun 2021 20:17:10 +0800 Subject: [PATCH 0681/1158] Updating dependency versions --- go.mod | 4 ++-- go.sum | 33 ++++++++++----------------------- 2 files changed, 12 insertions(+), 25 deletions(-) diff --git a/go.mod b/go.mod index e9003cb7..233fa205 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,11 @@ go 1.12 require ( github.com/jackc/chunkreader/v2 v2.0.1 github.com/jackc/pgio v1.0.0 - github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 + github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.6 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 - golang.org/x/text v0.3.3 + golang.org/x/text v0.3.6 ) diff --git a/go.sum b/go.sum index 58bb1286..14121a04 100644 --- a/go.sum +++ b/go.sum @@ -8,17 +8,18 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= -github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= -github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd h1:eDErF6V/JPJON/B7s68BxwHgfmyOntHJQ8IOaz0x4R8= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= @@ -26,22 +27,9 @@ github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI1NIJyNFVVeh1gUISyt57iw/fmI/IXJfH3ATE= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.1 h1:Rdjp4NFjwHnEslx2b66FfCI2S0LhO4itac3hXz6WX9M= -github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.0.2 h1:q1Hsy66zh4vuNsajBUF2PNqfAMMfxU5mk594lPE9vjY= -github.com/jackc/pgproto3/v2 v2.0.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.0.3 h1:2S4PhE00mvdvaSiCYR1ZCmR1NAxeYfTSsqqSKxE1vzo= -github.com/jackc/pgproto3/v2 v2.0.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.0.4 h1:RHkX5ZUD9bl/kn0f9dYUWs1N7Nwvo1wwUYvKiR26Zco= -github.com/jackc/pgproto3/v2 v2.0.4/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.0.5 h1:NUbEWPmCQZbMmYlTjVoNPhc0CfnYyz2bfUAh6A5ZVJM= -github.com/jackc/pgproto3/v2 v2.0.5/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.0.6 h1:b1105ZGEMFe7aCvrT1Cca3VoVb4ZFMaFJLJcg/3zD+8= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= -github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= @@ -81,7 +69,6 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= @@ -93,12 +80,9 @@ go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= -golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 h1:3zb4D3T4G8jdExgVU/95+vQXfpEPiMdCaZgmGVxjNHM= -golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -114,20 +98,23 @@ golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From bacf81fb4eada115882c2f4f29d0d42047902be5 Mon Sep 17 00:00:00 2001 From: Sivabalan Thirunavukkarasu Date: Thu, 17 Jun 2021 20:43:54 +0800 Subject: [PATCH 0682/1158] Bumping versions for other dependencies --- go.mod | 6 +++--- go.sum | 14 +++++++++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index 233fa205..57f773b1 100644 --- a/go.mod +++ b/go.mod @@ -7,9 +7,9 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.6 + github.com/jackc/pgproto3/v2 v2.1.0 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b - github.com/stretchr/testify v1.5.1 - golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 + github.com/stretchr/testify v1.7.0 + golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e golang.org/x/text v0.3.6 ) diff --git a/go.sum b/go.sum index 14121a04..eedcac1b 100644 --- a/go.sum +++ b/go.sum @@ -28,8 +28,9 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.6 h1:b1105ZGEMFe7aCvrT1Cca3VoVb4ZFMaFJLJcg/3zD+8= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.0 h1:h2yg3kjIyAGSZKDijYn1/gXHlYLCwl9ZjEh2PU0yVxE= +github.com/jackc/pgproto3/v2 v2.1.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= @@ -70,8 +71,9 @@ github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoH github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= @@ -83,8 +85,8 @@ golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaE golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w= -golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e h1:gsTQYXdTw2Gq7RBsWvlQ91b+aEQ6bXFUngBGuR8sPpI= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -100,6 +102,7 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -119,5 +122,6 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From bf76d1ed51099a78209a2dc109d826cab20d286e Mon Sep 17 00:00:00 2001 From: mgoddard Date: Sat, 19 Jun 2021 07:16:00 -0400 Subject: [PATCH 0683/1158] Solve issue with 'sslmode=verify-full' when there are multiple hosts --- config.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/config.go b/config.go index 16480589..172e7478 100644 --- a/config.go +++ b/config.go @@ -297,7 +297,7 @@ func ParseConfig(connString string) (*Config, error) { tlsConfigs = append(tlsConfigs, nil) } else { var err error - tlsConfigs, err = configTLS(settings) + tlsConfigs, err = configTLS(settings, host) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} } @@ -552,8 +552,8 @@ func parseServiceSettings(servicefilePath, serviceName string) (map[string]strin // configTLS uses libpq's TLS parameters to construct []*tls.Config. It is // necessary to allow returning multiple TLS configs as sslmode "allow" and // "prefer" allow fallback. -func configTLS(settings map[string]string) ([]*tls.Config, error) { - host := settings["host"] +func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, error) { + host := thisHost sslmode := settings["sslmode"] sslrootcert := settings["sslrootcert"] sslcert := settings["sslcert"] From 2ca304d4617a72491b1844eb92e9ecd86f7b84e9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jun 2021 10:49:34 -0500 Subject: [PATCH 0684/1158] pgtype.Inet preserves masked address portion fixes #111 --- inet.go | 3 ++- inet_test.go | 9 +++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/inet.go b/inet.go index 101b9ab4..43f7252a 100644 --- a/inet.go +++ b/inet.go @@ -45,10 +45,11 @@ func (dst *Inet) Set(src interface{}) error { *dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Status: Present} } case string: - _, ipnet, err := net.ParseCIDR(value) + ip, ipnet, err := net.ParseCIDR(value) if err != nil { return err } + ipnet.IP = ip *dst = Inet{IPNet: ipnet, Status: Present} case *net.IPNet: if value == nil { diff --git a/inet_test.go b/inet_test.go index cb420a51..08d73e4e 100644 --- a/inet_test.go +++ b/inet_test.go @@ -7,6 +7,7 @@ import ( "github.com/jackc/pgtype" "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/assert" ) func TestInetTranscode(t *testing.T) { @@ -16,6 +17,7 @@ func TestInetTranscode(t *testing.T) { &pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, &pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.50/24"), Status: pgtype.Present}, &pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, &pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, &pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, @@ -35,6 +37,7 @@ func TestInetSet(t *testing.T) { {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: "1.2.3.4/24", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("1.2.3.4"), Mask: net.CIDRMask(24, 32)}, Status: pgtype.Present}}, {source: net.ParseIP(""), result: pgtype.Inet{Status: pgtype.Null}}, } @@ -45,8 +48,10 @@ func TestInetSet(t *testing.T) { t.Errorf("%d: %v", i, err) } - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + assert.Equalf(t, tt.result.Status, r.Status, "%d: Status", i) + if tt.result.Status == pgtype.Present { + assert.Equalf(t, tt.result.IPNet.Mask, r.IPNet.Mask, "%d: IP", i) + assert.Truef(t, tt.result.IPNet.IP.Equal(r.IPNet.IP), "%d: Mask", i) } } } From a123e5b4e575b5eb3c68ae4ab87c508d341242df Mon Sep 17 00:00:00 2001 From: Joshua Brindle Date: Mon, 21 Jun 2021 15:25:10 -0400 Subject: [PATCH 0685/1158] Add defaults for sslcert, sslkey, and sslrootcert per https://www.postgresql.org/docs/current/libpq-ssl.html psql will use client certs located in ~/.postgresql on posix systems or %APPDATA%\postgresql on Windows systems. --- defaults.go | 13 +++++++++++++ defaults_windows.go | 13 +++++++++++++ 2 files changed, 26 insertions(+) diff --git a/defaults.go b/defaults.go index d3313481..f69cad31 100644 --- a/defaults.go +++ b/defaults.go @@ -22,6 +22,19 @@ func defaultSettings() map[string]string { settings["user"] = user.Username settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") + sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") + sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") + if _, err := os.Stat(sslcert); err == nil { + if _, err := os.Stat(sslkey); err == nil { + // Both the cert and key must be present to use them, or do not use either + settings["sslcert"] = sslcert + settings["sslkey"] = sslkey + } + } + sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt") + if _, err := os.Stat(sslrootcert); err == nil { + settings["sslrootcert"] = sslrootcert + } } settings["target_session_attrs"] = "any" diff --git a/defaults_windows.go b/defaults_windows.go index 55243700..71eb77db 100644 --- a/defaults_windows.go +++ b/defaults_windows.go @@ -29,6 +29,19 @@ func defaultSettings() map[string]string { settings["user"] = username settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf") settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") + sslcert := filepath.Join(appData, "postgresql", "postgresql.crt") + sslkey := filepath.Join(appData, "postgresql", "postgresql.key") + if _, err := os.Stat(sslcert); err == nil { + if _, err := os.Stat(sslkey); err == nil { + // Both the cert and key must be present to use them, or do not use either + settings["sslcert"] = sslcert + settings["sslkey"] = sslkey + } + } + sslrootcert := filepath.Join(appData, "postgresql", "root.crt") + if _, err := os.Stat(sslrootcert); err == nil { + settings["sslrootcert"] = sslrootcert + } } settings["target_session_attrs"] = "any" From 10c6c50ac9638764295be9bd6e897b7386ba7614 Mon Sep 17 00:00:00 2001 From: Yuli Khodorkovskiy Date: Wed, 30 Jun 2021 12:54:45 -0400 Subject: [PATCH 0686/1158] Extend handling of unexpected EOF to the backend In the original issue [1] and commit [2], support for unexpected EOF was added to the frontend to detect when a connection was closed abruptly. Additionally, this allows us to differentiate normal io.EOF errors with unexpected errors in the backend. [1] https://github.com/jackc/pgx/issues/662/ [2] https://github.com/jackc/pgproto3/commit/595780be0f9f581451a23a5151b77f782202ad72 --- backend.go | 6 +++--- backend_test.go | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/backend.go b/backend.go index 232aa11d..c9fa87ff 100644 --- a/backend.go +++ b/backend.go @@ -58,7 +58,7 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { buf, err = b.cr.Next(msgSize) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } code := binary.BigEndian.Uint32(buf) @@ -98,7 +98,7 @@ func (b *Backend) Receive() (FrontendMessage, error) { if !b.partialMsg { header, err := b.cr.Next(5) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } b.msgType = header[0] @@ -152,7 +152,7 @@ func (b *Backend) Receive() (FrontendMessage, error) { msgBody, err := b.cr.Next(b.bodyLen) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } b.partialMsg = false diff --git a/backend_test.go b/backend_test.go index 43a3f76c..19970c34 100644 --- a/backend_test.go +++ b/backend_test.go @@ -1,9 +1,11 @@ package pgproto3_test import ( + "io" "testing" "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/assert" ) func TestBackendReceiveInterrupted(t *testing.T) { @@ -32,3 +34,23 @@ func TestBackendReceiveInterrupted(t *testing.T) { t.Fatalf("unexpected msg: %v", msg) } } + +func TestBackendReceiveUnexpectedEOF(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Q', 0, 0, 0, 6}) + + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) + + // Receive regular msg + msg, err := backend.Receive() + assert.Nil(t, msg) + assert.Equal(t, io.ErrUnexpectedEOF, err) + + // Receive FE msg + server.push([]byte{'F', 0, 0, 0, 6}) + msg, err = backend.ReceiveStartupMessage() + assert.Nil(t, msg) + assert.Equal(t, io.ErrUnexpectedEOF, err) +} From 3eceab0f382295901243f9b43973108c36ee4d1a Mon Sep 17 00:00:00 2001 From: Cameron Daniel Date: Wed, 30 Jun 2021 14:22:26 +0200 Subject: [PATCH 0687/1158] Maintain host bits for inet types --- inet.go | 15 +++++++++++---- inet_test.go | 45 +++++++++++++++++++++++++++++---------------- pgtype_test.go | 14 ++++++++++++++ 3 files changed, 54 insertions(+), 20 deletions(-) diff --git a/inet.go b/inet.go index 43f7252a..1645334e 100644 --- a/inet.go +++ b/inet.go @@ -132,18 +132,22 @@ func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { var err error if ip := net.ParseIP(string(src)); ip != nil { - ipv4 := ip.To4() - if ipv4 != nil { + if ipv4 := ip.To4(); ipv4 != nil { ip = ipv4 } bitCount := len(ip) * 8 mask := net.CIDRMask(bitCount, bitCount) ipnet = &net.IPNet{Mask: mask, IP: ip} } else { - _, ipnet, err = net.ParseCIDR(string(src)) + ip, ipnet, err = net.ParseCIDR(string(src)) if err != nil { return err } + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + } + ones, _ := ipnet.Mask.Size() + *ipnet = net.IPNet{IP: ip, Mask: net.CIDRMask(ones, len(ip)*8)} } *dst = Inet{IPNet: ipnet, Status: Present} @@ -168,7 +172,10 @@ func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { var ipnet net.IPNet ipnet.IP = make(net.IP, int(addressLength)) copy(ipnet.IP, src[4:]) - ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) + if ipv4 := ipnet.IP.To4(); ipv4 != nil { + ipnet.IP = ipv4 + } + ipnet.Mask = net.CIDRMask(int(bits), len(ipnet.IP)*8) *dst = Inet{IPNet: &ipnet, Status: Present} diff --git a/inet_test.go b/inet_test.go index 08d73e4e..66fe777f 100644 --- a/inet_test.go +++ b/inet_test.go @@ -11,22 +11,35 @@ import ( ) func TestInetTranscode(t *testing.T) { - for _, pgTypeName := range []string{"inet", "cidr"} { - testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ - &pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.50/24"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - &pgtype.Inet{Status: pgtype.Null}, - }) - } + testutil.TestSuccessfulTranscode(t, "inet", []interface{}{ + &pgtype.Inet{IPNet: mustParseInet(t, "0.0.0.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "127.0.0.1/8"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "12.34.56.65/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "192.168.1.16/24"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "255.0.0.0/8"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "255.255.255.255/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "::1/64"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "::/0"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "::1/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), Status: pgtype.Present}, + &pgtype.Inet{Status: pgtype.Null}, + }) +} + +func TestCidrTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "cidr", []interface{}{ + &pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + &pgtype.Inet{Status: pgtype.Null}, + }) } func TestInetSet(t *testing.T) { diff --git a/pgtype_test.go b/pgtype_test.go index f46ec12a..75e1909f 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -35,6 +35,20 @@ func mustParseCIDR(t testing.TB, s string) *net.IPNet { return ipnet } +func mustParseInet(t testing.TB, s string) *net.IPNet { + ip, ipnet, err := net.ParseCIDR(s) + if err != nil { + t.Fatal(err) + } + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + } + + ipnet.IP = ip + + return ipnet +} + func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { addr, err := net.ParseMAC(s) if err != nil { From 2d3823838e54632bd4a3c85fb174f9c6fd64b107 Mon Sep 17 00:00:00 2001 From: Yuli Khodorkovskiy Date: Thu, 23 Jan 2020 10:54:03 -0500 Subject: [PATCH 0688/1158] Perform StartupMessage length validation PG provides a maximum size for a StartupMessage: https://doxygen.postgresql.org/pqcomm_8h.html#a4c50c668c551887ac3a49872130349e3 Limiting the size ensures a malicious user doesn't send an overwhelmingly large StartupMessage which could DOS a Go binary that uses pgproto3. --- backend.go | 9 +++++++ backend_test.go | 63 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/backend.go b/backend.go index 232aa11d..6944f80d 100644 --- a/backend.go +++ b/backend.go @@ -35,6 +35,11 @@ type Backend struct { authType uint32 } +const ( + minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code. + maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source. +) + // NewBackend creates a new Backend. func NewBackend(cr ChunkReader, w io.Writer) *Backend { return &Backend{cr: cr, w: w} @@ -56,6 +61,10 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { } msgSize := int(binary.BigEndian.Uint32(buf) - 4) + if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen { + return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize) + } + buf, err = b.cr.Next(msgSize) if err != nil { return nil, err diff --git a/backend_test.go b/backend_test.go index 43a3f76c..3cfde003 100644 --- a/backend_test.go +++ b/backend_test.go @@ -3,7 +3,9 @@ package pgproto3_test import ( "testing" + "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/require" ) func TestBackendReceiveInterrupted(t *testing.T) { @@ -32,3 +34,64 @@ func TestBackendReceiveInterrupted(t *testing.T) { t.Fatalf("unexpected msg: %v", msg) } } + +func TestStartupMessage(t *testing.T) { + t.Parallel() + + t.Run("valid StartupMessage", func(t *testing.T) { + want := &pgproto3.StartupMessage{ + ProtocolVersion: pgproto3.ProtocolVersionNumber, + Parameters: map[string]string{ + "username": "tester", + }, + } + dst := []byte{} + dst = want.Encode(dst) + + server := &interruptReader{} + server.push(dst) + + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) + + msg, err := backend.ReceiveStartupMessage() + require.NoError(t, err) + require.Equal(t, want, msg) + }) + + t.Run("invalid packet length", func(t *testing.T) { + wantErr := "invalid length of startup packet" + tests := []struct { + name string + packetLen uint32 + }{ + { + name: "large packet length", + // Since the StartupMessage contains the "Length of message contents + // in bytes, including self", the max startup packet length is actually + // 10000+4. Therefore, let's go past the limit with 10005 + packetLen: 10005, + }, + { + name: "short packet length", + packetLen: 3, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := &interruptReader{} + dst := []byte{} + dst = pgio.AppendUint32(dst, tt.packetLen) + dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber) + server.push(dst) + + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) + + msg, err := backend.ReceiveStartupMessage() + require.Error(t, err) + require.Nil(t, msg) + require.Contains(t, err.Error(), wantErr) + }) + } + }) + +} From 033ca7d47f43284a1aa3754259a6ec0561b02840 Mon Sep 17 00:00:00 2001 From: Yuli Khodorkovskiy Date: Tue, 6 Jul 2021 21:35:21 -0400 Subject: [PATCH 0689/1158] Fix unexpected EOF failure for StartupMessage --- backend_test.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/backend_test.go b/backend_test.go index 5e9a2ac5..708f1280 100644 --- a/backend_test.go +++ b/backend_test.go @@ -50,8 +50,12 @@ func TestBackendReceiveUnexpectedEOF(t *testing.T) { assert.Nil(t, msg) assert.Equal(t, io.ErrUnexpectedEOF, err) - // Receive FE msg - server.push([]byte{'F', 0, 0, 0, 6}) + // Receive StartupMessage msg + dst := []byte{} + dst = pgio.AppendUint32(dst, 1000) // tell the backend we expect 1000 bytes to be read + dst = pgio.AppendUint32(dst, 1) // only send 1 byte + server.push(dst) + msg, err = backend.ReceiveStartupMessage() assert.Nil(t, msg) assert.Equal(t, io.ErrUnexpectedEOF, err) From c0b4d3bc05e51a6df4c011de058f0e4daf7e154f Mon Sep 17 00:00:00 2001 From: Michael Darr Date: Tue, 29 Jun 2021 14:24:09 -0400 Subject: [PATCH 0690/1158] Implement timeout error Signed-off-by: Michael Darr --- errors.go | 47 ++++++++++++++++++++++++++++++++++++++++------- pgconn.go | 37 +++++++++++++++++++++++++------------ 2 files changed, 65 insertions(+), 19 deletions(-) diff --git a/errors.go b/errors.go index 77adfcf0..5df851d5 100644 --- a/errors.go +++ b/errors.go @@ -18,15 +18,11 @@ func SafeToRetry(err error) bool { return false } -// Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a +// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a // context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. func Timeout(err error) bool { - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return true - } - - var netErr net.Error - return errors.As(err, &netErr) && netErr.Timeout() + var timeoutErr *ErrTimeout + return errors.As(err, &timeoutErr) } // PgError represents an error reported by the PostgreSQL server. See @@ -134,6 +130,32 @@ func (e *pgconnError) Unwrap() error { return e.err } +// ErrTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is +// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true. +type ErrTimeout struct { + err error +} + +func (e *ErrTimeout) Error() string { + return fmt.Sprintf("timeout: %s", e.err.Error()) +} + +func (e *ErrTimeout) SafeToRetry() bool { + var ctxErr *contextAlreadyDoneError + if errors.As(e, &ctxErr) { + return ctxErr.SafeToRetry() + } + var netErr net.Error + if errors.As(e, &netErr) { + return netErr.Temporary() + } + return false +} + +func (e *ErrTimeout) Unwrap() error { + return e.err +} + type contextAlreadyDoneError struct { err error } @@ -150,6 +172,17 @@ func (e *contextAlreadyDoneError) Unwrap() error { return e.err } +// newContextAlreadyDoneError wraps a context error in `contextAlreadyDoneError`. If the context was cancelled or its +// deadline passed, the returned error is also wrapped by `ErrTimeout`. +func newContextAlreadyDoneError(ctx context.Context) (err error) { + ctxErr := ctx.Err() + err = &contextAlreadyDoneError{err: ctxErr} + if ctxErr != nil { + err = &ErrTimeout{err: err} + } + return err +} + type writeError struct { err error safeToRetry bool diff --git a/pgconn.go b/pgconn.go index 197aad4a..74e24257 100644 --- a/pgconn.go +++ b/pgconn.go @@ -217,6 +217,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) pgConn.conn, err = config.DialFunc(ctx, network, address) if err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + err = &ErrTimeout{err: err} + } return nil, &connectError{config: config, msg: "dial error", err: err} } @@ -389,7 +393,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { if ctx != context.Background() { select { case <-ctx.Done(): - return &contextAlreadyDoneError{err: ctx.Err()} + return newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -421,7 +425,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa if ctx != context.Background() { select { case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} + return nil, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -451,7 +455,8 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { pgConn.bufferingReceive = false // If a timeout error happened in the background try the read again. - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { msg, err = pgConn.frontend.Receive() } } else { @@ -460,8 +465,12 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { if err != nil { // Close on anything other than timeout error - everything else is fatal - if err, ok := err.(net.Error); !(ok && err.Timeout()) { + var netErr net.Error + isNetErr := errors.As(err, &netErr) + if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() + } else if isNetErr && netErr.Timeout() { + err = &ErrTimeout{err: err} } return nil, err @@ -476,8 +485,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { msg, err := pgConn.peekMessage() if err != nil { // Close on anything other than timeout error - everything else is fatal - if err, ok := err.(net.Error); !(ok && err.Timeout()) { + var netErr net.Error + isNetErr := errors.As(err, &netErr) + if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() + } else if isNetErr && netErr.Timeout() { + err = &ErrTimeout{err: err} } return nil, err @@ -745,7 +758,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ if ctx != context.Background() { select { case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} + return nil, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -918,7 +931,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + multiResult.err = newContextAlreadyDoneError(ctx) pgConn.unlock() return multiResult default: @@ -964,7 +977,7 @@ func (pgConn *PgConn) ReceiveResults(ctx context.Context) *MultiResultReader { select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + multiResult.err = newContextAlreadyDoneError(ctx) pgConn.unlock() return multiResult default: @@ -1058,7 +1071,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by if ctx != context.Background() { select { case <-ctx.Done(): - result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) + result.concludeCommand(nil, newContextAlreadyDoneError(ctx)) result.closed = true pgConn.unlock() return result @@ -1098,7 +1111,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm select { case <-ctx.Done(): pgConn.unlock() - return nil, &contextAlreadyDoneError{err: ctx.Err()} + return nil, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -1158,7 +1171,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co if ctx != context.Background() { select { case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} + return nil, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -1601,7 +1614,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + multiResult.err = newContextAlreadyDoneError(ctx) pgConn.unlock() return multiResult default: From b3e64d3cdb6e805e32adce9c4a148c2ebf6e9cee Mon Sep 17 00:00:00 2001 From: Michael Darr Date: Tue, 6 Jul 2021 15:36:46 -0400 Subject: [PATCH 0691/1158] Simplify SafeToRetry for ErrTimeout Signed-off-by: Michael Darr --- errors.go | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/errors.go b/errors.go index 5df851d5..0bb322cd 100644 --- a/errors.go +++ b/errors.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net" "net/url" "regexp" "strings" @@ -141,15 +140,7 @@ func (e *ErrTimeout) Error() string { } func (e *ErrTimeout) SafeToRetry() bool { - var ctxErr *contextAlreadyDoneError - if errors.As(e, &ctxErr) { - return ctxErr.SafeToRetry() - } - var netErr net.Error - if errors.As(e, &netErr) { - return netErr.Temporary() - } - return false + return SafeToRetry(e.err) } func (e *ErrTimeout) Unwrap() error { From 9a9830c00d579aaa709b095acd2ab96162e3a564 Mon Sep 17 00:00:00 2001 From: Michael Darr Date: Tue, 6 Jul 2021 15:43:26 -0400 Subject: [PATCH 0692/1158] Always double-wrap contextAlreadyDoneError Signed-off-by: Michael Darr --- errors.go | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/errors.go b/errors.go index 0bb322cd..ab83b3a5 100644 --- a/errors.go +++ b/errors.go @@ -163,15 +163,9 @@ func (e *contextAlreadyDoneError) Unwrap() error { return e.err } -// newContextAlreadyDoneError wraps a context error in `contextAlreadyDoneError`. If the context was cancelled or its -// deadline passed, the returned error is also wrapped by `ErrTimeout`. +// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `ErrTimeout`. func newContextAlreadyDoneError(ctx context.Context) (err error) { - ctxErr := ctx.Err() - err = &contextAlreadyDoneError{err: ctxErr} - if ctxErr != nil { - err = &ErrTimeout{err: err} - } - return err + return &ErrTimeout{&contextAlreadyDoneError{err: ctx.Err()}} } type writeError struct { From a50d96d4915cae7d1a28601ce9e7a57b0ea5ae41 Mon Sep 17 00:00:00 2001 From: Michael Darr Date: Tue, 6 Jul 2021 21:44:44 -0400 Subject: [PATCH 0693/1158] Make timeout error private Signed-off-by: Michael Darr --- errors.go | 16 ++++++++-------- pgconn.go | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/errors.go b/errors.go index ab83b3a5..64401d65 100644 --- a/errors.go +++ b/errors.go @@ -20,7 +20,7 @@ func SafeToRetry(err error) bool { // Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a // context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. func Timeout(err error) bool { - var timeoutErr *ErrTimeout + var timeoutErr *errTimeout return errors.As(err, &timeoutErr) } @@ -129,21 +129,21 @@ func (e *pgconnError) Unwrap() error { return e.err } -// ErrTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is +// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is // context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true. -type ErrTimeout struct { +type errTimeout struct { err error } -func (e *ErrTimeout) Error() string { +func (e *errTimeout) Error() string { return fmt.Sprintf("timeout: %s", e.err.Error()) } -func (e *ErrTimeout) SafeToRetry() bool { +func (e *errTimeout) SafeToRetry() bool { return SafeToRetry(e.err) } -func (e *ErrTimeout) Unwrap() error { +func (e *errTimeout) Unwrap() error { return e.err } @@ -163,9 +163,9 @@ func (e *contextAlreadyDoneError) Unwrap() error { return e.err } -// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `ErrTimeout`. +// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`. func newContextAlreadyDoneError(ctx context.Context) (err error) { - return &ErrTimeout{&contextAlreadyDoneError{err: ctx.Err()}} + return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}} } type writeError struct { diff --git a/pgconn.go b/pgconn.go index 74e24257..a17a108d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -219,7 +219,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err != nil { var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - err = &ErrTimeout{err: err} + err = &errTimeout{err: err} } return nil, &connectError{config: config, msg: "dial error", err: err} } @@ -470,7 +470,7 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() } else if isNetErr && netErr.Timeout() { - err = &ErrTimeout{err: err} + err = &errTimeout{err: err} } return nil, err @@ -490,7 +490,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() } else if isNetErr && netErr.Timeout() { - err = &ErrTimeout{err: err} + err = &errTimeout{err: err} } return nil, err From 5b7c6a3c8e9f0191a7383abb3440f3656a1efdc4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 10 Jul 2021 09:54:24 -0500 Subject: [PATCH 0694/1158] Upgrade to pgproto3 v2.1.1 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 57f773b1..dad81ebe 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.1.0 + github.com/jackc/pgproto3/v2 v2.1.1 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.7.0 golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e diff --git a/go.sum b/go.sum index eedcac1b..54405c28 100644 --- a/go.sum +++ b/go.sum @@ -31,6 +31,8 @@ github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1: github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.1.0 h1:h2yg3kjIyAGSZKDijYn1/gXHlYLCwl9ZjEh2PU0yVxE= github.com/jackc/pgproto3/v2 v2.1.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= From 13d454882b790b8a8fa00e049e9dc2c0e84318fc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 10 Jul 2021 09:54:39 -0500 Subject: [PATCH 0695/1158] Release v1.9.0 --- CHANGELOG.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c377b3ed..c496ea30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,12 @@ +# 1.9.0 (July 10, 2021) + +* pgconn.Timeout only is true for errors originating in pgconn (Michael Darr) +* Add defaults for sslcert, sslkey, and sslrootcert (Joshua Brindle) +* Solve issue with 'sslmode=verify-full' when there are multiple hosts (mgoddard) +* Fix default host when parsing URL without host but with port +* Allow dbname query parameter in URL conn string +* Update underlying dependencies + # 1.8.1 (March 25, 2021) * Better connection string sanitization (ip.novikov) From dcdc3eaec79d2767b0f67b7875667378a91c4061 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 10 Jul 2021 09:58:12 -0500 Subject: [PATCH 0696/1158] Release v1.8.0 --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d89f6ddc..0c8514e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +# 1.8.0 (July 10, 2021) + +* Maintain host bits for inet types (Cameron Daniel) +* Support pointers of wrapping structs (Ivan Daunis) +* Register JSONBArray at NewConnInfo() (Rueian) +* CompositeTextScanner handles backslash escapes + # 1.7.0 (March 25, 2021) * Fix scanning int into **sql.Scanner implementor From 6996e8d6c546d45bab6f1e8b24c010f40f095e6e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Jul 2021 09:09:22 -0500 Subject: [PATCH 0697/1158] Context errors returned instead of net.Error The net.Error caused by using SetDeadline to implement context cancellation shouldn't leak. fixes #80 --- errors.go | 10 ++++++++++ pgconn.go | 28 ++++++++++++++-------------- pgconn_test.go | 5 ++++- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/errors.go b/errors.go index 64401d65..a32b29c9 100644 --- a/errors.go +++ b/errors.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net" "net/url" "regexp" "strings" @@ -105,6 +106,15 @@ func (e *parseConfigError) Unwrap() error { return e.err } +// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == +// true. Otherwise returns err. +func preferContextOverNetTimeoutError(ctx context.Context, err error) error { + if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { + return &errTimeout{err: ctx.Err()} + } + return err +} + type pgconnError struct { msg string err error diff --git a/pgconn.go b/pgconn.go index a17a108d..43b13e43 100644 --- a/pgconn.go +++ b/pgconn.go @@ -271,7 +271,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err, ok := err.(*PgError); ok { return nil, err } - return nil, &connectError{config: config, msg: "failed to receive message", err: err} + return nil, &connectError{config: config, msg: "failed to receive message", err: preferContextOverNetTimeoutError(ctx, err)} } switch msg := msg.(type) { @@ -434,7 +434,10 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa msg, err := pgConn.receiveMessage() if err != nil { - err = &pgconnError{msg: "receive message failed", err: err, safeToRetry: true} + err = &pgconnError{ + msg: "receive message failed", + err: preferContextOverNetTimeoutError(ctx, err), + safeToRetry: true} } return msg, err } @@ -469,8 +472,6 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { isNetErr := errors.As(err, &netErr) if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() - } else if isNetErr && netErr.Timeout() { - err = &errTimeout{err: err} } return nil, err @@ -489,8 +490,6 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { isNetErr := errors.As(err, &netErr) if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() - } else if isNetErr && netErr.Timeout() { - err = &errTimeout{err: err} } return nil, err @@ -785,7 +784,7 @@ readloop: msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -888,7 +887,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { if ctx != context.Background() { select { case <-ctx.Done(): - return ctx.Err() + return newContextAlreadyDoneError(ctx) default: } @@ -899,7 +898,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { for { msg, err := pgConn.receiveMessage() if err != nil { - return err + return preferContextOverNetTimeoutError(ctx, err) } switch msg.(type) { @@ -1136,7 +1135,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1196,7 +1195,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1255,7 +1254,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1287,7 +1286,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1329,7 +1328,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) if err != nil { mrr.pgConn.contextWatcher.Unwatch() - mrr.err = err + mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.closed = true mrr.pgConn.asyncClose() return nil, mrr.err @@ -1536,6 +1535,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error } if err != nil { + err = preferContextOverNetTimeoutError(rr.ctx, err) rr.concludeCommand(nil, err) rr.pgConn.contextWatcher.Unwatch() rr.closed = true diff --git a/pgconn_test.go b/pgconn_test.go index 7ceda791..c20b7425 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -585,6 +585,7 @@ func TestConnExecContextCanceled(t *testing.T) { } err = multiResult.Close() assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) assert.True(t, pgConn.IsClosed()) select { case <-pgConn.CleanupDone(): @@ -729,6 +730,7 @@ func TestConnExecParamsCanceled(t *testing.T) { commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) assert.True(t, pgConn.IsClosed()) select { @@ -1289,7 +1291,7 @@ func TestConnWaitForNotificationPrecanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() err = pgConn.WaitForNotification(ctx) - require.Equal(t, context.Canceled, err) + require.ErrorIs(t, err, context.Canceled) ensureConnValid(t, pgConn) } @@ -1308,6 +1310,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { err = pgConn.WaitForNotification(ctx) cancel() assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) ensureConnValid(t, pgConn) } From 7d0a620dda033ba9fc9c496bf41539c4a1a6479f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Jul 2021 09:20:54 -0500 Subject: [PATCH 0698/1158] Upgrade pgx version used for tests --- go.mod | 4 ++-- go.sum | 27 +++++++++++++++++++++------ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index dd2449e6..29e6f628 100644 --- a/go.mod +++ b/go.mod @@ -4,9 +4,9 @@ go 1.13 require ( github.com/gofrs/uuid v4.0.0+incompatible - github.com/jackc/pgconn v1.8.1 + github.com/jackc/pgconn v1.9.0 github.com/jackc/pgio v1.0.0 - github.com/jackc/pgx/v4 v4.11.0 + github.com/jackc/pgx/v4 v4.12.0 github.com/lib/pq v1.10.2 github.com/shopspring/decimal v1.2.0 github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index 01f503c9..e49ce26f 100644 --- a/go.sum +++ b/go.sum @@ -133,12 +133,15 @@ github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsU github.com/jackc/pgconn v1.4.0/go.mod h1:Y2O3ZDF0q4mMacyWV3AstPJpeHXWGEetiFttmq5lahk= github.com/jackc/pgconn v1.5.0/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= -github.com/jackc/pgconn v1.8.1 h1:ySBX7Q87vOMqKU2bbmKbUvtYhauDFclYbNDYIE1/h6s= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= github.com/jackc/pgconn v1.8.1/go.mod h1:JV6m6b6jhjdmzchES0drzCcYcAHS1OPD5xu3OZ/lE2g= +github.com/jackc/pgconn v1.9.0 h1:gqibKSTJup/ahCsNKyMZAniPuZEfIqfXFc8FOWVYR+Q= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= -github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd h1:eDErF6V/JPJON/B7s68BxwHgfmyOntHJQ8IOaz0x4R8= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= @@ -148,8 +151,9 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.0.6 h1:b1105ZGEMFe7aCvrT1Cca3VoVb4ZFMaFJLJcg/3zD+8= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= @@ -160,14 +164,16 @@ github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4 github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= github.com/jackc/pgtype v1.3.1-0.20200606141011-f6355165a91c/go.mod h1:cvk9Bgu/VzJ9/lxTO5R5sf80p0DiucVtN7ZxvaC4GmQ= github.com/jackc/pgtype v1.7.0/go.mod h1:ZnHF+rMePVqDKaOfJVI4Q8IVvAQMryDlDkZnKOI75BE= +github.com/jackc/pgtype v1.8.0/go.mod h1:PqDKcEBtllAtk/2p6z6SHdXW5UB+MhE75tUol2OKexE= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= github.com/jackc/pgx/v4 v4.5.0/go.mod h1:EpAKPLdnTorwmPUUsqrPxy5fphV18j9q3wrfRXgo+kA= github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9/go.mod h1:t3/cdRQl6fOLDxqtlyhe9UWgfIi9R8+8v8GKV5TRA/o= github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904/go.mod h1:ZDaNWkt9sW1JMiNn0kdYBaLelIhw7Pg4qd+Vk6tw7Hg= -github.com/jackc/pgx/v4 v4.11.0 h1:J86tSWd3Y7nKjwT/43xZBvpi04keQWx8gNC2YkdJhZI= github.com/jackc/pgx/v4 v4.11.0/go.mod h1:i62xJgdrtVDsnL3U8ekyrQXEwGNTRoG7/8r+CIdYfcc= +github.com/jackc/pgx/v4 v4.12.0 h1:xiP3TdnkwyslWNp77yE5XAPfxAsU9RMFDe0c1SwN8h4= +github.com/jackc/pgx/v4 v4.12.0/go.mod h1:fE547h6VulLPA3kySjfnSG/e2D861g/50JlVUa/ub60= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= @@ -345,8 +351,11 @@ golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e h1:gsTQYXdTw2Gq7RBsWvlQ91b+aEQ6bXFUngBGuR8sPpI= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= @@ -397,15 +406,20 @@ golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -428,6 +442,7 @@ golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= From 32e20a603178b49fb189d1be971d0fb6960cabb2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Jul 2021 10:16:00 -0500 Subject: [PATCH 0699/1158] Temporarily delete tests and pgxtype to break recursive dependency with pgx --- aclitem_array_test.go | 329 ----------------- aclitem_test.go | 97 ----- array_test.go | 127 ------- array_type_test.go | 84 ----- bit_test.go | 25 -- bool_array_test.go | 283 --------------- bool_test.go | 140 -------- box_test.go | 34 -- bpchar_array_test.go | 55 --- bpchar_test.go | 51 --- bytea_array_test.go | 229 ------------ bytea_test.go | 73 ---- cid_test.go | 105 ------ cidr_array_test.go | 319 ---------------- circle_test.go | 16 - composite_bench_test.go | 192 ---------- composite_fields_test.go | 273 -------------- composite_type_test.go | 320 ----------------- custom_composite_test.go | 87 ----- date_array_test.go | 327 ----------------- date_test.go | 168 --------- daterange_test.go | 133 ------- enum_array_test.go | 281 --------------- enum_type_test.go | 148 -------- ext/gofrs-uuid/uuid_test.go | 101 ------ ext/shopspring-numeric/decimal_test.go | 330 ----------------- float4_array_test.go | 282 --------------- float4_test.go | 149 -------- float8_array_test.go | 258 ------------- float8_test.go | 149 -------- go.mod | 4 - go.sum | 480 ------------------------- hstore_array_test.go | 436 ---------------------- hstore_test.go | 111 ------ inet_array_test.go | 319 ---------------- inet_test.go | 134 ------- int2_array_test.go | 342 ------------------ int2_test.go | 144 -------- int4_array_test.go | 356 ------------------ int4_test.go | 186 ---------- int4range_test.go | 28 -- int8_array_test.go | 349 ------------------ int8_test.go | 187 ---------- int8range_test.go | 28 -- interval_test.go | 74 ---- json_test.go | 177 --------- jsonb_array_test.go | 36 -- jsonb_test.go | 142 -------- line_test.go | 38 -- lseg_test.go | 22 -- macaddr_array_test.go | 262 -------------- macaddr_test.go | 78 ---- name_test.go | 98 ----- numeric_array_test.go | 305 ---------------- numeric_test.go | 389 -------------------- numrange_test.go | 46 --- oid_value_test.go | 95 ----- path_test.go | 29 -- pgtype_test.go | 292 --------------- pgxtype/README.md | 3 - pgxtype/pgxtype.go | 145 -------- point_test.go | 150 -------- polygon_test.go | 89 ----- qchar_test.go | 143 -------- range_test.go | 177 --------- record_test.go | 186 ---------- testutil/testutil.go | 436 ---------------------- text_array_test.go | 294 --------------- text_test.go | 164 --------- tid_test.go | 63 ---- time_test.go | 131 ------- timestamp_array_test.go | 307 ---------------- timestamp_test.go | 178 --------- timestamptz_array_test.go | 343 ------------------ timestamptz_test.go | 224 ------------ tsrange_test.go | 41 --- tstzrange_test.go | 49 --- uuid_array_test.go | 368 ------------------- uuid_test.go | 245 ------------- varbit_test.go | 26 -- varchar_array_test.go | 282 --------------- xid_test.go | 105 ------ zeronull/int2_test.go | 23 -- zeronull/int4_test.go | 23 -- zeronull/int8_test.go | 23 -- zeronull/text_test.go | 23 -- zeronull/timestamp_test.go | 29 -- zeronull/timestamptz_test.go | 29 -- zeronull/uuid_test.go | 23 -- 89 files changed, 14674 deletions(-) delete mode 100644 aclitem_array_test.go delete mode 100644 aclitem_test.go delete mode 100644 array_test.go delete mode 100644 array_type_test.go delete mode 100644 bit_test.go delete mode 100644 bool_array_test.go delete mode 100644 bool_test.go delete mode 100644 box_test.go delete mode 100644 bpchar_array_test.go delete mode 100644 bpchar_test.go delete mode 100644 bytea_array_test.go delete mode 100644 bytea_test.go delete mode 100644 cid_test.go delete mode 100644 cidr_array_test.go delete mode 100644 circle_test.go delete mode 100644 composite_bench_test.go delete mode 100644 composite_fields_test.go delete mode 100644 composite_type_test.go delete mode 100644 custom_composite_test.go delete mode 100644 date_array_test.go delete mode 100644 date_test.go delete mode 100644 daterange_test.go delete mode 100644 enum_array_test.go delete mode 100644 enum_type_test.go delete mode 100644 ext/gofrs-uuid/uuid_test.go delete mode 100644 ext/shopspring-numeric/decimal_test.go delete mode 100644 float4_array_test.go delete mode 100644 float4_test.go delete mode 100644 float8_array_test.go delete mode 100644 float8_test.go delete mode 100644 hstore_array_test.go delete mode 100644 hstore_test.go delete mode 100644 inet_array_test.go delete mode 100644 inet_test.go delete mode 100644 int2_array_test.go delete mode 100644 int2_test.go delete mode 100644 int4_array_test.go delete mode 100644 int4_test.go delete mode 100644 int4range_test.go delete mode 100644 int8_array_test.go delete mode 100644 int8_test.go delete mode 100644 int8range_test.go delete mode 100644 interval_test.go delete mode 100644 json_test.go delete mode 100644 jsonb_array_test.go delete mode 100644 jsonb_test.go delete mode 100644 line_test.go delete mode 100644 lseg_test.go delete mode 100644 macaddr_array_test.go delete mode 100644 macaddr_test.go delete mode 100644 name_test.go delete mode 100644 numeric_array_test.go delete mode 100644 numeric_test.go delete mode 100644 numrange_test.go delete mode 100644 oid_value_test.go delete mode 100644 path_test.go delete mode 100644 pgtype_test.go delete mode 100644 pgxtype/README.md delete mode 100644 pgxtype/pgxtype.go delete mode 100644 point_test.go delete mode 100644 polygon_test.go delete mode 100644 qchar_test.go delete mode 100644 range_test.go delete mode 100644 record_test.go delete mode 100644 testutil/testutil.go delete mode 100644 text_array_test.go delete mode 100644 text_test.go delete mode 100644 tid_test.go delete mode 100644 time_test.go delete mode 100644 timestamp_array_test.go delete mode 100644 timestamp_test.go delete mode 100644 timestamptz_array_test.go delete mode 100644 timestamptz_test.go delete mode 100644 tsrange_test.go delete mode 100644 tstzrange_test.go delete mode 100644 uuid_array_test.go delete mode 100644 uuid_test.go delete mode 100644 varbit_test.go delete mode 100644 varchar_array_test.go delete mode 100644 xid_test.go delete mode 100644 zeronull/int2_test.go delete mode 100644 zeronull/int4_test.go delete mode 100644 zeronull/int8_test.go delete mode 100644 zeronull/text_test.go delete mode 100644 zeronull/timestamp_test.go delete mode 100644 zeronull/timestamptz_test.go delete mode 100644 zeronull/uuid_test.go diff --git a/aclitem_array_test.go b/aclitem_array_test.go deleted file mode 100644 index 8f015f40..00000000 --- a/aclitem_array_test.go +++ /dev/null @@ -1,329 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestACLItemArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "aclitem[]", []interface{}{ - &pgtype.ACLItemArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.ACLItemArray{Status: pgtype.Null}, - &pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - //{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - {String: `postgres=arwdDxt/postgres`, Status: pgtype.Present}, // todo: remove after fixing above case - {String: "=r/postgres", Status: pgtype.Present}, - {Status: pgtype.Null}, - {String: "=r/postgres", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestACLItemArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.ACLItemArray - }{ - { - source: []string{"=r/postgres"}, - result: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]string)(nil)), - result: pgtype.ACLItemArray{Status: pgtype.Null}, - }, - { - source: [][]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, - result: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]string{ - {{{ - "=r/postgres", - "postgres=arwdDxt/postgres", - "=r/postgres"}}}, - {{{ - "postgres=arwdDxt/postgres", - "=r/postgres", - "postgres=arwdDxt/postgres"}}}}, - result: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, - result: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3]string{ - {{{ - "=r/postgres", - "postgres=arwdDxt/postgres", - "=r/postgres"}}}, - {{{ - "postgres=arwdDxt/postgres", - "=r/postgres", - "postgres=arwdDxt/postgres"}}}}, - result: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.ACLItemArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestACLItemArrayAssignTo(t *testing.T) { - var stringSlice []string - type _stringSlice []string - var namedStringSlice _stringSlice - var stringSliceDim2 [][]string - var stringSliceDim4 [][][][]string - var stringArrayDim2 [2][1]string - var stringArrayDim4 [2][1][1][3]string - - simpleTests := []struct { - src pgtype.ACLItemArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - expected: []string{"=r/postgres"}, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedStringSlice, - expected: _stringSlice{"=r/postgres"}, - }, - { - src: pgtype.ACLItemArray{Status: pgtype.Null}, - dst: &stringSlice, - expected: (([]string)(nil)), - }, - { - src: pgtype.ACLItemArray{Status: pgtype.Present}, - dst: &stringSlice, - expected: []string{}, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &stringSliceDim2, - expected: [][]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &stringSliceDim4, - expected: [][][][]string{ - {{{ - "=r/postgres", - "postgres=arwdDxt/postgres", - "=r/postgres"}}}, - {{{ - "postgres=arwdDxt/postgres", - "=r/postgres", - "postgres=arwdDxt/postgres"}}}}, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &stringArrayDim2, - expected: [2][1]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &stringArrayDim4, - expected: [2][1][1][3]string{ - {{{ - "=r/postgres", - "postgres=arwdDxt/postgres", - "=r/postgres"}}}, - {{{ - "postgres=arwdDxt/postgres", - "=r/postgres", - "postgres=arwdDxt/postgres"}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.ACLItemArray - dst interface{} - }{ - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &stringArrayDim2, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &stringSlice, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &stringArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/aclitem_test.go b/aclitem_test.go deleted file mode 100644 index a37d7657..00000000 --- a/aclitem_test.go +++ /dev/null @@ -1,97 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestACLItemTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "aclitem", []interface{}{ - &pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - //&pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - &pgtype.ACLItem{Status: pgtype.Null}, - }) -} - -func TestACLItemSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.ACLItem - }{ - {source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - {source: (*string)(nil), result: pgtype.ACLItem{Status: pgtype.Null}}, - } - - for i, tt := range successfulTests { - var d pgtype.ACLItem - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if d != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } - } -} - -func TestACLItemAssignTo(t *testing.T) { - var s string - var ps *string - - simpleTests := []struct { - src pgtype.ACLItem - dst interface{} - expected interface{} - }{ - {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"}, - {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.ACLItem - dst interface{} - expected interface{} - }{ - {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.ACLItem - dst interface{} - }{ - {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &s}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/array_test.go b/array_test.go deleted file mode 100644 index d2120677..00000000 --- a/array_test.go +++ /dev/null @@ -1,127 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/stretchr/testify/require" -) - -func TestParseUntypedTextArray(t *testing.T) { - tests := []struct { - source string - result pgtype.UntypedTextArray - }{ - { - source: "{}", - result: pgtype.UntypedTextArray{ - Elements: nil, - Quoted: nil, - Dimensions: nil, - }, - }, - { - source: "{1}", - result: pgtype.UntypedTextArray{ - Elements: []string{"1"}, - Quoted: []bool{false}, - Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, - }, - }, - { - source: "{a,b}", - result: pgtype.UntypedTextArray{ - Elements: []string{"a", "b"}, - Quoted: []bool{false, false}, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - }, - }, - { - source: `{"NULL"}`, - result: pgtype.UntypedTextArray{ - Elements: []string{"NULL"}, - Quoted: []bool{true}, - Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, - }, - }, - { - source: `{""}`, - result: pgtype.UntypedTextArray{ - Elements: []string{""}, - Quoted: []bool{true}, - Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, - }, - }, - { - source: `{"He said, \"Hello.\""}`, - result: pgtype.UntypedTextArray{ - Elements: []string{`He said, "Hello."`}, - Quoted: []bool{true}, - Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, - }, - }, - { - source: "{{a,b},{c,d},{e,f}}", - result: pgtype.UntypedTextArray{ - Elements: []string{"a", "b", "c", "d", "e", "f"}, - Quoted: []bool{false, false, false, false, false, false}, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - }, - }, - { - source: "{{{a,b},{c,d},{e,f}},{{a,b},{c,d},{e,f}}}", - result: pgtype.UntypedTextArray{ - Elements: []string{"a", "b", "c", "d", "e", "f", "a", "b", "c", "d", "e", "f"}, - Quoted: []bool{false, false, false, false, false, false, false, false, false, false, false, false}, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 1}, - {Length: 3, LowerBound: 1}, - {Length: 2, LowerBound: 1}, - }, - }, - }, - { - source: "[4:4]={1}", - result: pgtype.UntypedTextArray{ - Elements: []string{"1"}, - Quoted: []bool{false}, - Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 4}}, - }, - }, - { - source: "[4:5][2:3]={{a,b},{c,d}}", - result: pgtype.UntypedTextArray{ - Elements: []string{"a", "b", "c", "d"}, - Quoted: []bool{false, false, false, false}, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - }, - }, - } - - for i, tt := range tests { - r, err := pgtype.ParseUntypedTextArray(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - continue - } - - if !reflect.DeepEqual(*r, tt.result) { - t.Errorf("%d: expected %+v to be parsed to %+v, but it was %+v", i, tt.source, tt.result, *r) - } - } -} - -// https://github.com/jackc/pgx/issues/881 -func TestArrayAssignToEmptyToNonSlice(t *testing.T) { - var a pgtype.Int4Array - err := a.Set([]int32{}) - require.NoError(t, err) - - var iface interface{} - err = a.AssignTo(&iface) - require.EqualError(t, err, "cannot assign *pgtype.Int4Array to *interface {}") -} diff --git a/array_type_test.go b/array_type_test.go deleted file mode 100644 index 626df4dc..00000000 --- a/array_type_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package pgtype_test - -import ( - "context" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" - "github.com/stretchr/testify/require" -) - -func TestArrayTypeValue(t *testing.T) { - arrayType := pgtype.NewArrayType("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }) - - err := arrayType.Set(nil) - require.NoError(t, err) - - gotValue := arrayType.Get() - require.Nil(t, gotValue) - - slice := []string{"foo", "bar"} - err = arrayType.AssignTo(&slice) - require.NoError(t, err) - require.Nil(t, slice) - - err = arrayType.Set([]string{}) - require.NoError(t, err) - - gotValue = arrayType.Get() - require.Len(t, gotValue, 0) - - err = arrayType.AssignTo(&slice) - require.NoError(t, err) - require.EqualValues(t, []string{}, slice) - - err = arrayType.Set([]string{"baz", "quz"}) - require.NoError(t, err) - - gotValue = arrayType.Get() - require.Len(t, gotValue, 2) - - err = arrayType.AssignTo(&slice) - require.NoError(t, err) - require.EqualValues(t, []string{"baz", "quz"}, slice) -} - -func TestArrayTypeTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - conn.ConnInfo().RegisterDataType(pgtype.DataType{ - Value: pgtype.NewArrayType("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }), - Name: "_text", - OID: pgtype.TextArrayOID, - }) - - var dstStrings []string - err := conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings) - require.NoError(t, err) - - require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) -} - -func TestArrayTypeEmptyArrayDoesNotBreakArrayType(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - conn.ConnInfo().RegisterDataType(pgtype.DataType{ - Value: pgtype.NewArrayType("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }), - Name: "_text", - OID: pgtype.TextArrayOID, - }) - - var dstStrings []string - err := conn.QueryRow(context.Background(), "select '{}'::text[]").Scan(&dstStrings) - require.NoError(t, err) - - require.EqualValues(t, []string{}, dstStrings) - - err = conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings) - require.NoError(t, err) - - require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) -} diff --git a/bit_test.go b/bit_test.go deleted file mode 100644 index 2e9c9b6e..00000000 --- a/bit_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestBitTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bit(40)", []interface{}{ - &pgtype.Varbit{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Status: pgtype.Present}, - &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Status: pgtype.Present}, - &pgtype.Varbit{Status: pgtype.Null}, - }) -} - -func TestBitNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select B'111111111'", - Value: &pgtype.Bit{Bytes: []byte{255, 128}, Len: 9, Status: pgtype.Present}, - }, - }) -} diff --git a/bool_array_test.go b/bool_array_test.go deleted file mode 100644 index be567e59..00000000 --- a/bool_array_test.go +++ /dev/null @@ -1,283 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestBoolArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bool[]", []interface{}{ - &pgtype.BoolArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.BoolArray{Status: pgtype.Null}, - &pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Bool: false, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestBoolArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.BoolArray - }{ - { - source: []bool{true}, - result: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]bool)(nil)), - result: pgtype.BoolArray{Status: pgtype.Null}, - }, - { - source: [][]bool{{true}, {false}}, - result: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, - result: pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1]bool{{true}, {false}}, - result: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, - result: pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.BoolArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestBoolArrayAssignTo(t *testing.T) { - var boolSlice []bool - type _boolSlice []bool - var namedBoolSlice _boolSlice - var boolSliceDim2 [][]bool - var boolSliceDim4 [][][][]bool - var boolArrayDim2 [2][1]bool - var boolArrayDim4 [2][1][1][3]bool - - simpleTests := []struct { - src pgtype.BoolArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &boolSlice, - expected: []bool{true}, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedBoolSlice, - expected: _boolSlice{true}, - }, - { - src: pgtype.BoolArray{Status: pgtype.Null}, - dst: &boolSlice, - expected: (([]bool)(nil)), - }, - { - src: pgtype.BoolArray{Status: pgtype.Present}, - dst: &boolSlice, - expected: []bool{}, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - expected: [][]bool{{true}, {false}}, - dst: &boolSliceDim2, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - expected: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, - dst: &boolSliceDim4, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - expected: [2][1]bool{{true}, {false}}, - dst: &boolArrayDim2, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - expected: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, - dst: &boolArrayDim4, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.BoolArray - dst interface{} - }{ - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &boolSlice, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &boolArrayDim2, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &boolSlice, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &boolArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/bool_test.go b/bool_test.go deleted file mode 100644 index 8e7a5220..00000000 --- a/bool_test.go +++ /dev/null @@ -1,140 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestBoolTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bool", []interface{}{ - &pgtype.Bool{Bool: false, Status: pgtype.Present}, - &pgtype.Bool{Bool: true, Status: pgtype.Present}, - &pgtype.Bool{Bool: false, Status: pgtype.Null}, - }) -} - -func TestBoolSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Bool - }{ - {source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, - {source: false, result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, - {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, - {source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, - {source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, - {source: "f", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, - {source: _bool(true), result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, - {source: _bool(false), result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, - {source: nil, result: pgtype.Bool{Status: pgtype.Null}}, - } - - for i, tt := range successfulTests { - var r pgtype.Bool - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestBoolAssignTo(t *testing.T) { - var b bool - var _b _bool - var pb *bool - var _pb *_bool - - simpleTests := []struct { - src pgtype.Bool - dst interface{} - expected interface{} - }{ - {src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &b, expected: false}, - {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &b, expected: true}, - {src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &_b, expected: _bool(false)}, - {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_b, expected: _bool(true)}, - {src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &pb, expected: ((*bool)(nil))}, - {src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &_pb, expected: ((*_bool)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Bool - dst interface{} - expected interface{} - }{ - {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &pb, expected: true}, - {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_pb, expected: _bool(true)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} - -func TestBoolMarshalJSON(t *testing.T) { - successfulTests := []struct { - source pgtype.Bool - result string - }{ - {source: pgtype.Bool{Status: pgtype.Null}, result: "null"}, - {source: pgtype.Bool{Bool: true, Status: pgtype.Present}, result: "true"}, - {source: pgtype.Bool{Bool: false, Status: pgtype.Present}, result: "false"}, - } - for i, tt := range successfulTests { - r, err := tt.source.MarshalJSON() - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if string(r) != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) - } - } -} - -func TestBoolUnmarshalJSON(t *testing.T) { - successfulTests := []struct { - source string - result pgtype.Bool - }{ - {source: "null", result: pgtype.Bool{Status: pgtype.Null}}, - {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, - {source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, - } - for i, tt := range successfulTests { - var r pgtype.Bool - err := r.UnmarshalJSON([]byte(tt.source)) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} diff --git a/box_test.go b/box_test.go deleted file mode 100644 index 643c74ec..00000000 --- a/box_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestBoxTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "box", []interface{}{ - &pgtype.Box{ - P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, - Status: pgtype.Present, - }, - &pgtype.Box{ - P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, - Status: pgtype.Present, - }, - &pgtype.Box{Status: pgtype.Null}, - }) -} - -func TestBoxNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select '3.14, 1.678, 7.1, 5.234'::box", - Value: &pgtype.Box{ - P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, - Status: pgtype.Present, - }, - }, - }) -} diff --git a/bpchar_array_test.go b/bpchar_array_test.go deleted file mode 100644 index af6bf09a..00000000 --- a/bpchar_array_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestBPCharArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "char(8)[]", []interface{}{ - &pgtype.BPCharArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.BPCharArray{ - Elements: []pgtype.BPChar{ - pgtype.BPChar{String: "foo ", Status: pgtype.Present}, - pgtype.BPChar{Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.BPCharArray{Status: pgtype.Null}, - &pgtype.BPCharArray{ - Elements: []pgtype.BPChar{ - pgtype.BPChar{String: "bar ", Status: pgtype.Present}, - pgtype.BPChar{String: "NuLL ", Status: pgtype.Present}, - pgtype.BPChar{String: `wow"quz\`, Status: pgtype.Present}, - pgtype.BPChar{String: "1 ", Status: pgtype.Present}, - pgtype.BPChar{String: "1 ", Status: pgtype.Present}, - pgtype.BPChar{String: "null ", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 3, LowerBound: 1}, - {Length: 2, LowerBound: 1}, - }, - Status: pgtype.Present, - }, - &pgtype.BPCharArray{ - Elements: []pgtype.BPChar{ - pgtype.BPChar{String: " bar ", Status: pgtype.Present}, - pgtype.BPChar{String: " baz ", Status: pgtype.Present}, - pgtype.BPChar{String: " quz ", Status: pgtype.Present}, - pgtype.BPChar{String: "foo ", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} diff --git a/bpchar_test.go b/bpchar_test.go deleted file mode 100644 index 7b8c1da3..00000000 --- a/bpchar_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestChar3Transcode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "char(3)", []interface{}{ - &pgtype.BPChar{String: "a ", Status: pgtype.Present}, - &pgtype.BPChar{String: " a ", Status: pgtype.Present}, - &pgtype.BPChar{String: "嗨 ", Status: pgtype.Present}, - &pgtype.BPChar{String: " ", Status: pgtype.Present}, - &pgtype.BPChar{Status: pgtype.Null}, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.BPChar) - b := bb.(pgtype.BPChar) - - return a.Status == b.Status && a.String == b.String - }) -} - -func TestBPCharAssignTo(t *testing.T) { - var ( - str string - run rune - ) - simpleTests := []struct { - src pgtype.BPChar - dst interface{} - expected interface{} - }{ - {src: pgtype.BPChar{String: "simple", Status: pgtype.Present}, dst: &str, expected: "simple"}, - {src: pgtype.BPChar{String: "嗨", Status: pgtype.Present}, dst: &run, expected: '嗨'}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - -} diff --git a/bytea_array_test.go b/bytea_array_test.go deleted file mode 100644 index 27c0382e..00000000 --- a/bytea_array_test.go +++ /dev/null @@ -1,229 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestByteaArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bytea[]", []interface{}{ - &pgtype.ByteaArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.ByteaArray{Status: pgtype.Null}, - &pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{}, Status: pgtype.Present}, - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Bytes: []byte{1}, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{}, Status: pgtype.Present}, - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{1}, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestByteaArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.ByteaArray - }{ - { - source: [][]byte{{1, 2, 3}}, - result: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([][]byte)(nil)), - result: pgtype.ByteaArray{Status: pgtype.Null}, - }, - { - source: [][][]byte{{{1}}, {{2}}}, - result: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, - result: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, - {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, - {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, - {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, - {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1][]byte{{{1}}, {{2}}}, - result: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, - result: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, - {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, - {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, - {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, - {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.ByteaArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestByteaArrayAssignTo(t *testing.T) { - var byteByteSlice [][]byte - var byteByteSliceDim2 [][][]byte - var byteByteSliceDim4 [][][][][]byte - var byteByteArraySliceDim2 [2][1][]byte - var byteByteArraySliceDim4 [2][1][1][3][]byte - - simpleTests := []struct { - src pgtype.ByteaArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &byteByteSlice, - expected: [][]byte{{1, 2, 3}}, - }, - { - src: pgtype.ByteaArray{Status: pgtype.Null}, - dst: &byteByteSlice, - expected: (([][]byte)(nil)), - }, - { - src: pgtype.ByteaArray{Status: pgtype.Present}, - dst: &byteByteSlice, - expected: [][]byte{}, - }, - { - src: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &byteByteSliceDim2, - expected: [][][]byte{{{1}}, {{2}}}, - }, - { - src: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, - {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, - {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, - {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, - {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &byteByteSliceDim4, - expected: [][][][][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, - }, - { - src: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &byteByteArraySliceDim2, - expected: [2][1][]byte{{{1}}, {{2}}}, - }, - { - src: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, - {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, - {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, - {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, - {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &byteByteArraySliceDim4, - expected: [2][1][1][3][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/bytea_test.go b/bytea_test.go deleted file mode 100644 index c8c49ff7..00000000 --- a/bytea_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestByteaTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bytea", []interface{}{ - &pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - &pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, - &pgtype.Bytea{Bytes: nil, Status: pgtype.Null}, - }) -} - -func TestByteaSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Bytea - }{ - {source: []byte{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, - {source: []byte{}, result: pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}}, - {source: []byte(nil), result: pgtype.Bytea{Status: pgtype.Null}}, - {source: _byteSlice{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, - {source: _byteSlice(nil), result: pgtype.Bytea{Status: pgtype.Null}}, - } - - for i, tt := range successfulTests { - var r pgtype.Bytea - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestByteaAssignTo(t *testing.T) { - var buf []byte - var _buf _byteSlice - var pbuf *[]byte - var _pbuf *_byteSlice - - simpleTests := []struct { - src pgtype.Bytea - dst interface{} - expected interface{} - }{ - {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &buf, expected: []byte{1, 2, 3}}, - {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_buf, expected: _byteSlice{1, 2, 3}}, - {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &pbuf, expected: &[]byte{1, 2, 3}}, - {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_pbuf, expected: &_byteSlice{1, 2, 3}}, - {src: pgtype.Bytea{Status: pgtype.Null}, dst: &pbuf, expected: ((*[]byte)(nil))}, - {src: pgtype.Bytea{Status: pgtype.Null}, dst: &_pbuf, expected: ((*_byteSlice)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/cid_test.go b/cid_test.go deleted file mode 100644 index 50e50cd8..00000000 --- a/cid_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestCIDTranscode(t *testing.T) { - pgTypeName := "cid" - values := []interface{}{ - &pgtype.CID{Uint: 42, Status: pgtype.Present}, - &pgtype.CID{Status: pgtype.Null}, - } - eqFunc := func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - } - - testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) - } -} - -func TestCIDSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.CID - }{ - {source: uint32(1), result: pgtype.CID{Uint: 1, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.CID - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestCIDAssignTo(t *testing.T) { - var ui32 uint32 - var pui32 *uint32 - - simpleTests := []struct { - src pgtype.CID - dst interface{} - expected interface{} - }{ - {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.CID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.CID - dst interface{} - expected interface{} - }{ - {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.CID - dst interface{} - }{ - {src: pgtype.CID{Status: pgtype.Null}, dst: &ui32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/cidr_array_test.go b/cidr_array_test.go deleted file mode 100644 index 74c063fa..00000000 --- a/cidr_array_test.go +++ /dev/null @@ -1,319 +0,0 @@ -package pgtype_test - -import ( - "net" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestCIDRArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "cidr[]", []interface{}{ - &pgtype.CIDRArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.CIDRArray{Status: pgtype.Null}, - &pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - {Status: pgtype.Null}, - {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestCIDRArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.CIDRArray - }{ - { - source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]*net.IPNet)(nil)), - result: pgtype.CIDRArray{Status: pgtype.Null}, - }, - { - source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]net.IP)(nil)), - result: pgtype.CIDRArray{Status: pgtype.Null}, - }, - { - source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.CIDRArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestCIDRArrayAssignTo(t *testing.T) { - var ipnetSlice []*net.IPNet - var ipSlice []net.IP - var ipSliceDim2 [][]net.IP - var ipnetSliceDim4 [][][][]*net.IPNet - var ipArrayDim2 [2][1]net.IP - var ipnetArrayDim4 [2][1][1][3]*net.IPNet - - simpleTests := []struct { - src pgtype.CIDRArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &ipnetSlice, - expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &ipnetSlice, - expected: []*net.IPNet{nil}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &ipSlice, - expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &ipSlice, - expected: []net.IP{nil}, - }, - { - src: pgtype.CIDRArray{Status: pgtype.Null}, - dst: &ipnetSlice, - expected: (([]*net.IPNet)(nil)), - }, - { - src: pgtype.CIDRArray{Status: pgtype.Present}, - dst: &ipnetSlice, - expected: []*net.IPNet{}, - }, - { - src: pgtype.CIDRArray{Status: pgtype.Null}, - dst: &ipSlice, - expected: (([]net.IP)(nil)), - }, - { - src: pgtype.CIDRArray{Status: pgtype.Present}, - dst: &ipSlice, - expected: []net.IP{}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &ipSliceDim2, - expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &ipnetSliceDim4, - expected: [][][][]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &ipArrayDim2, - expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &ipnetArrayDim4, - expected: [2][1][1][3]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/circle_test.go b/circle_test.go deleted file mode 100644 index ba4f408b..00000000 --- a/circle_test.go +++ /dev/null @@ -1,16 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestCircleTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "circle", []interface{}{ - &pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Status: pgtype.Present}, - &pgtype.Circle{P: pgtype.Vec2{-1.234, -5.6789}, R: 12.9, Status: pgtype.Present}, - &pgtype.Circle{Status: pgtype.Null}, - }) -} diff --git a/composite_bench_test.go b/composite_bench_test.go deleted file mode 100644 index 7aef8c4f..00000000 --- a/composite_bench_test.go +++ /dev/null @@ -1,192 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgio" - "github.com/jackc/pgtype" - "github.com/stretchr/testify/require" -) - -type MyCompositeRaw struct { - A int32 - B *string -} - -func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - buf = pgio.AppendUint32(buf, 2) - - buf = pgio.AppendUint32(buf, pgtype.Int4OID) - buf = pgio.AppendInt32(buf, 4) - buf = pgio.AppendInt32(buf, src.A) - - buf = pgio.AppendUint32(buf, pgtype.TextOID) - if src.B != nil { - buf = pgio.AppendInt32(buf, int32(len(*src.B))) - buf = append(buf, (*src.B)...) - } else { - buf = pgio.AppendInt32(buf, -1) - } - - return buf, nil -} - -func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - a := pgtype.Int4{} - b := pgtype.Text{} - - scanner := pgtype.NewCompositeBinaryScanner(ci, src) - scanner.ScanDecoder(&a) - scanner.ScanDecoder(&b) - - if scanner.Err() != nil { - return scanner.Err() - } - - dst.A = a.Int - if b.Status == pgtype.Present { - dst.B = &b.String - } else { - dst.B = nil - } - - return nil -} - -var x []byte - -func BenchmarkBinaryEncodingManual(b *testing.B) { - buf := make([]byte, 0, 128) - ci := pgtype.NewConnInfo() - v := MyCompositeRaw{4, ptrS("ABCDEFG")} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - buf, _ = v.EncodeBinary(ci, buf[:0]) - } - x = buf -} - -func BenchmarkBinaryEncodingHelper(b *testing.B) { - buf := make([]byte, 0, 128) - ci := pgtype.NewConnInfo() - v := MyType{4, ptrS("ABCDEFG")} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - buf, _ = v.EncodeBinary(ci, buf[:0]) - } - x = buf -} - -func BenchmarkBinaryEncodingComposite(b *testing.B) { - buf := make([]byte, 0, 128) - ci := pgtype.NewConnInfo() - f1 := 2 - f2 := ptrS("bar") - c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ - {"a", pgtype.Int4OID}, - {"b", pgtype.TextOID}, - }, ci) - require.NoError(b, err) - - b.ResetTimer() - for n := 0; n < b.N; n++ { - c.Set([]interface{}{f1, f2}) - buf, _ = c.EncodeBinary(ci, buf[:0]) - } - x = buf -} - -func BenchmarkBinaryEncodingJSON(b *testing.B) { - buf := make([]byte, 0, 128) - ci := pgtype.NewConnInfo() - v := MyCompositeRaw{4, ptrS("ABCDEFG")} - j := pgtype.JSON{} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - j.Set(v) - buf, _ = j.EncodeBinary(ci, buf[:0]) - } - x = buf -} - -var dstRaw MyCompositeRaw - -func BenchmarkBinaryDecodingManual(b *testing.B) { - ci := pgtype.NewConnInfo() - buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) - dst := MyCompositeRaw{} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - err := dst.DecodeBinary(ci, buf) - E(err) - } - dstRaw = dst -} - -var dstMyType MyType - -func BenchmarkBinaryDecodingHelpers(b *testing.B) { - ci := pgtype.NewConnInfo() - buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) - dst := MyType{} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - err := dst.DecodeBinary(ci, buf) - E(err) - } - dstMyType = dst -} - -var gf1 int -var gf2 *string - -func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { - ci := pgtype.NewConnInfo() - buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) - var f1 int - var f2 *string - - c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ - {"a", pgtype.Int4OID}, - {"b", pgtype.TextOID}, - }, ci) - require.NoError(b, err) - - b.ResetTimer() - for n := 0; n < b.N; n++ { - err := c.DecodeBinary(ci, buf) - if err != nil { - b.Fatal(err) - } - err = c.AssignTo([]interface{}{&f1, &f2}) - if err != nil { - b.Fatal(err) - } - } - gf1 = f1 - gf2 = f2 -} - -func BenchmarkBinaryDecodingJSON(b *testing.B) { - ci := pgtype.NewConnInfo() - j := pgtype.JSON{} - j.Set(MyCompositeRaw{4, ptrS("ABCDEFG")}) - buf, _ := j.EncodeBinary(ci, nil) - - j = pgtype.JSON{} - dst := MyCompositeRaw{} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - err := j.DecodeBinary(ci, buf) - E(err) - err = j.AssignTo(&dst) - E(err) - } - dstRaw = dst -} diff --git a/composite_fields_test.go b/composite_fields_test.go deleted file mode 100644 index dc4d4c29..00000000 --- a/composite_fields_test.go +++ /dev/null @@ -1,273 +0,0 @@ -package pgtype_test - -import ( - "context" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgx/v4" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestCompositeFieldsDecode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - formats := []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} - - // Assorted values - { - var a int32 - var b string - var c float64 - - for _, format := range formats { - err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan( - pgtype.CompositeFields{&a, &b, &c}, - ) - if !assert.NoErrorf(t, err, "Format: %v", format) { - continue - } - - assert.EqualValuesf(t, 1, a, "Format: %v", format) - assert.EqualValuesf(t, "hi", b, "Format: %v", format) - assert.EqualValuesf(t, 2.1, c, "Format: %v", format) - } - } - - // nulls, string "null", and empty string fields - { - var a pgtype.Text - var b string - var c pgtype.Text - var d string - var e pgtype.Text - - for _, format := range formats { - err := conn.QueryRow(context.Background(), "select row(null,'null',null,'',null)", pgx.QueryResultFormats{format}).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - if !assert.NoErrorf(t, err, "Format: %v", format) { - continue - } - - assert.Nilf(t, a.Get(), "Format: %v", format) - assert.EqualValuesf(t, "null", b, "Format: %v", format) - assert.Nilf(t, c.Get(), "Format: %v", format) - assert.EqualValuesf(t, "", d, "Format: %v", format) - assert.Nilf(t, e.Get(), "Format: %v", format) - } - } - - // null record - { - var a pgtype.Text - var b string - cf := pgtype.CompositeFields{&a, &b} - - for _, format := range formats { - // Cannot scan nil into - err := conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan( - cf, - ) - if assert.Errorf(t, err, "Format: %v", format) { - continue - } - assert.NotNilf(t, cf, "Format: %v", format) - - // But can scan nil into *pgtype.CompositeFields - err = conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan( - &cf, - ) - if assert.Errorf(t, err, "Format: %v", format) { - continue - } - assert.Nilf(t, cf, "Format: %v", format) - } - } - - // quotes and special characters - { - var a, b, c, d string - - for _, format := range formats { - err := conn.QueryRow(context.Background(), `select row('"', 'foo bar', 'foo''bar', 'baz)bar')`, pgx.QueryResultFormats{format}).Scan( - pgtype.CompositeFields{&a, &b, &c, &d}, - ) - if !assert.NoErrorf(t, err, "Format: %v", format) { - continue - } - - assert.Equalf(t, `"`, a, "Format: %v", format) - assert.Equalf(t, `foo bar`, b, "Format: %v", format) - assert.Equalf(t, `foo'bar`, c, "Format: %v", format) - assert.Equalf(t, `baz)bar`, d, "Format: %v", format) - } - } - - // arrays - { - var a []string - var b []int64 - - for _, format := range formats { - err := conn.QueryRow(context.Background(), `select row(array['foo', 'bar', 'baz'], array[1,2,3])`, pgx.QueryResultFormats{format}).Scan( - pgtype.CompositeFields{&a, &b}, - ) - if !assert.NoErrorf(t, err, "Format: %v", format) { - continue - } - - assert.EqualValuesf(t, []string{"foo", "bar", "baz"}, a, "Format: %v", format) - assert.EqualValuesf(t, []int64{1, 2, 3}, b, "Format: %v", format) - } - } - - // Skip nil fields - { - var a int32 - var c float64 - - for _, format := range formats { - err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan( - pgtype.CompositeFields{&a, nil, &c}, - ) - if !assert.NoErrorf(t, err, "Format: %v", format) { - continue - } - - assert.EqualValuesf(t, 1, a, "Format: %v", format) - assert.EqualValuesf(t, 2.1, c, "Format: %v", format) - } - } -} - -func TestCompositeFieldsEncode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - _, err := conn.Exec(context.Background(), `drop type if exists cf_encode; - -create type cf_encode as ( - a text, - b int4, - c text, - d float8, - e text -);`) - require.NoError(t, err) - defer conn.Exec(context.Background(), "drop type cf_encode") - - // Use simple protocol to force text or binary encoding - simpleProtocols := []bool{true, false} - - // Assorted values - { - var a string - var b int32 - var c string - var d float64 - var e string - - for _, simpleProtocol := range simpleProtocols { - err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{"hi", int32(1), "ok", float64(2.1), "bye"}, - ).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { - assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, "ok", c, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, 2.1, d, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, "bye", e, "Simple Protocol: %v", simpleProtocol) - } - } - } - - // untyped nil - { - var a pgtype.Text - var b int32 - var c string - var d pgtype.Float8 - var e pgtype.Text - - simpleProtocol := true - err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{nil, int32(1), "null", nil, nil}, - ).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { - assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol) - assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol) - assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol) - } - - // untyped nil cannot be represented in binary format because CompositeFields does not know the PostgreSQL schema - // of the composite type. - simpleProtocol = false - err = conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{nil, int32(1), "null", nil, nil}, - ).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - assert.Errorf(t, err, "Simple Protocol: %v", simpleProtocol) - } - - // nulls, string "null", and empty string fields - { - var a pgtype.Text - var b int32 - var c string - var d pgtype.Float8 - var e pgtype.Text - - for _, simpleProtocol := range simpleProtocols { - err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{&pgtype.Text{Status: pgtype.Null}, int32(1), "null", &pgtype.Float8{Status: pgtype.Null}, &pgtype.Text{Status: pgtype.Null}}, - ).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { - assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol) - assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol) - assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol) - } - } - } - - // quotes and special characters - { - var a string - var b int32 - var c string - var d float64 - var e string - - for _, simpleProtocol := range simpleProtocols { - err := conn.QueryRow( - context.Background(), - `select $1::cf_encode`, - pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{`"`, int32(42), `foo'bar`, float64(1.2), `baz)bar`}, - ).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { - assert.Equalf(t, `"`, a, "Simple Protocol: %v", simpleProtocol) - assert.Equalf(t, int32(42), b, "Simple Protocol: %v", simpleProtocol) - assert.Equalf(t, `foo'bar`, c, "Simple Protocol: %v", simpleProtocol) - assert.Equalf(t, float64(1.2), d, "Simple Protocol: %v", simpleProtocol) - assert.Equalf(t, `baz)bar`, e, "Simple Protocol: %v", simpleProtocol) - } - } - } -} diff --git a/composite_type_test.go b/composite_type_test.go deleted file mode 100644 index 2349a67d..00000000 --- a/composite_type_test.go +++ /dev/null @@ -1,320 +0,0 @@ -package pgtype_test - -import ( - "context" - "fmt" - "os" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" - pgx "github.com/jackc/pgx/v4" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestCompositeTypeSetAndGet(t *testing.T) { - ci := pgtype.NewConnInfo() - ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ - {"a", pgtype.TextOID}, - {"b", pgtype.Int4OID}, - }, ci) - require.NoError(t, err) - assert.Equal(t, pgtype.Undefined, ct.Get()) - - nilTests := []struct { - src interface{} - }{ - {nil}, // nil interface - {(*[]interface{})(nil)}, // typed nil - } - - for i, tt := range nilTests { - err := ct.Set(tt.src) - assert.NoErrorf(t, err, "%d", i) - assert.Equal(t, nil, ct.Get()) - } - - compatibleValuesTests := []struct { - src []interface{} - expected map[string]interface{} - }{ - { - src: []interface{}{"foo", int32(42)}, - expected: map[string]interface{}{"a": "foo", "b": int32(42)}, - }, - { - src: []interface{}{nil, nil}, - expected: map[string]interface{}{"a": nil, "b": nil}, - }, - { - src: []interface{}{&pgtype.Text{String: "hi", Status: pgtype.Present}, &pgtype.Int4{Int: 7, Status: pgtype.Present}}, - expected: map[string]interface{}{"a": "hi", "b": int32(7)}, - }, - } - - for i, tt := range compatibleValuesTests { - err := ct.Set(tt.src) - assert.NoErrorf(t, err, "%d", i) - assert.EqualValues(t, tt.expected, ct.Get()) - } -} - -func TestCompositeTypeAssignTo(t *testing.T) { - ci := pgtype.NewConnInfo() - ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ - {"a", pgtype.TextOID}, - {"b", pgtype.Int4OID}, - }, ci) - require.NoError(t, err) - - { - err := ct.Set([]interface{}{"foo", int32(42)}) - assert.NoError(t, err) - - var a string - var b int32 - - err = ct.AssignTo([]interface{}{&a, &b}) - assert.NoError(t, err) - - assert.Equal(t, "foo", a) - assert.Equal(t, int32(42), b) - } - - { - err := ct.Set([]interface{}{"foo", int32(42)}) - assert.NoError(t, err) - - var a pgtype.Text - var b pgtype.Int4 - - err = ct.AssignTo([]interface{}{&a, &b}) - assert.NoError(t, err) - - assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a) - assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b) - } - - // Allow nil destination component as no-op - { - err := ct.Set([]interface{}{"foo", int32(42)}) - assert.NoError(t, err) - - var b int32 - - err = ct.AssignTo([]interface{}{nil, &b}) - assert.NoError(t, err) - - assert.Equal(t, int32(42), b) - } - - // *[]interface{} dest when null - { - err := ct.Set(nil) - assert.NoError(t, err) - - var a pgtype.Text - var b pgtype.Int4 - dst := []interface{}{&a, &b} - - err = ct.AssignTo(&dst) - assert.NoError(t, err) - - assert.Nil(t, dst) - } - - // *[]interface{} dest when not null - { - err := ct.Set([]interface{}{"foo", int32(42)}) - assert.NoError(t, err) - - var a pgtype.Text - var b pgtype.Int4 - dst := []interface{}{&a, &b} - - err = ct.AssignTo(&dst) - assert.NoError(t, err) - - assert.NotNil(t, dst) - assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a) - assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b) - } - - // Struct fields positionally via reflection - { - err := ct.Set([]interface{}{"foo", int32(42)}) - assert.NoError(t, err) - - s := struct { - A string - B int32 - }{} - - err = ct.AssignTo(&s) - if assert.NoError(t, err) { - assert.Equal(t, "foo", s.A) - assert.Equal(t, int32(42), s.B) - } - } -} - -func TestCompositeTypeTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - _, err := conn.Exec(context.Background(), `drop type if exists ct_test; - -create type ct_test as ( - a text, - b int4 -);`) - require.NoError(t, err) - defer conn.Exec(context.Background(), "drop type ct_test") - - var oid uint32 - err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid) - require.NoError(t, err) - - defer conn.Exec(context.Background(), "drop type ct_test") - - ct, err := pgtype.NewCompositeType("ct_test", []pgtype.CompositeTypeField{ - {"a", pgtype.TextOID}, - {"b", pgtype.Int4OID}, - }, conn.ConnInfo()) - require.NoError(t, err) - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid}) - - // Use simple protocol to force text or binary encoding - simpleProtocols := []bool{true, false} - - var a string - var b int32 - - for _, simpleProtocol := range simpleProtocols { - err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{"hi", int32(42)}, - ).Scan( - []interface{}{&a, &b}, - ) - if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { - assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, 42, b, "Simple Protocol: %v", simpleProtocol) - } - } -} - -// https://github.com/jackc/pgx/issues/874 -func TestCompositeTypeTextDecodeNested(t *testing.T) { - newCompositeType := func(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name} - } - - rowType, err := pgtype.NewCompositeTypeValues(name, fields, vals) - require.NoError(t, err) - return rowType - } - - dimensionsType := func() pgtype.ValueTranscoder { - return newCompositeType( - "dimensions", - []string{"width", "height"}, - &pgtype.Int4{}, - &pgtype.Int4{}, - ) - } - productImageType := func() pgtype.ValueTranscoder { - return newCompositeType( - "product_image_type", - []string{"source", "dimensions"}, - &pgtype.Text{}, - dimensionsType(), - ) - } - productImageSetType := newCompositeType( - "product_image_set_type", - []string{"name", "orig_image", "images"}, - &pgtype.Text{}, - productImageType(), - pgtype.NewArrayType("product_image", 0, func() pgtype.ValueTranscoder { - return productImageType() - }), - ) - - err := productImageSetType.DecodeText(nil, []byte(`(name,"(img1,""(11,11)"")","{""(img2,\\""(22,22)\\"")"",""(img3,\\""(33,33)\\"")""}")`)) - require.NoError(t, err) -} - -func Example_composite() { - conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - fmt.Println(err) - return - } - - defer conn.Close(context.Background()) - _, err = conn.Exec(context.Background(), `drop type if exists mytype;`) - if err != nil { - fmt.Println(err) - return - } - - _, err = conn.Exec(context.Background(), `create type mytype as ( - a int4, - b text -);`) - if err != nil { - fmt.Println(err) - return - } - defer conn.Exec(context.Background(), "drop type mytype") - - var oid uint32 - err = conn.QueryRow(context.Background(), `select 'mytype'::regtype::oid`).Scan(&oid) - if err != nil { - fmt.Println(err) - return - } - - ct, err := pgtype.NewCompositeType("mytype", []pgtype.CompositeTypeField{ - {"a", pgtype.Int4OID}, - {"b", pgtype.TextOID}, - }, conn.ConnInfo()) - if err != nil { - fmt.Println(err) - return - } - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid}) - - var a int - var b *string - - err = conn.QueryRow(context.Background(), "select $1::mytype", []interface{}{2, "bar"}).Scan([]interface{}{&a, &b}) - if err != nil { - fmt.Println(err) - return - } - - fmt.Printf("First: a=%d b=%s\n", a, *b) - - err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan([]interface{}{&a, &b}) - if err != nil { - fmt.Println(err) - return - } - - fmt.Printf("Second: a=%d b=%v\n", a, b) - - scanTarget := []interface{}{&a, &b} - err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(&scanTarget) - E(err) - - fmt.Printf("Third: isNull=%v\n", scanTarget == nil) - - // Output: - // First: a=2 b=bar - // Second: a=1 b= - // Third: isNull=true -} diff --git a/custom_composite_test.go b/custom_composite_test.go deleted file mode 100644 index 9ca8dd5e..00000000 --- a/custom_composite_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package pgtype_test - -import ( - "context" - "errors" - "fmt" - "os" - - "github.com/jackc/pgtype" - pgx "github.com/jackc/pgx/v4" -) - -type MyType struct { - a int32 // NULL will cause decoding error - b *string // there can be NULL in this position in SQL -} - -func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs") - } - - if err := (pgtype.CompositeFields{&dst.a, &dst.b}).DecodeBinary(ci, src); err != nil { - return err - } - - return nil -} - -func (src MyType) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) { - a := pgtype.Int4{src.a, pgtype.Present} - var b pgtype.Text - if src.b != nil { - b = pgtype.Text{*src.b, pgtype.Present} - } else { - b = pgtype.Text{Status: pgtype.Null} - } - - return (pgtype.CompositeFields{&a, &b}).EncodeBinary(ci, buf) -} - -func ptrS(s string) *string { - return &s -} - -func E(err error) { - if err != nil { - panic(err) - } -} - -// ExampleCustomCompositeTypes demonstrates how support for custom types mappable to SQL -// composites can be added. -func Example_customCompositeTypes() { - conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - E(err) - - defer conn.Close(context.Background()) - _, err = conn.Exec(context.Background(), `drop type if exists mytype; - -create type mytype as ( - a int4, - b text -);`) - E(err) - defer conn.Exec(context.Background(), "drop type mytype") - - var result *MyType - - // Demonstrates both passing and reading back composite values - err = conn.QueryRow(context.Background(), "select $1::mytype", - pgx.QueryResultFormats{pgx.BinaryFormatCode}, MyType{1, ptrS("foo")}). - Scan(&result) - E(err) - - fmt.Printf("First row: a=%d b=%s\n", result.a, *result.b) - - // Because we scan into &*MyType, NULLs are handled generically by assigning nil to result - err = conn.QueryRow(context.Background(), "select NULL::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result) - E(err) - - fmt.Printf("Second row: %v\n", result) - - // Output: - // First row: a=1 b=foo - // Second row: -} diff --git a/date_array_test.go b/date_array_test.go deleted file mode 100644 index 4458abfe..00000000 --- a/date_array_test.go +++ /dev/null @@ -1,327 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - "time" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestDateArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "date[]", []interface{}{ - &pgtype.DateArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.DateArray{Status: pgtype.Null}, - &pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestDateArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.DateArray - }{ - { - source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - result: pgtype.DateArray{ - Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]time.Time)(nil)), - result: pgtype.DateArray{Status: pgtype.Null}, - }, - { - source: [][]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - result: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - result: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - result: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - result: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.DateArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestDateArrayAssignTo(t *testing.T) { - var timeSlice []time.Time - var timeSliceDim2 [][]time.Time - var timeSliceDim4 [][][][]time.Time - var timeArrayDim2 [2][1]time.Time - var timeArrayDim4 [2][1][1][3]time.Time - - simpleTests := []struct { - src pgtype.DateArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &timeSlice, - expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - }, - { - src: pgtype.DateArray{Status: pgtype.Null}, - dst: &timeSlice, - expected: (([]time.Time)(nil)), - }, - { - src: pgtype.DateArray{Status: pgtype.Present}, - dst: &timeSlice, - expected: []time.Time{}, - }, - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &timeSliceDim2, - expected: [][]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - }, - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &timeSliceDim4, - expected: [][][][]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - }, - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &timeArrayDim2, - expected: [2][1]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - }, - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &timeArrayDim4, - expected: [2][1][1][3]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.DateArray - dst interface{} - }{ - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &timeSlice, - }, - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &timeArrayDim2, - }, - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &timeSlice, - }, - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &timeArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/date_test.go b/date_test.go deleted file mode 100644 index 5c38e7a3..00000000 --- a/date_test.go +++ /dev/null @@ -1,168 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - "time" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestDateTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "date", []interface{}{ - &pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Status: pgtype.Null}, - &pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, - &pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, - }, func(a, b interface{}) bool { - at := a.(pgtype.Date) - bt := b.(pgtype.Date) - - return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier - }) -} - -func TestDateSet(t *testing.T) { - type _time time.Time - - successfulTests := []struct { - source interface{} - result pgtype.Date - }{ - {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: "1999-12-31", result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var d pgtype.Date - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if d != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } - } -} - -func TestDateAssignTo(t *testing.T) { - var tim time.Time - var ptim *time.Time - - simpleTests := []struct { - src pgtype.Date - dst interface{} - expected interface{} - }{ - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, - {src: pgtype.Date{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Date - dst interface{} - expected interface{} - }{ - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Date - dst interface{} - }{ - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} - -func TestDateMarshalJSON(t *testing.T) { - successfulTests := []struct { - source pgtype.Date - result string - }{ - {source: pgtype.Date{Status: pgtype.Null}, result: "null"}, - {source: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, result: "\"2012-03-29\""}, - {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29\""}, - {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29\""}, - {source: pgtype.Date{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, result: "\"infinity\""}, - {source: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, result: "\"-infinity\""}, - } - for i, tt := range successfulTests { - r, err := tt.source.MarshalJSON() - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if string(r) != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) - } - } -} - -func TestDateUnmarshalJSON(t *testing.T) { - successfulTests := []struct { - source string - result pgtype.Date - }{ - {source: "null", result: pgtype.Date{Status: pgtype.Null}}, - {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, - {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, - {source: "\"infinity\"", result: pgtype.Date{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, - {source: "\"-infinity\"", result: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, - } - for i, tt := range successfulTests { - var r pgtype.Date - err := r.UnmarshalJSON([]byte(tt.source)) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r.Time.Year() != tt.result.Time.Year() || r.Time.Month() != tt.result.Time.Month() || r.Time.Day() != tt.result.Time.Day() || r.Status != tt.result.Status || r.InfinityModifier != tt.result.InfinityModifier { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} diff --git a/daterange_test.go b/daterange_test.go deleted file mode 100644 index 54d51e2d..00000000 --- a/daterange_test.go +++ /dev/null @@ -1,133 +0,0 @@ -package pgtype_test - -import ( - "testing" - "time" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestDaterangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "daterange", []interface{}{ - &pgtype.Daterange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - &pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Daterange{Status: pgtype.Null}, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.Daterange) - b := bb.(pgtype.Daterange) - - return a.Status == b.Status && - a.Lower.Time.Equal(b.Lower.Time) && - a.Lower.Status == b.Lower.Status && - a.Lower.InfinityModifier == b.Lower.InfinityModifier && - a.Upper.Time.Equal(b.Upper.Time) && - a.Upper.Status == b.Upper.Status && - a.Upper.InfinityModifier == b.Upper.InfinityModifier - }) -} - -func TestDaterangeNormalize(t *testing.T) { - testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ - { - SQL: "select daterange('2010-01-01', '2010-01-11', '(]')", - Value: pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(2010, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2010, 1, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - }, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.Daterange) - b := bb.(pgtype.Daterange) - - return a.Status == b.Status && - a.Lower.Time.Equal(b.Lower.Time) && - a.Lower.Status == b.Lower.Status && - a.Lower.InfinityModifier == b.Lower.InfinityModifier && - a.Upper.Time.Equal(b.Upper.Time) && - a.Upper.Status == b.Upper.Status && - a.Upper.InfinityModifier == b.Upper.InfinityModifier - }) -} - -func TestDaterangeSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Daterange - }{ - { - source: nil, - result: pgtype.Daterange{Status: pgtype.Null}, - }, - { - source: &pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - result: pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - }, - { - source: pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - result: pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - }, - { - source: "[1990-12-31,2028-01-01)", - result: pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Daterange - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} diff --git a/enum_array_test.go b/enum_array_test.go deleted file mode 100644 index 659340f0..00000000 --- a/enum_array_test.go +++ /dev/null @@ -1,281 +0,0 @@ -package pgtype_test - -import ( - "context" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestEnumArrayTranscode(t *testing.T) { - setupConn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, setupConn) - - if _, err := setupConn.Exec(context.Background(), "drop type if exists color"); err != nil { - t.Fatal(err) - } - if _, err := setupConn.Exec(context.Background(), "create type color as enum ('red', 'green', 'blue')"); err != nil { - t.Fatal(err) - } - - testutil.TestSuccessfulTranscode(t, "color[]", []interface{}{ - &pgtype.EnumArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.EnumArray{ - Elements: []pgtype.GenericText{ - {String: "red", Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.EnumArray{Status: pgtype.Null}, - &pgtype.EnumArray{ - Elements: []pgtype.GenericText{ - {String: "red", Status: pgtype.Present}, - {String: "green", Status: pgtype.Present}, - {String: "blue", Status: pgtype.Present}, - {String: "red", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestEnumArrayArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.EnumArray - }{ - { - source: []string{"foo"}, - result: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]string)(nil)), - result: pgtype.EnumArray{Status: pgtype.Null}, - }, - { - source: [][]string{{"foo"}, {"bar"}}, - result: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - result: pgtype.EnumArray{ - Elements: []pgtype.GenericText{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1]string{{"foo"}, {"bar"}}, - result: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - result: pgtype.EnumArray{ - Elements: []pgtype.GenericText{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.EnumArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestEnumArrayArrayAssignTo(t *testing.T) { - var stringSlice []string - type _stringSlice []string - var namedStringSlice _stringSlice - var stringSliceDim2 [][]string - var stringSliceDim4 [][][][]string - var stringArrayDim2 [2][1]string - var stringArrayDim4 [2][1][1][3]string - - simpleTests := []struct { - src pgtype.EnumArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - expected: []string{"foo"}, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedStringSlice, - expected: _stringSlice{"bar"}, - }, - { - src: pgtype.EnumArray{Status: pgtype.Null}, - dst: &stringSlice, - expected: (([]string)(nil)), - }, - { - src: pgtype.EnumArray{Status: pgtype.Present}, - dst: &stringSlice, - expected: []string{}, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &stringSliceDim2, - expected: [][]string{{"foo"}, {"bar"}}, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &stringSliceDim4, - expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &stringArrayDim2, - expected: [2][1]string{{"foo"}, {"bar"}}, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &stringArrayDim4, - expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.EnumArray - dst interface{} - }{ - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &stringArrayDim2, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &stringSlice, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &stringArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/enum_type_test.go b/enum_type_test.go deleted file mode 100644 index 4dd88f2a..00000000 --- a/enum_type_test.go +++ /dev/null @@ -1,148 +0,0 @@ -package pgtype_test - -import ( - "bytes" - "context" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgx/v4" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func setupEnum(t *testing.T, conn *pgx.Conn) *pgtype.EnumType { - _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;") - require.NoError(t, err) - - _, err = conn.Exec(context.Background(), "create type pgtype_enum_color as enum ('blue', 'green', 'purple');") - require.NoError(t, err) - - var oid uint32 - err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "pgtype_enum_color").Scan(&oid) - require.NoError(t, err) - - et := pgtype.NewEnumType("pgtype_enum_color", []string{"blue", "green", "purple"}) - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "pgtype_enum_color", OID: oid}) - - return et -} - -func cleanupEnum(t *testing.T, conn *pgx.Conn) { - _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;") - require.NoError(t, err) -} - -func TestEnumTypeTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - setupEnum(t, conn) - defer cleanupEnum(t, conn) - - var dst string - err := conn.QueryRow(context.Background(), "select $1::pgtype_enum_color", "blue").Scan(&dst) - require.NoError(t, err) - require.EqualValues(t, "blue", dst) -} - -func TestEnumTypeSet(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - enumType := setupEnum(t, conn) - defer cleanupEnum(t, conn) - - successfulTests := []struct { - source interface{} - result interface{} - }{ - {source: "blue", result: "blue"}, - {source: _string("green"), result: "green"}, - {source: (*string)(nil), result: nil}, - } - - for i, tt := range successfulTests { - err := enumType.Set(tt.source) - assert.NoErrorf(t, err, "%d", i) - assert.Equalf(t, tt.result, enumType.Get(), "%d", i) - } -} - -func TestEnumTypeAssignTo(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - enumType := setupEnum(t, conn) - defer cleanupEnum(t, conn) - - { - var s string - - err := enumType.Set("blue") - require.NoError(t, err) - - err = enumType.AssignTo(&s) - require.NoError(t, err) - - assert.EqualValues(t, "blue", s) - } - - { - var ps *string - - err := enumType.Set("blue") - require.NoError(t, err) - - err = enumType.AssignTo(&ps) - require.NoError(t, err) - - assert.EqualValues(t, "blue", *ps) - } - - { - var ps *string - - err := enumType.Set(nil) - require.NoError(t, err) - - err = enumType.AssignTo(&ps) - require.NoError(t, err) - - assert.EqualValues(t, (*string)(nil), ps) - } - - var buf []byte - bytesTests := []struct { - src interface{} - dst *[]byte - expected []byte - }{ - {src: "blue", dst: &buf, expected: []byte("blue")}, - {src: nil, dst: &buf, expected: nil}, - } - - for i, tt := range bytesTests { - err := enumType.Set(tt.src) - require.NoError(t, err, "%d", i) - - err = enumType.AssignTo(tt.dst) - require.NoError(t, err, "%d", i) - - if bytes.Compare(*tt.dst, tt.expected) != 0 { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst) - } - } - - { - var s string - - err := enumType.Set(nil) - require.NoError(t, err) - - err = enumType.AssignTo(&s) - require.Error(t, err) - } - -} diff --git a/ext/gofrs-uuid/uuid_test.go b/ext/gofrs-uuid/uuid_test.go deleted file mode 100644 index 56814524..00000000 --- a/ext/gofrs-uuid/uuid_test.go +++ /dev/null @@ -1,101 +0,0 @@ -package uuid_test - -import ( - "bytes" - "testing" - - "github.com/jackc/pgtype" - gofrs "github.com/jackc/pgtype/ext/gofrs-uuid" - "github.com/jackc/pgtype/testutil" -) - -func TestUUIDTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - &gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - &gofrs.UUID{Status: pgtype.Null}, - }) -} - -func TestUUIDSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result gofrs.UUID - }{ - { - source: &gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - }, - { - source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - }, - { - source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - }, - { - source: "00010203-0405-0607-0809-0a0b0c0d0e0f", - result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r gofrs.UUID - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestUUIDAssignTo(t *testing.T) { - { - src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} - var dst [16]byte - expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} - var dst []byte - expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if bytes.Compare(dst, expected) != 0 { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} - var dst string - expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - -} diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go deleted file mode 100644 index bf34e0dd..00000000 --- a/ext/shopspring-numeric/decimal_test.go +++ /dev/null @@ -1,330 +0,0 @@ -package numeric_test - -import ( - "fmt" - "math/big" - "math/rand" - "reflect" - "testing" - - "github.com/jackc/pgtype" - shopspring "github.com/jackc/pgtype/ext/shopspring-numeric" - "github.com/jackc/pgtype/testutil" - "github.com/shopspring/decimal" - "github.com/stretchr/testify/require" -) - -func mustParseDecimal(t *testing.T, src string) decimal.Decimal { - dec, err := decimal.NewFromString(src) - if err != nil { - t.Fatal(err) - } - return dec -} - -func TestNumericNormalize(t *testing.T) { - testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ - { - SQL: "select '0'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, - }, - { - SQL: "select '1'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, - }, - { - SQL: "select '10.00'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Status: pgtype.Present}, - }, - { - SQL: "select '1e-3'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, - }, - { - SQL: "select '-1'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, - }, - { - SQL: "select '10000'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Status: pgtype.Present}, - }, - { - SQL: "select '3.14'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, - }, - { - SQL: "select '1.1'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Status: pgtype.Present}, - }, - { - SQL: "select '100010001'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Status: pgtype.Present}, - }, - { - SQL: "select '100010001.0001'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Status: pgtype.Present}, - }, - { - SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", - Value: &shopspring.Numeric{ - Decimal: mustParseDecimal(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), - Status: pgtype.Present, - }, - }, - { - SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", - Value: &shopspring.Numeric{ - Decimal: mustParseDecimal(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), - Status: pgtype.Present, - }, - }, - { - SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", - Value: &shopspring.Numeric{ - Decimal: mustParseDecimal(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), - Status: pgtype.Present, - }, - }, - }, func(aa, bb interface{}) bool { - a := aa.(shopspring.Numeric) - b := bb.(shopspring.Numeric) - - return a.Status == b.Status && a.Decimal.Equal(b.Decimal) - }) -} - -func TestNumericTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "100000"), Status: pgtype.Present}, - - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.1"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.01"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0001"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00001"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000001"), Status: pgtype.Present}, - - &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000123"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000000123"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0000000123"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000123"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001234567890123456789"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "4309132809320932980457137401234890237489238912983572189348951289375283573984571892758234678903467889512893489128589347891272139.8489235871258912789347891235879148795891238915678189467128957812395781238579189025891238901583915890128973578957912385798125789012378905238905471598123758923478294374327894237892234"), Status: pgtype.Present}, - &shopspring.Numeric{Status: pgtype.Null}, - }, func(aa, bb interface{}) bool { - a := aa.(shopspring.Numeric) - b := bb.(shopspring.Numeric) - - return a.Status == b.Status && a.Decimal.Equal(b.Decimal) - }) - -} - -func TestNumericTranscodeFuzz(t *testing.T) { - r := rand.New(rand.NewSource(0)) - max := &big.Int{} - max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) - - values := make([]interface{}, 0, 2000) - for i := 0; i < 500; i++ { - num := fmt.Sprintf("%s.%s", (&big.Int{}).Rand(r, max).String(), (&big.Int{}).Rand(r, max).String()) - negNum := "-" + num - values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, num), Status: pgtype.Present}) - values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, negNum), Status: pgtype.Present}) - } - - testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, - func(aa, bb interface{}) bool { - a := aa.(shopspring.Numeric) - b := bb.(shopspring.Numeric) - - return a.Status == b.Status && a.Decimal.Equal(b.Decimal) - }) -} - -func TestNumericSet(t *testing.T) { - type _int8 int8 - - successfulTests := []struct { - source interface{} - result *shopspring.Numeric - }{ - {source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: int16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: int32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: int64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: int8(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, - {source: int16(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, - {source: int32(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, - {source: int64(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, - {source: uint8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: uint16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: uint32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: uint64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: "1", result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: _int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: float64(1000), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1000"), Status: pgtype.Present}}, - {source: float64(1234), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1234"), Status: pgtype.Present}}, - {source: float64(12345678900), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345678900"), Status: pgtype.Present}}, - {source: float64(1.25), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.25"), Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - r := &shopspring.Numeric{} - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !(r.Status == tt.result.Status && r.Decimal.Equal(tt.result.Decimal)) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestNumericAssignTo(t *testing.T) { - type _int8 int8 - - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - var f32 float32 - var f64 float64 - var pf32 *float32 - var pf64 *float64 - - simpleTests := []struct { - src *shopspring.Numeric - dst interface{} - expected interface{} - }{ - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f32, expected: float32(4.2)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f64, expected: float64(4.2)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &i64, expected: int64(42000)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src *shopspring.Numeric - dst interface{} - expected interface{} - }{ - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src *shopspring.Numeric - dst interface{} - }{ - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "150"), Status: pgtype.Present}, dst: &i8}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "40000"), Status: pgtype.Present}, dst: &i16}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui8}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui16}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui32}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui64}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui}, - {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &i32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} - -func BenchmarkDecode(b *testing.B) { - benchmarks := []struct { - name string - numberStr string - }{ - {"Zero", "0"}, - {"Small", "12345"}, - {"Medium", "12345.12345"}, - {"Large", "123457890.1234567890"}, - {"Huge", "123457890123457890123457890.1234567890123457890123457890"}, - } - - for _, bm := range benchmarks { - src := &shopspring.Numeric{} - err := src.Set(bm.numberStr) - require.NoError(b, err) - textFormat, err := src.EncodeText(nil, nil) - require.NoError(b, err) - binaryFormat, err := src.EncodeBinary(nil, nil) - require.NoError(b, err) - - b.Run(fmt.Sprintf("%s-Text", bm.name), func(b *testing.B) { - dst := &shopspring.Numeric{} - for i := 0; i < b.N; i++ { - err := dst.DecodeText(nil, textFormat) - if err != nil { - b.Fatal(err) - } - } - }) - - b.Run(fmt.Sprintf("%s-Binary", bm.name), func(b *testing.B) { - dst := &shopspring.Numeric{} - for i := 0; i < b.N; i++ { - err := dst.DecodeBinary(nil, binaryFormat) - if err != nil { - b.Fatal(err) - } - } - }) - } -} diff --git a/float4_array_test.go b/float4_array_test.go deleted file mode 100644 index db438999..00000000 --- a/float4_array_test.go +++ /dev/null @@ -1,282 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestFloat4ArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "float4[]", []interface{}{ - &pgtype.Float4Array{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Float4Array{Status: pgtype.Null}, - &pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Float: 6, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestFloat4ArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Float4Array - }{ - { - source: []float32{1}, - result: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]float32)(nil)), - result: pgtype.Float4Array{Status: pgtype.Null}, - }, - { - source: [][]float32{{1}, {2}}, - result: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Float: 5, Status: pgtype.Present}, - {Float: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1]float32{{1}, {2}}, - result: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Float: 5, Status: pgtype.Present}, - {Float: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Float4Array - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestFloat4ArrayAssignTo(t *testing.T) { - var float32Slice []float32 - var namedFloat32Slice _float32Slice - var float32SliceDim2 [][]float32 - var float32SliceDim4 [][][][]float32 - var float32ArrayDim2 [2][1]float32 - var float32ArrayDim4 [2][1][1][3]float32 - - simpleTests := []struct { - src pgtype.Float4Array - dst interface{} - expected interface{} - }{ - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1.23, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &float32Slice, - expected: []float32{1.23}, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1.23, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedFloat32Slice, - expected: _float32Slice{1.23}, - }, - { - src: pgtype.Float4Array{Status: pgtype.Null}, - dst: &float32Slice, - expected: (([]float32)(nil)), - }, - { - src: pgtype.Float4Array{Status: pgtype.Present}, - dst: &float32Slice, - expected: []float32{}, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - expected: [][]float32{{1}, {2}}, - dst: &float32SliceDim2, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Float: 5, Status: pgtype.Present}, - {Float: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - expected: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &float32SliceDim4, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - expected: [2][1]float32{{1}, {2}}, - dst: &float32ArrayDim2, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Float: 5, Status: pgtype.Present}, - {Float: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - expected: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &float32ArrayDim4, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Float4Array - dst interface{} - }{ - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &float32Slice, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &float32ArrayDim2, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &float32Slice, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &float32ArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/float4_test.go b/float4_test.go deleted file mode 100644 index d2524cda..00000000 --- a/float4_test.go +++ /dev/null @@ -1,149 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestFloat4Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "float4", []interface{}{ - &pgtype.Float4{Float: -1, Status: pgtype.Present}, - &pgtype.Float4{Float: 0, Status: pgtype.Present}, - &pgtype.Float4{Float: 0.00001, Status: pgtype.Present}, - &pgtype.Float4{Float: 1, Status: pgtype.Present}, - &pgtype.Float4{Float: 9999.99, Status: pgtype.Present}, - &pgtype.Float4{Float: 0, Status: pgtype.Null}, - }) -} - -func TestFloat4Set(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Float4 - }{ - {source: float32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: float64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: int8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.Float4 - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestFloat4AssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - var f32 float32 - var f64 float64 - var pf32 *float32 - var pf64 *float64 - - simpleTests := []struct { - src pgtype.Float4 - dst interface{} - expected interface{} - }{ - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &f32, expected: float32(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &f64, expected: float64(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Float4 - dst interface{} - expected interface{} - }{ - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Float4 - dst interface{} - }{ - {src: pgtype.Float4{Float: 150, Status: pgtype.Present}, dst: &i8}, - {src: pgtype.Float4{Float: 40000, Status: pgtype.Present}, dst: &i16}, - {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &i32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/float8_array_test.go b/float8_array_test.go deleted file mode 100644 index 85cb8f43..00000000 --- a/float8_array_test.go +++ /dev/null @@ -1,258 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestFloat8ArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "float8[]", []interface{}{ - &pgtype.Float8Array{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.Float8Array{ - Elements: []pgtype.Float8{ - {Float: 1, Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Float8Array{Status: pgtype.Null}, - &pgtype.Float8Array{ - Elements: []pgtype.Float8{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Float: 6, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Float8Array{ - Elements: []pgtype.Float8{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestFloat8ArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Float8Array - }{ - { - source: []float64{1}, - result: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]float64)(nil)), - result: pgtype.Float8Array{Status: pgtype.Null}, - }, - { - source: [][]float64{{1}, {2}}, - result: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Float8Array{ - Elements: []pgtype.Float8{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Float: 5, Status: pgtype.Present}, - {Float: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Float8Array - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestFloat8ArrayAssignTo(t *testing.T) { - var float64Slice []float64 - var namedFloat64Slice _float64Slice - var float64SliceDim2 [][]float64 - var float64SliceDim4 [][][][]float64 - var float64ArrayDim2 [2][1]float64 - var float64ArrayDim4 [2][1][1][3]float64 - - simpleTests := []struct { - src pgtype.Float8Array - dst interface{} - expected interface{} - }{ - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1.23, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &float64Slice, - expected: []float64{1.23}, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1.23, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedFloat64Slice, - expected: _float64Slice{1.23}, - }, - { - src: pgtype.Float8Array{Status: pgtype.Null}, - dst: &float64Slice, - expected: (([]float64)(nil)), - }, - { - src: pgtype.Float8Array{Status: pgtype.Present}, - dst: &float64Slice, - expected: []float64{}, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - expected: [][]float64{{1}, {2}}, - dst: &float64SliceDim2, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Float: 5, Status: pgtype.Present}, - {Float: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - expected: [][][][]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &float64SliceDim4, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - expected: [2][1]float64{{1}, {2}}, - dst: &float64ArrayDim2, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Float: 5, Status: pgtype.Present}, - {Float: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - expected: [2][1][1][3]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &float64ArrayDim4, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Float8Array - dst interface{} - }{ - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &float64Slice, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &float64ArrayDim2, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &float64Slice, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &float64ArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/float8_test.go b/float8_test.go deleted file mode 100644 index 6bc7c652..00000000 --- a/float8_test.go +++ /dev/null @@ -1,149 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestFloat8Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "float8", []interface{}{ - &pgtype.Float8{Float: -1, Status: pgtype.Present}, - &pgtype.Float8{Float: 0, Status: pgtype.Present}, - &pgtype.Float8{Float: 0.00001, Status: pgtype.Present}, - &pgtype.Float8{Float: 1, Status: pgtype.Present}, - &pgtype.Float8{Float: 9999.99, Status: pgtype.Present}, - &pgtype.Float8{Float: 0, Status: pgtype.Null}, - }) -} - -func TestFloat8Set(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Float8 - }{ - {source: float32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: float64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: int8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.Float8 - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestFloat8AssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - var f32 float32 - var f64 float64 - var pf32 *float32 - var pf64 *float64 - - simpleTests := []struct { - src pgtype.Float8 - dst interface{} - expected interface{} - }{ - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &f32, expected: float32(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &f64, expected: float64(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Float8 - dst interface{} - expected interface{} - }{ - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Float8 - dst interface{} - }{ - {src: pgtype.Float8{Float: 150, Status: pgtype.Present}, dst: &i8}, - {src: pgtype.Float8{Float: 40000, Status: pgtype.Present}, dst: &i16}, - {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &i32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/go.mod b/go.mod index 29e6f628..42ee3838 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,6 @@ go 1.13 require ( github.com/gofrs/uuid v4.0.0+incompatible - github.com/jackc/pgconn v1.9.0 github.com/jackc/pgio v1.0.0 - github.com/jackc/pgx/v4 v4.12.0 - github.com/lib/pq v1.10.2 github.com/shopspring/decimal v1.2.0 - github.com/stretchr/testify v1.7.0 ) diff --git a/go.sum b/go.sum index e49ce26f..da822c7d 100644 --- a/go.sum +++ b/go.sum @@ -1,486 +1,6 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= -github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= -github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= -github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= -github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= -github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= -github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c= -github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= -github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= -github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= -github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= -github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= -github.com/aryann/difflib v0.0.0-20170710044230-e206f873d14a/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A= -github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU= -github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= -github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= -github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= -github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ= -github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec/go.mod h1:jMjuTZXRI4dUb/I5gc9Hdhagfvm9+RyrPryS/auMzxE= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= -github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= -github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= -github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= -github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= -github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= -github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= -github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= -github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= -github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= -github.com/envoyproxy/go-control-plane v0.6.9/go.mod h1:SBwIajubJHhxtWwsL9s8ss4safvEdbitLhGGK48rN6g= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= -github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= -github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-kit/kit v0.10.0/go.mod h1:xUsJbQ/Fp4kEt7AFgCuvyX4a71u8h9jB8tj/ORgOZ7o= -github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= -github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= -github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= -github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= -github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= -github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= -github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= -github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= -github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= -github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= -github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= -github.com/hashicorp/consul/api v1.3.0/go.mod h1:MmDNSzIMUjNpY/mQ398R4bk2FnqQLoPndWW5VkKPlCE= -github.com/hashicorp/consul/sdk v0.3.0/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= -github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= -github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= -github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= -github.com/hashicorp/go-msgpack v0.5.3/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= -github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= -github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU= -github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= -github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= -github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= -github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= -github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= -github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= -github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= -github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= -github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= -github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= -github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= -github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= -github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= -github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= -github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= -github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= -github.com/jackc/pgconn v1.4.0/go.mod h1:Y2O3ZDF0q4mMacyWV3AstPJpeHXWGEetiFttmq5lahk= -github.com/jackc/pgconn v1.5.0/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= -github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= -github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= -github.com/jackc/pgconn v1.8.1/go.mod h1:JV6m6b6jhjdmzchES0drzCcYcAHS1OPD5xu3OZ/lE2g= -github.com/jackc/pgconn v1.9.0 h1:gqibKSTJup/ahCsNKyMZAniPuZEfIqfXFc8FOWVYR+Q= -github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= -github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= -github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd h1:eDErF6V/JPJON/B7s68BxwHgfmyOntHJQ8IOaz0x4R8= -github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= -github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= -github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= -github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= -github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= -github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= -github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= -github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= -github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= -github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= -github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= -github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4WpC0= -github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= -github.com/jackc/pgtype v1.3.1-0.20200606141011-f6355165a91c/go.mod h1:cvk9Bgu/VzJ9/lxTO5R5sf80p0DiucVtN7ZxvaC4GmQ= -github.com/jackc/pgtype v1.7.0/go.mod h1:ZnHF+rMePVqDKaOfJVI4Q8IVvAQMryDlDkZnKOI75BE= -github.com/jackc/pgtype v1.8.0/go.mod h1:PqDKcEBtllAtk/2p6z6SHdXW5UB+MhE75tUol2OKexE= -github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= -github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= -github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= -github.com/jackc/pgx/v4 v4.5.0/go.mod h1:EpAKPLdnTorwmPUUsqrPxy5fphV18j9q3wrfRXgo+kA= -github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9/go.mod h1:t3/cdRQl6fOLDxqtlyhe9UWgfIi9R8+8v8GKV5TRA/o= -github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904/go.mod h1:ZDaNWkt9sW1JMiNn0kdYBaLelIhw7Pg4qd+Vk6tw7Hg= -github.com/jackc/pgx/v4 v4.11.0/go.mod h1:i62xJgdrtVDsnL3U8ekyrQXEwGNTRoG7/8r+CIdYfcc= -github.com/jackc/pgx/v4 v4.12.0 h1:xiP3TdnkwyslWNp77yE5XAPfxAsU9RMFDe0c1SwN8h4= -github.com/jackc/pgx/v4 v4.12.0/go.mod h1:fE547h6VulLPA3kySjfnSG/e2D861g/50JlVUa/ub60= -github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.1.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.1.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= -github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= -github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= -github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/json-iterator/go v1.1.8/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= -github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= -github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= -github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= -github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= -github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= -github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= -github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= -github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= -github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= -github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= -github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= -github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= -github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= -github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= -github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/nats-io/jwt v0.3.0/go.mod h1:fRYCDE99xlTsqUzISS1Bi75UBJ6ljOJQOAAu5VglpSg= -github.com/nats-io/jwt v0.3.2/go.mod h1:/euKqTS1ZD+zzjYrY7pseZrTtWQSjujC7xjPc8wL6eU= -github.com/nats-io/nats-server/v2 v2.1.2/go.mod h1:Afk+wRZqkMQs/p45uXdrVLuab3gwv3Z8C4HTBu8GD/k= -github.com/nats-io/nats.go v1.9.1/go.mod h1:ZjDU1L/7fJ09jvUSRVBR2e7+RnLiiIQyqyzEE/Zbp4w= -github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= -github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= -github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= -github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs= -github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= -github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= -github.com/op/go-logging v0.0.0-20160315200505-970db520ece7/go.mod h1:HzydrMdWErDVzsI23lYNej1Htcns9BCg93Dk0bBINWk= -github.com/opentracing-contrib/go-observer v0.0.0-20170622124052-a52f23424492/go.mod h1:Ngi6UdF0k5OKD5t5wlmGhe/EDKPoUM3BXZSSfIuJbis= -github.com/opentracing/basictracer-go v1.0.0/go.mod h1:QfBfYuafItcjQuMwinw9GhYKwFXS9KnPs5lxoYwgW74= -github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= -github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= -github.com/openzipkin-contrib/zipkin-go-opentracing v0.4.5/go.mod h1:/wsWhb9smxSfWAKL3wpBW7V8scJMt8N8gnaMCS9E/cA= -github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= -github.com/openzipkin/zipkin-go v0.2.1/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= -github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= -github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM= -github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= -github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= -github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac= -github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= -github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= -github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs= -github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= -github.com/prometheus/client_golang v1.3.0/go.mod h1:hJaj2vgQTGQmVCsAACORcieXFeDPbaTKGT+JTgUa3og= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.1.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= -github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= -github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= -github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= -github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= -github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= -github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= -github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= -github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= -github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= -github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= -github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= -github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= -github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= -github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= -github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= -github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= -github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= -github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= -github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= -github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= -github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= -github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= -github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= -github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= -github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= -github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= -go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= -go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg= -go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= -go.opencensus.io v0.20.2/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= -go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= -go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= -go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= -go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= -go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= -go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= -golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= -golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e h1:gsTQYXdTw2Gq7RBsWvlQ91b+aEQ6bXFUngBGuR8sPpI= -golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= -golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= -golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190530194941-fb225487d101/go.mod h1:z3L6/3dTEVtUr6QSP8miRzeRqwQOioJ9I66odjN4I7s= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.20.0/go.mod h1:chYK+tFQF0nDUGJgXMSgLCQk3phJEuONr2DCgLDdAQM= -google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= -google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= -google.golang.org/grpc v1.22.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= -gopkg.in/gcfg.v1 v1.2.3/go.mod h1:yesOnuUOFQAhST5vPY4nbZsb/huCgGGXlipJsBn0b3o= -gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= -gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= -gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= -sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU= diff --git a/hstore_array_test.go b/hstore_array_test.go deleted file mode 100644 index 672eca4a..00000000 --- a/hstore_array_test.go +++ /dev/null @@ -1,436 +0,0 @@ -package pgtype_test - -import ( - "context" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgx/v4" -) - -func TestHstoreArrayTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - var hstoreOID uint32 - err := conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='hstore';").Scan(&hstoreOID) - if err != nil { - t.Fatalf("did not find hstore OID, %v", err) - } - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.Hstore{}, Name: "hstore", OID: hstoreOID}) - - var hstoreArrayOID uint32 - err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='_hstore';").Scan(&hstoreArrayOID) - if err != nil { - t.Fatalf("did not find _hstore OID, %v", err) - } - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.HstoreArray{}, Name: "_hstore", OID: hstoreArrayOID}) - - text := func(s string) pgtype.Text { - return pgtype.Text{String: s, Status: pgtype.Present} - } - - values := []pgtype.Hstore{ - {Map: map[string]pgtype.Text{}, Status: pgtype.Present}, - {Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, - {Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, - {Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, - {Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, - {Status: pgtype.Null}, - } - - specialStrings := []string{ - `"`, - `'`, - `\`, - `\\`, - `=>`, - ` `, - `\ / / \\ => " ' " '`, - } - for _, s := range specialStrings { - // Special key values - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key - - // Special value values - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key - } - - src := &pgtype.HstoreArray{ - Elements: values, - Dimensions: []pgtype.ArrayDimension{{Length: int32(len(values)), LowerBound: 1}}, - Status: pgtype.Present, - } - - _, err = conn.Prepare(context.Background(), "test", "select $1::hstore[]") - if err != nil { - t.Fatal(err) - } - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for _, fc := range formats { - queryResultFormats := pgx.QueryResultFormats{fc.formatCode} - vEncoder := testutil.ForceEncoder(src, fc.formatCode) - if vEncoder == nil { - t.Logf("%#v does not implement %v", src, fc.name) - continue - } - - var result pgtype.HstoreArray - err := conn.QueryRow(context.Background(), "test", queryResultFormats, vEncoder).Scan(&result) - if err != nil { - t.Errorf("%v: %v", fc.name, err) - continue - } - - if result.Status != src.Status { - t.Errorf("%v: expected Status %v, got %v", fc.formatCode, src.Status, result.Status) - continue - } - - if len(result.Elements) != len(src.Elements) { - t.Errorf("%v: expected %v elements, got %v", fc.formatCode, len(src.Elements), len(result.Elements)) - continue - } - - for i := range result.Elements { - a := src.Elements[i] - b := result.Elements[i] - - if a.Status != b.Status { - t.Errorf("%v element idx %d: expected status %v, got %v", fc.formatCode, i, a.Status, b.Status) - } - - if len(a.Map) != len(b.Map) { - t.Errorf("%v element idx %d: expected %v pairs, got %v", fc.formatCode, i, len(a.Map), len(b.Map)) - } - - for k := range a.Map { - if a.Map[k] != b.Map[k] { - t.Errorf("%v element idx %d: expected key %v to be %v, got %v", fc.formatCode, i, k, a.Map[k], b.Map[k]) - } - } - } - } -} - -func TestHstoreArraySet(t *testing.T) { - successfulTests := []struct { - src interface{} - result pgtype.HstoreArray - }{ - { - src: []map[string]string{{"foo": "bar"}}, - result: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - }, - { - src: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, - result: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - }, - { - src: [][][][]map[string]string{ - {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, - {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, - result: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - }, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present, - }, - }, - { - src: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, - result: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - }, - { - src: [2][1][1][3]map[string]string{ - {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, - {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, - result: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - }, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present, - }, - }, - } - - for i, tt := range successfulTests { - var dst pgtype.HstoreArray - err := dst.Set(tt.src) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(dst, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) - } - } -} - -func TestHstoreArrayAssignTo(t *testing.T) { - var hstoreSlice []map[string]string - var hstoreSliceDim2 [][]map[string]string - var hstoreSliceDim4 [][][][]map[string]string - var hstoreArrayDim2 [2][1]map[string]string - var hstoreArrayDim4 [2][1][1][3]map[string]string - - simpleTests := []struct { - src pgtype.HstoreArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &hstoreSlice, - expected: []map[string]string{{"foo": "bar"}}}, - { - src: pgtype.HstoreArray{Status: pgtype.Null}, dst: &hstoreSlice, expected: (([]map[string]string)(nil)), - }, - { - src: pgtype.HstoreArray{Status: pgtype.Present}, dst: &hstoreSlice, expected: []map[string]string{}, - }, - { - src: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &hstoreSliceDim2, - expected: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, - }, - { - src: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - }, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present, - }, - dst: &hstoreSliceDim4, - expected: [][][][]map[string]string{ - {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, - {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, - }, - { - src: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &hstoreArrayDim2, - expected: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, - }, - { - src: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - { - Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - }, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present, - }, - dst: &hstoreArrayDim4, - expected: [2][1][1][3]map[string]string{ - {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, - {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/hstore_test.go b/hstore_test.go deleted file mode 100644 index dce8baf2..00000000 --- a/hstore_test.go +++ /dev/null @@ -1,111 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestHstoreTranscode(t *testing.T) { - text := func(s string) pgtype.Text { - return pgtype.Text{String: s, Status: pgtype.Present} - } - - values := []interface{}{ - &pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(""), "bar": text(""), "baz": text("123")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"": text("bar")}, Status: pgtype.Present}, - &pgtype.Hstore{Status: pgtype.Null}, - } - - specialStrings := []string{ - `"`, - `'`, - `\`, - `\\`, - `=>`, - ` `, - `\ / / \\ => " ' " '`, - } - for _, s := range specialStrings { - // Special key values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key - - // Special value values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key - } - - testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { - a := ai.(pgtype.Hstore) - b := bi.(pgtype.Hstore) - - if len(a.Map) != len(b.Map) || a.Status != b.Status { - return false - } - - for k := range a.Map { - if a.Map[k] != b.Map[k] { - return false - } - } - - return true - }) -} - -func TestHstoreSet(t *testing.T) { - successfulTests := []struct { - src map[string]string - result pgtype.Hstore - }{ - {src: map[string]string{"foo": "bar"}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var dst pgtype.Hstore - err := dst.Set(tt.src) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(dst, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) - } - } -} - -func TestHstoreAssignTo(t *testing.T) { - var m map[string]string - - simpleTests := []struct { - src pgtype.Hstore - dst *map[string]string - expected map[string]string - }{ - {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}, dst: &m, expected: map[string]string{"foo": "bar"}}, - {src: pgtype.Hstore{Status: pgtype.Null}, dst: &m, expected: ((map[string]string)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(*tt.dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } -} diff --git a/inet_array_test.go b/inet_array_test.go deleted file mode 100644 index 46dc7d12..00000000 --- a/inet_array_test.go +++ /dev/null @@ -1,319 +0,0 @@ -package pgtype_test - -import ( - "net" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestInetArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "inet[]", []interface{}{ - &pgtype.InetArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.InetArray{Status: pgtype.Null}, - &pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - {Status: pgtype.Null}, - {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestInetArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.InetArray - }{ - { - source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]*net.IPNet)(nil)), - result: pgtype.InetArray{Status: pgtype.Null}, - }, - { - source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]net.IP)(nil)), - result: pgtype.InetArray{Status: pgtype.Null}, - }, - { - source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.InetArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInetArrayAssignTo(t *testing.T) { - var ipnetSlice []*net.IPNet - var ipSlice []net.IP - var ipSliceDim2 [][]net.IP - var ipnetSliceDim4 [][][][]*net.IPNet - var ipArrayDim2 [2][1]net.IP - var ipnetArrayDim4 [2][1][1][3]*net.IPNet - - simpleTests := []struct { - src pgtype.InetArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &ipnetSlice, - expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &ipnetSlice, - expected: []*net.IPNet{nil}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &ipSlice, - expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &ipSlice, - expected: []net.IP{nil}, - }, - { - src: pgtype.InetArray{Status: pgtype.Null}, - dst: &ipnetSlice, - expected: (([]*net.IPNet)(nil)), - }, - { - src: pgtype.InetArray{Status: pgtype.Present}, - dst: &ipnetSlice, - expected: []*net.IPNet{}, - }, - { - src: pgtype.InetArray{Status: pgtype.Null}, - dst: &ipSlice, - expected: (([]net.IP)(nil)), - }, - { - src: pgtype.InetArray{Status: pgtype.Present}, - dst: &ipSlice, - expected: []net.IP{}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &ipSliceDim2, - expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &ipnetSliceDim4, - expected: [][][][]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &ipArrayDim2, - expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &ipnetArrayDim4, - expected: [2][1][1][3]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/inet_test.go b/inet_test.go deleted file mode 100644 index 66fe777f..00000000 --- a/inet_test.go +++ /dev/null @@ -1,134 +0,0 @@ -package pgtype_test - -import ( - "net" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" - "github.com/stretchr/testify/assert" -) - -func TestInetTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "inet", []interface{}{ - &pgtype.Inet{IPNet: mustParseInet(t, "0.0.0.0/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "127.0.0.1/8"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "12.34.56.65/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "192.168.1.16/24"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "255.0.0.0/8"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "255.255.255.255/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "::1/64"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "::/0"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "::1/128"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), Status: pgtype.Present}, - &pgtype.Inet{Status: pgtype.Null}, - }) -} - -func TestCidrTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "cidr", []interface{}{ - &pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - &pgtype.Inet{Status: pgtype.Null}, - }) -} - -func TestInetSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Inet - }{ - {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: "1.2.3.4/24", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("1.2.3.4"), Mask: net.CIDRMask(24, 32)}, Status: pgtype.Present}}, - {source: net.ParseIP(""), result: pgtype.Inet{Status: pgtype.Null}}, - } - - for i, tt := range successfulTests { - var r pgtype.Inet - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - assert.Equalf(t, tt.result.Status, r.Status, "%d: Status", i) - if tt.result.Status == pgtype.Present { - assert.Equalf(t, tt.result.IPNet.Mask, r.IPNet.Mask, "%d: IP", i) - assert.Truef(t, tt.result.IPNet.IP.Equal(r.IPNet.IP), "%d: Mask", i) - } - } -} - -func TestInetAssignTo(t *testing.T) { - var ipnet net.IPNet - var pipnet *net.IPNet - var ip net.IP - var pip *net.IP - - simpleTests := []struct { - src pgtype.Inet - dst interface{} - expected interface{} - }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, - {src: pgtype.Inet{Status: pgtype.Null}, dst: &pipnet, expected: ((*net.IPNet)(nil))}, - {src: pgtype.Inet{Status: pgtype.Null}, dst: &pip, expected: ((*net.IP)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %#v, but result was %#v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Inet - dst interface{} - expected interface{} - }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Inet - dst interface{} - }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip}, - {src: pgtype.Inet{Status: pgtype.Null}, dst: &ipnet}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/int2_array_test.go b/int2_array_test.go deleted file mode 100644 index 17c37360..00000000 --- a/int2_array_test.go +++ /dev/null @@ -1,342 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestInt2ArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int2[]", []interface{}{ - &pgtype.Int2Array{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Int2Array{Status: pgtype.Null}, - &pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: 6, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestInt2ArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int2Array - }{ - { - source: []int64{1}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []int32{1}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []int16{1}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []int{1}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []uint64{1}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []uint32{1}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []uint16{1}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]int16)(nil)), - result: pgtype.Int2Array{Status: pgtype.Null}, - }, - { - source: [][]int16{{1}, {2}}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1]int16{{1}, {2}}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Int2Array - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt2ArrayAssignTo(t *testing.T) { - var int16Slice []int16 - var uint16Slice []uint16 - var namedInt16Slice _int16Slice - var int16SliceDim2 [][]int16 - var int16SliceDim4 [][][][]int16 - var int16ArrayDim2 [2][1]int16 - var int16ArrayDim4 [2][1][1][3]int16 - - simpleTests := []struct { - src pgtype.Int2Array - dst interface{} - expected interface{} - }{ - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &int16Slice, - expected: []int16{1}, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &uint16Slice, - expected: []uint16{1}, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedInt16Slice, - expected: _int16Slice{1}, - }, - { - src: pgtype.Int2Array{Status: pgtype.Null}, - dst: &int16Slice, - expected: (([]int16)(nil)), - }, - { - src: pgtype.Int2Array{Status: pgtype.Present}, - dst: &int16Slice, - expected: []int16{}, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - expected: [][]int16{{1}, {2}}, - dst: &int16SliceDim2, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - expected: [][][][]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &int16SliceDim4, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - expected: [2][1]int16{{1}, {2}}, - dst: &int16ArrayDim2, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - expected: [2][1][1][3]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &int16ArrayDim4, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int2Array - dst interface{} - }{ - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &int16Slice, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: -1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &uint16Slice, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &int16ArrayDim2, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &int16Slice, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &int16ArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/int2_test.go b/int2_test.go deleted file mode 100644 index 178eb278..00000000 --- a/int2_test.go +++ /dev/null @@ -1,144 +0,0 @@ -package pgtype_test - -import ( - "math" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestInt2Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ - &pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}, - &pgtype.Int2{Int: -1, Status: pgtype.Present}, - &pgtype.Int2{Int: 0, Status: pgtype.Present}, - &pgtype.Int2{Int: 1, Status: pgtype.Present}, - &pgtype.Int2{Int: math.MaxInt16, Status: pgtype.Present}, - &pgtype.Int2{Int: 0, Status: pgtype.Null}, - }) -} - -func TestInt2Set(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int2 - }{ - {source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: float32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: float64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.Int2 - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt2AssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - - simpleTests := []struct { - src pgtype.Int2 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Int2 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int2 - dst interface{} - }{ - {src: pgtype.Int2{Int: 150, Status: pgtype.Present}, dst: &i8}, - {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &i16}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/int4_array_test.go b/int4_array_test.go deleted file mode 100644 index 110512a9..00000000 --- a/int4_array_test.go +++ /dev/null @@ -1,356 +0,0 @@ -package pgtype_test - -import ( - "math" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestInt4ArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int4[]", []interface{}{ - &pgtype.Int4Array{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Int4Array{Status: pgtype.Null}, - &pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: 6, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestInt4ArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int4Array - expectedError bool - }{ - { - source: []int64{1}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []int32{1}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []int16{1}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []int{1}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []int{1, math.MaxInt32 + 1, 2}, - expectedError: true, - }, - { - source: []uint64{1}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []uint32{1}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []uint16{1}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]int32)(nil)), - result: pgtype.Int4Array{Status: pgtype.Null}, - }, - { - source: [][]int32{{1}, {2}}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1]int32{{1}, {2}}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Int4Array - err := r.Set(tt.source) - if err != nil { - if tt.expectedError { - continue - } - t.Errorf("%d: %v", i, err) - } - - if tt.expectedError { - t.Errorf("%d: an error was expected, %v", i, tt) - continue - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt4ArrayAssignTo(t *testing.T) { - var int32Slice []int32 - var uint32Slice []uint32 - var namedInt32Slice _int32Slice - var int32SliceDim2 [][]int32 - var int32SliceDim4 [][][][]int32 - var int32ArrayDim2 [2][1]int32 - var int32ArrayDim4 [2][1][1][3]int32 - - simpleTests := []struct { - src pgtype.Int4Array - dst interface{} - expected interface{} - }{ - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &int32Slice, - expected: []int32{1}, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &uint32Slice, - expected: []uint32{1}, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedInt32Slice, - expected: _int32Slice{1}, - }, - { - src: pgtype.Int4Array{Status: pgtype.Null}, - dst: &int32Slice, - expected: (([]int32)(nil)), - }, - { - src: pgtype.Int4Array{Status: pgtype.Present}, - dst: &int32Slice, - expected: []int32{}, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - expected: [][]int32{{1}, {2}}, - dst: &int32SliceDim2, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - expected: [][][][]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &int32SliceDim4, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - expected: [2][1]int32{{1}, {2}}, - dst: &int32ArrayDim2, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - expected: [2][1][1][3]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &int32ArrayDim4, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int4Array - dst interface{} - }{ - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &int32Slice, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: -1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &uint32Slice, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &int32ArrayDim2, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &int32Slice, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &int32ArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/int4_test.go b/int4_test.go deleted file mode 100644 index ae01114f..00000000 --- a/int4_test.go +++ /dev/null @@ -1,186 +0,0 @@ -package pgtype_test - -import ( - "math" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestInt4Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ - &pgtype.Int4{Int: math.MinInt32, Status: pgtype.Present}, - &pgtype.Int4{Int: -1, Status: pgtype.Present}, - &pgtype.Int4{Int: 0, Status: pgtype.Present}, - &pgtype.Int4{Int: 1, Status: pgtype.Present}, - &pgtype.Int4{Int: math.MaxInt32, Status: pgtype.Present}, - &pgtype.Int4{Int: 0, Status: pgtype.Null}, - }) -} - -func TestInt4Set(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int4 - }{ - {source: int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: float32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: float64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.Int4 - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt4AssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - - simpleTests := []struct { - src pgtype.Int4 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Int4 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int4 - dst interface{} - }{ - {src: pgtype.Int4{Int: 150, Status: pgtype.Present}, dst: &i8}, - {src: pgtype.Int4{Int: 40000, Status: pgtype.Present}, dst: &i16}, - {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &i32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} - -func TestInt4MarshalJSON(t *testing.T) { - successfulTests := []struct { - source pgtype.Int4 - result string - }{ - {source: pgtype.Int4{Int: 0, Status: pgtype.Null}, result: "null"}, - {source: pgtype.Int4{Int: 1, Status: pgtype.Present}, result: "1"}, - } - for i, tt := range successfulTests { - r, err := tt.source.MarshalJSON() - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if string(r) != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) - } - } -} - -func TestInt4UnmarshalJSON(t *testing.T) { - successfulTests := []struct { - source string - result pgtype.Int4 - }{ - {source: "null", result: pgtype.Int4{Int: 0, Status: pgtype.Null}}, - {source: "1", result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - } - for i, tt := range successfulTests { - var r pgtype.Int4 - err := r.UnmarshalJSON([]byte(tt.source)) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} diff --git a/int4range_test.go b/int4range_test.go deleted file mode 100644 index 43626189..00000000 --- a/int4range_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestInt4rangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int4range", []interface{}{ - &pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int4{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, Status: pgtype.Present}, - &pgtype.Int4range{Upper: pgtype.Int4{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int4range{Status: pgtype.Null}, - }) -} - -func TestInt4rangeNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select int4range(1, 10, '(]')", - Value: pgtype.Int4range{Lower: pgtype.Int4{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - }, - }) -} diff --git a/int8_array_test.go b/int8_array_test.go deleted file mode 100644 index 1d42a278..00000000 --- a/int8_array_test.go +++ /dev/null @@ -1,349 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestInt8ArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int8[]", []interface{}{ - &pgtype.Int8Array{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Int8Array{Status: pgtype.Null}, - &pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: 6, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestInt8ArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int8Array - }{ - { - source: []int64{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []int32{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []int16{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []int{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []uint64{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []uint32{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []uint16{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []uint{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]int64)(nil)), - result: pgtype.Int8Array{Status: pgtype.Null}, - }, - { - source: [][]int64{{1}, {2}}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1]int64{{1}, {2}}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Int8Array - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt8ArrayAssignTo(t *testing.T) { - var int64Slice []int64 - var uint64Slice []uint64 - var namedInt64Slice _int64Slice - var int64SliceDim2 [][]int64 - var int64SliceDim4 [][][][]int64 - var int64ArrayDim2 [2][1]int64 - var int64ArrayDim4 [2][1][1][3]int64 - - simpleTests := []struct { - src pgtype.Int8Array - dst interface{} - expected interface{} - }{ - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &int64Slice, - expected: []int64{1}, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &uint64Slice, - expected: []uint64{1}, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedInt64Slice, - expected: _int64Slice{1}, - }, - { - src: pgtype.Int8Array{Status: pgtype.Null}, - dst: &int64Slice, - expected: (([]int64)(nil)), - }, - { - src: pgtype.Int8Array{Status: pgtype.Present}, - dst: &int64Slice, - expected: []int64{}, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - expected: [][]int64{{1}, {2}}, - dst: &int64SliceDim2, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - expected: [][][][]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &int64SliceDim4, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - expected: [2][1]int64{{1}, {2}}, - dst: &int64ArrayDim2, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - expected: [2][1][1][3]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &int64ArrayDim4, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int8Array - dst interface{} - }{ - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &int64Slice, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: -1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &uint64Slice, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &int64ArrayDim2, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &int64Slice, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &int64ArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/int8_test.go b/int8_test.go deleted file mode 100644 index 4e28e374..00000000 --- a/int8_test.go +++ /dev/null @@ -1,187 +0,0 @@ -package pgtype_test - -import ( - "math" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestInt8Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ - &pgtype.Int8{Int: math.MinInt64, Status: pgtype.Present}, - &pgtype.Int8{Int: -1, Status: pgtype.Present}, - &pgtype.Int8{Int: 0, Status: pgtype.Present}, - &pgtype.Int8{Int: 1, Status: pgtype.Present}, - &pgtype.Int8{Int: math.MaxInt64, Status: pgtype.Present}, - &pgtype.Int8{Int: 0, Status: pgtype.Null}, - }) -} - -func TestInt8Set(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int8 - }{ - {source: int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: float32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: float64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.Int8 - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt8AssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - - simpleTests := []struct { - src pgtype.Int8 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Int8 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int8 - dst interface{} - }{ - {src: pgtype.Int8{Int: 150, Status: pgtype.Present}, dst: &i8}, - {src: pgtype.Int8{Int: 40000, Status: pgtype.Present}, dst: &i16}, - {src: pgtype.Int8{Int: 5000000000, Status: pgtype.Present}, dst: &i32}, - {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &i64}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} - -func TestInt8MarshalJSON(t *testing.T) { - successfulTests := []struct { - source pgtype.Int8 - result string - }{ - {source: pgtype.Int8{Int: 0, Status: pgtype.Null}, result: "null"}, - {source: pgtype.Int8{Int: 1, Status: pgtype.Present}, result: "1"}, - } - for i, tt := range successfulTests { - r, err := tt.source.MarshalJSON() - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if string(r) != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) - } - } -} - -func TestInt8UnmarshalJSON(t *testing.T) { - successfulTests := []struct { - source string - result pgtype.Int8 - }{ - {source: "null", result: pgtype.Int8{Int: 0, Status: pgtype.Null}}, - {source: "1", result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - } - for i, tt := range successfulTests { - var r pgtype.Int8 - err := r.UnmarshalJSON([]byte(tt.source)) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} diff --git a/int8range_test.go b/int8range_test.go deleted file mode 100644 index 99d4e8a3..00000000 --- a/int8range_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestInt8rangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "Int8range", []interface{}{ - &pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int8{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, Status: pgtype.Present}, - &pgtype.Int8range{Upper: pgtype.Int8{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int8range{Status: pgtype.Null}, - }) -} - -func TestInt8rangeNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select Int8range(1, 10, '(]')", - Value: pgtype.Int8range{Lower: pgtype.Int8{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - }, - }) -} diff --git a/interval_test.go b/interval_test.go deleted file mode 100644 index 1ee094d7..00000000 --- a/interval_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package pgtype_test - -import ( - "testing" - "time" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestIntervalTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "interval", []interface{}{ - &pgtype.Interval{Microseconds: 1, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, - &pgtype.Interval{Days: 1, Status: pgtype.Present}, - &pgtype.Interval{Months: 1, Status: pgtype.Present}, - &pgtype.Interval{Months: 12, Status: pgtype.Present}, - &pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: -1, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: -1000000, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: -1000001, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: -123202800000000, Status: pgtype.Present}, - &pgtype.Interval{Days: -1, Status: pgtype.Present}, - &pgtype.Interval{Months: -1, Status: pgtype.Present}, - &pgtype.Interval{Months: -12, Status: pgtype.Present}, - &pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Status: pgtype.Present}, - &pgtype.Interval{Status: pgtype.Null}, - }) -} - -func TestIntervalNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select '1 second'::interval", - Value: &pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, - }, - { - SQL: "select '1.000001 second'::interval", - Value: &pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, - }, - { - SQL: "select '34223 hours'::interval", - Value: &pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, - }, - { - SQL: "select '1 day'::interval", - Value: &pgtype.Interval{Days: 1, Status: pgtype.Present}, - }, - { - SQL: "select '1 month'::interval", - Value: &pgtype.Interval{Months: 1, Status: pgtype.Present}, - }, - { - SQL: "select '1 year'::interval", - Value: &pgtype.Interval{Months: 12, Status: pgtype.Present}, - }, - { - SQL: "select '-13 mon'::interval", - Value: &pgtype.Interval{Months: -13, Status: pgtype.Present}, - }, - }) -} - -func TestIntervalLossyConversionToDuration(t *testing.T) { - interval := &pgtype.Interval{Months: 1, Days: 1, Status: pgtype.Present} - var d time.Duration - err := interval.AssignTo(&d) - require.NoError(t, err) - assert.EqualValues(t, int64(2678400000000000), d.Nanoseconds()) -} diff --git a/json_test.go b/json_test.go deleted file mode 100644 index bbd3959e..00000000 --- a/json_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package pgtype_test - -import ( - "bytes" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestJSONTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "json", []interface{}{ - &pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, - &pgtype.JSON{Bytes: []byte("null"), Status: pgtype.Present}, - &pgtype.JSON{Bytes: []byte("42"), Status: pgtype.Present}, - &pgtype.JSON{Bytes: []byte(`"hello"`), Status: pgtype.Present}, - &pgtype.JSON{Status: pgtype.Null}, - }) -} - -func TestJSONSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.JSON - }{ - {source: "{}", result: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: []byte("{}"), result: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: ([]byte)(nil), result: pgtype.JSON{Status: pgtype.Null}}, - {source: (*string)(nil), result: pgtype.JSON{Status: pgtype.Null}}, - {source: []int{1, 2, 3}, result: pgtype.JSON{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, - {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var d pgtype.JSON - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(d, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } - } -} - -func TestJSONAssignTo(t *testing.T) { - var s string - var ps *string - var b []byte - - rawStringTests := []struct { - src pgtype.JSON - dst *string - expected string - }{ - {src: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, - } - - for i, tt := range rawStringTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if *tt.dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } - - rawBytesTests := []struct { - src pgtype.JSON - dst *[]byte - expected []byte - }{ - {src: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, - {src: pgtype.JSON{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, - } - - for i, tt := range rawBytesTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if bytes.Compare(tt.expected, *tt.dst) != 0 { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } - - var mapDst map[string]interface{} - type structDst struct { - Name string `json:"name"` - Age int `json:"age"` - } - var strDst structDst - - unmarshalTests := []struct { - src pgtype.JSON - dst interface{} - expected interface{} - }{ - {src: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, - {src: pgtype.JSON{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, - } - for i, tt := range unmarshalTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.JSON - dst **string - expected *string - }{ - {src: pgtype.JSON{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if *tt.dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } -} - -func TestJSONMarshalJSON(t *testing.T) { - successfulTests := []struct { - source pgtype.JSON - result string - }{ - {source: pgtype.JSON{Status: pgtype.Null}, result: "null"}, - {source: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Status: pgtype.Present}, result: "{\"a\": 1}"}, - } - for i, tt := range successfulTests { - r, err := tt.source.MarshalJSON() - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if string(r) != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) - } - } -} - -func TestJSONUnmarshalJSON(t *testing.T) { - successfulTests := []struct { - source string - result pgtype.JSON - }{ - {source: "null", result: pgtype.JSON{Status: pgtype.Null}}, - {source: "{\"a\": 1}", result: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Status: pgtype.Present}}, - } - for i, tt := range successfulTests { - var r pgtype.JSON - err := r.UnmarshalJSON([]byte(tt.source)) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if string(r.Bytes) != string(tt.result.Bytes) || r.Status != tt.result.Status { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} diff --git a/jsonb_array_test.go b/jsonb_array_test.go deleted file mode 100644 index 65f1777a..00000000 --- a/jsonb_array_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestJSONBArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "jsonb[]", []interface{}{ - &pgtype.JSONBArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.JSONBArray{ - Elements: []pgtype.JSONB{ - {Bytes: []byte(`"foo"`), Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.JSONBArray{Status: pgtype.Null}, - &pgtype.JSONBArray{ - Elements: []pgtype.JSONB{ - {Bytes: []byte(`"foo"`), Status: pgtype.Present}, - {Bytes: []byte("null"), Status: pgtype.Present}, - {Bytes: []byte("42"), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, - Status: pgtype.Present, - }, - }) -} diff --git a/jsonb_test.go b/jsonb_test.go deleted file mode 100644 index 9ce80d42..00000000 --- a/jsonb_test.go +++ /dev/null @@ -1,142 +0,0 @@ -package pgtype_test - -import ( - "bytes" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestJSONBTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - if _, ok := conn.ConnInfo().DataTypeForName("jsonb"); !ok { - t.Skip("Skipping due to no jsonb type") - } - - testutil.TestSuccessfulTranscode(t, "jsonb", []interface{}{ - &pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, - &pgtype.JSONB{Bytes: []byte("null"), Status: pgtype.Present}, - &pgtype.JSONB{Bytes: []byte("42"), Status: pgtype.Present}, - &pgtype.JSONB{Bytes: []byte(`"hello"`), Status: pgtype.Present}, - &pgtype.JSONB{Status: pgtype.Null}, - }) -} - -func TestJSONBSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.JSONB - }{ - {source: "{}", result: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: []byte("{}"), result: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: ([]byte)(nil), result: pgtype.JSONB{Status: pgtype.Null}}, - {source: (*string)(nil), result: pgtype.JSONB{Status: pgtype.Null}}, - {source: []int{1, 2, 3}, result: pgtype.JSONB{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, - {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var d pgtype.JSONB - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(d, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } - } -} - -func TestJSONBAssignTo(t *testing.T) { - var s string - var ps *string - var b []byte - - rawStringTests := []struct { - src pgtype.JSONB - dst *string - expected string - }{ - {src: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, - } - - for i, tt := range rawStringTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if *tt.dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } - - rawBytesTests := []struct { - src pgtype.JSONB - dst *[]byte - expected []byte - }{ - {src: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, - {src: pgtype.JSONB{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, - } - - for i, tt := range rawBytesTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if bytes.Compare(tt.expected, *tt.dst) != 0 { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } - - var mapDst map[string]interface{} - type structDst struct { - Name string `json:"name"` - Age int `json:"age"` - } - var strDst structDst - - unmarshalTests := []struct { - src pgtype.JSONB - dst interface{} - expected interface{} - }{ - {src: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, - {src: pgtype.JSONB{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, - } - for i, tt := range unmarshalTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.JSONB - dst **string - expected *string - }{ - {src: pgtype.JSONB{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if *tt.dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } -} diff --git a/line_test.go b/line_test.go deleted file mode 100644 index f697ac43..00000000 --- a/line_test.go +++ /dev/null @@ -1,38 +0,0 @@ -package pgtype_test - -import ( - "context" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestLineTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - if _, ok := conn.ConnInfo().DataTypeForName("line"); !ok { - t.Skip("Skipping due to no line type") - } - - // line may exist but not be usable on 9.3 :( - var isPG93 bool - err := conn.QueryRow(context.Background(), "select version() ~ '9.3'").Scan(&isPG93) - if err != nil { - t.Fatal(err) - } - if isPG93 { - t.Skip("Skipping due to unimplemented line type in PG 9.3") - } - - testutil.TestSuccessfulTranscode(t, "line", []interface{}{ - &pgtype.Line{ - A: 1.23, B: 4.56, C: 7.89012345, - Status: pgtype.Present, - }, - &pgtype.Line{ - A: -1.23, B: -4.56, C: -7.89, - Status: pgtype.Present, - }, - &pgtype.Line{Status: pgtype.Null}, - }) -} diff --git a/lseg_test.go b/lseg_test.go deleted file mode 100644 index b75297cc..00000000 --- a/lseg_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestLsegTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "lseg", []interface{}{ - &pgtype.Lseg{ - P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, - Status: pgtype.Present, - }, - &pgtype.Lseg{ - P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, - Status: pgtype.Present, - }, - &pgtype.Lseg{Status: pgtype.Null}, - }) -} diff --git a/macaddr_array_test.go b/macaddr_array_test.go deleted file mode 100644 index c1a8b72d..00000000 --- a/macaddr_array_test.go +++ /dev/null @@ -1,262 +0,0 @@ -package pgtype_test - -import ( - "net" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestMacaddrArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "macaddr[]", []interface{}{ - &pgtype.MacaddrArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.MacaddrArray{Status: pgtype.Null}, - }) -} - -func TestMacaddrArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.MacaddrArray - }{ - { - source: []net.HardwareAddr{mustParseMacaddr(t, "01:23:45:67:89:ab")}, - result: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]net.HardwareAddr)(nil)), - result: pgtype.MacaddrArray{Status: pgtype.Null}, - }, - { - source: [][]net.HardwareAddr{ - {mustParseMacaddr(t, "01:23:45:67:89:ab")}, - {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, - result: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]net.HardwareAddr{ - {{{ - mustParseMacaddr(t, "01:23:45:67:89:ab"), - mustParseMacaddr(t, "cd:ef:01:23:45:67"), - mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, - {{{ - mustParseMacaddr(t, "45:67:89:ab:cd:ef"), - mustParseMacaddr(t, "fe:dc:ba:98:76:54"), - mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, - result: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1]net.HardwareAddr{ - {mustParseMacaddr(t, "01:23:45:67:89:ab")}, - {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, - result: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3]net.HardwareAddr{ - {{{ - mustParseMacaddr(t, "01:23:45:67:89:ab"), - mustParseMacaddr(t, "cd:ef:01:23:45:67"), - mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, - {{{ - mustParseMacaddr(t, "45:67:89:ab:cd:ef"), - mustParseMacaddr(t, "fe:dc:ba:98:76:54"), - mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, - result: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.MacaddrArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestMacaddrArrayAssignTo(t *testing.T) { - var macaddrSlice []net.HardwareAddr - var macaddrSliceDim2 [][]net.HardwareAddr - var macaddrSliceDim4 [][][][]net.HardwareAddr - var macaddrArrayDim2 [2][1]net.HardwareAddr - var macaddrArrayDim4 [2][1][1][3]net.HardwareAddr - - simpleTests := []struct { - src pgtype.MacaddrArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &macaddrSlice, - expected: []net.HardwareAddr{mustParseMacaddr(t, "01:23:45:67:89:ab")}, - }, - { - src: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &macaddrSlice, - expected: []net.HardwareAddr{nil}, - }, - { - src: pgtype.MacaddrArray{Status: pgtype.Null}, - dst: &macaddrSlice, - expected: (([]net.HardwareAddr)(nil)), - }, - { - src: pgtype.MacaddrArray{Status: pgtype.Present}, - dst: &macaddrSlice, - expected: []net.HardwareAddr{}, - }, - { - src: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &macaddrSliceDim2, - expected: [][]net.HardwareAddr{ - {mustParseMacaddr(t, "01:23:45:67:89:ab")}, - {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, - }, - { - src: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &macaddrSliceDim4, - expected: [][][][]net.HardwareAddr{ - {{{ - mustParseMacaddr(t, "01:23:45:67:89:ab"), - mustParseMacaddr(t, "cd:ef:01:23:45:67"), - mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, - {{{ - mustParseMacaddr(t, "45:67:89:ab:cd:ef"), - mustParseMacaddr(t, "fe:dc:ba:98:76:54"), - mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, - }, - { - src: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &macaddrArrayDim2, - expected: [2][1]net.HardwareAddr{ - {mustParseMacaddr(t, "01:23:45:67:89:ab")}, - {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, - }, - { - src: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &macaddrArrayDim4, - expected: [2][1][1][3]net.HardwareAddr{ - {{{ - mustParseMacaddr(t, "01:23:45:67:89:ab"), - mustParseMacaddr(t, "cd:ef:01:23:45:67"), - mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, - {{{ - mustParseMacaddr(t, "45:67:89:ab:cd:ef"), - mustParseMacaddr(t, "fe:dc:ba:98:76:54"), - mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/macaddr_test.go b/macaddr_test.go deleted file mode 100644 index 364a8914..00000000 --- a/macaddr_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package pgtype_test - -import ( - "bytes" - "net" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestMacaddrTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "macaddr", []interface{}{ - &pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - &pgtype.Macaddr{Status: pgtype.Null}, - }) -} - -func TestMacaddrSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Macaddr - }{ - { - source: mustParseMacaddr(t, "01:23:45:67:89:ab"), - result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - }, - { - source: "01:23:45:67:89:ab", - result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Macaddr - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestMacaddrAssignTo(t *testing.T) { - { - src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present} - var dst net.HardwareAddr - expected := mustParseMacaddr(t, "01:23:45:67:89:ab") - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if bytes.Compare([]byte(dst), []byte(expected)) != 0 { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present} - var dst string - expected := "01:23:45:67:89:ab" - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } -} diff --git a/name_test.go b/name_test.go deleted file mode 100644 index 75329b01..00000000 --- a/name_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestNameTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "name", []interface{}{ - &pgtype.Name{String: "", Status: pgtype.Present}, - &pgtype.Name{String: "foo", Status: pgtype.Present}, - &pgtype.Name{Status: pgtype.Null}, - }) -} - -func TestNameSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Name - }{ - {source: "foo", result: pgtype.Name{String: "foo", Status: pgtype.Present}}, - {source: _string("bar"), result: pgtype.Name{String: "bar", Status: pgtype.Present}}, - {source: (*string)(nil), result: pgtype.Name{Status: pgtype.Null}}, - } - - for i, tt := range successfulTests { - var d pgtype.Name - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if d != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } - } -} - -func TestNameAssignTo(t *testing.T) { - var s string - var ps *string - - simpleTests := []struct { - src pgtype.Name - dst interface{} - expected interface{} - }{ - {src: pgtype.Name{String: "foo", Status: pgtype.Present}, dst: &s, expected: "foo"}, - {src: pgtype.Name{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Name - dst interface{} - expected interface{} - }{ - {src: pgtype.Name{String: "foo", Status: pgtype.Present}, dst: &ps, expected: "foo"}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Name - dst interface{} - }{ - {src: pgtype.Name{Status: pgtype.Null}, dst: &s}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/numeric_array_test.go b/numeric_array_test.go deleted file mode 100644 index 7c1e8c3b..00000000 --- a/numeric_array_test.go +++ /dev/null @@ -1,305 +0,0 @@ -package pgtype_test - -import ( - "math" - "math/big" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestNumericArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "numeric[]", []interface{}{ - &pgtype.NumericArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.NumericArray{Status: pgtype.Null}, - &pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Int: big.NewInt(2), Status: pgtype.Present}, - {Int: big.NewInt(3), Status: pgtype.Present}, - {Int: big.NewInt(4), Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: big.NewInt(6), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Int: big.NewInt(2), Status: pgtype.Present}, - {Int: big.NewInt(3), Status: pgtype.Present}, - {Int: big.NewInt(4), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestNumericArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.NumericArray - }{ - { - source: []float32{1}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []float32{float32(math.Copysign(0, -1))}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(0), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []float64{1}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []float64{math.Copysign(0, -1)}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(0), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]float32)(nil)), - result: pgtype.NumericArray{Status: pgtype.Null}, - }, - { - source: [][]float32{{1}, {2}}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Int: big.NewInt(2), Status: pgtype.Present}, - {Int: big.NewInt(3), Status: pgtype.Present}, - {Int: big.NewInt(4), Status: pgtype.Present}, - {Int: big.NewInt(5), Status: pgtype.Present}, - {Int: big.NewInt(6), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1]float32{{1}, {2}}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Int: big.NewInt(2), Status: pgtype.Present}, - {Int: big.NewInt(3), Status: pgtype.Present}, - {Int: big.NewInt(4), Status: pgtype.Present}, - {Int: big.NewInt(5), Status: pgtype.Present}, - {Int: big.NewInt(6), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.NumericArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestNumericArrayAssignTo(t *testing.T) { - var float32Slice []float32 - var float64Slice []float64 - var float32SliceDim2 [][]float32 - var float32SliceDim4 [][][][]float32 - var float32ArrayDim2 [2][1]float32 - var float32ArrayDim4 [2][1][1][3]float32 - - simpleTests := []struct { - src pgtype.NumericArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &float32Slice, - expected: []float32{1}, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &float64Slice, - expected: []float64{1}, - }, - { - src: pgtype.NumericArray{Status: pgtype.Null}, - dst: &float32Slice, - expected: (([]float32)(nil)), - }, - { - src: pgtype.NumericArray{Status: pgtype.Present}, - dst: &float32Slice, - expected: []float32{}, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &float32SliceDim2, - expected: [][]float32{{1}, {2}}, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Int: big.NewInt(2), Status: pgtype.Present}, - {Int: big.NewInt(3), Status: pgtype.Present}, - {Int: big.NewInt(4), Status: pgtype.Present}, - {Int: big.NewInt(5), Status: pgtype.Present}, - {Int: big.NewInt(6), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &float32SliceDim4, - expected: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &float32ArrayDim2, - expected: [2][1]float32{{1}, {2}}, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Int: big.NewInt(2), Status: pgtype.Present}, - {Int: big.NewInt(3), Status: pgtype.Present}, - {Int: big.NewInt(4), Status: pgtype.Present}, - {Int: big.NewInt(5), Status: pgtype.Present}, - {Int: big.NewInt(6), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &float32ArrayDim4, - expected: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.NumericArray - dst interface{} - }{ - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &float32Slice, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &float32ArrayDim2, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &float32Slice, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &float32ArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/numeric_test.go b/numeric_test.go deleted file mode 100644 index 81595cb3..00000000 --- a/numeric_test.go +++ /dev/null @@ -1,389 +0,0 @@ -package pgtype_test - -import ( - "math" - "math/big" - "math/rand" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -// For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0) -func numericEqual(left, right *pgtype.Numeric) bool { - return left.Status == right.Status && - left.Exp == right.Exp && - ((left.Int == nil && right.Int == nil) || (left.Int != nil && right.Int != nil && left.Int.Cmp(right.Int) == 0)) && - left.NaN == right.NaN -} - -// For test purposes only. -func numericNormalizedEqual(left, right *pgtype.Numeric) bool { - if left.Status != right.Status { - return false - } - - normLeft := &pgtype.Numeric{Int: (&big.Int{}).Set(left.Int), Status: left.Status} - normRight := &pgtype.Numeric{Int: (&big.Int{}).Set(right.Int), Status: right.Status} - - if left.Exp < right.Exp { - mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(right.Exp-left.Exp)), nil) - normRight.Int.Mul(normRight.Int, mul) - } else if left.Exp > right.Exp { - mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(left.Exp-right.Exp)), nil) - normLeft.Int.Mul(normLeft.Int, mul) - } - - return normLeft.Int.Cmp(normRight.Int) == 0 -} - -func mustParseBigInt(t *testing.T, src string) *big.Int { - i := &big.Int{} - if _, ok := i.SetString(src, 10); !ok { - t.Fatalf("could not parse big.Int: %s", src) - } - return i -} - -func TestNumericNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select '0'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, - }, - { - SQL: "select '1'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, - }, - { - SQL: "select '10.00'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present}, - }, - { - SQL: "select '1e-3'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present}, - }, - { - SQL: "select '-1'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, - }, - { - SQL: "select '10000'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present}, - }, - { - SQL: "select '3.14'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, - }, - { - SQL: "select '1.1'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present}, - }, - { - SQL: "select '100010001'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present}, - }, - { - SQL: "select '100010001.0001'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present}, - }, - { - SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", - Value: &pgtype.Numeric{ - Int: mustParseBigInt(t, "423723478923478928934789237432487213832189417894318904389012483210893443219085471578891547854892438945012347981"), - Exp: -41, - Status: pgtype.Present, - }, - }, - { - SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", - Value: &pgtype.Numeric{ - Int: mustParseBigInt(t, "8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), - Exp: -196, - Status: pgtype.Present, - }, - }, - { - SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", - Value: &pgtype.Numeric{ - Int: mustParseBigInt(t, "123"), - Exp: -186, - Status: pgtype.Present, - }, - }, - }) -} - -func TestNumericTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ - &pgtype.Numeric{NaN: true, Status: pgtype.Present}, - - &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(1), Exp: 6, Status: pgtype.Present}, - - // preserves significant zeroes - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -1, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -2, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -3, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -4, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -5, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -6, Status: pgtype.Present}, - - &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(123), Exp: -7, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(123), Exp: -8, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(123), Exp: -9, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(123), Exp: -1500, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "3723409723490243842378942378901237502734019231380123"), Exp: 81, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "723409723490243842378942378901237502734019231380123"), Exp: 82, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "23409723490243842378942378901237502734019231380123"), Exp: 83, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "3409723490243842378942378901237502734019231380123"), Exp: 84, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "3423409823409243892349028349023482934092340892390101"), Exp: -91, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "423409823409243892349028349023482934092340892390101"), Exp: -92, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "23409823409243892349028349023482934092340892390101"), Exp: -93, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "3409823409243892349028349023482934092340892390101"), Exp: -94, Status: pgtype.Present}, - &pgtype.Numeric{Status: pgtype.Null}, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.Numeric) - b := bb.(pgtype.Numeric) - - return numericEqual(&a, &b) - }) - -} - -func TestNumericTranscodeFuzz(t *testing.T) { - r := rand.New(rand.NewSource(0)) - max := &big.Int{} - max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) - - values := make([]interface{}, 0, 2000) - for i := 0; i < 10; i++ { - for j := -50; j < 50; j++ { - num := (&big.Int{}).Rand(r, max) - negNum := &big.Int{} - negNum.Neg(num) - values = append(values, &pgtype.Numeric{Int: num, Exp: int32(j), Status: pgtype.Present}) - values = append(values, &pgtype.Numeric{Int: negNum, Exp: int32(j), Status: pgtype.Present}) - } - } - - testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, - func(aa, bb interface{}) bool { - a := aa.(pgtype.Numeric) - b := bb.(pgtype.Numeric) - - return numericNormalizedEqual(&a, &b) - }) -} - -func TestNumericSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result *pgtype.Numeric - }{ - {source: float32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: float32(math.Copysign(0, -1)), result: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Present}}, - {source: float64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: float64(math.Copysign(0, -1)), result: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Present}}, - {source: int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: int16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: int32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: int64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: int8(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, - {source: int16(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, - {source: int32(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, - {source: int64(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, - {source: uint8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: uint16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: uint32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: uint64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: "1", result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: _int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: float64(1000), result: &pgtype.Numeric{Int: big.NewInt(1), Exp: 3, Status: pgtype.Present}}, - {source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Status: pgtype.Present}}, - {source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Status: pgtype.Present}}, - {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}}, - {source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, NaN: true}}, - {source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, NaN: true}}, - } - - for i, tt := range successfulTests { - r := &pgtype.Numeric{} - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !numericEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestNumericAssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - var f32 float32 - var f64 float64 - var pf32 *float32 - var pf64 *float64 - - simpleTests := []struct { - src *pgtype.Numeric - dst interface{} - expected interface{} - }{ - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f32, expected: float32(4.2)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f64, expected: float64(4.2)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: 3, Status: pgtype.Present}, dst: &i64, expected: int64(42000)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, - {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Status: pgtype.Present}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgtype/issues/27 - {src: &pgtype.Numeric{Status: pgtype.Present, NaN: true}, dst: &f64, expected: math.NaN()}, - {src: &pgtype.Numeric{Status: pgtype.Present, NaN: true}, dst: &f32, expected: float32(math.NaN())}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - dst := reflect.ValueOf(tt.dst).Elem().Interface() - switch dstTyped := dst.(type) { - case float32: - nanExpected := math.IsNaN(float64(tt.expected.(float32))) - if nanExpected && !math.IsNaN(float64(dstTyped)) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } else if !nanExpected && dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - case float64: - nanExpected := math.IsNaN(tt.expected.(float64)) - if nanExpected && !math.IsNaN(dstTyped) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } else if !nanExpected && dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - default: - if dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - } - - pointerAllocTests := []struct { - src *pgtype.Numeric - dst interface{} - expected interface{} - }{ - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src *pgtype.Numeric - dst interface{} - }{ - {src: &pgtype.Numeric{Int: big.NewInt(150), Status: pgtype.Present}, dst: &i8}, - {src: &pgtype.Numeric{Int: big.NewInt(40000), Status: pgtype.Present}, dst: &i16}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui8}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui16}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui32}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui64}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui}, - {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &i32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} - -func TestNumericEncodeDecodeBinary(t *testing.T) { - ci := pgtype.NewConnInfo() - tests := []interface{}{ - 123, - 0.000012345, - 1.00002345, - math.NaN(), - float32(math.NaN()), - } - - for i, tt := range tests { - toString := func(n *pgtype.Numeric) string { - ci := pgtype.NewConnInfo() - text, err := n.EncodeText(ci, nil) - if err != nil { - t.Errorf("%d (EncodeText): %v", i, err) - } - return string(text) - } - numeric := &pgtype.Numeric{} - numeric.Set(tt) - - encoded, err := numeric.EncodeBinary(ci, nil) - if err != nil { - t.Errorf("%d (EncodeBinary): %v", i, err) - } - decoded := &pgtype.Numeric{} - err = decoded.DecodeBinary(ci, encoded) - if err != nil { - t.Errorf("%d (DecodeBinary): %v", i, err) - } - - text0 := toString(numeric) - text1 := toString(decoded) - - if text0 != text1 { - t.Errorf("%d: expected %v to equal to %v, but doesn't", i, text0, text1) - } - } -} diff --git a/numrange_test.go b/numrange_test.go deleted file mode 100644 index 0bbb26f0..00000000 --- a/numrange_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package pgtype_test - -import ( - "math/big" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestNumrangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "numrange", []interface{}{ - &pgtype.Numrange{ - LowerType: pgtype.Empty, - UpperType: pgtype.Empty, - Status: pgtype.Present, - }, - &pgtype.Numrange{ - Lower: pgtype.Numeric{Int: big.NewInt(-543), Exp: 3, Status: pgtype.Present}, - Upper: pgtype.Numeric{Int: big.NewInt(342), Exp: 1, Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Numrange{ - Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, - Upper: pgtype.Numeric{Int: big.NewInt(-5), Exp: 0, Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Numrange{ - Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Unbounded, - Status: pgtype.Present, - }, - &pgtype.Numrange{ - Upper: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, - LowerType: pgtype.Unbounded, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Numrange{Status: pgtype.Null}, - }) -} diff --git a/oid_value_test.go b/oid_value_test.go deleted file mode 100644 index 69742dd7..00000000 --- a/oid_value_test.go +++ /dev/null @@ -1,95 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestOIDValueTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "oid", []interface{}{ - &pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, - &pgtype.OIDValue{Status: pgtype.Null}, - }) -} - -func TestOIDValueSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.OIDValue - }{ - {source: uint32(1), result: pgtype.OIDValue{Uint: 1, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.OIDValue - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestOIDValueAssignTo(t *testing.T) { - var ui32 uint32 - var pui32 *uint32 - - simpleTests := []struct { - src pgtype.OIDValue - dst interface{} - expected interface{} - }{ - {src: pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.OIDValue{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.OIDValue - dst interface{} - expected interface{} - }{ - {src: pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.OIDValue - dst interface{} - }{ - {src: pgtype.OIDValue{Status: pgtype.Null}, dst: &ui32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/path_test.go b/path_test.go deleted file mode 100644 index 969a89ec..00000000 --- a/path_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestPathTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "path", []interface{}{ - &pgtype.Path{ - P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}}, - Closed: false, - Status: pgtype.Present, - }, - &pgtype.Path{ - P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, - Closed: true, - Status: pgtype.Present, - }, - &pgtype.Path{ - P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, - Closed: true, - Status: pgtype.Present, - }, - &pgtype.Path{Status: pgtype.Null}, - }) -} diff --git a/pgtype_test.go b/pgtype_test.go deleted file mode 100644 index 75e1909f..00000000 --- a/pgtype_test.go +++ /dev/null @@ -1,292 +0,0 @@ -package pgtype_test - -import ( - "bytes" - "errors" - "net" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgx/v4" - _ "github.com/jackc/pgx/v4/stdlib" - _ "github.com/lib/pq" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// Test for renamed types -type _string string -type _bool bool -type _int8 int8 -type _int16 int16 -type _int16Slice []int16 -type _int32Slice []int32 -type _int64Slice []int64 -type _float32Slice []float32 -type _float64Slice []float64 -type _byteSlice []byte - -func mustParseCIDR(t testing.TB, s string) *net.IPNet { - _, ipnet, err := net.ParseCIDR(s) - if err != nil { - t.Fatal(err) - } - - return ipnet -} - -func mustParseInet(t testing.TB, s string) *net.IPNet { - ip, ipnet, err := net.ParseCIDR(s) - if err != nil { - t.Fatal(err) - } - if ipv4 := ip.To4(); ipv4 != nil { - ip = ipv4 - } - - ipnet.IP = ip - - return ipnet -} - -func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { - addr, err := net.ParseMAC(s) - if err != nil { - t.Fatal(err) - } - - return addr -} - -func TestConnInfoResultFormatCodeForOID(t *testing.T) { - ci := pgtype.NewConnInfo() - - // pgtype.JSONB implements BinaryDecoder but also implements ResultFormatPreferrer to override it to text. - assert.Equal(t, int16(pgtype.TextFormatCode), ci.ResultFormatCodeForOID(pgtype.JSONBOID)) - - // pgtype.Int4 implements BinaryDecoder but does not implement ResultFormatPreferrer so it should be binary. - assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.ResultFormatCodeForOID(pgtype.Int4OID)) -} - -func TestConnInfoParamFormatCodeForOID(t *testing.T) { - ci := pgtype.NewConnInfo() - - // pgtype.JSONB implements BinaryEncoder but also implements ParamFormatPreferrer to override it to text. - assert.Equal(t, int16(pgtype.TextFormatCode), ci.ParamFormatCodeForOID(pgtype.JSONBOID)) - - // pgtype.Int4 implements BinaryEncoder but does not implement ParamFormatPreferrer so it should be binary. - assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.ParamFormatCodeForOID(pgtype.Int4OID)) -} - -func TestConnInfoScanNilIsNoOp(t *testing.T) { - ci := pgtype.NewConnInfo() - - err := ci.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), nil) - assert.NoError(t, err) -} - -func TestConnInfoScanTextFormatInterfacePtr(t *testing.T) { - ci := pgtype.NewConnInfo() - var got interface{} - err := ci.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), &got) - require.NoError(t, err) - assert.Equal(t, "foo", got) -} - -func TestConnInfoScanTextFormatNonByteaIntoByteSlice(t *testing.T) { - ci := pgtype.NewConnInfo() - var got []byte - err := ci.Scan(pgtype.JSONBOID, pgx.TextFormatCode, []byte("{}"), &got) - require.NoError(t, err) - assert.Equal(t, []byte("{}"), got) -} - -func TestConnInfoScanBinaryFormatInterfacePtr(t *testing.T) { - ci := pgtype.NewConnInfo() - var got interface{} - err := ci.Scan(pgtype.TextOID, pgx.BinaryFormatCode, []byte("foo"), &got) - require.NoError(t, err) - assert.Equal(t, "foo", got) -} - -func TestConnInfoScanUnknownOIDToStringsAndBytes(t *testing.T) { - unknownOID := uint32(999999) - srcBuf := []byte("foo") - ci := pgtype.NewConnInfo() - - var s string - err := ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &s) - assert.NoError(t, err) - assert.Equal(t, "foo", s) - - var rs _string - err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rs) - assert.NoError(t, err) - assert.Equal(t, "foo", string(rs)) - - var b []byte - err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &b) - assert.NoError(t, err) - assert.Equal(t, []byte("foo"), b) - - err = ci.Scan(unknownOID, pgx.BinaryFormatCode, srcBuf, &b) - assert.NoError(t, err) - assert.Equal(t, []byte("foo"), b) - - var rb _byteSlice - err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rb) - assert.NoError(t, err) - assert.Equal(t, []byte("foo"), []byte(rb)) - - err = ci.Scan(unknownOID, pgx.BinaryFormatCode, srcBuf, &b) - assert.NoError(t, err) - assert.Equal(t, []byte("foo"), []byte(rb)) -} - -type pgCustomType struct { - a string - b string -} - -func (ct *pgCustomType) DecodeText(ci *pgtype.ConnInfo, buf []byte) error { - // This is not a complete parser for the text format of composite types. This is just for test purposes. - if buf == nil { - return errors.New("cannot parse null") - } - - if len(buf) < 2 { - return errors.New("invalid text format") - } - - parts := bytes.Split(buf[1:len(buf)-1], []byte(",")) - if len(parts) != 2 { - return errors.New("wrong number of parts") - } - - ct.a = string(parts[0]) - ct.b = string(parts[1]) - - return nil -} - -func TestConnInfoScanUnregisteredOIDToCustomType(t *testing.T) { - unregisteredOID := uint32(999999) - ci := pgtype.NewConnInfo() - - var ct pgCustomType - err := ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct) - assert.NoError(t, err) - assert.Equal(t, "foo", ct.a) - assert.Equal(t, "bar", ct.b) - - // Scan value into pointer to custom type - var pCt *pgCustomType - err = ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt) - assert.NoError(t, err) - require.NotNil(t, pCt) - assert.Equal(t, "foo", pCt.a) - assert.Equal(t, "bar", pCt.b) - - // Scan null into pointer to custom type - err = ci.Scan(unregisteredOID, pgx.TextFormatCode, nil, &pCt) - assert.NoError(t, err) - assert.Nil(t, pCt) -} - -func TestConnInfoScanUnknownOIDTextFormat(t *testing.T) { - ci := pgtype.NewConnInfo() - - var n int32 - err := ci.Scan(0, pgx.TextFormatCode, []byte("123"), &n) - assert.NoError(t, err) - assert.EqualValues(t, 123, n) -} - -func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) { - ci := pgtype.NewConnInfo() - src := []byte{0, 0, 0, 42} - var v pgtype.Int4 - - for i := 0; i < b.N; i++ { - v = pgtype.Int4{} - err := ci.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) - if err != nil { - b.Fatal(err) - } - if v != (pgtype.Int4{Int: 42, Status: pgtype.Present}) { - b.Fatal("scan failed due to bad value") - } - } -} - -func TestScanPlanBinaryInt32ScanChangedType(t *testing.T) { - ci := pgtype.NewConnInfo() - src := []byte{0, 0, 0, 42} - var v int32 - - plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) - err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) - require.NoError(t, err) - require.EqualValues(t, 42, v) - - var d pgtype.Int4 - err = plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &d) - require.NoError(t, err) - require.EqualValues(t, 42, d.Int) - require.EqualValues(t, pgtype.Present, d.Status) -} - -func BenchmarkConnInfoScanInt4IntoGoInt32(b *testing.B) { - ci := pgtype.NewConnInfo() - src := []byte{0, 0, 0, 42} - var v int32 - - for i := 0; i < b.N; i++ { - v = 0 - err := ci.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) - if err != nil { - b.Fatal(err) - } - if v != 42 { - b.Fatal("scan failed due to bad value") - } - } -} - -func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) { - ci := pgtype.NewConnInfo() - src := []byte{0, 0, 0, 42} - var v pgtype.Int4 - - plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) - - for i := 0; i < b.N; i++ { - v = pgtype.Int4{} - err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) - if err != nil { - b.Fatal(err) - } - if v != (pgtype.Int4{Int: 42, Status: pgtype.Present}) { - b.Fatal("scan failed due to bad value") - } - } -} - -func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { - ci := pgtype.NewConnInfo() - src := []byte{0, 0, 0, 42} - var v int32 - - plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) - - for i := 0; i < b.N; i++ { - v = 0 - err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) - if err != nil { - b.Fatal(err) - } - if v != 42 { - b.Fatal("scan failed due to bad value") - } - } -} diff --git a/pgxtype/README.md b/pgxtype/README.md deleted file mode 100644 index a070111f..00000000 --- a/pgxtype/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# pgxtype - -pgxtype is a helper module that connects pgx and pgtype. This package is not currently covered by semantic version guarantees. i.e. The interfaces may change without a major version release of pgtype. diff --git a/pgxtype/pgxtype.go b/pgxtype/pgxtype.go deleted file mode 100644 index 041f2545..00000000 --- a/pgxtype/pgxtype.go +++ /dev/null @@ -1,145 +0,0 @@ -package pgxtype - -import ( - "context" - "errors" - - "github.com/jackc/pgconn" - "github.com/jackc/pgtype" - "github.com/jackc/pgx/v4" -) - -type Querier interface { - Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) - Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) - QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row -} - -// LoadDataType uses conn to inspect the database for typeName and produces a pgtype.DataType suitable for -// registration on ci. -func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeName string) (pgtype.DataType, error) { - var oid uint32 - - err := conn.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid) - if err != nil { - return pgtype.DataType{}, err - } - - var typtype string - - err = conn.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype) - if err != nil { - return pgtype.DataType{}, err - } - - switch typtype { - case "b": // array - elementOID, err := GetArrayElementOID(ctx, conn, oid) - if err != nil { - return pgtype.DataType{}, err - } - - var element pgtype.ValueTranscoder - if dt, ok := ci.DataTypeForOID(elementOID); ok { - if element, ok = dt.Value.(pgtype.ValueTranscoder); !ok { - return pgtype.DataType{}, errors.New("array element OID not registered as ValueTranscoder") - } - } - - newElement := func() pgtype.ValueTranscoder { - return pgtype.NewValue(element).(pgtype.ValueTranscoder) - } - - at := pgtype.NewArrayType(typeName, elementOID, newElement) - return pgtype.DataType{Value: at, Name: typeName, OID: oid}, nil - case "c": // composite - fields, err := GetCompositeFields(ctx, conn, oid) - if err != nil { - return pgtype.DataType{}, err - } - ct, err := pgtype.NewCompositeType(typeName, fields, ci) - if err != nil { - return pgtype.DataType{}, err - } - return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil - case "e": // enum - members, err := GetEnumMembers(ctx, conn, oid) - if err != nil { - return pgtype.DataType{}, err - } - return pgtype.DataType{Value: pgtype.NewEnumType(typeName, members), Name: typeName, OID: oid}, nil - default: - return pgtype.DataType{}, errors.New("unknown typtype") - } -} - -func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32, error) { - var typelem uint32 - - err := conn.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem) - if err != nil { - return 0, err - } - - return typelem, nil -} - -// GetCompositeFields gets the fields of a composite type. -func GetCompositeFields(ctx context.Context, conn Querier, oid uint32) ([]pgtype.CompositeTypeField, error) { - var typrelid uint32 - - err := conn.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid) - if err != nil { - return nil, err - } - - var fields []pgtype.CompositeTypeField - - rows, err := conn.Query(ctx, `select attname, atttypid -from pg_attribute -where attrelid=$1 -order by attnum`, typrelid) - if err != nil { - return nil, err - } - - for rows.Next() { - var f pgtype.CompositeTypeField - err := rows.Scan(&f.Name, &f.OID) - if err != nil { - return nil, err - } - fields = append(fields, f) - } - - if rows.Err() != nil { - return nil, rows.Err() - } - - return fields, nil -} - -// GetEnumMembers gets the possible values of the enum by oid. -func GetEnumMembers(ctx context.Context, conn Querier, oid uint32) ([]string, error) { - members := []string{} - - rows, err := conn.Query(ctx, "select enumlabel from pg_enum where enumtypid=$1 order by enumsortorder", oid) - if err != nil { - return nil, err - } - - for rows.Next() { - var m string - err := rows.Scan(&m) - if err != nil { - return nil, err - } - members = append(members, m) - } - - if rows.Err() != nil { - return nil, rows.Err() - } - - return members, nil -} diff --git a/point_test.go b/point_test.go deleted file mode 100644 index 63f8df07..00000000 --- a/point_test.go +++ /dev/null @@ -1,150 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestPointTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "point", []interface{}{ - &pgtype.Point{P: pgtype.Vec2{1.234, 5.6789012345}, Status: pgtype.Present}, - &pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Status: pgtype.Present}, - &pgtype.Point{Status: pgtype.Null}, - }) -} - -func TestPoint_Set(t *testing.T) { - tests := []struct { - name string - arg interface{} - status pgtype.Status - wantErr bool - }{ - { - name: "first", - arg: "(12312.123123,123123.123123)", - status: pgtype.Present, - wantErr: false, - }, - { - name: "second", - arg: "(1231s2.123123,123123.123123)", - status: pgtype.Undefined, - wantErr: true, - }, - { - name: "third", - arg: []byte("(122.123123,123.123123)"), - status: pgtype.Present, - wantErr: false, - }, - { - name: "third", - arg: nil, - status: pgtype.Null, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - dst := &pgtype.Point{} - if err := dst.Set(tt.arg); (err != nil) != tt.wantErr { - t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr) - } - if dst.Status != tt.status { - t.Errorf("Expected status: %v; got: %v", tt.status, dst.Status) - } - }) - } -} - -func TestPoint_MarshalJSON(t *testing.T) { - tests := []struct { - name string - point pgtype.Point - want []byte - wantErr bool - }{ - { - name: "first", - point: pgtype.Point{ - P: pgtype.Vec2{}, - Status: pgtype.Undefined, - }, - want: nil, - wantErr: true, - }, - { - name: "second", - point: pgtype.Point{ - P: pgtype.Vec2{X: 12.245, Y: 432.12}, - Status: pgtype.Present, - }, - want: []byte(`"(12.245,432.12)"`), - wantErr: false, - }, - { - name: "third", - point: pgtype.Point{ - P: pgtype.Vec2{}, - Status: pgtype.Null, - }, - want: []byte("null"), - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.point.MarshalJSON() - if (err != nil) != tt.wantErr { - t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestPoint_UnmarshalJSON(t *testing.T) { - tests := []struct { - name string - status pgtype.Status - arg []byte - wantErr bool - }{ - { - name: "first", - status: pgtype.Present, - arg: []byte(`"(123.123,54.12)"`), - wantErr: false, - }, - { - name: "second", - status: pgtype.Undefined, - arg: []byte(`"(123.123,54.1sad2)"`), - wantErr: true, - }, - { - name: "third", - status: pgtype.Null, - arg: []byte("null"), - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - dst := &pgtype.Point{} - if err := dst.UnmarshalJSON(tt.arg); (err != nil) != tt.wantErr { - t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) - } - if dst.Status != tt.status { - t.Errorf("Status mismatch: %v != %v", dst.Status, tt.status) - } - }) - } -} diff --git a/polygon_test.go b/polygon_test.go deleted file mode 100644 index 1a139444..00000000 --- a/polygon_test.go +++ /dev/null @@ -1,89 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestPolygonTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "polygon", []interface{}{ - &pgtype.Polygon{ - P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, - Status: pgtype.Present, - }, - &pgtype.Polygon{ - P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, - Status: pgtype.Present, - }, - &pgtype.Polygon{Status: pgtype.Null}, - }) -} - -func TestPolygon_Set(t *testing.T) { - tests := []struct { - name string - arg interface{} - status pgtype.Status - wantErr bool - }{ - { - name: "string", - arg: "((3.14,1.678901234),(7.1,5.234),(5.0,3.234))", - status: pgtype.Present, - wantErr: false, - }, { - name: "[]float64", - arg: []float64{1, 2, 3.45, 6.78, 1.23, 4.567, 8.9, 1.0}, - status: pgtype.Present, - wantErr: false, - }, { - name: "[]Vec2", - arg: []pgtype.Vec2{{1, 2}, {2.3, 4.5}, {6.78, 9.123}}, - status: pgtype.Present, - wantErr: false, - }, { - name: "null", - arg: nil, - status: pgtype.Null, - wantErr: false, - }, { - name: "invalid_string_1", - arg: "((3.14,1.678901234),(7.1,5.234),(5.0,3.234x))", - status: pgtype.Undefined, - wantErr: true, - }, { - name: "invalid_string_2", - arg: "(3,4)", - status: pgtype.Undefined, - wantErr: true, - }, { - name: "invalid_[]float64", - arg: []float64{1, 2, 3.45, 6.78, 1.23, 4.567, 8.9}, - status: pgtype.Undefined, - wantErr: true, - }, { - name: "invalid_type", - arg: []int{1, 2, 3, 6}, - status: pgtype.Undefined, - wantErr: true, - }, { - name: "empty_[]float64", - arg: []float64{}, - status: pgtype.Null, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - dst := &pgtype.Polygon{} - if err := dst.Set(tt.arg); (err != nil) != tt.wantErr { - t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr) - } - if dst.Status != tt.status { - t.Errorf("Expected status: %v; got: %v", tt.status, dst.Status) - } - }) - } -} diff --git a/qchar_test.go b/qchar_test.go deleted file mode 100644 index 4b60339c..00000000 --- a/qchar_test.go +++ /dev/null @@ -1,143 +0,0 @@ -package pgtype_test - -import ( - "math" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestQCharTranscode(t *testing.T) { - testutil.TestPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ - &pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, - &pgtype.QChar{Int: -1, Status: pgtype.Present}, - &pgtype.QChar{Int: 0, Status: pgtype.Present}, - &pgtype.QChar{Int: 1, Status: pgtype.Present}, - &pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present}, - &pgtype.QChar{Int: 0, Status: pgtype.Null}, - }, func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - }) -} - -func TestQCharSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.QChar - }{ - {source: int8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.QChar - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestQCharAssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - - simpleTests := []struct { - src pgtype.QChar - dst interface{} - expected interface{} - }{ - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.QChar - dst interface{} - expected interface{} - }{ - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.QChar - dst interface{} - }{ - {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &i16}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/range_test.go b/range_test.go deleted file mode 100644 index 9e16df59..00000000 --- a/range_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package pgtype - -import ( - "bytes" - "testing" -) - -func TestParseUntypedTextRange(t *testing.T) { - tests := []struct { - src string - result UntypedTextRange - err error - }{ - { - src: `[1,2)`, - result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `[1,2]`, - result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive}, - err: nil, - }, - { - src: `(1,3)`, - result: UntypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: ` [1,2) `, - result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `[ foo , bar )`, - result: UntypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `["foo","bar")`, - result: UntypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `["f""oo","b""ar")`, - result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `["f""oo","b""ar")`, - result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `["","bar")`, - result: UntypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `[f\"oo\,,b\\ar\))`, - result: UntypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `empty`, - result: UntypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty}, - err: nil, - }, - } - - for i, tt := range tests { - r, err := ParseUntypedTextRange(tt.src) - if err != tt.err { - t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) - continue - } - - if r.LowerType != tt.result.LowerType { - t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) - } - - if r.UpperType != tt.result.UpperType { - t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) - } - - if r.Lower != tt.result.Lower { - t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) - } - - if r.Upper != tt.result.Upper { - t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) - } - } -} - -func TestParseUntypedBinaryRange(t *testing.T) { - tests := []struct { - src []byte - result UntypedBinaryRange - err error - }{ - { - src: []byte{0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: []byte{1}, - result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty}, - err: nil, - }, - { - src: []byte{2, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: []byte{4, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive}, - err: nil, - }, - { - src: []byte{6, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive}, - err: nil, - }, - { - src: []byte{8, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive}, - err: nil, - }, - { - src: []byte{12, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive}, - err: nil, - }, - { - src: []byte{16, 0, 0, 0, 2, 0, 4}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded}, - err: nil, - }, - { - src: []byte{18, 0, 0, 0, 2, 0, 4}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded}, - err: nil, - }, - { - src: []byte{24}, - result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded}, - err: nil, - }, - } - - for i, tt := range tests { - r, err := ParseUntypedBinaryRange(tt.src) - if err != tt.err { - t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) - continue - } - - if r.LowerType != tt.result.LowerType { - t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) - } - - if r.UpperType != tt.result.UpperType { - t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) - } - - if bytes.Compare(r.Lower, tt.result.Lower) != 0 { - t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) - } - - if bytes.Compare(r.Upper, tt.result.Upper) != 0 { - t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) - } - } -} diff --git a/record_test.go b/record_test.go deleted file mode 100644 index 240812a6..00000000 --- a/record_test.go +++ /dev/null @@ -1,186 +0,0 @@ -package pgtype_test - -import ( - "context" - "fmt" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgx/v4" -) - -var recordTests = []struct { - sql string - expected pgtype.Record -}{ - { - sql: `select row()`, - expected: pgtype.Record{ - Fields: []pgtype.Value{}, - Status: pgtype.Present, - }, - }, - { - sql: `select row('foo'::text, 42::int4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, - }, - Status: pgtype.Present, - }, - }, - { - sql: `select row(100.0::float4, 1.09::float4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Float4{Float: 100, Status: pgtype.Present}, - &pgtype.Float4{Float: 1.09, Status: pgtype.Present}, - }, - Status: pgtype.Present, - }, - }, - { - sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: 4, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, - }, - Status: pgtype.Present, - }, - }, - { - sql: `select row(null)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Unknown{Status: pgtype.Null}, - }, - Status: pgtype.Present, - }, - }, - { - sql: `select null::record`, - expected: pgtype.Record{ - Status: pgtype.Null, - }, - }, -} - -func TestRecordTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - for i, tt := range recordTests { - psName := fmt.Sprintf("test%d", i) - _, err := conn.Prepare(context.Background(), psName, tt.sql) - if err != nil { - t.Fatal(err) - } - - t.Run(tt.sql, func(t *testing.T) { - var result pgtype.Record - if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil { - t.Errorf("%v", err) - return - } - - if !reflect.DeepEqual(tt.expected, result) { - t.Errorf("expected %#v, got %#v", tt.expected, result) - } - }) - - } -} - -func TestRecordWithUnknownOID(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - _, err := conn.Exec(context.Background(), `drop type if exists floatrange; - -create type floatrange as range ( - subtype = float8, - subtype_diff = float8mi -);`) - if err != nil { - t.Fatal(err) - } - defer conn.Exec(context.Background(), "drop type floatrange") - - var result pgtype.Record - err = conn.QueryRow(context.Background(), "select row('foo'::text, floatrange(1, 10), 'bar'::text)").Scan(&result) - if err == nil { - t.Errorf("expected error but none") - } -} - -func TestRecordAssignTo(t *testing.T) { - var valueSlice []pgtype.Value - var interfaceSlice []interface{} - - simpleTests := []struct { - src pgtype.Record - dst interface{} - expected interface{} - }{ - { - src: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, - }, - Status: pgtype.Present, - }, - dst: &valueSlice, - expected: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, - }, - }, - { - src: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, - }, - Status: pgtype.Present, - }, - dst: &interfaceSlice, - expected: []interface{}{"foo", int32(42)}, - }, - { - src: pgtype.Record{Status: pgtype.Null}, - dst: &valueSlice, - expected: (([]pgtype.Value)(nil)), - }, - { - src: pgtype.Record{Status: pgtype.Null}, - dst: &interfaceSlice, - expected: (([]interface{})(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/testutil/testutil.go b/testutil/testutil.go deleted file mode 100644 index e7b64b58..00000000 --- a/testutil/testutil.go +++ /dev/null @@ -1,436 +0,0 @@ -package testutil - -import ( - "context" - "database/sql" - "fmt" - "os" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgx/v4" - _ "github.com/jackc/pgx/v4/stdlib" - _ "github.com/lib/pq" -) - -func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { - var sqlDriverName string - switch driverName { - case "github.com/lib/pq": - sqlDriverName = "postgres" - case "github.com/jackc/pgx/stdlib": - sqlDriverName = "pgx" - default: - t.Fatalf("Unknown driver %v", driverName) - } - - db, err := sql.Open(sqlDriverName, os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - t.Fatal(err) - } - - return db -} - -func MustConnectPgx(t testing.TB) *pgx.Conn { - conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - t.Fatal(err) - } - - return conn -} - -func MustClose(t testing.TB, conn interface { - Close() error -}) { - err := conn.Close() - if err != nil { - t.Fatal(err) - } -} - -func MustCloseContext(t testing.TB, conn interface { - Close(context.Context) error -}) { - err := conn.Close(context.Background()) - if err != nil { - t.Fatal(err) - } -} - -type forceTextEncoder struct { - e pgtype.TextEncoder -} - -func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - return f.e.EncodeText(ci, buf) -} - -type forceBinaryEncoder struct { - e pgtype.BinaryEncoder -} - -func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - return f.e.EncodeBinary(ci, buf) -} - -func ForceEncoder(e interface{}, formatCode int16) interface{} { - switch formatCode { - case pgx.TextFormatCode: - if e, ok := e.(pgtype.TextEncoder); ok { - return forceTextEncoder{e: e} - } - case pgx.BinaryFormatCode: - if e, ok := e.(pgtype.BinaryEncoder); ok { - return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} - } - } - return nil -} - -func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { - TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - }) -} - -func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) - } -} - -func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := MustConnectPgx(t) - defer MustCloseContext(t, conn) - - _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for i, v := range values { - for _, paramFormat := range formats { - for _, resultFormat := range formats { - vEncoder := ForceEncoder(v, paramFormat.formatCode) - if vEncoder == nil { - t.Logf("Skipping Param %s Result %s: %#v does not implement %v for encoding", paramFormat.name, resultFormat.name, v, paramFormat.name) - continue - } - switch resultFormat.formatCode { - case pgx.TextFormatCode: - if _, ok := v.(pgtype.TextEncoder); !ok { - t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name) - continue - } - case pgx.BinaryFormatCode: - if _, ok := v.(pgtype.BinaryEncoder); !ok { - t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name) - continue - } - } - - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - - err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{resultFormat.formatCode}, vEncoder).Scan(result.Interface()) - if err != nil { - t.Errorf("Param %s Result %s %d: %v", paramFormat.name, resultFormat.name, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("Param %s Result %s %d: expected %v, got %v", paramFormat.name, resultFormat.name, i, derefV, result.Elem().Interface()) - } - } - } - } -} - -func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := MustConnectDatabaseSQL(t, driverName) - defer MustClose(t, conn) - - ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - for i, v := range values { - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err := ps.QueryRow(v).Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", driverName, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) - } - } -} - -type NormalizeTest struct { - SQL string - Value interface{} -} - -func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) { - TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - }) -} - -func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { - TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - TestDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc) - } -} - -func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { - conn := MustConnectPgx(t) - defer MustCloseContext(t, conn) - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for i, tt := range tests { - for _, fc := range formats { - psName := fmt.Sprintf("test%d", i) - _, err := conn.Prepare(context.Background(), psName, tt.SQL) - if err != nil { - t.Fatal(err) - } - - queryResultFormats := pgx.QueryResultFormats{fc.formatCode} - if ForceEncoder(tt.Value, fc.formatCode) == nil { - t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name) - continue - } - // Derefence value if it is a pointer - derefV := tt.Value - refVal := reflect.ValueOf(tt.Value) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err = conn.QueryRow(context.Background(), psName, queryResultFormats).Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", fc.name, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) - } - } - } -} - -func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { - conn := MustConnectDatabaseSQL(t, driverName) - defer MustClose(t, conn) - - for i, tt := range tests { - ps, err := conn.Prepare(tt.SQL) - if err != nil { - t.Errorf("%d. %v", i, err) - continue - } - - // Derefence value if it is a pointer - derefV := tt.Value - refVal := reflect.ValueOf(tt.Value) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err = ps.QueryRow().Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", driverName, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) - } - } -} - -func TestGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) { - TestPgxGoZeroToNullConversion(t, pgTypeName, zero) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - TestDatabaseSQLGoZeroToNullConversion(t, driverName, pgTypeName, zero) - } -} - -func TestNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) { - TestPgxNullToGoZeroConversion(t, pgTypeName, zero) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - TestDatabaseSQLNullToGoZeroConversion(t, driverName, pgTypeName, zero) - } -} - -func TestPgxGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) { - conn := MustConnectPgx(t) - defer MustCloseContext(t, conn) - - _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s is null", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for _, paramFormat := range formats { - vEncoder := ForceEncoder(zero, paramFormat.formatCode) - if vEncoder == nil { - t.Logf("Skipping Param %s: %#v does not implement %v for encoding", paramFormat.name, zero, paramFormat.name) - continue - } - - var result bool - err := conn.QueryRow(context.Background(), "test", vEncoder).Scan(&result) - if err != nil { - t.Errorf("Param %s: %v", paramFormat.name, err) - } - - if !result { - t.Errorf("Param %s: did not convert zero to null", paramFormat.name) - } - } -} - -func TestPgxNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) { - conn := MustConnectPgx(t) - defer MustCloseContext(t, conn) - - _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select null::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for _, resultFormat := range formats { - - switch resultFormat.formatCode { - case pgx.TextFormatCode: - if _, ok := zero.(pgtype.TextEncoder); !ok { - t.Logf("Skipping Result %s: %#v does not implement %v for decoding", resultFormat.name, zero, resultFormat.name) - continue - } - case pgx.BinaryFormatCode: - if _, ok := zero.(pgtype.BinaryEncoder); !ok { - t.Logf("Skipping Result %s: %#v does not implement %v for decoding", resultFormat.name, zero, resultFormat.name) - continue - } - } - - // Derefence value if it is a pointer - derefZero := zero - refVal := reflect.ValueOf(zero) - if refVal.Kind() == reflect.Ptr { - derefZero = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefZero)) - - err := conn.QueryRow(context.Background(), "test").Scan(result.Interface()) - if err != nil { - t.Errorf("Result %s: %v", resultFormat.name, err) - } - - if !reflect.DeepEqual(result.Elem().Interface(), derefZero) { - t.Errorf("Result %s: did not convert null to zero", resultFormat.name) - } - } -} - -func TestDatabaseSQLGoZeroToNullConversion(t testing.TB, driverName, pgTypeName string, zero interface{}) { - conn := MustConnectDatabaseSQL(t, driverName) - defer MustClose(t, conn) - - ps, err := conn.Prepare(fmt.Sprintf("select $1::%s is null", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - var result bool - err = ps.QueryRow(zero).Scan(&result) - if err != nil { - t.Errorf("%v %v", driverName, err) - } - - if !result { - t.Errorf("%v: did not convert zero to null", driverName) - } -} - -func TestDatabaseSQLNullToGoZeroConversion(t testing.TB, driverName, pgTypeName string, zero interface{}) { - conn := MustConnectDatabaseSQL(t, driverName) - defer MustClose(t, conn) - - ps, err := conn.Prepare(fmt.Sprintf("select null::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - // Derefence value if it is a pointer - derefZero := zero - refVal := reflect.ValueOf(zero) - if refVal.Kind() == reflect.Ptr { - derefZero = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefZero)) - - err = ps.QueryRow().Scan(result.Interface()) - if err != nil { - t.Errorf("%v %v", driverName, err) - } - - if !reflect.DeepEqual(result.Elem().Interface(), derefZero) { - t.Errorf("%s: did not convert null to zero", driverName) - } -} diff --git a/text_array_test.go b/text_array_test.go deleted file mode 100644 index a5d050f6..00000000 --- a/text_array_test.go +++ /dev/null @@ -1,294 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// https://github.com/jackc/pgtype/issues/78 -func TestTextArrayDecodeTextNull(t *testing.T) { - textArray := &pgtype.TextArray{} - err := textArray.DecodeText(nil, []byte(`{abc,"NULL",NULL,def}`)) - require.NoError(t, err) - require.Len(t, textArray.Elements, 4) - assert.Equal(t, pgtype.Present, textArray.Elements[1].Status) - assert.Equal(t, pgtype.Null, textArray.Elements[2].Status) -} - -func TestTextArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "text[]", []interface{}{ - &pgtype.TextArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "foo", Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.TextArray{Status: pgtype.Null}, - &pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "bar ", Status: pgtype.Present}, - {String: "NuLL", Status: pgtype.Present}, - {String: `wow"quz\`, Status: pgtype.Present}, - {String: "", Status: pgtype.Present}, - {Status: pgtype.Null}, - {String: "null", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "quz", Status: pgtype.Present}, - {String: "foo", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestTextArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.TextArray - }{ - { - source: []string{"foo"}, - result: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]string)(nil)), - result: pgtype.TextArray{Status: pgtype.Null}, - }, - { - source: [][]string{{"foo"}, {"bar"}}, - result: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - result: pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1]string{{"foo"}, {"bar"}}, - result: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - result: pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.TextArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestTextArrayAssignTo(t *testing.T) { - var stringSlice []string - type _stringSlice []string - var namedStringSlice _stringSlice - var stringSliceDim2 [][]string - var stringSliceDim4 [][][][]string - var stringArrayDim2 [2][1]string - var stringArrayDim4 [2][1][1][3]string - - simpleTests := []struct { - src pgtype.TextArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - expected: []string{"foo"}, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedStringSlice, - expected: _stringSlice{"bar"}, - }, - { - src: pgtype.TextArray{Status: pgtype.Null}, - dst: &stringSlice, - expected: (([]string)(nil)), - }, - { - src: pgtype.TextArray{Status: pgtype.Present}, - dst: &stringSlice, - expected: []string{}, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &stringSliceDim2, - expected: [][]string{{"foo"}, {"bar"}}, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &stringSliceDim4, - expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &stringArrayDim2, - expected: [2][1]string{{"foo"}, {"bar"}}, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &stringArrayDim4, - expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.TextArray - dst interface{} - }{ - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &stringArrayDim2, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &stringSlice, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &stringArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/text_test.go b/text_test.go deleted file mode 100644 index cca3a05d..00000000 --- a/text_test.go +++ /dev/null @@ -1,164 +0,0 @@ -package pgtype_test - -import ( - "bytes" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestTextTranscode(t *testing.T) { - for _, pgTypeName := range []string{"text", "varchar"} { - testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ - &pgtype.Text{String: "", Status: pgtype.Present}, - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Text{Status: pgtype.Null}, - }) - } -} - -func TestTextSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Text - }{ - {source: "foo", result: pgtype.Text{String: "foo", Status: pgtype.Present}}, - {source: _string("bar"), result: pgtype.Text{String: "bar", Status: pgtype.Present}}, - {source: (*string)(nil), result: pgtype.Text{Status: pgtype.Null}}, - } - - for i, tt := range successfulTests { - var d pgtype.Text - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if d != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } - } -} - -func TestTextAssignTo(t *testing.T) { - var s string - var ps *string - - stringTests := []struct { - src pgtype.Text - dst interface{} - expected interface{} - }{ - {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &s, expected: "foo"}, - {src: pgtype.Text{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, - } - - for i, tt := range stringTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - var buf []byte - - bytesTests := []struct { - src pgtype.Text - dst *[]byte - expected []byte - }{ - {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &buf, expected: []byte("foo")}, - {src: pgtype.Text{Status: pgtype.Null}, dst: &buf, expected: nil}, - } - - for i, tt := range bytesTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if bytes.Compare(*tt.dst, tt.expected) != 0 { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Text - dst interface{} - expected interface{} - }{ - {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &ps, expected: "foo"}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Text - dst interface{} - }{ - {src: pgtype.Text{Status: pgtype.Null}, dst: &s}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} - -func TestTextMarshalJSON(t *testing.T) { - successfulTests := []struct { - source pgtype.Text - result string - }{ - {source: pgtype.Text{String: "", Status: pgtype.Null}, result: "null"}, - {source: pgtype.Text{String: "a", Status: pgtype.Present}, result: "\"a\""}, - } - for i, tt := range successfulTests { - r, err := tt.source.MarshalJSON() - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if string(r) != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) - } - } -} - -func TestTextUnmarshalJSON(t *testing.T) { - successfulTests := []struct { - source string - result pgtype.Text - }{ - {source: "null", result: pgtype.Text{String: "", Status: pgtype.Null}}, - {source: "\"a\"", result: pgtype.Text{String: "a", Status: pgtype.Present}}, - } - for i, tt := range successfulTests { - var r pgtype.Text - err := r.UnmarshalJSON([]byte(tt.source)) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} diff --git a/tid_test.go b/tid_test.go deleted file mode 100644 index 818be8af..00000000 --- a/tid_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestTIDTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "tid", []interface{}{ - &pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, - &pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, - &pgtype.TID{Status: pgtype.Null}, - }) -} - -func TestTIDAssignTo(t *testing.T) { - var s string - var sp *string - - simpleTests := []struct { - src pgtype.TID - dst interface{} - expected interface{} - }{ - {src: pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, dst: &s, expected: "(42,43)"}, - {src: pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, dst: &s, expected: "(4294967295,65535)"}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.TID - dst interface{} - expected interface{} - }{ - {src: pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, dst: &sp, expected: "(42,43)"}, - {src: pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, dst: &sp, expected: "(4294967295,65535)"}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} - diff --git a/time_test.go b/time_test.go deleted file mode 100644 index 0af42b1e..00000000 --- a/time_test.go +++ /dev/null @@ -1,131 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - "time" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestTimeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "time", []interface{}{ - &pgtype.Time{Microseconds: 0, Status: pgtype.Present}, - &pgtype.Time{Microseconds: 1, Status: pgtype.Present}, - &pgtype.Time{Microseconds: 86399999999, Status: pgtype.Present}, - &pgtype.Time{Status: pgtype.Null}, - }) -} - -// Test for transcoding 24:00:00 separately as github.com/lib/pq doesn't seem to support it. -func TestTimeTranscode24HH(t *testing.T) { - pgTypeName := "time" - values := []interface{}{ - &pgtype.Time{Microseconds: 86400000000, Status: pgtype.Present}, - } - - eqFunc := func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - } - - testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, "github.com/jackc/pgx/stdlib", pgTypeName, values, eqFunc) -} - -func TestTimeSet(t *testing.T) { - type _time time.Time - - successfulTests := []struct { - source interface{} - result pgtype.Time - }{ - {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 0, Status: pgtype.Present}}, - {source: time.Date(1900, 1, 1, 1, 0, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 3600000000, Status: pgtype.Present}}, - {source: time.Date(1900, 1, 1, 0, 1, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 60000000, Status: pgtype.Present}}, - {source: time.Date(1900, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Time{Microseconds: 1000000, Status: pgtype.Present}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 1, time.UTC), result: pgtype.Time{Microseconds: 0, Status: pgtype.Present}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 1000, time.UTC), result: pgtype.Time{Microseconds: 1, Status: pgtype.Present}}, - {source: time.Date(1999, 12, 31, 23, 59, 59, 999999999, time.UTC), result: pgtype.Time{Microseconds: 86399999999, Status: pgtype.Present}}, - {source: time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local), result: pgtype.Time{Microseconds: 2, Status: pgtype.Present}}, - {source: func(t time.Time) *time.Time { return &t }(time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local)), result: pgtype.Time{Microseconds: 2, Status: pgtype.Present}}, - {source: nil, result: pgtype.Time{Status: pgtype.Null}}, - {source: (*time.Time)(nil), result: pgtype.Time{Status: pgtype.Null}}, - {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 3000, time.UTC)), result: pgtype.Time{Microseconds: 3, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.Time - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestTimeAssignTo(t *testing.T) { - var tim time.Time - var ptim *time.Time - - simpleTests := []struct { - src pgtype.Time - dst interface{} - expected interface{} - }{ - {src: pgtype.Time{Microseconds: 0, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}, - {src: pgtype.Time{Microseconds: 3600000000, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 1, 0, 0, 0, time.UTC)}, - {src: pgtype.Time{Microseconds: 60000000, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 1, 0, 0, time.UTC)}, - {src: pgtype.Time{Microseconds: 1000000, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC)}, - {src: pgtype.Time{Microseconds: 1, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 0, 1000, time.UTC)}, - {src: pgtype.Time{Microseconds: 86399999999, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 23, 59, 59, 999999000, time.UTC)}, - {src: pgtype.Time{Microseconds: 0, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Time - dst interface{} - expected interface{} - }{ - {src: pgtype.Time{Microseconds: 0, Status: pgtype.Present}, dst: &ptim, expected: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Time - dst interface{} - }{ - {src: pgtype.Time{Microseconds: 86400000000, Status: pgtype.Present}, dst: &tim}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/timestamp_array_test.go b/timestamp_array_test.go deleted file mode 100644 index 54d15b24..00000000 --- a/timestamp_array_test.go +++ /dev/null @@ -1,307 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - "time" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestTimestampArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp[]", []interface{}{ - &pgtype.TimestampArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.TimestampArray{Status: pgtype.Null}, - &pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }, func(a, b interface{}) bool { - ata := a.(pgtype.TimestampArray) - bta := b.(pgtype.TimestampArray) - - if len(ata.Elements) != len(bta.Elements) || ata.Status != bta.Status { - return false - } - - for i := range ata.Elements { - ae, be := ata.Elements[i], bta.Elements[i] - if !(ae.Time.Equal(be.Time) && ae.Status == be.Status && ae.InfinityModifier == be.InfinityModifier) { - return false - } - } - - return true - }) -} - -func TestTimestampArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.TimestampArray - }{ - { - source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - result: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]time.Time)(nil)), - result: pgtype.TimestampArray{Status: pgtype.Null}, - }, - { - source: [][]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - result: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - result: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.TimestampArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestTimestampArrayAssignTo(t *testing.T) { - var timeSlice []time.Time - var timeSliceDim2 [][]time.Time - var timeSliceDim4 [][][][]time.Time - var timeArrayDim2 [2][1]time.Time - var timeArrayDim4 [2][1][1][3]time.Time - - simpleTests := []struct { - src pgtype.TimestampArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &timeSlice, - expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - }, - { - src: pgtype.TimestampArray{Status: pgtype.Null}, - dst: &timeSlice, - expected: (([]time.Time)(nil)), - }, - { - src: pgtype.TimestampArray{Status: pgtype.Present}, - dst: &timeSlice, - expected: []time.Time{}, - }, - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &timeSliceDim2, - expected: [][]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - }, - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &timeSliceDim4, - expected: [][][][]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - }, - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &timeArrayDim2, - expected: [2][1]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - }, - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &timeArrayDim4, - expected: [2][1][1][3]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.TimestampArray - dst interface{} - }{ - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &timeSlice, - }, - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &timeArrayDim2, - }, - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &timeSlice, - }, - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &timeArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/timestamp_test.go b/timestamp_test.go deleted file mode 100644 index 74cb1221..00000000 --- a/timestamp_test.go +++ /dev/null @@ -1,178 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - "time" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" - "github.com/stretchr/testify/require" -) - -func TestTimestampTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ - &pgtype.Timestamp{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Status: pgtype.Null}, - &pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, - &pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, - }, func(a, b interface{}) bool { - at := a.(pgtype.Timestamp) - bt := b.(pgtype.Timestamp) - - return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier - }) -} - -func TestTimestampNanosecondsTruncated(t *testing.T) { - tests := []struct { - input time.Time - expected time.Time - }{ - {time.Date(2020, 1, 1, 0, 0, 0, 999999999, time.UTC), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.UTC)}, - {time.Date(2020, 1, 1, 0, 0, 0, 999999001, time.UTC), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.UTC)}, - } - for i, tt := range tests { - { - ts := pgtype.Timestamp{Time: tt.input, Status: pgtype.Present} - buf, err := ts.EncodeText(nil, nil) - if err != nil { - t.Errorf("%d. EncodeText failed - %v", i, err) - } - - ts.DecodeText(nil, buf) - if err != nil { - t.Errorf("%d. DecodeText failed - %v", i, err) - } - - if !(ts.Status == pgtype.Present && ts.Time.Equal(tt.expected)) { - t.Errorf("%d. EncodeText did not truncate nanoseconds", i) - } - } - - { - ts := pgtype.Timestamp{Time: tt.input, Status: pgtype.Present} - buf, err := ts.EncodeBinary(nil, nil) - if err != nil { - t.Errorf("%d. EncodeBinary failed - %v", i, err) - } - - ts.DecodeBinary(nil, buf) - if err != nil { - t.Errorf("%d. DecodeBinary failed - %v", i, err) - } - - if !(ts.Status == pgtype.Present && ts.Time.Equal(tt.expected)) { - t.Errorf("%d. EncodeBinary did not truncate nanoseconds", i) - } - } - } -} - -// https://github.com/jackc/pgtype/issues/74 -func TestTimestampDecodeTextInvalid(t *testing.T) { - tstz := &pgtype.Timestamp{} - err := tstz.DecodeText(nil, []byte(`eeeee`)) - require.Error(t, err) -} - -func TestTimestampSet(t *testing.T) { - type _time time.Time - - successfulTests := []struct { - source interface{} - result pgtype.Timestamp - }{ - {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: pgtype.Infinity, result: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, - {source: pgtype.NegativeInfinity, result: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.Timestamp - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestTimestampAssignTo(t *testing.T) { - var tim time.Time - var ptim *time.Time - - simpleTests := []struct { - src pgtype.Timestamp - dst interface{} - expected interface{} - }{ - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)}, - {src: pgtype.Timestamp{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Timestamp - dst interface{} - expected interface{} - }{ - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Timestamp - dst interface{} - }{ - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/timestamptz_array_test.go b/timestamptz_array_test.go deleted file mode 100644 index 9856e4e7..00000000 --- a/timestamptz_array_test.go +++ /dev/null @@ -1,343 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - "time" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestTimestamptzArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz[]", []interface{}{ - &pgtype.TimestamptzArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.TimestamptzArray{Status: pgtype.Null}, - &pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }, func(a, b interface{}) bool { - ata := a.(pgtype.TimestamptzArray) - bta := b.(pgtype.TimestamptzArray) - - if len(ata.Elements) != len(bta.Elements) || ata.Status != bta.Status { - return false - } - - for i := range ata.Elements { - ae, be := ata.Elements[i], bta.Elements[i] - if !(ae.Time.Equal(be.Time) && ae.Status == be.Status && ae.InfinityModifier == be.InfinityModifier) { - return false - } - } - - return true - }) -} - -func TestTimestamptzArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.TimestamptzArray - }{ - { - source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - result: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]time.Time)(nil)), - result: pgtype.TimestamptzArray{Status: pgtype.Null}, - }, - { - source: [][]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - result: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - result: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - result: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - result: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.TimestamptzArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestTimestamptzArrayAssignTo(t *testing.T) { - var timeSlice []time.Time - var timeSliceDim2 [][]time.Time - var timeSliceDim4 [][][][]time.Time - var timeArrayDim2 [2][1]time.Time - var timeArrayDim4 [2][1][1][3]time.Time - - simpleTests := []struct { - src pgtype.TimestamptzArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &timeSlice, - expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - }, - { - src: pgtype.TimestamptzArray{Status: pgtype.Null}, - dst: &timeSlice, - expected: (([]time.Time)(nil)), - }, - { - src: pgtype.TimestamptzArray{Status: pgtype.Present}, - dst: &timeSlice, - expected: []time.Time{}, - }, - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &timeSliceDim2, - expected: [][]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - }, - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &timeSliceDim4, - expected: [][][][]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - }, - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &timeArrayDim2, - expected: [2][1]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - }, - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &timeArrayDim4, - expected: [2][1][1][3]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.TimestamptzArray - dst interface{} - }{ - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &timeSlice, - }, - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &timeArrayDim2, - }, - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &timeSlice, - }, - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &timeArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/timestamptz_test.go b/timestamptz_test.go deleted file mode 100644 index 769c9239..00000000 --- a/timestamptz_test.go +++ /dev/null @@ -1,224 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - "time" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" - "github.com/stretchr/testify/require" -) - -func TestTimestamptzTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ - &pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Status: pgtype.Null}, - &pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, - &pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, - }, func(a, b interface{}) bool { - at := a.(pgtype.Timestamptz) - bt := b.(pgtype.Timestamptz) - - return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier - }) -} - -func TestTimestamptzNanosecondsTruncated(t *testing.T) { - tests := []struct { - input time.Time - expected time.Time - }{ - {time.Date(2020, 1, 1, 0, 0, 0, 999999999, time.Local), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.Local)}, - {time.Date(2020, 1, 1, 0, 0, 0, 999999001, time.Local), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.Local)}, - } - for i, tt := range tests { - { - tstz := pgtype.Timestamptz{Time: tt.input, Status: pgtype.Present} - buf, err := tstz.EncodeText(nil, nil) - if err != nil { - t.Errorf("%d. EncodeText failed - %v", i, err) - } - - tstz.DecodeText(nil, buf) - if err != nil { - t.Errorf("%d. DecodeText failed - %v", i, err) - } - - if !(tstz.Status == pgtype.Present && tstz.Time.Equal(tt.expected)) { - t.Errorf("%d. EncodeText did not truncate nanoseconds", i) - } - } - - { - tstz := pgtype.Timestamptz{Time: tt.input, Status: pgtype.Present} - buf, err := tstz.EncodeBinary(nil, nil) - if err != nil { - t.Errorf("%d. EncodeBinary failed - %v", i, err) - } - - tstz.DecodeBinary(nil, buf) - if err != nil { - t.Errorf("%d. DecodeBinary failed - %v", i, err) - } - - if !(tstz.Status == pgtype.Present && tstz.Time.Equal(tt.expected)) { - t.Errorf("%d. EncodeBinary did not truncate nanoseconds", i) - } - } - } -} - -// https://github.com/jackc/pgtype/issues/74 -func TestTimestamptzDecodeTextInvalid(t *testing.T) { - tstz := &pgtype.Timestamptz{} - err := tstz.DecodeText(nil, []byte(`eeeee`)) - require.Error(t, err) -} - -func TestTimestamptzSet(t *testing.T) { - type _time time.Time - - successfulTests := []struct { - source interface{} - result pgtype.Timestamptz - }{ - {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, - {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), Status: pgtype.Present}}, - {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, - {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, - {source: pgtype.Infinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, - {source: pgtype.NegativeInfinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.Timestamptz - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestTimestamptzAssignTo(t *testing.T) { - var tim time.Time - var ptim *time.Time - - simpleTests := []struct { - src pgtype.Timestamptz - dst interface{} - expected interface{} - }{ - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, - {src: pgtype.Timestamptz{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Timestamptz - dst interface{} - expected interface{} - }{ - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Timestamptz - dst interface{} - }{ - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} - -func TestTimestamptzMarshalJSON(t *testing.T) { - successfulTests := []struct { - source pgtype.Timestamptz - result string - }{ - {source: pgtype.Timestamptz{Status: pgtype.Null}, result: "null"}, - {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29T10:05:45-06:00\""}, - {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29T10:05:45.555-06:00\""}, - {source: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, result: "\"infinity\""}, - {source: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, result: "\"-infinity\""}, - } - for i, tt := range successfulTests { - r, err := tt.source.MarshalJSON() - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if string(r) != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) - } - } -} - -func TestTimestamptzUnmarshalJSON(t *testing.T) { - successfulTests := []struct { - source string - result pgtype.Timestamptz - }{ - {source: "null", result: pgtype.Timestamptz{Status: pgtype.Null}}, - {source: "\"2012-03-29T10:05:45-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, - {source: "\"2012-03-29T10:05:45.555-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, - {source: "\"infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, - {source: "\"-infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, - } - for i, tt := range successfulTests { - var r pgtype.Timestamptz - err := r.UnmarshalJSON([]byte(tt.source)) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !r.Time.Equal(tt.result.Time) || r.Status != tt.result.Status || r.InfinityModifier != tt.result.InfinityModifier { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} diff --git a/tsrange_test.go b/tsrange_test.go deleted file mode 100644 index 1be0c7d2..00000000 --- a/tsrange_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package pgtype_test - -import ( - "testing" - "time" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestTsrangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "tsrange", []interface{}{ - &pgtype.Tsrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - &pgtype.Tsrange{ - Lower: pgtype.Timestamp{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Timestamp{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Tsrange{ - Lower: pgtype.Timestamp{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Tsrange{Status: pgtype.Null}, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.Tsrange) - b := bb.(pgtype.Tsrange) - - return a.Status == b.Status && - a.Lower.Time.Equal(b.Lower.Time) && - a.Lower.Status == b.Lower.Status && - a.Lower.InfinityModifier == b.Lower.InfinityModifier && - a.Upper.Time.Equal(b.Upper.Time) && - a.Upper.Status == b.Upper.Status && - a.Upper.InfinityModifier == b.Upper.InfinityModifier - }) -} diff --git a/tstzrange_test.go b/tstzrange_test.go deleted file mode 100644 index f8e2c2c5..00000000 --- a/tstzrange_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package pgtype_test - -import ( - "testing" - "time" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" - "github.com/stretchr/testify/require" -) - -func TestTstzrangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "tstzrange", []interface{}{ - &pgtype.Tstzrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - &pgtype.Tstzrange{ - Lower: pgtype.Timestamptz{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Timestamptz{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Tstzrange{ - Lower: pgtype.Timestamptz{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Tstzrange{Status: pgtype.Null}, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.Tstzrange) - b := bb.(pgtype.Tstzrange) - - return a.Status == b.Status && - a.Lower.Time.Equal(b.Lower.Time) && - a.Lower.Status == b.Lower.Status && - a.Lower.InfinityModifier == b.Lower.InfinityModifier && - a.Upper.Time.Equal(b.Upper.Time) && - a.Upper.Status == b.Upper.Status && - a.Upper.InfinityModifier == b.Upper.InfinityModifier - }) -} - -// https://github.com/jackc/pgtype/issues/74 -func TestTstzRangeDecodeTextInvalid(t *testing.T) { - tstzrange := &pgtype.Tstzrange{} - err := tstzrange.DecodeText(nil, []byte(`[eeee,)`)) - require.Error(t, err) -} diff --git a/uuid_array_test.go b/uuid_array_test.go deleted file mode 100644 index 7d822e7a..00000000 --- a/uuid_array_test.go +++ /dev/null @@ -1,368 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestUUIDArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "uuid[]", []interface{}{ - &pgtype.UUIDArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.UUIDArray{Status: pgtype.Null}, - &pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestUUIDArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.UUIDArray - }{ - { - source: nil, - result: pgtype.UUIDArray{Status: pgtype.Null}, - }, - { - source: [][16]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][16]byte{}, - result: pgtype.UUIDArray{Status: pgtype.Present}, - }, - { - source: ([][16]byte)(nil), - result: pgtype.UUIDArray{Status: pgtype.Null}, - }, - { - source: [][]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][]byte{ - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, - nil, - {32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, - }, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 4}}, - Status: pgtype.Present}, - }, - { - source: [][]byte{}, - result: pgtype.UUIDArray{Status: pgtype.Present}, - }, - { - source: ([][]byte)(nil), - result: pgtype.UUIDArray{Status: pgtype.Null}, - }, - { - source: []string{"00010203-0405-0607-0809-0a0b0c0d0e0f"}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []string{}, - result: pgtype.UUIDArray{Status: pgtype.Present}, - }, - { - source: ([]string)(nil), - result: pgtype.UUIDArray{Status: pgtype.Null}, - }, - { - source: [][][16]byte{{ - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]string{ - {{{ - "00010203-0405-0607-0809-0a0b0c0d0e0f", - "10111213-1415-1617-1819-1a1b1c1d1e1f", - "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, - {{{ - "30313233-3435-3637-3839-3a3b3c3d3e3f", - "40414243-4445-4647-4849-4a4b4c4d4e4f", - "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, - {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, - {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1][16]byte{{ - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3]string{ - {{{ - "00010203-0405-0607-0809-0a0b0c0d0e0f", - "10111213-1415-1617-1819-1a1b1c1d1e1f", - "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, - {{{ - "30313233-3435-3637-3839-3a3b3c3d3e3f", - "40414243-4445-4647-4849-4a4b4c4d4e4f", - "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, - {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, - {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.UUIDArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestUUIDArrayAssignTo(t *testing.T) { - var byteArraySlice [][16]byte - var byteSliceSlice [][]byte - var stringSlice []string - var byteSlice []byte - var byteArraySliceDim2 [][][16]byte - var stringSliceDim4 [][][][]string - var byteArrayDim2 [2][1][16]byte - var stringArrayDim4 [2][1][1][3]string - - simpleTests := []struct { - src pgtype.UUIDArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &byteArraySlice, - expected: [][16]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - }, - { - src: pgtype.UUIDArray{Status: pgtype.Null}, - dst: &byteArraySlice, - expected: ([][16]byte)(nil), - }, - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &byteSliceSlice, - expected: [][]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - }, - { - src: pgtype.UUIDArray{Status: pgtype.Null}, - dst: &byteSliceSlice, - expected: ([][]byte)(nil), - }, - { - src: pgtype.UUIDArray{Status: pgtype.Present}, - dst: &byteSlice, - expected: []byte{}, - }, - { - src: pgtype.UUIDArray{Status: pgtype.Present}, - dst: &stringSlice, - expected: []string{}, - }, - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - expected: []string{"00010203-0405-0607-0809-0a0b0c0d0e0f"}, - }, - { - src: pgtype.UUIDArray{Status: pgtype.Null}, - dst: &stringSlice, - expected: ([]string)(nil), - }, - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &byteArraySliceDim2, - expected: [][][16]byte{{ - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, - }, - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, - {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, - {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &stringSliceDim4, - expected: [][][][]string{ - {{{ - "00010203-0405-0607-0809-0a0b0c0d0e0f", - "10111213-1415-1617-1819-1a1b1c1d1e1f", - "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, - {{{ - "30313233-3435-3637-3839-3a3b3c3d3e3f", - "40414243-4445-4647-4849-4a4b4c4d4e4f", - "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, - }, - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &byteArrayDim2, - expected: [2][1][16]byte{{ - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, - }, - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, - {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, - {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &stringArrayDim4, - expected: [2][1][1][3]string{ - {{{ - "00010203-0405-0607-0809-0a0b0c0d0e0f", - "10111213-1415-1617-1819-1a1b1c1d1e1f", - "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, - {{{ - "30313233-3435-3637-3839-3a3b3c3d3e3f", - "40414243-4445-4647-4849-4a4b4c4d4e4f", - "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/uuid_test.go b/uuid_test.go deleted file mode 100644 index 5a93ea8d..00000000 --- a/uuid_test.go +++ /dev/null @@ -1,245 +0,0 @@ -package pgtype_test - -import ( - "bytes" - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestUUIDTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - &pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - &pgtype.UUID{Status: pgtype.Null}, - }) -} - -type SomeUUIDWrapper struct { - SomeUUIDType -} - -type SomeUUIDType [16]byte - -func TestUUIDSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.UUID - }{ - { - source: nil, - result: pgtype.UUID{Status: pgtype.Null}, - }, - { - source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - }, - { - source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - }, - { - source: SomeUUIDType{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - }, - { - source: ([]byte)(nil), - result: pgtype.UUID{Status: pgtype.Null}, - }, - { - source: "00010203-0405-0607-0809-0a0b0c0d0e0f", - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - }, - { - source: "000102030405060708090a0b0c0d0e0f", - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.UUID - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestUUIDAssignTo(t *testing.T) { - { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} - var dst [16]byte - expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} - var dst []byte - expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if bytes.Compare(dst, expected) != 0 { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} - var dst SomeUUIDType - expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} - var dst string - expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} - var dst SomeUUIDWrapper - expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst.SomeUUIDType != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } -} - -func TestUUID_MarshalJSON(t *testing.T) { - tests := []struct { - name string - src pgtype.UUID - want []byte - wantErr bool - }{ - { - name: "first", - src: pgtype.UUID{ - Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, - Status: pgtype.Present, - }, - want: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), - wantErr: false, - }, - { - name: "second", - src: pgtype.UUID{ - Bytes: [16]byte{}, - Status: pgtype.Undefined, - }, - want: nil, - wantErr: true, - }, - { - name: "third", - src: pgtype.UUID{ - Bytes: [16]byte{}, - Status: pgtype.Null, - }, - want: []byte("null"), - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.src.MarshalJSON() - if (err != nil) != tt.wantErr { - t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestUUID_UnmarshalJSON(t *testing.T) { - tests := []struct { - name string - want *pgtype.UUID - src []byte - wantErr bool - }{ - { - name: "first", - want: &pgtype.UUID{ - Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, - Status: pgtype.Present, - }, - src: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), - wantErr: false, - }, - { - name: "second", - want: &pgtype.UUID{ - Bytes: [16]byte{}, - Status: pgtype.Null, - }, - src: []byte("null"), - wantErr: false, - }, - { - name: "third", - want: &pgtype.UUID{ - Bytes: [16]byte{}, - Status: pgtype.Undefined, - }, - src: []byte("1d485a7a-6d18-4599-8c6c-34425616887a"), - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := &pgtype.UUID{} - if err := got.UnmarshalJSON(tt.src); (err != nil) != tt.wantErr { - t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/varbit_test.go b/varbit_test.go deleted file mode 100644 index 3c5aea1e..00000000 --- a/varbit_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestVarbitTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "varbit", []interface{}{ - &pgtype.Varbit{Bytes: []byte{}, Len: 0, Status: pgtype.Present}, - &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Status: pgtype.Present}, - &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Status: pgtype.Present}, - &pgtype.Varbit{Status: pgtype.Null}, - }) -} - -func TestVarbitNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select B'111111111'", - Value: &pgtype.Varbit{Bytes: []byte{255, 128}, Len: 9, Status: pgtype.Present}, - }, - }) -} diff --git a/varchar_array_test.go b/varchar_array_test.go deleted file mode 100644 index 5fb7326d..00000000 --- a/varchar_array_test.go +++ /dev/null @@ -1,282 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestVarcharArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "varchar[]", []interface{}{ - &pgtype.VarcharArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "foo", Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.VarcharArray{Status: pgtype.Null}, - &pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "bar ", Status: pgtype.Present}, - {String: "NuLL", Status: pgtype.Present}, - {String: `wow"quz\`, Status: pgtype.Present}, - {String: "", Status: pgtype.Present}, - {Status: pgtype.Null}, - {String: "null", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "quz", Status: pgtype.Present}, - {String: "foo", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestVarcharArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.VarcharArray - }{ - { - source: []string{"foo"}, - result: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]string)(nil)), - result: pgtype.VarcharArray{Status: pgtype.Null}, - }, - { - source: [][]string{{"foo"}, {"bar"}}, - result: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - result: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - { - source: [2][1]string{{"foo"}, {"bar"}}, - result: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - result: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.VarcharArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestVarcharArrayAssignTo(t *testing.T) { - var stringSlice []string - type _stringSlice []string - var namedStringSlice _stringSlice - var stringSliceDim2 [][]string - var stringSliceDim4 [][][][]string - var stringArrayDim2 [2][1]string - var stringArrayDim4 [2][1][1][3]string - - simpleTests := []struct { - src pgtype.VarcharArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - expected: []string{"foo"}, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedStringSlice, - expected: _stringSlice{"bar"}, - }, - { - src: pgtype.VarcharArray{Status: pgtype.Null}, - dst: &stringSlice, - expected: (([]string)(nil)), - }, - { - src: pgtype.VarcharArray{Status: pgtype.Present}, - dst: &stringSlice, - expected: []string{}, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &stringSliceDim2, - expected: [][]string{{"foo"}, {"bar"}}, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &stringSliceDim4, - expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &stringArrayDim2, - expected: [2][1]string{{"foo"}, {"bar"}}, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, - dst: &stringArrayDim4, - expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.VarcharArray - dst interface{} - }{ - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &stringArrayDim2, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, - dst: &stringSlice, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - dst: &stringArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/xid_test.go b/xid_test.go deleted file mode 100644 index 563ce96e..00000000 --- a/xid_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" -) - -func TestXIDTranscode(t *testing.T) { - pgTypeName := "xid" - values := []interface{}{ - &pgtype.XID{Uint: 42, Status: pgtype.Present}, - &pgtype.XID{Status: pgtype.Null}, - } - eqFunc := func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - } - - testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) - } -} - -func TestXIDSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.XID - }{ - {source: uint32(1), result: pgtype.XID{Uint: 1, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.XID - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestXIDAssignTo(t *testing.T) { - var ui32 uint32 - var pui32 *uint32 - - simpleTests := []struct { - src pgtype.XID - dst interface{} - expected interface{} - }{ - {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.XID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.XID - dst interface{} - expected interface{} - }{ - {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.XID - dst interface{} - }{ - {src: pgtype.XID{Status: pgtype.Null}, dst: &ui32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/zeronull/int2_test.go b/zeronull/int2_test.go deleted file mode 100644 index 2dcb4e79..00000000 --- a/zeronull/int2_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package zeronull_test - -import ( - "testing" - - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgtype/zeronull" -) - -func TestInt2Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ - (zeronull.Int2)(1), - (zeronull.Int2)(0), - }) -} - -func TestInt2ConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "int2", (zeronull.Int2)(0)) -} - -func TestInt2ConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "int2", (zeronull.Int2)(0)) -} diff --git a/zeronull/int4_test.go b/zeronull/int4_test.go deleted file mode 100644 index 309e4125..00000000 --- a/zeronull/int4_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package zeronull_test - -import ( - "testing" - - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgtype/zeronull" -) - -func TestInt4Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ - (zeronull.Int4)(1), - (zeronull.Int4)(0), - }) -} - -func TestInt4ConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "int4", (zeronull.Int4)(0)) -} - -func TestInt4ConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "int4", (zeronull.Int4)(0)) -} diff --git a/zeronull/int8_test.go b/zeronull/int8_test.go deleted file mode 100644 index ae80bc0a..00000000 --- a/zeronull/int8_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package zeronull_test - -import ( - "testing" - - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgtype/zeronull" -) - -func TestInt8Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ - (zeronull.Int8)(1), - (zeronull.Int8)(0), - }) -} - -func TestInt8ConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "int8", (zeronull.Int8)(0)) -} - -func TestInt8ConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "int8", (zeronull.Int8)(0)) -} diff --git a/zeronull/text_test.go b/zeronull/text_test.go deleted file mode 100644 index f08a0d2a..00000000 --- a/zeronull/text_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package zeronull_test - -import ( - "testing" - - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgtype/zeronull" -) - -func TestTextTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "text", []interface{}{ - (zeronull.Text)("foo"), - (zeronull.Text)(""), - }) -} - -func TestTextConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "text", (zeronull.Text)("")) -} - -func TestTextConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "text", (zeronull.Text)("")) -} diff --git a/zeronull/timestamp_test.go b/zeronull/timestamp_test.go deleted file mode 100644 index ec96ff07..00000000 --- a/zeronull/timestamp_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package zeronull_test - -import ( - "testing" - "time" - - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgtype/zeronull" -) - -func TestTimestampTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ - (zeronull.Timestamp)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), - (zeronull.Timestamp)(time.Time{}), - }, func(a, b interface{}) bool { - at := a.(zeronull.Timestamp) - bt := b.(zeronull.Timestamp) - - return time.Time(at).Equal(time.Time(bt)) - }) -} - -func TestTimestampConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "timestamp", (zeronull.Timestamp)(time.Time{})) -} - -func TestTimestampConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "timestamp", (zeronull.Timestamp)(time.Time{})) -} diff --git a/zeronull/timestamptz_test.go b/zeronull/timestamptz_test.go deleted file mode 100644 index 3a401c49..00000000 --- a/zeronull/timestamptz_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package zeronull_test - -import ( - "testing" - "time" - - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgtype/zeronull" -) - -func TestTimestamptzTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ - (zeronull.Timestamptz)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), - (zeronull.Timestamptz)(time.Time{}), - }, func(a, b interface{}) bool { - at := a.(zeronull.Timestamptz) - bt := b.(zeronull.Timestamptz) - - return time.Time(at).Equal(time.Time(bt)) - }) -} - -func TestTimestamptzConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "timestamptz", (zeronull.Timestamptz)(time.Time{})) -} - -func TestTimestamptzConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "timestamptz", (zeronull.Timestamptz)(time.Time{})) -} diff --git a/zeronull/uuid_test.go b/zeronull/uuid_test.go deleted file mode 100644 index 162bdf1f..00000000 --- a/zeronull/uuid_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package zeronull_test - -import ( - "testing" - - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgtype/zeronull" -) - -func TestUUIDTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - (*zeronull.UUID)(&[16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), - (*zeronull.UUID)(&[16]byte{}), - }) -} - -func TestUUIDConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "uuid", (*zeronull.UUID)(&[16]byte{})) -} - -func TestUUIDConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "uuid", (*zeronull.UUID)(&[16]byte{})) -} From d89c8390a530599c1ba1b6f68bbb0de092cbd6cb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Jul 2021 10:25:08 -0500 Subject: [PATCH 0700/1158] Update dependencies and go mod tidy --- go.mod | 4 ++-- go.sum | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index dad81ebe..6fdd0e97 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,11 @@ go 1.12 require ( github.com/jackc/chunkreader/v2 v2.0.1 github.com/jackc/pgio v1.0.0 - github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd + github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.1.1 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.7.0 - golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e + golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 golang.org/x/text v0.3.6 ) diff --git a/go.sum b/go.sum index 54405c28..3c77ee21 100644 --- a/go.sum +++ b/go.sum @@ -15,11 +15,13 @@ github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9 github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= -github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd h1:eDErF6V/JPJON/B7s68BxwHgfmyOntHJQ8IOaz0x4R8= github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= @@ -29,8 +31,6 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.1.0 h1:h2yg3kjIyAGSZKDijYn1/gXHlYLCwl9ZjEh2PU0yVxE= -github.com/jackc/pgproto3/v2 v2.1.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= @@ -87,8 +87,9 @@ golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaE golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e h1:gsTQYXdTw2Gq7RBsWvlQ91b+aEQ6bXFUngBGuR8sPpI= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= From c16a4f7d6a7cfac78cfa7e927264e21346bbdc20 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Jul 2021 10:40:30 -0500 Subject: [PATCH 0701/1158] Revert "Temporarily delete tests and pgxtype to break recursive dependency with pgx" This reverts commit 32e20a603178b49fb189d1be971d0fb6960cabb2. --- aclitem_array_test.go | 329 +++++++++++++++++ aclitem_test.go | 97 +++++ array_test.go | 127 +++++++ array_type_test.go | 84 +++++ bit_test.go | 25 ++ bool_array_test.go | 283 +++++++++++++++ bool_test.go | 140 ++++++++ box_test.go | 34 ++ bpchar_array_test.go | 55 +++ bpchar_test.go | 51 +++ bytea_array_test.go | 229 ++++++++++++ bytea_test.go | 73 ++++ cid_test.go | 105 ++++++ cidr_array_test.go | 319 ++++++++++++++++ circle_test.go | 16 + composite_bench_test.go | 192 ++++++++++ composite_fields_test.go | 273 ++++++++++++++ composite_type_test.go | 320 +++++++++++++++++ custom_composite_test.go | 87 +++++ date_array_test.go | 327 +++++++++++++++++ date_test.go | 168 +++++++++ daterange_test.go | 133 +++++++ enum_array_test.go | 281 +++++++++++++++ enum_type_test.go | 148 ++++++++ ext/gofrs-uuid/uuid_test.go | 101 ++++++ ext/shopspring-numeric/decimal_test.go | 330 +++++++++++++++++ float4_array_test.go | 282 +++++++++++++++ float4_test.go | 149 ++++++++ float8_array_test.go | 258 +++++++++++++ float8_test.go | 149 ++++++++ go.mod | 4 + go.sum | 480 +++++++++++++++++++++++++ hstore_array_test.go | 436 ++++++++++++++++++++++ hstore_test.go | 111 ++++++ inet_array_test.go | 319 ++++++++++++++++ inet_test.go | 134 +++++++ int2_array_test.go | 342 ++++++++++++++++++ int2_test.go | 144 ++++++++ int4_array_test.go | 356 ++++++++++++++++++ int4_test.go | 186 ++++++++++ int4range_test.go | 28 ++ int8_array_test.go | 349 ++++++++++++++++++ int8_test.go | 187 ++++++++++ int8range_test.go | 28 ++ interval_test.go | 74 ++++ json_test.go | 177 +++++++++ jsonb_array_test.go | 36 ++ jsonb_test.go | 142 ++++++++ line_test.go | 38 ++ lseg_test.go | 22 ++ macaddr_array_test.go | 262 ++++++++++++++ macaddr_test.go | 78 ++++ name_test.go | 98 +++++ numeric_array_test.go | 305 ++++++++++++++++ numeric_test.go | 389 ++++++++++++++++++++ numrange_test.go | 46 +++ oid_value_test.go | 95 +++++ path_test.go | 29 ++ pgtype_test.go | 292 +++++++++++++++ pgxtype/README.md | 3 + pgxtype/pgxtype.go | 145 ++++++++ point_test.go | 150 ++++++++ polygon_test.go | 89 +++++ qchar_test.go | 143 ++++++++ range_test.go | 177 +++++++++ record_test.go | 186 ++++++++++ testutil/testutil.go | 436 ++++++++++++++++++++++ text_array_test.go | 294 +++++++++++++++ text_test.go | 164 +++++++++ tid_test.go | 63 ++++ time_test.go | 131 +++++++ timestamp_array_test.go | 307 ++++++++++++++++ timestamp_test.go | 178 +++++++++ timestamptz_array_test.go | 343 ++++++++++++++++++ timestamptz_test.go | 224 ++++++++++++ tsrange_test.go | 41 +++ tstzrange_test.go | 49 +++ uuid_array_test.go | 368 +++++++++++++++++++ uuid_test.go | 245 +++++++++++++ varbit_test.go | 26 ++ varchar_array_test.go | 282 +++++++++++++++ xid_test.go | 105 ++++++ zeronull/int2_test.go | 23 ++ zeronull/int4_test.go | 23 ++ zeronull/int8_test.go | 23 ++ zeronull/text_test.go | 23 ++ zeronull/timestamp_test.go | 29 ++ zeronull/timestamptz_test.go | 29 ++ zeronull/uuid_test.go | 23 ++ 89 files changed, 14674 insertions(+) create mode 100644 aclitem_array_test.go create mode 100644 aclitem_test.go create mode 100644 array_test.go create mode 100644 array_type_test.go create mode 100644 bit_test.go create mode 100644 bool_array_test.go create mode 100644 bool_test.go create mode 100644 box_test.go create mode 100644 bpchar_array_test.go create mode 100644 bpchar_test.go create mode 100644 bytea_array_test.go create mode 100644 bytea_test.go create mode 100644 cid_test.go create mode 100644 cidr_array_test.go create mode 100644 circle_test.go create mode 100644 composite_bench_test.go create mode 100644 composite_fields_test.go create mode 100644 composite_type_test.go create mode 100644 custom_composite_test.go create mode 100644 date_array_test.go create mode 100644 date_test.go create mode 100644 daterange_test.go create mode 100644 enum_array_test.go create mode 100644 enum_type_test.go create mode 100644 ext/gofrs-uuid/uuid_test.go create mode 100644 ext/shopspring-numeric/decimal_test.go create mode 100644 float4_array_test.go create mode 100644 float4_test.go create mode 100644 float8_array_test.go create mode 100644 float8_test.go create mode 100644 hstore_array_test.go create mode 100644 hstore_test.go create mode 100644 inet_array_test.go create mode 100644 inet_test.go create mode 100644 int2_array_test.go create mode 100644 int2_test.go create mode 100644 int4_array_test.go create mode 100644 int4_test.go create mode 100644 int4range_test.go create mode 100644 int8_array_test.go create mode 100644 int8_test.go create mode 100644 int8range_test.go create mode 100644 interval_test.go create mode 100644 json_test.go create mode 100644 jsonb_array_test.go create mode 100644 jsonb_test.go create mode 100644 line_test.go create mode 100644 lseg_test.go create mode 100644 macaddr_array_test.go create mode 100644 macaddr_test.go create mode 100644 name_test.go create mode 100644 numeric_array_test.go create mode 100644 numeric_test.go create mode 100644 numrange_test.go create mode 100644 oid_value_test.go create mode 100644 path_test.go create mode 100644 pgtype_test.go create mode 100644 pgxtype/README.md create mode 100644 pgxtype/pgxtype.go create mode 100644 point_test.go create mode 100644 polygon_test.go create mode 100644 qchar_test.go create mode 100644 range_test.go create mode 100644 record_test.go create mode 100644 testutil/testutil.go create mode 100644 text_array_test.go create mode 100644 text_test.go create mode 100644 tid_test.go create mode 100644 time_test.go create mode 100644 timestamp_array_test.go create mode 100644 timestamp_test.go create mode 100644 timestamptz_array_test.go create mode 100644 timestamptz_test.go create mode 100644 tsrange_test.go create mode 100644 tstzrange_test.go create mode 100644 uuid_array_test.go create mode 100644 uuid_test.go create mode 100644 varbit_test.go create mode 100644 varchar_array_test.go create mode 100644 xid_test.go create mode 100644 zeronull/int2_test.go create mode 100644 zeronull/int4_test.go create mode 100644 zeronull/int8_test.go create mode 100644 zeronull/text_test.go create mode 100644 zeronull/timestamp_test.go create mode 100644 zeronull/timestamptz_test.go create mode 100644 zeronull/uuid_test.go diff --git a/aclitem_array_test.go b/aclitem_array_test.go new file mode 100644 index 00000000..8f015f40 --- /dev/null +++ b/aclitem_array_test.go @@ -0,0 +1,329 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestACLItemArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "aclitem[]", []interface{}{ + &pgtype.ACLItemArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ACLItemArray{Status: pgtype.Null}, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + //{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + {String: `postgres=arwdDxt/postgres`, Status: pgtype.Present}, // todo: remove after fixing above case + {String: "=r/postgres", Status: pgtype.Present}, + {Status: pgtype.Null}, + {String: "=r/postgres", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestACLItemArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ACLItemArray + }{ + { + source: []string{"=r/postgres"}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.ACLItemArray{Status: pgtype.Null}, + }, + { + source: [][]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.ACLItemArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestACLItemArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string + + simpleTests := []struct { + src pgtype.ACLItemArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"=r/postgres"}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"=r/postgres"}, + }, + { + src: pgtype.ACLItemArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + { + src: pgtype.ACLItemArray{Status: pgtype.Present}, + dst: &stringSlice, + expected: []string{}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringSliceDim2, + expected: [][]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + expected: [2][1]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.ACLItemArray + dst interface{} + }{ + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringSlice, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/aclitem_test.go b/aclitem_test.go new file mode 100644 index 00000000..a37d7657 --- /dev/null +++ b/aclitem_test.go @@ -0,0 +1,97 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestACLItemTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "aclitem", []interface{}{ + &pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + //&pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + &pgtype.ACLItem{Status: pgtype.Null}, + }) +} + +func TestACLItemSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ACLItem + }{ + {source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.ACLItem{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var d pgtype.ACLItem + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestACLItemAssignTo(t *testing.T) { + var s string + var ps *string + + simpleTests := []struct { + src pgtype.ACLItem + dst interface{} + expected interface{} + }{ + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"}, + {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.ACLItem + dst interface{} + expected interface{} + }{ + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.ACLItem + dst interface{} + }{ + {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/array_test.go b/array_test.go new file mode 100644 index 00000000..d2120677 --- /dev/null +++ b/array_test.go @@ -0,0 +1,127 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/stretchr/testify/require" +) + +func TestParseUntypedTextArray(t *testing.T) { + tests := []struct { + source string + result pgtype.UntypedTextArray + }{ + { + source: "{}", + result: pgtype.UntypedTextArray{ + Elements: nil, + Quoted: nil, + Dimensions: nil, + }, + }, + { + source: "{1}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1"}, + Quoted: []bool{false}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{a,b}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b"}, + Quoted: []bool{false, false}, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + }, + }, + { + source: `{"NULL"}`, + result: pgtype.UntypedTextArray{ + Elements: []string{"NULL"}, + Quoted: []bool{true}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: `{""}`, + result: pgtype.UntypedTextArray{ + Elements: []string{""}, + Quoted: []bool{true}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: `{"He said, \"Hello.\""}`, + result: pgtype.UntypedTextArray{ + Elements: []string{`He said, "Hello."`}, + Quoted: []bool{true}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{{a,b},{c,d},{e,f}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d", "e", "f"}, + Quoted: []bool{false, false, false, false, false, false}, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + }, + }, + { + source: "{{{a,b},{c,d},{e,f}},{{a,b},{c,d},{e,f}}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d", "e", "f", "a", "b", "c", "d", "e", "f"}, + Quoted: []bool{false, false, false, false, false, false, false, false, false, false, false, false}, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 1}, + {Length: 3, LowerBound: 1}, + {Length: 2, LowerBound: 1}, + }, + }, + }, + { + source: "[4:4]={1}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1"}, + Quoted: []bool{false}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 4}}, + }, + }, + { + source: "[4:5][2:3]={{a,b},{c,d}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d"}, + Quoted: []bool{false, false, false, false}, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + }, + }, + } + + for i, tt := range tests { + r, err := pgtype.ParseUntypedTextArray(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !reflect.DeepEqual(*r, tt.result) { + t.Errorf("%d: expected %+v to be parsed to %+v, but it was %+v", i, tt.source, tt.result, *r) + } + } +} + +// https://github.com/jackc/pgx/issues/881 +func TestArrayAssignToEmptyToNonSlice(t *testing.T) { + var a pgtype.Int4Array + err := a.Set([]int32{}) + require.NoError(t, err) + + var iface interface{} + err = a.AssignTo(&iface) + require.EqualError(t, err, "cannot assign *pgtype.Int4Array to *interface {}") +} diff --git a/array_type_test.go b/array_type_test.go new file mode 100644 index 00000000..626df4dc --- /dev/null +++ b/array_type_test.go @@ -0,0 +1,84 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestArrayTypeValue(t *testing.T) { + arrayType := pgtype.NewArrayType("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }) + + err := arrayType.Set(nil) + require.NoError(t, err) + + gotValue := arrayType.Get() + require.Nil(t, gotValue) + + slice := []string{"foo", "bar"} + err = arrayType.AssignTo(&slice) + require.NoError(t, err) + require.Nil(t, slice) + + err = arrayType.Set([]string{}) + require.NoError(t, err) + + gotValue = arrayType.Get() + require.Len(t, gotValue, 0) + + err = arrayType.AssignTo(&slice) + require.NoError(t, err) + require.EqualValues(t, []string{}, slice) + + err = arrayType.Set([]string{"baz", "quz"}) + require.NoError(t, err) + + gotValue = arrayType.Get() + require.Len(t, gotValue, 2) + + err = arrayType.AssignTo(&slice) + require.NoError(t, err) + require.EqualValues(t, []string{"baz", "quz"}, slice) +} + +func TestArrayTypeTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + conn.ConnInfo().RegisterDataType(pgtype.DataType{ + Value: pgtype.NewArrayType("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }), + Name: "_text", + OID: pgtype.TextArrayOID, + }) + + var dstStrings []string + err := conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings) + require.NoError(t, err) + + require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) +} + +func TestArrayTypeEmptyArrayDoesNotBreakArrayType(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + conn.ConnInfo().RegisterDataType(pgtype.DataType{ + Value: pgtype.NewArrayType("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }), + Name: "_text", + OID: pgtype.TextArrayOID, + }) + + var dstStrings []string + err := conn.QueryRow(context.Background(), "select '{}'::text[]").Scan(&dstStrings) + require.NoError(t, err) + + require.EqualValues(t, []string{}, dstStrings) + + err = conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings) + require.NoError(t, err) + + require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) +} diff --git a/bit_test.go b/bit_test.go new file mode 100644 index 00000000..2e9c9b6e --- /dev/null +++ b/bit_test.go @@ -0,0 +1,25 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestBitTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "bit(40)", []interface{}{ + &pgtype.Varbit{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Status: pgtype.Present}, + &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Status: pgtype.Present}, + &pgtype.Varbit{Status: pgtype.Null}, + }) +} + +func TestBitNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select B'111111111'", + Value: &pgtype.Bit{Bytes: []byte{255, 128}, Len: 9, Status: pgtype.Present}, + }, + }) +} diff --git a/bool_array_test.go b/bool_array_test.go new file mode 100644 index 00000000..be567e59 --- /dev/null +++ b/bool_array_test.go @@ -0,0 +1,283 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestBoolArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "bool[]", []interface{}{ + &pgtype.BoolArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.BoolArray{Status: pgtype.Null}, + &pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Bool: false, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestBoolArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.BoolArray + }{ + { + source: []bool{true}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]bool)(nil)), + result: pgtype.BoolArray{Status: pgtype.Null}, + }, + { + source: [][]bool{{true}, {false}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]bool{{true}, {false}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.BoolArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestBoolArrayAssignTo(t *testing.T) { + var boolSlice []bool + type _boolSlice []bool + var namedBoolSlice _boolSlice + var boolSliceDim2 [][]bool + var boolSliceDim4 [][][][]bool + var boolArrayDim2 [2][1]bool + var boolArrayDim4 [2][1][1][3]bool + + simpleTests := []struct { + src pgtype.BoolArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &boolSlice, + expected: []bool{true}, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedBoolSlice, + expected: _boolSlice{true}, + }, + { + src: pgtype.BoolArray{Status: pgtype.Null}, + dst: &boolSlice, + expected: (([]bool)(nil)), + }, + { + src: pgtype.BoolArray{Status: pgtype.Present}, + dst: &boolSlice, + expected: []bool{}, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]bool{{true}, {false}}, + dst: &boolSliceDim2, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + dst: &boolSliceDim4, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]bool{{true}, {false}}, + dst: &boolArrayDim2, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + dst: &boolArrayDim4, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.BoolArray + dst interface{} + }{ + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &boolSlice, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &boolArrayDim2, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &boolSlice, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &boolArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/bool_test.go b/bool_test.go new file mode 100644 index 00000000..8e7a5220 --- /dev/null +++ b/bool_test.go @@ -0,0 +1,140 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestBoolTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "bool", []interface{}{ + &pgtype.Bool{Bool: false, Status: pgtype.Present}, + &pgtype.Bool{Bool: true, Status: pgtype.Present}, + &pgtype.Bool{Bool: false, Status: pgtype.Null}, + }) +} + +func TestBoolSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Bool + }{ + {source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: false, result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "f", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: _bool(true), result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: _bool(false), result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: nil, result: pgtype.Bool{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var r pgtype.Bool + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestBoolAssignTo(t *testing.T) { + var b bool + var _b _bool + var pb *bool + var _pb *_bool + + simpleTests := []struct { + src pgtype.Bool + dst interface{} + expected interface{} + }{ + {src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &b, expected: false}, + {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &b, expected: true}, + {src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &_b, expected: _bool(false)}, + {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_b, expected: _bool(true)}, + {src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &pb, expected: ((*bool)(nil))}, + {src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &_pb, expected: ((*_bool)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Bool + dst interface{} + expected interface{} + }{ + {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &pb, expected: true}, + {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_pb, expected: _bool(true)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} + +func TestBoolMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Bool + result string + }{ + {source: pgtype.Bool{Status: pgtype.Null}, result: "null"}, + {source: pgtype.Bool{Bool: true, Status: pgtype.Present}, result: "true"}, + {source: pgtype.Bool{Bool: false, Status: pgtype.Present}, result: "false"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestBoolUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Bool + }{ + {source: "null", result: pgtype.Bool{Status: pgtype.Null}}, + {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.Bool + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/box_test.go b/box_test.go new file mode 100644 index 00000000..643c74ec --- /dev/null +++ b/box_test.go @@ -0,0 +1,34 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestBoxTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "box", []interface{}{ + &pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, + Status: pgtype.Present, + }, + &pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Status: pgtype.Present, + }, + &pgtype.Box{Status: pgtype.Null}, + }) +} + +func TestBoxNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select '3.14, 1.678, 7.1, 5.234'::box", + Value: &pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, + Status: pgtype.Present, + }, + }, + }) +} diff --git a/bpchar_array_test.go b/bpchar_array_test.go new file mode 100644 index 00000000..af6bf09a --- /dev/null +++ b/bpchar_array_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestBPCharArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "char(8)[]", []interface{}{ + &pgtype.BPCharArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.BPCharArray{ + Elements: []pgtype.BPChar{ + pgtype.BPChar{String: "foo ", Status: pgtype.Present}, + pgtype.BPChar{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.BPCharArray{Status: pgtype.Null}, + &pgtype.BPCharArray{ + Elements: []pgtype.BPChar{ + pgtype.BPChar{String: "bar ", Status: pgtype.Present}, + pgtype.BPChar{String: "NuLL ", Status: pgtype.Present}, + pgtype.BPChar{String: `wow"quz\`, Status: pgtype.Present}, + pgtype.BPChar{String: "1 ", Status: pgtype.Present}, + pgtype.BPChar{String: "1 ", Status: pgtype.Present}, + pgtype.BPChar{String: "null ", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 3, LowerBound: 1}, + {Length: 2, LowerBound: 1}, + }, + Status: pgtype.Present, + }, + &pgtype.BPCharArray{ + Elements: []pgtype.BPChar{ + pgtype.BPChar{String: " bar ", Status: pgtype.Present}, + pgtype.BPChar{String: " baz ", Status: pgtype.Present}, + pgtype.BPChar{String: " quz ", Status: pgtype.Present}, + pgtype.BPChar{String: "foo ", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} diff --git a/bpchar_test.go b/bpchar_test.go new file mode 100644 index 00000000..7b8c1da3 --- /dev/null +++ b/bpchar_test.go @@ -0,0 +1,51 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestChar3Transcode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "char(3)", []interface{}{ + &pgtype.BPChar{String: "a ", Status: pgtype.Present}, + &pgtype.BPChar{String: " a ", Status: pgtype.Present}, + &pgtype.BPChar{String: "å—¨ ", Status: pgtype.Present}, + &pgtype.BPChar{String: " ", Status: pgtype.Present}, + &pgtype.BPChar{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.BPChar) + b := bb.(pgtype.BPChar) + + return a.Status == b.Status && a.String == b.String + }) +} + +func TestBPCharAssignTo(t *testing.T) { + var ( + str string + run rune + ) + simpleTests := []struct { + src pgtype.BPChar + dst interface{} + expected interface{} + }{ + {src: pgtype.BPChar{String: "simple", Status: pgtype.Present}, dst: &str, expected: "simple"}, + {src: pgtype.BPChar{String: "å—¨", Status: pgtype.Present}, dst: &run, expected: 'å—¨'}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + +} diff --git a/bytea_array_test.go b/bytea_array_test.go new file mode 100644 index 00000000..27c0382e --- /dev/null +++ b/bytea_array_test.go @@ -0,0 +1,229 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestByteaArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "bytea[]", []interface{}{ + &pgtype.ByteaArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ByteaArray{Status: pgtype.Null}, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{}, Status: pgtype.Present}, + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Bytes: []byte{1}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{}, Status: pgtype.Present}, + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{1}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestByteaArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ByteaArray + }{ + { + source: [][]byte{{1, 2, 3}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([][]byte)(nil)), + result: pgtype.ByteaArray{Status: pgtype.Null}, + }, + { + source: [][][]byte{{{1}}, {{2}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, + {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, + {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, + {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, + {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1][]byte{{{1}}, {{2}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, + {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, + {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, + {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, + {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.ByteaArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestByteaArrayAssignTo(t *testing.T) { + var byteByteSlice [][]byte + var byteByteSliceDim2 [][][]byte + var byteByteSliceDim4 [][][][][]byte + var byteByteArraySliceDim2 [2][1][]byte + var byteByteArraySliceDim4 [2][1][1][3][]byte + + simpleTests := []struct { + src pgtype.ByteaArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &byteByteSlice, + expected: [][]byte{{1, 2, 3}}, + }, + { + src: pgtype.ByteaArray{Status: pgtype.Null}, + dst: &byteByteSlice, + expected: (([][]byte)(nil)), + }, + { + src: pgtype.ByteaArray{Status: pgtype.Present}, + dst: &byteByteSlice, + expected: [][]byte{}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &byteByteSliceDim2, + expected: [][][]byte{{{1}}, {{2}}}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, + {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, + {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, + {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, + {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &byteByteSliceDim4, + expected: [][][][][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &byteByteArraySliceDim2, + expected: [2][1][]byte{{{1}}, {{2}}}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, + {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, + {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, + {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, + {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &byteByteArraySliceDim4, + expected: [2][1][1][3][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/bytea_test.go b/bytea_test.go new file mode 100644 index 00000000..c8c49ff7 --- /dev/null +++ b/bytea_test.go @@ -0,0 +1,73 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestByteaTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "bytea", []interface{}{ + &pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + &pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, + &pgtype.Bytea{Bytes: nil, Status: pgtype.Null}, + }) +} + +func TestByteaSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Bytea + }{ + {source: []byte{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + {source: []byte{}, result: pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}}, + {source: []byte(nil), result: pgtype.Bytea{Status: pgtype.Null}}, + {source: _byteSlice{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + {source: _byteSlice(nil), result: pgtype.Bytea{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var r pgtype.Bytea + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestByteaAssignTo(t *testing.T) { + var buf []byte + var _buf _byteSlice + var pbuf *[]byte + var _pbuf *_byteSlice + + simpleTests := []struct { + src pgtype.Bytea + dst interface{} + expected interface{} + }{ + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &buf, expected: []byte{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_buf, expected: _byteSlice{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &pbuf, expected: &[]byte{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_pbuf, expected: &_byteSlice{1, 2, 3}}, + {src: pgtype.Bytea{Status: pgtype.Null}, dst: &pbuf, expected: ((*[]byte)(nil))}, + {src: pgtype.Bytea{Status: pgtype.Null}, dst: &_pbuf, expected: ((*_byteSlice)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/cid_test.go b/cid_test.go new file mode 100644 index 00000000..50e50cd8 --- /dev/null +++ b/cid_test.go @@ -0,0 +1,105 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestCIDTranscode(t *testing.T) { + pgTypeName := "cid" + values := []interface{}{ + &pgtype.CID{Uint: 42, Status: pgtype.Present}, + &pgtype.CID{Status: pgtype.Null}, + } + eqFunc := func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + } + + testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } +} + +func TestCIDSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.CID + }{ + {source: uint32(1), result: pgtype.CID{Uint: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.CID + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestCIDAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.CID + dst interface{} + expected interface{} + }{ + {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.CID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.CID + dst interface{} + expected interface{} + }{ + {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.CID + dst interface{} + }{ + {src: pgtype.CID{Status: pgtype.Null}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/cidr_array_test.go b/cidr_array_test.go new file mode 100644 index 00000000..74c063fa --- /dev/null +++ b/cidr_array_test.go @@ -0,0 +1,319 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestCIDRArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "cidr[]", []interface{}{ + &pgtype.CIDRArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.CIDRArray{Status: pgtype.Null}, + &pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + {Status: pgtype.Null}, + {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestCIDRArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.CIDRArray + }{ + { + source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]*net.IPNet)(nil)), + result: pgtype.CIDRArray{Status: pgtype.Null}, + }, + { + source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]net.IP)(nil)), + result: pgtype.CIDRArray{Status: pgtype.Null}, + }, + { + source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.CIDRArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestCIDRArrayAssignTo(t *testing.T) { + var ipnetSlice []*net.IPNet + var ipSlice []net.IP + var ipSliceDim2 [][]net.IP + var ipnetSliceDim4 [][][][]*net.IPNet + var ipArrayDim2 [2][1]net.IP + var ipnetArrayDim4 [2][1][1][3]*net.IPNet + + simpleTests := []struct { + src pgtype.CIDRArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{nil}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{nil}, + }, + { + src: pgtype.CIDRArray{Status: pgtype.Null}, + dst: &ipnetSlice, + expected: (([]*net.IPNet)(nil)), + }, + { + src: pgtype.CIDRArray{Status: pgtype.Present}, + dst: &ipnetSlice, + expected: []*net.IPNet{}, + }, + { + src: pgtype.CIDRArray{Status: pgtype.Null}, + dst: &ipSlice, + expected: (([]net.IP)(nil)), + }, + { + src: pgtype.CIDRArray{Status: pgtype.Present}, + dst: &ipSlice, + expected: []net.IP{}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &ipSliceDim2, + expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &ipnetSliceDim4, + expected: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &ipArrayDim2, + expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &ipnetArrayDim4, + expected: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/circle_test.go b/circle_test.go new file mode 100644 index 00000000..ba4f408b --- /dev/null +++ b/circle_test.go @@ -0,0 +1,16 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestCircleTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "circle", []interface{}{ + &pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Status: pgtype.Present}, + &pgtype.Circle{P: pgtype.Vec2{-1.234, -5.6789}, R: 12.9, Status: pgtype.Present}, + &pgtype.Circle{Status: pgtype.Null}, + }) +} diff --git a/composite_bench_test.go b/composite_bench_test.go new file mode 100644 index 00000000..7aef8c4f --- /dev/null +++ b/composite_bench_test.go @@ -0,0 +1,192 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgio" + "github.com/jackc/pgtype" + "github.com/stretchr/testify/require" +) + +type MyCompositeRaw struct { + A int32 + B *string +} + +func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + buf = pgio.AppendUint32(buf, 2) + + buf = pgio.AppendUint32(buf, pgtype.Int4OID) + buf = pgio.AppendInt32(buf, 4) + buf = pgio.AppendInt32(buf, src.A) + + buf = pgio.AppendUint32(buf, pgtype.TextOID) + if src.B != nil { + buf = pgio.AppendInt32(buf, int32(len(*src.B))) + buf = append(buf, (*src.B)...) + } else { + buf = pgio.AppendInt32(buf, -1) + } + + return buf, nil +} + +func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + a := pgtype.Int4{} + b := pgtype.Text{} + + scanner := pgtype.NewCompositeBinaryScanner(ci, src) + scanner.ScanDecoder(&a) + scanner.ScanDecoder(&b) + + if scanner.Err() != nil { + return scanner.Err() + } + + dst.A = a.Int + if b.Status == pgtype.Present { + dst.B = &b.String + } else { + dst.B = nil + } + + return nil +} + +var x []byte + +func BenchmarkBinaryEncodingManual(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + v := MyCompositeRaw{4, ptrS("ABCDEFG")} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + buf, _ = v.EncodeBinary(ci, buf[:0]) + } + x = buf +} + +func BenchmarkBinaryEncodingHelper(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + v := MyType{4, ptrS("ABCDEFG")} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + buf, _ = v.EncodeBinary(ci, buf[:0]) + } + x = buf +} + +func BenchmarkBinaryEncodingComposite(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + f1 := 2 + f2 := ptrS("bar") + c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ + {"a", pgtype.Int4OID}, + {"b", pgtype.TextOID}, + }, ci) + require.NoError(b, err) + + b.ResetTimer() + for n := 0; n < b.N; n++ { + c.Set([]interface{}{f1, f2}) + buf, _ = c.EncodeBinary(ci, buf[:0]) + } + x = buf +} + +func BenchmarkBinaryEncodingJSON(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + v := MyCompositeRaw{4, ptrS("ABCDEFG")} + j := pgtype.JSON{} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + j.Set(v) + buf, _ = j.EncodeBinary(ci, buf[:0]) + } + x = buf +} + +var dstRaw MyCompositeRaw + +func BenchmarkBinaryDecodingManual(b *testing.B) { + ci := pgtype.NewConnInfo() + buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) + dst := MyCompositeRaw{} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := dst.DecodeBinary(ci, buf) + E(err) + } + dstRaw = dst +} + +var dstMyType MyType + +func BenchmarkBinaryDecodingHelpers(b *testing.B) { + ci := pgtype.NewConnInfo() + buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) + dst := MyType{} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := dst.DecodeBinary(ci, buf) + E(err) + } + dstMyType = dst +} + +var gf1 int +var gf2 *string + +func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { + ci := pgtype.NewConnInfo() + buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) + var f1 int + var f2 *string + + c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ + {"a", pgtype.Int4OID}, + {"b", pgtype.TextOID}, + }, ci) + require.NoError(b, err) + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := c.DecodeBinary(ci, buf) + if err != nil { + b.Fatal(err) + } + err = c.AssignTo([]interface{}{&f1, &f2}) + if err != nil { + b.Fatal(err) + } + } + gf1 = f1 + gf2 = f2 +} + +func BenchmarkBinaryDecodingJSON(b *testing.B) { + ci := pgtype.NewConnInfo() + j := pgtype.JSON{} + j.Set(MyCompositeRaw{4, ptrS("ABCDEFG")}) + buf, _ := j.EncodeBinary(ci, nil) + + j = pgtype.JSON{} + dst := MyCompositeRaw{} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := j.DecodeBinary(ci, buf) + E(err) + err = j.AssignTo(&dst) + E(err) + } + dstRaw = dst +} diff --git a/composite_fields_test.go b/composite_fields_test.go new file mode 100644 index 00000000..dc4d4c29 --- /dev/null +++ b/composite_fields_test.go @@ -0,0 +1,273 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCompositeFieldsDecode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + formats := []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} + + // Assorted values + { + var a int32 + var b string + var c float64 + + for _, format := range formats { + err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, &b, &c}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.EqualValuesf(t, 1, a, "Format: %v", format) + assert.EqualValuesf(t, "hi", b, "Format: %v", format) + assert.EqualValuesf(t, 2.1, c, "Format: %v", format) + } + } + + // nulls, string "null", and empty string fields + { + var a pgtype.Text + var b string + var c pgtype.Text + var d string + var e pgtype.Text + + for _, format := range formats { + err := conn.QueryRow(context.Background(), "select row(null,'null',null,'',null)", pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.Nilf(t, a.Get(), "Format: %v", format) + assert.EqualValuesf(t, "null", b, "Format: %v", format) + assert.Nilf(t, c.Get(), "Format: %v", format) + assert.EqualValuesf(t, "", d, "Format: %v", format) + assert.Nilf(t, e.Get(), "Format: %v", format) + } + } + + // null record + { + var a pgtype.Text + var b string + cf := pgtype.CompositeFields{&a, &b} + + for _, format := range formats { + // Cannot scan nil into + err := conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan( + cf, + ) + if assert.Errorf(t, err, "Format: %v", format) { + continue + } + assert.NotNilf(t, cf, "Format: %v", format) + + // But can scan nil into *pgtype.CompositeFields + err = conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan( + &cf, + ) + if assert.Errorf(t, err, "Format: %v", format) { + continue + } + assert.Nilf(t, cf, "Format: %v", format) + } + } + + // quotes and special characters + { + var a, b, c, d string + + for _, format := range formats { + err := conn.QueryRow(context.Background(), `select row('"', 'foo bar', 'foo''bar', 'baz)bar')`, pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, &b, &c, &d}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.Equalf(t, `"`, a, "Format: %v", format) + assert.Equalf(t, `foo bar`, b, "Format: %v", format) + assert.Equalf(t, `foo'bar`, c, "Format: %v", format) + assert.Equalf(t, `baz)bar`, d, "Format: %v", format) + } + } + + // arrays + { + var a []string + var b []int64 + + for _, format := range formats { + err := conn.QueryRow(context.Background(), `select row(array['foo', 'bar', 'baz'], array[1,2,3])`, pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, &b}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.EqualValuesf(t, []string{"foo", "bar", "baz"}, a, "Format: %v", format) + assert.EqualValuesf(t, []int64{1, 2, 3}, b, "Format: %v", format) + } + } + + // Skip nil fields + { + var a int32 + var c float64 + + for _, format := range formats { + err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, nil, &c}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.EqualValuesf(t, 1, a, "Format: %v", format) + assert.EqualValuesf(t, 2.1, c, "Format: %v", format) + } + } +} + +func TestCompositeFieldsEncode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists cf_encode; + +create type cf_encode as ( + a text, + b int4, + c text, + d float8, + e text +);`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type cf_encode") + + // Use simple protocol to force text or binary encoding + simpleProtocols := []bool{true, false} + + // Assorted values + { + var a string + var b int32 + var c string + var d float64 + var e string + + for _, simpleProtocol := range simpleProtocols { + err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{"hi", int32(1), "ok", float64(2.1), "bye"}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, "ok", c, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 2.1, d, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, "bye", e, "Simple Protocol: %v", simpleProtocol) + } + } + } + + // untyped nil + { + var a pgtype.Text + var b int32 + var c string + var d pgtype.Float8 + var e pgtype.Text + + simpleProtocol := true + err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{nil, int32(1), "null", nil, nil}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol) + assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol) + assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol) + } + + // untyped nil cannot be represented in binary format because CompositeFields does not know the PostgreSQL schema + // of the composite type. + simpleProtocol = false + err = conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{nil, int32(1), "null", nil, nil}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + assert.Errorf(t, err, "Simple Protocol: %v", simpleProtocol) + } + + // nulls, string "null", and empty string fields + { + var a pgtype.Text + var b int32 + var c string + var d pgtype.Float8 + var e pgtype.Text + + for _, simpleProtocol := range simpleProtocols { + err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{&pgtype.Text{Status: pgtype.Null}, int32(1), "null", &pgtype.Float8{Status: pgtype.Null}, &pgtype.Text{Status: pgtype.Null}}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol) + assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol) + assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol) + } + } + } + + // quotes and special characters + { + var a string + var b int32 + var c string + var d float64 + var e string + + for _, simpleProtocol := range simpleProtocols { + err := conn.QueryRow( + context.Background(), + `select $1::cf_encode`, + pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{`"`, int32(42), `foo'bar`, float64(1.2), `baz)bar`}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.Equalf(t, `"`, a, "Simple Protocol: %v", simpleProtocol) + assert.Equalf(t, int32(42), b, "Simple Protocol: %v", simpleProtocol) + assert.Equalf(t, `foo'bar`, c, "Simple Protocol: %v", simpleProtocol) + assert.Equalf(t, float64(1.2), d, "Simple Protocol: %v", simpleProtocol) + assert.Equalf(t, `baz)bar`, e, "Simple Protocol: %v", simpleProtocol) + } + } + } +} diff --git a/composite_type_test.go b/composite_type_test.go new file mode 100644 index 00000000..2349a67d --- /dev/null +++ b/composite_type_test.go @@ -0,0 +1,320 @@ +package pgtype_test + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + pgx "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCompositeTypeSetAndGet(t *testing.T) { + ci := pgtype.NewConnInfo() + ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ + {"a", pgtype.TextOID}, + {"b", pgtype.Int4OID}, + }, ci) + require.NoError(t, err) + assert.Equal(t, pgtype.Undefined, ct.Get()) + + nilTests := []struct { + src interface{} + }{ + {nil}, // nil interface + {(*[]interface{})(nil)}, // typed nil + } + + for i, tt := range nilTests { + err := ct.Set(tt.src) + assert.NoErrorf(t, err, "%d", i) + assert.Equal(t, nil, ct.Get()) + } + + compatibleValuesTests := []struct { + src []interface{} + expected map[string]interface{} + }{ + { + src: []interface{}{"foo", int32(42)}, + expected: map[string]interface{}{"a": "foo", "b": int32(42)}, + }, + { + src: []interface{}{nil, nil}, + expected: map[string]interface{}{"a": nil, "b": nil}, + }, + { + src: []interface{}{&pgtype.Text{String: "hi", Status: pgtype.Present}, &pgtype.Int4{Int: 7, Status: pgtype.Present}}, + expected: map[string]interface{}{"a": "hi", "b": int32(7)}, + }, + } + + for i, tt := range compatibleValuesTests { + err := ct.Set(tt.src) + assert.NoErrorf(t, err, "%d", i) + assert.EqualValues(t, tt.expected, ct.Get()) + } +} + +func TestCompositeTypeAssignTo(t *testing.T) { + ci := pgtype.NewConnInfo() + ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ + {"a", pgtype.TextOID}, + {"b", pgtype.Int4OID}, + }, ci) + require.NoError(t, err) + + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + var a string + var b int32 + + err = ct.AssignTo([]interface{}{&a, &b}) + assert.NoError(t, err) + + assert.Equal(t, "foo", a) + assert.Equal(t, int32(42), b) + } + + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + var a pgtype.Text + var b pgtype.Int4 + + err = ct.AssignTo([]interface{}{&a, &b}) + assert.NoError(t, err) + + assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a) + assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b) + } + + // Allow nil destination component as no-op + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + var b int32 + + err = ct.AssignTo([]interface{}{nil, &b}) + assert.NoError(t, err) + + assert.Equal(t, int32(42), b) + } + + // *[]interface{} dest when null + { + err := ct.Set(nil) + assert.NoError(t, err) + + var a pgtype.Text + var b pgtype.Int4 + dst := []interface{}{&a, &b} + + err = ct.AssignTo(&dst) + assert.NoError(t, err) + + assert.Nil(t, dst) + } + + // *[]interface{} dest when not null + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + var a pgtype.Text + var b pgtype.Int4 + dst := []interface{}{&a, &b} + + err = ct.AssignTo(&dst) + assert.NoError(t, err) + + assert.NotNil(t, dst) + assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a) + assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b) + } + + // Struct fields positionally via reflection + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + s := struct { + A string + B int32 + }{} + + err = ct.AssignTo(&s) + if assert.NoError(t, err) { + assert.Equal(t, "foo", s.A) + assert.Equal(t, int32(42), s.B) + } + } +} + +func TestCompositeTypeTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists ct_test; + +create type ct_test as ( + a text, + b int4 +);`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type ct_test") + + var oid uint32 + err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid) + require.NoError(t, err) + + defer conn.Exec(context.Background(), "drop type ct_test") + + ct, err := pgtype.NewCompositeType("ct_test", []pgtype.CompositeTypeField{ + {"a", pgtype.TextOID}, + {"b", pgtype.Int4OID}, + }, conn.ConnInfo()) + require.NoError(t, err) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid}) + + // Use simple protocol to force text or binary encoding + simpleProtocols := []bool{true, false} + + var a string + var b int32 + + for _, simpleProtocol := range simpleProtocols { + err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{"hi", int32(42)}, + ).Scan( + []interface{}{&a, &b}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 42, b, "Simple Protocol: %v", simpleProtocol) + } + } +} + +// https://github.com/jackc/pgx/issues/874 +func TestCompositeTypeTextDecodeNested(t *testing.T) { + newCompositeType := func(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { + fields := make([]pgtype.CompositeTypeField, len(fieldNames)) + for i, name := range fieldNames { + fields[i] = pgtype.CompositeTypeField{Name: name} + } + + rowType, err := pgtype.NewCompositeTypeValues(name, fields, vals) + require.NoError(t, err) + return rowType + } + + dimensionsType := func() pgtype.ValueTranscoder { + return newCompositeType( + "dimensions", + []string{"width", "height"}, + &pgtype.Int4{}, + &pgtype.Int4{}, + ) + } + productImageType := func() pgtype.ValueTranscoder { + return newCompositeType( + "product_image_type", + []string{"source", "dimensions"}, + &pgtype.Text{}, + dimensionsType(), + ) + } + productImageSetType := newCompositeType( + "product_image_set_type", + []string{"name", "orig_image", "images"}, + &pgtype.Text{}, + productImageType(), + pgtype.NewArrayType("product_image", 0, func() pgtype.ValueTranscoder { + return productImageType() + }), + ) + + err := productImageSetType.DecodeText(nil, []byte(`(name,"(img1,""(11,11)"")","{""(img2,\\""(22,22)\\"")"",""(img3,\\""(33,33)\\"")""}")`)) + require.NoError(t, err) +} + +func Example_composite() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Println(err) + return + } + + defer conn.Close(context.Background()) + _, err = conn.Exec(context.Background(), `drop type if exists mytype;`) + if err != nil { + fmt.Println(err) + return + } + + _, err = conn.Exec(context.Background(), `create type mytype as ( + a int4, + b text +);`) + if err != nil { + fmt.Println(err) + return + } + defer conn.Exec(context.Background(), "drop type mytype") + + var oid uint32 + err = conn.QueryRow(context.Background(), `select 'mytype'::regtype::oid`).Scan(&oid) + if err != nil { + fmt.Println(err) + return + } + + ct, err := pgtype.NewCompositeType("mytype", []pgtype.CompositeTypeField{ + {"a", pgtype.Int4OID}, + {"b", pgtype.TextOID}, + }, conn.ConnInfo()) + if err != nil { + fmt.Println(err) + return + } + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid}) + + var a int + var b *string + + err = conn.QueryRow(context.Background(), "select $1::mytype", []interface{}{2, "bar"}).Scan([]interface{}{&a, &b}) + if err != nil { + fmt.Println(err) + return + } + + fmt.Printf("First: a=%d b=%s\n", a, *b) + + err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan([]interface{}{&a, &b}) + if err != nil { + fmt.Println(err) + return + } + + fmt.Printf("Second: a=%d b=%v\n", a, b) + + scanTarget := []interface{}{&a, &b} + err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(&scanTarget) + E(err) + + fmt.Printf("Third: isNull=%v\n", scanTarget == nil) + + // Output: + // First: a=2 b=bar + // Second: a=1 b= + // Third: isNull=true +} diff --git a/custom_composite_test.go b/custom_composite_test.go new file mode 100644 index 00000000..9ca8dd5e --- /dev/null +++ b/custom_composite_test.go @@ -0,0 +1,87 @@ +package pgtype_test + +import ( + "context" + "errors" + "fmt" + "os" + + "github.com/jackc/pgtype" + pgx "github.com/jackc/pgx/v4" +) + +type MyType struct { + a int32 // NULL will cause decoding error + b *string // there can be NULL in this position in SQL +} + +func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs") + } + + if err := (pgtype.CompositeFields{&dst.a, &dst.b}).DecodeBinary(ci, src); err != nil { + return err + } + + return nil +} + +func (src MyType) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) { + a := pgtype.Int4{src.a, pgtype.Present} + var b pgtype.Text + if src.b != nil { + b = pgtype.Text{*src.b, pgtype.Present} + } else { + b = pgtype.Text{Status: pgtype.Null} + } + + return (pgtype.CompositeFields{&a, &b}).EncodeBinary(ci, buf) +} + +func ptrS(s string) *string { + return &s +} + +func E(err error) { + if err != nil { + panic(err) + } +} + +// ExampleCustomCompositeTypes demonstrates how support for custom types mappable to SQL +// composites can be added. +func Example_customCompositeTypes() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + E(err) + + defer conn.Close(context.Background()) + _, err = conn.Exec(context.Background(), `drop type if exists mytype; + +create type mytype as ( + a int4, + b text +);`) + E(err) + defer conn.Exec(context.Background(), "drop type mytype") + + var result *MyType + + // Demonstrates both passing and reading back composite values + err = conn.QueryRow(context.Background(), "select $1::mytype", + pgx.QueryResultFormats{pgx.BinaryFormatCode}, MyType{1, ptrS("foo")}). + Scan(&result) + E(err) + + fmt.Printf("First row: a=%d b=%s\n", result.a, *result.b) + + // Because we scan into &*MyType, NULLs are handled generically by assigning nil to result + err = conn.QueryRow(context.Background(), "select NULL::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result) + E(err) + + fmt.Printf("Second row: %v\n", result) + + // Output: + // First row: a=1 b=foo + // Second row: +} diff --git a/date_array_test.go b/date_array_test.go new file mode 100644 index 00000000..4458abfe --- /dev/null +++ b/date_array_test.go @@ -0,0 +1,327 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestDateArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "date[]", []interface{}{ + &pgtype.DateArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.DateArray{Status: pgtype.Null}, + &pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Status: pgtype.Null}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestDateArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.DateArray + }{ + { + source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]time.Time)(nil)), + result: pgtype.DateArray{Status: pgtype.Null}, + }, + { + source: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.DateArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestDateArrayAssignTo(t *testing.T) { + var timeSlice []time.Time + var timeSliceDim2 [][]time.Time + var timeSliceDim4 [][][][]time.Time + var timeArrayDim2 [2][1]time.Time + var timeArrayDim4 [2][1][1][3]time.Time + + simpleTests := []struct { + src pgtype.DateArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + }, + { + src: pgtype.DateArray{Status: pgtype.Null}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, + { + src: pgtype.DateArray{Status: pgtype.Present}, + dst: &timeSlice, + expected: []time.Time{}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeSliceDim2, + expected: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeSliceDim4, + expected: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + expected: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + expected: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.DateArray + dst interface{} + }{ + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeSlice, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/date_test.go b/date_test.go new file mode 100644 index 00000000..5c38e7a3 --- /dev/null +++ b/date_test.go @@ -0,0 +1,168 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestDateTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "date", []interface{}{ + &pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Status: pgtype.Null}, + &pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + &pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Date) + bt := b.(pgtype.Date) + + return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier + }) +} + +func TestDateSet(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Date + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: "1999-12-31", result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var d pgtype.Date + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestDateAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Date + dst interface{} + expected interface{} + }{ + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {src: pgtype.Date{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Date + dst interface{} + expected interface{} + }{ + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Date + dst interface{} + }{ + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} + +func TestDateMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Date + result string + }{ + {source: pgtype.Date{Status: pgtype.Null}, result: "null"}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, result: "\"2012-03-29\""}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29\""}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29\""}, + {source: pgtype.Date{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, result: "\"infinity\""}, + {source: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, result: "\"-infinity\""}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestDateUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Date + }{ + {source: "null", result: pgtype.Date{Status: pgtype.Null}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, + {source: "\"infinity\"", result: pgtype.Date{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, + {source: "\"-infinity\"", result: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.Date + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r.Time.Year() != tt.result.Time.Year() || r.Time.Month() != tt.result.Time.Month() || r.Time.Day() != tt.result.Time.Day() || r.Status != tt.result.Status || r.InfinityModifier != tt.result.InfinityModifier { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/daterange_test.go b/daterange_test.go new file mode 100644 index 00000000..54d51e2d --- /dev/null +++ b/daterange_test.go @@ -0,0 +1,133 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestDaterangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "daterange", []interface{}{ + &pgtype.Daterange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Daterange{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Daterange) + b := bb.(pgtype.Daterange) + + return a.Status == b.Status && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Status == b.Lower.Status && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Status == b.Upper.Status && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} + +func TestDaterangeNormalize(t *testing.T) { + testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ + { + SQL: "select daterange('2010-01-01', '2010-01-11', '(]')", + Value: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(2010, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2010, 1, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + }, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Daterange) + b := bb.(pgtype.Daterange) + + return a.Status == b.Status && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Status == b.Lower.Status && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Status == b.Upper.Status && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} + +func TestDaterangeSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Daterange + }{ + { + source: nil, + result: pgtype.Daterange{Status: pgtype.Null}, + }, + { + source: &pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + result: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + }, + { + source: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + result: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + }, + { + source: "[1990-12-31,2028-01-01)", + result: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Daterange + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/enum_array_test.go b/enum_array_test.go new file mode 100644 index 00000000..659340f0 --- /dev/null +++ b/enum_array_test.go @@ -0,0 +1,281 @@ +package pgtype_test + +import ( + "context" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestEnumArrayTranscode(t *testing.T) { + setupConn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, setupConn) + + if _, err := setupConn.Exec(context.Background(), "drop type if exists color"); err != nil { + t.Fatal(err) + } + if _, err := setupConn.Exec(context.Background(), "create type color as enum ('red', 'green', 'blue')"); err != nil { + t.Fatal(err) + } + + testutil.TestSuccessfulTranscode(t, "color[]", []interface{}{ + &pgtype.EnumArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "red", Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.EnumArray{Status: pgtype.Null}, + &pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "red", Status: pgtype.Present}, + {String: "green", Status: pgtype.Present}, + {String: "blue", Status: pgtype.Present}, + {String: "red", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestEnumArrayArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.EnumArray + }{ + { + source: []string{"foo"}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.EnumArray{Status: pgtype.Null}, + }, + { + source: [][]string{{"foo"}, {"bar"}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]string{{"foo"}, {"bar"}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.EnumArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestEnumArrayArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string + + simpleTests := []struct { + src pgtype.EnumArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.EnumArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + { + src: pgtype.EnumArray{Status: pgtype.Present}, + dst: &stringSlice, + expected: []string{}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringSliceDim2, + expected: [][]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + expected: [2][1]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.EnumArray + dst interface{} + }{ + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringSlice, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/enum_type_test.go b/enum_type_test.go new file mode 100644 index 00000000..4dd88f2a --- /dev/null +++ b/enum_type_test.go @@ -0,0 +1,148 @@ +package pgtype_test + +import ( + "bytes" + "context" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupEnum(t *testing.T, conn *pgx.Conn) *pgtype.EnumType { + _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;") + require.NoError(t, err) + + _, err = conn.Exec(context.Background(), "create type pgtype_enum_color as enum ('blue', 'green', 'purple');") + require.NoError(t, err) + + var oid uint32 + err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "pgtype_enum_color").Scan(&oid) + require.NoError(t, err) + + et := pgtype.NewEnumType("pgtype_enum_color", []string{"blue", "green", "purple"}) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "pgtype_enum_color", OID: oid}) + + return et +} + +func cleanupEnum(t *testing.T, conn *pgx.Conn) { + _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;") + require.NoError(t, err) +} + +func TestEnumTypeTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + setupEnum(t, conn) + defer cleanupEnum(t, conn) + + var dst string + err := conn.QueryRow(context.Background(), "select $1::pgtype_enum_color", "blue").Scan(&dst) + require.NoError(t, err) + require.EqualValues(t, "blue", dst) +} + +func TestEnumTypeSet(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + enumType := setupEnum(t, conn) + defer cleanupEnum(t, conn) + + successfulTests := []struct { + source interface{} + result interface{} + }{ + {source: "blue", result: "blue"}, + {source: _string("green"), result: "green"}, + {source: (*string)(nil), result: nil}, + } + + for i, tt := range successfulTests { + err := enumType.Set(tt.source) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.result, enumType.Get(), "%d", i) + } +} + +func TestEnumTypeAssignTo(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + enumType := setupEnum(t, conn) + defer cleanupEnum(t, conn) + + { + var s string + + err := enumType.Set("blue") + require.NoError(t, err) + + err = enumType.AssignTo(&s) + require.NoError(t, err) + + assert.EqualValues(t, "blue", s) + } + + { + var ps *string + + err := enumType.Set("blue") + require.NoError(t, err) + + err = enumType.AssignTo(&ps) + require.NoError(t, err) + + assert.EqualValues(t, "blue", *ps) + } + + { + var ps *string + + err := enumType.Set(nil) + require.NoError(t, err) + + err = enumType.AssignTo(&ps) + require.NoError(t, err) + + assert.EqualValues(t, (*string)(nil), ps) + } + + var buf []byte + bytesTests := []struct { + src interface{} + dst *[]byte + expected []byte + }{ + {src: "blue", dst: &buf, expected: []byte("blue")}, + {src: nil, dst: &buf, expected: nil}, + } + + for i, tt := range bytesTests { + err := enumType.Set(tt.src) + require.NoError(t, err, "%d", i) + + err = enumType.AssignTo(tt.dst) + require.NoError(t, err, "%d", i) + + if bytes.Compare(*tt.dst, tt.expected) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst) + } + } + + { + var s string + + err := enumType.Set(nil) + require.NoError(t, err) + + err = enumType.AssignTo(&s) + require.Error(t, err) + } + +} diff --git a/ext/gofrs-uuid/uuid_test.go b/ext/gofrs-uuid/uuid_test.go new file mode 100644 index 00000000..56814524 --- /dev/null +++ b/ext/gofrs-uuid/uuid_test.go @@ -0,0 +1,101 @@ +package uuid_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgtype" + gofrs "github.com/jackc/pgtype/ext/gofrs-uuid" + "github.com/jackc/pgtype/testutil" +) + +func TestUUIDTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ + &gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + &gofrs.UUID{Status: pgtype.Null}, + }) +} + +func TestUUIDSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result gofrs.UUID + }{ + { + source: &gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: "00010203-0405-0607-0809-0a0b0c0d0e0f", + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r gofrs.UUID + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestUUIDAssignTo(t *testing.T) { + { + src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst [16]byte + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst []byte + expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare(dst, expected) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst string + expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + +} diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go new file mode 100644 index 00000000..bf34e0dd --- /dev/null +++ b/ext/shopspring-numeric/decimal_test.go @@ -0,0 +1,330 @@ +package numeric_test + +import ( + "fmt" + "math/big" + "math/rand" + "reflect" + "testing" + + "github.com/jackc/pgtype" + shopspring "github.com/jackc/pgtype/ext/shopspring-numeric" + "github.com/jackc/pgtype/testutil" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/require" +) + +func mustParseDecimal(t *testing.T, src string) decimal.Decimal { + dec, err := decimal.NewFromString(src) + if err != nil { + t.Fatal(err) + } + return dec +} + +func TestNumericNormalize(t *testing.T) { + testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ + { + SQL: "select '0'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, + }, + { + SQL: "select '1'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, + }, + { + SQL: "select '10.00'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Status: pgtype.Present}, + }, + { + SQL: "select '1e-3'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, + }, + { + SQL: "select '-1'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, + }, + { + SQL: "select '10000'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Status: pgtype.Present}, + }, + { + SQL: "select '3.14'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, + }, + { + SQL: "select '1.1'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Status: pgtype.Present}, + }, + { + SQL: "select '100010001'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Status: pgtype.Present}, + }, + { + SQL: "select '100010001.0001'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Status: pgtype.Present}, + }, + { + SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", + Value: &shopspring.Numeric{ + Decimal: mustParseDecimal(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), + Status: pgtype.Present, + }, + }, + { + SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", + Value: &shopspring.Numeric{ + Decimal: mustParseDecimal(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), + Status: pgtype.Present, + }, + }, + { + SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", + Value: &shopspring.Numeric{ + Decimal: mustParseDecimal(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), + Status: pgtype.Present, + }, + }, + }, func(aa, bb interface{}) bool { + a := aa.(shopspring.Numeric) + b := bb.(shopspring.Numeric) + + return a.Status == b.Status && a.Decimal.Equal(b.Decimal) + }) +} + +func TestNumericTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "100000"), Status: pgtype.Present}, + + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.1"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.01"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0001"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00001"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000001"), Status: pgtype.Present}, + + &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0000000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001234567890123456789"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "4309132809320932980457137401234890237489238912983572189348951289375283573984571892758234678903467889512893489128589347891272139.8489235871258912789347891235879148795891238915678189467128957812395781238579189025891238901583915890128973578957912385798125789012378905238905471598123758923478294374327894237892234"), Status: pgtype.Present}, + &shopspring.Numeric{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(shopspring.Numeric) + b := bb.(shopspring.Numeric) + + return a.Status == b.Status && a.Decimal.Equal(b.Decimal) + }) + +} + +func TestNumericTranscodeFuzz(t *testing.T) { + r := rand.New(rand.NewSource(0)) + max := &big.Int{} + max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) + + values := make([]interface{}, 0, 2000) + for i := 0; i < 500; i++ { + num := fmt.Sprintf("%s.%s", (&big.Int{}).Rand(r, max).String(), (&big.Int{}).Rand(r, max).String()) + negNum := "-" + num + values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, num), Status: pgtype.Present}) + values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, negNum), Status: pgtype.Present}) + } + + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, + func(aa, bb interface{}) bool { + a := aa.(shopspring.Numeric) + b := bb.(shopspring.Numeric) + + return a.Status == b.Status && a.Decimal.Equal(b.Decimal) + }) +} + +func TestNumericSet(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result *shopspring.Numeric + }{ + {source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int8(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: int16(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: int32(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: int64(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: uint8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: uint16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: uint32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: uint64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: "1", result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: _int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: float64(1000), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1000"), Status: pgtype.Present}}, + {source: float64(1234), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1234"), Status: pgtype.Present}}, + {source: float64(12345678900), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345678900"), Status: pgtype.Present}}, + {source: float64(1.25), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.25"), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + r := &shopspring.Numeric{} + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !(r.Status == tt.result.Status && r.Decimal.Equal(tt.result.Decimal)) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericAssignTo(t *testing.T) { + type _int8 int8 + + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src *shopspring.Numeric + dst interface{} + expected interface{} + }{ + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f32, expected: float32(4.2)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f64, expected: float64(4.2)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &i64, expected: int64(42000)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src *shopspring.Numeric + dst interface{} + expected interface{} + }{ + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src *shopspring.Numeric + dst interface{} + }{ + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "150"), Status: pgtype.Present}, dst: &i8}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "40000"), Status: pgtype.Present}, dst: &i16}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui8}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui16}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui32}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui64}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} + +func BenchmarkDecode(b *testing.B) { + benchmarks := []struct { + name string + numberStr string + }{ + {"Zero", "0"}, + {"Small", "12345"}, + {"Medium", "12345.12345"}, + {"Large", "123457890.1234567890"}, + {"Huge", "123457890123457890123457890.1234567890123457890123457890"}, + } + + for _, bm := range benchmarks { + src := &shopspring.Numeric{} + err := src.Set(bm.numberStr) + require.NoError(b, err) + textFormat, err := src.EncodeText(nil, nil) + require.NoError(b, err) + binaryFormat, err := src.EncodeBinary(nil, nil) + require.NoError(b, err) + + b.Run(fmt.Sprintf("%s-Text", bm.name), func(b *testing.B) { + dst := &shopspring.Numeric{} + for i := 0; i < b.N; i++ { + err := dst.DecodeText(nil, textFormat) + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run(fmt.Sprintf("%s-Binary", bm.name), func(b *testing.B) { + dst := &shopspring.Numeric{} + for i := 0; i < b.N; i++ { + err := dst.DecodeBinary(nil, binaryFormat) + if err != nil { + b.Fatal(err) + } + } + }) + } +} diff --git a/float4_array_test.go b/float4_array_test.go new file mode 100644 index 00000000..db438999 --- /dev/null +++ b/float4_array_test.go @@ -0,0 +1,282 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestFloat4ArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "float4[]", []interface{}{ + &pgtype.Float4Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Float4Array{Status: pgtype.Null}, + &pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Float: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestFloat4ArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float4Array + }{ + { + source: []float32{1}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]float32)(nil)), + result: pgtype.Float4Array{Status: pgtype.Null}, + }, + { + source: [][]float32{{1}, {2}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]float32{{1}, {2}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Float4Array + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat4ArrayAssignTo(t *testing.T) { + var float32Slice []float32 + var namedFloat32Slice _float32Slice + var float32SliceDim2 [][]float32 + var float32SliceDim4 [][][][]float32 + var float32ArrayDim2 [2][1]float32 + var float32ArrayDim4 [2][1][1][3]float32 + + simpleTests := []struct { + src pgtype.Float4Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1.23, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + expected: []float32{1.23}, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1.23, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedFloat32Slice, + expected: _float32Slice{1.23}, + }, + { + src: pgtype.Float4Array{Status: pgtype.Null}, + dst: &float32Slice, + expected: (([]float32)(nil)), + }, + { + src: pgtype.Float4Array{Status: pgtype.Present}, + dst: &float32Slice, + expected: []float32{}, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]float32{{1}, {2}}, + dst: &float32SliceDim2, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float32SliceDim4, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]float32{{1}, {2}}, + dst: &float32ArrayDim2, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float32ArrayDim4, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float4Array + dst interface{} + }{ + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float32ArrayDim2, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float32Slice, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float32ArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/float4_test.go b/float4_test.go new file mode 100644 index 00000000..d2524cda --- /dev/null +++ b/float4_test.go @@ -0,0 +1,149 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestFloat4Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "float4", []interface{}{ + &pgtype.Float4{Float: -1, Status: pgtype.Present}, + &pgtype.Float4{Float: 0, Status: pgtype.Present}, + &pgtype.Float4{Float: 0.00001, Status: pgtype.Present}, + &pgtype.Float4{Float: 1, Status: pgtype.Present}, + &pgtype.Float4{Float: 9999.99, Status: pgtype.Present}, + &pgtype.Float4{Float: 0, Status: pgtype.Null}, + }) +} + +func TestFloat4Set(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float4 + }{ + {source: float32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: float64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Float4 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat4AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src pgtype.Float4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Float4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float4 + dst interface{} + }{ + {src: pgtype.Float4{Float: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Float4{Float: 40000, Status: pgtype.Present}, dst: &i16}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/float8_array_test.go b/float8_array_test.go new file mode 100644 index 00000000..85cb8f43 --- /dev/null +++ b/float8_array_test.go @@ -0,0 +1,258 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestFloat8ArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "float8[]", []interface{}{ + &pgtype.Float8Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Float8Array{Status: pgtype.Null}, + &pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Float: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestFloat8ArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float8Array + }{ + { + source: []float64{1}, + result: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]float64)(nil)), + result: pgtype.Float8Array{Status: pgtype.Null}, + }, + { + source: [][]float64{{1}, {2}}, + result: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Float8Array + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat8ArrayAssignTo(t *testing.T) { + var float64Slice []float64 + var namedFloat64Slice _float64Slice + var float64SliceDim2 [][]float64 + var float64SliceDim4 [][][][]float64 + var float64ArrayDim2 [2][1]float64 + var float64ArrayDim4 [2][1][1][3]float64 + + simpleTests := []struct { + src pgtype.Float8Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1.23, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float64Slice, + expected: []float64{1.23}, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1.23, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedFloat64Slice, + expected: _float64Slice{1.23}, + }, + { + src: pgtype.Float8Array{Status: pgtype.Null}, + dst: &float64Slice, + expected: (([]float64)(nil)), + }, + { + src: pgtype.Float8Array{Status: pgtype.Present}, + dst: &float64Slice, + expected: []float64{}, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]float64{{1}, {2}}, + dst: &float64SliceDim2, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float64SliceDim4, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]float64{{1}, {2}}, + dst: &float64ArrayDim2, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float64ArrayDim4, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float8Array + dst interface{} + }{ + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float64Slice, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float64ArrayDim2, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float64Slice, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float64ArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/float8_test.go b/float8_test.go new file mode 100644 index 00000000..6bc7c652 --- /dev/null +++ b/float8_test.go @@ -0,0 +1,149 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestFloat8Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "float8", []interface{}{ + &pgtype.Float8{Float: -1, Status: pgtype.Present}, + &pgtype.Float8{Float: 0, Status: pgtype.Present}, + &pgtype.Float8{Float: 0.00001, Status: pgtype.Present}, + &pgtype.Float8{Float: 1, Status: pgtype.Present}, + &pgtype.Float8{Float: 9999.99, Status: pgtype.Present}, + &pgtype.Float8{Float: 0, Status: pgtype.Null}, + }) +} + +func TestFloat8Set(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float8 + }{ + {source: float32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: float64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Float8 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat8AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src pgtype.Float8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Float8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float8 + dst interface{} + }{ + {src: pgtype.Float8{Float: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Float8{Float: 40000, Status: pgtype.Present}, dst: &i16}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/go.mod b/go.mod index 42ee3838..29e6f628 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,10 @@ go 1.13 require ( github.com/gofrs/uuid v4.0.0+incompatible + github.com/jackc/pgconn v1.9.0 github.com/jackc/pgio v1.0.0 + github.com/jackc/pgx/v4 v4.12.0 + github.com/lib/pq v1.10.2 github.com/shopspring/decimal v1.2.0 + github.com/stretchr/testify v1.7.0 ) diff --git a/go.sum b/go.sum index da822c7d..e49ce26f 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,486 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= +github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= +github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= +github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= +github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= +github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= +github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= +github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= +github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/aryann/difflib v0.0.0-20170710044230-e206f873d14a/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A= +github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU= +github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= +github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= +github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ= +github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec/go.mod h1:jMjuTZXRI4dUb/I5gc9Hdhagfvm9+RyrPryS/auMzxE= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= +github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= +github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= +github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= +github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= +github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= +github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= +github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= +github.com/envoyproxy/go-control-plane v0.6.9/go.mod h1:SBwIajubJHhxtWwsL9s8ss4safvEdbitLhGGK48rN6g= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= +github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.10.0/go.mod h1:xUsJbQ/Fp4kEt7AFgCuvyX4a71u8h9jB8tj/ORgOZ7o= +github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= +github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= +github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= +github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= +github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= +github.com/hashicorp/consul/api v1.3.0/go.mod h1:MmDNSzIMUjNpY/mQ398R4bk2FnqQLoPndWW5VkKPlCE= +github.com/hashicorp/consul/sdk v0.3.0/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= +github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-msgpack v0.5.3/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= +github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU= +github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= +github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= +github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= +github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= +github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= +github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= +github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.4.0/go.mod h1:Y2O3ZDF0q4mMacyWV3AstPJpeHXWGEetiFttmq5lahk= +github.com/jackc/pgconn v1.5.0/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= +github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.8.1/go.mod h1:JV6m6b6jhjdmzchES0drzCcYcAHS1OPD5xu3OZ/lE2g= +github.com/jackc/pgconn v1.9.0 h1:gqibKSTJup/ahCsNKyMZAniPuZEfIqfXFc8FOWVYR+Q= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd h1:eDErF6V/JPJON/B7s68BxwHgfmyOntHJQ8IOaz0x4R8= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4WpC0= +github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= +github.com/jackc/pgtype v1.3.1-0.20200606141011-f6355165a91c/go.mod h1:cvk9Bgu/VzJ9/lxTO5R5sf80p0DiucVtN7ZxvaC4GmQ= +github.com/jackc/pgtype v1.7.0/go.mod h1:ZnHF+rMePVqDKaOfJVI4Q8IVvAQMryDlDkZnKOI75BE= +github.com/jackc/pgtype v1.8.0/go.mod h1:PqDKcEBtllAtk/2p6z6SHdXW5UB+MhE75tUol2OKexE= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/pgx/v4 v4.5.0/go.mod h1:EpAKPLdnTorwmPUUsqrPxy5fphV18j9q3wrfRXgo+kA= +github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9/go.mod h1:t3/cdRQl6fOLDxqtlyhe9UWgfIi9R8+8v8GKV5TRA/o= +github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904/go.mod h1:ZDaNWkt9sW1JMiNn0kdYBaLelIhw7Pg4qd+Vk6tw7Hg= +github.com/jackc/pgx/v4 v4.11.0/go.mod h1:i62xJgdrtVDsnL3U8ekyrQXEwGNTRoG7/8r+CIdYfcc= +github.com/jackc/pgx/v4 v4.12.0 h1:xiP3TdnkwyslWNp77yE5XAPfxAsU9RMFDe0c1SwN8h4= +github.com/jackc/pgx/v4 v4.12.0/go.mod h1:fE547h6VulLPA3kySjfnSG/e2D861g/50JlVUa/ub60= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= +github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.8/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= +github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= +github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= +github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= +github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= +github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= +github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= +github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= +github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/nats-io/jwt v0.3.0/go.mod h1:fRYCDE99xlTsqUzISS1Bi75UBJ6ljOJQOAAu5VglpSg= +github.com/nats-io/jwt v0.3.2/go.mod h1:/euKqTS1ZD+zzjYrY7pseZrTtWQSjujC7xjPc8wL6eU= +github.com/nats-io/nats-server/v2 v2.1.2/go.mod h1:Afk+wRZqkMQs/p45uXdrVLuab3gwv3Z8C4HTBu8GD/k= +github.com/nats-io/nats.go v1.9.1/go.mod h1:ZjDU1L/7fJ09jvUSRVBR2e7+RnLiiIQyqyzEE/Zbp4w= +github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= +github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= +github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= +github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs= +github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= +github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/op/go-logging v0.0.0-20160315200505-970db520ece7/go.mod h1:HzydrMdWErDVzsI23lYNej1Htcns9BCg93Dk0bBINWk= +github.com/opentracing-contrib/go-observer v0.0.0-20170622124052-a52f23424492/go.mod h1:Ngi6UdF0k5OKD5t5wlmGhe/EDKPoUM3BXZSSfIuJbis= +github.com/opentracing/basictracer-go v1.0.0/go.mod h1:QfBfYuafItcjQuMwinw9GhYKwFXS9KnPs5lxoYwgW74= +github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/openzipkin-contrib/zipkin-go-opentracing v0.4.5/go.mod h1:/wsWhb9smxSfWAKL3wpBW7V8scJMt8N8gnaMCS9E/cA= +github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= +github.com/openzipkin/zipkin-go v0.2.1/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= +github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= +github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM= +github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= +github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= +github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac= +github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= +github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= +github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs= +github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= +github.com/prometheus/client_golang v1.3.0/go.mod h1:hJaj2vgQTGQmVCsAACORcieXFeDPbaTKGT+JTgUa3og= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.1.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA= +github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= +github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= +github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= +github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= +github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= +github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= +github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= +github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= +github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= +github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= +go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg= +go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= +go.opencensus.io v0.20.2/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= +go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e h1:gsTQYXdTw2Gq7RBsWvlQ91b+aEQ6bXFUngBGuR8sPpI= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190530194941-fb225487d101/go.mod h1:z3L6/3dTEVtUr6QSP8miRzeRqwQOioJ9I66odjN4I7s= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.20.0/go.mod h1:chYK+tFQF0nDUGJgXMSgLCQk3phJEuONr2DCgLDdAQM= +google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= +google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= +google.golang.org/grpc v1.22.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/gcfg.v1 v1.2.3/go.mod h1:yesOnuUOFQAhST5vPY4nbZsb/huCgGGXlipJsBn0b3o= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= +gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= +sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU= diff --git a/hstore_array_test.go b/hstore_array_test.go new file mode 100644 index 00000000..672eca4a --- /dev/null +++ b/hstore_array_test.go @@ -0,0 +1,436 @@ +package pgtype_test + +import ( + "context" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" +) + +func TestHstoreArrayTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + var hstoreOID uint32 + err := conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='hstore';").Scan(&hstoreOID) + if err != nil { + t.Fatalf("did not find hstore OID, %v", err) + } + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.Hstore{}, Name: "hstore", OID: hstoreOID}) + + var hstoreArrayOID uint32 + err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='_hstore';").Scan(&hstoreArrayOID) + if err != nil { + t.Fatalf("did not find _hstore OID, %v", err) + } + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.HstoreArray{}, Name: "_hstore", OID: hstoreArrayOID}) + + text := func(s string) pgtype.Text { + return pgtype.Text{String: s, Status: pgtype.Present} + } + + values := []pgtype.Hstore{ + {Map: map[string]pgtype.Text{}, Status: pgtype.Present}, + {Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + {Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + {Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + {Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + {Status: pgtype.Null}, + } + + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + + // Special value values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + } + + src := &pgtype.HstoreArray{ + Elements: values, + Dimensions: []pgtype.ArrayDimension{{Length: int32(len(values)), LowerBound: 1}}, + Status: pgtype.Present, + } + + _, err = conn.Prepare(context.Background(), "test", "select $1::hstore[]") + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, fc := range formats { + queryResultFormats := pgx.QueryResultFormats{fc.formatCode} + vEncoder := testutil.ForceEncoder(src, fc.formatCode) + if vEncoder == nil { + t.Logf("%#v does not implement %v", src, fc.name) + continue + } + + var result pgtype.HstoreArray + err := conn.QueryRow(context.Background(), "test", queryResultFormats, vEncoder).Scan(&result) + if err != nil { + t.Errorf("%v: %v", fc.name, err) + continue + } + + if result.Status != src.Status { + t.Errorf("%v: expected Status %v, got %v", fc.formatCode, src.Status, result.Status) + continue + } + + if len(result.Elements) != len(src.Elements) { + t.Errorf("%v: expected %v elements, got %v", fc.formatCode, len(src.Elements), len(result.Elements)) + continue + } + + for i := range result.Elements { + a := src.Elements[i] + b := result.Elements[i] + + if a.Status != b.Status { + t.Errorf("%v element idx %d: expected status %v, got %v", fc.formatCode, i, a.Status, b.Status) + } + + if len(a.Map) != len(b.Map) { + t.Errorf("%v element idx %d: expected %v pairs, got %v", fc.formatCode, i, len(a.Map), len(b.Map)) + } + + for k := range a.Map { + if a.Map[k] != b.Map[k] { + t.Errorf("%v element idx %d: expected key %v to be %v, got %v", fc.formatCode, i, k, a.Map[k], b.Map[k]) + } + } + } + } +} + +func TestHstoreArraySet(t *testing.T) { + successfulTests := []struct { + src interface{} + result pgtype.HstoreArray + }{ + { + src: []map[string]string{{"foo": "bar"}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + }, + { + src: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + }, + { + src: [][][][]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present, + }, + }, + { + src: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + }, + { + src: [2][1][1][3]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present, + }, + }, + } + + for i, tt := range successfulTests { + var dst pgtype.HstoreArray + err := dst.Set(tt.src) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(dst, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) + } + } +} + +func TestHstoreArrayAssignTo(t *testing.T) { + var hstoreSlice []map[string]string + var hstoreSliceDim2 [][]map[string]string + var hstoreSliceDim4 [][][][]map[string]string + var hstoreArrayDim2 [2][1]map[string]string + var hstoreArrayDim4 [2][1][1][3]map[string]string + + simpleTests := []struct { + src pgtype.HstoreArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &hstoreSlice, + expected: []map[string]string{{"foo": "bar"}}}, + { + src: pgtype.HstoreArray{Status: pgtype.Null}, dst: &hstoreSlice, expected: (([]map[string]string)(nil)), + }, + { + src: pgtype.HstoreArray{Status: pgtype.Present}, dst: &hstoreSlice, expected: []map[string]string{}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &hstoreSliceDim2, + expected: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present, + }, + dst: &hstoreSliceDim4, + expected: [][][][]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &hstoreArrayDim2, + expected: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present, + }, + dst: &hstoreArrayDim4, + expected: [2][1][1][3]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/hstore_test.go b/hstore_test.go new file mode 100644 index 00000000..dce8baf2 --- /dev/null +++ b/hstore_test.go @@ -0,0 +1,111 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestHstoreTranscode(t *testing.T) { + text := func(s string) pgtype.Text { + return pgtype.Text{String: s, Status: pgtype.Present} + } + + values := []interface{}{ + &pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(""), "bar": text(""), "baz": text("123")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"": text("bar")}, Status: pgtype.Present}, + &pgtype.Hstore{Status: pgtype.Null}, + } + + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + + // Special value values + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + } + + testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { + a := ai.(pgtype.Hstore) + b := bi.(pgtype.Hstore) + + if len(a.Map) != len(b.Map) || a.Status != b.Status { + return false + } + + for k := range a.Map { + if a.Map[k] != b.Map[k] { + return false + } + } + + return true + }) +} + +func TestHstoreSet(t *testing.T) { + successfulTests := []struct { + src map[string]string + result pgtype.Hstore + }{ + {src: map[string]string{"foo": "bar"}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var dst pgtype.Hstore + err := dst.Set(tt.src) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(dst, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) + } + } +} + +func TestHstoreAssignTo(t *testing.T) { + var m map[string]string + + simpleTests := []struct { + src pgtype.Hstore + dst *map[string]string + expected map[string]string + }{ + {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}, dst: &m, expected: map[string]string{"foo": "bar"}}, + {src: pgtype.Hstore{Status: pgtype.Null}, dst: &m, expected: ((map[string]string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(*tt.dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} diff --git a/inet_array_test.go b/inet_array_test.go new file mode 100644 index 00000000..46dc7d12 --- /dev/null +++ b/inet_array_test.go @@ -0,0 +1,319 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInetArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "inet[]", []interface{}{ + &pgtype.InetArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.InetArray{Status: pgtype.Null}, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + {Status: pgtype.Null}, + {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInetArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.InetArray + }{ + { + source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]*net.IPNet)(nil)), + result: pgtype.InetArray{Status: pgtype.Null}, + }, + { + source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]net.IP)(nil)), + result: pgtype.InetArray{Status: pgtype.Null}, + }, + { + source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.InetArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInetArrayAssignTo(t *testing.T) { + var ipnetSlice []*net.IPNet + var ipSlice []net.IP + var ipSliceDim2 [][]net.IP + var ipnetSliceDim4 [][][][]*net.IPNet + var ipArrayDim2 [2][1]net.IP + var ipnetArrayDim4 [2][1][1][3]*net.IPNet + + simpleTests := []struct { + src pgtype.InetArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{nil}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{nil}, + }, + { + src: pgtype.InetArray{Status: pgtype.Null}, + dst: &ipnetSlice, + expected: (([]*net.IPNet)(nil)), + }, + { + src: pgtype.InetArray{Status: pgtype.Present}, + dst: &ipnetSlice, + expected: []*net.IPNet{}, + }, + { + src: pgtype.InetArray{Status: pgtype.Null}, + dst: &ipSlice, + expected: (([]net.IP)(nil)), + }, + { + src: pgtype.InetArray{Status: pgtype.Present}, + dst: &ipSlice, + expected: []net.IP{}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &ipSliceDim2, + expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &ipnetSliceDim4, + expected: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &ipArrayDim2, + expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &ipnetArrayDim4, + expected: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/inet_test.go b/inet_test.go new file mode 100644 index 00000000..66fe777f --- /dev/null +++ b/inet_test.go @@ -0,0 +1,134 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/assert" +) + +func TestInetTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "inet", []interface{}{ + &pgtype.Inet{IPNet: mustParseInet(t, "0.0.0.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "127.0.0.1/8"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "12.34.56.65/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "192.168.1.16/24"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "255.0.0.0/8"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "255.255.255.255/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "::1/64"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "::/0"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "::1/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), Status: pgtype.Present}, + &pgtype.Inet{Status: pgtype.Null}, + }) +} + +func TestCidrTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "cidr", []interface{}{ + &pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + &pgtype.Inet{Status: pgtype.Null}, + }) +} + +func TestInetSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Inet + }{ + {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: "1.2.3.4/24", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("1.2.3.4"), Mask: net.CIDRMask(24, 32)}, Status: pgtype.Present}}, + {source: net.ParseIP(""), result: pgtype.Inet{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var r pgtype.Inet + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + assert.Equalf(t, tt.result.Status, r.Status, "%d: Status", i) + if tt.result.Status == pgtype.Present { + assert.Equalf(t, tt.result.IPNet.Mask, r.IPNet.Mask, "%d: IP", i) + assert.Truef(t, tt.result.IPNet.IP.Equal(r.IPNet.IP), "%d: Mask", i) + } + } +} + +func TestInetAssignTo(t *testing.T) { + var ipnet net.IPNet + var pipnet *net.IPNet + var ip net.IP + var pip *net.IP + + simpleTests := []struct { + src pgtype.Inet + dst interface{} + expected interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + {src: pgtype.Inet{Status: pgtype.Null}, dst: &pipnet, expected: ((*net.IPNet)(nil))}, + {src: pgtype.Inet{Status: pgtype.Null}, dst: &pip, expected: ((*net.IP)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %#v, but result was %#v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Inet + dst interface{} + expected interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Inet + dst interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip}, + {src: pgtype.Inet{Status: pgtype.Null}, dst: &ipnet}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/int2_array_test.go b/int2_array_test.go new file mode 100644 index 00000000..17c37360 --- /dev/null +++ b/int2_array_test.go @@ -0,0 +1,342 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt2ArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int2[]", []interface{}{ + &pgtype.Int2Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int2Array{Status: pgtype.Null}, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Int: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInt2ArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int2Array + }{ + { + source: []int64{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []int32{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []int16{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []int{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint64{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint32{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint16{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]int16)(nil)), + result: pgtype.Int2Array{Status: pgtype.Null}, + }, + { + source: [][]int16{{1}, {2}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]int16{{1}, {2}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Int2Array + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt2ArrayAssignTo(t *testing.T) { + var int16Slice []int16 + var uint16Slice []uint16 + var namedInt16Slice _int16Slice + var int16SliceDim2 [][]int16 + var int16SliceDim4 [][][][]int16 + var int16ArrayDim2 [2][1]int16 + var int16ArrayDim4 [2][1][1][3]int16 + + simpleTests := []struct { + src pgtype.Int2Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int16Slice, + expected: []int16{1}, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint16Slice, + expected: []uint16{1}, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedInt16Slice, + expected: _int16Slice{1}, + }, + { + src: pgtype.Int2Array{Status: pgtype.Null}, + dst: &int16Slice, + expected: (([]int16)(nil)), + }, + { + src: pgtype.Int2Array{Status: pgtype.Present}, + dst: &int16Slice, + expected: []int16{}, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]int16{{1}, {2}}, + dst: &int16SliceDim2, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int16SliceDim4, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]int16{{1}, {2}}, + dst: &int16ArrayDim2, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int16ArrayDim4, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int2Array + dst interface{} + }{ + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int16Slice, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: -1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint16Slice, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int16ArrayDim2, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int16Slice, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &int16ArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/int2_test.go b/int2_test.go new file mode 100644 index 00000000..178eb278 --- /dev/null +++ b/int2_test.go @@ -0,0 +1,144 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt2Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ + &pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}, + &pgtype.Int2{Int: -1, Status: pgtype.Present}, + &pgtype.Int2{Int: 0, Status: pgtype.Present}, + &pgtype.Int2{Int: 1, Status: pgtype.Present}, + &pgtype.Int2{Int: math.MaxInt16, Status: pgtype.Present}, + &pgtype.Int2{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt2Set(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int2 + }{ + {source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: float32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: float64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int2 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt2AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.Int2 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Int2 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int2 + dst interface{} + }{ + {src: pgtype.Int2{Int: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &i16}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/int4_array_test.go b/int4_array_test.go new file mode 100644 index 00000000..110512a9 --- /dev/null +++ b/int4_array_test.go @@ -0,0 +1,356 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt4ArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int4[]", []interface{}{ + &pgtype.Int4Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4Array{Status: pgtype.Null}, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Int: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInt4ArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int4Array + expectedError bool + }{ + { + source: []int64{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []int32{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []int16{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []int{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []int{1, math.MaxInt32 + 1, 2}, + expectedError: true, + }, + { + source: []uint64{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint32{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint16{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]int32)(nil)), + result: pgtype.Int4Array{Status: pgtype.Null}, + }, + { + source: [][]int32{{1}, {2}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]int32{{1}, {2}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Int4Array + err := r.Set(tt.source) + if err != nil { + if tt.expectedError { + continue + } + t.Errorf("%d: %v", i, err) + } + + if tt.expectedError { + t.Errorf("%d: an error was expected, %v", i, tt) + continue + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt4ArrayAssignTo(t *testing.T) { + var int32Slice []int32 + var uint32Slice []uint32 + var namedInt32Slice _int32Slice + var int32SliceDim2 [][]int32 + var int32SliceDim4 [][][][]int32 + var int32ArrayDim2 [2][1]int32 + var int32ArrayDim4 [2][1][1][3]int32 + + simpleTests := []struct { + src pgtype.Int4Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int32Slice, + expected: []int32{1}, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint32Slice, + expected: []uint32{1}, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedInt32Slice, + expected: _int32Slice{1}, + }, + { + src: pgtype.Int4Array{Status: pgtype.Null}, + dst: &int32Slice, + expected: (([]int32)(nil)), + }, + { + src: pgtype.Int4Array{Status: pgtype.Present}, + dst: &int32Slice, + expected: []int32{}, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]int32{{1}, {2}}, + dst: &int32SliceDim2, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int32SliceDim4, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]int32{{1}, {2}}, + dst: &int32ArrayDim2, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int32ArrayDim4, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int4Array + dst interface{} + }{ + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int32Slice, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: -1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint32Slice, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int32ArrayDim2, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int32Slice, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &int32ArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/int4_test.go b/int4_test.go new file mode 100644 index 00000000..ae01114f --- /dev/null +++ b/int4_test.go @@ -0,0 +1,186 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt4Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ + &pgtype.Int4{Int: math.MinInt32, Status: pgtype.Present}, + &pgtype.Int4{Int: -1, Status: pgtype.Present}, + &pgtype.Int4{Int: 0, Status: pgtype.Present}, + &pgtype.Int4{Int: 1, Status: pgtype.Present}, + &pgtype.Int4{Int: math.MaxInt32, Status: pgtype.Present}, + &pgtype.Int4{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt4Set(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int4 + }{ + {source: int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: float32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: float64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int4 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt4AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.Int4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Int4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int4 + dst interface{} + }{ + {src: pgtype.Int4{Int: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Int4{Int: 40000, Status: pgtype.Present}, dst: &i16}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} + +func TestInt4MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int4 + result string + }{ + {source: pgtype.Int4{Int: 0, Status: pgtype.Null}, result: "null"}, + {source: pgtype.Int4{Int: 1, Status: pgtype.Present}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestInt4UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int4 + }{ + {source: "null", result: pgtype.Int4{Int: 0, Status: pgtype.Null}}, + {source: "1", result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.Int4 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/int4range_test.go b/int4range_test.go new file mode 100644 index 00000000..43626189 --- /dev/null +++ b/int4range_test.go @@ -0,0 +1,28 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt4rangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int4range", []interface{}{ + &pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int4{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, Status: pgtype.Present}, + &pgtype.Int4range{Upper: pgtype.Int4{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int4range{Status: pgtype.Null}, + }) +} + +func TestInt4rangeNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select int4range(1, 10, '(]')", + Value: pgtype.Int4range{Lower: pgtype.Int4{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + }, + }) +} diff --git a/int8_array_test.go b/int8_array_test.go new file mode 100644 index 00000000..1d42a278 --- /dev/null +++ b/int8_array_test.go @@ -0,0 +1,349 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt8ArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int8[]", []interface{}{ + &pgtype.Int8Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int8Array{Status: pgtype.Null}, + &pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Int: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInt8ArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int8Array + }{ + { + source: []int64{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []int32{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []int16{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []int{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint64{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint32{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint16{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]int64)(nil)), + result: pgtype.Int8Array{Status: pgtype.Null}, + }, + { + source: [][]int64{{1}, {2}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]int64{{1}, {2}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Int8Array + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt8ArrayAssignTo(t *testing.T) { + var int64Slice []int64 + var uint64Slice []uint64 + var namedInt64Slice _int64Slice + var int64SliceDim2 [][]int64 + var int64SliceDim4 [][][][]int64 + var int64ArrayDim2 [2][1]int64 + var int64ArrayDim4 [2][1][1][3]int64 + + simpleTests := []struct { + src pgtype.Int8Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int64Slice, + expected: []int64{1}, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint64Slice, + expected: []uint64{1}, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedInt64Slice, + expected: _int64Slice{1}, + }, + { + src: pgtype.Int8Array{Status: pgtype.Null}, + dst: &int64Slice, + expected: (([]int64)(nil)), + }, + { + src: pgtype.Int8Array{Status: pgtype.Present}, + dst: &int64Slice, + expected: []int64{}, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]int64{{1}, {2}}, + dst: &int64SliceDim2, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int64SliceDim4, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]int64{{1}, {2}}, + dst: &int64ArrayDim2, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int64ArrayDim4, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int8Array + dst interface{} + }{ + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int64Slice, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: -1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint64Slice, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int64ArrayDim2, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int64Slice, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &int64ArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/int8_test.go b/int8_test.go new file mode 100644 index 00000000..4e28e374 --- /dev/null +++ b/int8_test.go @@ -0,0 +1,187 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt8Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ + &pgtype.Int8{Int: math.MinInt64, Status: pgtype.Present}, + &pgtype.Int8{Int: -1, Status: pgtype.Present}, + &pgtype.Int8{Int: 0, Status: pgtype.Present}, + &pgtype.Int8{Int: 1, Status: pgtype.Present}, + &pgtype.Int8{Int: math.MaxInt64, Status: pgtype.Present}, + &pgtype.Int8{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt8Set(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int8 + }{ + {source: int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: float32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: float64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int8 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt8AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.Int8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Int8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int8 + dst interface{} + }{ + {src: pgtype.Int8{Int: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Int8{Int: 40000, Status: pgtype.Present}, dst: &i16}, + {src: pgtype.Int8{Int: 5000000000, Status: pgtype.Present}, dst: &i32}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &i64}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} + +func TestInt8MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int8 + result string + }{ + {source: pgtype.Int8{Int: 0, Status: pgtype.Null}, result: "null"}, + {source: pgtype.Int8{Int: 1, Status: pgtype.Present}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestInt8UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int8 + }{ + {source: "null", result: pgtype.Int8{Int: 0, Status: pgtype.Null}}, + {source: "1", result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.Int8 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/int8range_test.go b/int8range_test.go new file mode 100644 index 00000000..99d4e8a3 --- /dev/null +++ b/int8range_test.go @@ -0,0 +1,28 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt8rangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "Int8range", []interface{}{ + &pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int8{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, Status: pgtype.Present}, + &pgtype.Int8range{Upper: pgtype.Int8{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int8range{Status: pgtype.Null}, + }) +} + +func TestInt8rangeNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select Int8range(1, 10, '(]')", + Value: pgtype.Int8range{Lower: pgtype.Int8{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + }, + }) +} diff --git a/interval_test.go b/interval_test.go new file mode 100644 index 00000000..1ee094d7 --- /dev/null +++ b/interval_test.go @@ -0,0 +1,74 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIntervalTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "interval", []interface{}{ + &pgtype.Interval{Microseconds: 1, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, + &pgtype.Interval{Days: 1, Status: pgtype.Present}, + &pgtype.Interval{Months: 1, Status: pgtype.Present}, + &pgtype.Interval{Months: 12, Status: pgtype.Present}, + &pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: -1, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: -1000000, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: -1000001, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: -123202800000000, Status: pgtype.Present}, + &pgtype.Interval{Days: -1, Status: pgtype.Present}, + &pgtype.Interval{Months: -1, Status: pgtype.Present}, + &pgtype.Interval{Months: -12, Status: pgtype.Present}, + &pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Status: pgtype.Present}, + &pgtype.Interval{Status: pgtype.Null}, + }) +} + +func TestIntervalNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select '1 second'::interval", + Value: &pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, + }, + { + SQL: "select '1.000001 second'::interval", + Value: &pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, + }, + { + SQL: "select '34223 hours'::interval", + Value: &pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, + }, + { + SQL: "select '1 day'::interval", + Value: &pgtype.Interval{Days: 1, Status: pgtype.Present}, + }, + { + SQL: "select '1 month'::interval", + Value: &pgtype.Interval{Months: 1, Status: pgtype.Present}, + }, + { + SQL: "select '1 year'::interval", + Value: &pgtype.Interval{Months: 12, Status: pgtype.Present}, + }, + { + SQL: "select '-13 mon'::interval", + Value: &pgtype.Interval{Months: -13, Status: pgtype.Present}, + }, + }) +} + +func TestIntervalLossyConversionToDuration(t *testing.T) { + interval := &pgtype.Interval{Months: 1, Days: 1, Status: pgtype.Present} + var d time.Duration + err := interval.AssignTo(&d) + require.NoError(t, err) + assert.EqualValues(t, int64(2678400000000000), d.Nanoseconds()) +} diff --git a/json_test.go b/json_test.go new file mode 100644 index 00000000..bbd3959e --- /dev/null +++ b/json_test.go @@ -0,0 +1,177 @@ +package pgtype_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestJSONTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "json", []interface{}{ + &pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, + &pgtype.JSON{Bytes: []byte("null"), Status: pgtype.Present}, + &pgtype.JSON{Bytes: []byte("42"), Status: pgtype.Present}, + &pgtype.JSON{Bytes: []byte(`"hello"`), Status: pgtype.Present}, + &pgtype.JSON{Status: pgtype.Null}, + }) +} + +func TestJSONSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.JSON + }{ + {source: "{}", result: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: []byte("{}"), result: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: ([]byte)(nil), result: pgtype.JSON{Status: pgtype.Null}}, + {source: (*string)(nil), result: pgtype.JSON{Status: pgtype.Null}}, + {source: []int{1, 2, 3}, result: pgtype.JSON{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, + {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var d pgtype.JSON + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(d, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestJSONAssignTo(t *testing.T) { + var s string + var ps *string + var b []byte + + rawStringTests := []struct { + src pgtype.JSON + dst *string + expected string + }{ + {src: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, + } + + for i, tt := range rawStringTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + rawBytesTests := []struct { + src pgtype.JSON + dst *[]byte + expected []byte + }{ + {src: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, + {src: pgtype.JSON{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, + } + + for i, tt := range rawBytesTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if bytes.Compare(tt.expected, *tt.dst) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + var mapDst map[string]interface{} + type structDst struct { + Name string `json:"name"` + Age int `json:"age"` + } + var strDst structDst + + unmarshalTests := []struct { + src pgtype.JSON + dst interface{} + expected interface{} + }{ + {src: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, + {src: pgtype.JSON{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + } + for i, tt := range unmarshalTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.JSON + dst **string + expected *string + }{ + {src: pgtype.JSON{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} + +func TestJSONMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.JSON + result string + }{ + {source: pgtype.JSON{Status: pgtype.Null}, result: "null"}, + {source: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Status: pgtype.Present}, result: "{\"a\": 1}"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestJSONUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.JSON + }{ + {source: "null", result: pgtype.JSON{Status: pgtype.Null}}, + {source: "{\"a\": 1}", result: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.JSON + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r.Bytes) != string(tt.result.Bytes) || r.Status != tt.result.Status { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/jsonb_array_test.go b/jsonb_array_test.go new file mode 100644 index 00000000..65f1777a --- /dev/null +++ b/jsonb_array_test.go @@ -0,0 +1,36 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestJSONBArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "jsonb[]", []interface{}{ + &pgtype.JSONBArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.JSONBArray{ + Elements: []pgtype.JSONB{ + {Bytes: []byte(`"foo"`), Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.JSONBArray{Status: pgtype.Null}, + &pgtype.JSONBArray{ + Elements: []pgtype.JSONB{ + {Bytes: []byte(`"foo"`), Status: pgtype.Present}, + {Bytes: []byte("null"), Status: pgtype.Present}, + {Bytes: []byte("42"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, + Status: pgtype.Present, + }, + }) +} diff --git a/jsonb_test.go b/jsonb_test.go new file mode 100644 index 00000000..9ce80d42 --- /dev/null +++ b/jsonb_test.go @@ -0,0 +1,142 @@ +package pgtype_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestJSONBTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + if _, ok := conn.ConnInfo().DataTypeForName("jsonb"); !ok { + t.Skip("Skipping due to no jsonb type") + } + + testutil.TestSuccessfulTranscode(t, "jsonb", []interface{}{ + &pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, + &pgtype.JSONB{Bytes: []byte("null"), Status: pgtype.Present}, + &pgtype.JSONB{Bytes: []byte("42"), Status: pgtype.Present}, + &pgtype.JSONB{Bytes: []byte(`"hello"`), Status: pgtype.Present}, + &pgtype.JSONB{Status: pgtype.Null}, + }) +} + +func TestJSONBSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.JSONB + }{ + {source: "{}", result: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: []byte("{}"), result: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: ([]byte)(nil), result: pgtype.JSONB{Status: pgtype.Null}}, + {source: (*string)(nil), result: pgtype.JSONB{Status: pgtype.Null}}, + {source: []int{1, 2, 3}, result: pgtype.JSONB{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, + {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var d pgtype.JSONB + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(d, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestJSONBAssignTo(t *testing.T) { + var s string + var ps *string + var b []byte + + rawStringTests := []struct { + src pgtype.JSONB + dst *string + expected string + }{ + {src: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, + } + + for i, tt := range rawStringTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + rawBytesTests := []struct { + src pgtype.JSONB + dst *[]byte + expected []byte + }{ + {src: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, + {src: pgtype.JSONB{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, + } + + for i, tt := range rawBytesTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if bytes.Compare(tt.expected, *tt.dst) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + var mapDst map[string]interface{} + type structDst struct { + Name string `json:"name"` + Age int `json:"age"` + } + var strDst structDst + + unmarshalTests := []struct { + src pgtype.JSONB + dst interface{} + expected interface{} + }{ + {src: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, + {src: pgtype.JSONB{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + } + for i, tt := range unmarshalTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.JSONB + dst **string + expected *string + }{ + {src: pgtype.JSONB{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} diff --git a/line_test.go b/line_test.go new file mode 100644 index 00000000..f697ac43 --- /dev/null +++ b/line_test.go @@ -0,0 +1,38 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestLineTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + if _, ok := conn.ConnInfo().DataTypeForName("line"); !ok { + t.Skip("Skipping due to no line type") + } + + // line may exist but not be usable on 9.3 :( + var isPG93 bool + err := conn.QueryRow(context.Background(), "select version() ~ '9.3'").Scan(&isPG93) + if err != nil { + t.Fatal(err) + } + if isPG93 { + t.Skip("Skipping due to unimplemented line type in PG 9.3") + } + + testutil.TestSuccessfulTranscode(t, "line", []interface{}{ + &pgtype.Line{ + A: 1.23, B: 4.56, C: 7.89012345, + Status: pgtype.Present, + }, + &pgtype.Line{ + A: -1.23, B: -4.56, C: -7.89, + Status: pgtype.Present, + }, + &pgtype.Line{Status: pgtype.Null}, + }) +} diff --git a/lseg_test.go b/lseg_test.go new file mode 100644 index 00000000..b75297cc --- /dev/null +++ b/lseg_test.go @@ -0,0 +1,22 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestLsegTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "lseg", []interface{}{ + &pgtype.Lseg{ + P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, + Status: pgtype.Present, + }, + &pgtype.Lseg{ + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Status: pgtype.Present, + }, + &pgtype.Lseg{Status: pgtype.Null}, + }) +} diff --git a/macaddr_array_test.go b/macaddr_array_test.go new file mode 100644 index 00000000..c1a8b72d --- /dev/null +++ b/macaddr_array_test.go @@ -0,0 +1,262 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestMacaddrArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "macaddr[]", []interface{}{ + &pgtype.MacaddrArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.MacaddrArray{Status: pgtype.Null}, + }) +} + +func TestMacaddrArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.MacaddrArray + }{ + { + source: []net.HardwareAddr{mustParseMacaddr(t, "01:23:45:67:89:ab")}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]net.HardwareAddr)(nil)), + result: pgtype.MacaddrArray{Status: pgtype.Null}, + }, + { + source: [][]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.MacaddrArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestMacaddrArrayAssignTo(t *testing.T) { + var macaddrSlice []net.HardwareAddr + var macaddrSliceDim2 [][]net.HardwareAddr + var macaddrSliceDim4 [][][][]net.HardwareAddr + var macaddrArrayDim2 [2][1]net.HardwareAddr + var macaddrArrayDim4 [2][1][1][3]net.HardwareAddr + + simpleTests := []struct { + src pgtype.MacaddrArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &macaddrSlice, + expected: []net.HardwareAddr{mustParseMacaddr(t, "01:23:45:67:89:ab")}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &macaddrSlice, + expected: []net.HardwareAddr{nil}, + }, + { + src: pgtype.MacaddrArray{Status: pgtype.Null}, + dst: &macaddrSlice, + expected: (([]net.HardwareAddr)(nil)), + }, + { + src: pgtype.MacaddrArray{Status: pgtype.Present}, + dst: &macaddrSlice, + expected: []net.HardwareAddr{}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &macaddrSliceDim2, + expected: [][]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &macaddrSliceDim4, + expected: [][][][]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &macaddrArrayDim2, + expected: [2][1]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &macaddrArrayDim4, + expected: [2][1][1][3]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/macaddr_test.go b/macaddr_test.go new file mode 100644 index 00000000..364a8914 --- /dev/null +++ b/macaddr_test.go @@ -0,0 +1,78 @@ +package pgtype_test + +import ( + "bytes" + "net" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestMacaddrTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "macaddr", []interface{}{ + &pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + &pgtype.Macaddr{Status: pgtype.Null}, + }) +} + +func TestMacaddrSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Macaddr + }{ + { + source: mustParseMacaddr(t, "01:23:45:67:89:ab"), + result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + }, + { + source: "01:23:45:67:89:ab", + result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Macaddr + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestMacaddrAssignTo(t *testing.T) { + { + src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present} + var dst net.HardwareAddr + expected := mustParseMacaddr(t, "01:23:45:67:89:ab") + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare([]byte(dst), []byte(expected)) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present} + var dst string + expected := "01:23:45:67:89:ab" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } +} diff --git a/name_test.go b/name_test.go new file mode 100644 index 00000000..75329b01 --- /dev/null +++ b/name_test.go @@ -0,0 +1,98 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestNameTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "name", []interface{}{ + &pgtype.Name{String: "", Status: pgtype.Present}, + &pgtype.Name{String: "foo", Status: pgtype.Present}, + &pgtype.Name{Status: pgtype.Null}, + }) +} + +func TestNameSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Name + }{ + {source: "foo", result: pgtype.Name{String: "foo", Status: pgtype.Present}}, + {source: _string("bar"), result: pgtype.Name{String: "bar", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.Name{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var d pgtype.Name + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestNameAssignTo(t *testing.T) { + var s string + var ps *string + + simpleTests := []struct { + src pgtype.Name + dst interface{} + expected interface{} + }{ + {src: pgtype.Name{String: "foo", Status: pgtype.Present}, dst: &s, expected: "foo"}, + {src: pgtype.Name{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Name + dst interface{} + expected interface{} + }{ + {src: pgtype.Name{String: "foo", Status: pgtype.Present}, dst: &ps, expected: "foo"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Name + dst interface{} + }{ + {src: pgtype.Name{Status: pgtype.Null}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/numeric_array_test.go b/numeric_array_test.go new file mode 100644 index 00000000..7c1e8c3b --- /dev/null +++ b/numeric_array_test.go @@ -0,0 +1,305 @@ +package pgtype_test + +import ( + "math" + "math/big" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestNumericArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "numeric[]", []interface{}{ + &pgtype.NumericArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.NumericArray{Status: pgtype.Null}, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Status: pgtype.Null}, + {Int: big.NewInt(6), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestNumericArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.NumericArray + }{ + { + source: []float32{1}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []float32{float32(math.Copysign(0, -1))}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(0), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []float64{1}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []float64{math.Copysign(0, -1)}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(0), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]float32)(nil)), + result: pgtype.NumericArray{Status: pgtype.Null}, + }, + { + source: [][]float32{{1}, {2}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(5), Status: pgtype.Present}, + {Int: big.NewInt(6), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]float32{{1}, {2}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(5), Status: pgtype.Present}, + {Int: big.NewInt(6), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.NumericArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericArrayAssignTo(t *testing.T) { + var float32Slice []float32 + var float64Slice []float64 + var float32SliceDim2 [][]float32 + var float32SliceDim4 [][][][]float32 + var float32ArrayDim2 [2][1]float32 + var float32ArrayDim4 [2][1][1][3]float32 + + simpleTests := []struct { + src pgtype.NumericArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + expected: []float32{1}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float64Slice, + expected: []float64{1}, + }, + { + src: pgtype.NumericArray{Status: pgtype.Null}, + dst: &float32Slice, + expected: (([]float32)(nil)), + }, + { + src: pgtype.NumericArray{Status: pgtype.Present}, + dst: &float32Slice, + expected: []float32{}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float32SliceDim2, + expected: [][]float32{{1}, {2}}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(5), Status: pgtype.Present}, + {Int: big.NewInt(6), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &float32SliceDim4, + expected: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float32ArrayDim2, + expected: [2][1]float32{{1}, {2}}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(5), Status: pgtype.Present}, + {Int: big.NewInt(6), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &float32ArrayDim4, + expected: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.NumericArray + dst interface{} + }{ + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float32ArrayDim2, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float32Slice, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float32ArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/numeric_test.go b/numeric_test.go new file mode 100644 index 00000000..81595cb3 --- /dev/null +++ b/numeric_test.go @@ -0,0 +1,389 @@ +package pgtype_test + +import ( + "math" + "math/big" + "math/rand" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +// For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0) +func numericEqual(left, right *pgtype.Numeric) bool { + return left.Status == right.Status && + left.Exp == right.Exp && + ((left.Int == nil && right.Int == nil) || (left.Int != nil && right.Int != nil && left.Int.Cmp(right.Int) == 0)) && + left.NaN == right.NaN +} + +// For test purposes only. +func numericNormalizedEqual(left, right *pgtype.Numeric) bool { + if left.Status != right.Status { + return false + } + + normLeft := &pgtype.Numeric{Int: (&big.Int{}).Set(left.Int), Status: left.Status} + normRight := &pgtype.Numeric{Int: (&big.Int{}).Set(right.Int), Status: right.Status} + + if left.Exp < right.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(right.Exp-left.Exp)), nil) + normRight.Int.Mul(normRight.Int, mul) + } else if left.Exp > right.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(left.Exp-right.Exp)), nil) + normLeft.Int.Mul(normLeft.Int, mul) + } + + return normLeft.Int.Cmp(normRight.Int) == 0 +} + +func mustParseBigInt(t *testing.T, src string) *big.Int { + i := &big.Int{} + if _, ok := i.SetString(src, 10); !ok { + t.Fatalf("could not parse big.Int: %s", src) + } + return i +} + +func TestNumericNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select '0'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, + }, + { + SQL: "select '1'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, + }, + { + SQL: "select '10.00'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present}, + }, + { + SQL: "select '1e-3'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present}, + }, + { + SQL: "select '-1'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, + }, + { + SQL: "select '10000'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present}, + }, + { + SQL: "select '3.14'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, + }, + { + SQL: "select '1.1'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present}, + }, + { + SQL: "select '100010001'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present}, + }, + { + SQL: "select '100010001.0001'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present}, + }, + { + SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", + Value: &pgtype.Numeric{ + Int: mustParseBigInt(t, "423723478923478928934789237432487213832189417894318904389012483210893443219085471578891547854892438945012347981"), + Exp: -41, + Status: pgtype.Present, + }, + }, + { + SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", + Value: &pgtype.Numeric{ + Int: mustParseBigInt(t, "8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), + Exp: -196, + Status: pgtype.Present, + }, + }, + { + SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", + Value: &pgtype.Numeric{ + Int: mustParseBigInt(t, "123"), + Exp: -186, + Status: pgtype.Present, + }, + }, + }) +} + +func TestNumericTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ + &pgtype.Numeric{NaN: true, Status: pgtype.Present}, + + &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(1), Exp: 6, Status: pgtype.Present}, + + // preserves significant zeroes + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -1, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -2, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -3, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -4, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -5, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -6, Status: pgtype.Present}, + + &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -7, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -8, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -9, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -1500, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3723409723490243842378942378901237502734019231380123"), Exp: 81, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "723409723490243842378942378901237502734019231380123"), Exp: 82, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "23409723490243842378942378901237502734019231380123"), Exp: 83, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3409723490243842378942378901237502734019231380123"), Exp: 84, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3423409823409243892349028349023482934092340892390101"), Exp: -91, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "423409823409243892349028349023482934092340892390101"), Exp: -92, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "23409823409243892349028349023482934092340892390101"), Exp: -93, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3409823409243892349028349023482934092340892390101"), Exp: -94, Status: pgtype.Present}, + &pgtype.Numeric{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Numeric) + b := bb.(pgtype.Numeric) + + return numericEqual(&a, &b) + }) + +} + +func TestNumericTranscodeFuzz(t *testing.T) { + r := rand.New(rand.NewSource(0)) + max := &big.Int{} + max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) + + values := make([]interface{}, 0, 2000) + for i := 0; i < 10; i++ { + for j := -50; j < 50; j++ { + num := (&big.Int{}).Rand(r, max) + negNum := &big.Int{} + negNum.Neg(num) + values = append(values, &pgtype.Numeric{Int: num, Exp: int32(j), Status: pgtype.Present}) + values = append(values, &pgtype.Numeric{Int: negNum, Exp: int32(j), Status: pgtype.Present}) + } + } + + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, + func(aa, bb interface{}) bool { + a := aa.(pgtype.Numeric) + b := bb.(pgtype.Numeric) + + return numericNormalizedEqual(&a, &b) + }) +} + +func TestNumericSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result *pgtype.Numeric + }{ + {source: float32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: float32(math.Copysign(0, -1)), result: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Present}}, + {source: float64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: float64(math.Copysign(0, -1)), result: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Present}}, + {source: int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int8(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: int16(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: int32(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: int64(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: uint8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: uint16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: uint32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: uint64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: "1", result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: _int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: float64(1000), result: &pgtype.Numeric{Int: big.NewInt(1), Exp: 3, Status: pgtype.Present}}, + {source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Status: pgtype.Present}}, + {source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Status: pgtype.Present}}, + {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}}, + {source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, NaN: true}}, + {source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, NaN: true}}, + } + + for i, tt := range successfulTests { + r := &pgtype.Numeric{} + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !numericEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericAssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src *pgtype.Numeric + dst interface{} + expected interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f32, expected: float32(4.2)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f64, expected: float64(4.2)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: 3, Status: pgtype.Present}, dst: &i64, expected: int64(42000)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Status: pgtype.Present}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgtype/issues/27 + {src: &pgtype.Numeric{Status: pgtype.Present, NaN: true}, dst: &f64, expected: math.NaN()}, + {src: &pgtype.Numeric{Status: pgtype.Present, NaN: true}, dst: &f32, expected: float32(math.NaN())}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + dst := reflect.ValueOf(tt.dst).Elem().Interface() + switch dstTyped := dst.(type) { + case float32: + nanExpected := math.IsNaN(float64(tt.expected.(float32))) + if nanExpected && !math.IsNaN(float64(dstTyped)) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } else if !nanExpected && dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + case float64: + nanExpected := math.IsNaN(tt.expected.(float64)) + if nanExpected && !math.IsNaN(dstTyped) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } else if !nanExpected && dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + default: + if dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + } + + pointerAllocTests := []struct { + src *pgtype.Numeric + dst interface{} + expected interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src *pgtype.Numeric + dst interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(150), Status: pgtype.Present}, dst: &i8}, + {src: &pgtype.Numeric{Int: big.NewInt(40000), Status: pgtype.Present}, dst: &i16}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui8}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui16}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui32}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui64}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} + +func TestNumericEncodeDecodeBinary(t *testing.T) { + ci := pgtype.NewConnInfo() + tests := []interface{}{ + 123, + 0.000012345, + 1.00002345, + math.NaN(), + float32(math.NaN()), + } + + for i, tt := range tests { + toString := func(n *pgtype.Numeric) string { + ci := pgtype.NewConnInfo() + text, err := n.EncodeText(ci, nil) + if err != nil { + t.Errorf("%d (EncodeText): %v", i, err) + } + return string(text) + } + numeric := &pgtype.Numeric{} + numeric.Set(tt) + + encoded, err := numeric.EncodeBinary(ci, nil) + if err != nil { + t.Errorf("%d (EncodeBinary): %v", i, err) + } + decoded := &pgtype.Numeric{} + err = decoded.DecodeBinary(ci, encoded) + if err != nil { + t.Errorf("%d (DecodeBinary): %v", i, err) + } + + text0 := toString(numeric) + text1 := toString(decoded) + + if text0 != text1 { + t.Errorf("%d: expected %v to equal to %v, but doesn't", i, text0, text1) + } + } +} diff --git a/numrange_test.go b/numrange_test.go new file mode 100644 index 00000000..0bbb26f0 --- /dev/null +++ b/numrange_test.go @@ -0,0 +1,46 @@ +package pgtype_test + +import ( + "math/big" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestNumrangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "numrange", []interface{}{ + &pgtype.Numrange{ + LowerType: pgtype.Empty, + UpperType: pgtype.Empty, + Status: pgtype.Present, + }, + &pgtype.Numrange{ + Lower: pgtype.Numeric{Int: big.NewInt(-543), Exp: 3, Status: pgtype.Present}, + Upper: pgtype.Numeric{Int: big.NewInt(342), Exp: 1, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Numrange{ + Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, + Upper: pgtype.Numeric{Int: big.NewInt(-5), Exp: 0, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Numrange{ + Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Unbounded, + Status: pgtype.Present, + }, + &pgtype.Numrange{ + Upper: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, + LowerType: pgtype.Unbounded, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Numrange{Status: pgtype.Null}, + }) +} diff --git a/oid_value_test.go b/oid_value_test.go new file mode 100644 index 00000000..69742dd7 --- /dev/null +++ b/oid_value_test.go @@ -0,0 +1,95 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestOIDValueTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "oid", []interface{}{ + &pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, + &pgtype.OIDValue{Status: pgtype.Null}, + }) +} + +func TestOIDValueSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.OIDValue + }{ + {source: uint32(1), result: pgtype.OIDValue{Uint: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.OIDValue + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestOIDValueAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.OIDValue + dst interface{} + expected interface{} + }{ + {src: pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.OIDValue{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.OIDValue + dst interface{} + expected interface{} + }{ + {src: pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.OIDValue + dst interface{} + }{ + {src: pgtype.OIDValue{Status: pgtype.Null}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/path_test.go b/path_test.go new file mode 100644 index 00000000..969a89ec --- /dev/null +++ b/path_test.go @@ -0,0 +1,29 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestPathTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "path", []interface{}{ + &pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}}, + Closed: false, + Status: pgtype.Present, + }, + &pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, + Closed: true, + Status: pgtype.Present, + }, + &pgtype.Path{ + P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Closed: true, + Status: pgtype.Present, + }, + &pgtype.Path{Status: pgtype.Null}, + }) +} diff --git a/pgtype_test.go b/pgtype_test.go new file mode 100644 index 00000000..75e1909f --- /dev/null +++ b/pgtype_test.go @@ -0,0 +1,292 @@ +package pgtype_test + +import ( + "bytes" + "errors" + "net" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" + _ "github.com/jackc/pgx/v4/stdlib" + _ "github.com/lib/pq" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test for renamed types +type _string string +type _bool bool +type _int8 int8 +type _int16 int16 +type _int16Slice []int16 +type _int32Slice []int32 +type _int64Slice []int64 +type _float32Slice []float32 +type _float64Slice []float64 +type _byteSlice []byte + +func mustParseCIDR(t testing.TB, s string) *net.IPNet { + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + t.Fatal(err) + } + + return ipnet +} + +func mustParseInet(t testing.TB, s string) *net.IPNet { + ip, ipnet, err := net.ParseCIDR(s) + if err != nil { + t.Fatal(err) + } + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + } + + ipnet.IP = ip + + return ipnet +} + +func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { + addr, err := net.ParseMAC(s) + if err != nil { + t.Fatal(err) + } + + return addr +} + +func TestConnInfoResultFormatCodeForOID(t *testing.T) { + ci := pgtype.NewConnInfo() + + // pgtype.JSONB implements BinaryDecoder but also implements ResultFormatPreferrer to override it to text. + assert.Equal(t, int16(pgtype.TextFormatCode), ci.ResultFormatCodeForOID(pgtype.JSONBOID)) + + // pgtype.Int4 implements BinaryDecoder but does not implement ResultFormatPreferrer so it should be binary. + assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.ResultFormatCodeForOID(pgtype.Int4OID)) +} + +func TestConnInfoParamFormatCodeForOID(t *testing.T) { + ci := pgtype.NewConnInfo() + + // pgtype.JSONB implements BinaryEncoder but also implements ParamFormatPreferrer to override it to text. + assert.Equal(t, int16(pgtype.TextFormatCode), ci.ParamFormatCodeForOID(pgtype.JSONBOID)) + + // pgtype.Int4 implements BinaryEncoder but does not implement ParamFormatPreferrer so it should be binary. + assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.ParamFormatCodeForOID(pgtype.Int4OID)) +} + +func TestConnInfoScanNilIsNoOp(t *testing.T) { + ci := pgtype.NewConnInfo() + + err := ci.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), nil) + assert.NoError(t, err) +} + +func TestConnInfoScanTextFormatInterfacePtr(t *testing.T) { + ci := pgtype.NewConnInfo() + var got interface{} + err := ci.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), &got) + require.NoError(t, err) + assert.Equal(t, "foo", got) +} + +func TestConnInfoScanTextFormatNonByteaIntoByteSlice(t *testing.T) { + ci := pgtype.NewConnInfo() + var got []byte + err := ci.Scan(pgtype.JSONBOID, pgx.TextFormatCode, []byte("{}"), &got) + require.NoError(t, err) + assert.Equal(t, []byte("{}"), got) +} + +func TestConnInfoScanBinaryFormatInterfacePtr(t *testing.T) { + ci := pgtype.NewConnInfo() + var got interface{} + err := ci.Scan(pgtype.TextOID, pgx.BinaryFormatCode, []byte("foo"), &got) + require.NoError(t, err) + assert.Equal(t, "foo", got) +} + +func TestConnInfoScanUnknownOIDToStringsAndBytes(t *testing.T) { + unknownOID := uint32(999999) + srcBuf := []byte("foo") + ci := pgtype.NewConnInfo() + + var s string + err := ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &s) + assert.NoError(t, err) + assert.Equal(t, "foo", s) + + var rs _string + err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rs) + assert.NoError(t, err) + assert.Equal(t, "foo", string(rs)) + + var b []byte + err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &b) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), b) + + err = ci.Scan(unknownOID, pgx.BinaryFormatCode, srcBuf, &b) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), b) + + var rb _byteSlice + err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rb) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), []byte(rb)) + + err = ci.Scan(unknownOID, pgx.BinaryFormatCode, srcBuf, &b) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), []byte(rb)) +} + +type pgCustomType struct { + a string + b string +} + +func (ct *pgCustomType) DecodeText(ci *pgtype.ConnInfo, buf []byte) error { + // This is not a complete parser for the text format of composite types. This is just for test purposes. + if buf == nil { + return errors.New("cannot parse null") + } + + if len(buf) < 2 { + return errors.New("invalid text format") + } + + parts := bytes.Split(buf[1:len(buf)-1], []byte(",")) + if len(parts) != 2 { + return errors.New("wrong number of parts") + } + + ct.a = string(parts[0]) + ct.b = string(parts[1]) + + return nil +} + +func TestConnInfoScanUnregisteredOIDToCustomType(t *testing.T) { + unregisteredOID := uint32(999999) + ci := pgtype.NewConnInfo() + + var ct pgCustomType + err := ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct) + assert.NoError(t, err) + assert.Equal(t, "foo", ct.a) + assert.Equal(t, "bar", ct.b) + + // Scan value into pointer to custom type + var pCt *pgCustomType + err = ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt) + assert.NoError(t, err) + require.NotNil(t, pCt) + assert.Equal(t, "foo", pCt.a) + assert.Equal(t, "bar", pCt.b) + + // Scan null into pointer to custom type + err = ci.Scan(unregisteredOID, pgx.TextFormatCode, nil, &pCt) + assert.NoError(t, err) + assert.Nil(t, pCt) +} + +func TestConnInfoScanUnknownOIDTextFormat(t *testing.T) { + ci := pgtype.NewConnInfo() + + var n int32 + err := ci.Scan(0, pgx.TextFormatCode, []byte("123"), &n) + assert.NoError(t, err) + assert.EqualValues(t, 123, n) +} + +func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v pgtype.Int4 + + for i := 0; i < b.N; i++ { + v = pgtype.Int4{} + err := ci.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != (pgtype.Int4{Int: 42, Status: pgtype.Present}) { + b.Fatal("scan failed due to bad value") + } + } +} + +func TestScanPlanBinaryInt32ScanChangedType(t *testing.T) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v int32 + + plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) + err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + require.NoError(t, err) + require.EqualValues(t, 42, v) + + var d pgtype.Int4 + err = plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &d) + require.NoError(t, err) + require.EqualValues(t, 42, d.Int) + require.EqualValues(t, pgtype.Present, d.Status) +} + +func BenchmarkConnInfoScanInt4IntoGoInt32(b *testing.B) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v int32 + + for i := 0; i < b.N; i++ { + v = 0 + err := ci.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != 42 { + b.Fatal("scan failed due to bad value") + } + } +} + +func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v pgtype.Int4 + + plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) + + for i := 0; i < b.N; i++ { + v = pgtype.Int4{} + err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != (pgtype.Int4{Int: 42, Status: pgtype.Present}) { + b.Fatal("scan failed due to bad value") + } + } +} + +func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v int32 + + plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) + + for i := 0; i < b.N; i++ { + v = 0 + err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != 42 { + b.Fatal("scan failed due to bad value") + } + } +} diff --git a/pgxtype/README.md b/pgxtype/README.md new file mode 100644 index 00000000..a070111f --- /dev/null +++ b/pgxtype/README.md @@ -0,0 +1,3 @@ +# pgxtype + +pgxtype is a helper module that connects pgx and pgtype. This package is not currently covered by semantic version guarantees. i.e. The interfaces may change without a major version release of pgtype. diff --git a/pgxtype/pgxtype.go b/pgxtype/pgxtype.go new file mode 100644 index 00000000..041f2545 --- /dev/null +++ b/pgxtype/pgxtype.go @@ -0,0 +1,145 @@ +package pgxtype + +import ( + "context" + "errors" + + "github.com/jackc/pgconn" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" +) + +type Querier interface { + Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) + Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row +} + +// LoadDataType uses conn to inspect the database for typeName and produces a pgtype.DataType suitable for +// registration on ci. +func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeName string) (pgtype.DataType, error) { + var oid uint32 + + err := conn.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid) + if err != nil { + return pgtype.DataType{}, err + } + + var typtype string + + err = conn.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype) + if err != nil { + return pgtype.DataType{}, err + } + + switch typtype { + case "b": // array + elementOID, err := GetArrayElementOID(ctx, conn, oid) + if err != nil { + return pgtype.DataType{}, err + } + + var element pgtype.ValueTranscoder + if dt, ok := ci.DataTypeForOID(elementOID); ok { + if element, ok = dt.Value.(pgtype.ValueTranscoder); !ok { + return pgtype.DataType{}, errors.New("array element OID not registered as ValueTranscoder") + } + } + + newElement := func() pgtype.ValueTranscoder { + return pgtype.NewValue(element).(pgtype.ValueTranscoder) + } + + at := pgtype.NewArrayType(typeName, elementOID, newElement) + return pgtype.DataType{Value: at, Name: typeName, OID: oid}, nil + case "c": // composite + fields, err := GetCompositeFields(ctx, conn, oid) + if err != nil { + return pgtype.DataType{}, err + } + ct, err := pgtype.NewCompositeType(typeName, fields, ci) + if err != nil { + return pgtype.DataType{}, err + } + return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil + case "e": // enum + members, err := GetEnumMembers(ctx, conn, oid) + if err != nil { + return pgtype.DataType{}, err + } + return pgtype.DataType{Value: pgtype.NewEnumType(typeName, members), Name: typeName, OID: oid}, nil + default: + return pgtype.DataType{}, errors.New("unknown typtype") + } +} + +func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32, error) { + var typelem uint32 + + err := conn.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem) + if err != nil { + return 0, err + } + + return typelem, nil +} + +// GetCompositeFields gets the fields of a composite type. +func GetCompositeFields(ctx context.Context, conn Querier, oid uint32) ([]pgtype.CompositeTypeField, error) { + var typrelid uint32 + + err := conn.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid) + if err != nil { + return nil, err + } + + var fields []pgtype.CompositeTypeField + + rows, err := conn.Query(ctx, `select attname, atttypid +from pg_attribute +where attrelid=$1 +order by attnum`, typrelid) + if err != nil { + return nil, err + } + + for rows.Next() { + var f pgtype.CompositeTypeField + err := rows.Scan(&f.Name, &f.OID) + if err != nil { + return nil, err + } + fields = append(fields, f) + } + + if rows.Err() != nil { + return nil, rows.Err() + } + + return fields, nil +} + +// GetEnumMembers gets the possible values of the enum by oid. +func GetEnumMembers(ctx context.Context, conn Querier, oid uint32) ([]string, error) { + members := []string{} + + rows, err := conn.Query(ctx, "select enumlabel from pg_enum where enumtypid=$1 order by enumsortorder", oid) + if err != nil { + return nil, err + } + + for rows.Next() { + var m string + err := rows.Scan(&m) + if err != nil { + return nil, err + } + members = append(members, m) + } + + if rows.Err() != nil { + return nil, rows.Err() + } + + return members, nil +} diff --git a/point_test.go b/point_test.go new file mode 100644 index 00000000..63f8df07 --- /dev/null +++ b/point_test.go @@ -0,0 +1,150 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestPointTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "point", []interface{}{ + &pgtype.Point{P: pgtype.Vec2{1.234, 5.6789012345}, Status: pgtype.Present}, + &pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Status: pgtype.Present}, + &pgtype.Point{Status: pgtype.Null}, + }) +} + +func TestPoint_Set(t *testing.T) { + tests := []struct { + name string + arg interface{} + status pgtype.Status + wantErr bool + }{ + { + name: "first", + arg: "(12312.123123,123123.123123)", + status: pgtype.Present, + wantErr: false, + }, + { + name: "second", + arg: "(1231s2.123123,123123.123123)", + status: pgtype.Undefined, + wantErr: true, + }, + { + name: "third", + arg: []byte("(122.123123,123.123123)"), + status: pgtype.Present, + wantErr: false, + }, + { + name: "third", + arg: nil, + status: pgtype.Null, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dst := &pgtype.Point{} + if err := dst.Set(tt.arg); (err != nil) != tt.wantErr { + t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr) + } + if dst.Status != tt.status { + t.Errorf("Expected status: %v; got: %v", tt.status, dst.Status) + } + }) + } +} + +func TestPoint_MarshalJSON(t *testing.T) { + tests := []struct { + name string + point pgtype.Point + want []byte + wantErr bool + }{ + { + name: "first", + point: pgtype.Point{ + P: pgtype.Vec2{}, + Status: pgtype.Undefined, + }, + want: nil, + wantErr: true, + }, + { + name: "second", + point: pgtype.Point{ + P: pgtype.Vec2{X: 12.245, Y: 432.12}, + Status: pgtype.Present, + }, + want: []byte(`"(12.245,432.12)"`), + wantErr: false, + }, + { + name: "third", + point: pgtype.Point{ + P: pgtype.Vec2{}, + Status: pgtype.Null, + }, + want: []byte("null"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.point.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPoint_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + status pgtype.Status + arg []byte + wantErr bool + }{ + { + name: "first", + status: pgtype.Present, + arg: []byte(`"(123.123,54.12)"`), + wantErr: false, + }, + { + name: "second", + status: pgtype.Undefined, + arg: []byte(`"(123.123,54.1sad2)"`), + wantErr: true, + }, + { + name: "third", + status: pgtype.Null, + arg: []byte("null"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dst := &pgtype.Point{} + if err := dst.UnmarshalJSON(tt.arg); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if dst.Status != tt.status { + t.Errorf("Status mismatch: %v != %v", dst.Status, tt.status) + } + }) + } +} diff --git a/polygon_test.go b/polygon_test.go new file mode 100644 index 00000000..1a139444 --- /dev/null +++ b/polygon_test.go @@ -0,0 +1,89 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestPolygonTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "polygon", []interface{}{ + &pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, + Status: pgtype.Present, + }, + &pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, + Status: pgtype.Present, + }, + &pgtype.Polygon{Status: pgtype.Null}, + }) +} + +func TestPolygon_Set(t *testing.T) { + tests := []struct { + name string + arg interface{} + status pgtype.Status + wantErr bool + }{ + { + name: "string", + arg: "((3.14,1.678901234),(7.1,5.234),(5.0,3.234))", + status: pgtype.Present, + wantErr: false, + }, { + name: "[]float64", + arg: []float64{1, 2, 3.45, 6.78, 1.23, 4.567, 8.9, 1.0}, + status: pgtype.Present, + wantErr: false, + }, { + name: "[]Vec2", + arg: []pgtype.Vec2{{1, 2}, {2.3, 4.5}, {6.78, 9.123}}, + status: pgtype.Present, + wantErr: false, + }, { + name: "null", + arg: nil, + status: pgtype.Null, + wantErr: false, + }, { + name: "invalid_string_1", + arg: "((3.14,1.678901234),(7.1,5.234),(5.0,3.234x))", + status: pgtype.Undefined, + wantErr: true, + }, { + name: "invalid_string_2", + arg: "(3,4)", + status: pgtype.Undefined, + wantErr: true, + }, { + name: "invalid_[]float64", + arg: []float64{1, 2, 3.45, 6.78, 1.23, 4.567, 8.9}, + status: pgtype.Undefined, + wantErr: true, + }, { + name: "invalid_type", + arg: []int{1, 2, 3, 6}, + status: pgtype.Undefined, + wantErr: true, + }, { + name: "empty_[]float64", + arg: []float64{}, + status: pgtype.Null, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dst := &pgtype.Polygon{} + if err := dst.Set(tt.arg); (err != nil) != tt.wantErr { + t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr) + } + if dst.Status != tt.status { + t.Errorf("Expected status: %v; got: %v", tt.status, dst.Status) + } + }) + } +} diff --git a/qchar_test.go b/qchar_test.go new file mode 100644 index 00000000..4b60339c --- /dev/null +++ b/qchar_test.go @@ -0,0 +1,143 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestQCharTranscode(t *testing.T) { + testutil.TestPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ + &pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, + &pgtype.QChar{Int: -1, Status: pgtype.Present}, + &pgtype.QChar{Int: 0, Status: pgtype.Present}, + &pgtype.QChar{Int: 1, Status: pgtype.Present}, + &pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present}, + &pgtype.QChar{Int: 0, Status: pgtype.Null}, + }, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func TestQCharSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.QChar + }{ + {source: int8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.QChar + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestQCharAssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.QChar + dst interface{} + expected interface{} + }{ + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.QChar + dst interface{} + expected interface{} + }{ + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.QChar + dst interface{} + }{ + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &i16}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/range_test.go b/range_test.go new file mode 100644 index 00000000..9e16df59 --- /dev/null +++ b/range_test.go @@ -0,0 +1,177 @@ +package pgtype + +import ( + "bytes" + "testing" +) + +func TestParseUntypedTextRange(t *testing.T) { + tests := []struct { + src string + result UntypedTextRange + err error + }{ + { + src: `[1,2)`, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[1,2]`, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: `(1,3)`, + result: UntypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: ` [1,2) `, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[ foo , bar )`, + result: UntypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["foo","bar")`, + result: UntypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["f""oo","b""ar")`, + result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["f""oo","b""ar")`, + result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["","bar")`, + result: UntypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[f\"oo\,,b\\ar\))`, + result: UntypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `empty`, + result: UntypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty}, + err: nil, + }, + } + + for i, tt := range tests { + r, err := ParseUntypedTextRange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if r.LowerType != tt.result.LowerType { + t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) + } + + if r.UpperType != tt.result.UpperType { + t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) + } + + if r.Lower != tt.result.Lower { + t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) + } + + if r.Upper != tt.result.Upper { + t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) + } + } +} + +func TestParseUntypedBinaryRange(t *testing.T) { + tests := []struct { + src []byte + result UntypedBinaryRange + err error + }{ + { + src: []byte{0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{1}, + result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty}, + err: nil, + }, + { + src: []byte{2, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{4, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{6, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{8, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{12, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{16, 0, 0, 0, 2, 0, 4}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded}, + err: nil, + }, + { + src: []byte{18, 0, 0, 0, 2, 0, 4}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded}, + err: nil, + }, + { + src: []byte{24}, + result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded}, + err: nil, + }, + } + + for i, tt := range tests { + r, err := ParseUntypedBinaryRange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if r.LowerType != tt.result.LowerType { + t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) + } + + if r.UpperType != tt.result.UpperType { + t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) + } + + if bytes.Compare(r.Lower, tt.result.Lower) != 0 { + t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) + } + + if bytes.Compare(r.Upper, tt.result.Upper) != 0 { + t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) + } + } +} diff --git a/record_test.go b/record_test.go new file mode 100644 index 00000000..240812a6 --- /dev/null +++ b/record_test.go @@ -0,0 +1,186 @@ +package pgtype_test + +import ( + "context" + "fmt" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" +) + +var recordTests = []struct { + sql string + expected pgtype.Record +}{ + { + sql: `select row()`, + expected: pgtype.Record{ + Fields: []pgtype.Value{}, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row(100.0::float4, 1.09::float4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Float4{Float: 100, Status: pgtype.Present}, + &pgtype.Float4{Float: 1.09, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row(null)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Unknown{Status: pgtype.Null}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select null::record`, + expected: pgtype.Record{ + Status: pgtype.Null, + }, + }, +} + +func TestRecordTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + for i, tt := range recordTests { + psName := fmt.Sprintf("test%d", i) + _, err := conn.Prepare(context.Background(), psName, tt.sql) + if err != nil { + t.Fatal(err) + } + + t.Run(tt.sql, func(t *testing.T) { + var result pgtype.Record + if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil { + t.Errorf("%v", err) + return + } + + if !reflect.DeepEqual(tt.expected, result) { + t.Errorf("expected %#v, got %#v", tt.expected, result) + } + }) + + } +} + +func TestRecordWithUnknownOID(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists floatrange; + +create type floatrange as range ( + subtype = float8, + subtype_diff = float8mi +);`) + if err != nil { + t.Fatal(err) + } + defer conn.Exec(context.Background(), "drop type floatrange") + + var result pgtype.Record + err = conn.QueryRow(context.Background(), "select row('foo'::text, floatrange(1, 10), 'bar'::text)").Scan(&result) + if err == nil { + t.Errorf("expected error but none") + } +} + +func TestRecordAssignTo(t *testing.T) { + var valueSlice []pgtype.Value + var interfaceSlice []interface{} + + simpleTests := []struct { + src pgtype.Record + dst interface{} + expected interface{} + }{ + { + src: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + dst: &valueSlice, + expected: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + }, + { + src: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + dst: &interfaceSlice, + expected: []interface{}{"foo", int32(42)}, + }, + { + src: pgtype.Record{Status: pgtype.Null}, + dst: &valueSlice, + expected: (([]pgtype.Value)(nil)), + }, + { + src: pgtype.Record{Status: pgtype.Null}, + dst: &interfaceSlice, + expected: (([]interface{})(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/testutil/testutil.go b/testutil/testutil.go new file mode 100644 index 00000000..e7b64b58 --- /dev/null +++ b/testutil/testutil.go @@ -0,0 +1,436 @@ +package testutil + +import ( + "context" + "database/sql" + "fmt" + "os" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" + _ "github.com/jackc/pgx/v4/stdlib" + _ "github.com/lib/pq" +) + +func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { + var sqlDriverName string + switch driverName { + case "github.com/lib/pq": + sqlDriverName = "postgres" + case "github.com/jackc/pgx/stdlib": + sqlDriverName = "pgx" + default: + t.Fatalf("Unknown driver %v", driverName) + } + + db, err := sql.Open(sqlDriverName, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + + return db +} + +func MustConnectPgx(t testing.TB) *pgx.Conn { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + + return conn +} + +func MustClose(t testing.TB, conn interface { + Close() error +}) { + err := conn.Close() + if err != nil { + t.Fatal(err) + } +} + +func MustCloseContext(t testing.TB, conn interface { + Close(context.Context) error +}) { + err := conn.Close(context.Background()) + if err != nil { + t.Fatal(err) + } +} + +type forceTextEncoder struct { + e pgtype.TextEncoder +} + +func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + return f.e.EncodeText(ci, buf) +} + +type forceBinaryEncoder struct { + e pgtype.BinaryEncoder +} + +func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + return f.e.EncodeBinary(ci, buf) +} + +func ForceEncoder(e interface{}, formatCode int16) interface{} { + switch formatCode { + case pgx.TextFormatCode: + if e, ok := e.(pgtype.TextEncoder); ok { + return forceTextEncoder{e: e} + } + case pgx.BinaryFormatCode: + if e, ok := e.(pgtype.BinaryEncoder); ok { + return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} + } + } + return nil +} + +func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { + TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } +} + +func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := MustConnectPgx(t) + defer MustCloseContext(t, conn) + + _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for i, v := range values { + for _, paramFormat := range formats { + for _, resultFormat := range formats { + vEncoder := ForceEncoder(v, paramFormat.formatCode) + if vEncoder == nil { + t.Logf("Skipping Param %s Result %s: %#v does not implement %v for encoding", paramFormat.name, resultFormat.name, v, paramFormat.name) + continue + } + switch resultFormat.formatCode { + case pgx.TextFormatCode: + if _, ok := v.(pgtype.TextEncoder); !ok { + t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name) + continue + } + case pgx.BinaryFormatCode: + if _, ok := v.(pgtype.BinaryEncoder); !ok { + t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name) + continue + } + } + + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + + err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{resultFormat.formatCode}, vEncoder).Scan(result.Interface()) + if err != nil { + t.Errorf("Param %s Result %s %d: %v", paramFormat.name, resultFormat.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("Param %s Result %s %d: expected %v, got %v", paramFormat.name, resultFormat.name, i, derefV, result.Elem().Interface()) + } + } + } + } +} + +func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := ps.QueryRow(v).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } +} + +type NormalizeTest struct { + SQL string + Value interface{} +} + +func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) { + TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + TestDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc) + } +} + +func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + conn := MustConnectPgx(t) + defer MustCloseContext(t, conn) + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for i, tt := range tests { + for _, fc := range formats { + psName := fmt.Sprintf("test%d", i) + _, err := conn.Prepare(context.Background(), psName, tt.SQL) + if err != nil { + t.Fatal(err) + } + + queryResultFormats := pgx.QueryResultFormats{fc.formatCode} + if ForceEncoder(tt.Value, fc.formatCode) == nil { + t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name) + continue + } + // Derefence value if it is a pointer + derefV := tt.Value + refVal := reflect.ValueOf(tt.Value) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err = conn.QueryRow(context.Background(), psName, queryResultFormats).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + } + } + } +} + +func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) + + for i, tt := range tests { + ps, err := conn.Prepare(tt.SQL) + if err != nil { + t.Errorf("%d. %v", i, err) + continue + } + + // Derefence value if it is a pointer + derefV := tt.Value + refVal := reflect.ValueOf(tt.Value) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err = ps.QueryRow().Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } +} + +func TestGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) { + TestPgxGoZeroToNullConversion(t, pgTypeName, zero) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + TestDatabaseSQLGoZeroToNullConversion(t, driverName, pgTypeName, zero) + } +} + +func TestNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) { + TestPgxNullToGoZeroConversion(t, pgTypeName, zero) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + TestDatabaseSQLNullToGoZeroConversion(t, driverName, pgTypeName, zero) + } +} + +func TestPgxGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) { + conn := MustConnectPgx(t) + defer MustCloseContext(t, conn) + + _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s is null", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, paramFormat := range formats { + vEncoder := ForceEncoder(zero, paramFormat.formatCode) + if vEncoder == nil { + t.Logf("Skipping Param %s: %#v does not implement %v for encoding", paramFormat.name, zero, paramFormat.name) + continue + } + + var result bool + err := conn.QueryRow(context.Background(), "test", vEncoder).Scan(&result) + if err != nil { + t.Errorf("Param %s: %v", paramFormat.name, err) + } + + if !result { + t.Errorf("Param %s: did not convert zero to null", paramFormat.name) + } + } +} + +func TestPgxNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) { + conn := MustConnectPgx(t) + defer MustCloseContext(t, conn) + + _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select null::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, resultFormat := range formats { + + switch resultFormat.formatCode { + case pgx.TextFormatCode: + if _, ok := zero.(pgtype.TextEncoder); !ok { + t.Logf("Skipping Result %s: %#v does not implement %v for decoding", resultFormat.name, zero, resultFormat.name) + continue + } + case pgx.BinaryFormatCode: + if _, ok := zero.(pgtype.BinaryEncoder); !ok { + t.Logf("Skipping Result %s: %#v does not implement %v for decoding", resultFormat.name, zero, resultFormat.name) + continue + } + } + + // Derefence value if it is a pointer + derefZero := zero + refVal := reflect.ValueOf(zero) + if refVal.Kind() == reflect.Ptr { + derefZero = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefZero)) + + err := conn.QueryRow(context.Background(), "test").Scan(result.Interface()) + if err != nil { + t.Errorf("Result %s: %v", resultFormat.name, err) + } + + if !reflect.DeepEqual(result.Elem().Interface(), derefZero) { + t.Errorf("Result %s: did not convert null to zero", resultFormat.name) + } + } +} + +func TestDatabaseSQLGoZeroToNullConversion(t testing.TB, driverName, pgTypeName string, zero interface{}) { + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select $1::%s is null", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + var result bool + err = ps.QueryRow(zero).Scan(&result) + if err != nil { + t.Errorf("%v %v", driverName, err) + } + + if !result { + t.Errorf("%v: did not convert zero to null", driverName) + } +} + +func TestDatabaseSQLNullToGoZeroConversion(t testing.TB, driverName, pgTypeName string, zero interface{}) { + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select null::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + // Derefence value if it is a pointer + derefZero := zero + refVal := reflect.ValueOf(zero) + if refVal.Kind() == reflect.Ptr { + derefZero = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefZero)) + + err = ps.QueryRow().Scan(result.Interface()) + if err != nil { + t.Errorf("%v %v", driverName, err) + } + + if !reflect.DeepEqual(result.Elem().Interface(), derefZero) { + t.Errorf("%s: did not convert null to zero", driverName) + } +} diff --git a/text_array_test.go b/text_array_test.go new file mode 100644 index 00000000..a5d050f6 --- /dev/null +++ b/text_array_test.go @@ -0,0 +1,294 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// https://github.com/jackc/pgtype/issues/78 +func TestTextArrayDecodeTextNull(t *testing.T) { + textArray := &pgtype.TextArray{} + err := textArray.DecodeText(nil, []byte(`{abc,"NULL",NULL,def}`)) + require.NoError(t, err) + require.Len(t, textArray.Elements, 4) + assert.Equal(t, pgtype.Present, textArray.Elements[1].Status) + assert.Equal(t, pgtype.Null, textArray.Elements[2].Status) +} + +func TestTextArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "text[]", []interface{}{ + &pgtype.TextArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TextArray{Status: pgtype.Null}, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "bar ", Status: pgtype.Present}, + {String: "NuLL", Status: pgtype.Present}, + {String: `wow"quz\`, Status: pgtype.Present}, + {String: "", Status: pgtype.Present}, + {Status: pgtype.Null}, + {String: "null", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "quz", Status: pgtype.Present}, + {String: "foo", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestTextArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.TextArray + }{ + { + source: []string{"foo"}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.TextArray{Status: pgtype.Null}, + }, + { + source: [][]string{{"foo"}, {"bar"}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]string{{"foo"}, {"bar"}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.TextArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTextArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string + + simpleTests := []struct { + src pgtype.TextArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.TextArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + { + src: pgtype.TextArray{Status: pgtype.Present}, + dst: &stringSlice, + expected: []string{}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringSliceDim2, + expected: [][]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + expected: [2][1]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.TextArray + dst interface{} + }{ + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringSlice, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/text_test.go b/text_test.go new file mode 100644 index 00000000..cca3a05d --- /dev/null +++ b/text_test.go @@ -0,0 +1,164 @@ +package pgtype_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestTextTranscode(t *testing.T) { + for _, pgTypeName := range []string{"text", "varchar"} { + testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ + &pgtype.Text{String: "", Status: pgtype.Present}, + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Text{Status: pgtype.Null}, + }) + } +} + +func TestTextSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Text + }{ + {source: "foo", result: pgtype.Text{String: "foo", Status: pgtype.Present}}, + {source: _string("bar"), result: pgtype.Text{String: "bar", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.Text{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var d pgtype.Text + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestTextAssignTo(t *testing.T) { + var s string + var ps *string + + stringTests := []struct { + src pgtype.Text + dst interface{} + expected interface{} + }{ + {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &s, expected: "foo"}, + {src: pgtype.Text{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range stringTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + var buf []byte + + bytesTests := []struct { + src pgtype.Text + dst *[]byte + expected []byte + }{ + {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &buf, expected: []byte("foo")}, + {src: pgtype.Text{Status: pgtype.Null}, dst: &buf, expected: nil}, + } + + for i, tt := range bytesTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if bytes.Compare(*tt.dst, tt.expected) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Text + dst interface{} + expected interface{} + }{ + {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &ps, expected: "foo"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Text + dst interface{} + }{ + {src: pgtype.Text{Status: pgtype.Null}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} + +func TestTextMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Text + result string + }{ + {source: pgtype.Text{String: "", Status: pgtype.Null}, result: "null"}, + {source: pgtype.Text{String: "a", Status: pgtype.Present}, result: "\"a\""}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestTextUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Text + }{ + {source: "null", result: pgtype.Text{String: "", Status: pgtype.Null}}, + {source: "\"a\"", result: pgtype.Text{String: "a", Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.Text + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/tid_test.go b/tid_test.go new file mode 100644 index 00000000..818be8af --- /dev/null +++ b/tid_test.go @@ -0,0 +1,63 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestTIDTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "tid", []interface{}{ + &pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, + &pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, + &pgtype.TID{Status: pgtype.Null}, + }) +} + +func TestTIDAssignTo(t *testing.T) { + var s string + var sp *string + + simpleTests := []struct { + src pgtype.TID + dst interface{} + expected interface{} + }{ + {src: pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, dst: &s, expected: "(42,43)"}, + {src: pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, dst: &s, expected: "(4294967295,65535)"}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.TID + dst interface{} + expected interface{} + }{ + {src: pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, dst: &sp, expected: "(42,43)"}, + {src: pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, dst: &sp, expected: "(4294967295,65535)"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} + diff --git a/time_test.go b/time_test.go new file mode 100644 index 00000000..0af42b1e --- /dev/null +++ b/time_test.go @@ -0,0 +1,131 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestTimeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "time", []interface{}{ + &pgtype.Time{Microseconds: 0, Status: pgtype.Present}, + &pgtype.Time{Microseconds: 1, Status: pgtype.Present}, + &pgtype.Time{Microseconds: 86399999999, Status: pgtype.Present}, + &pgtype.Time{Status: pgtype.Null}, + }) +} + +// Test for transcoding 24:00:00 separately as github.com/lib/pq doesn't seem to support it. +func TestTimeTranscode24HH(t *testing.T) { + pgTypeName := "time" + values := []interface{}{ + &pgtype.Time{Microseconds: 86400000000, Status: pgtype.Present}, + } + + eqFunc := func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + } + + testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, "github.com/jackc/pgx/stdlib", pgTypeName, values, eqFunc) +} + +func TestTimeSet(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Time + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 0, Status: pgtype.Present}}, + {source: time.Date(1900, 1, 1, 1, 0, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 3600000000, Status: pgtype.Present}}, + {source: time.Date(1900, 1, 1, 0, 1, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 60000000, Status: pgtype.Present}}, + {source: time.Date(1900, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Time{Microseconds: 1000000, Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 1, time.UTC), result: pgtype.Time{Microseconds: 0, Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 1000, time.UTC), result: pgtype.Time{Microseconds: 1, Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 23, 59, 59, 999999999, time.UTC), result: pgtype.Time{Microseconds: 86399999999, Status: pgtype.Present}}, + {source: time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local), result: pgtype.Time{Microseconds: 2, Status: pgtype.Present}}, + {source: func(t time.Time) *time.Time { return &t }(time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local)), result: pgtype.Time{Microseconds: 2, Status: pgtype.Present}}, + {source: nil, result: pgtype.Time{Status: pgtype.Null}}, + {source: (*time.Time)(nil), result: pgtype.Time{Status: pgtype.Null}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 3000, time.UTC)), result: pgtype.Time{Microseconds: 3, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Time + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimeAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Time + dst interface{} + expected interface{} + }{ + {src: pgtype.Time{Microseconds: 0, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 3600000000, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 1, 0, 0, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 60000000, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 1, 0, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 1000000, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 1, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 0, 1000, time.UTC)}, + {src: pgtype.Time{Microseconds: 86399999999, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 23, 59, 59, 999999000, time.UTC)}, + {src: pgtype.Time{Microseconds: 0, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Time + dst interface{} + expected interface{} + }{ + {src: pgtype.Time{Microseconds: 0, Status: pgtype.Present}, dst: &ptim, expected: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Time + dst interface{} + }{ + {src: pgtype.Time{Microseconds: 86400000000, Status: pgtype.Present}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/timestamp_array_test.go b/timestamp_array_test.go new file mode 100644 index 00000000..54d15b24 --- /dev/null +++ b/timestamp_array_test.go @@ -0,0 +1,307 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestTimestampArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp[]", []interface{}{ + &pgtype.TimestampArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TimestampArray{Status: pgtype.Null}, + &pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Status: pgtype.Null}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }, func(a, b interface{}) bool { + ata := a.(pgtype.TimestampArray) + bta := b.(pgtype.TimestampArray) + + if len(ata.Elements) != len(bta.Elements) || ata.Status != bta.Status { + return false + } + + for i := range ata.Elements { + ae, be := ata.Elements[i], bta.Elements[i] + if !(ae.Time.Equal(be.Time) && ae.Status == be.Status && ae.InfinityModifier == be.InfinityModifier) { + return false + } + } + + return true + }) +} + +func TestTimestampArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.TimestampArray + }{ + { + source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + result: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]time.Time)(nil)), + result: pgtype.TimestampArray{Status: pgtype.Null}, + }, + { + source: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.TimestampArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestampArrayAssignTo(t *testing.T) { + var timeSlice []time.Time + var timeSliceDim2 [][]time.Time + var timeSliceDim4 [][][][]time.Time + var timeArrayDim2 [2][1]time.Time + var timeArrayDim4 [2][1][1][3]time.Time + + simpleTests := []struct { + src pgtype.TimestampArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + }, + { + src: pgtype.TimestampArray{Status: pgtype.Null}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, + { + src: pgtype.TimestampArray{Status: pgtype.Present}, + dst: &timeSlice, + expected: []time.Time{}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeSliceDim2, + expected: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeSliceDim4, + expected: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + expected: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + expected: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.TimestampArray + dst interface{} + }{ + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeSlice, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/timestamp_test.go b/timestamp_test.go new file mode 100644 index 00000000..74cb1221 --- /dev/null +++ b/timestamp_test.go @@ -0,0 +1,178 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestTimestampTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ + &pgtype.Timestamp{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Status: pgtype.Null}, + &pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + &pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Timestamp) + bt := b.(pgtype.Timestamp) + + return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier + }) +} + +func TestTimestampNanosecondsTruncated(t *testing.T) { + tests := []struct { + input time.Time + expected time.Time + }{ + {time.Date(2020, 1, 1, 0, 0, 0, 999999999, time.UTC), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.UTC)}, + {time.Date(2020, 1, 1, 0, 0, 0, 999999001, time.UTC), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.UTC)}, + } + for i, tt := range tests { + { + ts := pgtype.Timestamp{Time: tt.input, Status: pgtype.Present} + buf, err := ts.EncodeText(nil, nil) + if err != nil { + t.Errorf("%d. EncodeText failed - %v", i, err) + } + + ts.DecodeText(nil, buf) + if err != nil { + t.Errorf("%d. DecodeText failed - %v", i, err) + } + + if !(ts.Status == pgtype.Present && ts.Time.Equal(tt.expected)) { + t.Errorf("%d. EncodeText did not truncate nanoseconds", i) + } + } + + { + ts := pgtype.Timestamp{Time: tt.input, Status: pgtype.Present} + buf, err := ts.EncodeBinary(nil, nil) + if err != nil { + t.Errorf("%d. EncodeBinary failed - %v", i, err) + } + + ts.DecodeBinary(nil, buf) + if err != nil { + t.Errorf("%d. DecodeBinary failed - %v", i, err) + } + + if !(ts.Status == pgtype.Present && ts.Time.Equal(tt.expected)) { + t.Errorf("%d. EncodeBinary did not truncate nanoseconds", i) + } + } + } +} + +// https://github.com/jackc/pgtype/issues/74 +func TestTimestampDecodeTextInvalid(t *testing.T) { + tstz := &pgtype.Timestamp{} + err := tstz.DecodeText(nil, []byte(`eeeee`)) + require.Error(t, err) +} + +func TestTimestampSet(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Timestamp + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: pgtype.Infinity, result: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, + {source: pgtype.NegativeInfinity, result: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Timestamp + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestampAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Timestamp + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)}, + {src: pgtype.Timestamp{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Timestamp + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Timestamp + dst interface{} + }{ + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/timestamptz_array_test.go b/timestamptz_array_test.go new file mode 100644 index 00000000..9856e4e7 --- /dev/null +++ b/timestamptz_array_test.go @@ -0,0 +1,343 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestTimestamptzArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz[]", []interface{}{ + &pgtype.TimestamptzArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TimestamptzArray{Status: pgtype.Null}, + &pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Status: pgtype.Null}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }, func(a, b interface{}) bool { + ata := a.(pgtype.TimestamptzArray) + bta := b.(pgtype.TimestamptzArray) + + if len(ata.Elements) != len(bta.Elements) || ata.Status != bta.Status { + return false + } + + for i := range ata.Elements { + ae, be := ata.Elements[i], bta.Elements[i] + if !(ae.Time.Equal(be.Time) && ae.Status == be.Status && ae.InfinityModifier == be.InfinityModifier) { + return false + } + } + + return true + }) +} + +func TestTimestamptzArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.TimestamptzArray + }{ + { + source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]time.Time)(nil)), + result: pgtype.TimestamptzArray{Status: pgtype.Null}, + }, + { + source: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.TimestamptzArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestamptzArrayAssignTo(t *testing.T) { + var timeSlice []time.Time + var timeSliceDim2 [][]time.Time + var timeSliceDim4 [][][][]time.Time + var timeArrayDim2 [2][1]time.Time + var timeArrayDim4 [2][1][1][3]time.Time + + simpleTests := []struct { + src pgtype.TimestamptzArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + }, + { + src: pgtype.TimestamptzArray{Status: pgtype.Null}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, + { + src: pgtype.TimestamptzArray{Status: pgtype.Present}, + dst: &timeSlice, + expected: []time.Time{}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeSliceDim2, + expected: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeSliceDim4, + expected: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + expected: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + expected: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.TimestamptzArray + dst interface{} + }{ + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeSlice, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/timestamptz_test.go b/timestamptz_test.go new file mode 100644 index 00000000..769c9239 --- /dev/null +++ b/timestamptz_test.go @@ -0,0 +1,224 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestTimestamptzTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ + &pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Status: pgtype.Null}, + &pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + &pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Timestamptz) + bt := b.(pgtype.Timestamptz) + + return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier + }) +} + +func TestTimestamptzNanosecondsTruncated(t *testing.T) { + tests := []struct { + input time.Time + expected time.Time + }{ + {time.Date(2020, 1, 1, 0, 0, 0, 999999999, time.Local), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.Local)}, + {time.Date(2020, 1, 1, 0, 0, 0, 999999001, time.Local), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.Local)}, + } + for i, tt := range tests { + { + tstz := pgtype.Timestamptz{Time: tt.input, Status: pgtype.Present} + buf, err := tstz.EncodeText(nil, nil) + if err != nil { + t.Errorf("%d. EncodeText failed - %v", i, err) + } + + tstz.DecodeText(nil, buf) + if err != nil { + t.Errorf("%d. DecodeText failed - %v", i, err) + } + + if !(tstz.Status == pgtype.Present && tstz.Time.Equal(tt.expected)) { + t.Errorf("%d. EncodeText did not truncate nanoseconds", i) + } + } + + { + tstz := pgtype.Timestamptz{Time: tt.input, Status: pgtype.Present} + buf, err := tstz.EncodeBinary(nil, nil) + if err != nil { + t.Errorf("%d. EncodeBinary failed - %v", i, err) + } + + tstz.DecodeBinary(nil, buf) + if err != nil { + t.Errorf("%d. DecodeBinary failed - %v", i, err) + } + + if !(tstz.Status == pgtype.Present && tstz.Time.Equal(tt.expected)) { + t.Errorf("%d. EncodeBinary did not truncate nanoseconds", i) + } + } + } +} + +// https://github.com/jackc/pgtype/issues/74 +func TestTimestamptzDecodeTextInvalid(t *testing.T) { + tstz := &pgtype.Timestamptz{} + err := tstz.DecodeText(nil, []byte(`eeeee`)) + require.Error(t, err) +} + +func TestTimestamptzSet(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Timestamptz + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: pgtype.Infinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, + {source: pgtype.NegativeInfinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Timestamptz + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestamptzAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Timestamptz + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {src: pgtype.Timestamptz{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Timestamptz + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Timestamptz + dst interface{} + }{ + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} + +func TestTimestamptzMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Timestamptz + result string + }{ + {source: pgtype.Timestamptz{Status: pgtype.Null}, result: "null"}, + {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29T10:05:45-06:00\""}, + {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29T10:05:45.555-06:00\""}, + {source: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, result: "\"infinity\""}, + {source: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, result: "\"-infinity\""}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestTimestamptzUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Timestamptz + }{ + {source: "null", result: pgtype.Timestamptz{Status: pgtype.Null}}, + {source: "\"2012-03-29T10:05:45-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, + {source: "\"2012-03-29T10:05:45.555-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, + {source: "\"infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, + {source: "\"-infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.Timestamptz + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !r.Time.Equal(tt.result.Time) || r.Status != tt.result.Status || r.InfinityModifier != tt.result.InfinityModifier { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/tsrange_test.go b/tsrange_test.go new file mode 100644 index 00000000..1be0c7d2 --- /dev/null +++ b/tsrange_test.go @@ -0,0 +1,41 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestTsrangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "tsrange", []interface{}{ + &pgtype.Tsrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Tsrange{ + Lower: pgtype.Timestamp{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Timestamp{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Tsrange{ + Lower: pgtype.Timestamp{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Tsrange{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Tsrange) + b := bb.(pgtype.Tsrange) + + return a.Status == b.Status && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Status == b.Lower.Status && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Status == b.Upper.Status && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} diff --git a/tstzrange_test.go b/tstzrange_test.go new file mode 100644 index 00000000..f8e2c2c5 --- /dev/null +++ b/tstzrange_test.go @@ -0,0 +1,49 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestTstzrangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "tstzrange", []interface{}{ + &pgtype.Tstzrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Tstzrange{ + Lower: pgtype.Timestamptz{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Timestamptz{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Tstzrange{ + Lower: pgtype.Timestamptz{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Tstzrange{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Tstzrange) + b := bb.(pgtype.Tstzrange) + + return a.Status == b.Status && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Status == b.Lower.Status && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Status == b.Upper.Status && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} + +// https://github.com/jackc/pgtype/issues/74 +func TestTstzRangeDecodeTextInvalid(t *testing.T) { + tstzrange := &pgtype.Tstzrange{} + err := tstzrange.DecodeText(nil, []byte(`[eeee,)`)) + require.Error(t, err) +} diff --git a/uuid_array_test.go b/uuid_array_test.go new file mode 100644 index 00000000..7d822e7a --- /dev/null +++ b/uuid_array_test.go @@ -0,0 +1,368 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestUUIDArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid[]", []interface{}{ + &pgtype.UUIDArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.UUIDArray{Status: pgtype.Null}, + &pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestUUIDArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.UUIDArray + }{ + { + source: nil, + result: pgtype.UUIDArray{Status: pgtype.Null}, + }, + { + source: [][16]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][16]byte{}, + result: pgtype.UUIDArray{Status: pgtype.Present}, + }, + { + source: ([][16]byte)(nil), + result: pgtype.UUIDArray{Status: pgtype.Null}, + }, + { + source: [][]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][]byte{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, + nil, + {32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, + }, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 4}}, + Status: pgtype.Present}, + }, + { + source: [][]byte{}, + result: pgtype.UUIDArray{Status: pgtype.Present}, + }, + { + source: ([][]byte)(nil), + result: pgtype.UUIDArray{Status: pgtype.Null}, + }, + { + source: []string{"00010203-0405-0607-0809-0a0b0c0d0e0f"}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []string{}, + result: pgtype.UUIDArray{Status: pgtype.Present}, + }, + { + source: ([]string)(nil), + result: pgtype.UUIDArray{Status: pgtype.Null}, + }, + { + source: [][][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.UUIDArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestUUIDArrayAssignTo(t *testing.T) { + var byteArraySlice [][16]byte + var byteSliceSlice [][]byte + var stringSlice []string + var byteSlice []byte + var byteArraySliceDim2 [][][16]byte + var stringSliceDim4 [][][][]string + var byteArrayDim2 [2][1][16]byte + var stringArrayDim4 [2][1][1][3]string + + simpleTests := []struct { + src pgtype.UUIDArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &byteArraySlice, + expected: [][16]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + }, + { + src: pgtype.UUIDArray{Status: pgtype.Null}, + dst: &byteArraySlice, + expected: ([][16]byte)(nil), + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &byteSliceSlice, + expected: [][]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + }, + { + src: pgtype.UUIDArray{Status: pgtype.Null}, + dst: &byteSliceSlice, + expected: ([][]byte)(nil), + }, + { + src: pgtype.UUIDArray{Status: pgtype.Present}, + dst: &byteSlice, + expected: []byte{}, + }, + { + src: pgtype.UUIDArray{Status: pgtype.Present}, + dst: &stringSlice, + expected: []string{}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"00010203-0405-0607-0809-0a0b0c0d0e0f"}, + }, + { + src: pgtype.UUIDArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: ([]string)(nil), + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &byteArraySliceDim2, + expected: [][][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &byteArrayDim2, + expected: [2][1][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/uuid_test.go b/uuid_test.go new file mode 100644 index 00000000..5a93ea8d --- /dev/null +++ b/uuid_test.go @@ -0,0 +1,245 @@ +package pgtype_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestUUIDTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ + &pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + &pgtype.UUID{Status: pgtype.Null}, + }) +} + +type SomeUUIDWrapper struct { + SomeUUIDType +} + +type SomeUUIDType [16]byte + +func TestUUIDSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.UUID + }{ + { + source: nil, + result: pgtype.UUID{Status: pgtype.Null}, + }, + { + source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: SomeUUIDType{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: ([]byte)(nil), + result: pgtype.UUID{Status: pgtype.Null}, + }, + { + source: "00010203-0405-0607-0809-0a0b0c0d0e0f", + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: "000102030405060708090a0b0c0d0e0f", + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.UUID + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestUUIDAssignTo(t *testing.T) { + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst [16]byte + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst []byte + expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare(dst, expected) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst SomeUUIDType + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst string + expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst SomeUUIDWrapper + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst.SomeUUIDType != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } +} + +func TestUUID_MarshalJSON(t *testing.T) { + tests := []struct { + name string + src pgtype.UUID + want []byte + wantErr bool + }{ + { + name: "first", + src: pgtype.UUID{ + Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, + Status: pgtype.Present, + }, + want: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), + wantErr: false, + }, + { + name: "second", + src: pgtype.UUID{ + Bytes: [16]byte{}, + Status: pgtype.Undefined, + }, + want: nil, + wantErr: true, + }, + { + name: "third", + src: pgtype.UUID{ + Bytes: [16]byte{}, + Status: pgtype.Null, + }, + want: []byte("null"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.src.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUUID_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + want *pgtype.UUID + src []byte + wantErr bool + }{ + { + name: "first", + want: &pgtype.UUID{ + Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, + Status: pgtype.Present, + }, + src: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), + wantErr: false, + }, + { + name: "second", + want: &pgtype.UUID{ + Bytes: [16]byte{}, + Status: pgtype.Null, + }, + src: []byte("null"), + wantErr: false, + }, + { + name: "third", + want: &pgtype.UUID{ + Bytes: [16]byte{}, + Status: pgtype.Undefined, + }, + src: []byte("1d485a7a-6d18-4599-8c6c-34425616887a"), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &pgtype.UUID{} + if err := got.UnmarshalJSON(tt.src); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/varbit_test.go b/varbit_test.go new file mode 100644 index 00000000..3c5aea1e --- /dev/null +++ b/varbit_test.go @@ -0,0 +1,26 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestVarbitTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "varbit", []interface{}{ + &pgtype.Varbit{Bytes: []byte{}, Len: 0, Status: pgtype.Present}, + &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Status: pgtype.Present}, + &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Status: pgtype.Present}, + &pgtype.Varbit{Status: pgtype.Null}, + }) +} + +func TestVarbitNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select B'111111111'", + Value: &pgtype.Varbit{Bytes: []byte{255, 128}, Len: 9, Status: pgtype.Present}, + }, + }) +} diff --git a/varchar_array_test.go b/varchar_array_test.go new file mode 100644 index 00000000..5fb7326d --- /dev/null +++ b/varchar_array_test.go @@ -0,0 +1,282 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestVarcharArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "varchar[]", []interface{}{ + &pgtype.VarcharArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{Status: pgtype.Null}, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "bar ", Status: pgtype.Present}, + {String: "NuLL", Status: pgtype.Present}, + {String: `wow"quz\`, Status: pgtype.Present}, + {String: "", Status: pgtype.Present}, + {Status: pgtype.Null}, + {String: "null", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "quz", Status: pgtype.Present}, + {String: "foo", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestVarcharArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.VarcharArray + }{ + { + source: []string{"foo"}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.VarcharArray{Status: pgtype.Null}, + }, + { + source: [][]string{{"foo"}, {"bar"}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]string{{"foo"}, {"bar"}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.VarcharArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestVarcharArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string + + simpleTests := []struct { + src pgtype.VarcharArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.VarcharArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + { + src: pgtype.VarcharArray{Status: pgtype.Present}, + dst: &stringSlice, + expected: []string{}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringSliceDim2, + expected: [][]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + expected: [2][1]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.VarcharArray + dst interface{} + }{ + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringSlice, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/xid_test.go b/xid_test.go new file mode 100644 index 00000000..563ce96e --- /dev/null +++ b/xid_test.go @@ -0,0 +1,105 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestXIDTranscode(t *testing.T) { + pgTypeName := "xid" + values := []interface{}{ + &pgtype.XID{Uint: 42, Status: pgtype.Present}, + &pgtype.XID{Status: pgtype.Null}, + } + eqFunc := func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + } + + testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } +} + +func TestXIDSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.XID + }{ + {source: uint32(1), result: pgtype.XID{Uint: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.XID + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestXIDAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.XID + dst interface{} + expected interface{} + }{ + {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.XID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.XID + dst interface{} + expected interface{} + }{ + {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.XID + dst interface{} + }{ + {src: pgtype.XID{Status: pgtype.Null}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/zeronull/int2_test.go b/zeronull/int2_test.go new file mode 100644 index 00000000..2dcb4e79 --- /dev/null +++ b/zeronull/int2_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestInt2Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ + (zeronull.Int2)(1), + (zeronull.Int2)(0), + }) +} + +func TestInt2ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "int2", (zeronull.Int2)(0)) +} + +func TestInt2ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "int2", (zeronull.Int2)(0)) +} diff --git a/zeronull/int4_test.go b/zeronull/int4_test.go new file mode 100644 index 00000000..309e4125 --- /dev/null +++ b/zeronull/int4_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestInt4Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ + (zeronull.Int4)(1), + (zeronull.Int4)(0), + }) +} + +func TestInt4ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "int4", (zeronull.Int4)(0)) +} + +func TestInt4ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "int4", (zeronull.Int4)(0)) +} diff --git a/zeronull/int8_test.go b/zeronull/int8_test.go new file mode 100644 index 00000000..ae80bc0a --- /dev/null +++ b/zeronull/int8_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestInt8Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ + (zeronull.Int8)(1), + (zeronull.Int8)(0), + }) +} + +func TestInt8ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "int8", (zeronull.Int8)(0)) +} + +func TestInt8ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "int8", (zeronull.Int8)(0)) +} diff --git a/zeronull/text_test.go b/zeronull/text_test.go new file mode 100644 index 00000000..f08a0d2a --- /dev/null +++ b/zeronull/text_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestTextTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "text", []interface{}{ + (zeronull.Text)("foo"), + (zeronull.Text)(""), + }) +} + +func TestTextConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "text", (zeronull.Text)("")) +} + +func TestTextConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "text", (zeronull.Text)("")) +} diff --git a/zeronull/timestamp_test.go b/zeronull/timestamp_test.go new file mode 100644 index 00000000..ec96ff07 --- /dev/null +++ b/zeronull/timestamp_test.go @@ -0,0 +1,29 @@ +package zeronull_test + +import ( + "testing" + "time" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestTimestampTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ + (zeronull.Timestamp)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + (zeronull.Timestamp)(time.Time{}), + }, func(a, b interface{}) bool { + at := a.(zeronull.Timestamp) + bt := b.(zeronull.Timestamp) + + return time.Time(at).Equal(time.Time(bt)) + }) +} + +func TestTimestampConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "timestamp", (zeronull.Timestamp)(time.Time{})) +} + +func TestTimestampConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "timestamp", (zeronull.Timestamp)(time.Time{})) +} diff --git a/zeronull/timestamptz_test.go b/zeronull/timestamptz_test.go new file mode 100644 index 00000000..3a401c49 --- /dev/null +++ b/zeronull/timestamptz_test.go @@ -0,0 +1,29 @@ +package zeronull_test + +import ( + "testing" + "time" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestTimestamptzTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ + (zeronull.Timestamptz)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + (zeronull.Timestamptz)(time.Time{}), + }, func(a, b interface{}) bool { + at := a.(zeronull.Timestamptz) + bt := b.(zeronull.Timestamptz) + + return time.Time(at).Equal(time.Time(bt)) + }) +} + +func TestTimestamptzConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "timestamptz", (zeronull.Timestamptz)(time.Time{})) +} + +func TestTimestamptzConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "timestamptz", (zeronull.Timestamptz)(time.Time{})) +} diff --git a/zeronull/uuid_test.go b/zeronull/uuid_test.go new file mode 100644 index 00000000..162bdf1f --- /dev/null +++ b/zeronull/uuid_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestUUIDTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ + (*zeronull.UUID)(&[16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), + (*zeronull.UUID)(&[16]byte{}), + }) +} + +func TestUUIDConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "uuid", (*zeronull.UUID)(&[16]byte{})) +} + +func TestUUIDConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "uuid", (*zeronull.UUID)(&[16]byte{})) +} From 377eed5d2f44de254357e923b1a59fef7cf02089 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Jul 2021 10:48:07 -0500 Subject: [PATCH 0702/1158] Cleaning go.sum --- go.mod | 4 +- go.sum | 331 ++------------------------------------------------------- 2 files changed, 12 insertions(+), 323 deletions(-) diff --git a/go.mod b/go.mod index 29e6f628..63bae879 100644 --- a/go.mod +++ b/go.mod @@ -4,9 +4,9 @@ go 1.13 require ( github.com/gofrs/uuid v4.0.0+incompatible - github.com/jackc/pgconn v1.9.0 + github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530 github.com/jackc/pgio v1.0.0 - github.com/jackc/pgx/v4 v4.12.0 + github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c github.com/lib/pq v1.10.2 github.com/shopspring/decimal v1.2.0 github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index e49ce26f..8f2d760e 100644 --- a/go.sum +++ b/go.sum @@ -1,127 +1,20 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= -github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= -github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= -github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= -github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c= -github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= -github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= -github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= -github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= -github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= -github.com/aryann/difflib v0.0.0-20170710044230-e206f873d14a/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A= -github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU= -github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= -github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= -github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= -github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ= -github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec/go.mod h1:jMjuTZXRI4dUb/I5gc9Hdhagfvm9+RyrPryS/auMzxE= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= -github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= -github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= -github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= -github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= -github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= -github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= -github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= -github.com/envoyproxy/go-control-plane v0.6.9/go.mod h1:SBwIajubJHhxtWwsL9s8ss4safvEdbitLhGGK48rN6g= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= -github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= -github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-kit/kit v0.10.0/go.mod h1:xUsJbQ/Fp4kEt7AFgCuvyX4a71u8h9jB8tj/ORgOZ7o= -github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= -github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= +github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= -github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= -github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= -github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= -github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= -github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= -github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= -github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= -github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= -github.com/hashicorp/consul/api v1.3.0/go.mod h1:MmDNSzIMUjNpY/mQ398R4bk2FnqQLoPndWW5VkKPlCE= -github.com/hashicorp/consul/sdk v0.3.0/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= -github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= -github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= -github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= -github.com/hashicorp/go-msgpack v0.5.3/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= -github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= -github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU= -github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= -github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= -github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= -github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= -github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= -github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= -github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= -github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= -github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= @@ -130,18 +23,16 @@ github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgO github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= -github.com/jackc/pgconn v1.4.0/go.mod h1:Y2O3ZDF0q4mMacyWV3AstPJpeHXWGEetiFttmq5lahk= -github.com/jackc/pgconn v1.5.0/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= -github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= -github.com/jackc/pgconn v1.8.1/go.mod h1:JV6m6b6jhjdmzchES0drzCcYcAHS1OPD5xu3OZ/lE2g= -github.com/jackc/pgconn v1.9.0 h1:gqibKSTJup/ahCsNKyMZAniPuZEfIqfXFc8FOWVYR+Q= github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= +github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530 h1:dUJ578zuPEsXjtzOfEF0q9zDAfljJ9oFnTHcQaNkccw= +github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= -github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd h1:eDErF6V/JPJON/B7s68BxwHgfmyOntHJQ8IOaz0x4R8= github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= @@ -150,47 +41,26 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= -github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4WpC0= -github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= -github.com/jackc/pgtype v1.3.1-0.20200606141011-f6355165a91c/go.mod h1:cvk9Bgu/VzJ9/lxTO5R5sf80p0DiucVtN7ZxvaC4GmQ= -github.com/jackc/pgtype v1.7.0/go.mod h1:ZnHF+rMePVqDKaOfJVI4Q8IVvAQMryDlDkZnKOI75BE= -github.com/jackc/pgtype v1.8.0/go.mod h1:PqDKcEBtllAtk/2p6z6SHdXW5UB+MhE75tUol2OKexE= +github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= -github.com/jackc/pgx/v4 v4.5.0/go.mod h1:EpAKPLdnTorwmPUUsqrPxy5fphV18j9q3wrfRXgo+kA= -github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9/go.mod h1:t3/cdRQl6fOLDxqtlyhe9UWgfIi9R8+8v8GKV5TRA/o= -github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904/go.mod h1:ZDaNWkt9sW1JMiNn0kdYBaLelIhw7Pg4qd+Vk6tw7Hg= -github.com/jackc/pgx/v4 v4.11.0/go.mod h1:i62xJgdrtVDsnL3U8ekyrQXEwGNTRoG7/8r+CIdYfcc= -github.com/jackc/pgx/v4 v4.12.0 h1:xiP3TdnkwyslWNp77yE5XAPfxAsU9RMFDe0c1SwN8h4= -github.com/jackc/pgx/v4 v4.12.0/go.mod h1:fE547h6VulLPA3kySjfnSG/e2D861g/50JlVUa/ub60= +github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c h1:Dznn52SgVIVst9UyOT9brctYUgxs+CvVfPaC3jKrA50= +github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.1.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.1.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= -github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= -github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= -github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/json-iterator/go v1.1.8/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= -github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= -github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -200,117 +70,27 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= -github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= -github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= -github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= -github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= -github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= -github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= -github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= -github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= -github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= -github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/nats-io/jwt v0.3.0/go.mod h1:fRYCDE99xlTsqUzISS1Bi75UBJ6ljOJQOAAu5VglpSg= -github.com/nats-io/jwt v0.3.2/go.mod h1:/euKqTS1ZD+zzjYrY7pseZrTtWQSjujC7xjPc8wL6eU= -github.com/nats-io/nats-server/v2 v2.1.2/go.mod h1:Afk+wRZqkMQs/p45uXdrVLuab3gwv3Z8C4HTBu8GD/k= -github.com/nats-io/nats.go v1.9.1/go.mod h1:ZjDU1L/7fJ09jvUSRVBR2e7+RnLiiIQyqyzEE/Zbp4w= -github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= -github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= -github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= -github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs= -github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= -github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= -github.com/op/go-logging v0.0.0-20160315200505-970db520ece7/go.mod h1:HzydrMdWErDVzsI23lYNej1Htcns9BCg93Dk0bBINWk= -github.com/opentracing-contrib/go-observer v0.0.0-20170622124052-a52f23424492/go.mod h1:Ngi6UdF0k5OKD5t5wlmGhe/EDKPoUM3BXZSSfIuJbis= -github.com/opentracing/basictracer-go v1.0.0/go.mod h1:QfBfYuafItcjQuMwinw9GhYKwFXS9KnPs5lxoYwgW74= -github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= -github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= -github.com/openzipkin-contrib/zipkin-go-opentracing v0.4.5/go.mod h1:/wsWhb9smxSfWAKL3wpBW7V8scJMt8N8gnaMCS9E/cA= -github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= -github.com/openzipkin/zipkin-go v0.2.1/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= -github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= -github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM= -github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= -github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= -github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac= -github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= -github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= -github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs= -github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= -github.com/prometheus/client_golang v1.3.0/go.mod h1:hJaj2vgQTGQmVCsAACORcieXFeDPbaTKGT+JTgUa3og= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.1.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= -github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= -github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= -github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= -github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= -github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= -github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= -github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= -github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= -github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= -github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= -github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= -github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= -github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= -github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= -github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= @@ -320,16 +100,7 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= -github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= -github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= -github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= -go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= -go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg= -go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= -go.opencensus.io v0.20.2/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= -go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= @@ -341,73 +112,33 @@ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9E go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= -golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e h1:gsTQYXdTw2Gq7RBsWvlQ91b+aEQ6bXFUngBGuR8sPpI= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= -golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -420,18 +151,9 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= @@ -441,46 +163,13 @@ golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190530194941-fb225487d101/go.mod h1:z3L6/3dTEVtUr6QSP8miRzeRqwQOioJ9I66odjN4I7s= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.20.0/go.mod h1:chYK+tFQF0nDUGJgXMSgLCQk3phJEuONr2DCgLDdAQM= -google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= -google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= -google.golang.org/grpc v1.22.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= -gopkg.in/gcfg.v1 v1.2.3/go.mod h1:yesOnuUOFQAhST5vPY4nbZsb/huCgGGXlipJsBn0b3o= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= -gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= -gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= -sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU= From e26c6b4e3d1c1d25a086e5da106165f1819d62e2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Jul 2021 10:50:22 -0500 Subject: [PATCH 0703/1158] Release v1.8.1 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c8514e3..64d96fa0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.8.1 (July 24, 2021) + +* Cleaned up Go module dependency chain + # 1.8.0 (July 10, 2021) * Maintain host bits for inet types (Cameron Daniel) From 53f5fed36c570f0b5c98d6ec2415658c7b9bd11c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Jul 2021 10:52:26 -0500 Subject: [PATCH 0704/1158] Release v1.10.0 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c496ea30..45c02f1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.10.0 (July 24, 2021) + +* net.Timeout errors are no longer returned when a query is canceled via context. A wrapped context error is returned. + # 1.9.0 (July 10, 2021) * pgconn.Timeout only is true for errors originating in pgconn (Michael Darr) From 6bda09691dca413fe365dff0960d6cf9fc57071d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Jul 2021 11:06:03 -0500 Subject: [PATCH 0705/1158] Fix hstore binary null decoding Bug was advancing the read pointer by the length of the value even if it was a NULL value. Since NULL is indicated by a -1 length it actually decremented the read pointer. --- hstore.go | 2 +- hstore_test.go | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/hstore.go b/hstore.go index 18b413c6..c2de1ccf 100644 --- a/hstore.go +++ b/hstore.go @@ -142,8 +142,8 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { var valueBuf []byte if valueLen >= 0 { valueBuf = src[rp : rp+valueLen] + rp += valueLen } - rp += valueLen var value Text err := value.DecodeBinary(ci, valueBuf) diff --git a/hstore_test.go b/hstore_test.go index dce8baf2..48b4b42e 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -21,6 +21,10 @@ func TestHstoreTranscode(t *testing.T) { &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, &pgtype.Hstore{Map: map[string]pgtype.Text{"": text("bar")}, Status: pgtype.Present}, + &pgtype.Hstore{ + Map: map[string]pgtype.Text{"a": text("a"), "b": {Status: pgtype.Null}, "c": text("c"), "d": {Status: pgtype.Null}, "e": text("e")}, + Status: pgtype.Present, + }, &pgtype.Hstore{Status: pgtype.Null}, } From db84905b7f4388608bbc5362c4bf9b9423b4cdd6 Mon Sep 17 00:00:00 2001 From: Eli Treuherz Date: Mon, 2 Aug 2021 09:26:24 +0100 Subject: [PATCH 0706/1158] Add NullDecimal to shopspring-numeric The shopspring/decimal package provides a NullDecimal struct intended for use with nullable SQL NUMERICs and numbers. It has Scanner and Valuer implementations already, but adding it to this package allows it to be used with the binary encoding as well. The implementation is very straightforward, but the tests have been made slightly more complicated. The previous version wasn't testing the decimal.Decimal cases, and this change adds those as well as new NullDecimal cases. I've added some logic to the test harness to catch these as you need to use the Equals method to properly compare Decimals. --- ext/shopspring-numeric/decimal.go | 15 ++++++++++- ext/shopspring-numeric/decimal_test.go | 35 ++++++++++++++++++++++++-- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index e8694111..ef3ce201 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -34,6 +34,12 @@ func (dst *Numeric) Set(src interface{}) error { switch value := src.(type) { case decimal.Decimal: *dst = Numeric{Decimal: value, Status: pgtype.Present} + case decimal.NullDecimal: + if value.Valid { + *dst = Numeric{Decimal: value.Decimal, Status: pgtype.Present} + } else { + *dst = Numeric{Status: pgtype.Null} + } case float32: *dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present} case float64: @@ -113,6 +119,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { switch v := dst.(type) { case *decimal.Decimal: *v = src.Decimal + case *decimal.NullDecimal: + (*v).Valid = true + (*v).Decimal = src.Decimal case *float32: f, _ := src.Decimal.Float64() *v = float32(f) @@ -216,7 +225,11 @@ func (src *Numeric) AssignTo(dst interface{}) error { return fmt.Errorf("unable to assign to %T", dst) } case pgtype.Null: - return pgtype.NullAssignTo(dst) + if v, ok := dst.(*decimal.NullDecimal); ok { + (*v).Valid = false + } else { + return pgtype.NullAssignTo(dst) + } } return nil diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go index bf34e0dd..e635da41 100644 --- a/ext/shopspring-numeric/decimal_test.go +++ b/ext/shopspring-numeric/decimal_test.go @@ -153,6 +153,9 @@ func TestNumericSet(t *testing.T) { source interface{} result *shopspring.Numeric }{ + {source: decimal.New(1, 0), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: decimal.NullDecimal{Valid: true, Decimal: decimal.New(1, 0)}, result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: decimal.NullDecimal{Valid: false}, result: &shopspring.Numeric{Status: pgtype.Null}}, {source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, {source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, {source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, @@ -208,6 +211,8 @@ func TestNumericAssignTo(t *testing.T) { var f64 float64 var pf32 *float32 var pf64 *float64 + var d decimal.Decimal + var nd decimal.NullDecimal simpleTests := []struct { src *shopspring.Numeric @@ -231,16 +236,42 @@ func TestNumericAssignTo(t *testing.T) { {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, 0)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, 3)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.042"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, -3)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &nd, expected: decimal.NullDecimal{Valid: true, Decimal: decimal.New(42, 0)}}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &nd, expected: decimal.NullDecimal{Valid: false}}, } for i, tt := range simpleTests { + // Zero out the destination variable + reflect.ValueOf(tt.dst).Elem().Set(reflect.Zero(reflect.TypeOf(tt.dst).Elem())) + err := tt.src.AssignTo(tt.dst) if err != nil { t.Errorf("%d: %v", i, err) } - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + // Need to specially handle Decimal or NullDecimal methods so we can use their Equal method. Without this + // we end up checking reference equality on the *big.Int they contain. + switch dst := tt.dst.(type) { + case *decimal.Decimal: + if !dst.Equal(tt.expected.(decimal.Decimal)) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, d) + } + case *decimal.NullDecimal: + expected := tt.expected.(decimal.NullDecimal) + + if dst.Valid != expected.Valid { + t.Errorf("%d: expected %v to assign NullDecimal.Valid = %v, but result was NullDecimal.Valid = %v", i, tt.src, expected.Valid, dst.Valid) + } + if !dst.Decimal.Equal(expected.Decimal) { + t.Errorf("%d: expected %v to assign NullDecimal.Decimal = %v, but result was NullDecimal.Decimal = %v", i, tt.src, expected.Decimal, dst.Decimal) + } + default: + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } } } From 94f8441f4eac38759a52b3d60bdee88978a6bede Mon Sep 17 00:00:00 2001 From: Carl Dunham Date: Wed, 11 Aug 2021 20:38:44 -0700 Subject: [PATCH 0707/1158] Fix #119: add support for bare IP address as input for Inet --- inet.go | 10 +++++++++- inet_test.go | 5 +++++ pgtype_test.go | 21 +++++++++++++++------ 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/inet.go b/inet.go index 1645334e..f35f88ba 100644 --- a/inet.go +++ b/inet.go @@ -47,7 +47,15 @@ func (dst *Inet) Set(src interface{}) error { case string: ip, ipnet, err := net.ParseCIDR(value) if err != nil { - return err + ip = net.ParseIP(value) + if ip == nil { + return fmt.Errorf("unable to parse inet address: %s", value) + } + ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + ipnet.Mask = net.CIDRMask(32, 32) + } } ipnet.IP = ip *dst = Inet{IPNet: ipnet, Status: Present} diff --git a/inet_test.go b/inet_test.go index 66fe777f..09c6b21f 100644 --- a/inet_test.go +++ b/inet_test.go @@ -18,6 +18,8 @@ func TestInetTranscode(t *testing.T) { &pgtype.Inet{IPNet: mustParseInet(t, "192.168.1.16/24"), Status: pgtype.Present}, &pgtype.Inet{IPNet: mustParseInet(t, "255.0.0.0/8"), Status: pgtype.Present}, &pgtype.Inet{IPNet: mustParseInet(t, "255.255.255.255/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Status: pgtype.Present}, &pgtype.Inet{IPNet: mustParseInet(t, "::1/64"), Status: pgtype.Present}, &pgtype.Inet{IPNet: mustParseInet(t, "::/0"), Status: pgtype.Present}, &pgtype.Inet{IPNet: mustParseInet(t, "::1/128"), Status: pgtype.Present}, @@ -51,6 +53,8 @@ func TestInetSet(t *testing.T) { {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: "1.2.3.4/24", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("1.2.3.4"), Mask: net.CIDRMask(24, 32)}, Status: pgtype.Present}}, + {source: "10.0.0.1", result: pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Status: pgtype.Present}}, + {source: "2607:f8b0:4009:80b::200e", result: pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Status: pgtype.Present}}, {source: net.ParseIP(""), result: pgtype.Inet{Status: pgtype.Null}}, } @@ -59,6 +63,7 @@ func TestInetSet(t *testing.T) { err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) + continue } assert.Equalf(t, tt.result.Status, r.Status, "%d: Status", i) diff --git a/pgtype_test.go b/pgtype_test.go index 75e1909f..2506e0a3 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -37,15 +37,24 @@ func mustParseCIDR(t testing.TB, s string) *net.IPNet { func mustParseInet(t testing.TB, s string) *net.IPNet { ip, ipnet, err := net.ParseCIDR(s) - if err != nil { - t.Fatal(err) + if err == nil { + if ipv4 := ip.To4(); ipv4 != nil { + ipnet.IP = ipv4 + } + return ipnet } + + // May be bare IP address. + // + ip = net.ParseIP(s) + if ip == nil { + t.Fatal(errors.New("unable to parse inet address")) + } + ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} if ipv4 := ip.To4(); ipv4 != nil { - ip = ipv4 + ipnet.IP = ipv4 + ipnet.Mask = net.CIDRMask(32, 32) } - - ipnet.IP = ip - return ipnet } From 39aa071b15abedba83a4ea4d4075eef2fc8cba8f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 26 Aug 2021 13:21:02 -0500 Subject: [PATCH 0708/1158] Add zeronull float8 --- zeronull/float8.go | 90 +++++++++++++++++++++++++++++++++++++++++ zeronull/float8_test.go | 23 +++++++++++ 2 files changed, 113 insertions(+) create mode 100644 zeronull/float8.go create mode 100644 zeronull/float8_test.go diff --git a/zeronull/float8.go b/zeronull/float8.go new file mode 100644 index 00000000..effd0a11 --- /dev/null +++ b/zeronull/float8.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type Float8 int64 + +func (dst *Float8) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Float8 + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Float8(nullable.Float) + } else { + *dst = 0 + } + + return nil +} + +func (dst *Float8) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Float8 + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Float8(nullable.Float) + } else { + *dst = 0 + } + + return nil +} + +func (src Float8) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Float8{ + Float: float64(src), + Status: pgtype.Present, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Float8) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Float8{ + Float: float64(src), + Status: pgtype.Present, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Float8) Scan(src interface{}) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Float8 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Float8(nullable.Float) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float8) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/zeronull/float8_test.go b/zeronull/float8_test.go new file mode 100644 index 00000000..27fb785e --- /dev/null +++ b/zeronull/float8_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestFloat8Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "float8", []interface{}{ + (zeronull.Float8)(1), + (zeronull.Float8)(0), + }) +} + +func TestFloat8ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "float8", (zeronull.Float8)(0)) +} + +func TestFloat8ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "float8", (zeronull.Float8)(0)) +} From 30d763829680e38a3cbc602b74788321cdac05b1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 26 Aug 2021 15:42:47 -0500 Subject: [PATCH 0709/1158] Fix zeronull.Float8 --- zeronull/float8.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeronull/float8.go b/zeronull/float8.go index effd0a11..ebc86ac3 100644 --- a/zeronull/float8.go +++ b/zeronull/float8.go @@ -6,7 +6,7 @@ import ( "github.com/jackc/pgtype" ) -type Float8 int64 +type Float8 float64 func (dst *Float8) DecodeText(ci *pgtype.ConnInfo, src []byte) error { var nullable pgtype.Float8 From 90af821478c74fd8f917b7301d79886f933e5fe3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 26 Aug 2021 21:09:46 -0500 Subject: [PATCH 0710/1158] Remove old Travis CI code --- .travis.yml | 34 ------------------------------- travis/before_install.bash | 41 -------------------------------------- travis/before_script.bash | 11 ---------- travis/script.bash | 5 ----- 4 files changed, 91 deletions(-) delete mode 100644 .travis.yml delete mode 100755 travis/before_install.bash delete mode 100755 travis/before_script.bash delete mode 100755 travis/script.bash diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index d6762735..00000000 --- a/.travis.yml +++ /dev/null @@ -1,34 +0,0 @@ -# source: https://github.com/jackc/pgx/blob/master/.travis.yml - -language: go - -go: - - 1.14.x - - 1.13.x - - tip - -# Derived from https://github.com/lib/pq/blob/master/.travis.yml -before_install: - - ./travis/before_install.bash - -env: - global: - - GO111MODULE=on - - PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test - - matrix: - - PGVERSION=12 - - PGVERSION=11 - - PGVERSION=10 - - PGVERSION=9.6 - - PGVERSION=9.5 - -before_script: - - ./travis/before_script.bash - -script: - - ./travis/script.bash - -matrix: - allow_failures: - - go: tip \ No newline at end of file diff --git a/travis/before_install.bash b/travis/before_install.bash deleted file mode 100755 index c95969f9..00000000 --- a/travis/before_install.bash +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env bash -# source: https://github.com/jackc/pgx/blob/master/travis/before_install.bash - -set -eux - -if [ "${PGVERSION-}" != "" ] -then - sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common - sudo rm -rf /var/lib/postgresql - wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - - sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list" - sudo apt-get update -qq - sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION - sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf - if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then - echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf - echo "max_wal_senders=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf - echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf - fi - sudo /etc/init.d/postgresql restart -fi - -if [ "${CRATEVERSION-}" != "" ] -then - docker run \ - -p "6543:5432" \ - -d \ - crate:"$CRATEVERSION" \ - crate \ - -Cnetwork.host=0.0.0.0 \ - -Ctransport.host=localhost \ - -Clicense.enterprise=false -fi diff --git a/travis/before_script.bash b/travis/before_script.bash deleted file mode 100755 index 13147ab0..00000000 --- a/travis/before_script.bash +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env bash -# source: https://github.com/jackc/pgx/blob/master/travis/before_script.bash -set -eux - -if [ "${PGVERSION-}" != "" ] -then - psql -U postgres -c 'create database pgx_test' - psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)' - psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" - psql -U postgres pgx_test -c 'create extension if not exists hstore;' -fi diff --git a/travis/script.bash b/travis/script.bash deleted file mode 100755 index 1dfa2c20..00000000 --- a/travis/script.bash +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env bash -# source: https://github.com/jackc/pgx/blob/master/travis/script.bash -set -eux - -go test -v -race ./... From 3bee0c6398156fb4c1c302a0ce7b0b5bd6108ce9 Mon Sep 17 00:00:00 2001 From: Kei Kamikawa Date: Fri, 13 Aug 2021 12:53:24 +0900 Subject: [PATCH 0711/1158] removed lines to read conn --- pgconn.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/pgconn.go b/pgconn.go index 43b13e43..a1d22394 100644 --- a/pgconn.go +++ b/pgconn.go @@ -578,7 +578,6 @@ func (pgConn *PgConn) Close(ctx context.Context) error { // // See https://github.com/jackc/pgx/issues/637 pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) - pgConn.conn.Read(make([]byte, 1)) return pgConn.conn.Close() } @@ -605,7 +604,6 @@ func (pgConn *PgConn) asyncClose() { pgConn.conn.SetDeadline(deadline) pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) - pgConn.conn.Read(make([]byte, 1)) }() } From 693c7c7f7d4fa9d306af6c9ba911c26a0e62bf3c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Sep 2021 10:59:26 -0500 Subject: [PATCH 0712/1158] Fix NULL being lost when scanning unknown OID into sql.Scanner https://github.com/jackc/pgx/issues/1078 --- pgtype.go | 6 +++++- pgtype_test.go | 11 +++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pgtype.go b/pgtype.go index 4a680844..200fb562 100644 --- a/pgtype.go +++ b/pgtype.go @@ -588,7 +588,11 @@ type scanPlanSQLScanner struct{} func (scanPlanSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { scanner := dst.(sql.Scanner) - if formatCode == BinaryFormatCode { + if src == nil { + // This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the + // text format path would be converted to empty string. + return scanner.Scan(nil) + } else if formatCode == BinaryFormatCode { return scanner.Scan(src) } else { return scanner.Scan(string(src)) diff --git a/pgtype_test.go b/pgtype_test.go index 2506e0a3..85ca55e9 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -2,6 +2,7 @@ package pgtype_test import ( "bytes" + "database/sql" "errors" "net" "testing" @@ -211,6 +212,16 @@ func TestConnInfoScanUnknownOIDTextFormat(t *testing.T) { assert.EqualValues(t, 123, n) } +func TestConnInfoScanUnknownOIDIntoSQLScanner(t *testing.T) { + ci := pgtype.NewConnInfo() + + var s sql.NullString + err := ci.Scan(0, pgx.TextFormatCode, []byte(nil), &s) + assert.NoError(t, err) + assert.Equal(t, "", s.String) + assert.False(t, s.Valid) +} + func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) { ci := pgtype.NewConnInfo() src := []byte{0, 0, 0, 42} From 0b5b7c0d1e785eabcf41fee94c24ee39ca4c9a90 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Sep 2021 09:25:01 -0500 Subject: [PATCH 0713/1158] Fix BPChar.AssignTo **rune https://github.com/jackc/pgtype/issues/123 --- bpchar.go | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/bpchar.go b/bpchar.go index e4d058e9..c5fa42ea 100644 --- a/bpchar.go +++ b/bpchar.go @@ -2,6 +2,7 @@ package pgtype import ( "database/sql/driver" + "fmt" ) // BPChar is fixed-length, blank padded char type @@ -20,7 +21,8 @@ func (dst BPChar) Get() interface{} { // AssignTo assigns from src to dst. func (src *BPChar) AssignTo(dst interface{}) error { - if src.Status == Present { + switch src.Status { + case Present: switch v := dst.(type) { case *rune: runes := []rune(src.String) @@ -28,9 +30,24 @@ func (src *BPChar) AssignTo(dst interface{}) error { *v = runes[0] return nil } + case *string: + *v = src.String + return nil + case *[]byte: + *v = make([]byte, len(src.String)) + copy(*v, src.String) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) } + case Null: + return NullAssignTo(dst) } - return (*Text)(src).AssignTo(dst) + + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (BPChar) PreferredResultFormat() int16 { From e53b7aebaba1c7183facf23b10ae535d604002f7 Mon Sep 17 00:00:00 2001 From: Jan Dubsky Date: Mon, 20 Sep 2021 13:34:45 +0200 Subject: [PATCH 0714/1158] Add support for fmt.Stringer and driver.Valuer in String fields encoding --- text.go | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/text.go b/text.go index 6b01d1b4..a01815d9 100644 --- a/text.go +++ b/text.go @@ -39,7 +39,37 @@ func (dst *Text) Set(src interface{}) error { } else { *dst = Text{String: string(value), Status: Present} } + case fmt.Stringer: + if value == fmt.Stringer(nil) { + *dst = Text{Status: Null} + } else { + *dst = Text{String: value.String(), Status: Present} + } default: + // Cannot be part of the switch: If Value() returns nil on + // non-string, we should still try to checks the underlying type + // using reflection. + // + // For example the struct might implement driver.Valuer with + // pointer receiver and fmt.Stringer with value receiver. + if value, ok := src.(driver.Valuer); ok { + if value == driver.Valuer(nil) { + *dst = Text{Status: Null} + return nil + } else { + v, err := value.Value() + if err != nil { + return fmt.Errorf("driver.Valuer Value() method failed: %w", err) + } + + // Handles also v == nil case. + if s, ok := v.(string); ok { + *dst = Text{String: s, Status: Present} + return nil + } + } + } + if originalSrc, ok := underlyingStringType(src); ok { return dst.Set(originalSrc) } From 290ee79d1e8d48c3ff1c1381e01ba76d6b71985a Mon Sep 17 00:00:00 2001 From: Rueian Date: Mon, 27 Sep 2021 14:29:53 +0800 Subject: [PATCH 0715/1158] feat: remove unnecessary pending for CopyInResponse --- pgconn.go | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/pgconn.go b/pgconn.go index a1d22394..382ad33c 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1185,27 +1185,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co return nil, &writeError{err: err, safeToRetry: n == 0} } - // Read until copy in response or error. - var commandTag CommandTag - var pgErr error - pendingCopyInResponse := true - for pendingCopyInResponse { - msg, err := pgConn.receiveMessage() - if err != nil { - pgConn.asyncClose() - return nil, preferContextOverNetTimeoutError(ctx, err) - } - - switch msg := msg.(type) { - case *pgproto3.CopyInResponse: - pendingCopyInResponse = false - case *pgproto3.ErrorResponse: - pgErr = ErrorResponseToPgError(msg) - case *pgproto3.ReadyForQuery: - return commandTag, pgErr - } - } - // Send copy data abortCopyChan := make(chan struct{}) copyErrChan := make(chan error, 1) @@ -1244,6 +1223,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } }() + var pgErr error var copyErr error for copyErr == nil && pgErr == nil { select { @@ -1280,6 +1260,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } // Read results + var commandTag CommandTag for { msg, err := pgConn.receiveMessage() if err != nil { From e28459e9d1773a5b033bafca691a655e5a1f24cd Mon Sep 17 00:00:00 2001 From: Jim Tsao Date: Fri, 8 Oct 2021 14:45:10 +0200 Subject: [PATCH 0716/1158] Fix int64 overflow error --- timestamptz.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/timestamptz.go b/timestamptz.go index e0743060..299a8668 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -148,8 +148,10 @@ func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { case negativeInfinityMicrosecondOffset: *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} default: - microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K - tim := time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) + tim := time.Unix( + microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, + (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), + ) *dst = Timestamptz{Time: tim, Status: Present} } From 5cb98120c10c0ed3125786007d6c0861d34b8f79 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Oct 2021 09:55:56 -0500 Subject: [PATCH 0717/1158] Add tests for big time and port fix to Timestamp.DecodeBinary https://github.com/jackc/pgtype/pull/128 --- timestamp.go | 6 ++++-- timestamp_test.go | 21 +++++++++++++++++++++ timestamptz_test.go | 21 +++++++++++++++++++++ 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/timestamp.go b/timestamp.go index 46644115..a184d232 100644 --- a/timestamp.go +++ b/timestamp.go @@ -141,8 +141,10 @@ func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { case negativeInfinityMicrosecondOffset: *dst = Timestamp{Status: Present, InfinityModifier: -Infinity} default: - microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K - tim := time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000).UTC() + tim := time.Unix( + microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, + (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), + ) *dst = Timestamp{Time: tim, Status: Present} } diff --git a/timestamp_test.go b/timestamp_test.go index 74cb1221..ea7ef57a 100644 --- a/timestamp_test.go +++ b/timestamp_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "context" "reflect" "testing" "time" @@ -33,6 +34,26 @@ func TestTimestampTranscode(t *testing.T) { }) } +// https://github.com/jackc/pgtype/pull/128 +func TestTimestampTranscodeBigTimeBinary(t *testing.T) { + conn := testutil.MustConnectPgx(t) + if _, ok := conn.ConnInfo().DataTypeForName("line"); !ok { + t.Skip("Skipping due to no line type") + } + defer testutil.MustCloseContext(t, conn) + + in := &pgtype.Timestamp{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Status: pgtype.Present} + var out pgtype.Timestamp + + err := conn.QueryRow(context.Background(), "select $1::timestamptz", in).Scan(&out) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, in.Status, out.Status) + require.Truef(t, in.Time.Equal(out.Time), "expected %v got %v", in.Time, out.Time) +} + func TestTimestampNanosecondsTruncated(t *testing.T) { tests := []struct { input time.Time diff --git a/timestamptz_test.go b/timestamptz_test.go index 769c9239..c3f63967 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "context" "reflect" "testing" "time" @@ -33,6 +34,26 @@ func TestTimestamptzTranscode(t *testing.T) { }) } +// https://github.com/jackc/pgtype/pull/128 +func TestTimestamptzTranscodeBigTimeBinary(t *testing.T) { + conn := testutil.MustConnectPgx(t) + if _, ok := conn.ConnInfo().DataTypeForName("line"); !ok { + t.Skip("Skipping due to no line type") + } + defer testutil.MustCloseContext(t, conn) + + in := &pgtype.Timestamptz{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Status: pgtype.Present} + var out pgtype.Timestamptz + + err := conn.QueryRow(context.Background(), "select $1::timestamptz", in).Scan(&out) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, in.Status, out.Status) + require.Truef(t, in.Time.Equal(out.Time), "expected %v got %v", in.Time, out.Time) +} + func TestTimestamptzNanosecondsTruncated(t *testing.T) { tests := []struct { input time.Time From b72f8084b57d2af6e4e1d8aa350a4d3f052c1e2b Mon Sep 17 00:00:00 2001 From: Adrian Sieger Date: Mon, 25 Oct 2021 10:01:21 +0200 Subject: [PATCH 0718/1158] implement nullable values for hstore maps --- hstore.go | 23 +++++++++++++ hstore_test.go | 89 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/hstore.go b/hstore.go index c2de1ccf..f46eeaf6 100644 --- a/hstore.go +++ b/hstore.go @@ -40,6 +40,16 @@ func (dst *Hstore) Set(src interface{}) error { m[k] = Text{String: v, Status: Present} } *dst = Hstore{Map: m, Status: Present} + case map[string]*string: + m := make(map[string]Text, len(value)) + for k, v := range value { + if v == nil { + m[k] = Text{Status: Null} + } else { + m[k] = Text{String: *v, Status: Present} + } + } + *dst = Hstore{Map: m, Status: Present} default: return fmt.Errorf("cannot convert %v to Hstore", src) } @@ -71,6 +81,19 @@ func (src *Hstore) AssignTo(dst interface{}) error { (*v)[k] = val.String } return nil + case *map[string]*string: + *v = make(map[string]*string, len(src.Map)) + for k, val := range src.Map { + switch val.Status { + case Null: + (*v)[k] = nil + case Present: + (*v)[k] = &val.String + default: + return fmt.Errorf("cannot decode %#v into %T", src, dst) + } + } + return nil default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) diff --git a/hstore_test.go b/hstore_test.go index 48b4b42e..73ee0612 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -69,6 +69,50 @@ func TestHstoreTranscode(t *testing.T) { }) } +func TestHstoreTranscodeNullable(t *testing.T) { + text := func(s string, status pgtype.Status) pgtype.Text { + return pgtype.Text{String: s, Status: status} + } + + values := []interface{}{ + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("", pgtype.Null)}, Status: pgtype.Present}, + } + + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("", pgtype.Null)}, Status: pgtype.Present}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("", pgtype.Null)}, Status: pgtype.Present}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("", pgtype.Null)}, Status: pgtype.Present}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("", pgtype.Null)}, Status: pgtype.Present}) // is key + } + + testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { + a := ai.(pgtype.Hstore) + b := bi.(pgtype.Hstore) + + if len(a.Map) != len(b.Map) || a.Status != b.Status { + return false + } + + for k := range a.Map { + if a.Map[k] != b.Map[k] { + return false + } + } + + return true + }) +} + func TestHstoreSet(t *testing.T) { successfulTests := []struct { src map[string]string @@ -90,6 +134,27 @@ func TestHstoreSet(t *testing.T) { } } +func TestHstoreSetNullable(t *testing.T) { + successfulTests := []struct { + src map[string]*string + result pgtype.Hstore + }{ + {src: map[string]*string{"foo": nil}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {Status: pgtype.Null}}, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var dst pgtype.Hstore + err := dst.Set(tt.src) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(dst, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) + } + } +} + func TestHstoreAssignTo(t *testing.T) { var m map[string]string @@ -113,3 +178,27 @@ func TestHstoreAssignTo(t *testing.T) { } } } + +func TestHstoreAssignToNullable(t *testing.T) { + var m map[string]*string + + simpleTests := []struct { + src pgtype.Hstore + dst *map[string]*string + expected map[string]*string + }{ + {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {Status: pgtype.Null}}, Status: pgtype.Present}, dst: &m, expected: map[string]*string{"foo": nil}}, + {src: pgtype.Hstore{Status: pgtype.Null}, dst: &m, expected: ((map[string]*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(*tt.dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} From 2caf113f1b6c824375c8b7750168c7b7d2d158f1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Oct 2021 09:00:48 -0500 Subject: [PATCH 0719/1158] Fix parsing text array with negative bounds e.g. '[-4:-2]={1,2,3}' fixes #132 --- array.go | 2 +- array_test.go | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/array.go b/array.go index 3d5930c1..174007c1 100644 --- a/array.go +++ b/array.go @@ -305,7 +305,7 @@ func arrayParseInteger(buf *bytes.Buffer) (int32, error) { return 0, err } - if '0' <= r && r <= '9' { + if ('0' <= r && r <= '9') || r == '-' { s.WriteRune(r) } else { buf.UnreadRune() diff --git a/array_test.go b/array_test.go index d2120677..f1fe90f4 100644 --- a/array_test.go +++ b/array_test.go @@ -100,6 +100,14 @@ func TestParseUntypedTextArray(t *testing.T) { }, }, }, + { + source: "[-4:-2]={1,2,3}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1", "2", "3"}, + Quoted: []bool{false, false, false}, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: -4}}, + }, + }, } for i, tt := range tests { From 5c447ff35d2974a97d4cc8ba0b49d09b5b2c6394 Mon Sep 17 00:00:00 2001 From: Yuli Khodorkovskiy Date: Thu, 28 Oct 2021 11:03:28 -0400 Subject: [PATCH 0720/1158] Fix JSON output for SASL{Response,InitialResponse} Hex encoding the Data field in the SASL responses made debugging SCRAM more difficult than actually helping. Before: F{"Type":"SASLResponse","Data":"633d655377732c723d4d4d4e4e6d666b536f5862694a68385833466d324f2b4d77787354692f4550753052414157484b7a306b7376336c5747392f4d4a5267504d2c703d616742664b533164383937674b4f4a6d4c7171626c49326b6b4a506f2b58354359516c63473458357657343d"} F{"Type":"SASLInitialResponse","AuthMechanism":"SCRAM-SHA-256","Data":"792c2c6e3d2c723d4d4d4e4e6d666b536f5862694a68385833466d324f2b4d77"} After: F{"Type":"SASLResponse","Data":"c=eSws,r=9dR43UQLL1KbrKKl4/QbxjqgVjZYR9mqnx3rFBiI7R/1pp5oeVYMGhXj,p=b2hmuvTvWn2xN0fclm+O4TwLAarRM8xoHSN7jsKDHAU="} F{"Type":"SASLInitialResponse","AuthMechanism":"SCRAM-SHA-256","Data":"y,,n=,r=9dR43UQLL1KbrKKl4/Qbxjqg"} --- sasl_initial_response.go | 2 +- sasl_response.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sasl_initial_response.go b/sasl_initial_response.go index f7e5f36a..f862f2a8 100644 --- a/sasl_initial_response.go +++ b/sasl_initial_response.go @@ -64,7 +64,7 @@ func (src SASLInitialResponse) MarshalJSON() ([]byte, error) { }{ Type: "SASLInitialResponse", AuthMechanism: src.AuthMechanism, - Data: hex.EncodeToString(src.Data), + Data: string(src.Data), }) } diff --git a/sasl_response.go b/sasl_response.go index 41fb4c39..d402759a 100644 --- a/sasl_response.go +++ b/sasl_response.go @@ -38,7 +38,7 @@ func (src SASLResponse) MarshalJSON() ([]byte, error) { Data string }{ Type: "SASLResponse", - Data: hex.EncodeToString(src.Data), + Data: string(src.Data), }) } From a29019de9d6dc5a3fdefe67f29a36d9ecf5a943f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Oct 2021 10:17:58 -0500 Subject: [PATCH 0721/1158] Fix binary decoding of very large numerics. fixes #133 --- numeric.go | 6 +++--- numeric_test.go | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/numeric.go b/numeric.go index a7efa704..f5260548 100644 --- a/numeric.go +++ b/numeric.go @@ -452,11 +452,11 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { } rp := 0 - ndigits := int16(binary.BigEndian.Uint16(src[rp:])) + ndigits := binary.BigEndian.Uint16(src[rp:]) rp += 2 weight := int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 - sign := uint16(binary.BigEndian.Uint16(src[rp:])) + sign := binary.BigEndian.Uint16(src[rp:]) rp += 2 dscale := int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 @@ -504,7 +504,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { exp := (int32(weight) - int32(ndigits) + 1) * 4 if dscale > 0 { - fracNBaseDigits := ndigits - weight - 1 + fracNBaseDigits := int16(int32(ndigits) - int32(weight) - 1) fracDecimalDigits := fracNBaseDigits * 4 if dscale > fracDecimalDigits { diff --git a/numeric_test.go b/numeric_test.go index 81595cb3..fff5a2e0 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -117,6 +117,10 @@ func TestNumericNormalize(t *testing.T) { } func TestNumericTranscode(t *testing.T) { + max := new(big.Int).Exp(big.NewInt(10), big.NewInt(147454), nil) + max.Add(max, big.NewInt(1)) + longestNumeric := &pgtype.Numeric{Int: max, Exp: -16383, Status: pgtype.Present} + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ &pgtype.Numeric{NaN: true, Status: pgtype.Present}, @@ -151,6 +155,9 @@ func TestNumericTranscode(t *testing.T) { &pgtype.Numeric{Int: mustParseBigInt(t, "423409823409243892349028349023482934092340892390101"), Exp: -92, Status: pgtype.Present}, &pgtype.Numeric{Int: mustParseBigInt(t, "23409823409243892349028349023482934092340892390101"), Exp: -93, Status: pgtype.Present}, &pgtype.Numeric{Int: mustParseBigInt(t, "3409823409243892349028349023482934092340892390101"), Exp: -94, Status: pgtype.Present}, + + longestNumeric, + &pgtype.Numeric{Status: pgtype.Null}, }, func(aa, bb interface{}) bool { a := aa.(pgtype.Numeric) From e0f9fc52122e55c8d8981d6eff21f879c9fffd1b Mon Sep 17 00:00:00 2001 From: Jim Tsao Date: Sun, 31 Oct 2021 12:35:35 +0100 Subject: [PATCH 0722/1158] Add infinity support for Numeric Set/Get --- numeric.go | 26 ++++++++++++++++++++++---- numeric_test.go | 6 ++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/numeric.go b/numeric.go index f5260548..3f2dc9ae 100644 --- a/numeric.go +++ b/numeric.go @@ -49,10 +49,11 @@ var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) type Numeric struct { - Int *big.Int - Exp int32 - Status Status - NaN bool + Int *big.Int + Exp int32 + Status Status + NaN bool + InfinityModifier InfinityModifier } func (dst *Numeric) Set(src interface{}) error { @@ -73,6 +74,12 @@ func (dst *Numeric) Set(src interface{}) error { if math.IsNaN(float64(value)) { *dst = Numeric{Status: Present, NaN: true} return nil + } else if math.IsInf(float64(value), 1) { + *dst = Numeric{Status: Present, InfinityModifier: Infinity} + return nil + } else if math.IsInf(float64(value), -1) { + *dst = Numeric{Status: Present, InfinityModifier: NegativeInfinity} + return nil } num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) if err != nil { @@ -83,6 +90,12 @@ func (dst *Numeric) Set(src interface{}) error { if math.IsNaN(value) { *dst = Numeric{Status: Present, NaN: true} return nil + } else if math.IsInf(value, 1) { + *dst = Numeric{Status: Present, InfinityModifier: Infinity} + return nil + } else if math.IsInf(value, -1) { + *dst = Numeric{Status: Present, InfinityModifier: NegativeInfinity} + return nil } num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) if err != nil { @@ -193,6 +206,8 @@ func (dst *Numeric) Set(src interface{}) error { } else { return dst.Set(*value) } + case InfinityModifier: + *dst = Numeric{InfinityModifier: value, Status: Present} default: if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) @@ -206,6 +221,9 @@ func (dst *Numeric) Set(src interface{}) error { func (dst Numeric) Get() interface{} { switch dst.Status { case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } return dst case Null: return nil diff --git a/numeric_test.go b/numeric_test.go index fff5a2e0..f14cf960 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -222,6 +222,12 @@ func TestNumericSet(t *testing.T) { {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}}, {source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, NaN: true}}, {source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, NaN: true}}, + {source: pgtype.Infinity, result: &pgtype.Numeric{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, + {source: math.Inf(1), result: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}}, + {source: float32(math.Inf(1)), result: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}}, + {source: pgtype.NegativeInfinity, result: &pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + {source: math.Inf(-1), result: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.NegativeInfinity}}, + {source: float32(math.Inf(1)), result: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}}, } for i, tt := range successfulTests { From 001b3166b9b675c48aae339a0d8d78f52a599056 Mon Sep 17 00:00:00 2001 From: Jim Tsao Date: Sun, 31 Oct 2021 13:32:23 +0100 Subject: [PATCH 0723/1158] Add infinity support for Numeric AssignTo --- numeric.go | 4 ++++ numeric_test.go | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/numeric.go b/numeric.go index 3f2dc9ae..72e59b69 100644 --- a/numeric.go +++ b/numeric.go @@ -403,6 +403,10 @@ func (dst *Numeric) toBigInt() (*big.Int, error) { func (src *Numeric) toFloat64() (float64, error) { if src.NaN { return math.NaN(), nil + } else if src.InfinityModifier == Infinity { + return math.Inf(1), nil + } else if src.InfinityModifier == NegativeInfinity { + return math.Inf(-1), nil } buf := make([]byte, 0, 32) diff --git a/numeric_test.go b/numeric_test.go index f14cf960..ecd2d95e 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -287,6 +287,10 @@ func TestNumericAssignTo(t *testing.T) { {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Status: pgtype.Present}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgtype/issues/27 {src: &pgtype.Numeric{Status: pgtype.Present, NaN: true}, dst: &f64, expected: math.NaN()}, {src: &pgtype.Numeric{Status: pgtype.Present, NaN: true}, dst: &f32, expected: float32(math.NaN())}, + {src: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, dst: &f64, expected: math.Inf(1)}, + {src: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, dst: &f32, expected: float32(math.Inf(1))}, + {src: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.NegativeInfinity}, dst: &f64, expected: math.Inf(-1)}, + {src: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.NegativeInfinity}, dst: &f32, expected: float32(math.Inf(-1))}, } for i, tt := range simpleTests { From 8890a746d79e2ef181eaf126f4f4420ec65db309 Mon Sep 17 00:00:00 2001 From: Jim Tsao Date: Sun, 31 Oct 2021 13:47:09 +0100 Subject: [PATCH 0724/1158] Add infinity support for Numeric Text Encode/Decode --- numeric.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/numeric.go b/numeric.go index 72e59b69..cf8770fa 100644 --- a/numeric.go +++ b/numeric.go @@ -431,6 +431,12 @@ func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { if string(src) == "NaN" { *dst = Numeric{Status: Present, NaN: true} return nil + } else if string(src) == "Infinity" { + *dst = Numeric{Status: Present, InfinityModifier: Infinity} + return nil + } else if string(src) == "-Infinity" { + *dst = Numeric{Status: Present, InfinityModifier: NegativeInfinity} + return nil } num, exp, err := parseNumericString(string(src)) @@ -597,6 +603,12 @@ func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if src.NaN { buf = append(buf, "NaN"...) return buf, nil + } else if src.InfinityModifier == Infinity { + buf = append(buf, "Infinity"...) + return buf, nil + } else if src.InfinityModifier == NegativeInfinity { + buf = append(buf, "-Infinity"...) + return buf, nil } buf = append(buf, src.Int.String()...) From 14c515db82228a7138d7e3d26b8e5b723a621007 Mon Sep 17 00:00:00 2001 From: Jim Tsao Date: Sun, 31 Oct 2021 14:02:41 +0100 Subject: [PATCH 0725/1158] Add infinity support for Numeric Binary Encode/Decode --- numeric.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/numeric.go b/numeric.go index cf8770fa..a939625b 100644 --- a/numeric.go +++ b/numeric.go @@ -18,6 +18,12 @@ const nbase = 10000 const ( pgNumericNaN = 0x00000000c0000000 pgNumericNaNSign = 0xc000 + + pgNumericPosInf = 0x00000000d0000000 + pgNumericPosInfSign = 0xd000 + + pgNumericNegInf = 0x00000000f0000000 + pgNumericNegInfSign = 0xf000 ) var big0 *big.Int = big.NewInt(0) @@ -492,6 +498,12 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { if sign == pgNumericNaNSign { *dst = Numeric{Status: Present, NaN: true} return nil + } else if sign == pgNumericPosInfSign { + *dst = Numeric{Status: Present, InfinityModifier: Infinity} + return nil + } else if sign == pgNumericNegInfSign { + *dst = Numeric{Status: Present, InfinityModifier: NegativeInfinity} + return nil } if ndigits == 0 { @@ -628,6 +640,12 @@ func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if src.NaN { buf = pgio.AppendUint64(buf, pgNumericNaN) return buf, nil + } else if src.InfinityModifier == Infinity { + buf = pgio.AppendUint64(buf, pgNumericPosInf) + return buf, nil + } else if src.InfinityModifier == NegativeInfinity { + buf = pgio.AppendUint64(buf, pgNumericNegInf) + return buf, nil } var sign int16 From decb75f242b2be04fc75f9adc8c2bd739856eb31 Mon Sep 17 00:00:00 2001 From: Jim Tsao Date: Sun, 31 Oct 2021 14:03:40 +0100 Subject: [PATCH 0726/1158] Add numeric tests for infinity encoding/decoding --- numeric_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/numeric_test.go b/numeric_test.go index ecd2d95e..455c3ac3 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -123,6 +123,8 @@ func TestNumericTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ &pgtype.Numeric{NaN: true, Status: pgtype.Present}, + &pgtype.Numeric{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, + &pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, @@ -372,6 +374,10 @@ func TestNumericEncodeDecodeBinary(t *testing.T) { 1.00002345, math.NaN(), float32(math.NaN()), + math.Inf(1), + float32(math.Inf(1)), + math.Inf(-1), + float32(math.Inf(-1)), } for i, tt := range tests { From 162dc65eff6f037c98baa36f1f4c75658408d65a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Nov 2021 08:57:49 -0500 Subject: [PATCH 0727/1158] Make ContextWatcher concurrency safe fixes #94 --- internal/ctxwatch/context_watcher.go | 13 +++++++++++-- internal/ctxwatch/context_watcher_test.go | 15 ++++++++++++++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/internal/ctxwatch/context_watcher.go b/internal/ctxwatch/context_watcher.go index 391f0b79..b39cb3ee 100644 --- a/internal/ctxwatch/context_watcher.go +++ b/internal/ctxwatch/context_watcher.go @@ -2,6 +2,7 @@ package ctxwatch import ( "context" + "sync" ) // ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a @@ -10,8 +11,10 @@ type ContextWatcher struct { onCancel func() onUnwatchAfterCancel func() unwatchChan chan struct{} - watchInProgress bool - onCancelWasCalled bool + + lock sync.Mutex + watchInProgress bool + onCancelWasCalled bool } // NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. @@ -29,6 +32,9 @@ func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWat // Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called. func (cw *ContextWatcher) Watch(ctx context.Context) { + cw.lock.Lock() + defer cw.lock.Unlock() + if cw.watchInProgress { panic("Watch already in progress") } @@ -54,6 +60,9 @@ func (cw *ContextWatcher) Watch(ctx context.Context) { // Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was // called then onUnwatchAfterCancel will also be called. func (cw *ContextWatcher) Unwatch() { + cw.lock.Lock() + defer cw.lock.Unlock() + if cw.watchInProgress { cw.unwatchChan <- struct{}{} if cw.onCancelWasCalled { diff --git a/internal/ctxwatch/context_watcher_test.go b/internal/ctxwatch/context_watcher_test.go index 6348b729..289606c3 100644 --- a/internal/ctxwatch/context_watcher_test.go +++ b/internal/ctxwatch/context_watcher_test.go @@ -59,7 +59,7 @@ func TestContextWatcherMultipleWatchPanics(t *testing.T) { require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times") } -func TestContextWatcherUnwatchIsAlwaysSafe(t *testing.T) { +func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) { cw := ctxwatch.NewContextWatcher(func() {}, func() {}) cw.Unwatch() // unwatch when not / never watching @@ -70,6 +70,19 @@ func TestContextWatcherUnwatchIsAlwaysSafe(t *testing.T) { cw.Unwatch() // double unwatch } +func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + cw.Watch(ctx) + + go cw.Unwatch() + go cw.Unwatch() + + <-ctx.Done() +} + func TestContextWatcherStress(t *testing.T) { var cancelFuncCalls int64 var cleanupFuncCalls int64 From 9275da562f356273288268731252a8f1be60d5f4 Mon Sep 17 00:00:00 2001 From: Martin Ashby Date: Wed, 3 Nov 2021 19:43:46 +0000 Subject: [PATCH 0728/1158] Added FunctionCall support Added support for FunctionCall message as per https://www.postgresql.org/docs/11/protocol-message-formats.html Adds unit test for Encode / Decode cycle and invalid message format errors. Fixes https://github.com/jackc/pgproto3/issues/23 --- backend.go | 12 +++-- function_call.go | 104 ++++++++++++++++++++++++++++++++++++++++++ function_call_test.go | 44 ++++++++++++++++++ go.mod | 1 + go.sum | 3 ++ 5 files changed, 160 insertions(+), 4 deletions(-) create mode 100644 function_call.go create mode 100644 function_call_test.go diff --git a/backend.go b/backend.go index e9ba38fc..9c42ad02 100644 --- a/backend.go +++ b/backend.go @@ -21,6 +21,7 @@ type Backend struct { describe Describe execute Execute flush Flush + functionCall FunctionCall gssEncRequest GSSEncRequest parse Parse query Query @@ -29,10 +30,11 @@ type Backend struct { sync Sync terminate Terminate - bodyLen int - msgType byte - partialMsg bool - authType uint32 + bodyLen int + msgType byte + partialMsg bool + authType uint32 + } const ( @@ -125,6 +127,8 @@ func (b *Backend) Receive() (FrontendMessage, error) { msg = &b.describe case 'E': msg = &b.execute + case 'F': + msg = &b.functionCall case 'f': msg = &b.copyFail case 'd': diff --git a/function_call.go b/function_call.go new file mode 100644 index 00000000..74d3c3c7 --- /dev/null +++ b/function_call.go @@ -0,0 +1,104 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "github.com/jackc/pgio" +) + +type FunctionCall struct{ + Function uint32 + ArgFormatCodes []uint16 + Arguments [][]byte + ResultFormatCode uint16 +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*FunctionCall) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *FunctionCall) Decode(src []byte) error { + *dst = FunctionCall{} + rp := 0 + // Specifies the object ID of the function to call. + dst.Function = binary.BigEndian.Uint32(src[rp:]) + rp += 4 + // The number of argument format codes that follow (denoted C below). + // This can be zero to indicate that there are no arguments or that the arguments all use the default format (text); + // or one, in which case the specified format code is applied to all arguments; + // or it can equal the actual number of arguments. + nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + argumentCodes := make([]uint16, nArgumentCodes) + for i := 0; i < nArgumentCodes; i++ { + // The argument format codes. Each must presently be zero (text) or one (binary). + ac := binary.BigEndian.Uint16(src[rp:]) + if ac != 0 && ac != 1 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + argumentCodes[i] = ac + rp += 2 + } + dst.ArgFormatCodes = argumentCodes + + // Specifies the number of arguments being supplied to the function. + nArguments := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + arguments := make([][]byte, nArguments) + for i := 0; i < nArguments; i++ { + // The length of the argument value, in bytes (this count does not include itself). Can be zero. + // As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case. + argumentLength := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + if argumentLength == -1 { + arguments[i] = nil + } else { + // The value of the argument, in the format indicated by the associated format code. n is the above length. + argumentValue := src[rp:rp+argumentLength] + rp += argumentLength + arguments[i] = argumentValue + } + } + dst.Arguments = arguments + // The format code for the function result. Must presently be zero (text) or one (binary). + resultFormatCode := binary.BigEndian.Uint16(src[rp:]) + if resultFormatCode != 0 && resultFormatCode != 1 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + dst.ResultFormatCode = resultFormatCode + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *FunctionCall) Encode(dst []byte) []byte { + dst = append(dst, 'F') + sp := len(dst) + dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end + dst = pgio.AppendUint32(dst, src.Function) + dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes))) + for _, argFormatCode := range src.ArgFormatCodes { + dst = pgio.AppendUint16(dst, argFormatCode) + } + dst = pgio.AppendUint16(dst, uint16(len(src.Arguments))) + for _, argument := range src.Arguments { + if argument == nil { + dst = pgio.AppendInt32(dst, -1) + } else { + dst = pgio.AppendInt32(dst, int32(len(argument))) + dst = append(dst, argument...) + } + } + dst = pgio.AppendUint16(dst, src.ResultFormatCode) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src FunctionCall) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "FunctionCall", + }) +} diff --git a/function_call_test.go b/function_call_test.go new file mode 100644 index 00000000..f1586560 --- /dev/null +++ b/function_call_test.go @@ -0,0 +1,44 @@ +package pgproto3 + +import ( + "github.com/go-test/deep" + "testing" +) + +func TestFunctionCall_EncodeDecode(t *testing.T) { + type fields struct { + Function uint32 + ArgFormatCodes []uint16 + Arguments [][]byte + ResultFormatCode uint16 + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"foo", fields{uint32(123), []uint16{0, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, false}, + {"invalid format code", fields{uint32(123), []uint16{2, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, true}, + {"invalid result format code", fields{uint32(123), []uint16{1, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(2)}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + src := &FunctionCall{ + Function: tt.fields.Function, + ArgFormatCodes: tt.fields.ArgFormatCodes, + Arguments: tt.fields.Arguments, + ResultFormatCode: tt.fields.ResultFormatCode, + } + encoded := src.Encode([]byte{}) + decoded := &FunctionCall{} + err := decoded.Decode(encoded[5:]) + if (err != nil) != tt.wantErr { + t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr) + return + } + if diff := deep.Equal(src, decoded); diff != nil { + t.Error(diff) + } + }) + } +} \ No newline at end of file diff --git a/go.mod b/go.mod index 36041a94..030953b5 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/jackc/pgproto3/v2 go 1.12 require ( + github.com/go-test/deep v1.0.8 github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/stretchr/testify v1.4.0 diff --git a/go.sum b/go.sum index dd9cd044..765190c1 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= +github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= @@ -9,6 +11,7 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= From 3d9a54f092879f0356034cea00157ef08ebbac71 Mon Sep 17 00:00:00 2001 From: Martin Ashby Date: Sat, 6 Nov 2021 16:17:26 +0000 Subject: [PATCH 0729/1158] Fix unit test, it should return after any error is returned from Decode function whether expected or not, rather than continue and try to compare invalid decoded results. Extend the unit test slightly to check the header. Remove go-test/deep dependency in favour of standard library reflect package. --- function_call.go | 60 +++++++++++++++++++++---------------------- function_call_test.go | 40 +++++++++++++++++++++-------- go.mod | 1 - go.sum | 3 --- 4 files changed, 59 insertions(+), 45 deletions(-) diff --git a/function_call.go b/function_call.go index 74d3c3c7..11cccb3e 100644 --- a/function_call.go +++ b/function_call.go @@ -6,10 +6,10 @@ import ( "github.com/jackc/pgio" ) -type FunctionCall struct{ - Function uint32 - ArgFormatCodes []uint16 - Arguments [][]byte +type FunctionCall struct { + Function uint32 + ArgFormatCodes []uint16 + Arguments [][]byte ResultFormatCode uint16 } @@ -24,9 +24,9 @@ func (dst *FunctionCall) Decode(src []byte) error { // Specifies the object ID of the function to call. dst.Function = binary.BigEndian.Uint32(src[rp:]) rp += 4 - // The number of argument format codes that follow (denoted C below). - // This can be zero to indicate that there are no arguments or that the arguments all use the default format (text); - // or one, in which case the specified format code is applied to all arguments; + // The number of argument format codes that follow (denoted C below). + // This can be zero to indicate that there are no arguments or that the arguments all use the default format (text); + // or one, in which case the specified format code is applied to all arguments; // or it can equal the actual number of arguments. nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:])) rp += 2 @@ -37,36 +37,36 @@ func (dst *FunctionCall) Decode(src []byte) error { if ac != 0 && ac != 1 { return &invalidMessageFormatErr{messageType: "FunctionCall"} } - argumentCodes[i] = ac - rp += 2 - } + argumentCodes[i] = ac + rp += 2 + } dst.ArgFormatCodes = argumentCodes - + // Specifies the number of arguments being supplied to the function. nArguments := int(binary.BigEndian.Uint16(src[rp:])) rp += 2 arguments := make([][]byte, nArguments) for i := 0; i < nArguments; i++ { - // The length of the argument value, in bytes (this count does not include itself). Can be zero. + // The length of the argument value, in bytes (this count does not include itself). Can be zero. // As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case. - argumentLength := int(binary.BigEndian.Uint32(src[rp:])) - rp += 4 - if argumentLength == -1 { + argumentLength := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + if argumentLength == -1 { arguments[i] = nil } else { // The value of the argument, in the format indicated by the associated format code. n is the above length. - argumentValue := src[rp:rp+argumentLength] + argumentValue := src[rp : rp+argumentLength] rp += argumentLength arguments[i] = argumentValue - } - } + } + } dst.Arguments = arguments // The format code for the function result. Must presently be zero (text) or one (binary). resultFormatCode := binary.BigEndian.Uint16(src[rp:]) if resultFormatCode != 0 && resultFormatCode != 1 { - return &invalidMessageFormatErr{messageType: "FunctionCall"} - } - dst.ResultFormatCode = resultFormatCode + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + dst.ResultFormatCode = resultFormatCode return nil } @@ -78,17 +78,17 @@ func (src *FunctionCall) Encode(dst []byte) []byte { dst = pgio.AppendUint32(dst, src.Function) dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes))) for _, argFormatCode := range src.ArgFormatCodes { - dst = pgio.AppendUint16(dst, argFormatCode) - } + dst = pgio.AppendUint16(dst, argFormatCode) + } dst = pgio.AppendUint16(dst, uint16(len(src.Arguments))) for _, argument := range src.Arguments { - if argument == nil { - dst = pgio.AppendInt32(dst, -1) - } else { - dst = pgio.AppendInt32(dst, int32(len(argument))) - dst = append(dst, argument...) - } - } + if argument == nil { + dst = pgio.AppendInt32(dst, -1) + } else { + dst = pgio.AppendInt32(dst, int32(len(argument))) + dst = append(dst, argument...) + } + } dst = pgio.AppendUint16(dst, src.ResultFormatCode) pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return dst diff --git a/function_call_test.go b/function_call_test.go index f1586560..8c08bb24 100644 --- a/function_call_test.go +++ b/function_call_test.go @@ -1,7 +1,8 @@ package pgproto3 import ( - "github.com/go-test/deep" + "encoding/binary" + "reflect" "testing" ) @@ -17,7 +18,7 @@ func TestFunctionCall_EncodeDecode(t *testing.T) { fields fields wantErr bool }{ - {"foo", fields{uint32(123), []uint16{0, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, false}, + {"valid", fields{uint32(123), []uint16{0, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(1)}, false}, {"invalid format code", fields{uint32(123), []uint16{2, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, true}, {"invalid result format code", fields{uint32(123), []uint16{1, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(2)}, true}, } @@ -30,15 +31,32 @@ func TestFunctionCall_EncodeDecode(t *testing.T) { ResultFormatCode: tt.fields.ResultFormatCode, } encoded := src.Encode([]byte{}) - decoded := &FunctionCall{} - err := decoded.Decode(encoded[5:]) - if (err != nil) != tt.wantErr { - t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr) - return - } - if diff := deep.Equal(src, decoded); diff != nil { - t.Error(diff) + dst := &FunctionCall{} + // Check the header + msgTypeCode := encoded[0] + if msgTypeCode != 'F' { + t.Errorf("msgTypeCode %v should be 'F'", msgTypeCode) + return + } + // Check length, does not include type code character + l := binary.BigEndian.Uint32(encoded[1:5]) + if int(l) != (len(encoded) - 1) { + t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded)) + } + // Check decoding works as expected + err := dst.Decode(encoded[5:]) + if err != nil { + if !tt.wantErr { + t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr) + } + return + } + + if !reflect.DeepEqual(src, dst) { + t.Error("difference after encode / decode cycle") + t.Errorf("src = %v", src) + t.Errorf("dst = %v", dst) } }) } -} \ No newline at end of file +} diff --git a/go.mod b/go.mod index 030953b5..36041a94 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/jackc/pgproto3/v2 go 1.12 require ( - github.com/go-test/deep v1.0.8 github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/stretchr/testify v1.4.0 diff --git a/go.sum b/go.sum index 765190c1..dd9cd044 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= -github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= @@ -11,7 +9,6 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= From 40ecac487c46292c2139a1fc7a7365be2a287ba6 Mon Sep 17 00:00:00 2001 From: Martin Ashby Date: Sat, 6 Nov 2021 16:33:51 +0000 Subject: [PATCH 0730/1158] Remove unimplemented JSON marshalling for FunctionCall type. --- function_call.go | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/function_call.go b/function_call.go index 11cccb3e..b3a22c4f 100644 --- a/function_call.go +++ b/function_call.go @@ -2,7 +2,6 @@ package pgproto3 import ( "encoding/binary" - "encoding/json" "github.com/jackc/pgio" ) @@ -93,12 +92,3 @@ func (src *FunctionCall) Encode(dst []byte) []byte { pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return dst } - -// MarshalJSON implements encoding/json.Marshaler. -func (src FunctionCall) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - }{ - Type: "FunctionCall", - }) -} From 141f132ae7e1428ba7bcf519ff618c39a6d07fea Mon Sep 17 00:00:00 2001 From: Georges Varouchas Date: Mon, 8 Nov 2021 21:00:05 +0100 Subject: [PATCH 0731/1158] add a unit test on LRU context check TestLRUContext highlights the lack of context check when querying for a cached value --- stmtcache/lru_test.go | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/stmtcache/lru_test.go b/stmtcache/lru_test.go index a4108155..f594ceac 100644 --- a/stmtcache/lru_test.go +++ b/stmtcache/lru_test.go @@ -235,6 +235,40 @@ func TestLRUModeDescribe(t *testing.T) { require.Empty(t, fetchServerStatements(t, ctx, conn)) } +func TestLRUContext(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) + + // test 1 : getting a value for the first time with a cancelled context returns an error + ctx1, cancel1 := context.WithCancel(ctx) + cancel1() + + desc, err := cache.Get(ctx1, "SELECT 1") + require.Error(t, err) + require.Nil(t, desc) + + // test 2 : when querying for the 2nd time a cached value, if the context is canceled return an error + ctx2, cancel2 := context.WithCancel(ctx) + + desc, err = cache.Get(ctx2, "SELECT 2") + require.NoError(t, err) + require.NotNil(t, desc) + + cancel2() + + desc, err = cache.Get(ctx2, "SELECT 2") + require.Error(t, err) + require.Nil(t, desc) +} + func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string { result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read() require.NoError(t, result.Err) From cd7dcd58025f5936f76170cf8d9d2fa467b3c189 Mon Sep 17 00:00:00 2001 From: Georges Varouchas Date: Mon, 8 Nov 2021 21:00:24 +0100 Subject: [PATCH 0732/1158] have lru.Get() always check if context is already expired --- stmtcache/lru.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/stmtcache/lru.go b/stmtcache/lru.go index f58f2ac3..a4106457 100644 --- a/stmtcache/lru.go +++ b/stmtcache/lru.go @@ -53,6 +53,14 @@ func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription } } + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + } + if el, ok := c.m[sql]; ok { c.l.MoveToFront(el) return el.Value.(*pgconn.StatementDescription), nil From 146268e829bdea59e5381b3199e4e3b5f5388b0b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Nov 2021 04:12:35 -0600 Subject: [PATCH 0733/1158] Move context test above bad statement cleanup --- stmtcache/lru.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/stmtcache/lru.go b/stmtcache/lru.go index a4106457..90fb76c2 100644 --- a/stmtcache/lru.go +++ b/stmtcache/lru.go @@ -42,6 +42,14 @@ func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + } + // flush an outstanding bad statements txStatus := c.conn.TxStatus() if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 { @@ -53,14 +61,6 @@ func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription } } - if ctx != context.Background() { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - } - if el, ok := c.m[sql]; ok { c.l.MoveToFront(el) return el.Value.(*pgconn.StatementDescription), nil From 662ecb496ffc8c64f7bfa156694e0fe525a97685 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Nov 2021 09:56:46 -0600 Subject: [PATCH 0734/1158] Release v1.10.1 --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 45c02f1e..63933a3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +# 1.10.1 (November 20, 2021) + +* Close without waiting for response (Kei Kamikawa) +* Save waiting for network round-trip in CopyFrom (Rueian) +* Fix concurrency issue with ContextWatcher +* LRU.Get always checks context for cancellation / expiration (Georges Varouchas) + # 1.10.0 (July 24, 2021) * net.Timeout errors are no longer returned when a query is canceled via context. A wrapped context error is returned. From e80bc75409a12d505de49b9a7d7c5aab1ff3dfde Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Nov 2021 10:08:33 -0600 Subject: [PATCH 0735/1158] Release v1.9.0 --- CHANGELOG.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 64d96fa0..84173f18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,17 @@ +# 1.9.0 (November 20, 2021) + +* Fix binary hstore null decoding +* Add shopspring/decimal.NullDecimal support to integration (Eli Treuherz) +* Inet.Set supports bare IP address (Carl Dunham) +* Add zeronull.Float8 +* Fix NULL being lost when scanning unknown OID into sql.Scanner +* Fix BPChar.AssignTo **rune +* Add support for fmt.Stringer and driver.Valuer in String fields encoding (Jan Dubsky) +* Fix really big timestamp(tz)s binary format parsing (e.g. year 294276) (Jim Tsao) +* Support `map[string]*string` as hstore (Adrian Sieger) +* Fix parsing text array with negative bounds +* Add infinity support for numeric (Jim Tsao) + # 1.8.1 (July 24, 2021) * Cleaned up Go module dependency chain From 84bb47fb26e23e8859ed208ccf78dd1816d56a55 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 24 Nov 2021 07:57:51 -0600 Subject: [PATCH 0736/1158] Fix: Timestamp DecodeBinary is in UTC Preserve previously existing behavior. fixes #138 --- timestamp.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timestamp.go b/timestamp.go index a184d232..5517acb1 100644 --- a/timestamp.go +++ b/timestamp.go @@ -144,7 +144,7 @@ func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { tim := time.Unix( microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), - ) + ).UTC() *dst = Timestamp{Time: tim, Status: Present} } From e95ebc02d9bcbf851891deeb7566df84218dc44f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 28 Nov 2021 16:29:42 -0600 Subject: [PATCH 0737/1158] Release v1.9.1 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 84173f18..e34c7979 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.9.1 (November 28, 2021) + +* Fix: binary timestamp is assumed to be in UTC (restored behavior changed in v1.9.0) + # 1.9.0 (November 20, 2021) * Fix binary hstore null decoding From 37044f47f541879ff4577238795c39722fde0eb5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 26 Aug 2021 21:08:18 -0500 Subject: [PATCH 0738/1158] Remove tests against github.com/lib/pq --- README.md | 2 +- cid_test.go | 5 +---- pgtype_test.go | 1 - testutil/testutil.go | 19 ++++--------------- time_test.go | 16 +--------------- xid_test.go | 5 +---- 6 files changed, 8 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 77d59b31..bc4e72f9 100644 --- a/README.md +++ b/README.md @@ -5,4 +5,4 @@ pgtype implements Go types for over 70 PostgreSQL types. pgtype is the type system underlying the https://github.com/jackc/pgx PostgreSQL driver. These types support the binary format for enhanced performance with pgx. -They also support the database/sql `Scan` and `Value` interfaces and can be used with https://github.com/lib/pq. +They also support the database/sql `Scan` and `Value` interfaces. diff --git a/cid_test.go b/cid_test.go index 50e50cd8..5b1150eb 100644 --- a/cid_test.go +++ b/cid_test.go @@ -19,10 +19,7 @@ func TestCIDTranscode(t *testing.T) { } testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) - } + testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, "github.com/jackc/pgx/stdlib", pgTypeName, values, eqFunc) } func TestCIDSet(t *testing.T) { diff --git a/pgtype_test.go b/pgtype_test.go index 85ca55e9..5fd89dcb 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -10,7 +10,6 @@ import ( "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" _ "github.com/jackc/pgx/v4/stdlib" - _ "github.com/lib/pq" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/testutil/testutil.go b/testutil/testutil.go index e7b64b58..5dded2b9 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -11,14 +11,11 @@ import ( "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" _ "github.com/jackc/pgx/v4/stdlib" - _ "github.com/lib/pq" ) func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { var sqlDriverName string switch driverName { - case "github.com/lib/pq": - sqlDriverName = "postgres" case "github.com/jackc/pgx/stdlib": sqlDriverName = "pgx" default: @@ -98,9 +95,7 @@ func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) - } + TestDatabaseSQLSuccessfulTranscodeEqFunc(t, "github.com/jackc/pgx/stdlib", pgTypeName, values, eqFunc) } func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { @@ -205,9 +200,7 @@ func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) { func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - TestDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc) - } + TestDatabaseSQLSuccessfulNormalizeEqFunc(t, "github.com/jackc/pgx/stdlib", tests, eqFunc) } func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { @@ -287,16 +280,12 @@ func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, t func TestGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) { TestPgxGoZeroToNullConversion(t, pgTypeName, zero) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - TestDatabaseSQLGoZeroToNullConversion(t, driverName, pgTypeName, zero) - } + TestDatabaseSQLGoZeroToNullConversion(t, "github.com/jackc/pgx/stdlib", pgTypeName, zero) } func TestNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) { TestPgxNullToGoZeroConversion(t, pgTypeName, zero) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - TestDatabaseSQLNullToGoZeroConversion(t, driverName, pgTypeName, zero) - } + TestDatabaseSQLNullToGoZeroConversion(t, "github.com/jackc/pgx/stdlib", pgTypeName, zero) } func TestPgxGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) { diff --git a/time_test.go b/time_test.go index 0af42b1e..09ca3c4d 100644 --- a/time_test.go +++ b/time_test.go @@ -14,25 +14,11 @@ func TestTimeTranscode(t *testing.T) { &pgtype.Time{Microseconds: 0, Status: pgtype.Present}, &pgtype.Time{Microseconds: 1, Status: pgtype.Present}, &pgtype.Time{Microseconds: 86399999999, Status: pgtype.Present}, + &pgtype.Time{Microseconds: 86400000000, Status: pgtype.Present}, &pgtype.Time{Status: pgtype.Null}, }) } -// Test for transcoding 24:00:00 separately as github.com/lib/pq doesn't seem to support it. -func TestTimeTranscode24HH(t *testing.T) { - pgTypeName := "time" - values := []interface{}{ - &pgtype.Time{Microseconds: 86400000000, Status: pgtype.Present}, - } - - eqFunc := func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - } - - testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, "github.com/jackc/pgx/stdlib", pgTypeName, values, eqFunc) -} - func TestTimeSet(t *testing.T) { type _time time.Time diff --git a/xid_test.go b/xid_test.go index 563ce96e..531867f6 100644 --- a/xid_test.go +++ b/xid_test.go @@ -19,10 +19,7 @@ func TestXIDTranscode(t *testing.T) { } testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) - } + testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, "github.com/jackc/pgx/stdlib", pgTypeName, values, eqFunc) } func TestXIDSet(t *testing.T) { From 11d351dd75d4ee7c9e81255d903e3cdb8880cf9b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 26 Aug 2021 22:46:13 -0500 Subject: [PATCH 0739/1158] Replace Status with Valid to conform to database/sql style https://github.com/jackc/pgx/issues/1060 --- aclitem.go | 63 ++- aclitem_array.go | 170 ++++---- aclitem_array_test.go | 156 ++++---- aclitem_test.go | 18 +- array_type.go | 50 +-- bit_test.go | 8 +- bool.go | 100 ++--- bool_array.go | 183 ++++----- bool_array_test.go | 140 +++---- bool_test.go | 52 +-- box.go | 32 +- box_test.go | 14 +- bpchar.go | 45 ++- bpchar_array.go | 183 ++++----- bpchar_array_test.go | 34 +- bpchar_test.go | 16 +- bytea.go | 77 ++-- bytea_array.go | 159 ++++---- bytea_array_test.go | 120 +++--- bytea_test.go | 28 +- cid_test.go | 14 +- cidr_array.go | 207 +++++----- cidr_array_test.go | 150 +++---- circle.go | 38 +- circle_test.go | 6 +- composite_bench_test.go | 2 +- composite_fields_test.go | 2 +- composite_type.go | 105 +++-- composite_type_test.go | 12 +- convert.go | 22 +- custom_composite_test.go | 6 +- date.go | 117 +++--- date_array.go | 183 ++++----- date_array_test.go | 150 +++---- date_test.go | 74 ++-- daterange.go | 34 +- daterange_test.go | 66 ++-- enum_array.go | 170 ++++---- enum_array_test.go | 126 +++--- enum_type.go | 70 ++-- ext/gofrs-uuid/uuid.go | 111 ++---- ext/gofrs-uuid/uuid_test.go | 21 +- ext/shopspring-numeric/decimal.go | 323 +++++++-------- ext/shopspring-numeric/decimal_test.go | 185 +++++---- float4.go | 98 ++--- float4_array.go | 183 ++++----- float4_array_test.go | 140 +++---- float4_test.go | 92 ++--- float8.go | 98 ++--- float8_array.go | 183 ++++----- float8_array_test.go | 122 +++--- float8_test.go | 92 ++--- go.mod | 1 - hstore.go | 108 +++-- hstore_array.go | 159 ++++---- hstore_array_test.go | 200 +++++----- hstore_test.go | 68 ++-- inet.go | 97 ++--- inet_array.go | 207 +++++----- inet_array_test.go | 150 +++---- inet_test.go | 82 ++-- int2.go | 108 +++-- int2_array.go | 519 ++++++++++++------------- int2_array_test.go | 172 ++++---- int2_test.go | 88 ++--- int4.go | 112 +++--- int4_array.go | 519 ++++++++++++------------- int4_array_test.go | 172 ++++---- int4_test.go | 98 ++--- int4range.go | 34 +- int4range_test.go | 14 +- int8.go | 112 +++--- int8_array.go | 519 ++++++++++++------------- int8_array_test.go | 176 ++++----- int8_test.go | 100 ++--- int8range.go | 34 +- int8range_test.go | 14 +- interval.go | 61 ++- interval_test.go | 50 +-- json.go | 80 ++-- json_test.go | 44 +-- jsonb.go | 9 +- jsonb_array.go | 183 ++++----- jsonb_array_test.go | 18 +- jsonb_test.go | 34 +- line.go | 36 +- line_test.go | 6 +- lseg.go | 32 +- lseg_test.go | 10 +- macaddr.go | 75 ++-- macaddr_array.go | 183 ++++----- macaddr_array_test.go | 108 ++--- macaddr_test.go | 12 +- name_test.go | 20 +- numeric.go | 411 ++++++++++---------- numeric_array.go | 327 ++++++++-------- numeric_array_test.go | 152 ++++---- numeric_test.go | 244 ++++++------ numrange.go | 34 +- numrange_test.go | 24 +- oid_value_test.go | 14 +- path.go | 30 +- path_test.go | 8 +- pgtype.go | 12 - pgtype_test.go | 6 +- pguint32.go | 52 +-- point.go | 58 ++- point_test.go | 67 ++-- polygon.go | 43 +- polygon_test.go | 34 +- qchar.go | 47 +-- qchar_test.go | 82 ++-- record.go | 57 ++- record_test.go | 60 ++- text.go | 95 ++--- text_array.go | 183 ++++----- text_array_test.go | 144 +++---- text_test.go | 32 +- tid.go | 54 ++- tid_test.go | 15 +- time.go | 89 ++--- time_test.go | 52 +-- timestamp.go | 100 ++--- timestamp_array.go | 183 ++++----- timestamp_array_test.go | 134 +++---- timestamp_test.go | 72 ++-- timestamptz.go | 117 +++--- timestamptz_array.go | 183 ++++----- timestamptz_array_test.go | 154 ++++---- timestamptz_test.go | 92 ++--- tsrange.go | 34 +- tsrange_array.go | 153 ++++---- tsrange_test.go | 22 +- tstzrange.go | 34 +- tstzrange_array.go | 153 ++++---- tstzrange_test.go | 22 +- typed_array.go.erb | 159 ++++---- typed_range.go.erb | 38 +- unknown.go | 2 +- uuid.go | 101 ++--- uuid_array.go | 231 ++++++----- uuid_array_test.go | 172 ++++---- uuid_test.go | 74 ++-- varbit.go | 34 +- varbit_test.go | 10 +- varchar_array.go | 183 ++++----- varchar_array_test.go | 140 +++---- xid_test.go | 14 +- zeronull/float8.go | 12 +- zeronull/int2.go | 12 +- zeronull/int4.go | 12 +- zeronull/int8.go | 12 +- zeronull/text.go | 8 +- zeronull/timestamp.go | 12 +- zeronull/timestamptz.go | 12 +- zeronull/uuid.go | 12 +- 156 files changed, 6909 insertions(+), 7894 deletions(-) diff --git a/aclitem.go b/aclitem.go index 9f6587be..0c1f23b5 100644 --- a/aclitem.go +++ b/aclitem.go @@ -19,12 +19,12 @@ import ( // type ACLItem struct { String string - Status Status + Valid bool } func (dst *ACLItem) Set(src interface{}) error { if src == nil { - *dst = ACLItem{Status: Null} + *dst = ACLItem{} return nil } @@ -37,12 +37,12 @@ func (dst *ACLItem) Set(src interface{}) error { switch value := src.(type) { case string: - *dst = ACLItem{String: value, Status: Present} + *dst = ACLItem{String: value, Valid: true} case *string: if value == nil { - *dst = ACLItem{Status: Null} + *dst = ACLItem{} } else { - *dst = ACLItem{String: *value, Status: Present} + *dst = ACLItem{String: *value, Valid: true} } default: if originalSrc, ok := underlyingStringType(src); ok { @@ -55,52 +55,44 @@ func (dst *ACLItem) Set(src interface{}) error { } func (dst ACLItem) Get() interface{} { - switch dst.Status { - case Present: - return dst.String - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.String } func (src *ACLItem) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *string: - *v = src.String - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.Valid { return NullAssignTo(dst) } + switch v := dst.(type) { + case *string: + *v = src.String + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (dst *ACLItem) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = ACLItem{Status: Null} + *dst = ACLItem{} return nil } - *dst = ACLItem{String: string(src), Status: Present} + *dst = ACLItem{String: string(src), Valid: true} return nil } func (src ACLItem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return append(buf, src.String...), nil @@ -109,7 +101,7 @@ func (src ACLItem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *ACLItem) Scan(src interface{}) error { if src == nil { - *dst = ACLItem{Status: Null} + *dst = ACLItem{} return nil } @@ -127,12 +119,9 @@ func (dst *ACLItem) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src ACLItem) Value() (driver.Value, error) { - switch src.Status { - case Present: - return src.String, nil - case Null: + if !src.Valid { return nil, nil - default: - return nil, errUndefined } + + return src.String, nil } diff --git a/aclitem_array.go b/aclitem_array.go index 4e3be3bd..fc1128b7 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -11,13 +11,13 @@ import ( type ACLItemArray struct { Elements []ACLItem Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *ACLItemArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = ACLItemArray{Status: Null} + *dst = ACLItemArray{} return nil } @@ -33,9 +33,9 @@ func (dst *ACLItemArray) Set(src interface{}) error { case []string: if value == nil { - *dst = ACLItemArray{Status: Null} + *dst = ACLItemArray{} } else if len(value) == 0 { - *dst = ACLItemArray{Status: Present} + *dst = ACLItemArray{Valid: true} } else { elements := make([]ACLItem, len(value)) for i := range value { @@ -46,15 +46,15 @@ func (dst *ACLItemArray) Set(src interface{}) error { *dst = ACLItemArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*string: if value == nil { - *dst = ACLItemArray{Status: Null} + *dst = ACLItemArray{} } else if len(value) == 0 { - *dst = ACLItemArray{Status: Present} + *dst = ACLItemArray{Valid: true} } else { elements := make([]ACLItem, len(value)) for i := range value { @@ -65,20 +65,20 @@ func (dst *ACLItemArray) Set(src interface{}) error { *dst = ACLItemArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []ACLItem: if value == nil { - *dst = ACLItemArray{Status: Null} + *dst = ACLItemArray{} } else if len(value) == 0 { - *dst = ACLItemArray{Status: Present} + *dst = ACLItemArray{Valid: true} } else { *dst = ACLItemArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -87,7 +87,7 @@ func (dst *ACLItemArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = ACLItemArray{Status: Null} + *dst = ACLItemArray{} return nil } @@ -96,7 +96,7 @@ func (dst *ACLItemArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for ACLItemArray", src) } if elementsLength == 0 { - *dst = ACLItemArray{Status: Present} + *dst = ACLItemArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -109,7 +109,7 @@ func (dst *ACLItemArray) Set(src interface{}) error { *dst = ACLItemArray{ Elements: make([]ACLItem, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -176,84 +176,77 @@ func (dst *ACLItemArray) setRecursive(value reflect.Value, index, dimension int) } func (dst ACLItemArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *ACLItemArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *ACLItemArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -305,7 +298,7 @@ func (src *ACLItemArray) assignToRecursive(value reflect.Value, index, dimension func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = ACLItemArray{Status: Null} + *dst = ACLItemArray{} return nil } @@ -334,17 +327,14 @@ func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (src ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { diff --git a/aclitem_array_test.go b/aclitem_array_test.go index 8f015f40..0d6adb1d 100644 --- a/aclitem_array_test.go +++ b/aclitem_array_test.go @@ -13,42 +13,42 @@ func TestACLItemArrayTranscode(t *testing.T) { &pgtype.ACLItemArray{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {Status: pgtype.Null}, + {String: "=r/postgres", Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.ACLItemArray{Status: pgtype.Null}, + &pgtype.ACLItemArray{}, &pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - //{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - {String: `postgres=arwdDxt/postgres`, Status: pgtype.Present}, // todo: remove after fixing above case - {String: "=r/postgres", Status: pgtype.Present}, - {Status: pgtype.Null}, - {String: "=r/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + //{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}, + {String: `postgres=arwdDxt/postgres`, Valid: true}, // todo: remove after fixing above case + {String: "=r/postgres", Valid: true}, + {}, + {String: "=r/postgres", Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } @@ -61,22 +61,22 @@ func TestACLItemArraySet(t *testing.T) { { source: []string{"=r/postgres"}, result: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Elements: []pgtype.ACLItem{{String: "=r/postgres", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]string)(nil)), - result: pgtype.ACLItemArray{Status: pgtype.Null}, + result: pgtype.ACLItemArray{}, }, { source: [][]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, result: pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]string{ @@ -90,27 +90,27 @@ func TestACLItemArraySet(t *testing.T) { "postgres=arwdDxt/postgres"}}}}, result: pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, result: pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3]string{ @@ -124,18 +124,18 @@ func TestACLItemArraySet(t *testing.T) { "postgres=arwdDxt/postgres"}}}}, result: pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -168,57 +168,57 @@ func TestACLItemArrayAssignTo(t *testing.T) { }{ { src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Elements: []pgtype.ACLItem{{String: "=r/postgres", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &stringSlice, expected: []string{"=r/postgres"}, }, { src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Elements: []pgtype.ACLItem{{String: "=r/postgres", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &namedStringSlice, expected: _stringSlice{"=r/postgres"}, }, { - src: pgtype.ACLItemArray{Status: pgtype.Null}, + src: pgtype.ACLItemArray{}, dst: &stringSlice, expected: (([]string)(nil)), }, { - src: pgtype.ACLItemArray{Status: pgtype.Present}, + src: pgtype.ACLItemArray{Valid: true}, dst: &stringSlice, expected: []string{}, }, { src: pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringSliceDim2, expected: [][]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, }, { src: pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringSliceDim4, expected: [][][][]string{ {{{ @@ -233,28 +233,28 @@ func TestACLItemArrayAssignTo(t *testing.T) { { src: pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim2, expected: [2][1]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, }, { src: pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim4, expected: [2][1][1][3]string{ {{{ @@ -285,37 +285,37 @@ func TestACLItemArrayAssignTo(t *testing.T) { }{ { src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{Status: pgtype.Null}}, + Elements: []pgtype.ACLItem{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &stringSlice, }, { src: pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim2, }, { src: pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringSlice, }, { src: pgtype.ACLItemArray{ Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim4, }, } diff --git a/aclitem_test.go b/aclitem_test.go index a37d7657..4e9bc5b0 100644 --- a/aclitem_test.go +++ b/aclitem_test.go @@ -10,9 +10,9 @@ import ( func TestACLItemTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "aclitem", []interface{}{ - &pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - //&pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - &pgtype.ACLItem{Status: pgtype.Null}, + &pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Valid: true}, + //&pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}, + &pgtype.ACLItem{}, }) } @@ -21,8 +21,8 @@ func TestACLItemSet(t *testing.T) { source interface{} result pgtype.ACLItem }{ - {source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - {source: (*string)(nil), result: pgtype.ACLItem{Status: pgtype.Null}}, + {source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Valid: true}}, + {source: (*string)(nil), result: pgtype.ACLItem{}}, } for i, tt := range successfulTests { @@ -47,8 +47,8 @@ func TestACLItemAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"}, - {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Valid: true}, dst: &s, expected: "postgres=arwdDxt/postgres"}, + {src: pgtype.ACLItem{}, dst: &ps, expected: ((*string)(nil))}, } for i, tt := range simpleTests { @@ -67,7 +67,7 @@ func TestACLItemAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Valid: true}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, } for i, tt := range pointerAllocTests { @@ -85,7 +85,7 @@ func TestACLItemAssignTo(t *testing.T) { src pgtype.ACLItem dst interface{} }{ - {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &s}, + {src: pgtype.ACLItem{}, dst: &s}, } for i, tt := range errorTests { diff --git a/array_type.go b/array_type.go index 1bd0244b..1df1689f 100644 --- a/array_type.go +++ b/array_type.go @@ -20,7 +20,7 @@ type ArrayType struct { newElement func() ValueTranscoder elementOID uint32 - status Status + valid bool } func NewArrayType(typeName string, elementOID uint32, newElement func() ValueTranscoder) *ArrayType { @@ -31,7 +31,7 @@ func (at *ArrayType) NewTypeValue() Value { return &ArrayType{ elements: at.elements, dimensions: at.dimensions, - status: at.status, + valid: at.valid, typeName: at.typeName, elementOID: at.elementOID, @@ -46,7 +46,7 @@ func (at *ArrayType) TypeName() string { func (dst *ArrayType) setNil() { dst.elements = nil dst.dimensions = nil - dst.status = Null + dst.valid = false } func (dst *ArrayType) Set(src interface{}) error { @@ -77,24 +77,21 @@ func (dst *ArrayType) Set(src interface{}) error { dst.elements[i] = v } dst.dimensions = []ArrayDimension{{Length: int32(len(dst.elements)), LowerBound: 1}} - dst.status = Present + dst.valid = true return nil } -func (dst ArrayType) Get() interface{} { - switch dst.status { - case Present: - elementValues := make([]interface{}, len(dst.elements)) - for i := range dst.elements { - elementValues[i] = dst.elements[i].Get() - } - return elementValues - case Null: +func (src ArrayType) Get() interface{} { + if !src.valid { return nil - default: - return dst.status } + + elementValues := make([]interface{}, len(src.elements)) + for i := range src.elements { + elementValues[i] = src.elements[i].Get() + } + return elementValues } func (src *ArrayType) AssignTo(dst interface{}) error { @@ -110,8 +107,7 @@ func (src *ArrayType) AssignTo(dst interface{}) error { return fmt.Errorf("cannot assign to pointer to non-slice") } - switch src.status { - case Present: + if src.valid { slice := reflect.MakeSlice(sliceType, len(src.elements), len(src.elements)) elemType := sliceType.Elem() @@ -127,12 +123,10 @@ func (src *ArrayType) AssignTo(dst interface{}) error { sliceVal.Set(slice) return nil - case Null: + } else { sliceVal.Set(reflect.Zero(sliceType)) return nil } - - return fmt.Errorf("cannot decode %#v into %T", src, dst) } func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error { @@ -168,7 +162,7 @@ func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error { dst.elements = elements dst.dimensions = uta.Dimensions - dst.status = Present + dst.valid = true return nil } @@ -190,7 +184,7 @@ func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error { if len(arrayHeader.Dimensions) == 0 { dst.elements = elements dst.dimensions = arrayHeader.Dimensions - dst.status = Present + dst.valid = true return nil } @@ -220,17 +214,14 @@ func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error { dst.elements = elements dst.dimensions = arrayHeader.Dimensions - dst.status = Present + dst.valid = true return nil } func (src ArrayType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.status { - case Null: + if !src.valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.dimensions) == 0 { @@ -283,11 +274,8 @@ func (src ArrayType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src ArrayType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.status { - case Null: + if !src.valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ diff --git a/bit_test.go b/bit_test.go index 2e9c9b6e..df5fe4cb 100644 --- a/bit_test.go +++ b/bit_test.go @@ -9,9 +9,9 @@ import ( func TestBitTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "bit(40)", []interface{}{ - &pgtype.Varbit{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Status: pgtype.Present}, - &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Status: pgtype.Present}, - &pgtype.Varbit{Status: pgtype.Null}, + &pgtype.Varbit{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Valid: true}, + &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, + &pgtype.Varbit{}, }) } @@ -19,7 +19,7 @@ func TestBitNormalize(t *testing.T) { testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { SQL: "select B'111111111'", - Value: &pgtype.Bit{Bytes: []byte{255, 128}, Len: 9, Status: pgtype.Present}, + Value: &pgtype.Bit{Bytes: []byte{255, 128}, Len: 9, Valid: true}, }, }) } diff --git a/bool.go b/bool.go index 676c8e5d..4fcc67e3 100644 --- a/bool.go +++ b/bool.go @@ -8,13 +8,13 @@ import ( ) type Bool struct { - Bool bool - Status Status + Bool bool + Valid bool } func (dst *Bool) Set(src interface{}) error { if src == nil { - *dst = Bool{Status: Null} + *dst = Bool{} return nil } @@ -27,22 +27,22 @@ func (dst *Bool) Set(src interface{}) error { switch value := src.(type) { case bool: - *dst = Bool{Bool: value, Status: Present} + *dst = Bool{Bool: value, Valid: true} case string: bb, err := strconv.ParseBool(value) if err != nil { return err } - *dst = Bool{Bool: bb, Status: Present} + *dst = Bool{Bool: bb, Valid: true} case *bool: if value == nil { - *dst = Bool{Status: Null} + *dst = Bool{} } else { return dst.Set(*value) } case *string: if value == nil { - *dst = Bool{Status: Null} + *dst = Bool{} } else { return dst.Set(*value) } @@ -57,39 +57,33 @@ func (dst *Bool) Set(src interface{}) error { } func (dst Bool) Get() interface{} { - switch dst.Status { - case Present: - return dst.Bool - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + + return dst.Bool } func (src *Bool) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *bool: - *v = src.Bool - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + switch v := dst.(type) { + case *bool: + *v = src.Bool + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } } func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Bool{Status: Null} + *dst = Bool{} return nil } @@ -97,13 +91,13 @@ func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { return fmt.Errorf("invalid length for bool: %v", len(src)) } - *dst = Bool{Bool: src[0] == 't', Status: Present} + *dst = Bool{Bool: src[0] == 't', Valid: true} return nil } func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Bool{Status: Null} + *dst = Bool{} return nil } @@ -111,16 +105,13 @@ func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { return fmt.Errorf("invalid length for bool: %v", len(src)) } - *dst = Bool{Bool: src[0] == 1, Status: Present} + *dst = Bool{Bool: src[0] == 1, Valid: true} return nil } func (src Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if src.Bool { @@ -133,11 +124,8 @@ func (src Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Bool) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if src.Bool { @@ -152,13 +140,13 @@ func (src Bool) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Bool) Scan(src interface{}) error { if src == nil { - *dst = Bool{Status: Null} + *dst = Bool{} return nil } switch src := src.(type) { case bool: - *dst = Bool{Bool: src, Status: Present} + *dst = Bool{Bool: src, Valid: true} return nil case string: return dst.DecodeText(nil, []byte(src)) @@ -173,31 +161,23 @@ func (dst *Bool) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Bool) Value() (driver.Value, error) { - switch src.Status { - case Present: - return src.Bool, nil - case Null: + if !src.Valid { return nil, nil - default: - return nil, errUndefined } + + return src.Bool, nil } func (src Bool) MarshalJSON() ([]byte, error) { - switch src.Status { - case Present: - if src.Bool { - return []byte("true"), nil - } else { - return []byte("false"), nil - } - case Null: + if !src.Valid { return []byte("null"), nil - case Undefined: - return nil, errUndefined } - return nil, errBadStatus + if src.Bool { + return []byte("true"), nil + } else { + return []byte("false"), nil + } } func (dst *Bool) UnmarshalJSON(b []byte) error { @@ -208,9 +188,9 @@ func (dst *Bool) UnmarshalJSON(b []byte) error { } if v == nil { - *dst = Bool{Status: Null} + *dst = Bool{} } else { - *dst = Bool{Bool: *v, Status: Present} + *dst = Bool{Bool: *v, Valid: true} } return nil diff --git a/bool_array.go b/bool_array.go index 6558d971..a282fd6b 100644 --- a/bool_array.go +++ b/bool_array.go @@ -14,13 +14,13 @@ import ( type BoolArray struct { Elements []Bool Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *BoolArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = BoolArray{Status: Null} + *dst = BoolArray{} return nil } @@ -36,9 +36,9 @@ func (dst *BoolArray) Set(src interface{}) error { case []bool: if value == nil { - *dst = BoolArray{Status: Null} + *dst = BoolArray{} } else if len(value) == 0 { - *dst = BoolArray{Status: Present} + *dst = BoolArray{Valid: true} } else { elements := make([]Bool, len(value)) for i := range value { @@ -49,15 +49,15 @@ func (dst *BoolArray) Set(src interface{}) error { *dst = BoolArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*bool: if value == nil { - *dst = BoolArray{Status: Null} + *dst = BoolArray{} } else if len(value) == 0 { - *dst = BoolArray{Status: Present} + *dst = BoolArray{Valid: true} } else { elements := make([]Bool, len(value)) for i := range value { @@ -68,20 +68,20 @@ func (dst *BoolArray) Set(src interface{}) error { *dst = BoolArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []Bool: if value == nil { - *dst = BoolArray{Status: Null} + *dst = BoolArray{} } else if len(value) == 0 { - *dst = BoolArray{Status: Present} + *dst = BoolArray{Valid: true} } else { *dst = BoolArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -90,7 +90,7 @@ func (dst *BoolArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = BoolArray{Status: Null} + *dst = BoolArray{} return nil } @@ -99,7 +99,7 @@ func (dst *BoolArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for BoolArray", src) } if elementsLength == 0 { - *dst = BoolArray{Status: Present} + *dst = BoolArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -112,7 +112,7 @@ func (dst *BoolArray) Set(src interface{}) error { *dst = BoolArray{ Elements: make([]Bool, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -179,84 +179,77 @@ func (dst *BoolArray) setRecursive(value reflect.Value, index, dimension int) (i } func (dst BoolArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *BoolArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]bool: - *v = make([]bool, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*bool: - *v = make([]*bool, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]bool: + *v = make([]bool, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*bool: + *v = make([]*bool, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *BoolArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -308,7 +301,7 @@ func (src *BoolArray) assignToRecursive(value reflect.Value, index, dimension in func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = BoolArray{Status: Null} + *dst = BoolArray{} return nil } @@ -337,14 +330,14 @@ func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = BoolArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = BoolArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = BoolArray{Status: Null} + *dst = BoolArray{} return nil } @@ -355,7 +348,7 @@ func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = BoolArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = BoolArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -380,16 +373,13 @@ func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = BoolArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = BoolArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -442,11 +432,8 @@ func (src BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -460,7 +447,7 @@ func (src BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/bool_array_test.go b/bool_array_test.go index be567e59..cfb9ad79 100644 --- a/bool_array_test.go +++ b/bool_array_test.go @@ -13,41 +13,41 @@ func TestBoolArrayTranscode(t *testing.T) { &pgtype.BoolArray{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.BoolArray{ Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Status: pgtype.Null}, + {Bool: true, Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.BoolArray{Status: pgtype.Null}, + &pgtype.BoolArray{}, &pgtype.BoolArray{ Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Bool: false, Status: pgtype.Present}, + {Bool: true, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {}, + {Bool: false, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.BoolArray{ Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } @@ -60,61 +60,61 @@ func TestBoolArraySet(t *testing.T) { { source: []bool{true}, result: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, + Elements: []pgtype.Bool{{Bool: true, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]bool)(nil)), - result: pgtype.BoolArray{Status: pgtype.Null}, + result: pgtype.BoolArray{}, }, { source: [][]bool{{true}, {false}}, result: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, result: pgtype.BoolArray{ Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1]bool{{true}, {false}}, result: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, result: pgtype.BoolArray{ Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -147,81 +147,81 @@ func TestBoolArrayAssignTo(t *testing.T) { }{ { src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, + Elements: []pgtype.Bool{{Bool: true, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &boolSlice, expected: []bool{true}, }, { src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, + Elements: []pgtype.Bool{{Bool: true, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &namedBoolSlice, expected: _boolSlice{true}, }, { - src: pgtype.BoolArray{Status: pgtype.Null}, + src: pgtype.BoolArray{}, dst: &boolSlice, expected: (([]bool)(nil)), }, { - src: pgtype.BoolArray{Status: pgtype.Present}, + src: pgtype.BoolArray{Valid: true}, dst: &boolSlice, expected: []bool{}, }, { src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, expected: [][]bool{{true}, {false}}, dst: &boolSliceDim2, }, { src: pgtype.BoolArray{ Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, expected: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, dst: &boolSliceDim4, }, { src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, expected: [2][1]bool{{true}, {false}}, dst: &boolArrayDim2, }, { src: pgtype.BoolArray{ Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, expected: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, dst: &boolArrayDim4, }, @@ -244,31 +244,31 @@ func TestBoolArrayAssignTo(t *testing.T) { }{ { src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Status: pgtype.Null}}, + Elements: []pgtype.Bool{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &boolSlice, }, { src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &boolArrayDim2, }, { src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &boolSlice, }, { src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &boolArrayDim4, }, } diff --git a/bool_test.go b/bool_test.go index 8e7a5220..a1ba9bb0 100644 --- a/bool_test.go +++ b/bool_test.go @@ -10,9 +10,9 @@ import ( func TestBoolTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "bool", []interface{}{ - &pgtype.Bool{Bool: false, Status: pgtype.Present}, - &pgtype.Bool{Bool: true, Status: pgtype.Present}, - &pgtype.Bool{Bool: false, Status: pgtype.Null}, + &pgtype.Bool{Bool: false, Valid: true}, + &pgtype.Bool{Bool: true, Valid: true}, + &pgtype.Bool{Bool: false}, }) } @@ -21,15 +21,15 @@ func TestBoolSet(t *testing.T) { source interface{} result pgtype.Bool }{ - {source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, - {source: false, result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, - {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, - {source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, - {source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, - {source: "f", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, - {source: _bool(true), result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, - {source: _bool(false), result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, - {source: nil, result: pgtype.Bool{Status: pgtype.Null}}, + {source: true, result: pgtype.Bool{Bool: true, Valid: true}}, + {source: false, result: pgtype.Bool{Bool: false, Valid: true}}, + {source: "true", result: pgtype.Bool{Bool: true, Valid: true}}, + {source: "false", result: pgtype.Bool{Bool: false, Valid: true}}, + {source: "t", result: pgtype.Bool{Bool: true, Valid: true}}, + {source: "f", result: pgtype.Bool{Bool: false, Valid: true}}, + {source: _bool(true), result: pgtype.Bool{Bool: true, Valid: true}}, + {source: _bool(false), result: pgtype.Bool{Bool: false, Valid: true}}, + {source: nil, result: pgtype.Bool{}}, } for i, tt := range successfulTests { @@ -56,12 +56,12 @@ func TestBoolAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &b, expected: false}, - {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &b, expected: true}, - {src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &_b, expected: _bool(false)}, - {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_b, expected: _bool(true)}, - {src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &pb, expected: ((*bool)(nil))}, - {src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &_pb, expected: ((*_bool)(nil))}, + {src: pgtype.Bool{Bool: false, Valid: true}, dst: &b, expected: false}, + {src: pgtype.Bool{Bool: true, Valid: true}, dst: &b, expected: true}, + {src: pgtype.Bool{Bool: false, Valid: true}, dst: &_b, expected: _bool(false)}, + {src: pgtype.Bool{Bool: true, Valid: true}, dst: &_b, expected: _bool(true)}, + {src: pgtype.Bool{Bool: false}, dst: &pb, expected: ((*bool)(nil))}, + {src: pgtype.Bool{Bool: false}, dst: &_pb, expected: ((*_bool)(nil))}, } for i, tt := range simpleTests { @@ -80,8 +80,8 @@ func TestBoolAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &pb, expected: true}, - {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_pb, expected: _bool(true)}, + {src: pgtype.Bool{Bool: true, Valid: true}, dst: &pb, expected: true}, + {src: pgtype.Bool{Bool: true, Valid: true}, dst: &_pb, expected: _bool(true)}, } for i, tt := range pointerAllocTests { @@ -101,9 +101,9 @@ func TestBoolMarshalJSON(t *testing.T) { source pgtype.Bool result string }{ - {source: pgtype.Bool{Status: pgtype.Null}, result: "null"}, - {source: pgtype.Bool{Bool: true, Status: pgtype.Present}, result: "true"}, - {source: pgtype.Bool{Bool: false, Status: pgtype.Present}, result: "false"}, + {source: pgtype.Bool{}, result: "null"}, + {source: pgtype.Bool{Bool: true, Valid: true}, result: "true"}, + {source: pgtype.Bool{Bool: false, Valid: true}, result: "false"}, } for i, tt := range successfulTests { r, err := tt.source.MarshalJSON() @@ -122,9 +122,9 @@ func TestBoolUnmarshalJSON(t *testing.T) { source string result pgtype.Bool }{ - {source: "null", result: pgtype.Bool{Status: pgtype.Null}}, - {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, - {source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: "null", result: pgtype.Bool{}}, + {source: "true", result: pgtype.Bool{Bool: true, Valid: true}}, + {source: "false", result: pgtype.Bool{Bool: false, Valid: true}}, } for i, tt := range successfulTests { var r pgtype.Bool diff --git a/box.go b/box.go index 27fb829e..868b40a2 100644 --- a/box.go +++ b/box.go @@ -12,8 +12,8 @@ import ( ) type Box struct { - P [2]Vec2 - Status Status + P [2]Vec2 + Valid bool } func (dst *Box) Set(src interface{}) error { @@ -21,14 +21,10 @@ func (dst *Box) Set(src interface{}) error { } func (dst Box) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *Box) AssignTo(dst interface{}) error { @@ -37,7 +33,7 @@ func (src *Box) AssignTo(dst interface{}) error { func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Box{Status: Null} + *dst = Box{} return nil } @@ -78,13 +74,13 @@ func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present} + *dst = Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true} return nil } func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Box{Status: Null} + *dst = Box{} return nil } @@ -102,17 +98,14 @@ func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { {math.Float64frombits(x1), math.Float64frombits(y1)}, {math.Float64frombits(x2), math.Float64frombits(y2)}, }, - Status: Present, + Valid: true, } return nil } func (src Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, @@ -125,11 +118,8 @@ func (src Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Box) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) @@ -143,7 +133,7 @@ func (src Box) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Box) Scan(src interface{}) error { if src == nil { - *dst = Box{Status: Null} + *dst = Box{} return nil } diff --git a/box_test.go b/box_test.go index 643c74ec..c7e00553 100644 --- a/box_test.go +++ b/box_test.go @@ -10,14 +10,14 @@ import ( func TestBoxTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "box", []interface{}{ &pgtype.Box{ - P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, - Status: pgtype.Present, + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, + Valid: true, }, &pgtype.Box{ - P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, - Status: pgtype.Present, + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Valid: true, }, - &pgtype.Box{Status: pgtype.Null}, + &pgtype.Box{}, }) } @@ -26,8 +26,8 @@ func TestBoxNormalize(t *testing.T) { { SQL: "select '3.14, 1.678, 7.1, 5.234'::box", Value: &pgtype.Box{ - P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, - Status: pgtype.Present, + P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, + Valid: true, }, }, }) diff --git a/bpchar.go b/bpchar.go index c5fa42ea..2e899ea8 100644 --- a/bpchar.go +++ b/bpchar.go @@ -21,32 +21,31 @@ func (dst BPChar) Get() interface{} { // AssignTo assigns from src to dst. func (src *BPChar) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *rune: - runes := []rune(src.String) - if len(runes) == 1 { - *v = runes[0] - return nil - } - case *string: - *v = src.String - return nil - case *[]byte: - *v = make([]byte, len(src.String)) - copy(*v, src.String) - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.Valid { return NullAssignTo(dst) } + switch v := dst.(type) { + case *rune: + runes := []rune(src.String) + if len(runes) == 1 { + *v = runes[0] + return nil + } + case *string: + *v = src.String + return nil + case *[]byte: + *v = make([]byte, len(src.String)) + copy(*v, src.String) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + return fmt.Errorf("cannot decode %#v into %T", src, dst) } diff --git a/bpchar_array.go b/bpchar_array.go index 8e792214..c73c78a3 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -14,13 +14,13 @@ import ( type BPCharArray struct { Elements []BPChar Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *BPCharArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = BPCharArray{Status: Null} + *dst = BPCharArray{} return nil } @@ -36,9 +36,9 @@ func (dst *BPCharArray) Set(src interface{}) error { case []string: if value == nil { - *dst = BPCharArray{Status: Null} + *dst = BPCharArray{} } else if len(value) == 0 { - *dst = BPCharArray{Status: Present} + *dst = BPCharArray{Valid: true} } else { elements := make([]BPChar, len(value)) for i := range value { @@ -49,15 +49,15 @@ func (dst *BPCharArray) Set(src interface{}) error { *dst = BPCharArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*string: if value == nil { - *dst = BPCharArray{Status: Null} + *dst = BPCharArray{} } else if len(value) == 0 { - *dst = BPCharArray{Status: Present} + *dst = BPCharArray{Valid: true} } else { elements := make([]BPChar, len(value)) for i := range value { @@ -68,20 +68,20 @@ func (dst *BPCharArray) Set(src interface{}) error { *dst = BPCharArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []BPChar: if value == nil { - *dst = BPCharArray{Status: Null} + *dst = BPCharArray{} } else if len(value) == 0 { - *dst = BPCharArray{Status: Present} + *dst = BPCharArray{Valid: true} } else { *dst = BPCharArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -90,7 +90,7 @@ func (dst *BPCharArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = BPCharArray{Status: Null} + *dst = BPCharArray{} return nil } @@ -99,7 +99,7 @@ func (dst *BPCharArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for BPCharArray", src) } if elementsLength == 0 { - *dst = BPCharArray{Status: Present} + *dst = BPCharArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -112,7 +112,7 @@ func (dst *BPCharArray) Set(src interface{}) error { *dst = BPCharArray{ Elements: make([]BPChar, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -179,84 +179,77 @@ func (dst *BPCharArray) setRecursive(value reflect.Value, index, dimension int) } func (dst BPCharArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *BPCharArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *BPCharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -308,7 +301,7 @@ func (src *BPCharArray) assignToRecursive(value reflect.Value, index, dimension func (dst *BPCharArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = BPCharArray{Status: Null} + *dst = BPCharArray{} return nil } @@ -337,14 +330,14 @@ func (dst *BPCharArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = BPCharArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = BPCharArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *BPCharArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = BPCharArray{Status: Null} + *dst = BPCharArray{} return nil } @@ -355,7 +348,7 @@ func (dst *BPCharArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = BPCharArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = BPCharArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -380,16 +373,13 @@ func (dst *BPCharArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = BPCharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = BPCharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -442,11 +432,8 @@ func (src BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src BPCharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -460,7 +447,7 @@ func (src BPCharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/bpchar_array_test.go b/bpchar_array_test.go index af6bf09a..277f6e3c 100644 --- a/bpchar_array_test.go +++ b/bpchar_array_test.go @@ -12,44 +12,44 @@ func TestBPCharArrayTranscode(t *testing.T) { &pgtype.BPCharArray{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.BPCharArray{ Elements: []pgtype.BPChar{ - pgtype.BPChar{String: "foo ", Status: pgtype.Present}, - pgtype.BPChar{Status: pgtype.Null}, + pgtype.BPChar{String: "foo ", Valid: true}, + pgtype.BPChar{}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.BPCharArray{Status: pgtype.Null}, + &pgtype.BPCharArray{}, &pgtype.BPCharArray{ Elements: []pgtype.BPChar{ - pgtype.BPChar{String: "bar ", Status: pgtype.Present}, - pgtype.BPChar{String: "NuLL ", Status: pgtype.Present}, - pgtype.BPChar{String: `wow"quz\`, Status: pgtype.Present}, - pgtype.BPChar{String: "1 ", Status: pgtype.Present}, - pgtype.BPChar{String: "1 ", Status: pgtype.Present}, - pgtype.BPChar{String: "null ", Status: pgtype.Present}, + pgtype.BPChar{String: "bar ", Valid: true}, + pgtype.BPChar{String: "NuLL ", Valid: true}, + pgtype.BPChar{String: `wow"quz\`, Valid: true}, + pgtype.BPChar{String: "1 ", Valid: true}, + pgtype.BPChar{String: "1 ", Valid: true}, + pgtype.BPChar{String: "null ", Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}, }, - Status: pgtype.Present, + Valid: true, }, &pgtype.BPCharArray{ Elements: []pgtype.BPChar{ - pgtype.BPChar{String: " bar ", Status: pgtype.Present}, - pgtype.BPChar{String: " baz ", Status: pgtype.Present}, - pgtype.BPChar{String: " quz ", Status: pgtype.Present}, - pgtype.BPChar{String: "foo ", Status: pgtype.Present}, + pgtype.BPChar{String: " bar ", Valid: true}, + pgtype.BPChar{String: " baz ", Valid: true}, + pgtype.BPChar{String: " quz ", Valid: true}, + pgtype.BPChar{String: "foo ", Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } diff --git a/bpchar_test.go b/bpchar_test.go index 7b8c1da3..fe7e651c 100644 --- a/bpchar_test.go +++ b/bpchar_test.go @@ -10,16 +10,16 @@ import ( func TestChar3Transcode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "char(3)", []interface{}{ - &pgtype.BPChar{String: "a ", Status: pgtype.Present}, - &pgtype.BPChar{String: " a ", Status: pgtype.Present}, - &pgtype.BPChar{String: "å—¨ ", Status: pgtype.Present}, - &pgtype.BPChar{String: " ", Status: pgtype.Present}, - &pgtype.BPChar{Status: pgtype.Null}, + &pgtype.BPChar{String: "a ", Valid: true}, + &pgtype.BPChar{String: " a ", Valid: true}, + &pgtype.BPChar{String: "å—¨ ", Valid: true}, + &pgtype.BPChar{String: " ", Valid: true}, + &pgtype.BPChar{}, }, func(aa, bb interface{}) bool { a := aa.(pgtype.BPChar) b := bb.(pgtype.BPChar) - return a.Status == b.Status && a.String == b.String + return a.Valid == b.Valid && a.String == b.String }) } @@ -33,8 +33,8 @@ func TestBPCharAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.BPChar{String: "simple", Status: pgtype.Present}, dst: &str, expected: "simple"}, - {src: pgtype.BPChar{String: "å—¨", Status: pgtype.Present}, dst: &run, expected: 'å—¨'}, + {src: pgtype.BPChar{String: "simple", Valid: true}, dst: &str, expected: "simple"}, + {src: pgtype.BPChar{String: "å—¨", Valid: true}, dst: &run, expected: 'å—¨'}, } for i, tt := range simpleTests { diff --git a/bytea.go b/bytea.go index 67eba350..d4c4e436 100644 --- a/bytea.go +++ b/bytea.go @@ -7,13 +7,13 @@ import ( ) type Bytea struct { - Bytes []byte - Status Status + Bytes []byte + Valid bool } func (dst *Bytea) Set(src interface{}) error { if src == nil { - *dst = Bytea{Status: Null} + *dst = Bytea{} return nil } @@ -27,9 +27,9 @@ func (dst *Bytea) Set(src interface{}) error { switch value := src.(type) { case []byte: if value != nil { - *dst = Bytea{Bytes: value, Status: Present} + *dst = Bytea{Bytes: value, Valid: true} } else { - *dst = Bytea{Status: Null} + *dst = Bytea{} } default: if originalSrc, ok := underlyingBytesType(src); ok { @@ -42,43 +42,36 @@ func (dst *Bytea) Set(src interface{}) error { } func (dst Bytea) Get() interface{} { - switch dst.Status { - case Present: - return dst.Bytes - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.Bytes } func (src *Bytea) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *[]byte: - buf := make([]byte, len(src.Bytes)) - copy(buf, src.Bytes) - *v = buf - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + switch v := dst.(type) { + case *[]byte: + buf := make([]byte, len(src.Bytes)) + copy(buf, src.Bytes) + *v = buf + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } } // DecodeText only supports the hex format. This has been the default since // PostgreSQL 9.0. func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Bytea{Status: Null} + *dst = Bytea{} return nil } @@ -92,26 +85,23 @@ func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Bytea{Bytes: buf, Status: Present} + *dst = Bytea{Bytes: buf, Valid: true} return nil } func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Bytea{Status: Null} + *dst = Bytea{} return nil } - *dst = Bytea{Bytes: src, Status: Present} + *dst = Bytea{Bytes: src, Valid: true} return nil } func (src Bytea) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = append(buf, `\x`...) @@ -120,11 +110,8 @@ func (src Bytea) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Bytea) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return append(buf, src.Bytes...), nil @@ -133,7 +120,7 @@ func (src Bytea) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Bytea) Scan(src interface{}) error { if src == nil { - *dst = Bytea{Status: Null} + *dst = Bytea{} return nil } @@ -143,7 +130,7 @@ func (dst *Bytea) Scan(src interface{}) error { case []byte: buf := make([]byte, len(src)) copy(buf, src) - *dst = Bytea{Bytes: buf, Status: Present} + *dst = Bytea{Bytes: buf, Valid: true} return nil } @@ -152,12 +139,8 @@ func (dst *Bytea) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Bytea) Value() (driver.Value, error) { - switch src.Status { - case Present: - return src.Bytes, nil - case Null: + if !src.Valid { return nil, nil - default: - return nil, errUndefined } + return src.Bytes, nil } diff --git a/bytea_array.go b/bytea_array.go index 69d1ceb9..7c539e21 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -14,13 +14,13 @@ import ( type ByteaArray struct { Elements []Bytea Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *ByteaArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = ByteaArray{Status: Null} + *dst = ByteaArray{} return nil } @@ -36,9 +36,9 @@ func (dst *ByteaArray) Set(src interface{}) error { case [][]byte: if value == nil { - *dst = ByteaArray{Status: Null} + *dst = ByteaArray{} } else if len(value) == 0 { - *dst = ByteaArray{Status: Present} + *dst = ByteaArray{Valid: true} } else { elements := make([]Bytea, len(value)) for i := range value { @@ -49,20 +49,20 @@ func (dst *ByteaArray) Set(src interface{}) error { *dst = ByteaArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []Bytea: if value == nil { - *dst = ByteaArray{Status: Null} + *dst = ByteaArray{} } else if len(value) == 0 { - *dst = ByteaArray{Status: Present} + *dst = ByteaArray{Valid: true} } else { *dst = ByteaArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -71,7 +71,7 @@ func (dst *ByteaArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = ByteaArray{Status: Null} + *dst = ByteaArray{} return nil } @@ -80,7 +80,7 @@ func (dst *ByteaArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for ByteaArray", src) } if elementsLength == 0 { - *dst = ByteaArray{Status: Present} + *dst = ByteaArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -93,7 +93,7 @@ func (dst *ByteaArray) Set(src interface{}) error { *dst = ByteaArray{ Elements: make([]Bytea, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -160,75 +160,68 @@ func (dst *ByteaArray) setRecursive(value reflect.Value, index, dimension int) ( } func (dst ByteaArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *ByteaArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[][]byte: - *v = make([][]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[][]byte: + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *ByteaArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -280,7 +273,7 @@ func (src *ByteaArray) assignToRecursive(value reflect.Value, index, dimension i func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = ByteaArray{Status: Null} + *dst = ByteaArray{} return nil } @@ -309,14 +302,14 @@ func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = ByteaArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = ByteaArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = ByteaArray{Status: Null} + *dst = ByteaArray{} return nil } @@ -327,7 +320,7 @@ func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = ByteaArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = ByteaArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -352,16 +345,13 @@ func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = ByteaArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = ByteaArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -414,11 +404,8 @@ func (src ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -432,7 +419,7 @@ func (src ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/bytea_array_test.go b/bytea_array_test.go index 27c0382e..1473eb9c 100644 --- a/bytea_array_test.go +++ b/bytea_array_test.go @@ -13,41 +13,41 @@ func TestByteaArrayTranscode(t *testing.T) { &pgtype.ByteaArray{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.ByteaArray{ Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Status: pgtype.Null}, + {Bytes: []byte{1, 2, 3}, Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.ByteaArray{Status: pgtype.Null}, + &pgtype.ByteaArray{}, &pgtype.ByteaArray{ Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{}, Status: pgtype.Present}, - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Bytes: []byte{1}, Status: pgtype.Present}, + {Bytes: []byte{1, 2, 3}, Valid: true}, + {Bytes: []byte{1, 2, 3}, Valid: true}, + {Bytes: []byte{}, Valid: true}, + {Bytes: []byte{1, 2, 3}, Valid: true}, + {}, + {Bytes: []byte{1}, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.ByteaArray{ Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{}, Status: pgtype.Present}, - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{1}, Status: pgtype.Present}, + {Bytes: []byte{1, 2, 3}, Valid: true}, + {Bytes: []byte{}, Valid: true}, + {Bytes: []byte{1, 2, 3}, Valid: true}, + {Bytes: []byte{1}, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } @@ -60,61 +60,61 @@ func TestByteaArraySet(t *testing.T) { { source: [][]byte{{1, 2, 3}}, result: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([][]byte)(nil)), - result: pgtype.ByteaArray{Status: pgtype.Null}, + result: pgtype.ByteaArray{}, }, { source: [][][]byte{{{1}}, {{2}}}, result: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Valid: true}, {Bytes: []byte{2}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, result: pgtype.ByteaArray{ Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, - {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, - {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, - {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, - {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + {Bytes: []byte{1, 2, 3}, Valid: true}, + {Bytes: []byte{4, 5, 6}, Valid: true}, + {Bytes: []byte{7, 8, 9}, Valid: true}, + {Bytes: []byte{10, 11, 12}, Valid: true}, + {Bytes: []byte{13, 14, 15}, Valid: true}, + {Bytes: []byte{16, 17, 18}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][]byte{{{1}}, {{2}}}, result: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Valid: true}, {Bytes: []byte{2}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, result: pgtype.ByteaArray{ Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, - {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, - {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, - {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, - {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + {Bytes: []byte{1, 2, 3}, Valid: true}, + {Bytes: []byte{4, 5, 6}, Valid: true}, + {Bytes: []byte{7, 8, 9}, Valid: true}, + {Bytes: []byte{10, 11, 12}, Valid: true}, + {Bytes: []byte{13, 14, 15}, Valid: true}, + {Bytes: []byte{16, 17, 18}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -145,72 +145,72 @@ func TestByteaArrayAssignTo(t *testing.T) { }{ { src: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &byteByteSlice, expected: [][]byte{{1, 2, 3}}, }, { - src: pgtype.ByteaArray{Status: pgtype.Null}, + src: pgtype.ByteaArray{}, dst: &byteByteSlice, expected: (([][]byte)(nil)), }, { - src: pgtype.ByteaArray{Status: pgtype.Present}, + src: pgtype.ByteaArray{Valid: true}, dst: &byteByteSlice, expected: [][]byte{}, }, { src: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Valid: true}, {Bytes: []byte{2}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &byteByteSliceDim2, expected: [][][]byte{{{1}}, {{2}}}, }, { src: pgtype.ByteaArray{ Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, - {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, - {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, - {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, - {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + {Bytes: []byte{1, 2, 3}, Valid: true}, + {Bytes: []byte{4, 5, 6}, Valid: true}, + {Bytes: []byte{7, 8, 9}, Valid: true}, + {Bytes: []byte{10, 11, 12}, Valid: true}, + {Bytes: []byte{13, 14, 15}, Valid: true}, + {Bytes: []byte{16, 17, 18}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &byteByteSliceDim4, expected: [][][][][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, }, { src: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Valid: true}, {Bytes: []byte{2}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &byteByteArraySliceDim2, expected: [2][1][]byte{{{1}}, {{2}}}, }, { src: pgtype.ByteaArray{ Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, - {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, - {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, - {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, - {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + {Bytes: []byte{1, 2, 3}, Valid: true}, + {Bytes: []byte{4, 5, 6}, Valid: true}, + {Bytes: []byte{7, 8, 9}, Valid: true}, + {Bytes: []byte{10, 11, 12}, Valid: true}, + {Bytes: []byte{13, 14, 15}, Valid: true}, + {Bytes: []byte{16, 17, 18}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &byteByteArraySliceDim4, expected: [2][1][1][3][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, }, diff --git a/bytea_test.go b/bytea_test.go index c8c49ff7..0f47cb7f 100644 --- a/bytea_test.go +++ b/bytea_test.go @@ -10,9 +10,9 @@ import ( func TestByteaTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "bytea", []interface{}{ - &pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - &pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, - &pgtype.Bytea{Bytes: nil, Status: pgtype.Null}, + &pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}, + &pgtype.Bytea{Bytes: []byte{}, Valid: true}, + &pgtype.Bytea{Bytes: nil}, }) } @@ -21,11 +21,11 @@ func TestByteaSet(t *testing.T) { source interface{} result pgtype.Bytea }{ - {source: []byte{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, - {source: []byte{}, result: pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}}, - {source: []byte(nil), result: pgtype.Bytea{Status: pgtype.Null}}, - {source: _byteSlice{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, - {source: _byteSlice(nil), result: pgtype.Bytea{Status: pgtype.Null}}, + {source: []byte{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}}, + {source: []byte{}, result: pgtype.Bytea{Bytes: []byte{}, Valid: true}}, + {source: []byte(nil), result: pgtype.Bytea{}}, + {source: _byteSlice{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}}, + {source: _byteSlice(nil), result: pgtype.Bytea{}}, } for i, tt := range successfulTests { @@ -52,12 +52,12 @@ func TestByteaAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &buf, expected: []byte{1, 2, 3}}, - {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_buf, expected: _byteSlice{1, 2, 3}}, - {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &pbuf, expected: &[]byte{1, 2, 3}}, - {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_pbuf, expected: &_byteSlice{1, 2, 3}}, - {src: pgtype.Bytea{Status: pgtype.Null}, dst: &pbuf, expected: ((*[]byte)(nil))}, - {src: pgtype.Bytea{Status: pgtype.Null}, dst: &_pbuf, expected: ((*_byteSlice)(nil))}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}, dst: &buf, expected: []byte{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}, dst: &_buf, expected: _byteSlice{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}, dst: &pbuf, expected: &[]byte{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}, dst: &_pbuf, expected: &_byteSlice{1, 2, 3}}, + {src: pgtype.Bytea{}, dst: &pbuf, expected: ((*[]byte)(nil))}, + {src: pgtype.Bytea{}, dst: &_pbuf, expected: ((*_byteSlice)(nil))}, } for i, tt := range simpleTests { diff --git a/cid_test.go b/cid_test.go index 5b1150eb..041cb805 100644 --- a/cid_test.go +++ b/cid_test.go @@ -11,8 +11,8 @@ import ( func TestCIDTranscode(t *testing.T) { pgTypeName := "cid" values := []interface{}{ - &pgtype.CID{Uint: 42, Status: pgtype.Present}, - &pgtype.CID{Status: pgtype.Null}, + &pgtype.CID{Uint: 42, Valid: true}, + &pgtype.CID{}, } eqFunc := func(a, b interface{}) bool { return reflect.DeepEqual(a, b) @@ -27,7 +27,7 @@ func TestCIDSet(t *testing.T) { source interface{} result pgtype.CID }{ - {source: uint32(1), result: pgtype.CID{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.CID{Uint: 1, Valid: true}}, } for i, tt := range successfulTests { @@ -52,8 +52,8 @@ func TestCIDAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.CID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.CID{Uint: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.CID{}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -72,7 +72,7 @@ func TestCIDAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.CID{Uint: 42, Valid: true}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -90,7 +90,7 @@ func TestCIDAssignTo(t *testing.T) { src pgtype.CID dst interface{} }{ - {src: pgtype.CID{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.CID{}, dst: &ui32}, } for i, tt := range errorTests { diff --git a/cidr_array.go b/cidr_array.go index 783c599c..48a6a4c1 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -15,13 +15,13 @@ import ( type CIDRArray struct { Elements []CIDR Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *CIDRArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = CIDRArray{Status: Null} + *dst = CIDRArray{} return nil } @@ -37,9 +37,9 @@ func (dst *CIDRArray) Set(src interface{}) error { case []*net.IPNet: if value == nil { - *dst = CIDRArray{Status: Null} + *dst = CIDRArray{} } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} + *dst = CIDRArray{Valid: true} } else { elements := make([]CIDR, len(value)) for i := range value { @@ -50,15 +50,15 @@ func (dst *CIDRArray) Set(src interface{}) error { *dst = CIDRArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []net.IP: if value == nil { - *dst = CIDRArray{Status: Null} + *dst = CIDRArray{} } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} + *dst = CIDRArray{Valid: true} } else { elements := make([]CIDR, len(value)) for i := range value { @@ -69,15 +69,15 @@ func (dst *CIDRArray) Set(src interface{}) error { *dst = CIDRArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*net.IP: if value == nil { - *dst = CIDRArray{Status: Null} + *dst = CIDRArray{} } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} + *dst = CIDRArray{Valid: true} } else { elements := make([]CIDR, len(value)) for i := range value { @@ -88,20 +88,20 @@ func (dst *CIDRArray) Set(src interface{}) error { *dst = CIDRArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []CIDR: if value == nil { - *dst = CIDRArray{Status: Null} + *dst = CIDRArray{} } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} + *dst = CIDRArray{Valid: true} } else { *dst = CIDRArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -110,7 +110,7 @@ func (dst *CIDRArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = CIDRArray{Status: Null} + *dst = CIDRArray{} return nil } @@ -119,7 +119,7 @@ func (dst *CIDRArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for CIDRArray", src) } if elementsLength == 0 { - *dst = CIDRArray{Status: Present} + *dst = CIDRArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -132,7 +132,7 @@ func (dst *CIDRArray) Set(src interface{}) error { *dst = CIDRArray{ Elements: make([]CIDR, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -199,93 +199,86 @@ func (dst *CIDRArray) setRecursive(value reflect.Value, index, dimension int) (i } func (dst CIDRArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *CIDRArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]*net.IPNet: - *v = make([]*net.IPNet, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]net.IP: - *v = make([]net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.IP: - *v = make([]*net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]*net.IPNet: + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]net.IP: + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*net.IP: + *v = make([]*net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *CIDRArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -337,7 +330,7 @@ func (src *CIDRArray) assignToRecursive(value reflect.Value, index, dimension in func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = CIDRArray{Status: Null} + *dst = CIDRArray{} return nil } @@ -366,14 +359,14 @@ func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = CIDRArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = CIDRArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = CIDRArray{Status: Null} + *dst = CIDRArray{} return nil } @@ -384,7 +377,7 @@ func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = CIDRArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = CIDRArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -409,16 +402,13 @@ func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = CIDRArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = CIDRArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -471,11 +461,8 @@ func (src CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -489,7 +476,7 @@ func (src CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/cidr_array_test.go b/cidr_array_test.go index 74c063fa..7821cf44 100644 --- a/cidr_array_test.go +++ b/cidr_array_test.go @@ -14,41 +14,41 @@ func TestCIDRArrayTranscode(t *testing.T) { &pgtype.CIDRArray{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.CIDRArray{ Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {Status: pgtype.Null}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.CIDRArray{Status: pgtype.Null}, + &pgtype.CIDRArray{}, &pgtype.CIDRArray{ Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - {Status: pgtype.Null}, - {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, + {}, + {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.CIDRArray{ Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } @@ -61,33 +61,33 @@ func TestCIDRArraySet(t *testing.T) { { source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]*net.IPNet)(nil)), - result: pgtype.CIDRArray{Status: pgtype.Null}, + result: pgtype.CIDRArray{}, }, { source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]net.IP)(nil)), - result: pgtype.CIDRArray{Status: pgtype.Null}, + result: pgtype.CIDRArray{}, }, { source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, result: pgtype.CIDRArray{ Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]*net.IPNet{ @@ -101,27 +101,27 @@ func TestCIDRArraySet(t *testing.T) { mustParseCIDR(t, "169.168.0.1/16")}}}}, result: pgtype.CIDRArray{ Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, result: pgtype.CIDRArray{ Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3]*net.IPNet{ @@ -135,18 +135,18 @@ func TestCIDRArraySet(t *testing.T) { mustParseCIDR(t, "169.168.0.1/16")}}}}, result: pgtype.CIDRArray{ Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -178,85 +178,85 @@ func TestCIDRArrayAssignTo(t *testing.T) { }{ { src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &ipnetSlice, expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, }, { src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{Status: pgtype.Null}}, + Elements: []pgtype.CIDR{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &ipnetSlice, expected: []*net.IPNet{nil}, }, { src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &ipSlice, expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, }, { src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{Status: pgtype.Null}}, + Elements: []pgtype.CIDR{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &ipSlice, expected: []net.IP{nil}, }, { - src: pgtype.CIDRArray{Status: pgtype.Null}, + src: pgtype.CIDRArray{}, dst: &ipnetSlice, expected: (([]*net.IPNet)(nil)), }, { - src: pgtype.CIDRArray{Status: pgtype.Present}, + src: pgtype.CIDRArray{Valid: true}, dst: &ipnetSlice, expected: []*net.IPNet{}, }, { - src: pgtype.CIDRArray{Status: pgtype.Null}, + src: pgtype.CIDRArray{}, dst: &ipSlice, expected: (([]net.IP)(nil)), }, { - src: pgtype.CIDRArray{Status: pgtype.Present}, + src: pgtype.CIDRArray{Valid: true}, dst: &ipSlice, expected: []net.IP{}, }, { src: pgtype.CIDRArray{ Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &ipSliceDim2, expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, }, { src: pgtype.CIDRArray{ Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &ipnetSliceDim4, expected: [][][][]*net.IPNet{ {{{ @@ -271,28 +271,28 @@ func TestCIDRArrayAssignTo(t *testing.T) { { src: pgtype.CIDRArray{ Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &ipArrayDim2, expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, }, { src: pgtype.CIDRArray{ Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &ipnetArrayDim4, expected: [2][1][1][3]*net.IPNet{ {{{ diff --git a/circle.go b/circle.go index 4279650e..7524d7b9 100644 --- a/circle.go +++ b/circle.go @@ -12,9 +12,9 @@ import ( ) type Circle struct { - P Vec2 - R float64 - Status Status + P Vec2 + R float64 + Valid bool } func (dst *Circle) Set(src interface{}) error { @@ -22,14 +22,10 @@ func (dst *Circle) Set(src interface{}) error { } func (dst Circle) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *Circle) AssignTo(dst interface{}) error { @@ -38,7 +34,7 @@ func (src *Circle) AssignTo(dst interface{}) error { func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Circle{Status: Null} + *dst = Circle{} return nil } @@ -68,13 +64,13 @@ func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Circle{P: Vec2{x, y}, R: r, Status: Present} + *dst = Circle{P: Vec2{x, y}, R: r, Valid: true} return nil } func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Circle{Status: Null} + *dst = Circle{} return nil } @@ -87,19 +83,16 @@ func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { r := binary.BigEndian.Uint64(src[16:]) *dst = Circle{ - P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, - R: math.Float64frombits(r), - Status: Present, + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + R: math.Float64frombits(r), + Valid: true, } return nil } func (src Circle) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = append(buf, fmt.Sprintf(`<(%s,%s),%s>`, @@ -112,11 +105,8 @@ func (src Circle) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Circle) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) @@ -128,7 +118,7 @@ func (src Circle) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Circle) Scan(src interface{}) error { if src == nil { - *dst = Circle{Status: Null} + *dst = Circle{} return nil } diff --git a/circle_test.go b/circle_test.go index ba4f408b..416a1a41 100644 --- a/circle_test.go +++ b/circle_test.go @@ -9,8 +9,8 @@ import ( func TestCircleTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "circle", []interface{}{ - &pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Status: pgtype.Present}, - &pgtype.Circle{P: pgtype.Vec2{-1.234, -5.6789}, R: 12.9, Status: pgtype.Present}, - &pgtype.Circle{Status: pgtype.Null}, + &pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, + &pgtype.Circle{P: pgtype.Vec2{-1.234, -5.6789}, R: 12.9, Valid: true}, + &pgtype.Circle{}, }) } diff --git a/composite_bench_test.go b/composite_bench_test.go index 7aef8c4f..a1d91f8e 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -44,7 +44,7 @@ func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { } dst.A = a.Int - if b.Status == pgtype.Present { + if b.Valid { dst.B = &b.String } else { dst.B = nil diff --git a/composite_fields_test.go b/composite_fields_test.go index dc4d4c29..be0b8125 100644 --- a/composite_fields_test.go +++ b/composite_fields_test.go @@ -230,7 +230,7 @@ create type cf_encode as ( for _, simpleProtocol := range simpleProtocols { err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{&pgtype.Text{Status: pgtype.Null}, int32(1), "null", &pgtype.Float8{Status: pgtype.Null}, &pgtype.Text{Status: pgtype.Null}}, + pgtype.CompositeFields{&pgtype.Text{}, int32(1), "null", &pgtype.Float8{}, &pgtype.Text{}}, ).Scan( pgtype.CompositeFields{&a, &b, &c, &d, &e}, ) diff --git a/composite_type.go b/composite_type.go index 32e0aa26..90b7b6ff 100644 --- a/composite_type.go +++ b/composite_type.go @@ -16,7 +16,7 @@ type CompositeTypeField struct { } type CompositeType struct { - status Status + valid bool typeName string @@ -58,18 +58,15 @@ func NewCompositeTypeValues(typeName string, fields []CompositeTypeField, values } func (src CompositeType) Get() interface{} { - switch src.status { - case Present: - results := make(map[string]interface{}, len(src.valueTranscoders)) - for i := range src.valueTranscoders { - results[src.fields[i].Name] = src.valueTranscoders[i].Get() - } - return results - case Null: + if !src.valid { return nil - default: - return src.status } + + results := make(map[string]interface{}, len(src.valueTranscoders)) + for i := range src.valueTranscoders { + results[src.fields[i].Name] = src.valueTranscoders[i].Get() + } + return results } func (ct *CompositeType) NewTypeValue() Value { @@ -96,7 +93,7 @@ func (ct *CompositeType) Fields() []CompositeTypeField { func (dst *CompositeType) Set(src interface{}) error { if src == nil { - dst.status = Null + dst.valid = false return nil } @@ -110,10 +107,10 @@ func (dst *CompositeType) Set(src interface{}) error { return err } } - dst.status = Present + dst.valid = true case *[]interface{}: if value == nil { - dst.status = Null + dst.valid = false return nil } return dst.Set(*value) @@ -126,40 +123,38 @@ func (dst *CompositeType) Set(src interface{}) error { // AssignTo should never be called on composite value directly func (src CompositeType) AssignTo(dst interface{}) error { - switch src.status { - case Present: - switch v := dst.(type) { - case []interface{}: - if len(v) != len(src.valueTranscoders) { - return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders)) - } - for i := range src.valueTranscoders { - if v[i] == nil { - continue - } - - err := assignToOrSet(src.valueTranscoders[i], v[i]) - if err != nil { - return fmt.Errorf("unable to assign to dst[%d]: %v", i, err) - } - } - return nil - case *[]interface{}: - return src.AssignTo(*v) - default: - if isPtrStruct, err := src.assignToPtrStruct(dst); isPtrStruct { - return err - } - - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + + switch v := dst.(type) { + case []interface{}: + if len(v) != len(src.valueTranscoders) { + return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders)) + } + for i := range src.valueTranscoders { + if v[i] == nil { + continue + } + + err := assignToOrSet(src.valueTranscoders[i], v[i]) + if err != nil { + return fmt.Errorf("unable to assign to dst[%d]: %v", i, err) + } + } + return nil + case *[]interface{}: + return src.AssignTo(*v) + default: + if isPtrStruct, err := src.assignToPtrStruct(dst); isPtrStruct { + return err + } + + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } } func assignToOrSet(src Value, dst interface{}) error { @@ -219,11 +214,8 @@ func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) { } func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { - switch src.status { - case Null: + if !src.valid { return nil, nil - case Undefined: - return nil, errUndefined } b := NewCompositeBinaryBuilder(ci, buf) @@ -240,7 +232,7 @@ func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, // type mismatch func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { if buf == nil { - dst.status = Null + dst.valid = false return nil } @@ -254,14 +246,14 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { return scanner.Err() } - dst.status = Present + dst.valid = true return nil } func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error { if buf == nil { - dst.status = Null + dst.valid = false return nil } @@ -275,17 +267,14 @@ func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error { return scanner.Err() } - dst.status = Present + dst.valid = true return nil } func (src CompositeType) EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { - switch src.status { - case Null: + if !src.valid { return nil, nil - case Undefined: - return nil, errUndefined } b := NewCompositeTextBuilder(ci, buf) diff --git a/composite_type_test.go b/composite_type_test.go index 2349a67d..e06927fa 100644 --- a/composite_type_test.go +++ b/composite_type_test.go @@ -20,7 +20,7 @@ func TestCompositeTypeSetAndGet(t *testing.T) { {"b", pgtype.Int4OID}, }, ci) require.NoError(t, err) - assert.Equal(t, pgtype.Undefined, ct.Get()) + assert.Equal(t, nil, ct.Get()) nilTests := []struct { src interface{} @@ -48,7 +48,7 @@ func TestCompositeTypeSetAndGet(t *testing.T) { expected: map[string]interface{}{"a": nil, "b": nil}, }, { - src: []interface{}{&pgtype.Text{String: "hi", Status: pgtype.Present}, &pgtype.Int4{Int: 7, Status: pgtype.Present}}, + src: []interface{}{&pgtype.Text{String: "hi", Valid: true}, &pgtype.Int4{Int: 7, Valid: true}}, expected: map[string]interface{}{"a": "hi", "b": int32(7)}, }, } @@ -92,8 +92,8 @@ func TestCompositeTypeAssignTo(t *testing.T) { err = ct.AssignTo([]interface{}{&a, &b}) assert.NoError(t, err) - assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a) - assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b) + assert.Equal(t, pgtype.Text{String: "foo", Valid: true}, a) + assert.Equal(t, pgtype.Int4{Int: 42, Valid: true}, b) } // Allow nil destination component as no-op @@ -137,8 +137,8 @@ func TestCompositeTypeAssignTo(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, dst) - assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a) - assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b) + assert.Equal(t, pgtype.Text{String: "foo", Valid: true}, a) + assert.Equal(t, pgtype.Int4{Int: 42, Valid: true}, b) } // Struct fields positionally via reflection diff --git a/convert.go b/convert.go index de9ba9ba..21e208f5 100644 --- a/convert.go +++ b/convert.go @@ -208,8 +208,8 @@ func underlyingSliceType(val interface{}) (interface{}, bool) { return nil, false } -func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { - if srcStatus == Present { +func int64AssignTo(srcVal int64, srcValid bool, dst interface{}) error { + if srcValid { switch v := dst.(type) { case *int: if srcVal < int64(minInt) { @@ -291,7 +291,7 @@ func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { // allocate destination el.Set(reflect.New(el.Type().Elem())) } - return int64AssignTo(srcVal, srcStatus, el.Interface()) + return int64AssignTo(srcVal, srcValid, el.Interface()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if el.OverflowInt(int64(srcVal)) { return fmt.Errorf("cannot put %d into %T", srcVal, dst) @@ -314,7 +314,7 @@ func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { return nil } - // if dst is a pointer to pointer and srcStatus is not Present, nil it out + // if dst is a pointer to pointer and srcStatus is not Valid, nil it out if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { el := v.Elem() if el.Kind() == reflect.Ptr { @@ -323,11 +323,11 @@ func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { } } - return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcValid, dst) } -func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { - if srcStatus == Present { +func float64AssignTo(srcVal float64, srcValid bool, dst interface{}) error { + if srcValid { switch v := dst.(type) { case *float32: *v = float32(srcVal) @@ -343,11 +343,11 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { // allocate destination el.Set(reflect.New(el.Type().Elem())) } - return float64AssignTo(srcVal, srcStatus, el.Interface()) + return float64AssignTo(srcVal, srcValid, el.Interface()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: i64 := int64(srcVal) if float64(i64) == srcVal { - return int64AssignTo(i64, srcStatus, dst) + return int64AssignTo(i64, srcValid, dst) } } } @@ -356,7 +356,7 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { return nil } - // if dst is a pointer to pointer and srcStatus is not Present, nil it out + // if dst is a pointer to pointer and srcStatus is not Valid, nil it out if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { el := v.Elem() if el.Kind() == reflect.Ptr { @@ -365,7 +365,7 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { } } - return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcValid, dst) } func NullAssignTo(dst interface{}) error { diff --git a/custom_composite_test.go b/custom_composite_test.go index 9ca8dd5e..86203828 100644 --- a/custom_composite_test.go +++ b/custom_composite_test.go @@ -28,12 +28,12 @@ func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { } func (src MyType) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) { - a := pgtype.Int4{src.a, pgtype.Present} + a := pgtype.Int4{src.a, true} var b pgtype.Text if src.b != nil { - b = pgtype.Text{*src.b, pgtype.Present} + b = pgtype.Text{*src.b, true} } else { - b = pgtype.Text{Status: pgtype.Null} + b = pgtype.Text{} } return (pgtype.CompositeFields{&a, &b}).EncodeBinary(ci, buf) diff --git a/date.go b/date.go index e8d21a78..5b7f47e6 100644 --- a/date.go +++ b/date.go @@ -12,7 +12,7 @@ import ( type Date struct { Time time.Time - Status Status + Valid bool InfinityModifier InfinityModifier } @@ -23,7 +23,7 @@ const ( func (dst *Date) Set(src interface{}) error { if src == nil { - *dst = Date{Status: Null} + *dst = Date{} return nil } @@ -36,18 +36,18 @@ func (dst *Date) Set(src interface{}) error { switch value := src.(type) { case time.Time: - *dst = Date{Time: value, Status: Present} + *dst = Date{Time: value, Valid: true} case string: return dst.DecodeText(nil, []byte(value)) case *time.Time: if value == nil { - *dst = Date{Status: Null} + *dst = Date{} } else { return dst.Set(*value) } case *string: if value == nil { - *dst = Date{Status: Null} + *dst = Date{} } else { return dst.Set(*value) } @@ -62,61 +62,54 @@ func (dst *Date) Set(src interface{}) error { } func (dst Date) Get() interface{} { - switch dst.Status { - case Present: - if dst.InfinityModifier != None { - return dst.InfinityModifier - } - return dst.Time - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst.Time } func (src *Date) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *time.Time: - if src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } } func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Date{Status: Null} + *dst = Date{} return nil } sbuf := string(src) switch sbuf { case "infinity": - *dst = Date{Status: Present, InfinityModifier: Infinity} + *dst = Date{Valid: true, InfinityModifier: Infinity} case "-infinity": - *dst = Date{Status: Present, InfinityModifier: -Infinity} + *dst = Date{Valid: true, InfinityModifier: -Infinity} default: t, err := time.ParseInLocation("2006-01-02", sbuf, time.UTC) if err != nil { return err } - *dst = Date{Time: t, Status: Present} + *dst = Date{Time: t, Valid: true} } return nil @@ -124,7 +117,7 @@ func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Date{Status: Null} + *dst = Date{} return nil } @@ -136,23 +129,20 @@ func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { switch dayOffset { case infinityDayOffset: - *dst = Date{Status: Present, InfinityModifier: Infinity} + *dst = Date{Valid: true, InfinityModifier: Infinity} case negativeInfinityDayOffset: - *dst = Date{Status: Present, InfinityModifier: -Infinity} + *dst = Date{Valid: true, InfinityModifier: -Infinity} default: t := time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.UTC) - *dst = Date{Time: t, Status: Present} + *dst = Date{Time: t, Valid: true} } return nil } func (src Date) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } var s string @@ -170,11 +160,8 @@ func (src Date) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Date) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } var daysSinceDateEpoch int32 @@ -197,7 +184,7 @@ func (src Date) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Date) Scan(src interface{}) error { if src == nil { - *dst = Date{Status: Null} + *dst = Date{} return nil } @@ -209,7 +196,7 @@ func (dst *Date) Scan(src interface{}) error { copy(srcCopy, src) return dst.DecodeText(nil, srcCopy) case time.Time: - *dst = Date{Time: src, Status: Present} + *dst = Date{Time: src, Valid: true} return nil } @@ -218,29 +205,19 @@ func (dst *Date) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Date) Value() (driver.Value, error) { - switch src.Status { - case Present: - if src.InfinityModifier != None { - return src.InfinityModifier.String(), nil - } - return src.Time, nil - case Null: + if !src.Valid { return nil, nil - default: - return nil, errUndefined } + + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil } func (src Date) MarshalJSON() ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return []byte("null"), nil - case Undefined: - return nil, errUndefined - } - - if src.Status != Present { - return nil, errBadStatus } var s string @@ -265,22 +242,22 @@ func (dst *Date) UnmarshalJSON(b []byte) error { } if s == nil { - *dst = Date{Status: Null} + *dst = Date{} return nil } switch *s { case "infinity": - *dst = Date{Status: Present, InfinityModifier: Infinity} + *dst = Date{Valid: true, InfinityModifier: Infinity} case "-infinity": - *dst = Date{Status: Present, InfinityModifier: -Infinity} + *dst = Date{Valid: true, InfinityModifier: -Infinity} default: t, err := time.ParseInLocation("2006-01-02", *s, time.UTC) if err != nil { return err } - *dst = Date{Time: t, Status: Present} + *dst = Date{Time: t, Valid: true} } return nil diff --git a/date_array.go b/date_array.go index 24152fa0..9d3b32e2 100644 --- a/date_array.go +++ b/date_array.go @@ -15,13 +15,13 @@ import ( type DateArray struct { Elements []Date Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *DateArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = DateArray{Status: Null} + *dst = DateArray{} return nil } @@ -37,9 +37,9 @@ func (dst *DateArray) Set(src interface{}) error { case []time.Time: if value == nil { - *dst = DateArray{Status: Null} + *dst = DateArray{} } else if len(value) == 0 { - *dst = DateArray{Status: Present} + *dst = DateArray{Valid: true} } else { elements := make([]Date, len(value)) for i := range value { @@ -50,15 +50,15 @@ func (dst *DateArray) Set(src interface{}) error { *dst = DateArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*time.Time: if value == nil { - *dst = DateArray{Status: Null} + *dst = DateArray{} } else if len(value) == 0 { - *dst = DateArray{Status: Present} + *dst = DateArray{Valid: true} } else { elements := make([]Date, len(value)) for i := range value { @@ -69,20 +69,20 @@ func (dst *DateArray) Set(src interface{}) error { *dst = DateArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []Date: if value == nil { - *dst = DateArray{Status: Null} + *dst = DateArray{} } else if len(value) == 0 { - *dst = DateArray{Status: Present} + *dst = DateArray{Valid: true} } else { *dst = DateArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -91,7 +91,7 @@ func (dst *DateArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = DateArray{Status: Null} + *dst = DateArray{} return nil } @@ -100,7 +100,7 @@ func (dst *DateArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for DateArray", src) } if elementsLength == 0 { - *dst = DateArray{Status: Present} + *dst = DateArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -113,7 +113,7 @@ func (dst *DateArray) Set(src interface{}) error { *dst = DateArray{ Elements: make([]Date, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -180,84 +180,77 @@ func (dst *DateArray) setRecursive(value reflect.Value, index, dimension int) (i } func (dst DateArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *DateArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*time.Time: - *v = make([]*time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]time.Time: + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*time.Time: + *v = make([]*time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *DateArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -309,7 +302,7 @@ func (src *DateArray) assignToRecursive(value reflect.Value, index, dimension in func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = DateArray{Status: Null} + *dst = DateArray{} return nil } @@ -338,14 +331,14 @@ func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = DateArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = DateArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = DateArray{Status: Null} + *dst = DateArray{} return nil } @@ -356,7 +349,7 @@ func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = DateArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = DateArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -381,16 +374,13 @@ func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = DateArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = DateArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -443,11 +433,8 @@ func (src DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -461,7 +448,7 @@ func (src DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/date_array_test.go b/date_array_test.go index 4458abfe..421427cd 100644 --- a/date_array_test.go +++ b/date_array_test.go @@ -14,41 +14,41 @@ func TestDateArrayTranscode(t *testing.T) { &pgtype.DateArray{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.DateArray{ Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.DateArray{Status: pgtype.Null}, + &pgtype.DateArray{}, &pgtype.DateArray{ Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.DateArray{ Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } @@ -61,13 +61,13 @@ func TestDateArraySet(t *testing.T) { { source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, result: pgtype.DateArray{ - Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]time.Time)(nil)), - result: pgtype.DateArray{Status: pgtype.Null}, + result: pgtype.DateArray{}, }, { source: [][]time.Time{ @@ -75,10 +75,10 @@ func TestDateArraySet(t *testing.T) { {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, result: pgtype.DateArray{ Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]time.Time{ @@ -92,18 +92,18 @@ func TestDateArraySet(t *testing.T) { time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, result: pgtype.DateArray{ Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1]time.Time{ @@ -111,10 +111,10 @@ func TestDateArraySet(t *testing.T) { {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, result: pgtype.DateArray{ Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3]time.Time{ @@ -128,18 +128,18 @@ func TestDateArraySet(t *testing.T) { time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, result: pgtype.DateArray{ Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -170,30 +170,30 @@ func TestDateArrayAssignTo(t *testing.T) { }{ { src: pgtype.DateArray{ - Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &timeSlice, expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, }, { - src: pgtype.DateArray{Status: pgtype.Null}, + src: pgtype.DateArray{}, dst: &timeSlice, expected: (([]time.Time)(nil)), }, { - src: pgtype.DateArray{Status: pgtype.Present}, + src: pgtype.DateArray{Valid: true}, dst: &timeSlice, expected: []time.Time{}, }, { src: pgtype.DateArray{ Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeSliceDim2, expected: [][]time.Time{ {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, @@ -202,18 +202,18 @@ func TestDateArrayAssignTo(t *testing.T) { { src: pgtype.DateArray{ Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeSliceDim4, expected: [][][][]time.Time{ {{{ @@ -228,10 +228,10 @@ func TestDateArrayAssignTo(t *testing.T) { { src: pgtype.DateArray{ Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeArrayDim2, expected: [2][1]time.Time{ {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, @@ -240,18 +240,18 @@ func TestDateArrayAssignTo(t *testing.T) { { src: pgtype.DateArray{ Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeArrayDim4, expected: [2][1][1][3]time.Time{ {{{ @@ -282,37 +282,37 @@ func TestDateArrayAssignTo(t *testing.T) { }{ { src: pgtype.DateArray{ - Elements: []pgtype.Date{{Status: pgtype.Null}}, + Elements: []pgtype.Date{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &timeSlice, }, { src: pgtype.DateArray{ Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeArrayDim2, }, { src: pgtype.DateArray{ Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeSlice, }, { src: pgtype.DateArray{ Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeArrayDim4, }, } diff --git a/date_test.go b/date_test.go index 5c38e7a3..87425540 100644 --- a/date_test.go +++ b/date_test.go @@ -11,20 +11,20 @@ import ( func TestDateTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "date", []interface{}{ - &pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Status: pgtype.Null}, - &pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, - &pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + &pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Date{}, + &pgtype.Date{Valid: true, InfinityModifier: pgtype.Infinity}, + &pgtype.Date{Valid: true, InfinityModifier: -pgtype.Infinity}, }, func(a, b interface{}) bool { at := a.(pgtype.Date) bt := b.(pgtype.Date) - return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier + return at.Time.Equal(bt.Time) && at.Valid == bt.Valid && at.InfinityModifier == bt.InfinityModifier }) } @@ -35,14 +35,14 @@ func TestDateSet(t *testing.T) { source interface{} result pgtype.Date }{ - {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: "1999-12-31", result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: "1999-12-31", result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}}, } for i, tt := range successfulTests { @@ -67,8 +67,8 @@ func TestDateAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, - {src: pgtype.Date{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {src: pgtype.Date{Time: time.Time{}}, dst: &ptim, expected: ((*time.Time)(nil))}, } for i, tt := range simpleTests { @@ -87,7 +87,7 @@ func TestDateAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, } for i, tt := range pointerAllocTests { @@ -105,9 +105,9 @@ func TestDateAssignTo(t *testing.T) { src pgtype.Date dst interface{} }{ - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Valid: true}, dst: &tim}, + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Valid: true}, dst: &tim}, + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, dst: &tim}, } for i, tt := range errorTests { @@ -123,12 +123,12 @@ func TestDateMarshalJSON(t *testing.T) { source pgtype.Date result string }{ - {source: pgtype.Date{Status: pgtype.Null}, result: "null"}, - {source: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, result: "\"2012-03-29\""}, - {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29\""}, - {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29\""}, - {source: pgtype.Date{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, result: "\"infinity\""}, - {source: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, result: "\"-infinity\""}, + {source: pgtype.Date{}, result: "null"}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Valid: true}, result: "\"2012-03-29\""}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}, result: "\"2012-03-29\""}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}, result: "\"2012-03-29\""}, + {source: pgtype.Date{InfinityModifier: pgtype.Infinity, Valid: true}, result: "\"infinity\""}, + {source: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "\"-infinity\""}, } for i, tt := range successfulTests { r, err := tt.source.MarshalJSON() @@ -147,12 +147,12 @@ func TestDateUnmarshalJSON(t *testing.T) { source string result pgtype.Date }{ - {source: "null", result: pgtype.Date{Status: pgtype.Null}}, - {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, - {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, - {source: "\"infinity\"", result: pgtype.Date{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, - {source: "\"-infinity\"", result: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + {source: "null", result: pgtype.Date{}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}}, + {source: "\"infinity\"", result: pgtype.Date{InfinityModifier: pgtype.Infinity, Valid: true}}, + {source: "\"-infinity\"", result: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, } for i, tt := range successfulTests { var r pgtype.Date @@ -161,7 +161,7 @@ func TestDateUnmarshalJSON(t *testing.T) { t.Errorf("%d: %v", i, err) } - if r.Time.Year() != tt.result.Time.Year() || r.Time.Month() != tt.result.Time.Month() || r.Time.Day() != tt.result.Time.Day() || r.Status != tt.result.Status || r.InfinityModifier != tt.result.InfinityModifier { + if r.Time.Year() != tt.result.Time.Year() || r.Time.Month() != tt.result.Time.Month() || r.Time.Day() != tt.result.Time.Day() || r.Valid != tt.result.Valid || r.InfinityModifier != tt.result.InfinityModifier { t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) } } diff --git a/daterange.go b/daterange.go index 63164a5a..8b0c03f1 100644 --- a/daterange.go +++ b/daterange.go @@ -12,13 +12,13 @@ type Daterange struct { Upper Date LowerType BoundType UpperType BoundType - Status Status + Valid bool } func (dst *Daterange) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = Daterange{Status: Null} + *dst = Daterange{} return nil } @@ -36,15 +36,11 @@ func (dst *Daterange) Set(src interface{}) error { return nil } -func (dst Daterange) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: +func (src Daterange) Get() interface{} { + if !src.Valid { return nil - default: - return dst.Status } + return src } func (src *Daterange) AssignTo(dst interface{}) error { @@ -53,7 +49,7 @@ func (src *Daterange) AssignTo(dst interface{}) error { func (dst *Daterange) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Daterange{Status: Null} + *dst = Daterange{} return nil } @@ -62,7 +58,7 @@ func (dst *Daterange) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Daterange{Status: Present} + *dst = Daterange{Valid: true} dst.LowerType = utr.LowerType dst.UpperType = utr.UpperType @@ -88,7 +84,7 @@ func (dst *Daterange) DecodeText(ci *ConnInfo, src []byte) error { func (dst *Daterange) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Daterange{Status: Null} + *dst = Daterange{} return nil } @@ -97,7 +93,7 @@ func (dst *Daterange) DecodeBinary(ci *ConnInfo, src []byte) error { return err } - *dst = Daterange{Status: Present} + *dst = Daterange{Valid: true} dst.LowerType = ubr.LowerType dst.UpperType = ubr.UpperType @@ -122,11 +118,8 @@ func (dst *Daterange) DecodeBinary(ci *ConnInfo, src []byte) error { } func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } switch src.LowerType { @@ -175,11 +168,8 @@ func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } var rangeType byte @@ -245,7 +235,7 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Daterange) Scan(src interface{}) error { if src == nil { - *dst = Daterange{Status: Null} + *dst = Daterange{} return nil } diff --git a/daterange_test.go b/daterange_test.go index 54d51e2d..830942d0 100644 --- a/daterange_test.go +++ b/daterange_test.go @@ -10,32 +10,32 @@ import ( func TestDaterangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "daterange", []interface{}{ - &pgtype.Daterange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Daterange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, &pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, - Status: pgtype.Present, + Valid: true, }, &pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Lower: pgtype.Date{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, - Status: pgtype.Present, + Valid: true, }, - &pgtype.Daterange{Status: pgtype.Null}, + &pgtype.Daterange{}, }, func(aa, bb interface{}) bool { a := aa.(pgtype.Daterange) b := bb.(pgtype.Daterange) - return a.Status == b.Status && + return a.Valid == b.Valid && a.Lower.Time.Equal(b.Lower.Time) && - a.Lower.Status == b.Lower.Status && + a.Lower.Valid == b.Lower.Valid && a.Lower.InfinityModifier == b.Lower.InfinityModifier && a.Upper.Time.Equal(b.Upper.Time) && - a.Upper.Status == b.Upper.Status && + a.Upper.Valid == b.Upper.Valid && a.Upper.InfinityModifier == b.Upper.InfinityModifier }) } @@ -45,23 +45,23 @@ func TestDaterangeNormalize(t *testing.T) { { SQL: "select daterange('2010-01-01', '2010-01-11', '(]')", Value: pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(2010, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2010, 1, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Lower: pgtype.Date{Time: time.Date(2010, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Date{Time: time.Date(2010, 1, 12, 0, 0, 0, 0, time.UTC), Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, - Status: pgtype.Present, + Valid: true, }, }, }, func(aa, bb interface{}) bool { a := aa.(pgtype.Daterange) b := bb.(pgtype.Daterange) - return a.Status == b.Status && + return a.Valid == b.Valid && a.Lower.Time.Equal(b.Lower.Time) && - a.Lower.Status == b.Lower.Status && + a.Lower.Valid == b.Lower.Valid && a.Lower.InfinityModifier == b.Lower.InfinityModifier && a.Upper.Time.Equal(b.Upper.Time) && - a.Upper.Status == b.Upper.Status && + a.Upper.Valid == b.Upper.Valid && a.Upper.InfinityModifier == b.Upper.InfinityModifier }) } @@ -73,48 +73,48 @@ func TestDaterangeSet(t *testing.T) { }{ { source: nil, - result: pgtype.Daterange{Status: pgtype.Null}, + result: pgtype.Daterange{}, }, { source: &pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, - Status: pgtype.Present, + Valid: true, }, result: pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, - Status: pgtype.Present, + Valid: true, }, }, { source: pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, - Status: pgtype.Present, + Valid: true, }, result: pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, - Status: pgtype.Present, + Valid: true, }, }, { source: "[1990-12-31,2028-01-01)", result: pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, - Status: pgtype.Present, + Valid: true, }, }, } diff --git a/enum_array.go b/enum_array.go index 59b5a3ed..dbfb211d 100644 --- a/enum_array.go +++ b/enum_array.go @@ -11,13 +11,13 @@ import ( type EnumArray struct { Elements []GenericText Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *EnumArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = EnumArray{Status: Null} + *dst = EnumArray{} return nil } @@ -33,9 +33,9 @@ func (dst *EnumArray) Set(src interface{}) error { case []string: if value == nil { - *dst = EnumArray{Status: Null} + *dst = EnumArray{} } else if len(value) == 0 { - *dst = EnumArray{Status: Present} + *dst = EnumArray{Valid: true} } else { elements := make([]GenericText, len(value)) for i := range value { @@ -46,15 +46,15 @@ func (dst *EnumArray) Set(src interface{}) error { *dst = EnumArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*string: if value == nil { - *dst = EnumArray{Status: Null} + *dst = EnumArray{} } else if len(value) == 0 { - *dst = EnumArray{Status: Present} + *dst = EnumArray{Valid: true} } else { elements := make([]GenericText, len(value)) for i := range value { @@ -65,20 +65,20 @@ func (dst *EnumArray) Set(src interface{}) error { *dst = EnumArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []GenericText: if value == nil { - *dst = EnumArray{Status: Null} + *dst = EnumArray{} } else if len(value) == 0 { - *dst = EnumArray{Status: Present} + *dst = EnumArray{Valid: true} } else { *dst = EnumArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -87,7 +87,7 @@ func (dst *EnumArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = EnumArray{Status: Null} + *dst = EnumArray{} return nil } @@ -96,7 +96,7 @@ func (dst *EnumArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for EnumArray", src) } if elementsLength == 0 { - *dst = EnumArray{Status: Present} + *dst = EnumArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -109,7 +109,7 @@ func (dst *EnumArray) Set(src interface{}) error { *dst = EnumArray{ Elements: make([]GenericText, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -176,84 +176,77 @@ func (dst *EnumArray) setRecursive(value reflect.Value, index, dimension int) (i } func (dst EnumArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *EnumArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *EnumArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -305,7 +298,7 @@ func (src *EnumArray) assignToRecursive(value reflect.Value, index, dimension in func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = EnumArray{Status: Null} + *dst = EnumArray{} return nil } @@ -334,17 +327,14 @@ func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = EnumArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = EnumArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (src EnumArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { diff --git a/enum_array_test.go b/enum_array_test.go index 659340f0..7d0ff864 100644 --- a/enum_array_test.go +++ b/enum_array_test.go @@ -24,29 +24,29 @@ func TestEnumArrayTranscode(t *testing.T) { &pgtype.EnumArray{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.EnumArray{ Elements: []pgtype.GenericText{ - {String: "red", Status: pgtype.Present}, - {Status: pgtype.Null}, + {String: "red", Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.EnumArray{Status: pgtype.Null}, + &pgtype.EnumArray{}, &pgtype.EnumArray{ Elements: []pgtype.GenericText{ - {String: "red", Status: pgtype.Present}, - {String: "green", Status: pgtype.Present}, - {String: "blue", Status: pgtype.Present}, - {String: "red", Status: pgtype.Present}, + {String: "red", Valid: true}, + {String: "green", Valid: true}, + {String: "blue", Valid: true}, + {String: "red", Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } @@ -59,61 +59,61 @@ func TestEnumArrayArraySet(t *testing.T) { { source: []string{"foo"}, result: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}}, + Elements: []pgtype.GenericText{{String: "foo", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]string)(nil)), - result: pgtype.EnumArray{Status: pgtype.Null}, + result: pgtype.EnumArray{}, }, { source: [][]string{{"foo"}, {"bar"}}, result: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, result: pgtype.EnumArray{ Elements: []pgtype.GenericText{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1]string{{"foo"}, {"bar"}}, result: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, result: pgtype.EnumArray{ Elements: []pgtype.GenericText{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -146,81 +146,81 @@ func TestEnumArrayArrayAssignTo(t *testing.T) { }{ { src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}}, + Elements: []pgtype.GenericText{{String: "foo", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &stringSlice, expected: []string{"foo"}, }, { src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.GenericText{{String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &namedStringSlice, expected: _stringSlice{"bar"}, }, { - src: pgtype.EnumArray{Status: pgtype.Null}, + src: pgtype.EnumArray{}, dst: &stringSlice, expected: (([]string)(nil)), }, { - src: pgtype.EnumArray{Status: pgtype.Present}, + src: pgtype.EnumArray{Valid: true}, dst: &stringSlice, expected: []string{}, }, { src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringSliceDim2, expected: [][]string{{"foo"}, {"bar"}}, }, { src: pgtype.EnumArray{ Elements: []pgtype.GenericText{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringSliceDim4, expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, }, { src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim2, expected: [2][1]string{{"foo"}, {"bar"}}, }, { src: pgtype.EnumArray{ Elements: []pgtype.GenericText{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim4, expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, }, @@ -243,31 +243,31 @@ func TestEnumArrayArrayAssignTo(t *testing.T) { }{ { src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{Status: pgtype.Null}}, + Elements: []pgtype.GenericText{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &stringSlice, }, { src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim2, }, { src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringSlice, }, { src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim4, }, } diff --git a/enum_type.go b/enum_type.go index d340320f..73ee3823 100644 --- a/enum_type.go +++ b/enum_type.go @@ -5,8 +5,8 @@ import "fmt" // EnumType represents a enum type. While it implements Value, this is only in service of its type conversion duties // when registered as a data type in a ConnType. It should not be used directly as a Value. type EnumType struct { - value string - status Status + value string + valid bool typeName string // PostgreSQL type name members []string // enum members @@ -25,8 +25,8 @@ func NewEnumType(typeName string, members []string) *EnumType { func (et *EnumType) NewTypeValue() Value { return &EnumType{ - value: et.value, - status: et.status, + value: et.value, + valid: et.valid, typeName: et.typeName, members: et.members, @@ -46,7 +46,7 @@ func (et *EnumType) Members() []string { // operation in the event the PostgreSQL enum type is modified during a connection. func (dst *EnumType) Set(src interface{}) error { if src == nil { - dst.status = Null + dst.valid = false return nil } @@ -60,20 +60,20 @@ func (dst *EnumType) Set(src interface{}) error { switch value := src.(type) { case string: dst.value = value - dst.status = Present + dst.valid = true case *string: if value == nil { - dst.status = Null + dst.valid = false } else { dst.value = *value - dst.status = Present + dst.valid = true } case []byte: if value == nil { - dst.status = Null + dst.valid = false } else { dst.value = string(value) - dst.status = Present + dst.valid = true } default: if originalSrc, ok := underlyingStringType(src); ok { @@ -86,38 +86,31 @@ func (dst *EnumType) Set(src interface{}) error { } func (dst EnumType) Get() interface{} { - switch dst.status { - case Present: - return dst.value - case Null: + if !dst.valid { return nil - default: - return dst.status } + return dst.value } func (src *EnumType) AssignTo(dst interface{}) error { - switch src.status { - case Present: - switch v := dst.(type) { - case *string: - *v = src.value - return nil - case *[]byte: - *v = make([]byte, len(src.value)) - copy(*v, src.value) - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + switch v := dst.(type) { + case *string: + *v = src.value + return nil + case *[]byte: + *v = make([]byte, len(src.value)) + copy(*v, src.value) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } } func (EnumType) PreferredResultFormat() int16 { @@ -126,7 +119,7 @@ func (EnumType) PreferredResultFormat() int16 { func (dst *EnumType) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - dst.status = Null + dst.valid = false return nil } @@ -139,7 +132,7 @@ func (dst *EnumType) DecodeText(ci *ConnInfo, src []byte) error { // and membersMap between connections. dst.value = string(src) } - dst.status = Present + dst.valid = true return nil } @@ -153,11 +146,8 @@ func (EnumType) PreferredParamFormat() int16 { } func (src EnumType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.status { - case Null: + if !src.valid { return nil, nil - case Undefined: - return nil, errUndefined } return append(buf, src.value...), nil diff --git a/ext/gofrs-uuid/uuid.go b/ext/gofrs-uuid/uuid.go index a5e0a3c3..0e0ebed3 100644 --- a/ext/gofrs-uuid/uuid.go +++ b/ext/gofrs-uuid/uuid.go @@ -2,24 +2,20 @@ package uuid import ( "database/sql/driver" - "errors" "fmt" "github.com/gofrs/uuid" "github.com/jackc/pgtype" ) -var errUndefined = errors.New("cannot encode status undefined") -var errBadStatus = errors.New("invalid status") - type UUID struct { - UUID uuid.UUID - Status pgtype.Status + UUID uuid.UUID + Valid bool } func (dst *UUID) Set(src interface{}) error { if src == nil { - *dst = UUID{Status: pgtype.Null} + *dst = UUID{} return nil } @@ -32,21 +28,21 @@ func (dst *UUID) Set(src interface{}) error { switch value := src.(type) { case uuid.UUID: - *dst = UUID{UUID: value, Status: pgtype.Present} + *dst = UUID{UUID: value, Valid: true} case [16]byte: - *dst = UUID{UUID: uuid.UUID(value), Status: pgtype.Present} + *dst = UUID{UUID: uuid.UUID(value), Valid: true} case []byte: if len(value) != 16 { return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) } - *dst = UUID{Status: pgtype.Present} + *dst = UUID{Valid: true} copy(dst.UUID[:], value) case string: uuid, err := uuid.FromString(value) if err != nil { return err } - *dst = UUID{UUID: uuid, Status: pgtype.Present} + *dst = UUID{UUID: uuid, Valid: true} default: // If all else fails see if pgtype.UUID can handle it. If so, translate through that. pgUUID := &pgtype.UUID{} @@ -54,56 +50,49 @@ func (dst *UUID) Set(src interface{}) error { return fmt.Errorf("cannot convert %v to UUID", value) } - *dst = UUID{UUID: uuid.UUID(pgUUID.Bytes), Status: pgUUID.Status} + *dst = UUID{UUID: uuid.UUID(pgUUID.Bytes), Valid: pgUUID.Valid} } return nil } func (dst UUID) Get() interface{} { - switch dst.Status { - case pgtype.Present: - return dst.UUID - case pgtype.Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.UUID } func (src *UUID) AssignTo(dst interface{}) error { - switch src.Status { - case pgtype.Present: - switch v := dst.(type) { - case *uuid.UUID: - *v = src.UUID - return nil - case *[16]byte: - *v = [16]byte(src.UUID) - return nil - case *[]byte: - *v = make([]byte, 16) - copy(*v, src.UUID[:]) - return nil - case *string: - *v = src.UUID.String() - return nil - default: - if nextDst, retry := pgtype.GetAssignToDstType(v); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case pgtype.Null: + if !src.Valid { return pgtype.NullAssignTo(dst) } - return fmt.Errorf("cannot assign %v into %T", src, dst) + switch v := dst.(type) { + case *uuid.UUID: + *v = src.UUID + return nil + case *[16]byte: + *v = [16]byte(src.UUID) + return nil + case *[]byte: + *v = make([]byte, 16) + copy(*v, src.UUID[:]) + return nil + case *string: + *v = src.UUID.String() + return nil + default: + if nextDst, retry := pgtype.GetAssignToDstType(v); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } } func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { if src == nil { - *dst = UUID{Status: pgtype.Null} + *dst = UUID{} return nil } @@ -112,13 +101,13 @@ func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { return err } - *dst = UUID{UUID: u, Status: pgtype.Present} + *dst = UUID{UUID: u, Valid: true} return nil } func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { if src == nil { - *dst = UUID{Status: pgtype.Null} + *dst = UUID{} return nil } @@ -126,37 +115,29 @@ func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return fmt.Errorf("invalid length for UUID: %v", len(src)) } - *dst = UUID{Status: pgtype.Present} + *dst = UUID{Valid: true} copy(dst.UUID[:], src) return nil } func (src UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case pgtype.Null: + if !src.Valid { return nil, nil - case pgtype.Undefined: - return nil, errUndefined } - return append(buf, src.UUID.String()...), nil } func (src UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case pgtype.Null: + if !src.Valid { return nil, nil - case pgtype.Undefined: - return nil, errUndefined } - return append(buf, src.UUID[:]...), nil } // Scan implements the database/sql Scanner interface. func (dst *UUID) Scan(src interface{}) error { if src == nil { - *dst = UUID{Status: pgtype.Null} + *dst = UUID{} return nil } @@ -176,16 +157,10 @@ func (src UUID) Value() (driver.Value, error) { } func (src UUID) MarshalJSON() ([]byte, error) { - switch src.Status { - case pgtype.Present: - return []byte(`"` + src.UUID.String() + `"`), nil - case pgtype.Null: + if !src.Valid { return []byte("null"), nil - case pgtype.Undefined: - return nil, errUndefined } - - return nil, errBadStatus + return []byte(`"` + src.UUID.String() + `"`), nil } func (dst *UUID) UnmarshalJSON(b []byte) error { @@ -195,11 +170,7 @@ func (dst *UUID) UnmarshalJSON(b []byte) error { return err } - status := pgtype.Null - if u.Valid { - status = pgtype.Present - } - *dst = UUID{UUID: u.UUID, Status: status} + *dst = UUID{UUID: u.UUID, Valid: u.Valid} return nil } diff --git a/ext/gofrs-uuid/uuid_test.go b/ext/gofrs-uuid/uuid_test.go index 56814524..3e5e4d82 100644 --- a/ext/gofrs-uuid/uuid_test.go +++ b/ext/gofrs-uuid/uuid_test.go @@ -4,15 +4,14 @@ import ( "bytes" "testing" - "github.com/jackc/pgtype" gofrs "github.com/jackc/pgtype/ext/gofrs-uuid" "github.com/jackc/pgtype/testutil" ) func TestUUIDTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - &gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - &gofrs.UUID{Status: pgtype.Null}, + &gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + &gofrs.UUID{}, }) } @@ -22,20 +21,20 @@ func TestUUIDSet(t *testing.T) { result gofrs.UUID }{ { - source: &gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + source: &gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, }, { source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, }, { source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, }, { source: "00010203-0405-0607-0809-0a0b0c0d0e0f", - result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, }, } @@ -54,7 +53,7 @@ func TestUUIDSet(t *testing.T) { func TestUUIDAssignTo(t *testing.T) { { - src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} var dst [16]byte expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} @@ -69,7 +68,7 @@ func TestUUIDAssignTo(t *testing.T) { } { - src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} var dst []byte expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} @@ -84,7 +83,7 @@ func TestUUIDAssignTo(t *testing.T) { } { - src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} var dst string expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index ef3ce201..3a9d99ba 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -2,7 +2,6 @@ package numeric import ( "database/sql/driver" - "errors" "fmt" "strconv" @@ -10,17 +9,14 @@ import ( "github.com/shopspring/decimal" ) -var errUndefined = errors.New("cannot encode status undefined") -var errBadStatus = errors.New("invalid status") - type Numeric struct { Decimal decimal.Decimal - Status pgtype.Status + Valid bool } func (dst *Numeric) Set(src interface{}) error { if src == nil { - *dst = Numeric{Status: pgtype.Null} + *dst = Numeric{} return nil } @@ -33,53 +29,53 @@ func (dst *Numeric) Set(src interface{}) error { switch value := src.(type) { case decimal.Decimal: - *dst = Numeric{Decimal: value, Status: pgtype.Present} + *dst = Numeric{Decimal: value, Valid: true} case decimal.NullDecimal: if value.Valid { - *dst = Numeric{Decimal: value.Decimal, Status: pgtype.Present} + *dst = Numeric{Decimal: value.Decimal, Valid: true} } else { - *dst = Numeric{Status: pgtype.Null} + *dst = Numeric{} } case float32: - *dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present} + *dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Valid: true} case float64: - *dst = Numeric{Decimal: decimal.NewFromFloat(value), Status: pgtype.Present} + *dst = Numeric{Decimal: decimal.NewFromFloat(value), Valid: true} case int8: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} case uint8: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} case int16: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} case uint16: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} case int32: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} case uint32: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} case int64: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} case uint64: // uint64 could be greater than int64 so convert to string then to decimal dec, err := decimal.NewFromString(strconv.FormatUint(value, 10)) if err != nil { return err } - *dst = Numeric{Decimal: dec, Status: pgtype.Present} + *dst = Numeric{Decimal: dec, Valid: true} case int: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} case uint: // uint could be greater than int64 so convert to string then to decimal dec, err := decimal.NewFromString(strconv.FormatUint(uint64(value), 10)) if err != nil { return err } - *dst = Numeric{Decimal: dec, Status: pgtype.Present} + *dst = Numeric{Decimal: dec, Valid: true} case string: dec, err := decimal.NewFromString(value) if err != nil { return err } - *dst = Numeric{Decimal: dec, Status: pgtype.Present} + *dst = Numeric{Decimal: dec, Valid: true} default: // If all else fails see if pgtype.Numeric can handle it. If so, translate through that. num := &pgtype.Numeric{} @@ -96,140 +92,136 @@ func (dst *Numeric) Set(src interface{}) error { if err != nil { return fmt.Errorf("cannot convert %v to Numeric", value) } - *dst = Numeric{Decimal: dec, Status: pgtype.Present} + *dst = Numeric{Decimal: dec, Valid: true} } return nil } func (dst Numeric) Get() interface{} { - switch dst.Status { - case pgtype.Present: - return dst.Decimal - case pgtype.Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.Decimal } func (src *Numeric) AssignTo(dst interface{}) error { - switch src.Status { - case pgtype.Present: - switch v := dst.(type) { - case *decimal.Decimal: - *v = src.Decimal - case *decimal.NullDecimal: - (*v).Valid = true - (*v).Decimal = src.Decimal - case *float32: - f, _ := src.Decimal.Float64() - *v = float32(f) - case *float64: - f, _ := src.Decimal.Float64() - *v = f - case *int: - if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, strconv.IntSize) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int(n) - case *int8: - if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, 8) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int8(n) - case *int16: - if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, 16) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int16(n) - case *int32: - if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, 32) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int32(n) - case *int64: - if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, 64) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int64(n) - case *uint: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, strconv.IntSize) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint(n) - case *uint8: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, 8) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint8(n) - case *uint16: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, 16) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint16(n) - case *uint32: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, 32) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint32(n) - case *uint64: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, 64) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint64(n) - default: - if nextDst, retry := pgtype.GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case pgtype.Null: + if !src.Valid { if v, ok := dst.(*decimal.NullDecimal); ok { (*v).Valid = false - } else { - return pgtype.NullAssignTo(dst) + (*v).Decimal = src.Decimal + return nil } + return pgtype.NullAssignTo(dst) + } + + switch v := dst.(type) { + case *decimal.Decimal: + *v = src.Decimal + case *decimal.NullDecimal: + (*v).Valid = true + (*v).Decimal = src.Decimal + case *float32: + f, _ := src.Decimal.Float64() + *v = float32(f) + case *float64: + f, _ := src.Decimal.Float64() + *v = f + case *int: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, strconv.IntSize) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int(n) + case *int8: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 8) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int8(n) + case *int16: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 16) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int16(n) + case *int32: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 32) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int32(n) + case *int64: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 64) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int64(n) + case *uint: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, strconv.IntSize) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint(n) + case *uint8: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 8) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint8(n) + case *uint16: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 16) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint16(n) + case *uint32: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 32) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint32(n) + case *uint64: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 64) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint64(n) + default: + if nextDst, retry := pgtype.GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) } return nil @@ -237,7 +229,7 @@ func (src *Numeric) AssignTo(dst interface{}) error { func (dst *Numeric) DecodeText(ci *pgtype.ConnInfo, src []byte) error { if src == nil { - *dst = Numeric{Status: pgtype.Null} + *dst = Numeric{} return nil } @@ -246,13 +238,13 @@ func (dst *Numeric) DecodeText(ci *pgtype.ConnInfo, src []byte) error { return err } - *dst = Numeric{Decimal: dec, Status: pgtype.Present} + *dst = Numeric{Decimal: dec, Valid: true} return nil } func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { if src == nil { - *dst = Numeric{Status: pgtype.Null} + *dst = Numeric{} return nil } @@ -263,28 +255,21 @@ func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } - *dst = Numeric{Decimal: decimal.NewFromBigInt(num.Int, num.Exp), Status: pgtype.Present} + *dst = Numeric{Decimal: decimal.NewFromBigInt(num.Int, num.Exp), Valid: true} return nil } func (src Numeric) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case pgtype.Null: + if !src.Valid { return nil, nil - case pgtype.Undefined: - return nil, errUndefined } - return append(buf, src.Decimal.String()...), nil } func (src Numeric) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case pgtype.Null: + if !src.Valid { return nil, nil - case pgtype.Undefined: - return nil, errUndefined } // For now at least, implement this in terms of pgtype.Numeric @@ -299,13 +284,13 @@ func (src Numeric) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) // Scan implements the database/sql Scanner interface. func (dst *Numeric) Scan(src interface{}) error { if src == nil { - *dst = Numeric{Status: pgtype.Null} + *dst = Numeric{} return nil } switch src := src.(type) { case float64: - *dst = Numeric{Decimal: decimal.NewFromFloat(src), Status: pgtype.Present} + *dst = Numeric{Decimal: decimal.NewFromFloat(src), Valid: true} return nil case string: return dst.DecodeText(nil, []byte(src)) @@ -318,27 +303,17 @@ func (dst *Numeric) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Numeric) Value() (driver.Value, error) { - switch src.Status { - case pgtype.Present: - return src.Decimal.Value() - case pgtype.Null: + if !src.Valid { return nil, nil - default: - return nil, errUndefined } + return src.Decimal.Value() } func (src Numeric) MarshalJSON() ([]byte, error) { - switch src.Status { - case pgtype.Present: - return src.Decimal.MarshalJSON() - case pgtype.Null: + if !src.Valid { return []byte("null"), nil - case pgtype.Undefined: - return nil, errUndefined } - - return nil, errBadStatus + return src.Decimal.MarshalJSON() } func (dst *Numeric) UnmarshalJSON(b []byte) error { @@ -348,11 +323,7 @@ func (dst *Numeric) UnmarshalJSON(b []byte) error { return err } - status := pgtype.Null - if d.Valid { - status = pgtype.Present - } - *dst = Numeric{Decimal: d.Decimal, Status: status} + *dst = Numeric{Decimal: d.Decimal, Valid: d.Valid} return nil } diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go index e635da41..d130a69a 100644 --- a/ext/shopspring-numeric/decimal_test.go +++ b/ext/shopspring-numeric/decimal_test.go @@ -7,7 +7,6 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" shopspring "github.com/jackc/pgtype/ext/shopspring-numeric" "github.com/jackc/pgtype/testutil" "github.com/shopspring/decimal" @@ -26,100 +25,100 @@ func TestNumericNormalize(t *testing.T) { testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ { SQL: "select '0'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Valid: true}, }, { SQL: "select '1'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}, }, { SQL: "select '10.00'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Valid: true}, }, { SQL: "select '1e-3'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Valid: true}, }, { SQL: "select '-1'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, }, { SQL: "select '10000'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Valid: true}, }, { SQL: "select '3.14'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Valid: true}, }, { SQL: "select '1.1'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Valid: true}, }, { SQL: "select '100010001'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Valid: true}, }, { SQL: "select '100010001.0001'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Valid: true}, }, { SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", Value: &shopspring.Numeric{ Decimal: mustParseDecimal(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), - Status: pgtype.Present, + Valid: true, }, }, { SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", Value: &shopspring.Numeric{ Decimal: mustParseDecimal(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), - Status: pgtype.Present, + Valid: true, }, }, { SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", Value: &shopspring.Numeric{ Decimal: mustParseDecimal(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), - Status: pgtype.Present, + Valid: true, }, }, }, func(aa, bb interface{}) bool { a := aa.(shopspring.Numeric) b := bb.(shopspring.Numeric) - return a.Status == b.Status && a.Decimal.Equal(b.Decimal) + return a.Valid == b.Valid && a.Decimal.Equal(b.Decimal) }) } func TestNumericTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "100000"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Valid: true}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "100000"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.1"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.01"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0001"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00001"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000001"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.1"), Valid: true}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.01"), Valid: true}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Valid: true}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0001"), Valid: true}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00001"), Valid: true}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000001"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000123"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000000123"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0000000123"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000123"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001234567890123456789"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "4309132809320932980457137401234890237489238912983572189348951289375283573984571892758234678903467889512893489128589347891272139.8489235871258912789347891235879148795891238915678189467128957812395781238579189025891238901583915890128973578957912385798125789012378905238905471598123758923478294374327894237892234"), Status: pgtype.Present}, - &shopspring.Numeric{Status: pgtype.Null}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Valid: true}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000123"), Valid: true}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000000123"), Valid: true}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0000000123"), Valid: true}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000123"), Valid: true}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001234567890123456789"), Valid: true}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "4309132809320932980457137401234890237489238912983572189348951289375283573984571892758234678903467889512893489128589347891272139.8489235871258912789347891235879148795891238915678189467128957812395781238579189025891238901583915890128973578957912385798125789012378905238905471598123758923478294374327894237892234"), Valid: true}, + &shopspring.Numeric{}, }, func(aa, bb interface{}) bool { a := aa.(shopspring.Numeric) b := bb.(shopspring.Numeric) - return a.Status == b.Status && a.Decimal.Equal(b.Decimal) + return a.Valid == b.Valid && a.Decimal.Equal(b.Decimal) }) } @@ -133,8 +132,8 @@ func TestNumericTranscodeFuzz(t *testing.T) { for i := 0; i < 500; i++ { num := fmt.Sprintf("%s.%s", (&big.Int{}).Rand(r, max).String(), (&big.Int{}).Rand(r, max).String()) negNum := "-" + num - values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, num), Status: pgtype.Present}) - values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, negNum), Status: pgtype.Present}) + values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, num), Valid: true}) + values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, negNum), Valid: true}) } testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, @@ -142,7 +141,7 @@ func TestNumericTranscodeFuzz(t *testing.T) { a := aa.(shopspring.Numeric) b := bb.(shopspring.Numeric) - return a.Status == b.Status && a.Decimal.Equal(b.Decimal) + return a.Valid == b.Valid && a.Decimal.Equal(b.Decimal) }) } @@ -153,29 +152,29 @@ func TestNumericSet(t *testing.T) { source interface{} result *shopspring.Numeric }{ - {source: decimal.New(1, 0), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: decimal.NullDecimal{Valid: true, Decimal: decimal.New(1, 0)}, result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: decimal.NullDecimal{Valid: false}, result: &shopspring.Numeric{Status: pgtype.Null}}, - {source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: int16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: int32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: int64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: int8(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, - {source: int16(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, - {source: int32(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, - {source: int64(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, - {source: uint8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: uint16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: uint32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: uint64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: "1", result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: _int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: float64(1000), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1000"), Status: pgtype.Present}}, - {source: float64(1234), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1234"), Status: pgtype.Present}}, - {source: float64(12345678900), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345678900"), Status: pgtype.Present}}, - {source: float64(1.25), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.25"), Status: pgtype.Present}}, + {source: decimal.New(1, 0), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, + {source: decimal.NullDecimal{Valid: true, Decimal: decimal.New(1, 0)}, result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, + {source: decimal.NullDecimal{Valid: false}, result: &shopspring.Numeric{}}, + {source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, + {source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, + {source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, + {source: int16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, + {source: int32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, + {source: int64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, + {source: int8(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}}, + {source: int16(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}}, + {source: int32(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}}, + {source: int64(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}}, + {source: uint8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, + {source: uint16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, + {source: uint32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, + {source: uint64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, + {source: "1", result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, + {source: _int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, + {source: float64(1000), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1000"), Valid: true}}, + {source: float64(1234), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1234"), Valid: true}}, + {source: float64(12345678900), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345678900"), Valid: true}}, + {source: float64(1.25), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.25"), Valid: true}}, } for i, tt := range successfulTests { @@ -185,7 +184,7 @@ func TestNumericSet(t *testing.T) { t.Errorf("%d: %v", i, err) } - if !(r.Status == tt.result.Status && r.Decimal.Equal(tt.result.Decimal)) { + if !(r.Valid == tt.result.Valid && r.Decimal.Equal(tt.result.Decimal)) { t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) } } @@ -219,28 +218,28 @@ func TestNumericAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f32, expected: float32(4.2)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f64, expected: float64(4.2)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &i64, expected: int64(42000)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, 0)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, 3)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.042"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, -3)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &nd, expected: decimal.NullDecimal{Valid: true, Decimal: decimal.New(42, 0)}}, - {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &nd, expected: decimal.NullDecimal{Valid: false}}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &f32, expected: float32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &f64, expected: float64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Valid: true}, dst: &f32, expected: float32(4.2)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Valid: true}, dst: &f64, expected: float64(4.2)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &i16, expected: int16(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &i32, expected: int32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &i64, expected: int64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Valid: true}, dst: &i64, expected: int64(42000)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &i, expected: int(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui8, expected: uint8(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui16, expected: uint16(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui64, expected: uint64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui, expected: uint(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &_i8, expected: _int8(42)}, + {src: &shopspring.Numeric{}, dst: &pi8, expected: ((*int8)(nil))}, + {src: &shopspring.Numeric{}, dst: &_pi8, expected: ((*_int8)(nil))}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &d, expected: decimal.New(42, 0)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Valid: true}, dst: &d, expected: decimal.New(42, 3)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.042"), Valid: true}, dst: &d, expected: decimal.New(42, -3)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &nd, expected: decimal.NullDecimal{Valid: true, Decimal: decimal.New(42, 0)}}, + {src: &shopspring.Numeric{}, dst: &nd, expected: decimal.NullDecimal{Valid: false}}, } for i, tt := range simpleTests { @@ -280,8 +279,8 @@ func TestNumericAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &pf32, expected: float32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &pf64, expected: float64(42)}, } for i, tt := range pointerAllocTests { @@ -299,14 +298,14 @@ func TestNumericAssignTo(t *testing.T) { src *shopspring.Numeric dst interface{} }{ - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "150"), Status: pgtype.Present}, dst: &i8}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "40000"), Status: pgtype.Present}, dst: &i16}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui8}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui16}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui32}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui64}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui}, - {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &i32}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "150"), Valid: true}, dst: &i8}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "40000"), Valid: true}, dst: &i16}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui8}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui16}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui32}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui64}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui}, + {src: &shopspring.Numeric{}, dst: &i32}, } for i, tt := range errorTests { diff --git a/float4.go b/float4.go index 89b9e8fa..36c46346 100644 --- a/float4.go +++ b/float4.go @@ -11,13 +11,13 @@ import ( ) type Float4 struct { - Float float32 - Status Status + Float float32 + Valid bool } func (dst *Float4) Set(src interface{}) error { if src == nil { - *dst = Float4{Status: Null} + *dst = Float4{} return nil } @@ -30,56 +30,56 @@ func (dst *Float4) Set(src interface{}) error { switch value := src.(type) { case float32: - *dst = Float4{Float: value, Status: Present} + *dst = Float4{Float: value, Valid: true} case float64: - *dst = Float4{Float: float32(value), Status: Present} + *dst = Float4{Float: float32(value), Valid: true} case int8: - *dst = Float4{Float: float32(value), Status: Present} + *dst = Float4{Float: float32(value), Valid: true} case uint8: - *dst = Float4{Float: float32(value), Status: Present} + *dst = Float4{Float: float32(value), Valid: true} case int16: - *dst = Float4{Float: float32(value), Status: Present} + *dst = Float4{Float: float32(value), Valid: true} case uint16: - *dst = Float4{Float: float32(value), Status: Present} + *dst = Float4{Float: float32(value), Valid: true} case int32: f32 := float32(value) if int32(f32) == value { - *dst = Float4{Float: f32, Status: Present} + *dst = Float4{Float: f32, Valid: true} } else { return fmt.Errorf("%v cannot be exactly represented as float32", value) } case uint32: f32 := float32(value) if uint32(f32) == value { - *dst = Float4{Float: f32, Status: Present} + *dst = Float4{Float: f32, Valid: true} } else { return fmt.Errorf("%v cannot be exactly represented as float32", value) } case int64: f32 := float32(value) if int64(f32) == value { - *dst = Float4{Float: f32, Status: Present} + *dst = Float4{Float: f32, Valid: true} } else { return fmt.Errorf("%v cannot be exactly represented as float32", value) } case uint64: f32 := float32(value) if uint64(f32) == value { - *dst = Float4{Float: f32, Status: Present} + *dst = Float4{Float: f32, Valid: true} } else { return fmt.Errorf("%v cannot be exactly represented as float32", value) } case int: f32 := float32(value) if int(f32) == value { - *dst = Float4{Float: f32, Status: Present} + *dst = Float4{Float: f32, Valid: true} } else { return fmt.Errorf("%v cannot be exactly represented as float32", value) } case uint: f32 := float32(value) if uint(f32) == value { - *dst = Float4{Float: f32, Status: Present} + *dst = Float4{Float: f32, Valid: true} } else { return fmt.Errorf("%v cannot be exactly represented as float32", value) } @@ -88,82 +88,82 @@ func (dst *Float4) Set(src interface{}) error { if err != nil { return err } - *dst = Float4{Float: float32(num), Status: Present} + *dst = Float4{Float: float32(num), Valid: true} case *float64: if value == nil { - *dst = Float4{Status: Null} + *dst = Float4{} } else { return dst.Set(*value) } case *float32: if value == nil { - *dst = Float4{Status: Null} + *dst = Float4{} } else { return dst.Set(*value) } case *int8: if value == nil { - *dst = Float4{Status: Null} + *dst = Float4{} } else { return dst.Set(*value) } case *uint8: if value == nil { - *dst = Float4{Status: Null} + *dst = Float4{} } else { return dst.Set(*value) } case *int16: if value == nil { - *dst = Float4{Status: Null} + *dst = Float4{} } else { return dst.Set(*value) } case *uint16: if value == nil { - *dst = Float4{Status: Null} + *dst = Float4{} } else { return dst.Set(*value) } case *int32: if value == nil { - *dst = Float4{Status: Null} + *dst = Float4{} } else { return dst.Set(*value) } case *uint32: if value == nil { - *dst = Float4{Status: Null} + *dst = Float4{} } else { return dst.Set(*value) } case *int64: if value == nil { - *dst = Float4{Status: Null} + *dst = Float4{} } else { return dst.Set(*value) } case *uint64: if value == nil { - *dst = Float4{Status: Null} + *dst = Float4{} } else { return dst.Set(*value) } case *int: if value == nil { - *dst = Float4{Status: Null} + *dst = Float4{} } else { return dst.Set(*value) } case *uint: if value == nil { - *dst = Float4{Status: Null} + *dst = Float4{} } else { return dst.Set(*value) } case *string: if value == nil { - *dst = Float4{Status: Null} + *dst = Float4{} } else { return dst.Set(*value) } @@ -178,23 +178,19 @@ func (dst *Float4) Set(src interface{}) error { } func (dst Float4) Get() interface{} { - switch dst.Status { - case Present: - return dst.Float - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.Float } func (src *Float4) AssignTo(dst interface{}) error { - return float64AssignTo(float64(src.Float), src.Status, dst) + return float64AssignTo(float64(src.Float), src.Valid, dst) } func (dst *Float4) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Float4{Status: Null} + *dst = Float4{} return nil } @@ -203,13 +199,13 @@ func (dst *Float4) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Float4{Float: float32(n), Status: Present} + *dst = Float4{Float: float32(n), Valid: true} return nil } func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Float4{Status: Null} + *dst = Float4{} return nil } @@ -219,16 +215,13 @@ func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { n := int32(binary.BigEndian.Uint32(src)) - *dst = Float4{Float: math.Float32frombits(uint32(n)), Status: Present} + *dst = Float4{Float: math.Float32frombits(uint32(n)), Valid: true} return nil } func (src Float4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = append(buf, strconv.FormatFloat(float64(src.Float), 'f', -1, 32)...) @@ -236,11 +229,8 @@ func (src Float4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Float4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = pgio.AppendUint32(buf, math.Float32bits(src.Float)) @@ -250,13 +240,13 @@ func (src Float4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Float4) Scan(src interface{}) error { if src == nil { - *dst = Float4{Status: Null} + *dst = Float4{} return nil } switch src := src.(type) { case float64: - *dst = Float4{Float: float32(src), Status: Present} + *dst = Float4{Float: float32(src), Valid: true} return nil case string: return dst.DecodeText(nil, []byte(src)) @@ -271,12 +261,8 @@ func (dst *Float4) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Float4) Value() (driver.Value, error) { - switch src.Status { - case Present: - return float64(src.Float), nil - case Null: + if !src.Valid { return nil, nil - default: - return nil, errUndefined } + return float64(src.Float), nil } diff --git a/float4_array.go b/float4_array.go index 41f2ec8f..dcf6c1f7 100644 --- a/float4_array.go +++ b/float4_array.go @@ -14,13 +14,13 @@ import ( type Float4Array struct { Elements []Float4 Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *Float4Array) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = Float4Array{Status: Null} + *dst = Float4Array{} return nil } @@ -36,9 +36,9 @@ func (dst *Float4Array) Set(src interface{}) error { case []float32: if value == nil { - *dst = Float4Array{Status: Null} + *dst = Float4Array{} } else if len(value) == 0 { - *dst = Float4Array{Status: Present} + *dst = Float4Array{Valid: true} } else { elements := make([]Float4, len(value)) for i := range value { @@ -49,15 +49,15 @@ func (dst *Float4Array) Set(src interface{}) error { *dst = Float4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*float32: if value == nil { - *dst = Float4Array{Status: Null} + *dst = Float4Array{} } else if len(value) == 0 { - *dst = Float4Array{Status: Present} + *dst = Float4Array{Valid: true} } else { elements := make([]Float4, len(value)) for i := range value { @@ -68,20 +68,20 @@ func (dst *Float4Array) Set(src interface{}) error { *dst = Float4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []Float4: if value == nil { - *dst = Float4Array{Status: Null} + *dst = Float4Array{} } else if len(value) == 0 { - *dst = Float4Array{Status: Present} + *dst = Float4Array{Valid: true} } else { *dst = Float4Array{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -90,7 +90,7 @@ func (dst *Float4Array) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = Float4Array{Status: Null} + *dst = Float4Array{} return nil } @@ -99,7 +99,7 @@ func (dst *Float4Array) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for Float4Array", src) } if elementsLength == 0 { - *dst = Float4Array{Status: Present} + *dst = Float4Array{Valid: true} return nil } if len(dimensions) == 0 { @@ -112,7 +112,7 @@ func (dst *Float4Array) Set(src interface{}) error { *dst = Float4Array{ Elements: make([]Float4, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -179,84 +179,77 @@ func (dst *Float4Array) setRecursive(value reflect.Value, index, dimension int) } func (dst Float4Array) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *Float4Array) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]float32: - *v = make([]float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float32: - *v = make([]*float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]float32: + *v = make([]float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*float32: + *v = make([]*float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *Float4Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -308,7 +301,7 @@ func (src *Float4Array) assignToRecursive(value reflect.Value, index, dimension func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Float4Array{Status: Null} + *dst = Float4Array{} return nil } @@ -337,14 +330,14 @@ func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = Float4Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = Float4Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Float4Array{Status: Null} + *dst = Float4Array{} return nil } @@ -355,7 +348,7 @@ func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = Float4Array{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = Float4Array{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -380,16 +373,13 @@ func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = Float4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = Float4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -442,11 +432,8 @@ func (src Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -460,7 +447,7 @@ func (src Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/float4_array_test.go b/float4_array_test.go index db438999..9b401ac8 100644 --- a/float4_array_test.go +++ b/float4_array_test.go @@ -13,41 +13,41 @@ func TestFloat4ArrayTranscode(t *testing.T) { &pgtype.Float4Array{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.Float4Array{ Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Status: pgtype.Null}, + {Float: 1, Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.Float4Array{Status: pgtype.Null}, + &pgtype.Float4Array{}, &pgtype.Float4Array{ Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Float: 6, Status: pgtype.Present}, + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {}, + {Float: 6, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.Float4Array{ Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } @@ -60,61 +60,61 @@ func TestFloat4ArraySet(t *testing.T) { { source: []float32{1}, result: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}}, + Elements: []pgtype.Float4{{Float: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]float32)(nil)), - result: pgtype.Float4Array{Status: pgtype.Null}, + result: pgtype.Float4Array{}, }, { source: [][]float32{{1}, {2}}, result: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, result: pgtype.Float4Array{ Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Float: 5, Status: pgtype.Present}, - {Float: 6, Status: pgtype.Present}}, + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {Float: 5, Valid: true}, + {Float: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1]float32{{1}, {2}}, result: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, result: pgtype.Float4Array{ Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Float: 5, Status: pgtype.Present}, - {Float: 6, Status: pgtype.Present}}, + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {Float: 5, Valid: true}, + {Float: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -146,81 +146,81 @@ func TestFloat4ArrayAssignTo(t *testing.T) { }{ { src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1.23, Status: pgtype.Present}}, + Elements: []pgtype.Float4{{Float: 1.23, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &float32Slice, expected: []float32{1.23}, }, { src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1.23, Status: pgtype.Present}}, + Elements: []pgtype.Float4{{Float: 1.23, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &namedFloat32Slice, expected: _float32Slice{1.23}, }, { - src: pgtype.Float4Array{Status: pgtype.Null}, + src: pgtype.Float4Array{}, dst: &float32Slice, expected: (([]float32)(nil)), }, { - src: pgtype.Float4Array{Status: pgtype.Present}, + src: pgtype.Float4Array{Valid: true}, dst: &float32Slice, expected: []float32{}, }, { src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, expected: [][]float32{{1}, {2}}, dst: &float32SliceDim2, }, { src: pgtype.Float4Array{ Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Float: 5, Status: pgtype.Present}, - {Float: 6, Status: pgtype.Present}}, + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {Float: 5, Valid: true}, + {Float: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, expected: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, dst: &float32SliceDim4, }, { src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, expected: [2][1]float32{{1}, {2}}, dst: &float32ArrayDim2, }, { src: pgtype.Float4Array{ Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Float: 5, Status: pgtype.Present}, - {Float: 6, Status: pgtype.Present}}, + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {Float: 5, Valid: true}, + {Float: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, expected: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, dst: &float32ArrayDim4, }, @@ -243,31 +243,31 @@ func TestFloat4ArrayAssignTo(t *testing.T) { }{ { src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Status: pgtype.Null}}, + Elements: []pgtype.Float4{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &float32Slice, }, { src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &float32ArrayDim2, }, { src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &float32Slice, }, { src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &float32ArrayDim4, }, } diff --git a/float4_test.go b/float4_test.go index d2524cda..191df65e 100644 --- a/float4_test.go +++ b/float4_test.go @@ -10,12 +10,12 @@ import ( func TestFloat4Transcode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "float4", []interface{}{ - &pgtype.Float4{Float: -1, Status: pgtype.Present}, - &pgtype.Float4{Float: 0, Status: pgtype.Present}, - &pgtype.Float4{Float: 0.00001, Status: pgtype.Present}, - &pgtype.Float4{Float: 1, Status: pgtype.Present}, - &pgtype.Float4{Float: 9999.99, Status: pgtype.Present}, - &pgtype.Float4{Float: 0, Status: pgtype.Null}, + &pgtype.Float4{Float: -1, Valid: true}, + &pgtype.Float4{Float: 0, Valid: true}, + &pgtype.Float4{Float: 0.00001, Valid: true}, + &pgtype.Float4{Float: 1, Valid: true}, + &pgtype.Float4{Float: 9999.99, Valid: true}, + &pgtype.Float4{Float: 0}, }) } @@ -24,22 +24,22 @@ func TestFloat4Set(t *testing.T) { source interface{} result pgtype.Float4 }{ - {source: float32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: float64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: int8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: float32(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: float64(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: int8(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: int16(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: int32(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: int64(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: int8(-1), result: pgtype.Float4{Float: -1, Valid: true}}, + {source: int16(-1), result: pgtype.Float4{Float: -1, Valid: true}}, + {source: int32(-1), result: pgtype.Float4{Float: -1, Valid: true}}, + {source: int64(-1), result: pgtype.Float4{Float: -1, Valid: true}}, + {source: uint8(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: uint16(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: uint32(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: uint64(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: "1", result: pgtype.Float4{Float: 1, Valid: true}}, + {source: _int8(1), result: pgtype.Float4{Float: 1, Valid: true}}, } for i, tt := range successfulTests { @@ -79,20 +79,20 @@ func TestFloat4AssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &f32, expected: float32(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &f64, expected: float64(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &f32, expected: float32(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &f64, expected: float64(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &i16, expected: int16(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &i32, expected: int32(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &i64, expected: int64(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &i, expected: int(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &ui, expected: uint(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Float4{Float: 0}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Float4{Float: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, } for i, tt := range simpleTests { @@ -111,8 +111,8 @@ func TestFloat4AssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &pf32, expected: float32(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &pf64, expected: float64(42)}, } for i, tt := range pointerAllocTests { @@ -130,14 +130,14 @@ func TestFloat4AssignTo(t *testing.T) { src pgtype.Float4 dst interface{} }{ - {src: pgtype.Float4{Float: 150, Status: pgtype.Present}, dst: &i8}, - {src: pgtype.Float4{Float: 40000, Status: pgtype.Present}, dst: &i16}, - {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &i32}, + {src: pgtype.Float4{Float: 150, Valid: true}, dst: &i8}, + {src: pgtype.Float4{Float: 40000, Valid: true}, dst: &i16}, + {src: pgtype.Float4{Float: -1, Valid: true}, dst: &ui8}, + {src: pgtype.Float4{Float: -1, Valid: true}, dst: &ui16}, + {src: pgtype.Float4{Float: -1, Valid: true}, dst: &ui32}, + {src: pgtype.Float4{Float: -1, Valid: true}, dst: &ui64}, + {src: pgtype.Float4{Float: -1, Valid: true}, dst: &ui}, + {src: pgtype.Float4{Float: 0}, dst: &i32}, } for i, tt := range errorTests { diff --git a/float8.go b/float8.go index 4d9e7116..1038d283 100644 --- a/float8.go +++ b/float8.go @@ -11,13 +11,13 @@ import ( ) type Float8 struct { - Float float64 - Status Status + Float float64 + Valid bool } func (dst *Float8) Set(src interface{}) error { if src == nil { - *dst = Float8{Status: Null} + *dst = Float8{} return nil } @@ -30,46 +30,46 @@ func (dst *Float8) Set(src interface{}) error { switch value := src.(type) { case float32: - *dst = Float8{Float: float64(value), Status: Present} + *dst = Float8{Float: float64(value), Valid: true} case float64: - *dst = Float8{Float: value, Status: Present} + *dst = Float8{Float: value, Valid: true} case int8: - *dst = Float8{Float: float64(value), Status: Present} + *dst = Float8{Float: float64(value), Valid: true} case uint8: - *dst = Float8{Float: float64(value), Status: Present} + *dst = Float8{Float: float64(value), Valid: true} case int16: - *dst = Float8{Float: float64(value), Status: Present} + *dst = Float8{Float: float64(value), Valid: true} case uint16: - *dst = Float8{Float: float64(value), Status: Present} + *dst = Float8{Float: float64(value), Valid: true} case int32: - *dst = Float8{Float: float64(value), Status: Present} + *dst = Float8{Float: float64(value), Valid: true} case uint32: - *dst = Float8{Float: float64(value), Status: Present} + *dst = Float8{Float: float64(value), Valid: true} case int64: f64 := float64(value) if int64(f64) == value { - *dst = Float8{Float: f64, Status: Present} + *dst = Float8{Float: f64, Valid: true} } else { return fmt.Errorf("%v cannot be exactly represented as float64", value) } case uint64: f64 := float64(value) if uint64(f64) == value { - *dst = Float8{Float: f64, Status: Present} + *dst = Float8{Float: f64, Valid: true} } else { return fmt.Errorf("%v cannot be exactly represented as float64", value) } case int: f64 := float64(value) if int(f64) == value { - *dst = Float8{Float: f64, Status: Present} + *dst = Float8{Float: f64, Valid: true} } else { return fmt.Errorf("%v cannot be exactly represented as float64", value) } case uint: f64 := float64(value) if uint(f64) == value { - *dst = Float8{Float: f64, Status: Present} + *dst = Float8{Float: f64, Valid: true} } else { return fmt.Errorf("%v cannot be exactly represented as float64", value) } @@ -78,82 +78,82 @@ func (dst *Float8) Set(src interface{}) error { if err != nil { return err } - *dst = Float8{Float: float64(num), Status: Present} + *dst = Float8{Float: float64(num), Valid: true} case *float64: if value == nil { - *dst = Float8{Status: Null} + *dst = Float8{} } else { return dst.Set(*value) } case *float32: if value == nil { - *dst = Float8{Status: Null} + *dst = Float8{} } else { return dst.Set(*value) } case *int8: if value == nil { - *dst = Float8{Status: Null} + *dst = Float8{} } else { return dst.Set(*value) } case *uint8: if value == nil { - *dst = Float8{Status: Null} + *dst = Float8{} } else { return dst.Set(*value) } case *int16: if value == nil { - *dst = Float8{Status: Null} + *dst = Float8{} } else { return dst.Set(*value) } case *uint16: if value == nil { - *dst = Float8{Status: Null} + *dst = Float8{} } else { return dst.Set(*value) } case *int32: if value == nil { - *dst = Float8{Status: Null} + *dst = Float8{} } else { return dst.Set(*value) } case *uint32: if value == nil { - *dst = Float8{Status: Null} + *dst = Float8{} } else { return dst.Set(*value) } case *int64: if value == nil { - *dst = Float8{Status: Null} + *dst = Float8{} } else { return dst.Set(*value) } case *uint64: if value == nil { - *dst = Float8{Status: Null} + *dst = Float8{} } else { return dst.Set(*value) } case *int: if value == nil { - *dst = Float8{Status: Null} + *dst = Float8{} } else { return dst.Set(*value) } case *uint: if value == nil { - *dst = Float8{Status: Null} + *dst = Float8{} } else { return dst.Set(*value) } case *string: if value == nil { - *dst = Float8{Status: Null} + *dst = Float8{} } else { return dst.Set(*value) } @@ -168,23 +168,19 @@ func (dst *Float8) Set(src interface{}) error { } func (dst Float8) Get() interface{} { - switch dst.Status { - case Present: - return dst.Float - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.Float } func (src *Float8) AssignTo(dst interface{}) error { - return float64AssignTo(src.Float, src.Status, dst) + return float64AssignTo(src.Float, src.Valid, dst) } func (dst *Float8) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Float8{Status: Null} + *dst = Float8{} return nil } @@ -193,13 +189,13 @@ func (dst *Float8) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Float8{Float: n, Status: Present} + *dst = Float8{Float: n, Valid: true} return nil } func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Float8{Status: Null} + *dst = Float8{} return nil } @@ -209,16 +205,13 @@ func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { n := int64(binary.BigEndian.Uint64(src)) - *dst = Float8{Float: math.Float64frombits(uint64(n)), Status: Present} + *dst = Float8{Float: math.Float64frombits(uint64(n)), Valid: true} return nil } func (src Float8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = append(buf, strconv.FormatFloat(float64(src.Float), 'f', -1, 64)...) @@ -226,11 +219,8 @@ func (src Float8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Float8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = pgio.AppendUint64(buf, math.Float64bits(src.Float)) @@ -240,13 +230,13 @@ func (src Float8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Float8) Scan(src interface{}) error { if src == nil { - *dst = Float8{Status: Null} + *dst = Float8{} return nil } switch src := src.(type) { case float64: - *dst = Float8{Float: src, Status: Present} + *dst = Float8{Float: src, Valid: true} return nil case string: return dst.DecodeText(nil, []byte(src)) @@ -261,12 +251,8 @@ func (dst *Float8) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Float8) Value() (driver.Value, error) { - switch src.Status { - case Present: - return src.Float, nil - case Null: + if !src.Valid { return nil, nil - default: - return nil, errUndefined } + return src.Float, nil } diff --git a/float8_array.go b/float8_array.go index 836ee19d..5e85e236 100644 --- a/float8_array.go +++ b/float8_array.go @@ -14,13 +14,13 @@ import ( type Float8Array struct { Elements []Float8 Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *Float8Array) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = Float8Array{Status: Null} + *dst = Float8Array{} return nil } @@ -36,9 +36,9 @@ func (dst *Float8Array) Set(src interface{}) error { case []float64: if value == nil { - *dst = Float8Array{Status: Null} + *dst = Float8Array{} } else if len(value) == 0 { - *dst = Float8Array{Status: Present} + *dst = Float8Array{Valid: true} } else { elements := make([]Float8, len(value)) for i := range value { @@ -49,15 +49,15 @@ func (dst *Float8Array) Set(src interface{}) error { *dst = Float8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*float64: if value == nil { - *dst = Float8Array{Status: Null} + *dst = Float8Array{} } else if len(value) == 0 { - *dst = Float8Array{Status: Present} + *dst = Float8Array{Valid: true} } else { elements := make([]Float8, len(value)) for i := range value { @@ -68,20 +68,20 @@ func (dst *Float8Array) Set(src interface{}) error { *dst = Float8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []Float8: if value == nil { - *dst = Float8Array{Status: Null} + *dst = Float8Array{} } else if len(value) == 0 { - *dst = Float8Array{Status: Present} + *dst = Float8Array{Valid: true} } else { *dst = Float8Array{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -90,7 +90,7 @@ func (dst *Float8Array) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = Float8Array{Status: Null} + *dst = Float8Array{} return nil } @@ -99,7 +99,7 @@ func (dst *Float8Array) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for Float8Array", src) } if elementsLength == 0 { - *dst = Float8Array{Status: Present} + *dst = Float8Array{Valid: true} return nil } if len(dimensions) == 0 { @@ -112,7 +112,7 @@ func (dst *Float8Array) Set(src interface{}) error { *dst = Float8Array{ Elements: make([]Float8, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -179,84 +179,77 @@ func (dst *Float8Array) setRecursive(value reflect.Value, index, dimension int) } func (dst Float8Array) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *Float8Array) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]float64: - *v = make([]float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float64: - *v = make([]*float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]float64: + *v = make([]float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*float64: + *v = make([]*float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *Float8Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -308,7 +301,7 @@ func (src *Float8Array) assignToRecursive(value reflect.Value, index, dimension func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Float8Array{Status: Null} + *dst = Float8Array{} return nil } @@ -337,14 +330,14 @@ func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = Float8Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = Float8Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Float8Array{Status: Null} + *dst = Float8Array{} return nil } @@ -355,7 +348,7 @@ func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = Float8Array{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = Float8Array{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -380,16 +373,13 @@ func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = Float8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = Float8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -442,11 +432,8 @@ func (src Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -460,7 +447,7 @@ func (src Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/float8_array_test.go b/float8_array_test.go index 85cb8f43..52209238 100644 --- a/float8_array_test.go +++ b/float8_array_test.go @@ -13,41 +13,41 @@ func TestFloat8ArrayTranscode(t *testing.T) { &pgtype.Float8Array{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.Float8Array{ Elements: []pgtype.Float8{ - {Float: 1, Status: pgtype.Present}, - {Status: pgtype.Null}, + {Float: 1, Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.Float8Array{Status: pgtype.Null}, + &pgtype.Float8Array{}, &pgtype.Float8Array{ Elements: []pgtype.Float8{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Float: 6, Status: pgtype.Present}, + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {}, + {Float: 6, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.Float8Array{ Elements: []pgtype.Float8{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } @@ -60,37 +60,37 @@ func TestFloat8ArraySet(t *testing.T) { { source: []float64{1}, result: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}}, + Elements: []pgtype.Float8{{Float: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]float64)(nil)), - result: pgtype.Float8Array{Status: pgtype.Null}, + result: pgtype.Float8Array{}, }, { source: [][]float64{{1}, {2}}, result: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, result: pgtype.Float8Array{ Elements: []pgtype.Float8{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Float: 5, Status: pgtype.Present}, - {Float: 6, Status: pgtype.Present}}, + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {Float: 5, Valid: true}, + {Float: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -122,81 +122,81 @@ func TestFloat8ArrayAssignTo(t *testing.T) { }{ { src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1.23, Status: pgtype.Present}}, + Elements: []pgtype.Float8{{Float: 1.23, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &float64Slice, expected: []float64{1.23}, }, { src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1.23, Status: pgtype.Present}}, + Elements: []pgtype.Float8{{Float: 1.23, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &namedFloat64Slice, expected: _float64Slice{1.23}, }, { - src: pgtype.Float8Array{Status: pgtype.Null}, + src: pgtype.Float8Array{}, dst: &float64Slice, expected: (([]float64)(nil)), }, { - src: pgtype.Float8Array{Status: pgtype.Present}, + src: pgtype.Float8Array{Valid: true}, dst: &float64Slice, expected: []float64{}, }, { src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, expected: [][]float64{{1}, {2}}, dst: &float64SliceDim2, }, { src: pgtype.Float8Array{ Elements: []pgtype.Float8{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Float: 5, Status: pgtype.Present}, - {Float: 6, Status: pgtype.Present}}, + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {Float: 5, Valid: true}, + {Float: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, expected: [][][][]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, dst: &float64SliceDim4, }, { src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, expected: [2][1]float64{{1}, {2}}, dst: &float64ArrayDim2, }, { src: pgtype.Float8Array{ Elements: []pgtype.Float8{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Float: 5, Status: pgtype.Present}, - {Float: 6, Status: pgtype.Present}}, + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {Float: 5, Valid: true}, + {Float: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, expected: [2][1][1][3]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, dst: &float64ArrayDim4, }, @@ -219,31 +219,31 @@ func TestFloat8ArrayAssignTo(t *testing.T) { }{ { src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Status: pgtype.Null}}, + Elements: []pgtype.Float8{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &float64Slice, }, { src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &float64ArrayDim2, }, { src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &float64Slice, }, { src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &float64ArrayDim4, }, } diff --git a/float8_test.go b/float8_test.go index 6bc7c652..dcc45879 100644 --- a/float8_test.go +++ b/float8_test.go @@ -10,12 +10,12 @@ import ( func TestFloat8Transcode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "float8", []interface{}{ - &pgtype.Float8{Float: -1, Status: pgtype.Present}, - &pgtype.Float8{Float: 0, Status: pgtype.Present}, - &pgtype.Float8{Float: 0.00001, Status: pgtype.Present}, - &pgtype.Float8{Float: 1, Status: pgtype.Present}, - &pgtype.Float8{Float: 9999.99, Status: pgtype.Present}, - &pgtype.Float8{Float: 0, Status: pgtype.Null}, + &pgtype.Float8{Float: -1, Valid: true}, + &pgtype.Float8{Float: 0, Valid: true}, + &pgtype.Float8{Float: 0.00001, Valid: true}, + &pgtype.Float8{Float: 1, Valid: true}, + &pgtype.Float8{Float: 9999.99, Valid: true}, + &pgtype.Float8{Float: 0}, }) } @@ -24,22 +24,22 @@ func TestFloat8Set(t *testing.T) { source interface{} result pgtype.Float8 }{ - {source: float32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: float64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: int8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: float32(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: float64(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: int8(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: int16(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: int32(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: int64(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: int8(-1), result: pgtype.Float8{Float: -1, Valid: true}}, + {source: int16(-1), result: pgtype.Float8{Float: -1, Valid: true}}, + {source: int32(-1), result: pgtype.Float8{Float: -1, Valid: true}}, + {source: int64(-1), result: pgtype.Float8{Float: -1, Valid: true}}, + {source: uint8(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: uint16(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: uint32(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: uint64(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: "1", result: pgtype.Float8{Float: 1, Valid: true}}, + {source: _int8(1), result: pgtype.Float8{Float: 1, Valid: true}}, } for i, tt := range successfulTests { @@ -79,20 +79,20 @@ func TestFloat8AssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &f32, expected: float32(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &f64, expected: float64(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &f32, expected: float32(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &f64, expected: float64(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &i16, expected: int16(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &i32, expected: int32(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &i64, expected: int64(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &i, expected: int(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &ui, expected: uint(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Float8{Float: 0}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Float8{Float: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, } for i, tt := range simpleTests { @@ -111,8 +111,8 @@ func TestFloat8AssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &pf32, expected: float32(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &pf64, expected: float64(42)}, } for i, tt := range pointerAllocTests { @@ -130,14 +130,14 @@ func TestFloat8AssignTo(t *testing.T) { src pgtype.Float8 dst interface{} }{ - {src: pgtype.Float8{Float: 150, Status: pgtype.Present}, dst: &i8}, - {src: pgtype.Float8{Float: 40000, Status: pgtype.Present}, dst: &i16}, - {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &i32}, + {src: pgtype.Float8{Float: 150, Valid: true}, dst: &i8}, + {src: pgtype.Float8{Float: 40000, Valid: true}, dst: &i16}, + {src: pgtype.Float8{Float: -1, Valid: true}, dst: &ui8}, + {src: pgtype.Float8{Float: -1, Valid: true}, dst: &ui16}, + {src: pgtype.Float8{Float: -1, Valid: true}, dst: &ui32}, + {src: pgtype.Float8{Float: -1, Valid: true}, dst: &ui64}, + {src: pgtype.Float8{Float: -1, Valid: true}, dst: &ui}, + {src: pgtype.Float8{Float: 0}, dst: &i32}, } for i, tt := range errorTests { diff --git a/go.mod b/go.mod index 63bae879..99c5b26e 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ require ( github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530 github.com/jackc/pgio v1.0.0 github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c - github.com/lib/pq v1.10.2 github.com/shopspring/decimal v1.2.0 github.com/stretchr/testify v1.7.0 ) diff --git a/hstore.go b/hstore.go index f46eeaf6..25406a74 100644 --- a/hstore.go +++ b/hstore.go @@ -16,13 +16,13 @@ import ( // Hstore represents an hstore column that can be null or have null values // associated with its keys. type Hstore struct { - Map map[string]Text - Status Status + Map map[string]Text + Valid bool } func (dst *Hstore) Set(src interface{}) error { if src == nil { - *dst = Hstore{Status: Null} + *dst = Hstore{} return nil } @@ -37,19 +37,19 @@ func (dst *Hstore) Set(src interface{}) error { case map[string]string: m := make(map[string]Text, len(value)) for k, v := range value { - m[k] = Text{String: v, Status: Present} + m[k] = Text{String: v, Valid: true} } - *dst = Hstore{Map: m, Status: Present} + *dst = Hstore{Map: m, Valid: true} case map[string]*string: m := make(map[string]Text, len(value)) for k, v := range value { if v == nil { - m[k] = Text{Status: Null} + m[k] = Text{} } else { - m[k] = Text{String: *v, Status: Present} + m[k] = Text{String: *v, Valid: true} } } - *dst = Hstore{Map: m, Status: Present} + *dst = Hstore{Map: m, Valid: true} default: return fmt.Errorf("cannot convert %v to Hstore", src) } @@ -58,58 +58,48 @@ func (dst *Hstore) Set(src interface{}) error { } func (dst Hstore) Get() interface{} { - switch dst.Status { - case Present: - return dst.Map - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.Map } func (src *Hstore) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *map[string]string: - *v = make(map[string]string, len(src.Map)) - for k, val := range src.Map { - if val.Status != Present { - return fmt.Errorf("cannot decode %#v into %T", src, dst) - } - (*v)[k] = val.String - } - return nil - case *map[string]*string: - *v = make(map[string]*string, len(src.Map)) - for k, val := range src.Map { - switch val.Status { - case Null: - (*v)[k] = nil - case Present: - (*v)[k] = &val.String - default: - return fmt.Errorf("cannot decode %#v into %T", src, dst) - } - } - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + switch v := dst.(type) { + case *map[string]string: + *v = make(map[string]string, len(src.Map)) + for k, val := range src.Map { + if !val.Valid { + return fmt.Errorf("cannot decode %#v into %T", src, dst) + } + (*v)[k] = val.String + } + return nil + case *map[string]*string: + *v = make(map[string]*string, len(src.Map)) + for k, val := range src.Map { + if val.Valid { + (*v)[k] = &val.String + } else { + (*v)[k] = nil + } + } + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } } func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Hstore{Status: Null} + *dst = Hstore{} return nil } @@ -123,13 +113,13 @@ func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { m[keys[i]] = values[i] } - *dst = Hstore{Map: m, Status: Present} + *dst = Hstore{Map: m, Valid: true} return nil } func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Hstore{Status: Null} + *dst = Hstore{} return nil } @@ -176,17 +166,14 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { m[key] = value } - *dst = Hstore{Map: m, Status: Present} + *dst = Hstore{Map: m, Valid: true} return nil } func (src Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } firstPair := true @@ -218,11 +205,8 @@ func (src Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Hstore) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = pgio.AppendInt32(buf, int32(len(src.Map))) @@ -372,7 +356,7 @@ func parseHstore(s string) (k []string, v []Text, err error) { case hsVal: switch r { case '"': //End of the value - values = append(values, Text{String: buf.String(), Status: Present}) + values = append(values, Text{String: buf.String(), Valid: true}) buf = bytes.Buffer{} state = hsNext case '\\': //Potential escaped character @@ -401,7 +385,7 @@ func parseHstore(s string) (k []string, v []Text, err error) { nulBuf[i] = r } if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { - values = append(values, Text{Status: Null}) + values = append(values, Text{}) state = hsNext } else { err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) @@ -440,7 +424,7 @@ func parseHstore(s string) (k []string, v []Text, err error) { // Scan implements the database/sql Scanner interface. func (dst *Hstore) Scan(src interface{}) error { if src == nil { - *dst = Hstore{Status: Null} + *dst = Hstore{} return nil } diff --git a/hstore_array.go b/hstore_array.go index 47b4b3ff..0ca5d4fb 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -14,13 +14,13 @@ import ( type HstoreArray struct { Elements []Hstore Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *HstoreArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = HstoreArray{Status: Null} + *dst = HstoreArray{} return nil } @@ -36,9 +36,9 @@ func (dst *HstoreArray) Set(src interface{}) error { case []map[string]string: if value == nil { - *dst = HstoreArray{Status: Null} + *dst = HstoreArray{} } else if len(value) == 0 { - *dst = HstoreArray{Status: Present} + *dst = HstoreArray{Valid: true} } else { elements := make([]Hstore, len(value)) for i := range value { @@ -49,20 +49,20 @@ func (dst *HstoreArray) Set(src interface{}) error { *dst = HstoreArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []Hstore: if value == nil { - *dst = HstoreArray{Status: Null} + *dst = HstoreArray{} } else if len(value) == 0 { - *dst = HstoreArray{Status: Present} + *dst = HstoreArray{Valid: true} } else { *dst = HstoreArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -71,7 +71,7 @@ func (dst *HstoreArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = HstoreArray{Status: Null} + *dst = HstoreArray{} return nil } @@ -80,7 +80,7 @@ func (dst *HstoreArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for HstoreArray", src) } if elementsLength == 0 { - *dst = HstoreArray{Status: Present} + *dst = HstoreArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -93,7 +93,7 @@ func (dst *HstoreArray) Set(src interface{}) error { *dst = HstoreArray{ Elements: make([]Hstore, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -160,75 +160,68 @@ func (dst *HstoreArray) setRecursive(value reflect.Value, index, dimension int) } func (dst HstoreArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *HstoreArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]map[string]string: - *v = make([]map[string]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]map[string]string: + *v = make([]map[string]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *HstoreArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -280,7 +273,7 @@ func (src *HstoreArray) assignToRecursive(value reflect.Value, index, dimension func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = HstoreArray{Status: Null} + *dst = HstoreArray{} return nil } @@ -309,14 +302,14 @@ func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = HstoreArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = HstoreArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = HstoreArray{Status: Null} + *dst = HstoreArray{} return nil } @@ -327,7 +320,7 @@ func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = HstoreArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = HstoreArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -352,16 +345,13 @@ func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = HstoreArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = HstoreArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -414,11 +404,8 @@ func (src HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -432,7 +419,7 @@ func (src HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/hstore_array_test.go b/hstore_array_test.go index 672eca4a..11290fb1 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -29,16 +29,16 @@ func TestHstoreArrayTranscode(t *testing.T) { conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.HstoreArray{}, Name: "_hstore", OID: hstoreArrayOID}) text := func(s string) pgtype.Text { - return pgtype.Text{String: s, Status: pgtype.Present} + return pgtype.Text{String: s, Valid: true} } values := []pgtype.Hstore{ - {Map: map[string]pgtype.Text{}, Status: pgtype.Present}, - {Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, - {Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, - {Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, - {Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, - {Status: pgtype.Null}, + {Map: map[string]pgtype.Text{}, Valid: true}, + {Map: map[string]pgtype.Text{"foo": text("bar")}, Valid: true}, + {Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Valid: true}, + {Map: map[string]pgtype.Text{"NULL": text("bar")}, Valid: true}, + {Map: map[string]pgtype.Text{"foo": text("NULL")}, Valid: true}, + {}, } specialStrings := []string{ @@ -52,22 +52,22 @@ func TestHstoreArrayTranscode(t *testing.T) { } for _, s := range specialStrings { // Special key values - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Valid: true}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Valid: true}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Valid: true}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Valid: true}) // is key // Special value values - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Valid: true}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Valid: true}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Valid: true}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Valid: true}) // is key } src := &pgtype.HstoreArray{ Elements: values, Dimensions: []pgtype.ArrayDimension{{Length: int32(len(values)), LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, } _, err = conn.Prepare(context.Background(), "test", "select $1::hstore[]") @@ -98,8 +98,8 @@ func TestHstoreArrayTranscode(t *testing.T) { continue } - if result.Status != src.Status { - t.Errorf("%v: expected Status %v, got %v", fc.formatCode, src.Status, result.Status) + if result.Valid != src.Valid { + t.Errorf("%v: expected Valid %v, got %v", fc.formatCode, src.Valid, result.Valid) continue } @@ -112,8 +112,8 @@ func TestHstoreArrayTranscode(t *testing.T) { a := src.Elements[i] b := result.Elements[i] - if a.Status != b.Status { - t.Errorf("%v element idx %d: expected status %v, got %v", fc.formatCode, i, a.Status, b.Status) + if a.Valid != b.Valid { + t.Errorf("%v element idx %d: expected Valid %v, got %v", fc.formatCode, i, a.Valid, b.Valid) } if len(a.Map) != len(b.Map) { @@ -139,12 +139,12 @@ func TestHstoreArraySet(t *testing.T) { result: pgtype.HstoreArray{ Elements: []pgtype.Hstore{ { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, }, }, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, }, { @@ -152,16 +152,16 @@ func TestHstoreArraySet(t *testing.T) { result: pgtype.HstoreArray{ Elements: []pgtype.Hstore{ { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, + Valid: true, }, }, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, }, { @@ -171,28 +171,28 @@ func TestHstoreArraySet(t *testing.T) { result: pgtype.HstoreArray{ Elements: []pgtype.Hstore{ { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"bar": {String: "baz", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Valid: true}}, + Valid: true, }, }, Dimensions: []pgtype.ArrayDimension{ @@ -200,7 +200,7 @@ func TestHstoreArraySet(t *testing.T) { {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present, + Valid: true, }, }, { @@ -208,16 +208,16 @@ func TestHstoreArraySet(t *testing.T) { result: pgtype.HstoreArray{ Elements: []pgtype.Hstore{ { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, + Valid: true, }, }, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, }, { @@ -227,28 +227,28 @@ func TestHstoreArraySet(t *testing.T) { result: pgtype.HstoreArray{ Elements: []pgtype.Hstore{ { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"bar": {String: "baz", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Valid: true}}, + Valid: true, }, }, Dimensions: []pgtype.ArrayDimension{ @@ -256,7 +256,7 @@ func TestHstoreArraySet(t *testing.T) { {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present, + Valid: true, }, }, } @@ -290,35 +290,35 @@ func TestHstoreArrayAssignTo(t *testing.T) { src: pgtype.HstoreArray{ Elements: []pgtype.Hstore{ { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, }, }, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &hstoreSlice, expected: []map[string]string{{"foo": "bar"}}}, { - src: pgtype.HstoreArray{Status: pgtype.Null}, dst: &hstoreSlice, expected: (([]map[string]string)(nil)), + src: pgtype.HstoreArray{}, dst: &hstoreSlice, expected: (([]map[string]string)(nil)), }, { - src: pgtype.HstoreArray{Status: pgtype.Present}, dst: &hstoreSlice, expected: []map[string]string{}, + src: pgtype.HstoreArray{Valid: true}, dst: &hstoreSlice, expected: []map[string]string{}, }, { src: pgtype.HstoreArray{ Elements: []pgtype.Hstore{ { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, + Valid: true, }, }, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &hstoreSliceDim2, expected: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, @@ -327,28 +327,28 @@ func TestHstoreArrayAssignTo(t *testing.T) { src: pgtype.HstoreArray{ Elements: []pgtype.Hstore{ { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"bar": {String: "baz", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Valid: true}}, + Valid: true, }, }, Dimensions: []pgtype.ArrayDimension{ @@ -356,7 +356,7 @@ func TestHstoreArrayAssignTo(t *testing.T) { {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present, + Valid: true, }, dst: &hstoreSliceDim4, expected: [][][][]map[string]string{ @@ -367,16 +367,16 @@ func TestHstoreArrayAssignTo(t *testing.T) { src: pgtype.HstoreArray{ Elements: []pgtype.Hstore{ { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, + Valid: true, }, }, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &hstoreArrayDim2, expected: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, @@ -385,28 +385,28 @@ func TestHstoreArrayAssignTo(t *testing.T) { src: pgtype.HstoreArray{ Elements: []pgtype.Hstore{ { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"bar": {String: "baz", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Valid: true}}, + Valid: true, }, { - Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Valid: true}}, + Valid: true, }, }, Dimensions: []pgtype.ArrayDimension{ @@ -414,7 +414,7 @@ func TestHstoreArrayAssignTo(t *testing.T) { {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present, + Valid: true, }, dst: &hstoreArrayDim4, expected: [2][1][1][3]map[string]string{ diff --git a/hstore_test.go b/hstore_test.go index 73ee0612..9c26a3df 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -10,22 +10,22 @@ import ( func TestHstoreTranscode(t *testing.T) { text := func(s string) pgtype.Text { - return pgtype.Text{String: s, Status: pgtype.Present} + return pgtype.Text{String: s, Valid: true} } values := []interface{}{ - &pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(""), "bar": text(""), "baz": text("123")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"": text("bar")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{}, Valid: true}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(""), "bar": text(""), "baz": text("123")}, Valid: true}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Valid: true}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Valid: true}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Valid: true}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Valid: true}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"": text("bar")}, Valid: true}, &pgtype.Hstore{ - Map: map[string]pgtype.Text{"a": text("a"), "b": {Status: pgtype.Null}, "c": text("c"), "d": {Status: pgtype.Null}, "e": text("e")}, - Status: pgtype.Present, + Map: map[string]pgtype.Text{"a": text("a"), "b": {}, "c": text("c"), "d": {}, "e": text("e")}, + Valid: true, }, - &pgtype.Hstore{Status: pgtype.Null}, + &pgtype.Hstore{}, } specialStrings := []string{ @@ -39,23 +39,23 @@ func TestHstoreTranscode(t *testing.T) { } for _, s := range specialStrings { // Special key values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Valid: true}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Valid: true}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Valid: true}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Valid: true}) // is key // Special value values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Valid: true}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Valid: true}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Valid: true}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Valid: true}) // is key } testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { a := ai.(pgtype.Hstore) b := bi.(pgtype.Hstore) - if len(a.Map) != len(b.Map) || a.Status != b.Status { + if len(a.Map) != len(b.Map) || a.Valid != b.Valid { return false } @@ -70,12 +70,12 @@ func TestHstoreTranscode(t *testing.T) { } func TestHstoreTranscodeNullable(t *testing.T) { - text := func(s string, status pgtype.Status) pgtype.Text { - return pgtype.Text{String: s, Status: status} + text := func(s string, valid bool) pgtype.Text { + return pgtype.Text{String: s, Valid: valid} } values := []interface{}{ - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("", pgtype.Null)}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("", false)}, Valid: true}, } specialStrings := []string{ @@ -89,17 +89,17 @@ func TestHstoreTranscodeNullable(t *testing.T) { } for _, s := range specialStrings { // Special key values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("", pgtype.Null)}, Status: pgtype.Present}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("", pgtype.Null)}, Status: pgtype.Present}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("", pgtype.Null)}, Status: pgtype.Present}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("", pgtype.Null)}, Status: pgtype.Present}) // is key + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("", false)}, Valid: true}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("", false)}, Valid: true}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("", false)}, Valid: true}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("", false)}, Valid: true}) // is key } testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { a := ai.(pgtype.Hstore) b := bi.(pgtype.Hstore) - if len(a.Map) != len(b.Map) || a.Status != b.Status { + if len(a.Map) != len(b.Map) || a.Valid != b.Valid { return false } @@ -118,7 +118,7 @@ func TestHstoreSet(t *testing.T) { src map[string]string result pgtype.Hstore }{ - {src: map[string]string{"foo": "bar"}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}}, + {src: map[string]string{"foo": "bar"}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, Valid: true}}, } for i, tt := range successfulTests { @@ -139,7 +139,7 @@ func TestHstoreSetNullable(t *testing.T) { src map[string]*string result pgtype.Hstore }{ - {src: map[string]*string{"foo": nil}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {Status: pgtype.Null}}, Status: pgtype.Present}}, + {src: map[string]*string{"foo": nil}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {}}, Valid: true}}, } for i, tt := range successfulTests { @@ -163,8 +163,8 @@ func TestHstoreAssignTo(t *testing.T) { dst *map[string]string expected map[string]string }{ - {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}, dst: &m, expected: map[string]string{"foo": "bar"}}, - {src: pgtype.Hstore{Status: pgtype.Null}, dst: &m, expected: ((map[string]string)(nil))}, + {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, Valid: true}, dst: &m, expected: map[string]string{"foo": "bar"}}, + {src: pgtype.Hstore{}, dst: &m, expected: ((map[string]string)(nil))}, } for i, tt := range simpleTests { @@ -187,8 +187,8 @@ func TestHstoreAssignToNullable(t *testing.T) { dst *map[string]*string expected map[string]*string }{ - {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {Status: pgtype.Null}}, Status: pgtype.Present}, dst: &m, expected: map[string]*string{"foo": nil}}, - {src: pgtype.Hstore{Status: pgtype.Null}, dst: &m, expected: ((map[string]*string)(nil))}, + {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {}}, Valid: true}, dst: &m, expected: map[string]*string{"foo": nil}}, + {src: pgtype.Hstore{}, dst: &m, expected: ((map[string]*string)(nil))}, } for i, tt := range simpleTests { diff --git a/inet.go b/inet.go index f35f88ba..4b3217a9 100644 --- a/inet.go +++ b/inet.go @@ -16,13 +16,13 @@ const ( // Inet represents both inet and cidr PostgreSQL types. type Inet struct { - IPNet *net.IPNet - Status Status + IPNet *net.IPNet + Valid bool } func (dst *Inet) Set(src interface{}) error { if src == nil { - *dst = Inet{Status: Null} + *dst = Inet{} return nil } @@ -35,14 +35,14 @@ func (dst *Inet) Set(src interface{}) error { switch value := src.(type) { case net.IPNet: - *dst = Inet{IPNet: &value, Status: Present} + *dst = Inet{IPNet: &value, Valid: true} case net.IP: if len(value) == 0 { - *dst = Inet{Status: Null} + *dst = Inet{} } else { bitCount := len(value) * 8 mask := net.CIDRMask(bitCount, bitCount) - *dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Status: Present} + *dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Valid: true} } case string: ip, ipnet, err := net.ParseCIDR(value) @@ -58,22 +58,22 @@ func (dst *Inet) Set(src interface{}) error { } } ipnet.IP = ip - *dst = Inet{IPNet: ipnet, Status: Present} + *dst = Inet{IPNet: ipnet, Valid: true} case *net.IPNet: if value == nil { - *dst = Inet{Status: Null} + *dst = Inet{} } else { return dst.Set(*value) } case *net.IP: if value == nil { - *dst = Inet{Status: Null} + *dst = Inet{} } else { return dst.Set(*value) } case *string: if value == nil { - *dst = Inet{Status: Null} + *dst = Inet{} } else { return dst.Set(*value) } @@ -88,51 +88,44 @@ func (dst *Inet) Set(src interface{}) error { } func (dst Inet) Get() interface{} { - switch dst.Status { - case Present: - return dst.IPNet - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.IPNet } func (src *Inet) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *net.IPNet: - *v = net.IPNet{ - IP: make(net.IP, len(src.IPNet.IP)), - Mask: make(net.IPMask, len(src.IPNet.Mask)), - } - copy(v.IP, src.IPNet.IP) - copy(v.Mask, src.IPNet.Mask) - return nil - case *net.IP: - if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = make(net.IP, len(src.IPNet.IP)) - copy(*v, src.IPNet.IP) - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + switch v := dst.(type) { + case *net.IPNet: + *v = net.IPNet{ + IP: make(net.IP, len(src.IPNet.IP)), + Mask: make(net.IPMask, len(src.IPNet.Mask)), + } + copy(v.IP, src.IPNet.IP) + copy(v.Mask, src.IPNet.Mask) + return nil + case *net.IP: + if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = make(net.IP, len(src.IPNet.IP)) + copy(*v, src.IPNet.IP) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } } func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Inet{Status: Null} + *dst = Inet{} return nil } @@ -158,13 +151,13 @@ func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { *ipnet = net.IPNet{IP: ip, Mask: net.CIDRMask(ones, len(ip)*8)} } - *dst = Inet{IPNet: ipnet, Status: Present} + *dst = Inet{IPNet: ipnet, Valid: true} return nil } func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Inet{Status: Null} + *dst = Inet{} return nil } @@ -185,17 +178,14 @@ func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { } ipnet.Mask = net.CIDRMask(int(bits), len(ipnet.IP)*8) - *dst = Inet{IPNet: &ipnet, Status: Present} + *dst = Inet{IPNet: &ipnet, Valid: true} return nil } func (src Inet) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return append(buf, src.IPNet.String()...), nil @@ -203,11 +193,8 @@ func (src Inet) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { // EncodeBinary encodes src into w. func (src Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } var family byte @@ -236,7 +223,7 @@ func (src Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Inet) Scan(src interface{}) error { if src == nil { - *dst = Inet{Status: Null} + *dst = Inet{} return nil } diff --git a/inet_array.go b/inet_array.go index 2460a1c4..7f41c4e5 100644 --- a/inet_array.go +++ b/inet_array.go @@ -15,13 +15,13 @@ import ( type InetArray struct { Elements []Inet Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *InetArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = InetArray{Status: Null} + *dst = InetArray{} return nil } @@ -37,9 +37,9 @@ func (dst *InetArray) Set(src interface{}) error { case []*net.IPNet: if value == nil { - *dst = InetArray{Status: Null} + *dst = InetArray{} } else if len(value) == 0 { - *dst = InetArray{Status: Present} + *dst = InetArray{Valid: true} } else { elements := make([]Inet, len(value)) for i := range value { @@ -50,15 +50,15 @@ func (dst *InetArray) Set(src interface{}) error { *dst = InetArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []net.IP: if value == nil { - *dst = InetArray{Status: Null} + *dst = InetArray{} } else if len(value) == 0 { - *dst = InetArray{Status: Present} + *dst = InetArray{Valid: true} } else { elements := make([]Inet, len(value)) for i := range value { @@ -69,15 +69,15 @@ func (dst *InetArray) Set(src interface{}) error { *dst = InetArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*net.IP: if value == nil { - *dst = InetArray{Status: Null} + *dst = InetArray{} } else if len(value) == 0 { - *dst = InetArray{Status: Present} + *dst = InetArray{Valid: true} } else { elements := make([]Inet, len(value)) for i := range value { @@ -88,20 +88,20 @@ func (dst *InetArray) Set(src interface{}) error { *dst = InetArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []Inet: if value == nil { - *dst = InetArray{Status: Null} + *dst = InetArray{} } else if len(value) == 0 { - *dst = InetArray{Status: Present} + *dst = InetArray{Valid: true} } else { *dst = InetArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -110,7 +110,7 @@ func (dst *InetArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = InetArray{Status: Null} + *dst = InetArray{} return nil } @@ -119,7 +119,7 @@ func (dst *InetArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for InetArray", src) } if elementsLength == 0 { - *dst = InetArray{Status: Present} + *dst = InetArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -132,7 +132,7 @@ func (dst *InetArray) Set(src interface{}) error { *dst = InetArray{ Elements: make([]Inet, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -199,93 +199,86 @@ func (dst *InetArray) setRecursive(value reflect.Value, index, dimension int) (i } func (dst InetArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *InetArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]*net.IPNet: - *v = make([]*net.IPNet, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]net.IP: - *v = make([]net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.IP: - *v = make([]*net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]*net.IPNet: + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]net.IP: + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*net.IP: + *v = make([]*net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *InetArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -337,7 +330,7 @@ func (src *InetArray) assignToRecursive(value reflect.Value, index, dimension in func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = InetArray{Status: Null} + *dst = InetArray{} return nil } @@ -366,14 +359,14 @@ func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = InetArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = InetArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = InetArray{Status: Null} + *dst = InetArray{} return nil } @@ -384,7 +377,7 @@ func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = InetArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = InetArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -409,16 +402,13 @@ func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = InetArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = InetArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -471,11 +461,8 @@ func (src InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -489,7 +476,7 @@ func (src InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/inet_array_test.go b/inet_array_test.go index 46dc7d12..1019c7eb 100644 --- a/inet_array_test.go +++ b/inet_array_test.go @@ -14,41 +14,41 @@ func TestInetArrayTranscode(t *testing.T) { &pgtype.InetArray{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.InetArray{ Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {Status: pgtype.Null}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.InetArray{Status: pgtype.Null}, + &pgtype.InetArray{}, &pgtype.InetArray{ Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - {Status: pgtype.Null}, - {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, + {}, + {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.InetArray{ Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } @@ -61,33 +61,33 @@ func TestInetArraySet(t *testing.T) { { source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]*net.IPNet)(nil)), - result: pgtype.InetArray{Status: pgtype.Null}, + result: pgtype.InetArray{}, }, { source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]net.IP)(nil)), - result: pgtype.InetArray{Status: pgtype.Null}, + result: pgtype.InetArray{}, }, { source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, result: pgtype.InetArray{ Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]*net.IPNet{ @@ -101,27 +101,27 @@ func TestInetArraySet(t *testing.T) { mustParseCIDR(t, "169.168.0.1/16")}}}}, result: pgtype.InetArray{ Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, result: pgtype.InetArray{ Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3]*net.IPNet{ @@ -135,18 +135,18 @@ func TestInetArraySet(t *testing.T) { mustParseCIDR(t, "169.168.0.1/16")}}}}, result: pgtype.InetArray{ Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -178,85 +178,85 @@ func TestInetArrayAssignTo(t *testing.T) { }{ { src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &ipnetSlice, expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, }, { src: pgtype.InetArray{ - Elements: []pgtype.Inet{{Status: pgtype.Null}}, + Elements: []pgtype.Inet{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &ipnetSlice, expected: []*net.IPNet{nil}, }, { src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &ipSlice, expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, }, { src: pgtype.InetArray{ - Elements: []pgtype.Inet{{Status: pgtype.Null}}, + Elements: []pgtype.Inet{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &ipSlice, expected: []net.IP{nil}, }, { - src: pgtype.InetArray{Status: pgtype.Null}, + src: pgtype.InetArray{}, dst: &ipnetSlice, expected: (([]*net.IPNet)(nil)), }, { - src: pgtype.InetArray{Status: pgtype.Present}, + src: pgtype.InetArray{Valid: true}, dst: &ipnetSlice, expected: []*net.IPNet{}, }, { - src: pgtype.InetArray{Status: pgtype.Null}, + src: pgtype.InetArray{}, dst: &ipSlice, expected: (([]net.IP)(nil)), }, { - src: pgtype.InetArray{Status: pgtype.Present}, + src: pgtype.InetArray{Valid: true}, dst: &ipSlice, expected: []net.IP{}, }, { src: pgtype.InetArray{ Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &ipSliceDim2, expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, }, { src: pgtype.InetArray{ Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &ipnetSliceDim4, expected: [][][][]*net.IPNet{ {{{ @@ -271,28 +271,28 @@ func TestInetArrayAssignTo(t *testing.T) { { src: pgtype.InetArray{ Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &ipArrayDim2, expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, }, { src: pgtype.InetArray{ Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &ipnetArrayDim4, expected: [2][1][1][3]*net.IPNet{ {{{ diff --git a/inet_test.go b/inet_test.go index 09c6b21f..c2a5dc28 100644 --- a/inet_test.go +++ b/inet_test.go @@ -12,35 +12,35 @@ import ( func TestInetTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "inet", []interface{}{ - &pgtype.Inet{IPNet: mustParseInet(t, "0.0.0.0/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "127.0.0.1/8"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "12.34.56.65/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "192.168.1.16/24"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "255.0.0.0/8"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "255.255.255.255/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "::1/64"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "::/0"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "::1/128"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), Status: pgtype.Present}, - &pgtype.Inet{Status: pgtype.Null}, + &pgtype.Inet{IPNet: mustParseInet(t, "0.0.0.0/32"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "127.0.0.1/8"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "12.34.56.65/32"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "192.168.1.16/24"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "255.0.0.0/8"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "255.255.255.255/32"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "::1/64"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "::/0"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "::1/128"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), Valid: true}, + &pgtype.Inet{}, }) } func TestCidrTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "cidr", []interface{}{ - &pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - &pgtype.Inet{Status: pgtype.Null}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, + &pgtype.Inet{}, }) } @@ -49,13 +49,13 @@ func TestInetSet(t *testing.T) { source interface{} result pgtype.Inet }{ - {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: "1.2.3.4/24", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("1.2.3.4"), Mask: net.CIDRMask(24, 32)}, Status: pgtype.Present}}, - {source: "10.0.0.1", result: pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Status: pgtype.Present}}, - {source: "2607:f8b0:4009:80b::200e", result: pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Status: pgtype.Present}}, - {source: net.ParseIP(""), result: pgtype.Inet{Status: pgtype.Null}}, + {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + {source: "1.2.3.4/24", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("1.2.3.4"), Mask: net.CIDRMask(24, 32)}, Valid: true}}, + {source: "10.0.0.1", result: pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Valid: true}}, + {source: "2607:f8b0:4009:80b::200e", result: pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Valid: true}}, + {source: net.ParseIP(""), result: pgtype.Inet{}}, } for i, tt := range successfulTests { @@ -66,8 +66,8 @@ func TestInetSet(t *testing.T) { continue } - assert.Equalf(t, tt.result.Status, r.Status, "%d: Status", i) - if tt.result.Status == pgtype.Present { + assert.Equalf(t, tt.result.Valid, r.Valid, "%d: Status", i) + if tt.result.Valid { assert.Equalf(t, tt.result.IPNet.Mask, r.IPNet.Mask, "%d: IP", i) assert.Truef(t, tt.result.IPNet.IP.Equal(r.IPNet.IP), "%d: Mask", i) } @@ -85,10 +85,10 @@ func TestInetAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, - {src: pgtype.Inet{Status: pgtype.Null}, dst: &pipnet, expected: ((*net.IPNet)(nil))}, - {src: pgtype.Inet{Status: pgtype.Null}, dst: &pip, expected: ((*net.IP)(nil))}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + {src: pgtype.Inet{}, dst: &pipnet, expected: ((*net.IPNet)(nil))}, + {src: pgtype.Inet{}, dst: &pip, expected: ((*net.IP)(nil))}, } for i, tt := range simpleTests { @@ -107,8 +107,8 @@ func TestInetAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, } for i, tt := range pointerAllocTests { @@ -126,8 +126,8 @@ func TestInetAssignTo(t *testing.T) { src pgtype.Inet dst interface{} }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip}, - {src: pgtype.Inet{Status: pgtype.Null}, dst: &ipnet}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Valid: true}, dst: &ip}, + {src: pgtype.Inet{}, dst: &ipnet}, } for i, tt := range errorTests { diff --git a/int2.go b/int2.go index 3eb5aeb5..bbfee1cf 100644 --- a/int2.go +++ b/int2.go @@ -11,13 +11,13 @@ import ( ) type Int2 struct { - Int int16 - Status Status + Int int16 + Valid bool } func (dst *Int2) Set(src interface{}) error { if src == nil { - *dst = Int2{Status: Null} + *dst = Int2{} return nil } @@ -30,16 +30,16 @@ func (dst *Int2) Set(src interface{}) error { switch value := src.(type) { case int8: - *dst = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Valid: true} case uint8: - *dst = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Valid: true} case int16: - *dst = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Valid: true} case uint16: if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *dst = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Valid: true} case int32: if value < math.MinInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) @@ -47,12 +47,12 @@ func (dst *Int2) Set(src interface{}) error { if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *dst = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Valid: true} case uint32: if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *dst = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Valid: true} case int64: if value < math.MinInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) @@ -60,12 +60,12 @@ func (dst *Int2) Set(src interface{}) error { if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *dst = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Valid: true} case uint64: if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *dst = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Valid: true} case int: if value < math.MinInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) @@ -73,103 +73,103 @@ func (dst *Int2) Set(src interface{}) error { if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *dst = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Valid: true} case uint: if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *dst = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Valid: true} case string: num, err := strconv.ParseInt(value, 10, 16) if err != nil { return err } - *dst = Int2{Int: int16(num), Status: Present} + *dst = Int2{Int: int16(num), Valid: true} case float32: if value > math.MaxInt16 { return fmt.Errorf("%f is greater than maximum value for Int2", value) } - *dst = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Valid: true} case float64: if value > math.MaxInt16 { return fmt.Errorf("%f is greater than maximum value for Int2", value) } - *dst = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Valid: true} case *int8: if value == nil { - *dst = Int2{Status: Null} + *dst = Int2{} } else { return dst.Set(*value) } case *uint8: if value == nil { - *dst = Int2{Status: Null} + *dst = Int2{} } else { return dst.Set(*value) } case *int16: if value == nil { - *dst = Int2{Status: Null} + *dst = Int2{} } else { return dst.Set(*value) } case *uint16: if value == nil { - *dst = Int2{Status: Null} + *dst = Int2{} } else { return dst.Set(*value) } case *int32: if value == nil { - *dst = Int2{Status: Null} + *dst = Int2{} } else { return dst.Set(*value) } case *uint32: if value == nil { - *dst = Int2{Status: Null} + *dst = Int2{} } else { return dst.Set(*value) } case *int64: if value == nil { - *dst = Int2{Status: Null} + *dst = Int2{} } else { return dst.Set(*value) } case *uint64: if value == nil { - *dst = Int2{Status: Null} + *dst = Int2{} } else { return dst.Set(*value) } case *int: if value == nil { - *dst = Int2{Status: Null} + *dst = Int2{} } else { return dst.Set(*value) } case *uint: if value == nil { - *dst = Int2{Status: Null} + *dst = Int2{} } else { return dst.Set(*value) } case *string: if value == nil { - *dst = Int2{Status: Null} + *dst = Int2{} } else { return dst.Set(*value) } case *float32: if value == nil { - *dst = Int2{Status: Null} + *dst = Int2{} } else { return dst.Set(*value) } case *float64: if value == nil { - *dst = Int2{Status: Null} + *dst = Int2{} } else { return dst.Set(*value) } @@ -184,23 +184,19 @@ func (dst *Int2) Set(src interface{}) error { } func (dst Int2) Get() interface{} { - switch dst.Status { - case Present: - return dst.Int - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.Int } func (src *Int2) AssignTo(dst interface{}) error { - return int64AssignTo(int64(src.Int), src.Status, dst) + return int64AssignTo(int64(src.Int), src.Valid, dst) } func (dst *Int2) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Int2{Status: Null} + *dst = Int2{} return nil } @@ -209,13 +205,13 @@ func (dst *Int2) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Int2{Int: int16(n), Status: Present} + *dst = Int2{Int: int16(n), Valid: true} return nil } func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Int2{Status: Null} + *dst = Int2{} return nil } @@ -224,27 +220,21 @@ func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { } n := int16(binary.BigEndian.Uint16(src)) - *dst = Int2{Int: n, Status: Present} + *dst = Int2{Int: n, Valid: true} return nil } func (src Int2) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil } func (src Int2) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return pgio.AppendInt16(buf, src.Int), nil @@ -253,7 +243,7 @@ func (src Int2) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Int2) Scan(src interface{}) error { if src == nil { - *dst = Int2{Status: Null} + *dst = Int2{} return nil } @@ -265,7 +255,7 @@ func (dst *Int2) Scan(src interface{}) error { if src > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", src) } - *dst = Int2{Int: int16(src), Status: Present} + *dst = Int2{Int: int16(src), Valid: true} return nil case string: return dst.DecodeText(nil, []byte(src)) @@ -280,25 +270,15 @@ func (dst *Int2) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Int2) Value() (driver.Value, error) { - switch src.Status { - case Present: - return int64(src.Int), nil - case Null: + if !src.Valid { return nil, nil - default: - return nil, errUndefined } + return int64(src.Int), nil } func (src Int2) MarshalJSON() ([]byte, error) { - switch src.Status { - case Present: - return []byte(strconv.FormatInt(int64(src.Int), 10)), nil - case Null: + if !src.Valid { return []byte("null"), nil - case Undefined: - return nil, errUndefined } - - return nil, errBadStatus + return []byte(strconv.FormatInt(int64(src.Int), 10)), nil } diff --git a/int2_array.go b/int2_array.go index a5133845..d96240dc 100644 --- a/int2_array.go +++ b/int2_array.go @@ -14,13 +14,13 @@ import ( type Int2Array struct { Elements []Int2 Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *Int2Array) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} return nil } @@ -36,9 +36,9 @@ func (dst *Int2Array) Set(src interface{}) error { case []int16: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { elements := make([]Int2, len(value)) for i := range value { @@ -49,15 +49,15 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*int16: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { elements := make([]Int2, len(value)) for i := range value { @@ -68,15 +68,15 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []uint16: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { elements := make([]Int2, len(value)) for i := range value { @@ -87,15 +87,15 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*uint16: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { elements := make([]Int2, len(value)) for i := range value { @@ -106,15 +106,15 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []int32: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { elements := make([]Int2, len(value)) for i := range value { @@ -125,15 +125,15 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*int32: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { elements := make([]Int2, len(value)) for i := range value { @@ -144,15 +144,15 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []uint32: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { elements := make([]Int2, len(value)) for i := range value { @@ -163,15 +163,15 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*uint32: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { elements := make([]Int2, len(value)) for i := range value { @@ -182,15 +182,15 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []int64: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { elements := make([]Int2, len(value)) for i := range value { @@ -201,15 +201,15 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*int64: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { elements := make([]Int2, len(value)) for i := range value { @@ -220,15 +220,15 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []uint64: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { elements := make([]Int2, len(value)) for i := range value { @@ -239,15 +239,15 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*uint64: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { elements := make([]Int2, len(value)) for i := range value { @@ -258,15 +258,15 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []int: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { elements := make([]Int2, len(value)) for i := range value { @@ -277,15 +277,15 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*int: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { elements := make([]Int2, len(value)) for i := range value { @@ -296,15 +296,15 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []uint: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { elements := make([]Int2, len(value)) for i := range value { @@ -315,15 +315,15 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*uint: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { elements := make([]Int2, len(value)) for i := range value { @@ -334,20 +334,20 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []Int2: if value == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} } else if len(value) == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} } else { *dst = Int2Array{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -356,7 +356,7 @@ func (dst *Int2Array) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} return nil } @@ -365,7 +365,7 @@ func (dst *Int2Array) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for Int2Array", src) } if elementsLength == 0 { - *dst = Int2Array{Status: Present} + *dst = Int2Array{Valid: true} return nil } if len(dimensions) == 0 { @@ -378,7 +378,7 @@ func (dst *Int2Array) Set(src interface{}) error { *dst = Int2Array{ Elements: make([]Int2, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -445,210 +445,203 @@ func (dst *Int2Array) setRecursive(value reflect.Value, index, dimension int) (i } func (dst Int2Array) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *Int2Array) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int16: - *v = make([]*int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint16: - *v = make([]*uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int32: - *v = make([]*int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint32: - *v = make([]*uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int: - *v = make([]int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int: - *v = make([]*int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint: - *v = make([]uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint: - *v = make([]*uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]int16: + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int16: + *v = make([]*int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint16: + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint16: + *v = make([]*uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int32: + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int32: + *v = make([]*int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint32: + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint32: + *v = make([]*uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int: + *v = make([]int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int: + *v = make([]*int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint: + *v = make([]uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint: + *v = make([]*uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *Int2Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -700,7 +693,7 @@ func (src *Int2Array) assignToRecursive(value reflect.Value, index, dimension in func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} return nil } @@ -729,14 +722,14 @@ func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Int2Array{Status: Null} + *dst = Int2Array{} return nil } @@ -747,7 +740,7 @@ func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = Int2Array{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = Int2Array{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -772,16 +765,13 @@ func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -834,11 +824,8 @@ func (src Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -852,7 +839,7 @@ func (src Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/int2_array_test.go b/int2_array_test.go index 17c37360..78dc532a 100644 --- a/int2_array_test.go +++ b/int2_array_test.go @@ -13,41 +13,41 @@ func TestInt2ArrayTranscode(t *testing.T) { &pgtype.Int2Array{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.Int2Array{ Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Status: pgtype.Null}, + {Int: 1, Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.Int2Array{Status: pgtype.Null}, + &pgtype.Int2Array{}, &pgtype.Int2Array{ Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: 6, Status: pgtype.Present}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {}, + {Int: 6, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.Int2Array{ Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } @@ -60,103 +60,103 @@ func TestInt2ArraySet(t *testing.T) { { source: []int64{1}, result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []int32{1}, result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []int16{1}, result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []int{1}, result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []uint64{1}, result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []uint32{1}, result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []uint16{1}, result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]int16)(nil)), - result: pgtype.Int2Array{Status: pgtype.Null}, + result: pgtype.Int2Array{}, }, { source: [][]int16{{1}, {2}}, result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, result: pgtype.Int2Array{ Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1]int16{{1}, {2}}, result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, result: pgtype.Int2Array{ Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -189,90 +189,90 @@ func TestInt2ArrayAssignTo(t *testing.T) { }{ { src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &int16Slice, expected: []int16{1}, }, { src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &uint16Slice, expected: []uint16{1}, }, { src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &namedInt16Slice, expected: _int16Slice{1}, }, { - src: pgtype.Int2Array{Status: pgtype.Null}, + src: pgtype.Int2Array{}, dst: &int16Slice, expected: (([]int16)(nil)), }, { - src: pgtype.Int2Array{Status: pgtype.Present}, + src: pgtype.Int2Array{Valid: true}, dst: &int16Slice, expected: []int16{}, }, { src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, expected: [][]int16{{1}, {2}}, dst: &int16SliceDim2, }, { src: pgtype.Int2Array{ Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, expected: [][][][]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, dst: &int16SliceDim4, }, { src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, expected: [2][1]int16{{1}, {2}}, dst: &int16ArrayDim2, }, { src: pgtype.Int2Array{ Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, expected: [2][1][1][3]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, dst: &int16ArrayDim4, }, @@ -295,39 +295,39 @@ func TestInt2ArrayAssignTo(t *testing.T) { }{ { src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Status: pgtype.Null}}, + Elements: []pgtype.Int2{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &int16Slice, }, { src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: -1, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: -1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &uint16Slice, }, { src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &int16ArrayDim2, }, { src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &int16Slice, }, { src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &int16ArrayDim4, }, } diff --git a/int2_test.go b/int2_test.go index 178eb278..6ed8fe90 100644 --- a/int2_test.go +++ b/int2_test.go @@ -11,12 +11,12 @@ import ( func TestInt2Transcode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ - &pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}, - &pgtype.Int2{Int: -1, Status: pgtype.Present}, - &pgtype.Int2{Int: 0, Status: pgtype.Present}, - &pgtype.Int2{Int: 1, Status: pgtype.Present}, - &pgtype.Int2{Int: math.MaxInt16, Status: pgtype.Present}, - &pgtype.Int2{Int: 0, Status: pgtype.Null}, + &pgtype.Int2{Int: math.MinInt16, Valid: true}, + &pgtype.Int2{Int: -1, Valid: true}, + &pgtype.Int2{Int: 0, Valid: true}, + &pgtype.Int2{Int: 1, Valid: true}, + &pgtype.Int2{Int: math.MaxInt16, Valid: true}, + &pgtype.Int2{Int: 0}, }) } @@ -25,22 +25,22 @@ func TestInt2Set(t *testing.T) { source interface{} result pgtype.Int2 }{ - {source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: float32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: float64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int8(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: int16(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: int32(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: int64(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: int8(-1), result: pgtype.Int2{Int: -1, Valid: true}}, + {source: int16(-1), result: pgtype.Int2{Int: -1, Valid: true}}, + {source: int32(-1), result: pgtype.Int2{Int: -1, Valid: true}}, + {source: int64(-1), result: pgtype.Int2{Int: -1, Valid: true}}, + {source: uint8(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: uint16(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: uint32(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: uint64(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: float32(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: float64(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: "1", result: pgtype.Int2{Int: 1, Valid: true}}, + {source: _int8(1), result: pgtype.Int2{Int: 1, Valid: true}}, } for i, tt := range successfulTests { @@ -76,19 +76,19 @@ func TestInt2AssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i, expected: int(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int2{Int: 0}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int2{Int: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, } for i, tt := range simpleTests { @@ -107,8 +107,8 @@ func TestInt2AssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &_pi8, expected: _int8(42)}, } for i, tt := range pointerAllocTests { @@ -126,13 +126,13 @@ func TestInt2AssignTo(t *testing.T) { src pgtype.Int2 dst interface{} }{ - {src: pgtype.Int2{Int: 150, Status: pgtype.Present}, dst: &i8}, - {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &i16}, + {src: pgtype.Int2{Int: 150, Valid: true}, dst: &i8}, + {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui8}, + {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui16}, + {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui32}, + {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui64}, + {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui}, + {src: pgtype.Int2{Int: 0}, dst: &i16}, } for i, tt := range errorTests { diff --git a/int4.go b/int4.go index 22b48e5e..6f1e61f3 100644 --- a/int4.go +++ b/int4.go @@ -12,13 +12,13 @@ import ( ) type Int4 struct { - Int int32 - Status Status + Int int32 + Valid bool } func (dst *Int4) Set(src interface{}) error { if src == nil { - *dst = Int4{Status: Null} + *dst = Int4{} return nil } @@ -31,20 +31,20 @@ func (dst *Int4) Set(src interface{}) error { switch value := src.(type) { case int8: - *dst = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Valid: true} case uint8: - *dst = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Valid: true} case int16: - *dst = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Valid: true} case uint16: - *dst = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Valid: true} case int32: - *dst = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Valid: true} case uint32: if value > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) } - *dst = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Valid: true} case int64: if value < math.MinInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) @@ -52,12 +52,12 @@ func (dst *Int4) Set(src interface{}) error { if value > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) } - *dst = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Valid: true} case uint64: if value > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) } - *dst = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Valid: true} case int: if value < math.MinInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) @@ -65,103 +65,103 @@ func (dst *Int4) Set(src interface{}) error { if value > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) } - *dst = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Valid: true} case uint: if value > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) } - *dst = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Valid: true} case string: num, err := strconv.ParseInt(value, 10, 32) if err != nil { return err } - *dst = Int4{Int: int32(num), Status: Present} + *dst = Int4{Int: int32(num), Valid: true} case float32: if value > math.MaxInt32 { return fmt.Errorf("%f is greater than maximum value for Int4", value) } - *dst = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Valid: true} case float64: if value > math.MaxInt32 { return fmt.Errorf("%f is greater than maximum value for Int4", value) } - *dst = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Valid: true} case *int8: if value == nil { - *dst = Int4{Status: Null} + *dst = Int4{} } else { return dst.Set(*value) } case *uint8: if value == nil { - *dst = Int4{Status: Null} + *dst = Int4{} } else { return dst.Set(*value) } case *int16: if value == nil { - *dst = Int4{Status: Null} + *dst = Int4{} } else { return dst.Set(*value) } case *uint16: if value == nil { - *dst = Int4{Status: Null} + *dst = Int4{} } else { return dst.Set(*value) } case *int32: if value == nil { - *dst = Int4{Status: Null} + *dst = Int4{} } else { return dst.Set(*value) } case *uint32: if value == nil { - *dst = Int4{Status: Null} + *dst = Int4{} } else { return dst.Set(*value) } case *int64: if value == nil { - *dst = Int4{Status: Null} + *dst = Int4{} } else { return dst.Set(*value) } case *uint64: if value == nil { - *dst = Int4{Status: Null} + *dst = Int4{} } else { return dst.Set(*value) } case *int: if value == nil { - *dst = Int4{Status: Null} + *dst = Int4{} } else { return dst.Set(*value) } case *uint: if value == nil { - *dst = Int4{Status: Null} + *dst = Int4{} } else { return dst.Set(*value) } case *string: if value == nil { - *dst = Int4{Status: Null} + *dst = Int4{} } else { return dst.Set(*value) } case *float32: if value == nil { - *dst = Int4{Status: Null} + *dst = Int4{} } else { return dst.Set(*value) } case *float64: if value == nil { - *dst = Int4{Status: Null} + *dst = Int4{} } else { return dst.Set(*value) } @@ -176,23 +176,19 @@ func (dst *Int4) Set(src interface{}) error { } func (dst Int4) Get() interface{} { - switch dst.Status { - case Present: - return dst.Int - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.Int } func (src *Int4) AssignTo(dst interface{}) error { - return int64AssignTo(int64(src.Int), src.Status, dst) + return int64AssignTo(int64(src.Int), src.Valid, dst) } func (dst *Int4) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Int4{Status: Null} + *dst = Int4{} return nil } @@ -201,13 +197,13 @@ func (dst *Int4) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Int4{Int: int32(n), Status: Present} + *dst = Int4{Int: int32(n), Valid: true} return nil } func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Int4{Status: Null} + *dst = Int4{} return nil } @@ -216,27 +212,21 @@ func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { } n := int32(binary.BigEndian.Uint32(src)) - *dst = Int4{Int: n, Status: Present} + *dst = Int4{Int: n, Valid: true} return nil } func (src Int4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil } func (src Int4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return pgio.AppendInt32(buf, src.Int), nil @@ -245,7 +235,7 @@ func (src Int4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Int4) Scan(src interface{}) error { if src == nil { - *dst = Int4{Status: Null} + *dst = Int4{} return nil } @@ -257,7 +247,7 @@ func (dst *Int4) Scan(src interface{}) error { if src > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", src) } - *dst = Int4{Int: int32(src), Status: Present} + *dst = Int4{Int: int32(src), Valid: true} return nil case string: return dst.DecodeText(nil, []byte(src)) @@ -272,27 +262,17 @@ func (dst *Int4) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Int4) Value() (driver.Value, error) { - switch src.Status { - case Present: - return int64(src.Int), nil - case Null: + if !src.Valid { return nil, nil - default: - return nil, errUndefined } + return int64(src.Int), nil } func (src Int4) MarshalJSON() ([]byte, error) { - switch src.Status { - case Present: - return []byte(strconv.FormatInt(int64(src.Int), 10)), nil - case Null: + if !src.Valid { return []byte("null"), nil - case Undefined: - return nil, errUndefined } - - return nil, errBadStatus + return []byte(strconv.FormatInt(int64(src.Int), 10)), nil } func (dst *Int4) UnmarshalJSON(b []byte) error { @@ -303,9 +283,9 @@ func (dst *Int4) UnmarshalJSON(b []byte) error { } if n == nil { - *dst = Int4{Status: Null} + *dst = Int4{} } else { - *dst = Int4{Int: *n, Status: Present} + *dst = Int4{Int: *n, Valid: true} } return nil diff --git a/int4_array.go b/int4_array.go index de26236f..e725e7a8 100644 --- a/int4_array.go +++ b/int4_array.go @@ -14,13 +14,13 @@ import ( type Int4Array struct { Elements []Int4 Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *Int4Array) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} return nil } @@ -36,9 +36,9 @@ func (dst *Int4Array) Set(src interface{}) error { case []int16: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { elements := make([]Int4, len(value)) for i := range value { @@ -49,15 +49,15 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*int16: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { elements := make([]Int4, len(value)) for i := range value { @@ -68,15 +68,15 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []uint16: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { elements := make([]Int4, len(value)) for i := range value { @@ -87,15 +87,15 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*uint16: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { elements := make([]Int4, len(value)) for i := range value { @@ -106,15 +106,15 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []int32: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { elements := make([]Int4, len(value)) for i := range value { @@ -125,15 +125,15 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*int32: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { elements := make([]Int4, len(value)) for i := range value { @@ -144,15 +144,15 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []uint32: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { elements := make([]Int4, len(value)) for i := range value { @@ -163,15 +163,15 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*uint32: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { elements := make([]Int4, len(value)) for i := range value { @@ -182,15 +182,15 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []int64: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { elements := make([]Int4, len(value)) for i := range value { @@ -201,15 +201,15 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*int64: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { elements := make([]Int4, len(value)) for i := range value { @@ -220,15 +220,15 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []uint64: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { elements := make([]Int4, len(value)) for i := range value { @@ -239,15 +239,15 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*uint64: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { elements := make([]Int4, len(value)) for i := range value { @@ -258,15 +258,15 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []int: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { elements := make([]Int4, len(value)) for i := range value { @@ -277,15 +277,15 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*int: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { elements := make([]Int4, len(value)) for i := range value { @@ -296,15 +296,15 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []uint: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { elements := make([]Int4, len(value)) for i := range value { @@ -315,15 +315,15 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*uint: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { elements := make([]Int4, len(value)) for i := range value { @@ -334,20 +334,20 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []Int4: if value == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} } else if len(value) == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} } else { *dst = Int4Array{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -356,7 +356,7 @@ func (dst *Int4Array) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} return nil } @@ -365,7 +365,7 @@ func (dst *Int4Array) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for Int4Array", src) } if elementsLength == 0 { - *dst = Int4Array{Status: Present} + *dst = Int4Array{Valid: true} return nil } if len(dimensions) == 0 { @@ -378,7 +378,7 @@ func (dst *Int4Array) Set(src interface{}) error { *dst = Int4Array{ Elements: make([]Int4, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -445,210 +445,203 @@ func (dst *Int4Array) setRecursive(value reflect.Value, index, dimension int) (i } func (dst Int4Array) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *Int4Array) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int16: - *v = make([]*int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint16: - *v = make([]*uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int32: - *v = make([]*int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint32: - *v = make([]*uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int: - *v = make([]int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int: - *v = make([]*int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint: - *v = make([]uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint: - *v = make([]*uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]int16: + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int16: + *v = make([]*int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint16: + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint16: + *v = make([]*uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int32: + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int32: + *v = make([]*int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint32: + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint32: + *v = make([]*uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int: + *v = make([]int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int: + *v = make([]*int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint: + *v = make([]uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint: + *v = make([]*uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *Int4Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -700,7 +693,7 @@ func (src *Int4Array) assignToRecursive(value reflect.Value, index, dimension in func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} return nil } @@ -729,14 +722,14 @@ func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = Int4Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = Int4Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Int4Array{Status: Null} + *dst = Int4Array{} return nil } @@ -747,7 +740,7 @@ func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = Int4Array{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = Int4Array{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -772,16 +765,13 @@ func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = Int4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = Int4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -834,11 +824,8 @@ func (src Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -852,7 +839,7 @@ func (src Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/int4_array_test.go b/int4_array_test.go index 110512a9..a9c9acd9 100644 --- a/int4_array_test.go +++ b/int4_array_test.go @@ -14,41 +14,41 @@ func TestInt4ArrayTranscode(t *testing.T) { &pgtype.Int4Array{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.Int4Array{ Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Status: pgtype.Null}, + {Int: 1, Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.Int4Array{Status: pgtype.Null}, + &pgtype.Int4Array{}, &pgtype.Int4Array{ Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: 6, Status: pgtype.Present}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {}, + {Int: 6, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.Int4Array{ Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } @@ -62,30 +62,30 @@ func TestInt4ArraySet(t *testing.T) { { source: []int64{1}, result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []int32{1}, result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []int16{1}, result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []int{1}, result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []int{1, math.MaxInt32 + 1, 2}, @@ -94,75 +94,75 @@ func TestInt4ArraySet(t *testing.T) { { source: []uint64{1}, result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []uint32{1}, result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []uint16{1}, result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]int32)(nil)), - result: pgtype.Int4Array{Status: pgtype.Null}, + result: pgtype.Int4Array{}, }, { source: [][]int32{{1}, {2}}, result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, result: pgtype.Int4Array{ Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1]int32{{1}, {2}}, result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, result: pgtype.Int4Array{ Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -203,90 +203,90 @@ func TestInt4ArrayAssignTo(t *testing.T) { }{ { src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &int32Slice, expected: []int32{1}, }, { src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &uint32Slice, expected: []uint32{1}, }, { src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &namedInt32Slice, expected: _int32Slice{1}, }, { - src: pgtype.Int4Array{Status: pgtype.Null}, + src: pgtype.Int4Array{}, dst: &int32Slice, expected: (([]int32)(nil)), }, { - src: pgtype.Int4Array{Status: pgtype.Present}, + src: pgtype.Int4Array{Valid: true}, dst: &int32Slice, expected: []int32{}, }, { src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, expected: [][]int32{{1}, {2}}, dst: &int32SliceDim2, }, { src: pgtype.Int4Array{ Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, expected: [][][][]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, dst: &int32SliceDim4, }, { src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, expected: [2][1]int32{{1}, {2}}, dst: &int32ArrayDim2, }, { src: pgtype.Int4Array{ Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, expected: [2][1][1][3]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, dst: &int32ArrayDim4, }, @@ -309,39 +309,39 @@ func TestInt4ArrayAssignTo(t *testing.T) { }{ { src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Status: pgtype.Null}}, + Elements: []pgtype.Int4{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &int32Slice, }, { src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: -1, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: -1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &uint32Slice, }, { src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &int32ArrayDim2, }, { src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &int32Slice, }, { src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &int32ArrayDim4, }, } diff --git a/int4_test.go b/int4_test.go index ae01114f..3085babd 100644 --- a/int4_test.go +++ b/int4_test.go @@ -11,12 +11,12 @@ import ( func TestInt4Transcode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ - &pgtype.Int4{Int: math.MinInt32, Status: pgtype.Present}, - &pgtype.Int4{Int: -1, Status: pgtype.Present}, - &pgtype.Int4{Int: 0, Status: pgtype.Present}, - &pgtype.Int4{Int: 1, Status: pgtype.Present}, - &pgtype.Int4{Int: math.MaxInt32, Status: pgtype.Present}, - &pgtype.Int4{Int: 0, Status: pgtype.Null}, + &pgtype.Int4{Int: math.MinInt32, Valid: true}, + &pgtype.Int4{Int: -1, Valid: true}, + &pgtype.Int4{Int: 0, Valid: true}, + &pgtype.Int4{Int: 1, Valid: true}, + &pgtype.Int4{Int: math.MaxInt32, Valid: true}, + &pgtype.Int4{Int: 0}, }) } @@ -25,22 +25,22 @@ func TestInt4Set(t *testing.T) { source interface{} result pgtype.Int4 }{ - {source: int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: float32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: float64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int8(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: int16(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: int32(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: int64(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: int8(-1), result: pgtype.Int4{Int: -1, Valid: true}}, + {source: int16(-1), result: pgtype.Int4{Int: -1, Valid: true}}, + {source: int32(-1), result: pgtype.Int4{Int: -1, Valid: true}}, + {source: int64(-1), result: pgtype.Int4{Int: -1, Valid: true}}, + {source: uint8(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: uint16(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: uint32(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: uint64(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: float32(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: float64(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: "1", result: pgtype.Int4{Int: 1, Valid: true}}, + {source: _int8(1), result: pgtype.Int4{Int: 1, Valid: true}}, } for i, tt := range successfulTests { @@ -76,19 +76,19 @@ func TestInt4AssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &i, expected: int(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int4{Int: 0}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int4{Int: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, } for i, tt := range simpleTests { @@ -107,8 +107,8 @@ func TestInt4AssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &_pi8, expected: _int8(42)}, } for i, tt := range pointerAllocTests { @@ -126,14 +126,14 @@ func TestInt4AssignTo(t *testing.T) { src pgtype.Int4 dst interface{} }{ - {src: pgtype.Int4{Int: 150, Status: pgtype.Present}, dst: &i8}, - {src: pgtype.Int4{Int: 40000, Status: pgtype.Present}, dst: &i16}, - {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &i32}, + {src: pgtype.Int4{Int: 150, Valid: true}, dst: &i8}, + {src: pgtype.Int4{Int: 40000, Valid: true}, dst: &i16}, + {src: pgtype.Int4{Int: -1, Valid: true}, dst: &ui8}, + {src: pgtype.Int4{Int: -1, Valid: true}, dst: &ui16}, + {src: pgtype.Int4{Int: -1, Valid: true}, dst: &ui32}, + {src: pgtype.Int4{Int: -1, Valid: true}, dst: &ui64}, + {src: pgtype.Int4{Int: -1, Valid: true}, dst: &ui}, + {src: pgtype.Int4{Int: 0}, dst: &i32}, } for i, tt := range errorTests { @@ -149,8 +149,8 @@ func TestInt4MarshalJSON(t *testing.T) { source pgtype.Int4 result string }{ - {source: pgtype.Int4{Int: 0, Status: pgtype.Null}, result: "null"}, - {source: pgtype.Int4{Int: 1, Status: pgtype.Present}, result: "1"}, + {source: pgtype.Int4{Int: 0}, result: "null"}, + {source: pgtype.Int4{Int: 1, Valid: true}, result: "1"}, } for i, tt := range successfulTests { r, err := tt.source.MarshalJSON() @@ -169,8 +169,8 @@ func TestInt4UnmarshalJSON(t *testing.T) { source string result pgtype.Int4 }{ - {source: "null", result: pgtype.Int4{Int: 0, Status: pgtype.Null}}, - {source: "1", result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: "null", result: pgtype.Int4{Int: 0}}, + {source: "1", result: pgtype.Int4{Int: 1, Valid: true}}, } for i, tt := range successfulTests { var r pgtype.Int4 diff --git a/int4range.go b/int4range.go index c7f51fa6..49503c0d 100644 --- a/int4range.go +++ b/int4range.go @@ -12,13 +12,13 @@ type Int4range struct { Upper Int4 LowerType BoundType UpperType BoundType - Status Status + Valid bool } func (dst *Int4range) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = Int4range{Status: Null} + *dst = Int4range{} return nil } @@ -36,15 +36,11 @@ func (dst *Int4range) Set(src interface{}) error { return nil } -func (dst Int4range) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: +func (src Int4range) Get() interface{} { + if !src.Valid { return nil - default: - return dst.Status } + return src } func (src *Int4range) AssignTo(dst interface{}) error { @@ -53,7 +49,7 @@ func (src *Int4range) AssignTo(dst interface{}) error { func (dst *Int4range) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Int4range{Status: Null} + *dst = Int4range{} return nil } @@ -62,7 +58,7 @@ func (dst *Int4range) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Int4range{Status: Present} + *dst = Int4range{Valid: true} dst.LowerType = utr.LowerType dst.UpperType = utr.UpperType @@ -88,7 +84,7 @@ func (dst *Int4range) DecodeText(ci *ConnInfo, src []byte) error { func (dst *Int4range) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Int4range{Status: Null} + *dst = Int4range{} return nil } @@ -97,7 +93,7 @@ func (dst *Int4range) DecodeBinary(ci *ConnInfo, src []byte) error { return err } - *dst = Int4range{Status: Present} + *dst = Int4range{Valid: true} dst.LowerType = ubr.LowerType dst.UpperType = ubr.UpperType @@ -122,11 +118,8 @@ func (dst *Int4range) DecodeBinary(ci *ConnInfo, src []byte) error { } func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } switch src.LowerType { @@ -175,11 +168,8 @@ func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } var rangeType byte @@ -245,7 +235,7 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Int4range) Scan(src interface{}) error { if src == nil { - *dst = Int4range{Status: Null} + *dst = Int4range{} return nil } diff --git a/int4range_test.go b/int4range_test.go index 43626189..8b990036 100644 --- a/int4range_test.go +++ b/int4range_test.go @@ -9,12 +9,12 @@ import ( func TestInt4rangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "int4range", []interface{}{ - &pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int4{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, Status: pgtype.Present}, - &pgtype.Int4range{Upper: pgtype.Int4{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int4range{Status: pgtype.Null}, + &pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Valid: true}, Upper: pgtype.Int4{Int: 10, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, + &pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Valid: true}, Upper: pgtype.Int4{Int: -5, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, + &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, Valid: true}, + &pgtype.Int4range{Upper: pgtype.Int4{Int: 1, Valid: true}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, Valid: true}, + &pgtype.Int4range{}, }) } @@ -22,7 +22,7 @@ func TestInt4rangeNormalize(t *testing.T) { testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { SQL: "select int4range(1, 10, '(]')", - Value: pgtype.Int4range{Lower: pgtype.Int4{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + Value: pgtype.Int4range{Lower: pgtype.Int4{Int: 2, Valid: true}, Upper: pgtype.Int4{Int: 11, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, }, }) } diff --git a/int8.go b/int8.go index 0e089979..794f92c6 100644 --- a/int8.go +++ b/int8.go @@ -12,13 +12,13 @@ import ( ) type Int8 struct { - Int int64 - Status Status + Int int64 + Valid bool } func (dst *Int8) Set(src interface{}) error { if src == nil { - *dst = Int8{Status: Null} + *dst = Int8{} return nil } @@ -31,24 +31,24 @@ func (dst *Int8) Set(src interface{}) error { switch value := src.(type) { case int8: - *dst = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Valid: true} case uint8: - *dst = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Valid: true} case int16: - *dst = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Valid: true} case uint16: - *dst = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Valid: true} case int32: - *dst = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Valid: true} case uint32: - *dst = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Valid: true} case int64: - *dst = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Valid: true} case uint64: if value > math.MaxInt64 { return fmt.Errorf("%d is greater than maximum value for Int8", value) } - *dst = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Valid: true} case int: if int64(value) < math.MinInt64 { return fmt.Errorf("%d is greater than maximum value for Int8", value) @@ -56,103 +56,103 @@ func (dst *Int8) Set(src interface{}) error { if int64(value) > math.MaxInt64 { return fmt.Errorf("%d is greater than maximum value for Int8", value) } - *dst = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Valid: true} case uint: if uint64(value) > math.MaxInt64 { return fmt.Errorf("%d is greater than maximum value for Int8", value) } - *dst = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Valid: true} case string: num, err := strconv.ParseInt(value, 10, 64) if err != nil { return err } - *dst = Int8{Int: num, Status: Present} + *dst = Int8{Int: num, Valid: true} case float32: if value > math.MaxInt64 { return fmt.Errorf("%f is greater than maximum value for Int8", value) } - *dst = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Valid: true} case float64: if value > math.MaxInt64 { return fmt.Errorf("%f is greater than maximum value for Int8", value) } - *dst = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Valid: true} case *int8: if value == nil { - *dst = Int8{Status: Null} + *dst = Int8{} } else { return dst.Set(*value) } case *uint8: if value == nil { - *dst = Int8{Status: Null} + *dst = Int8{} } else { return dst.Set(*value) } case *int16: if value == nil { - *dst = Int8{Status: Null} + *dst = Int8{} } else { return dst.Set(*value) } case *uint16: if value == nil { - *dst = Int8{Status: Null} + *dst = Int8{} } else { return dst.Set(*value) } case *int32: if value == nil { - *dst = Int8{Status: Null} + *dst = Int8{} } else { return dst.Set(*value) } case *uint32: if value == nil { - *dst = Int8{Status: Null} + *dst = Int8{} } else { return dst.Set(*value) } case *int64: if value == nil { - *dst = Int8{Status: Null} + *dst = Int8{} } else { return dst.Set(*value) } case *uint64: if value == nil { - *dst = Int8{Status: Null} + *dst = Int8{} } else { return dst.Set(*value) } case *int: if value == nil { - *dst = Int8{Status: Null} + *dst = Int8{} } else { return dst.Set(*value) } case *uint: if value == nil { - *dst = Int8{Status: Null} + *dst = Int8{} } else { return dst.Set(*value) } case *string: if value == nil { - *dst = Int8{Status: Null} + *dst = Int8{} } else { return dst.Set(*value) } case *float32: if value == nil { - *dst = Int8{Status: Null} + *dst = Int8{} } else { return dst.Set(*value) } case *float64: if value == nil { - *dst = Int8{Status: Null} + *dst = Int8{} } else { return dst.Set(*value) } @@ -167,23 +167,19 @@ func (dst *Int8) Set(src interface{}) error { } func (dst Int8) Get() interface{} { - switch dst.Status { - case Present: - return dst.Int - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.Int } func (src *Int8) AssignTo(dst interface{}) error { - return int64AssignTo(int64(src.Int), src.Status, dst) + return int64AssignTo(int64(src.Int), src.Valid, dst) } func (dst *Int8) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Int8{Status: Null} + *dst = Int8{} return nil } @@ -192,13 +188,13 @@ func (dst *Int8) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Int8{Int: n, Status: Present} + *dst = Int8{Int: n, Valid: true} return nil } func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Int8{Status: Null} + *dst = Int8{} return nil } @@ -208,27 +204,21 @@ func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { n := int64(binary.BigEndian.Uint64(src)) - *dst = Int8{Int: n, Status: Present} + *dst = Int8{Int: n, Valid: true} return nil } func (src Int8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return append(buf, strconv.FormatInt(src.Int, 10)...), nil } func (src Int8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return pgio.AppendInt64(buf, src.Int), nil @@ -237,13 +227,13 @@ func (src Int8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Int8) Scan(src interface{}) error { if src == nil { - *dst = Int8{Status: Null} + *dst = Int8{} return nil } switch src := src.(type) { case int64: - *dst = Int8{Int: src, Status: Present} + *dst = Int8{Int: src, Valid: true} return nil case string: return dst.DecodeText(nil, []byte(src)) @@ -258,27 +248,17 @@ func (dst *Int8) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Int8) Value() (driver.Value, error) { - switch src.Status { - case Present: - return int64(src.Int), nil - case Null: + if !src.Valid { return nil, nil - default: - return nil, errUndefined } + return int64(src.Int), nil } func (src Int8) MarshalJSON() ([]byte, error) { - switch src.Status { - case Present: - return []byte(strconv.FormatInt(src.Int, 10)), nil - case Null: + if !src.Valid { return []byte("null"), nil - case Undefined: - return nil, errUndefined } - - return nil, errBadStatus + return []byte(strconv.FormatInt(src.Int, 10)), nil } func (dst *Int8) UnmarshalJSON(b []byte) error { @@ -289,9 +269,9 @@ func (dst *Int8) UnmarshalJSON(b []byte) error { } if n == nil { - *dst = Int8{Status: Null} + *dst = Int8{} } else { - *dst = Int8{Int: *n, Status: Present} + *dst = Int8{Int: *n, Valid: true} } return nil diff --git a/int8_array.go b/int8_array.go index e405b326..d6f38994 100644 --- a/int8_array.go +++ b/int8_array.go @@ -14,13 +14,13 @@ import ( type Int8Array struct { Elements []Int8 Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *Int8Array) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} return nil } @@ -36,9 +36,9 @@ func (dst *Int8Array) Set(src interface{}) error { case []int16: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { elements := make([]Int8, len(value)) for i := range value { @@ -49,15 +49,15 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*int16: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { elements := make([]Int8, len(value)) for i := range value { @@ -68,15 +68,15 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []uint16: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { elements := make([]Int8, len(value)) for i := range value { @@ -87,15 +87,15 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*uint16: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { elements := make([]Int8, len(value)) for i := range value { @@ -106,15 +106,15 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []int32: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { elements := make([]Int8, len(value)) for i := range value { @@ -125,15 +125,15 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*int32: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { elements := make([]Int8, len(value)) for i := range value { @@ -144,15 +144,15 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []uint32: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { elements := make([]Int8, len(value)) for i := range value { @@ -163,15 +163,15 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*uint32: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { elements := make([]Int8, len(value)) for i := range value { @@ -182,15 +182,15 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []int64: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { elements := make([]Int8, len(value)) for i := range value { @@ -201,15 +201,15 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*int64: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { elements := make([]Int8, len(value)) for i := range value { @@ -220,15 +220,15 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []uint64: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { elements := make([]Int8, len(value)) for i := range value { @@ -239,15 +239,15 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*uint64: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { elements := make([]Int8, len(value)) for i := range value { @@ -258,15 +258,15 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []int: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { elements := make([]Int8, len(value)) for i := range value { @@ -277,15 +277,15 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*int: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { elements := make([]Int8, len(value)) for i := range value { @@ -296,15 +296,15 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []uint: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { elements := make([]Int8, len(value)) for i := range value { @@ -315,15 +315,15 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*uint: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { elements := make([]Int8, len(value)) for i := range value { @@ -334,20 +334,20 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []Int8: if value == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} } else if len(value) == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} } else { *dst = Int8Array{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -356,7 +356,7 @@ func (dst *Int8Array) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} return nil } @@ -365,7 +365,7 @@ func (dst *Int8Array) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for Int8Array", src) } if elementsLength == 0 { - *dst = Int8Array{Status: Present} + *dst = Int8Array{Valid: true} return nil } if len(dimensions) == 0 { @@ -378,7 +378,7 @@ func (dst *Int8Array) Set(src interface{}) error { *dst = Int8Array{ Elements: make([]Int8, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -445,210 +445,203 @@ func (dst *Int8Array) setRecursive(value reflect.Value, index, dimension int) (i } func (dst Int8Array) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *Int8Array) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int16: - *v = make([]*int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint16: - *v = make([]*uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int32: - *v = make([]*int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint32: - *v = make([]*uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int: - *v = make([]int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int: - *v = make([]*int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint: - *v = make([]uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint: - *v = make([]*uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]int16: + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int16: + *v = make([]*int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint16: + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint16: + *v = make([]*uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int32: + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int32: + *v = make([]*int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint32: + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint32: + *v = make([]*uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int: + *v = make([]int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int: + *v = make([]*int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint: + *v = make([]uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint: + *v = make([]*uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *Int8Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -700,7 +693,7 @@ func (src *Int8Array) assignToRecursive(value reflect.Value, index, dimension in func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} return nil } @@ -729,14 +722,14 @@ func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = Int8Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = Int8Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Int8Array{Status: Null} + *dst = Int8Array{} return nil } @@ -747,7 +740,7 @@ func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = Int8Array{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = Int8Array{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -772,16 +765,13 @@ func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = Int8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = Int8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -834,11 +824,8 @@ func (src Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -852,7 +839,7 @@ func (src Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/int8_array_test.go b/int8_array_test.go index 1d42a278..29eaf8cb 100644 --- a/int8_array_test.go +++ b/int8_array_test.go @@ -13,41 +13,41 @@ func TestInt8ArrayTranscode(t *testing.T) { &pgtype.Int8Array{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.Int8Array{ Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Status: pgtype.Null}, + {Int: 1, Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.Int8Array{Status: pgtype.Null}, + &pgtype.Int8Array{}, &pgtype.Int8Array{ Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: 6, Status: pgtype.Present}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {}, + {Int: 6, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.Int8Array{ Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } @@ -60,110 +60,110 @@ func TestInt8ArraySet(t *testing.T) { { source: []int64{1}, result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []int32{1}, result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []int16{1}, result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []int{1}, result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []uint64{1}, result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []uint32{1}, result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []uint16{1}, result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []uint{1}, result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]int64)(nil)), - result: pgtype.Int8Array{Status: pgtype.Null}, + result: pgtype.Int8Array{}, }, { source: [][]int64{{1}, {2}}, result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, result: pgtype.Int8Array{ Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1]int64{{1}, {2}}, result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, result: pgtype.Int8Array{ Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -196,90 +196,90 @@ func TestInt8ArrayAssignTo(t *testing.T) { }{ { src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &int64Slice, expected: []int64{1}, }, { src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &uint64Slice, expected: []uint64{1}, }, { src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &namedInt64Slice, expected: _int64Slice{1}, }, { - src: pgtype.Int8Array{Status: pgtype.Null}, + src: pgtype.Int8Array{}, dst: &int64Slice, expected: (([]int64)(nil)), }, { - src: pgtype.Int8Array{Status: pgtype.Present}, + src: pgtype.Int8Array{Valid: true}, dst: &int64Slice, expected: []int64{}, }, { src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, expected: [][]int64{{1}, {2}}, dst: &int64SliceDim2, }, { src: pgtype.Int8Array{ Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, expected: [][][][]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, dst: &int64SliceDim4, }, { src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, expected: [2][1]int64{{1}, {2}}, dst: &int64ArrayDim2, }, { src: pgtype.Int8Array{ Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, expected: [2][1][1][3]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, dst: &int64ArrayDim4, }, @@ -302,39 +302,39 @@ func TestInt8ArrayAssignTo(t *testing.T) { }{ { src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Status: pgtype.Null}}, + Elements: []pgtype.Int8{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &int64Slice, }, { src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: -1, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: -1, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &uint64Slice, }, { src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &int64ArrayDim2, }, { src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &int64Slice, }, { src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &int64ArrayDim4, }, } diff --git a/int8_test.go b/int8_test.go index 4e28e374..8aca741d 100644 --- a/int8_test.go +++ b/int8_test.go @@ -11,12 +11,12 @@ import ( func TestInt8Transcode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ - &pgtype.Int8{Int: math.MinInt64, Status: pgtype.Present}, - &pgtype.Int8{Int: -1, Status: pgtype.Present}, - &pgtype.Int8{Int: 0, Status: pgtype.Present}, - &pgtype.Int8{Int: 1, Status: pgtype.Present}, - &pgtype.Int8{Int: math.MaxInt64, Status: pgtype.Present}, - &pgtype.Int8{Int: 0, Status: pgtype.Null}, + &pgtype.Int8{Int: math.MinInt64, Valid: true}, + &pgtype.Int8{Int: -1, Valid: true}, + &pgtype.Int8{Int: 0, Valid: true}, + &pgtype.Int8{Int: 1, Valid: true}, + &pgtype.Int8{Int: math.MaxInt64, Valid: true}, + &pgtype.Int8{Int: 0}, }) } @@ -25,22 +25,22 @@ func TestInt8Set(t *testing.T) { source interface{} result pgtype.Int8 }{ - {source: int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: float32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: float64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int8(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: int16(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: int32(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: int64(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: int8(-1), result: pgtype.Int8{Int: -1, Valid: true}}, + {source: int16(-1), result: pgtype.Int8{Int: -1, Valid: true}}, + {source: int32(-1), result: pgtype.Int8{Int: -1, Valid: true}}, + {source: int64(-1), result: pgtype.Int8{Int: -1, Valid: true}}, + {source: uint8(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: uint16(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: uint32(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: uint64(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: float32(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: float64(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: "1", result: pgtype.Int8{Int: 1, Valid: true}}, + {source: _int8(1), result: pgtype.Int8{Int: 1, Valid: true}}, } for i, tt := range successfulTests { @@ -76,19 +76,19 @@ func TestInt8AssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &i, expected: int(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int8{Int: 0}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int8{Int: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, } for i, tt := range simpleTests { @@ -107,8 +107,8 @@ func TestInt8AssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &_pi8, expected: _int8(42)}, } for i, tt := range pointerAllocTests { @@ -126,15 +126,15 @@ func TestInt8AssignTo(t *testing.T) { src pgtype.Int8 dst interface{} }{ - {src: pgtype.Int8{Int: 150, Status: pgtype.Present}, dst: &i8}, - {src: pgtype.Int8{Int: 40000, Status: pgtype.Present}, dst: &i16}, - {src: pgtype.Int8{Int: 5000000000, Status: pgtype.Present}, dst: &i32}, - {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &i64}, + {src: pgtype.Int8{Int: 150, Valid: true}, dst: &i8}, + {src: pgtype.Int8{Int: 40000, Valid: true}, dst: &i16}, + {src: pgtype.Int8{Int: 5000000000, Valid: true}, dst: &i32}, + {src: pgtype.Int8{Int: -1, Valid: true}, dst: &ui8}, + {src: pgtype.Int8{Int: -1, Valid: true}, dst: &ui16}, + {src: pgtype.Int8{Int: -1, Valid: true}, dst: &ui32}, + {src: pgtype.Int8{Int: -1, Valid: true}, dst: &ui64}, + {src: pgtype.Int8{Int: -1, Valid: true}, dst: &ui}, + {src: pgtype.Int8{Int: 0}, dst: &i64}, } for i, tt := range errorTests { @@ -150,8 +150,8 @@ func TestInt8MarshalJSON(t *testing.T) { source pgtype.Int8 result string }{ - {source: pgtype.Int8{Int: 0, Status: pgtype.Null}, result: "null"}, - {source: pgtype.Int8{Int: 1, Status: pgtype.Present}, result: "1"}, + {source: pgtype.Int8{Int: 0}, result: "null"}, + {source: pgtype.Int8{Int: 1, Valid: true}, result: "1"}, } for i, tt := range successfulTests { r, err := tt.source.MarshalJSON() @@ -170,8 +170,8 @@ func TestInt8UnmarshalJSON(t *testing.T) { source string result pgtype.Int8 }{ - {source: "null", result: pgtype.Int8{Int: 0, Status: pgtype.Null}}, - {source: "1", result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: "null", result: pgtype.Int8{Int: 0}}, + {source: "1", result: pgtype.Int8{Int: 1, Valid: true}}, } for i, tt := range successfulTests { var r pgtype.Int8 diff --git a/int8range.go b/int8range.go index 71369373..a7cbcd12 100644 --- a/int8range.go +++ b/int8range.go @@ -12,13 +12,13 @@ type Int8range struct { Upper Int8 LowerType BoundType UpperType BoundType - Status Status + Valid bool } func (dst *Int8range) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = Int8range{Status: Null} + *dst = Int8range{} return nil } @@ -36,15 +36,11 @@ func (dst *Int8range) Set(src interface{}) error { return nil } -func (dst Int8range) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: +func (src Int8range) Get() interface{} { + if !src.Valid { return nil - default: - return dst.Status } + return src } func (src *Int8range) AssignTo(dst interface{}) error { @@ -53,7 +49,7 @@ func (src *Int8range) AssignTo(dst interface{}) error { func (dst *Int8range) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Int8range{Status: Null} + *dst = Int8range{} return nil } @@ -62,7 +58,7 @@ func (dst *Int8range) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Int8range{Status: Present} + *dst = Int8range{Valid: true} dst.LowerType = utr.LowerType dst.UpperType = utr.UpperType @@ -88,7 +84,7 @@ func (dst *Int8range) DecodeText(ci *ConnInfo, src []byte) error { func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Int8range{Status: Null} + *dst = Int8range{} return nil } @@ -97,7 +93,7 @@ func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { return err } - *dst = Int8range{Status: Present} + *dst = Int8range{Valid: true} dst.LowerType = ubr.LowerType dst.UpperType = ubr.UpperType @@ -122,11 +118,8 @@ func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { } func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } switch src.LowerType { @@ -175,11 +168,8 @@ func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } var rangeType byte @@ -245,7 +235,7 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Int8range) Scan(src interface{}) error { if src == nil { - *dst = Int8range{Status: Null} + *dst = Int8range{} return nil } diff --git a/int8range_test.go b/int8range_test.go index 99d4e8a3..f2e4098d 100644 --- a/int8range_test.go +++ b/int8range_test.go @@ -9,12 +9,12 @@ import ( func TestInt8rangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "Int8range", []interface{}{ - &pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int8{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, Status: pgtype.Present}, - &pgtype.Int8range{Upper: pgtype.Int8{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int8range{Status: pgtype.Null}, + &pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Valid: true}, Upper: pgtype.Int8{Int: 10, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, + &pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Valid: true}, Upper: pgtype.Int8{Int: -5, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, + &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, Valid: true}, + &pgtype.Int8range{Upper: pgtype.Int8{Int: 1, Valid: true}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, Valid: true}, + &pgtype.Int8range{}, }) } @@ -22,7 +22,7 @@ func TestInt8rangeNormalize(t *testing.T) { testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { SQL: "select Int8range(1, 10, '(]')", - Value: pgtype.Int8range{Lower: pgtype.Int8{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + Value: pgtype.Int8range{Lower: pgtype.Int8{Int: 2, Valid: true}, Upper: pgtype.Int8{Int: 11, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, }, }) } diff --git a/interval.go b/interval.go index b01fbb7c..a92cd41f 100644 --- a/interval.go +++ b/interval.go @@ -23,12 +23,12 @@ type Interval struct { Microseconds int64 Days int32 Months int32 - Status Status + Valid bool } func (dst *Interval) Set(src interface{}) error { if src == nil { - *dst = Interval{Status: Null} + *dst = Interval{} return nil } @@ -41,7 +41,7 @@ func (dst *Interval) Set(src interface{}) error { switch value := src.(type) { case time.Duration: - *dst = Interval{Microseconds: int64(value) / 1000, Status: Present} + *dst = Interval{Microseconds: int64(value) / 1000, Valid: true} default: if originalSrc, ok := underlyingPtrType(src); ok { return dst.Set(originalSrc) @@ -53,40 +53,33 @@ func (dst *Interval) Set(src interface{}) error { } func (dst Interval) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *Interval) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *time.Duration: - us := int64(src.Months)*microsecondsPerMonth + int64(src.Days)*microsecondsPerDay + src.Microseconds - *v = time.Duration(us) * time.Microsecond - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + switch v := dst.(type) { + case *time.Duration: + us := int64(src.Months)*microsecondsPerMonth + int64(src.Days)*microsecondsPerDay + src.Microseconds + *v = time.Duration(us) * time.Microsecond + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } } func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Interval{Status: Null} + *dst = Interval{} return nil } @@ -163,13 +156,13 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = Interval{Months: months, Days: days, Microseconds: microseconds, Status: Present} + *dst = Interval{Months: months, Days: days, Microseconds: microseconds, Valid: true} return nil } func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Interval{Status: Null} + *dst = Interval{} return nil } @@ -181,16 +174,13 @@ func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { days := int32(binary.BigEndian.Uint32(src[8:])) months := int32(binary.BigEndian.Uint32(src[12:])) - *dst = Interval{Microseconds: microseconds, Days: days, Months: months, Status: Present} + *dst = Interval{Microseconds: microseconds, Days: days, Months: months, Valid: true} return nil } func (src Interval) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if src.Months != 0 { @@ -220,11 +210,8 @@ func (src Interval) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { // EncodeBinary encodes src into w. func (src Interval) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = pgio.AppendInt64(buf, src.Microseconds) @@ -235,7 +222,7 @@ func (src Interval) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Interval) Scan(src interface{}) error { if src == nil { - *dst = Interval{Status: Null} + *dst = Interval{} return nil } diff --git a/interval_test.go b/interval_test.go index 1ee094d7..844f3866 100644 --- a/interval_test.go +++ b/interval_test.go @@ -12,23 +12,23 @@ import ( func TestIntervalTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "interval", []interface{}{ - &pgtype.Interval{Microseconds: 1, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, - &pgtype.Interval{Days: 1, Status: pgtype.Present}, - &pgtype.Interval{Months: 1, Status: pgtype.Present}, - &pgtype.Interval{Months: 12, Status: pgtype.Present}, - &pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: -1, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: -1000000, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: -1000001, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: -123202800000000, Status: pgtype.Present}, - &pgtype.Interval{Days: -1, Status: pgtype.Present}, - &pgtype.Interval{Months: -1, Status: pgtype.Present}, - &pgtype.Interval{Months: -12, Status: pgtype.Present}, - &pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Status: pgtype.Present}, - &pgtype.Interval{Status: pgtype.Null}, + &pgtype.Interval{Microseconds: 1, Valid: true}, + &pgtype.Interval{Microseconds: 1000000, Valid: true}, + &pgtype.Interval{Microseconds: 1000001, Valid: true}, + &pgtype.Interval{Microseconds: 123202800000000, Valid: true}, + &pgtype.Interval{Days: 1, Valid: true}, + &pgtype.Interval{Months: 1, Valid: true}, + &pgtype.Interval{Months: 12, Valid: true}, + &pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Valid: true}, + &pgtype.Interval{Microseconds: -1, Valid: true}, + &pgtype.Interval{Microseconds: -1000000, Valid: true}, + &pgtype.Interval{Microseconds: -1000001, Valid: true}, + &pgtype.Interval{Microseconds: -123202800000000, Valid: true}, + &pgtype.Interval{Days: -1, Valid: true}, + &pgtype.Interval{Months: -1, Valid: true}, + &pgtype.Interval{Months: -12, Valid: true}, + &pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Valid: true}, + &pgtype.Interval{}, }) } @@ -36,37 +36,37 @@ func TestIntervalNormalize(t *testing.T) { testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { SQL: "select '1 second'::interval", - Value: &pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, + Value: &pgtype.Interval{Microseconds: 1000000, Valid: true}, }, { SQL: "select '1.000001 second'::interval", - Value: &pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, + Value: &pgtype.Interval{Microseconds: 1000001, Valid: true}, }, { SQL: "select '34223 hours'::interval", - Value: &pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, + Value: &pgtype.Interval{Microseconds: 123202800000000, Valid: true}, }, { SQL: "select '1 day'::interval", - Value: &pgtype.Interval{Days: 1, Status: pgtype.Present}, + Value: &pgtype.Interval{Days: 1, Valid: true}, }, { SQL: "select '1 month'::interval", - Value: &pgtype.Interval{Months: 1, Status: pgtype.Present}, + Value: &pgtype.Interval{Months: 1, Valid: true}, }, { SQL: "select '1 year'::interval", - Value: &pgtype.Interval{Months: 12, Status: pgtype.Present}, + Value: &pgtype.Interval{Months: 12, Valid: true}, }, { SQL: "select '-13 mon'::interval", - Value: &pgtype.Interval{Months: -13, Status: pgtype.Present}, + Value: &pgtype.Interval{Months: -13, Valid: true}, }, }) } func TestIntervalLossyConversionToDuration(t *testing.T) { - interval := &pgtype.Interval{Months: 1, Days: 1, Status: pgtype.Present} + interval := &pgtype.Interval{Months: 1, Days: 1, Valid: true} var d time.Duration err := interval.AssignTo(&d) require.NoError(t, err) diff --git a/json.go b/json.go index 32bef5e7..580e8505 100644 --- a/json.go +++ b/json.go @@ -8,13 +8,13 @@ import ( ) type JSON struct { - Bytes []byte - Status Status + Bytes []byte + Valid bool } func (dst *JSON) Set(src interface{}) error { if src == nil { - *dst = JSON{Status: Null} + *dst = JSON{} return nil } @@ -27,18 +27,18 @@ func (dst *JSON) Set(src interface{}) error { switch value := src.(type) { case string: - *dst = JSON{Bytes: []byte(value), Status: Present} + *dst = JSON{Bytes: []byte(value), Valid: true} case *string: if value == nil { - *dst = JSON{Status: Null} + *dst = JSON{} } else { - *dst = JSON{Bytes: []byte(*value), Status: Present} + *dst = JSON{Bytes: []byte(*value), Valid: true} } case []byte: if value == nil { - *dst = JSON{Status: Null} + *dst = JSON{} } else { - *dst = JSON{Bytes: value, Status: Present} + *dst = JSON{Bytes: value, Valid: true} } // Encode* methods are defined on *JSON. If JSON is passed directly then the // struct itself would be encoded instead of Bytes. This is clearly a footgun @@ -54,38 +54,35 @@ func (dst *JSON) Set(src interface{}) error { if err != nil { return err } - *dst = JSON{Bytes: buf, Status: Present} + *dst = JSON{Bytes: buf, Valid: true} } return nil } func (dst JSON) Get() interface{} { - switch dst.Status { - case Present: - var i interface{} - err := json.Unmarshal(dst.Bytes, &i) - if err != nil { - return dst - } - return i - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + + var i interface{} + err := json.Unmarshal(dst.Bytes, &i) + if err != nil { + return dst + } + return i } func (src *JSON) AssignTo(dst interface{}) error { switch v := dst.(type) { case *string: - if src.Status == Present { + if src.Valid { *v = string(src.Bytes) } else { - return fmt.Errorf("cannot assign non-present status to %T", dst) + return fmt.Errorf("cannot assign non-valid to %T", dst) } case **string: - if src.Status == Present { + if src.Valid { s := string(src.Bytes) *v = &s return nil @@ -94,7 +91,7 @@ func (src *JSON) AssignTo(dst interface{}) error { return nil } case *[]byte: - if src.Status != Present { + if !src.Valid { *v = nil } else { buf := make([]byte, len(src.Bytes)) @@ -103,7 +100,7 @@ func (src *JSON) AssignTo(dst interface{}) error { } default: data := src.Bytes - if data == nil || src.Status != Present { + if data == nil || !src.Valid { data = []byte("null") } @@ -119,11 +116,11 @@ func (JSON) PreferredResultFormat() int16 { func (dst *JSON) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = JSON{Status: Null} + *dst = JSON{} return nil } - *dst = JSON{Bytes: src, Status: Present} + *dst = JSON{Bytes: src, Valid: true} return nil } @@ -136,11 +133,8 @@ func (JSON) PreferredParamFormat() int16 { } func (src JSON) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return append(buf, src.Bytes...), nil @@ -153,7 +147,7 @@ func (src JSON) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *JSON) Scan(src interface{}) error { if src == nil { - *dst = JSON{Status: Null} + *dst = JSON{} return nil } @@ -171,34 +165,24 @@ func (dst *JSON) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src JSON) Value() (driver.Value, error) { - switch src.Status { - case Present: - return src.Bytes, nil - case Null: + if !src.Valid { return nil, nil - default: - return nil, errUndefined } + return src.Bytes, nil } func (src JSON) MarshalJSON() ([]byte, error) { - switch src.Status { - case Present: - return src.Bytes, nil - case Null: + if !src.Valid { return []byte("null"), nil - case Undefined: - return nil, errUndefined } - - return nil, errBadStatus + return src.Bytes, nil } func (dst *JSON) UnmarshalJSON(b []byte) error { if b == nil || string(b) == "null" { - *dst = JSON{Status: Null} + *dst = JSON{} } else { - *dst = JSON{Bytes: b, Status: Present} + *dst = JSON{Bytes: b, Valid: true} } return nil diff --git a/json_test.go b/json_test.go index bbd3959e..c56f403f 100644 --- a/json_test.go +++ b/json_test.go @@ -11,11 +11,11 @@ import ( func TestJSONTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "json", []interface{}{ - &pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, - &pgtype.JSON{Bytes: []byte("null"), Status: pgtype.Present}, - &pgtype.JSON{Bytes: []byte("42"), Status: pgtype.Present}, - &pgtype.JSON{Bytes: []byte(`"hello"`), Status: pgtype.Present}, - &pgtype.JSON{Status: pgtype.Null}, + &pgtype.JSON{Bytes: []byte("{}"), Valid: true}, + &pgtype.JSON{Bytes: []byte("null"), Valid: true}, + &pgtype.JSON{Bytes: []byte("42"), Valid: true}, + &pgtype.JSON{Bytes: []byte(`"hello"`), Valid: true}, + &pgtype.JSON{}, }) } @@ -24,12 +24,12 @@ func TestJSONSet(t *testing.T) { source interface{} result pgtype.JSON }{ - {source: "{}", result: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: []byte("{}"), result: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: ([]byte)(nil), result: pgtype.JSON{Status: pgtype.Null}}, - {source: (*string)(nil), result: pgtype.JSON{Status: pgtype.Null}}, - {source: []int{1, 2, 3}, result: pgtype.JSON{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, - {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, + {source: "{}", result: pgtype.JSON{Bytes: []byte("{}"), Valid: true}}, + {source: []byte("{}"), result: pgtype.JSON{Bytes: []byte("{}"), Valid: true}}, + {source: ([]byte)(nil), result: pgtype.JSON{}}, + {source: (*string)(nil), result: pgtype.JSON{}}, + {source: []int{1, 2, 3}, result: pgtype.JSON{Bytes: []byte("[1,2,3]"), Valid: true}}, + {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Valid: true}}, } for i, tt := range successfulTests { @@ -55,7 +55,7 @@ func TestJSONAssignTo(t *testing.T) { dst *string expected string }{ - {src: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, + {src: pgtype.JSON{Bytes: []byte("{}"), Valid: true}, dst: &s, expected: "{}"}, } for i, tt := range rawStringTests { @@ -74,8 +74,8 @@ func TestJSONAssignTo(t *testing.T) { dst *[]byte expected []byte }{ - {src: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, - {src: pgtype.JSON{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, + {src: pgtype.JSON{Bytes: []byte("{}"), Valid: true}, dst: &b, expected: []byte("{}")}, + {src: pgtype.JSON{}, dst: &b, expected: (([]byte)(nil))}, } for i, tt := range rawBytesTests { @@ -101,8 +101,8 @@ func TestJSONAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, - {src: pgtype.JSON{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + {src: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Valid: true}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, + {src: pgtype.JSON{Bytes: []byte(`{"name":"John","age":42}`), Valid: true}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, } for i, tt := range unmarshalTests { err := tt.src.AssignTo(tt.dst) @@ -120,7 +120,7 @@ func TestJSONAssignTo(t *testing.T) { dst **string expected *string }{ - {src: pgtype.JSON{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + {src: pgtype.JSON{}, dst: &ps, expected: ((*string)(nil))}, } for i, tt := range pointerAllocTests { @@ -140,8 +140,8 @@ func TestJSONMarshalJSON(t *testing.T) { source pgtype.JSON result string }{ - {source: pgtype.JSON{Status: pgtype.Null}, result: "null"}, - {source: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Status: pgtype.Present}, result: "{\"a\": 1}"}, + {source: pgtype.JSON{}, result: "null"}, + {source: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Valid: true}, result: "{\"a\": 1}"}, } for i, tt := range successfulTests { r, err := tt.source.MarshalJSON() @@ -160,8 +160,8 @@ func TestJSONUnmarshalJSON(t *testing.T) { source string result pgtype.JSON }{ - {source: "null", result: pgtype.JSON{Status: pgtype.Null}}, - {source: "{\"a\": 1}", result: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Status: pgtype.Present}}, + {source: "null", result: pgtype.JSON{}}, + {source: "{\"a\": 1}", result: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Valid: true}}, } for i, tt := range successfulTests { var r pgtype.JSON @@ -170,7 +170,7 @@ func TestJSONUnmarshalJSON(t *testing.T) { t.Errorf("%d: %v", i, err) } - if string(r.Bytes) != string(tt.result.Bytes) || r.Status != tt.result.Status { + if string(r.Bytes) != string(tt.result.Bytes) || r.Valid != tt.result.Valid { t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) } } diff --git a/jsonb.go b/jsonb.go index c9dafc93..38d56499 100644 --- a/jsonb.go +++ b/jsonb.go @@ -29,7 +29,7 @@ func (dst *JSONB) DecodeText(ci *ConnInfo, src []byte) error { func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = JSONB{Status: Null} + *dst = JSONB{} return nil } @@ -41,7 +41,7 @@ func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { return fmt.Errorf("unknown jsonb version number %d", src[0]) } - *dst = JSONB{Bytes: src[1:], Status: Present} + *dst = JSONB{Bytes: src[1:], Valid: true} return nil } @@ -55,11 +55,8 @@ func (src JSONB) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src JSONB) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = append(buf, 1) diff --git a/jsonb_array.go b/jsonb_array.go index c4b7cd3d..81ed9f29 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -14,13 +14,13 @@ import ( type JSONBArray struct { Elements []JSONB Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *JSONBArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = JSONBArray{Status: Null} + *dst = JSONBArray{} return nil } @@ -36,9 +36,9 @@ func (dst *JSONBArray) Set(src interface{}) error { case []string: if value == nil { - *dst = JSONBArray{Status: Null} + *dst = JSONBArray{} } else if len(value) == 0 { - *dst = JSONBArray{Status: Present} + *dst = JSONBArray{Valid: true} } else { elements := make([]JSONB, len(value)) for i := range value { @@ -49,15 +49,15 @@ func (dst *JSONBArray) Set(src interface{}) error { *dst = JSONBArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case [][]byte: if value == nil { - *dst = JSONBArray{Status: Null} + *dst = JSONBArray{} } else if len(value) == 0 { - *dst = JSONBArray{Status: Present} + *dst = JSONBArray{Valid: true} } else { elements := make([]JSONB, len(value)) for i := range value { @@ -68,20 +68,20 @@ func (dst *JSONBArray) Set(src interface{}) error { *dst = JSONBArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []JSONB: if value == nil { - *dst = JSONBArray{Status: Null} + *dst = JSONBArray{} } else if len(value) == 0 { - *dst = JSONBArray{Status: Present} + *dst = JSONBArray{Valid: true} } else { *dst = JSONBArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -90,7 +90,7 @@ func (dst *JSONBArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = JSONBArray{Status: Null} + *dst = JSONBArray{} return nil } @@ -99,7 +99,7 @@ func (dst *JSONBArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for JSONBArray", src) } if elementsLength == 0 { - *dst = JSONBArray{Status: Present} + *dst = JSONBArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -112,7 +112,7 @@ func (dst *JSONBArray) Set(src interface{}) error { *dst = JSONBArray{ Elements: make([]JSONB, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -179,84 +179,77 @@ func (dst *JSONBArray) setRecursive(value reflect.Value, index, dimension int) ( } func (dst JSONBArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *JSONBArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[][]byte: - *v = make([][]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[][]byte: + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *JSONBArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -308,7 +301,7 @@ func (src *JSONBArray) assignToRecursive(value reflect.Value, index, dimension i func (dst *JSONBArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = JSONBArray{Status: Null} + *dst = JSONBArray{} return nil } @@ -337,14 +330,14 @@ func (dst *JSONBArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = JSONBArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = JSONBArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *JSONBArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = JSONBArray{Status: Null} + *dst = JSONBArray{} return nil } @@ -355,7 +348,7 @@ func (dst *JSONBArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = JSONBArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = JSONBArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -380,16 +373,13 @@ func (dst *JSONBArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = JSONBArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = JSONBArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src JSONBArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -442,11 +432,8 @@ func (src JSONBArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src JSONBArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -460,7 +447,7 @@ func (src JSONBArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/jsonb_array_test.go b/jsonb_array_test.go index 65f1777a..4f293e9e 100644 --- a/jsonb_array_test.go +++ b/jsonb_array_test.go @@ -12,25 +12,25 @@ func TestJSONBArrayTranscode(t *testing.T) { &pgtype.JSONBArray{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.JSONBArray{ Elements: []pgtype.JSONB{ - {Bytes: []byte(`"foo"`), Status: pgtype.Present}, - {Status: pgtype.Null}, + {Bytes: []byte(`"foo"`), Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.JSONBArray{Status: pgtype.Null}, + &pgtype.JSONBArray{}, &pgtype.JSONBArray{ Elements: []pgtype.JSONB{ - {Bytes: []byte(`"foo"`), Status: pgtype.Present}, - {Bytes: []byte("null"), Status: pgtype.Present}, - {Bytes: []byte("42"), Status: pgtype.Present}, + {Bytes: []byte(`"foo"`), Valid: true}, + {Bytes: []byte("null"), Valid: true}, + {Bytes: []byte("42"), Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, }) } diff --git a/jsonb_test.go b/jsonb_test.go index 9ce80d42..41df18fa 100644 --- a/jsonb_test.go +++ b/jsonb_test.go @@ -17,11 +17,11 @@ func TestJSONBTranscode(t *testing.T) { } testutil.TestSuccessfulTranscode(t, "jsonb", []interface{}{ - &pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, - &pgtype.JSONB{Bytes: []byte("null"), Status: pgtype.Present}, - &pgtype.JSONB{Bytes: []byte("42"), Status: pgtype.Present}, - &pgtype.JSONB{Bytes: []byte(`"hello"`), Status: pgtype.Present}, - &pgtype.JSONB{Status: pgtype.Null}, + &pgtype.JSONB{Bytes: []byte("{}"), Valid: true}, + &pgtype.JSONB{Bytes: []byte("null"), Valid: true}, + &pgtype.JSONB{Bytes: []byte("42"), Valid: true}, + &pgtype.JSONB{Bytes: []byte(`"hello"`), Valid: true}, + &pgtype.JSONB{}, }) } @@ -30,12 +30,12 @@ func TestJSONBSet(t *testing.T) { source interface{} result pgtype.JSONB }{ - {source: "{}", result: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: []byte("{}"), result: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: ([]byte)(nil), result: pgtype.JSONB{Status: pgtype.Null}}, - {source: (*string)(nil), result: pgtype.JSONB{Status: pgtype.Null}}, - {source: []int{1, 2, 3}, result: pgtype.JSONB{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, - {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, + {source: "{}", result: pgtype.JSONB{Bytes: []byte("{}"), Valid: true}}, + {source: []byte("{}"), result: pgtype.JSONB{Bytes: []byte("{}"), Valid: true}}, + {source: ([]byte)(nil), result: pgtype.JSONB{}}, + {source: (*string)(nil), result: pgtype.JSONB{}}, + {source: []int{1, 2, 3}, result: pgtype.JSONB{Bytes: []byte("[1,2,3]"), Valid: true}}, + {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Valid: true}}, } for i, tt := range successfulTests { @@ -61,7 +61,7 @@ func TestJSONBAssignTo(t *testing.T) { dst *string expected string }{ - {src: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, + {src: pgtype.JSONB{Bytes: []byte("{}"), Valid: true}, dst: &s, expected: "{}"}, } for i, tt := range rawStringTests { @@ -80,8 +80,8 @@ func TestJSONBAssignTo(t *testing.T) { dst *[]byte expected []byte }{ - {src: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, - {src: pgtype.JSONB{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, + {src: pgtype.JSONB{Bytes: []byte("{}"), Valid: true}, dst: &b, expected: []byte("{}")}, + {src: pgtype.JSONB{}, dst: &b, expected: (([]byte)(nil))}, } for i, tt := range rawBytesTests { @@ -107,8 +107,8 @@ func TestJSONBAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, - {src: pgtype.JSONB{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + {src: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Valid: true}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, + {src: pgtype.JSONB{Bytes: []byte(`{"name":"John","age":42}`), Valid: true}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, } for i, tt := range unmarshalTests { err := tt.src.AssignTo(tt.dst) @@ -126,7 +126,7 @@ func TestJSONBAssignTo(t *testing.T) { dst **string expected *string }{ - {src: pgtype.JSONB{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + {src: pgtype.JSONB{}, dst: &ps, expected: ((*string)(nil))}, } for i, tt := range pointerAllocTests { diff --git a/line.go b/line.go index 3564b174..c3192b2a 100644 --- a/line.go +++ b/line.go @@ -13,7 +13,7 @@ import ( type Line struct { A, B, C float64 - Status Status + Valid bool } func (dst *Line) Set(src interface{}) error { @@ -21,14 +21,10 @@ func (dst *Line) Set(src interface{}) error { } func (dst Line) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *Line) AssignTo(dst interface{}) error { @@ -37,7 +33,7 @@ func (src *Line) AssignTo(dst interface{}) error { func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Line{Status: Null} + *dst = Line{} return nil } @@ -65,13 +61,13 @@ func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Line{A: a, B: b, C: c, Status: Present} + *dst = Line{A: a, B: b, C: c, Valid: true} return nil } func (dst *Line) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Line{Status: Null} + *dst = Line{} return nil } @@ -84,20 +80,17 @@ func (dst *Line) DecodeBinary(ci *ConnInfo, src []byte) error { c := binary.BigEndian.Uint64(src[16:]) *dst = Line{ - A: math.Float64frombits(a), - B: math.Float64frombits(b), - C: math.Float64frombits(c), - Status: Present, + A: math.Float64frombits(a), + B: math.Float64frombits(b), + C: math.Float64frombits(c), + Valid: true, } return nil } func (src Line) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = append(buf, fmt.Sprintf(`{%s,%s,%s}`, @@ -110,11 +103,8 @@ func (src Line) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Line) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = pgio.AppendUint64(buf, math.Float64bits(src.A)) @@ -126,7 +116,7 @@ func (src Line) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Line) Scan(src interface{}) error { if src == nil { - *dst = Line{Status: Null} + *dst = Line{} return nil } diff --git a/line_test.go b/line_test.go index f697ac43..c47f6512 100644 --- a/line_test.go +++ b/line_test.go @@ -27,12 +27,12 @@ func TestLineTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "line", []interface{}{ &pgtype.Line{ A: 1.23, B: 4.56, C: 7.89012345, - Status: pgtype.Present, + Valid: true, }, &pgtype.Line{ A: -1.23, B: -4.56, C: -7.89, - Status: pgtype.Present, + Valid: true, }, - &pgtype.Line{Status: pgtype.Null}, + &pgtype.Line{}, }) } diff --git a/lseg.go b/lseg.go index 5c4babb6..649863ca 100644 --- a/lseg.go +++ b/lseg.go @@ -12,8 +12,8 @@ import ( ) type Lseg struct { - P [2]Vec2 - Status Status + P [2]Vec2 + Valid bool } func (dst *Lseg) Set(src interface{}) error { @@ -21,14 +21,10 @@ func (dst *Lseg) Set(src interface{}) error { } func (dst Lseg) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *Lseg) AssignTo(dst interface{}) error { @@ -37,7 +33,7 @@ func (src *Lseg) AssignTo(dst interface{}) error { func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Lseg{Status: Null} + *dst = Lseg{} return nil } @@ -78,13 +74,13 @@ func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Lseg{P: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present} + *dst = Lseg{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true} return nil } func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Lseg{Status: Null} + *dst = Lseg{} return nil } @@ -102,17 +98,14 @@ func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { {math.Float64frombits(x1), math.Float64frombits(y1)}, {math.Float64frombits(x2), math.Float64frombits(y2)}, }, - Status: Present, + Valid: true, } return nil } func (src Lseg) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, @@ -126,11 +119,8 @@ func (src Lseg) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Lseg) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) @@ -143,7 +133,7 @@ func (src Lseg) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Lseg) Scan(src interface{}) error { if src == nil { - *dst = Lseg{Status: Null} + *dst = Lseg{} return nil } diff --git a/lseg_test.go b/lseg_test.go index b75297cc..af2faf3f 100644 --- a/lseg_test.go +++ b/lseg_test.go @@ -10,13 +10,13 @@ import ( func TestLsegTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "lseg", []interface{}{ &pgtype.Lseg{ - P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, - Status: pgtype.Present, + P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, + Valid: true, }, &pgtype.Lseg{ - P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, - Status: pgtype.Present, + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Valid: true, }, - &pgtype.Lseg{Status: pgtype.Null}, + &pgtype.Lseg{}, }) } diff --git a/macaddr.go b/macaddr.go index 1d3cfe7b..8d6ab720 100644 --- a/macaddr.go +++ b/macaddr.go @@ -7,13 +7,13 @@ import ( ) type Macaddr struct { - Addr net.HardwareAddr - Status Status + Addr net.HardwareAddr + Valid bool } func (dst *Macaddr) Set(src interface{}) error { if src == nil { - *dst = Macaddr{Status: Null} + *dst = Macaddr{} return nil } @@ -28,22 +28,22 @@ func (dst *Macaddr) Set(src interface{}) error { case net.HardwareAddr: addr := make(net.HardwareAddr, len(value)) copy(addr, value) - *dst = Macaddr{Addr: addr, Status: Present} + *dst = Macaddr{Addr: addr, Valid: true} case string: addr, err := net.ParseMAC(value) if err != nil { return err } - *dst = Macaddr{Addr: addr, Status: Present} + *dst = Macaddr{Addr: addr, Valid: true} case *net.HardwareAddr: if value == nil { - *dst = Macaddr{Status: Null} + *dst = Macaddr{} } else { return dst.Set(*value) } case *string: if value == nil { - *dst = Macaddr{Status: Null} + *dst = Macaddr{} } else { return dst.Set(*value) } @@ -58,43 +58,36 @@ func (dst *Macaddr) Set(src interface{}) error { } func (dst Macaddr) Get() interface{} { - switch dst.Status { - case Present: - return dst.Addr - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.Addr } func (src *Macaddr) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *net.HardwareAddr: - *v = make(net.HardwareAddr, len(src.Addr)) - copy(*v, src.Addr) - return nil - case *string: - *v = src.Addr.String() - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + switch v := dst.(type) { + case *net.HardwareAddr: + *v = make(net.HardwareAddr, len(src.Addr)) + copy(*v, src.Addr) + return nil + case *string: + *v = src.Addr.String() + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } } func (dst *Macaddr) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Macaddr{Status: Null} + *dst = Macaddr{} return nil } @@ -103,13 +96,13 @@ func (dst *Macaddr) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Macaddr{Addr: addr, Status: Present} + *dst = Macaddr{Addr: addr, Valid: true} return nil } func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Macaddr{Status: Null} + *dst = Macaddr{} return nil } @@ -120,17 +113,14 @@ func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { addr := make(net.HardwareAddr, 6) copy(addr, src) - *dst = Macaddr{Addr: addr, Status: Present} + *dst = Macaddr{Addr: addr, Valid: true} return nil } func (src Macaddr) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return append(buf, src.Addr.String()...), nil @@ -138,11 +128,8 @@ func (src Macaddr) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { // EncodeBinary encodes src into w. func (src Macaddr) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return append(buf, src.Addr...), nil @@ -151,7 +138,7 @@ func (src Macaddr) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Macaddr) Scan(src interface{}) error { if src == nil { - *dst = Macaddr{Status: Null} + *dst = Macaddr{} return nil } diff --git a/macaddr_array.go b/macaddr_array.go index bdb1f203..78a93a2d 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -15,13 +15,13 @@ import ( type MacaddrArray struct { Elements []Macaddr Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *MacaddrArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = MacaddrArray{Status: Null} + *dst = MacaddrArray{} return nil } @@ -37,9 +37,9 @@ func (dst *MacaddrArray) Set(src interface{}) error { case []net.HardwareAddr: if value == nil { - *dst = MacaddrArray{Status: Null} + *dst = MacaddrArray{} } else if len(value) == 0 { - *dst = MacaddrArray{Status: Present} + *dst = MacaddrArray{Valid: true} } else { elements := make([]Macaddr, len(value)) for i := range value { @@ -50,15 +50,15 @@ func (dst *MacaddrArray) Set(src interface{}) error { *dst = MacaddrArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*net.HardwareAddr: if value == nil { - *dst = MacaddrArray{Status: Null} + *dst = MacaddrArray{} } else if len(value) == 0 { - *dst = MacaddrArray{Status: Present} + *dst = MacaddrArray{Valid: true} } else { elements := make([]Macaddr, len(value)) for i := range value { @@ -69,20 +69,20 @@ func (dst *MacaddrArray) Set(src interface{}) error { *dst = MacaddrArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []Macaddr: if value == nil { - *dst = MacaddrArray{Status: Null} + *dst = MacaddrArray{} } else if len(value) == 0 { - *dst = MacaddrArray{Status: Present} + *dst = MacaddrArray{Valid: true} } else { *dst = MacaddrArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -91,7 +91,7 @@ func (dst *MacaddrArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = MacaddrArray{Status: Null} + *dst = MacaddrArray{} return nil } @@ -100,7 +100,7 @@ func (dst *MacaddrArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for MacaddrArray", src) } if elementsLength == 0 { - *dst = MacaddrArray{Status: Present} + *dst = MacaddrArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -113,7 +113,7 @@ func (dst *MacaddrArray) Set(src interface{}) error { *dst = MacaddrArray{ Elements: make([]Macaddr, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -180,84 +180,77 @@ func (dst *MacaddrArray) setRecursive(value reflect.Value, index, dimension int) } func (dst MacaddrArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *MacaddrArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]net.HardwareAddr: - *v = make([]net.HardwareAddr, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.HardwareAddr: - *v = make([]*net.HardwareAddr, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]net.HardwareAddr: + *v = make([]net.HardwareAddr, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*net.HardwareAddr: + *v = make([]*net.HardwareAddr, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *MacaddrArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -309,7 +302,7 @@ func (src *MacaddrArray) assignToRecursive(value reflect.Value, index, dimension func (dst *MacaddrArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = MacaddrArray{Status: Null} + *dst = MacaddrArray{} return nil } @@ -338,14 +331,14 @@ func (dst *MacaddrArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = MacaddrArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = MacaddrArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *MacaddrArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = MacaddrArray{Status: Null} + *dst = MacaddrArray{} return nil } @@ -356,7 +349,7 @@ func (dst *MacaddrArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = MacaddrArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = MacaddrArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -381,16 +374,13 @@ func (dst *MacaddrArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = MacaddrArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = MacaddrArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src MacaddrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -443,11 +433,8 @@ func (src MacaddrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src MacaddrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -461,7 +448,7 @@ func (src MacaddrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/macaddr_array_test.go b/macaddr_array_test.go index c1a8b72d..a4a55cb0 100644 --- a/macaddr_array_test.go +++ b/macaddr_array_test.go @@ -14,17 +14,17 @@ func TestMacaddrArrayTranscode(t *testing.T) { &pgtype.MacaddrArray{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.MacaddrArray{ Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Status: pgtype.Null}, + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.MacaddrArray{Status: pgtype.Null}, + &pgtype.MacaddrArray{}, }) } @@ -36,13 +36,13 @@ func TestMacaddrArraySet(t *testing.T) { { source: []net.HardwareAddr{mustParseMacaddr(t, "01:23:45:67:89:ab")}, result: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}}, + Elements: []pgtype.Macaddr{{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]net.HardwareAddr)(nil)), - result: pgtype.MacaddrArray{Status: pgtype.Null}, + result: pgtype.MacaddrArray{}, }, { source: [][]net.HardwareAddr{ @@ -50,10 +50,10 @@ func TestMacaddrArraySet(t *testing.T) { {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, result: pgtype.MacaddrArray{ Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]net.HardwareAddr{ @@ -67,18 +67,18 @@ func TestMacaddrArraySet(t *testing.T) { mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, result: pgtype.MacaddrArray{ Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Valid: true}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Valid: true}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Valid: true}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1]net.HardwareAddr{ @@ -86,10 +86,10 @@ func TestMacaddrArraySet(t *testing.T) { {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, result: pgtype.MacaddrArray{ Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3]net.HardwareAddr{ @@ -103,18 +103,18 @@ func TestMacaddrArraySet(t *testing.T) { mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, result: pgtype.MacaddrArray{ Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Valid: true}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Valid: true}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Valid: true}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -145,39 +145,39 @@ func TestMacaddrArrayAssignTo(t *testing.T) { }{ { src: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}}, + Elements: []pgtype.Macaddr{{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &macaddrSlice, expected: []net.HardwareAddr{mustParseMacaddr(t, "01:23:45:67:89:ab")}, }, { src: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{{Status: pgtype.Null}}, + Elements: []pgtype.Macaddr{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &macaddrSlice, expected: []net.HardwareAddr{nil}, }, { - src: pgtype.MacaddrArray{Status: pgtype.Null}, + src: pgtype.MacaddrArray{}, dst: &macaddrSlice, expected: (([]net.HardwareAddr)(nil)), }, { - src: pgtype.MacaddrArray{Status: pgtype.Present}, + src: pgtype.MacaddrArray{Valid: true}, dst: &macaddrSlice, expected: []net.HardwareAddr{}, }, { src: pgtype.MacaddrArray{ Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &macaddrSliceDim2, expected: [][]net.HardwareAddr{ {mustParseMacaddr(t, "01:23:45:67:89:ab")}, @@ -186,18 +186,18 @@ func TestMacaddrArrayAssignTo(t *testing.T) { { src: pgtype.MacaddrArray{ Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Valid: true}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Valid: true}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Valid: true}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &macaddrSliceDim4, expected: [][][][]net.HardwareAddr{ {{{ @@ -212,10 +212,10 @@ func TestMacaddrArrayAssignTo(t *testing.T) { { src: pgtype.MacaddrArray{ Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &macaddrArrayDim2, expected: [2][1]net.HardwareAddr{ {mustParseMacaddr(t, "01:23:45:67:89:ab")}, @@ -224,18 +224,18 @@ func TestMacaddrArrayAssignTo(t *testing.T) { { src: pgtype.MacaddrArray{ Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, - {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Valid: true}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Valid: true}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Valid: true}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &macaddrArrayDim4, expected: [2][1][1][3]net.HardwareAddr{ {{{ diff --git a/macaddr_test.go b/macaddr_test.go index 364a8914..dc475c41 100644 --- a/macaddr_test.go +++ b/macaddr_test.go @@ -12,8 +12,8 @@ import ( func TestMacaddrTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "macaddr", []interface{}{ - &pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - &pgtype.Macaddr{Status: pgtype.Null}, + &pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + &pgtype.Macaddr{}, }) } @@ -24,11 +24,11 @@ func TestMacaddrSet(t *testing.T) { }{ { source: mustParseMacaddr(t, "01:23:45:67:89:ab"), - result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, }, { source: "01:23:45:67:89:ab", - result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, }, } @@ -47,7 +47,7 @@ func TestMacaddrSet(t *testing.T) { func TestMacaddrAssignTo(t *testing.T) { { - src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present} + src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true} var dst net.HardwareAddr expected := mustParseMacaddr(t, "01:23:45:67:89:ab") @@ -62,7 +62,7 @@ func TestMacaddrAssignTo(t *testing.T) { } { - src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present} + src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true} var dst string expected := "01:23:45:67:89:ab" diff --git a/name_test.go b/name_test.go index 75329b01..5f429d83 100644 --- a/name_test.go +++ b/name_test.go @@ -10,9 +10,9 @@ import ( func TestNameTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "name", []interface{}{ - &pgtype.Name{String: "", Status: pgtype.Present}, - &pgtype.Name{String: "foo", Status: pgtype.Present}, - &pgtype.Name{Status: pgtype.Null}, + &pgtype.Name{String: "", Valid: true}, + &pgtype.Name{String: "foo", Valid: true}, + &pgtype.Name{}, }) } @@ -21,9 +21,9 @@ func TestNameSet(t *testing.T) { source interface{} result pgtype.Name }{ - {source: "foo", result: pgtype.Name{String: "foo", Status: pgtype.Present}}, - {source: _string("bar"), result: pgtype.Name{String: "bar", Status: pgtype.Present}}, - {source: (*string)(nil), result: pgtype.Name{Status: pgtype.Null}}, + {source: "foo", result: pgtype.Name{String: "foo", Valid: true}}, + {source: _string("bar"), result: pgtype.Name{String: "bar", Valid: true}}, + {source: (*string)(nil), result: pgtype.Name{}}, } for i, tt := range successfulTests { @@ -48,8 +48,8 @@ func TestNameAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Name{String: "foo", Status: pgtype.Present}, dst: &s, expected: "foo"}, - {src: pgtype.Name{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + {src: pgtype.Name{String: "foo", Valid: true}, dst: &s, expected: "foo"}, + {src: pgtype.Name{}, dst: &ps, expected: ((*string)(nil))}, } for i, tt := range simpleTests { @@ -68,7 +68,7 @@ func TestNameAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Name{String: "foo", Status: pgtype.Present}, dst: &ps, expected: "foo"}, + {src: pgtype.Name{String: "foo", Valid: true}, dst: &ps, expected: "foo"}, } for i, tt := range pointerAllocTests { @@ -86,7 +86,7 @@ func TestNameAssignTo(t *testing.T) { src pgtype.Name dst interface{} }{ - {src: pgtype.Name{Status: pgtype.Null}, dst: &s}, + {src: pgtype.Name{}, dst: &s}, } for i, tt := range errorTests { diff --git a/numeric.go b/numeric.go index a939625b..3d209ff2 100644 --- a/numeric.go +++ b/numeric.go @@ -57,14 +57,14 @@ var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) type Numeric struct { Int *big.Int Exp int32 - Status Status + Valid bool NaN bool InfinityModifier InfinityModifier } func (dst *Numeric) Set(src interface{}) error { if src == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} return nil } @@ -78,142 +78,142 @@ func (dst *Numeric) Set(src interface{}) error { switch value := src.(type) { case float32: if math.IsNaN(float64(value)) { - *dst = Numeric{Status: Present, NaN: true} + *dst = Numeric{Valid: true, NaN: true} return nil } else if math.IsInf(float64(value), 1) { - *dst = Numeric{Status: Present, InfinityModifier: Infinity} + *dst = Numeric{Valid: true, InfinityModifier: Infinity} return nil } else if math.IsInf(float64(value), -1) { - *dst = Numeric{Status: Present, InfinityModifier: NegativeInfinity} + *dst = Numeric{Valid: true, InfinityModifier: NegativeInfinity} return nil } num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) if err != nil { return err } - *dst = Numeric{Int: num, Exp: exp, Status: Present} + *dst = Numeric{Int: num, Exp: exp, Valid: true} case float64: if math.IsNaN(value) { - *dst = Numeric{Status: Present, NaN: true} + *dst = Numeric{Valid: true, NaN: true} return nil } else if math.IsInf(value, 1) { - *dst = Numeric{Status: Present, InfinityModifier: Infinity} + *dst = Numeric{Valid: true, InfinityModifier: Infinity} return nil } else if math.IsInf(value, -1) { - *dst = Numeric{Status: Present, InfinityModifier: NegativeInfinity} + *dst = Numeric{Valid: true, InfinityModifier: NegativeInfinity} return nil } num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) if err != nil { return err } - *dst = Numeric{Int: num, Exp: exp, Status: Present} + *dst = Numeric{Int: num, Exp: exp, Valid: true} case int8: - *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} case uint8: - *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} case int16: - *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} case uint16: - *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} case int32: - *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} case uint32: - *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} case int64: - *dst = Numeric{Int: big.NewInt(value), Status: Present} + *dst = Numeric{Int: big.NewInt(value), Valid: true} case uint64: - *dst = Numeric{Int: (&big.Int{}).SetUint64(value), Status: Present} + *dst = Numeric{Int: (&big.Int{}).SetUint64(value), Valid: true} case int: - *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} case uint: - *dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Status: Present} + *dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Valid: true} case string: num, exp, err := parseNumericString(value) if err != nil { return err } - *dst = Numeric{Int: num, Exp: exp, Status: Present} + *dst = Numeric{Int: num, Exp: exp, Valid: true} case *float64: if value == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} } else { return dst.Set(*value) } case *float32: if value == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} } else { return dst.Set(*value) } case *int8: if value == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} } else { return dst.Set(*value) } case *uint8: if value == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} } else { return dst.Set(*value) } case *int16: if value == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} } else { return dst.Set(*value) } case *uint16: if value == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} } else { return dst.Set(*value) } case *int32: if value == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} } else { return dst.Set(*value) } case *uint32: if value == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} } else { return dst.Set(*value) } case *int64: if value == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} } else { return dst.Set(*value) } case *uint64: if value == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} } else { return dst.Set(*value) } case *int: if value == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} } else { return dst.Set(*value) } case *uint: if value == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} } else { return dst.Set(*value) } case *string: if value == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} } else { return dst.Set(*value) } case InfinityModifier: - *dst = Numeric{InfinityModifier: value, Status: Present} + *dst = Numeric{InfinityModifier: value, Valid: true} default: if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) @@ -225,160 +225,156 @@ func (dst *Numeric) Set(src interface{}) error { } func (dst Numeric) Get() interface{} { - switch dst.Status { - case Present: - if dst.InfinityModifier != None { - return dst.InfinityModifier - } - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst } func (src *Numeric) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *float32: - f, err := src.toFloat64() - if err != nil { - return err - } - return float64AssignTo(f, src.Status, dst) - case *float64: - f, err := src.toFloat64() - if err != nil { - return err - } - return float64AssignTo(f, src.Status, dst) - case *int: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(bigMaxInt) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) - } - if normalizedInt.Cmp(bigMinInt) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) - } - *v = int(normalizedInt.Int64()) - case *int8: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(bigMaxInt8) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) - } - if normalizedInt.Cmp(bigMinInt8) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) - } - *v = int8(normalizedInt.Int64()) - case *int16: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(bigMaxInt16) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) - } - if normalizedInt.Cmp(bigMinInt16) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) - } - *v = int16(normalizedInt.Int64()) - case *int32: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(bigMaxInt32) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) - } - if normalizedInt.Cmp(bigMinInt32) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) - } - *v = int32(normalizedInt.Int64()) - case *int64: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(bigMaxInt64) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) - } - if normalizedInt.Cmp(bigMinInt64) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) - } - *v = normalizedInt.Int64() - case *uint: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) - } else if normalizedInt.Cmp(bigMaxUint) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) - } - *v = uint(normalizedInt.Uint64()) - case *uint8: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) - } else if normalizedInt.Cmp(bigMaxUint8) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) - } - *v = uint8(normalizedInt.Uint64()) - case *uint16: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) - } else if normalizedInt.Cmp(bigMaxUint16) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) - } - *v = uint16(normalizedInt.Uint64()) - case *uint32: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) - } else if normalizedInt.Cmp(bigMaxUint32) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) - } - *v = uint32(normalizedInt.Uint64()) - case *uint64: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) - } else if normalizedInt.Cmp(bigMaxUint64) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) - } - *v = normalizedInt.Uint64() - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.Valid { return NullAssignTo(dst) } + switch v := dst.(type) { + case *float32: + f, err := src.toFloat64() + if err != nil { + return err + } + return float64AssignTo(f, src.Valid, dst) + case *float64: + f, err := src.toFloat64() + if err != nil { + return err + } + return float64AssignTo(f, src.Valid, dst) + case *int: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int(normalizedInt.Int64()) + case *int8: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt8) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt8) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int8(normalizedInt.Int64()) + case *int16: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt16) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt16) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int16(normalizedInt.Int64()) + case *int32: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt32) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt32) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int32(normalizedInt.Int64()) + case *int64: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt64) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt64) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = normalizedInt.Int64() + case *uint: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint(normalizedInt.Uint64()) + case *uint8: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint8) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint8(normalizedInt.Uint64()) + case *uint16: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint16) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint16(normalizedInt.Uint64()) + case *uint32: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint32) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint32(normalizedInt.Uint64()) + case *uint64: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint64) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = normalizedInt.Uint64() + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + return nil } @@ -430,18 +426,18 @@ func (src *Numeric) toFloat64() (float64, error) { func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} return nil } if string(src) == "NaN" { - *dst = Numeric{Status: Present, NaN: true} + *dst = Numeric{Valid: true, NaN: true} return nil } else if string(src) == "Infinity" { - *dst = Numeric{Status: Present, InfinityModifier: Infinity} + *dst = Numeric{Valid: true, InfinityModifier: Infinity} return nil } else if string(src) == "-Infinity" { - *dst = Numeric{Status: Present, InfinityModifier: NegativeInfinity} + *dst = Numeric{Valid: true, InfinityModifier: NegativeInfinity} return nil } @@ -450,7 +446,7 @@ func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Numeric{Int: num, Exp: exp, Status: Present} + *dst = Numeric{Int: num, Exp: exp, Valid: true} return nil } @@ -477,7 +473,7 @@ func parseNumericString(str string) (n *big.Int, exp int32, err error) { func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} return nil } @@ -496,18 +492,18 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { rp += 2 if sign == pgNumericNaNSign { - *dst = Numeric{Status: Present, NaN: true} + *dst = Numeric{Valid: true, NaN: true} return nil } else if sign == pgNumericPosInfSign { - *dst = Numeric{Status: Present, InfinityModifier: Infinity} + *dst = Numeric{Valid: true, InfinityModifier: Infinity} return nil } else if sign == pgNumericNegInfSign { - *dst = Numeric{Status: Present, InfinityModifier: NegativeInfinity} + *dst = Numeric{Valid: true, InfinityModifier: NegativeInfinity} return nil } if ndigits == 0 { - *dst = Numeric{Int: big.NewInt(0), Status: Present} + *dst = Numeric{Int: big.NewInt(0), Valid: true} return nil } @@ -579,7 +575,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { accum.Neg(accum) } - *dst = Numeric{Int: accum, Exp: exp, Status: Present} + *dst = Numeric{Int: accum, Exp: exp, Valid: true} return nil @@ -605,11 +601,8 @@ func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { } func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if src.NaN { @@ -630,11 +623,8 @@ func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if src.NaN { @@ -734,7 +724,7 @@ func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Numeric) Scan(src interface{}) error { if src == nil { - *dst = Numeric{Status: Null} + *dst = Numeric{} return nil } @@ -752,17 +742,14 @@ func (dst *Numeric) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Numeric) Value() (driver.Value, error) { - switch src.Status { - case Present: - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - - return string(buf), nil - case Null: + if !src.Valid { return nil, nil - default: - return nil, errUndefined } + + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + + return string(buf), nil } diff --git a/numeric_array.go b/numeric_array.go index 31899dec..3e9298b6 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -14,13 +14,13 @@ import ( type NumericArray struct { Elements []Numeric Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *NumericArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = NumericArray{Status: Null} + *dst = NumericArray{} return nil } @@ -36,9 +36,9 @@ func (dst *NumericArray) Set(src interface{}) error { case []float32: if value == nil { - *dst = NumericArray{Status: Null} + *dst = NumericArray{} } else if len(value) == 0 { - *dst = NumericArray{Status: Present} + *dst = NumericArray{Valid: true} } else { elements := make([]Numeric, len(value)) for i := range value { @@ -49,15 +49,15 @@ func (dst *NumericArray) Set(src interface{}) error { *dst = NumericArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*float32: if value == nil { - *dst = NumericArray{Status: Null} + *dst = NumericArray{} } else if len(value) == 0 { - *dst = NumericArray{Status: Present} + *dst = NumericArray{Valid: true} } else { elements := make([]Numeric, len(value)) for i := range value { @@ -68,15 +68,15 @@ func (dst *NumericArray) Set(src interface{}) error { *dst = NumericArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []float64: if value == nil { - *dst = NumericArray{Status: Null} + *dst = NumericArray{} } else if len(value) == 0 { - *dst = NumericArray{Status: Present} + *dst = NumericArray{Valid: true} } else { elements := make([]Numeric, len(value)) for i := range value { @@ -87,15 +87,15 @@ func (dst *NumericArray) Set(src interface{}) error { *dst = NumericArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*float64: if value == nil { - *dst = NumericArray{Status: Null} + *dst = NumericArray{} } else if len(value) == 0 { - *dst = NumericArray{Status: Present} + *dst = NumericArray{Valid: true} } else { elements := make([]Numeric, len(value)) for i := range value { @@ -106,15 +106,15 @@ func (dst *NumericArray) Set(src interface{}) error { *dst = NumericArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []int64: if value == nil { - *dst = NumericArray{Status: Null} + *dst = NumericArray{} } else if len(value) == 0 { - *dst = NumericArray{Status: Present} + *dst = NumericArray{Valid: true} } else { elements := make([]Numeric, len(value)) for i := range value { @@ -125,15 +125,15 @@ func (dst *NumericArray) Set(src interface{}) error { *dst = NumericArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*int64: if value == nil { - *dst = NumericArray{Status: Null} + *dst = NumericArray{} } else if len(value) == 0 { - *dst = NumericArray{Status: Present} + *dst = NumericArray{Valid: true} } else { elements := make([]Numeric, len(value)) for i := range value { @@ -144,15 +144,15 @@ func (dst *NumericArray) Set(src interface{}) error { *dst = NumericArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []uint64: if value == nil { - *dst = NumericArray{Status: Null} + *dst = NumericArray{} } else if len(value) == 0 { - *dst = NumericArray{Status: Present} + *dst = NumericArray{Valid: true} } else { elements := make([]Numeric, len(value)) for i := range value { @@ -163,15 +163,15 @@ func (dst *NumericArray) Set(src interface{}) error { *dst = NumericArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*uint64: if value == nil { - *dst = NumericArray{Status: Null} + *dst = NumericArray{} } else if len(value) == 0 { - *dst = NumericArray{Status: Present} + *dst = NumericArray{Valid: true} } else { elements := make([]Numeric, len(value)) for i := range value { @@ -182,20 +182,20 @@ func (dst *NumericArray) Set(src interface{}) error { *dst = NumericArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []Numeric: if value == nil { - *dst = NumericArray{Status: Null} + *dst = NumericArray{} } else if len(value) == 0 { - *dst = NumericArray{Status: Present} + *dst = NumericArray{Valid: true} } else { *dst = NumericArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -204,7 +204,7 @@ func (dst *NumericArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = NumericArray{Status: Null} + *dst = NumericArray{} return nil } @@ -213,7 +213,7 @@ func (dst *NumericArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for NumericArray", src) } if elementsLength == 0 { - *dst = NumericArray{Status: Present} + *dst = NumericArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -226,7 +226,7 @@ func (dst *NumericArray) Set(src interface{}) error { *dst = NumericArray{ Elements: make([]Numeric, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -293,138 +293,131 @@ func (dst *NumericArray) setRecursive(value reflect.Value, index, dimension int) } func (dst NumericArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *NumericArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]float32: - *v = make([]float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float32: - *v = make([]*float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]float64: - *v = make([]float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float64: - *v = make([]*float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]float32: + *v = make([]float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*float32: + *v = make([]*float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]float64: + *v = make([]float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*float64: + *v = make([]*float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *NumericArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -476,7 +469,7 @@ func (src *NumericArray) assignToRecursive(value reflect.Value, index, dimension func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = NumericArray{Status: Null} + *dst = NumericArray{} return nil } @@ -505,14 +498,14 @@ func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = NumericArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = NumericArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *NumericArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = NumericArray{Status: Null} + *dst = NumericArray{} return nil } @@ -523,7 +516,7 @@ func (dst *NumericArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = NumericArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = NumericArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -548,16 +541,13 @@ func (dst *NumericArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = NumericArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = NumericArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -610,11 +600,8 @@ func (src NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -628,7 +615,7 @@ func (src NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/numeric_array_test.go b/numeric_array_test.go index 7c1e8c3b..ee36d1a7 100644 --- a/numeric_array_test.go +++ b/numeric_array_test.go @@ -15,41 +15,41 @@ func TestNumericArrayTranscode(t *testing.T) { &pgtype.NumericArray{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.NumericArray{ Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Status: pgtype.Null}, + {Int: big.NewInt(1), Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.NumericArray{Status: pgtype.Null}, + &pgtype.NumericArray{}, &pgtype.NumericArray{ Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Int: big.NewInt(2), Status: pgtype.Present}, - {Int: big.NewInt(3), Status: pgtype.Present}, - {Int: big.NewInt(4), Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: big.NewInt(6), Status: pgtype.Present}, + {Int: big.NewInt(1), Valid: true}, + {Int: big.NewInt(2), Valid: true}, + {Int: big.NewInt(3), Valid: true}, + {Int: big.NewInt(4), Valid: true}, + {}, + {Int: big.NewInt(6), Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.NumericArray{ Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Int: big.NewInt(2), Status: pgtype.Present}, - {Int: big.NewInt(3), Status: pgtype.Present}, - {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(1), Valid: true}, + {Int: big.NewInt(2), Valid: true}, + {Int: big.NewInt(3), Valid: true}, + {Int: big.NewInt(4), Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } @@ -62,82 +62,82 @@ func TestNumericArraySet(t *testing.T) { { source: []float32{1}, result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []float32{float32(math.Copysign(0, -1))}, result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(0), Status: pgtype.Present}}, + Elements: []pgtype.Numeric{{Int: big.NewInt(0), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []float64{1}, result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []float64{math.Copysign(0, -1)}, result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(0), Status: pgtype.Present}}, + Elements: []pgtype.Numeric{{Int: big.NewInt(0), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]float32)(nil)), - result: pgtype.NumericArray{Status: pgtype.Null}, + result: pgtype.NumericArray{}, }, { source: [][]float32{{1}, {2}}, result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, result: pgtype.NumericArray{ Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Int: big.NewInt(2), Status: pgtype.Present}, - {Int: big.NewInt(3), Status: pgtype.Present}, - {Int: big.NewInt(4), Status: pgtype.Present}, - {Int: big.NewInt(5), Status: pgtype.Present}, - {Int: big.NewInt(6), Status: pgtype.Present}}, + {Int: big.NewInt(1), Valid: true}, + {Int: big.NewInt(2), Valid: true}, + {Int: big.NewInt(3), Valid: true}, + {Int: big.NewInt(4), Valid: true}, + {Int: big.NewInt(5), Valid: true}, + {Int: big.NewInt(6), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1]float32{{1}, {2}}, result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, result: pgtype.NumericArray{ Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Int: big.NewInt(2), Status: pgtype.Present}, - {Int: big.NewInt(3), Status: pgtype.Present}, - {Int: big.NewInt(4), Status: pgtype.Present}, - {Int: big.NewInt(5), Status: pgtype.Present}, - {Int: big.NewInt(6), Status: pgtype.Present}}, + {Int: big.NewInt(1), Valid: true}, + {Int: big.NewInt(2), Valid: true}, + {Int: big.NewInt(3), Valid: true}, + {Int: big.NewInt(4), Valid: true}, + {Int: big.NewInt(5), Valid: true}, + {Int: big.NewInt(6), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -169,81 +169,81 @@ func TestNumericArrayAssignTo(t *testing.T) { }{ { src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &float32Slice, expected: []float32{1}, }, { src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &float64Slice, expected: []float64{1}, }, { - src: pgtype.NumericArray{Status: pgtype.Null}, + src: pgtype.NumericArray{}, dst: &float32Slice, expected: (([]float32)(nil)), }, { - src: pgtype.NumericArray{Status: pgtype.Present}, + src: pgtype.NumericArray{Valid: true}, dst: &float32Slice, expected: []float32{}, }, { src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &float32SliceDim2, expected: [][]float32{{1}, {2}}, }, { src: pgtype.NumericArray{ Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Int: big.NewInt(2), Status: pgtype.Present}, - {Int: big.NewInt(3), Status: pgtype.Present}, - {Int: big.NewInt(4), Status: pgtype.Present}, - {Int: big.NewInt(5), Status: pgtype.Present}, - {Int: big.NewInt(6), Status: pgtype.Present}}, + {Int: big.NewInt(1), Valid: true}, + {Int: big.NewInt(2), Valid: true}, + {Int: big.NewInt(3), Valid: true}, + {Int: big.NewInt(4), Valid: true}, + {Int: big.NewInt(5), Valid: true}, + {Int: big.NewInt(6), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &float32SliceDim4, expected: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, }, { src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &float32ArrayDim2, expected: [2][1]float32{{1}, {2}}, }, { src: pgtype.NumericArray{ Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Int: big.NewInt(2), Status: pgtype.Present}, - {Int: big.NewInt(3), Status: pgtype.Present}, - {Int: big.NewInt(4), Status: pgtype.Present}, - {Int: big.NewInt(5), Status: pgtype.Present}, - {Int: big.NewInt(6), Status: pgtype.Present}}, + {Int: big.NewInt(1), Valid: true}, + {Int: big.NewInt(2), Valid: true}, + {Int: big.NewInt(3), Valid: true}, + {Int: big.NewInt(4), Valid: true}, + {Int: big.NewInt(5), Valid: true}, + {Int: big.NewInt(6), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &float32ArrayDim4, expected: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, }, @@ -266,31 +266,31 @@ func TestNumericArrayAssignTo(t *testing.T) { }{ { src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Status: pgtype.Null}}, + Elements: []pgtype.Numeric{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &float32Slice, }, { src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &float32ArrayDim2, }, { src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &float32Slice, }, { src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &float32ArrayDim4, }, } diff --git a/numeric_test.go b/numeric_test.go index 455c3ac3..58ce5c0f 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -13,7 +13,7 @@ import ( // For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0) func numericEqual(left, right *pgtype.Numeric) bool { - return left.Status == right.Status && + return left.Valid == right.Valid && left.Exp == right.Exp && ((left.Int == nil && right.Int == nil) || (left.Int != nil && right.Int != nil && left.Int.Cmp(right.Int) == 0)) && left.NaN == right.NaN @@ -21,12 +21,12 @@ func numericEqual(left, right *pgtype.Numeric) bool { // For test purposes only. func numericNormalizedEqual(left, right *pgtype.Numeric) bool { - if left.Status != right.Status { + if left.Valid != right.Valid { return false } - normLeft := &pgtype.Numeric{Int: (&big.Int{}).Set(left.Int), Status: left.Status} - normRight := &pgtype.Numeric{Int: (&big.Int{}).Set(right.Int), Status: right.Status} + normLeft := &pgtype.Numeric{Int: (&big.Int{}).Set(left.Int), Valid: left.Valid} + normRight := &pgtype.Numeric{Int: (&big.Int{}).Set(right.Int), Valid: right.Valid} if left.Exp < right.Exp { mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(right.Exp-left.Exp)), nil) @@ -51,66 +51,66 @@ func TestNumericNormalize(t *testing.T) { testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { SQL: "select '0'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Valid: true}, }, { SQL: "select '1'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Valid: true}, }, { SQL: "select '10.00'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Valid: true}, }, { SQL: "select '1e-3'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Valid: true}, }, { SQL: "select '-1'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Valid: true}, }, { SQL: "select '10000'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Valid: true}, }, { SQL: "select '3.14'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Valid: true}, }, { SQL: "select '1.1'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Valid: true}, }, { SQL: "select '100010001'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Valid: true}, }, { SQL: "select '100010001.0001'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Valid: true}, }, { SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", Value: &pgtype.Numeric{ - Int: mustParseBigInt(t, "423723478923478928934789237432487213832189417894318904389012483210893443219085471578891547854892438945012347981"), - Exp: -41, - Status: pgtype.Present, + Int: mustParseBigInt(t, "423723478923478928934789237432487213832189417894318904389012483210893443219085471578891547854892438945012347981"), + Exp: -41, + Valid: true, }, }, { SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", Value: &pgtype.Numeric{ - Int: mustParseBigInt(t, "8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), - Exp: -196, - Status: pgtype.Present, + Int: mustParseBigInt(t, "8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), + Exp: -196, + Valid: true, }, }, { SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", Value: &pgtype.Numeric{ - Int: mustParseBigInt(t, "123"), - Exp: -186, - Status: pgtype.Present, + Int: mustParseBigInt(t, "123"), + Exp: -186, + Valid: true, }, }, }) @@ -119,48 +119,48 @@ func TestNumericNormalize(t *testing.T) { func TestNumericTranscode(t *testing.T) { max := new(big.Int).Exp(big.NewInt(10), big.NewInt(147454), nil) max.Add(max, big.NewInt(1)) - longestNumeric := &pgtype.Numeric{Int: max, Exp: -16383, Status: pgtype.Present} + longestNumeric := &pgtype.Numeric{Int: max, Exp: -16383, Valid: true} testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ - &pgtype.Numeric{NaN: true, Status: pgtype.Present}, - &pgtype.Numeric{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, - &pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, + &pgtype.Numeric{NaN: true, Valid: true}, + &pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, + &pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, - &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(1), Exp: 6, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(1), Exp: 6, Valid: true}, // preserves significant zeroes - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -1, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -2, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -3, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -4, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -5, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -6, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -1, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -2, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -3, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -4, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -5, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -6, Valid: true}, - &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(123), Exp: -7, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(123), Exp: -8, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(123), Exp: -9, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(123), Exp: -1500, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "3723409723490243842378942378901237502734019231380123"), Exp: 81, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "723409723490243842378942378901237502734019231380123"), Exp: 82, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "23409723490243842378942378901237502734019231380123"), Exp: 83, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "3409723490243842378942378901237502734019231380123"), Exp: 84, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "3423409823409243892349028349023482934092340892390101"), Exp: -91, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "423409823409243892349028349023482934092340892390101"), Exp: -92, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "23409823409243892349028349023482934092340892390101"), Exp: -93, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "3409823409243892349028349023482934092340892390101"), Exp: -94, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -7, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -8, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -9, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -1500, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3723409723490243842378942378901237502734019231380123"), Exp: 81, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "723409723490243842378942378901237502734019231380123"), Exp: 82, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "23409723490243842378942378901237502734019231380123"), Exp: 83, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3409723490243842378942378901237502734019231380123"), Exp: 84, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3423409823409243892349028349023482934092340892390101"), Exp: -91, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "423409823409243892349028349023482934092340892390101"), Exp: -92, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "23409823409243892349028349023482934092340892390101"), Exp: -93, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3409823409243892349028349023482934092340892390101"), Exp: -94, Valid: true}, longestNumeric, - &pgtype.Numeric{Status: pgtype.Null}, + &pgtype.Numeric{}, }, func(aa, bb interface{}) bool { a := aa.(pgtype.Numeric) b := bb.(pgtype.Numeric) @@ -181,8 +181,8 @@ func TestNumericTranscodeFuzz(t *testing.T) { num := (&big.Int{}).Rand(r, max) negNum := &big.Int{} negNum.Neg(num) - values = append(values, &pgtype.Numeric{Int: num, Exp: int32(j), Status: pgtype.Present}) - values = append(values, &pgtype.Numeric{Int: negNum, Exp: int32(j), Status: pgtype.Present}) + values = append(values, &pgtype.Numeric{Int: num, Exp: int32(j), Valid: true}) + values = append(values, &pgtype.Numeric{Int: negNum, Exp: int32(j), Valid: true}) } } @@ -200,36 +200,36 @@ func TestNumericSet(t *testing.T) { source interface{} result *pgtype.Numeric }{ - {source: float32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: float32(math.Copysign(0, -1)), result: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Present}}, - {source: float64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: float64(math.Copysign(0, -1)), result: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Present}}, - {source: int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: int16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: int32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: int64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: int8(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, - {source: int16(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, - {source: int32(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, - {source: int64(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, - {source: uint8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: uint16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: uint32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: uint64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: "1", result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: _int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: float64(1000), result: &pgtype.Numeric{Int: big.NewInt(1), Exp: 3, Status: pgtype.Present}}, - {source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Status: pgtype.Present}}, - {source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Status: pgtype.Present}}, - {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}}, - {source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, NaN: true}}, - {source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, NaN: true}}, - {source: pgtype.Infinity, result: &pgtype.Numeric{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, - {source: math.Inf(1), result: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}}, - {source: float32(math.Inf(1)), result: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}}, - {source: pgtype.NegativeInfinity, result: &pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, - {source: math.Inf(-1), result: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.NegativeInfinity}}, - {source: float32(math.Inf(1)), result: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}}, + {source: float32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: float32(math.Copysign(0, -1)), result: &pgtype.Numeric{Int: big.NewInt(0), Valid: true}}, + {source: float64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: float64(math.Copysign(0, -1)), result: &pgtype.Numeric{Int: big.NewInt(0), Valid: true}}, + {source: int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: int16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: int32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: int64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: int8(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}}, + {source: int16(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}}, + {source: int32(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}}, + {source: int64(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}}, + {source: uint8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: uint16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: uint32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: uint64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: "1", result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: _int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: float64(1000), result: &pgtype.Numeric{Int: big.NewInt(1), Exp: 3, Valid: true}}, + {source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Valid: true}}, + {source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Valid: true}}, + {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Valid: true}}, + {source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Valid: true, NaN: true}}, + {source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Valid: true, NaN: true}}, + {source: pgtype.Infinity, result: &pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}}, + {source: math.Inf(1), result: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}}, + {source: float32(math.Inf(1)), result: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}}, + {source: pgtype.NegativeInfinity, result: &pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, + {source: math.Inf(-1), result: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.NegativeInfinity}}, + {source: float32(math.Inf(1)), result: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}}, } for i, tt := range successfulTests { @@ -269,30 +269,30 @@ func TestNumericAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f32, expected: float32(4.2)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f64, expected: float64(4.2)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: 3, Status: pgtype.Present}, dst: &i64, expected: int64(42000)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, - {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Status: pgtype.Present}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgtype/issues/27 - {src: &pgtype.Numeric{Status: pgtype.Present, NaN: true}, dst: &f64, expected: math.NaN()}, - {src: &pgtype.Numeric{Status: pgtype.Present, NaN: true}, dst: &f32, expected: float32(math.NaN())}, - {src: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, dst: &f64, expected: math.Inf(1)}, - {src: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, dst: &f32, expected: float32(math.Inf(1))}, - {src: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.NegativeInfinity}, dst: &f64, expected: math.Inf(-1)}, - {src: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.NegativeInfinity}, dst: &f32, expected: float32(math.Inf(-1))}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &f32, expected: float32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &f64, expected: float64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Valid: true}, dst: &f32, expected: float32(4.2)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Valid: true}, dst: &f64, expected: float64(4.2)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &i16, expected: int16(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &i32, expected: int32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &i64, expected: int64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: 3, Valid: true}, dst: &i64, expected: int64(42000)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &i, expected: int(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &ui8, expected: uint8(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &ui16, expected: uint16(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &ui64, expected: uint64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &ui, expected: uint(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &_i8, expected: _int8(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(0)}, dst: &pi8, expected: ((*int8)(nil))}, + {src: &pgtype.Numeric{Int: big.NewInt(0)}, dst: &_pi8, expected: ((*_int8)(nil))}, + {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Valid: true}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgtype/issues/27 + {src: &pgtype.Numeric{Valid: true, NaN: true}, dst: &f64, expected: math.NaN()}, + {src: &pgtype.Numeric{Valid: true, NaN: true}, dst: &f32, expected: float32(math.NaN())}, + {src: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}, dst: &f64, expected: math.Inf(1)}, + {src: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}, dst: &f32, expected: float32(math.Inf(1))}, + {src: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.NegativeInfinity}, dst: &f64, expected: math.Inf(-1)}, + {src: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.NegativeInfinity}, dst: &f32, expected: float32(math.Inf(-1))}, } for i, tt := range simpleTests { @@ -329,8 +329,8 @@ func TestNumericAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &pf32, expected: float32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &pf64, expected: float64(42)}, } for i, tt := range pointerAllocTests { @@ -348,14 +348,14 @@ func TestNumericAssignTo(t *testing.T) { src *pgtype.Numeric dst interface{} }{ - {src: &pgtype.Numeric{Int: big.NewInt(150), Status: pgtype.Present}, dst: &i8}, - {src: &pgtype.Numeric{Int: big.NewInt(40000), Status: pgtype.Present}, dst: &i16}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui8}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui16}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui32}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui64}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui}, - {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &i32}, + {src: &pgtype.Numeric{Int: big.NewInt(150), Valid: true}, dst: &i8}, + {src: &pgtype.Numeric{Int: big.NewInt(40000), Valid: true}, dst: &i16}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}, dst: &ui8}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}, dst: &ui16}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}, dst: &ui32}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}, dst: &ui64}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}, dst: &ui}, + {src: &pgtype.Numeric{Int: big.NewInt(0)}, dst: &i32}, } for i, tt := range errorTests { diff --git a/numrange.go b/numrange.go index 3d5951a2..f1118d83 100644 --- a/numrange.go +++ b/numrange.go @@ -12,13 +12,13 @@ type Numrange struct { Upper Numeric LowerType BoundType UpperType BoundType - Status Status + Valid bool } func (dst *Numrange) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = Numrange{Status: Null} + *dst = Numrange{} return nil } @@ -36,15 +36,11 @@ func (dst *Numrange) Set(src interface{}) error { return nil } -func (dst Numrange) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: +func (src Numrange) Get() interface{} { + if !src.Valid { return nil - default: - return dst.Status } + return src } func (src *Numrange) AssignTo(dst interface{}) error { @@ -53,7 +49,7 @@ func (src *Numrange) AssignTo(dst interface{}) error { func (dst *Numrange) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Numrange{Status: Null} + *dst = Numrange{} return nil } @@ -62,7 +58,7 @@ func (dst *Numrange) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Numrange{Status: Present} + *dst = Numrange{Valid: true} dst.LowerType = utr.LowerType dst.UpperType = utr.UpperType @@ -88,7 +84,7 @@ func (dst *Numrange) DecodeText(ci *ConnInfo, src []byte) error { func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Numrange{Status: Null} + *dst = Numrange{} return nil } @@ -97,7 +93,7 @@ func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { return err } - *dst = Numrange{Status: Present} + *dst = Numrange{Valid: true} dst.LowerType = ubr.LowerType dst.UpperType = ubr.UpperType @@ -122,11 +118,8 @@ func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { } func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } switch src.LowerType { @@ -175,11 +168,8 @@ func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } var rangeType byte @@ -245,7 +235,7 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Numrange) Scan(src interface{}) error { if src == nil { - *dst = Numrange{Status: Null} + *dst = Numrange{} return nil } diff --git a/numrange_test.go b/numrange_test.go index 0bbb26f0..b9ea7658 100644 --- a/numrange_test.go +++ b/numrange_test.go @@ -13,34 +13,34 @@ func TestNumrangeTranscode(t *testing.T) { &pgtype.Numrange{ LowerType: pgtype.Empty, UpperType: pgtype.Empty, - Status: pgtype.Present, + Valid: true, }, &pgtype.Numrange{ - Lower: pgtype.Numeric{Int: big.NewInt(-543), Exp: 3, Status: pgtype.Present}, - Upper: pgtype.Numeric{Int: big.NewInt(342), Exp: 1, Status: pgtype.Present}, + Lower: pgtype.Numeric{Int: big.NewInt(-543), Exp: 3, Valid: true}, + Upper: pgtype.Numeric{Int: big.NewInt(342), Exp: 1, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, - Status: pgtype.Present, + Valid: true, }, &pgtype.Numrange{ - Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, - Upper: pgtype.Numeric{Int: big.NewInt(-5), Exp: 0, Status: pgtype.Present}, + Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Valid: true}, + Upper: pgtype.Numeric{Int: big.NewInt(-5), Exp: 0, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, - Status: pgtype.Present, + Valid: true, }, &pgtype.Numrange{ - Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, + Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, - Status: pgtype.Present, + Valid: true, }, &pgtype.Numrange{ - Upper: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, + Upper: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Valid: true}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, - Status: pgtype.Present, + Valid: true, }, - &pgtype.Numrange{Status: pgtype.Null}, + &pgtype.Numrange{}, }) } diff --git a/oid_value_test.go b/oid_value_test.go index 69742dd7..021f81d3 100644 --- a/oid_value_test.go +++ b/oid_value_test.go @@ -10,8 +10,8 @@ import ( func TestOIDValueTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "oid", []interface{}{ - &pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, - &pgtype.OIDValue{Status: pgtype.Null}, + &pgtype.OIDValue{Uint: 42, Valid: true}, + &pgtype.OIDValue{}, }) } @@ -20,7 +20,7 @@ func TestOIDValueSet(t *testing.T) { source interface{} result pgtype.OIDValue }{ - {source: uint32(1), result: pgtype.OIDValue{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.OIDValue{Uint: 1, Valid: true}}, } for i, tt := range successfulTests { @@ -45,8 +45,8 @@ func TestOIDValueAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.OIDValue{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.OIDValue{Uint: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.OIDValue{}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -65,7 +65,7 @@ func TestOIDValueAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.OIDValue{Uint: 42, Valid: true}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -83,7 +83,7 @@ func TestOIDValueAssignTo(t *testing.T) { src pgtype.OIDValue dst interface{} }{ - {src: pgtype.OIDValue{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.OIDValue{}, dst: &ui32}, } for i, tt := range errorTests { diff --git a/path.go b/path.go index 9f89969e..7ac38c68 100644 --- a/path.go +++ b/path.go @@ -14,7 +14,7 @@ import ( type Path struct { P []Vec2 Closed bool - Status Status + Valid bool } func (dst *Path) Set(src interface{}) error { @@ -22,14 +22,10 @@ func (dst *Path) Set(src interface{}) error { } func (dst Path) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *Path) AssignTo(dst interface{}) error { @@ -38,7 +34,7 @@ func (src *Path) AssignTo(dst interface{}) error { func (dst *Path) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Path{Status: Null} + *dst = Path{} return nil } @@ -75,13 +71,13 @@ func (dst *Path) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = Path{P: points, Closed: closed, Status: Present} + *dst = Path{P: points, Closed: closed, Valid: true} return nil } func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Path{Status: Null} + *dst = Path{} return nil } @@ -110,17 +106,14 @@ func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { *dst = Path{ P: points, Closed: closed, - Status: Present, + Valid: true, } return nil } func (src Path) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } var startByte, endByte byte @@ -147,11 +140,8 @@ func (src Path) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Path) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } var closeByte byte @@ -173,7 +163,7 @@ func (src Path) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Path) Scan(src interface{}) error { if src == nil { - *dst = Path{Status: Null} + *dst = Path{} return nil } diff --git a/path_test.go b/path_test.go index 969a89ec..9a66996e 100644 --- a/path_test.go +++ b/path_test.go @@ -12,18 +12,18 @@ func TestPathTranscode(t *testing.T) { &pgtype.Path{ P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}}, Closed: false, - Status: pgtype.Present, + Valid: true, }, &pgtype.Path{ P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, Closed: true, - Status: pgtype.Present, + Valid: true, }, &pgtype.Path{ P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, Closed: true, - Status: pgtype.Present, + Valid: true, }, - &pgtype.Path{Status: pgtype.Null}, + &pgtype.Path{}, }) } diff --git a/pgtype.go b/pgtype.go index 200fb562..c4fe870d 100644 --- a/pgtype.go +++ b/pgtype.go @@ -3,7 +3,6 @@ package pgtype import ( "database/sql" "encoding/binary" - "errors" "fmt" "math" "net" @@ -82,14 +81,6 @@ const ( Int8rangeOID = 3926 ) -type Status byte - -const ( - Undefined Status = iota - Null - Present -) - type InfinityModifier int8 const ( @@ -208,9 +199,6 @@ type TextEncoder interface { EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) } -var errUndefined = errors.New("cannot encode status undefined") -var errBadStatus = errors.New("invalid status") - type nullAssignmentError struct { dst interface{} } diff --git a/pgtype_test.go b/pgtype_test.go index 5fd89dcb..7ae756e5 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -232,7 +232,7 @@ func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) { if err != nil { b.Fatal(err) } - if v != (pgtype.Int4{Int: 42, Status: pgtype.Present}) { + if v != (pgtype.Int4{Int: 42, Valid: true}) { b.Fatal("scan failed due to bad value") } } @@ -252,7 +252,7 @@ func TestScanPlanBinaryInt32ScanChangedType(t *testing.T) { err = plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &d) require.NoError(t, err) require.EqualValues(t, 42, d.Int) - require.EqualValues(t, pgtype.Present, d.Status) + require.True(t, d.Valid) } func BenchmarkConnInfoScanInt4IntoGoInt32(b *testing.B) { @@ -285,7 +285,7 @@ func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) { if err != nil { b.Fatal(err) } - if v != (pgtype.Int4{Int: 42, Status: pgtype.Present}) { + if v != (pgtype.Int4{Int: 42, Valid: true}) { b.Fatal("scan failed due to bad value") } } diff --git a/pguint32.go b/pguint32.go index a0e88ca2..e36ebb1f 100644 --- a/pguint32.go +++ b/pguint32.go @@ -13,8 +13,8 @@ import ( // pguint32 is the core type that is used to implement PostgreSQL types such as // CID and XID. type pguint32 struct { - Uint uint32 - Status Status + Uint uint32 + Valid bool } // Set converts from src to dst. Note that as pguint32 is not a general @@ -29,9 +29,9 @@ func (dst *pguint32) Set(src interface{}) error { if value > math.MaxUint32 { return fmt.Errorf("%d is greater than maximum value for pguint32", value) } - *dst = pguint32{Uint: uint32(value), Status: Present} + *dst = pguint32{Uint: uint32(value), Valid: true} case uint32: - *dst = pguint32{Uint: value, Status: Present} + *dst = pguint32{Uint: value, Valid: true} default: return fmt.Errorf("cannot convert %v to pguint32", value) } @@ -40,14 +40,10 @@ func (dst *pguint32) Set(src interface{}) error { } func (dst pguint32) Get() interface{} { - switch dst.Status { - case Present: - return dst.Uint - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.Uint } // AssignTo assigns from src to dst. Note that as pguint32 is not a general number @@ -55,13 +51,13 @@ func (dst pguint32) Get() interface{} { func (src *pguint32) AssignTo(dst interface{}) error { switch v := dst.(type) { case *uint32: - if src.Status == Present { + if src.Valid { *v = src.Uint } else { return fmt.Errorf("cannot assign %v into %T", src, dst) } case **uint32: - if src.Status == Present { + if src.Valid { n := src.Uint *v = &n } else { @@ -74,7 +70,7 @@ func (src *pguint32) AssignTo(dst interface{}) error { func (dst *pguint32) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = pguint32{Status: Null} + *dst = pguint32{} return nil } @@ -83,13 +79,13 @@ func (dst *pguint32) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = pguint32{Uint: uint32(n), Status: Present} + *dst = pguint32{Uint: uint32(n), Valid: true} return nil } func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = pguint32{Status: Null} + *dst = pguint32{} return nil } @@ -98,27 +94,21 @@ func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { } n := binary.BigEndian.Uint32(src) - *dst = pguint32{Uint: n, Status: Present} + *dst = pguint32{Uint: n, Valid: true} return nil } func (src pguint32) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return append(buf, strconv.FormatUint(uint64(src.Uint), 10)...), nil } func (src pguint32) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return pgio.AppendUint32(buf, src.Uint), nil @@ -127,16 +117,16 @@ func (src pguint32) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *pguint32) Scan(src interface{}) error { if src == nil { - *dst = pguint32{Status: Null} + *dst = pguint32{} return nil } switch src := src.(type) { case uint32: - *dst = pguint32{Uint: src, Status: Present} + *dst = pguint32{Uint: src, Valid: true} return nil case int64: - *dst = pguint32{Uint: uint32(src), Status: Present} + *dst = pguint32{Uint: uint32(src), Valid: true} return nil case string: return dst.DecodeText(nil, []byte(src)) @@ -151,12 +141,8 @@ func (dst *pguint32) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src pguint32) Value() (driver.Value, error) { - switch src.Status { - case Present: - return int64(src.Uint), nil - case Null: + if !src.Valid { return nil, nil - default: - return nil, errUndefined } + return int64(src.Uint), nil } diff --git a/point.go b/point.go index 0c799106..d35dbf03 100644 --- a/point.go +++ b/point.go @@ -18,13 +18,13 @@ type Vec2 struct { } type Point struct { - P Vec2 - Status Status + P Vec2 + Valid bool } func (dst *Point) Set(src interface{}) error { if src == nil { - dst.Status = Null + dst.Valid = false return nil } err := fmt.Errorf("cannot convert %v to Point", src) @@ -46,7 +46,7 @@ func (dst *Point) Set(src interface{}) error { func parsePoint(src []byte) (*Point, error) { if src == nil || bytes.Compare(src, []byte("null")) == 0 { - return &Point{Status: Null}, nil + return &Point{}, nil } if len(src) < 5 { @@ -70,18 +70,14 @@ func parsePoint(src []byte) (*Point, error) { return nil, err } - return &Point{P: Vec2{x, y}, Status: Present}, nil + return &Point{P: Vec2{x, y}, Valid: true}, nil } func (dst Point) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *Point) AssignTo(dst interface{}) error { @@ -90,7 +86,7 @@ func (src *Point) AssignTo(dst interface{}) error { func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Point{Status: Null} + *dst = Point{} return nil } @@ -113,13 +109,13 @@ func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Point{P: Vec2{x, y}, Status: Present} + *dst = Point{P: Vec2{x, y}, Valid: true} return nil } func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Point{Status: Null} + *dst = Point{} return nil } @@ -131,18 +127,15 @@ func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { y := binary.BigEndian.Uint64(src[8:]) *dst = Point{ - P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, - Status: Present, + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + Valid: true, } return nil } func (src Point) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return append(buf, fmt.Sprintf(`(%s,%s)`, @@ -152,11 +145,8 @@ func (src Point) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Point) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) @@ -167,7 +157,7 @@ func (src Point) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Point) Scan(src interface{}) error { if src == nil { - *dst = Point{Status: Null} + *dst = Point{} return nil } @@ -189,19 +179,15 @@ func (src Point) Value() (driver.Value, error) { } func (src Point) MarshalJSON() ([]byte, error) { - switch src.Status { - case Present: - var buff bytes.Buffer - buff.WriteByte('"') - buff.WriteString(fmt.Sprintf("(%g,%g)", src.P.X, src.P.Y)) - buff.WriteByte('"') - return buff.Bytes(), nil - case Null: + if !src.Valid { return []byte("null"), nil - case Undefined: - return nil, errUndefined } - return nil, errBadStatus + + var buff bytes.Buffer + buff.WriteByte('"') + buff.WriteString(fmt.Sprintf("(%g,%g)", src.P.X, src.P.Y)) + buff.WriteByte('"') + return buff.Bytes(), nil } func (dst *Point) UnmarshalJSON(point []byte) error { diff --git a/point_test.go b/point_test.go index 63f8df07..82f58e17 100644 --- a/point_test.go +++ b/point_test.go @@ -6,13 +6,14 @@ import ( "github.com/jackc/pgtype" "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" ) func TestPointTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "point", []interface{}{ - &pgtype.Point{P: pgtype.Vec2{1.234, 5.6789012345}, Status: pgtype.Present}, - &pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Status: pgtype.Present}, - &pgtype.Point{Status: pgtype.Null}, + &pgtype.Point{P: pgtype.Vec2{1.234, 5.6789012345}, Valid: true}, + &pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Valid: true}, + &pgtype.Point{}, }) } @@ -20,31 +21,31 @@ func TestPoint_Set(t *testing.T) { tests := []struct { name string arg interface{} - status pgtype.Status + valid bool wantErr bool }{ { name: "first", arg: "(12312.123123,123123.123123)", - status: pgtype.Present, + valid: true, wantErr: false, }, { name: "second", arg: "(1231s2.123123,123123.123123)", - status: pgtype.Undefined, + valid: false, wantErr: true, }, { name: "third", arg: []byte("(122.123123,123.123123)"), - status: pgtype.Present, + valid: true, wantErr: false, }, { name: "third", arg: nil, - status: pgtype.Null, + valid: false, wantErr: false, }, } @@ -54,8 +55,8 @@ func TestPoint_Set(t *testing.T) { if err := dst.Set(tt.arg); (err != nil) != tt.wantErr { t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr) } - if dst.Status != tt.status { - t.Errorf("Expected status: %v; got: %v", tt.status, dst.Status) + if dst.Valid != tt.valid { + t.Errorf("Expected status: %v; got: %v", tt.valid, dst.Valid) } }) } @@ -63,46 +64,30 @@ func TestPoint_Set(t *testing.T) { func TestPoint_MarshalJSON(t *testing.T) { tests := []struct { - name string - point pgtype.Point - want []byte - wantErr bool + name string + point pgtype.Point + want []byte }{ - { - name: "first", - point: pgtype.Point{ - P: pgtype.Vec2{}, - Status: pgtype.Undefined, - }, - want: nil, - wantErr: true, - }, { name: "second", point: pgtype.Point{ - P: pgtype.Vec2{X: 12.245, Y: 432.12}, - Status: pgtype.Present, + P: pgtype.Vec2{X: 12.245, Y: 432.12}, + Valid: true, }, - want: []byte(`"(12.245,432.12)"`), - wantErr: false, + want: []byte(`"(12.245,432.12)"`), }, { name: "third", point: pgtype.Point{ - P: pgtype.Vec2{}, - Status: pgtype.Null, + P: pgtype.Vec2{}, }, - want: []byte("null"), - wantErr: false, + want: []byte("null"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.point.MarshalJSON() - if (err != nil) != tt.wantErr { - t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) - return - } + require.NoError(t, err) if !reflect.DeepEqual(got, tt.want) { t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) } @@ -113,25 +98,25 @@ func TestPoint_MarshalJSON(t *testing.T) { func TestPoint_UnmarshalJSON(t *testing.T) { tests := []struct { name string - status pgtype.Status + valid bool arg []byte wantErr bool }{ { name: "first", - status: pgtype.Present, + valid: true, arg: []byte(`"(123.123,54.12)"`), wantErr: false, }, { name: "second", - status: pgtype.Undefined, + valid: false, arg: []byte(`"(123.123,54.1sad2)"`), wantErr: true, }, { name: "third", - status: pgtype.Null, + valid: false, arg: []byte("null"), wantErr: false, }, @@ -142,8 +127,8 @@ func TestPoint_UnmarshalJSON(t *testing.T) { if err := dst.UnmarshalJSON(tt.arg); (err != nil) != tt.wantErr { t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) } - if dst.Status != tt.status { - t.Errorf("Status mismatch: %v != %v", dst.Status, tt.status) + if dst.Valid != tt.valid { + t.Errorf("Valid mismatch: %v != %v", dst.Valid, tt.valid) } }) } diff --git a/polygon.go b/polygon.go index 207cadc0..956920e6 100644 --- a/polygon.go +++ b/polygon.go @@ -12,8 +12,8 @@ import ( ) type Polygon struct { - P []Vec2 - Status Status + P []Vec2 + Valid bool } // Set converts src to dest. @@ -24,7 +24,7 @@ type Polygon struct { // Important that there are no spaces in it. func (dst *Polygon) Set(src interface{}) error { if src == nil { - dst.Status = Null + dst.Valid = false return nil } err := fmt.Errorf("cannot convert %v to Polygon", src) @@ -33,7 +33,7 @@ func (dst *Polygon) Set(src interface{}) error { case string: p, err = stringToPolygon(value) case []Vec2: - p = &Polygon{Status: Present, P: value} + p = &Polygon{Valid: true, P: value} err = nil case []float64: p, err = float64ToPolygon(value) @@ -54,15 +54,14 @@ func stringToPolygon(src string) (*Polygon, error) { } func float64ToPolygon(src []float64) (*Polygon, error) { - p := &Polygon{Status: Null} + p := &Polygon{} if len(src) == 0 { return p, nil } if len(src)%2 != 0 { - p.Status = Undefined return p, fmt.Errorf("invalid length for polygon: %v", len(src)) } - p.Status = Present + p.Valid = true p.P = make([]Vec2, 0) for i := 0; i < len(src); i += 2 { p.P = append(p.P, Vec2{X: src[i], Y: src[i+1]}) @@ -71,14 +70,10 @@ func float64ToPolygon(src []float64) (*Polygon, error) { } func (dst Polygon) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *Polygon) AssignTo(dst interface{}) error { @@ -87,7 +82,7 @@ func (src *Polygon) AssignTo(dst interface{}) error { func (dst *Polygon) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Polygon{Status: Null} + *dst = Polygon{} return nil } @@ -123,13 +118,13 @@ func (dst *Polygon) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = Polygon{P: points, Status: Present} + *dst = Polygon{P: points, Valid: true} return nil } func (dst *Polygon) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Polygon{Status: Null} + *dst = Polygon{} return nil } @@ -154,18 +149,15 @@ func (dst *Polygon) DecodeBinary(ci *ConnInfo, src []byte) error { } *dst = Polygon{ - P: points, - Status: Present, + P: points, + Valid: true, } return nil } func (src Polygon) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = append(buf, '(') @@ -184,11 +176,8 @@ func (src Polygon) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Polygon) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = pgio.AppendInt32(buf, int32(len(src.P))) @@ -204,7 +193,7 @@ func (src Polygon) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Polygon) Scan(src interface{}) error { if src == nil { - *dst = Polygon{Status: Null} + *dst = Polygon{} return nil } diff --git a/polygon_test.go b/polygon_test.go index 1a139444..34f8d59a 100644 --- a/polygon_test.go +++ b/polygon_test.go @@ -10,14 +10,14 @@ import ( func TestPolygonTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "polygon", []interface{}{ &pgtype.Polygon{ - P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, - Status: pgtype.Present, + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, + Valid: true, }, &pgtype.Polygon{ - P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, - Status: pgtype.Present, + P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, + Valid: true, }, - &pgtype.Polygon{Status: pgtype.Null}, + &pgtype.Polygon{}, }) } @@ -25,53 +25,53 @@ func TestPolygon_Set(t *testing.T) { tests := []struct { name string arg interface{} - status pgtype.Status + valid bool wantErr bool }{ { name: "string", arg: "((3.14,1.678901234),(7.1,5.234),(5.0,3.234))", - status: pgtype.Present, + valid: true, wantErr: false, }, { name: "[]float64", arg: []float64{1, 2, 3.45, 6.78, 1.23, 4.567, 8.9, 1.0}, - status: pgtype.Present, + valid: true, wantErr: false, }, { name: "[]Vec2", arg: []pgtype.Vec2{{1, 2}, {2.3, 4.5}, {6.78, 9.123}}, - status: pgtype.Present, + valid: true, wantErr: false, }, { name: "null", arg: nil, - status: pgtype.Null, + valid: false, wantErr: false, }, { name: "invalid_string_1", arg: "((3.14,1.678901234),(7.1,5.234),(5.0,3.234x))", - status: pgtype.Undefined, + valid: false, wantErr: true, }, { name: "invalid_string_2", arg: "(3,4)", - status: pgtype.Undefined, + valid: false, wantErr: true, }, { name: "invalid_[]float64", arg: []float64{1, 2, 3.45, 6.78, 1.23, 4.567, 8.9}, - status: pgtype.Undefined, + valid: false, wantErr: true, }, { name: "invalid_type", arg: []int{1, 2, 3, 6}, - status: pgtype.Undefined, + valid: false, wantErr: true, }, { name: "empty_[]float64", arg: []float64{}, - status: pgtype.Null, + valid: false, wantErr: false, }, } @@ -81,8 +81,8 @@ func TestPolygon_Set(t *testing.T) { if err := dst.Set(tt.arg); (err != nil) != tt.wantErr { t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr) } - if dst.Status != tt.status { - t.Errorf("Expected status: %v; got: %v", tt.status, dst.Status) + if dst.Valid != tt.valid { + t.Errorf("Expected valid: %v; got: %v", tt.valid, dst.Valid) } }) } diff --git a/qchar.go b/qchar.go index 574f6066..e56bf142 100644 --- a/qchar.go +++ b/qchar.go @@ -18,13 +18,13 @@ import ( // addition, database/sql Scanner and database/sql/driver Value are not // implemented. type QChar struct { - Int int8 - Status Status + Int int8 + Valid bool } func (dst *QChar) Set(src interface{}) error { if src == nil { - *dst = QChar{Status: Null} + *dst = QChar{} return nil } @@ -37,12 +37,12 @@ func (dst *QChar) Set(src interface{}) error { switch value := src.(type) { case int8: - *dst = QChar{Int: value, Status: Present} + *dst = QChar{Int: value, Valid: true} case uint8: if value > math.MaxInt8 { return fmt.Errorf("%d is greater than maximum value for QChar", value) } - *dst = QChar{Int: int8(value), Status: Present} + *dst = QChar{Int: int8(value), Valid: true} case int16: if value < math.MinInt8 { return fmt.Errorf("%d is greater than maximum value for QChar", value) @@ -50,12 +50,12 @@ func (dst *QChar) Set(src interface{}) error { if value > math.MaxInt8 { return fmt.Errorf("%d is greater than maximum value for QChar", value) } - *dst = QChar{Int: int8(value), Status: Present} + *dst = QChar{Int: int8(value), Valid: true} case uint16: if value > math.MaxInt8 { return fmt.Errorf("%d is greater than maximum value for QChar", value) } - *dst = QChar{Int: int8(value), Status: Present} + *dst = QChar{Int: int8(value), Valid: true} case int32: if value < math.MinInt8 { return fmt.Errorf("%d is greater than maximum value for QChar", value) @@ -63,12 +63,12 @@ func (dst *QChar) Set(src interface{}) error { if value > math.MaxInt8 { return fmt.Errorf("%d is greater than maximum value for QChar", value) } - *dst = QChar{Int: int8(value), Status: Present} + *dst = QChar{Int: int8(value), Valid: true} case uint32: if value > math.MaxInt8 { return fmt.Errorf("%d is greater than maximum value for QChar", value) } - *dst = QChar{Int: int8(value), Status: Present} + *dst = QChar{Int: int8(value), Valid: true} case int64: if value < math.MinInt8 { return fmt.Errorf("%d is greater than maximum value for QChar", value) @@ -76,12 +76,12 @@ func (dst *QChar) Set(src interface{}) error { if value > math.MaxInt8 { return fmt.Errorf("%d is greater than maximum value for QChar", value) } - *dst = QChar{Int: int8(value), Status: Present} + *dst = QChar{Int: int8(value), Valid: true} case uint64: if value > math.MaxInt8 { return fmt.Errorf("%d is greater than maximum value for QChar", value) } - *dst = QChar{Int: int8(value), Status: Present} + *dst = QChar{Int: int8(value), Valid: true} case int: if value < math.MinInt8 { return fmt.Errorf("%d is greater than maximum value for QChar", value) @@ -89,18 +89,18 @@ func (dst *QChar) Set(src interface{}) error { if value > math.MaxInt8 { return fmt.Errorf("%d is greater than maximum value for QChar", value) } - *dst = QChar{Int: int8(value), Status: Present} + *dst = QChar{Int: int8(value), Valid: true} case uint: if value > math.MaxInt8 { return fmt.Errorf("%d is greater than maximum value for QChar", value) } - *dst = QChar{Int: int8(value), Status: Present} + *dst = QChar{Int: int8(value), Valid: true} case string: num, err := strconv.ParseInt(value, 10, 8) if err != nil { return err } - *dst = QChar{Int: int8(num), Status: Present} + *dst = QChar{Int: int8(num), Valid: true} default: if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) @@ -112,23 +112,19 @@ func (dst *QChar) Set(src interface{}) error { } func (dst QChar) Get() interface{} { - switch dst.Status { - case Present: - return dst.Int - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.Int } func (src *QChar) AssignTo(dst interface{}) error { - return int64AssignTo(int64(src.Int), src.Status, dst) + return int64AssignTo(int64(src.Int), src.Valid, dst) } func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = QChar{Status: Null} + *dst = QChar{} return nil } @@ -136,16 +132,13 @@ func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { return fmt.Errorf(`invalid length for "char": %v`, len(src)) } - *dst = QChar{Int: int8(src[0]), Status: Present} + *dst = QChar{Int: int8(src[0]), Valid: true} return nil } func (src QChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return append(buf, byte(src.Int)), nil diff --git a/qchar_test.go b/qchar_test.go index 4b60339c..eb54bf65 100644 --- a/qchar_test.go +++ b/qchar_test.go @@ -11,12 +11,12 @@ import ( func TestQCharTranscode(t *testing.T) { testutil.TestPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ - &pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, - &pgtype.QChar{Int: -1, Status: pgtype.Present}, - &pgtype.QChar{Int: 0, Status: pgtype.Present}, - &pgtype.QChar{Int: 1, Status: pgtype.Present}, - &pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present}, - &pgtype.QChar{Int: 0, Status: pgtype.Null}, + &pgtype.QChar{Int: math.MinInt8, Valid: true}, + &pgtype.QChar{Int: -1, Valid: true}, + &pgtype.QChar{Int: 0, Valid: true}, + &pgtype.QChar{Int: 1, Valid: true}, + &pgtype.QChar{Int: math.MaxInt8, Valid: true}, + &pgtype.QChar{Int: 0}, }, func(a, b interface{}) bool { return reflect.DeepEqual(a, b) }) @@ -27,20 +27,20 @@ func TestQCharSet(t *testing.T) { source interface{} result pgtype.QChar }{ - {source: int8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int8(1), result: pgtype.QChar{Int: 1, Valid: true}}, + {source: int16(1), result: pgtype.QChar{Int: 1, Valid: true}}, + {source: int32(1), result: pgtype.QChar{Int: 1, Valid: true}}, + {source: int64(1), result: pgtype.QChar{Int: 1, Valid: true}}, + {source: int8(-1), result: pgtype.QChar{Int: -1, Valid: true}}, + {source: int16(-1), result: pgtype.QChar{Int: -1, Valid: true}}, + {source: int32(-1), result: pgtype.QChar{Int: -1, Valid: true}}, + {source: int64(-1), result: pgtype.QChar{Int: -1, Valid: true}}, + {source: uint8(1), result: pgtype.QChar{Int: 1, Valid: true}}, + {source: uint16(1), result: pgtype.QChar{Int: 1, Valid: true}}, + {source: uint32(1), result: pgtype.QChar{Int: 1, Valid: true}}, + {source: uint64(1), result: pgtype.QChar{Int: 1, Valid: true}}, + {source: "1", result: pgtype.QChar{Int: 1, Valid: true}}, + {source: _int8(1), result: pgtype.QChar{Int: 1, Valid: true}}, } for i, tt := range successfulTests { @@ -76,19 +76,19 @@ func TestQCharAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i8, expected: int8(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i16, expected: int16(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i32, expected: int32(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i64, expected: int64(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i, expected: int(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui, expected: uint(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.QChar{Int: 0}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.QChar{Int: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, } for i, tt := range simpleTests { @@ -107,8 +107,8 @@ func TestQCharAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &pi8, expected: int8(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &_pi8, expected: _int8(42)}, } for i, tt := range pointerAllocTests { @@ -126,12 +126,12 @@ func TestQCharAssignTo(t *testing.T) { src pgtype.QChar dst interface{} }{ - {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &i16}, + {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui8}, + {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui16}, + {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui32}, + {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui64}, + {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui}, + {src: pgtype.QChar{Int: 0}, dst: &i16}, } for i, tt := range errorTests { diff --git a/record.go b/record.go index 718c3570..20b119c6 100644 --- a/record.go +++ b/record.go @@ -12,12 +12,12 @@ import ( // PostgreSQL does not support input of generic records. type Record struct { Fields []Value - Status Status + Valid bool } func (dst *Record) Set(src interface{}) error { if src == nil { - *dst = Record{Status: Null} + *dst = Record{} return nil } @@ -30,7 +30,7 @@ func (dst *Record) Set(src interface{}) error { switch value := src.(type) { case []Value: - *dst = Record{Fields: value, Status: Present} + *dst = Record{Fields: value, Valid: true} default: return fmt.Errorf("cannot convert %v to Record", src) } @@ -39,41 +39,34 @@ func (dst *Record) Set(src interface{}) error { } func (dst Record) Get() interface{} { - switch dst.Status { - case Present: - return dst.Fields - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.Fields } func (src *Record) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *[]Value: - *v = make([]Value, len(src.Fields)) - copy(*v, src.Fields) - return nil - case *[]interface{}: - *v = make([]interface{}, len(src.Fields)) - for i := range *v { - (*v)[i] = src.Fields[i].Get() - } - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + switch v := dst.(type) { + case *[]Value: + *v = make([]Value, len(src.Fields)) + copy(*v, src.Fields) + return nil + case *[]interface{}: + *v = make([]interface{}, len(src.Fields)) + for i := range *v { + (*v)[i] = src.Fields[i].Get() + } + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } } func prepareNewBinaryDecoder(ci *ConnInfo, fieldOID uint32, v *Value) (BinaryDecoder, error) { @@ -97,7 +90,7 @@ func prepareNewBinaryDecoder(ci *ConnInfo, fieldOID uint32, v *Value) (BinaryDec func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Record{Status: Null} + *dst = Record{} return nil } @@ -120,7 +113,7 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { return scanner.Err() } - *dst = Record{Fields: fields, Status: Present} + *dst = Record{Fields: fields, Valid: true} return nil } diff --git a/record_test.go b/record_test.go index 240812a6..c8e7d4b7 100644 --- a/record_test.go +++ b/record_test.go @@ -19,63 +19,61 @@ var recordTests = []struct { sql: `select row()`, expected: pgtype.Record{ Fields: []pgtype.Value{}, - Status: pgtype.Present, + Valid: true, }, }, { sql: `select row('foo'::text, 42::int4)`, expected: pgtype.Record{ Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, + &pgtype.Text{String: "foo", Valid: true}, + &pgtype.Int4{Int: 42, Valid: true}, }, - Status: pgtype.Present, + Valid: true, }, }, { sql: `select row(100.0::float4, 1.09::float4)`, expected: pgtype.Record{ Fields: []pgtype.Value{ - &pgtype.Float4{Float: 100, Status: pgtype.Present}, - &pgtype.Float4{Float: 1.09, Status: pgtype.Present}, + &pgtype.Float4{Float: 100, Valid: true}, + &pgtype.Float4{Float: 1.09, Valid: true}, }, - Status: pgtype.Present, + Valid: true, }, }, { sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, expected: pgtype.Record{ Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Text{String: "foo", Valid: true}, &pgtype.Int4Array{ Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: 4, Status: pgtype.Present}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {}, + {Int: 4, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Valid: true}, }, - Status: pgtype.Present, + Valid: true, }, }, { sql: `select row(null)`, expected: pgtype.Record{ Fields: []pgtype.Value{ - &pgtype.Unknown{Status: pgtype.Null}, + &pgtype.Unknown{}, }, - Status: pgtype.Present, + Valid: true, }, }, { - sql: `select null::record`, - expected: pgtype.Record{ - Status: pgtype.Null, - }, + sql: `select null::record`, + expected: pgtype.Record{}, }, } @@ -139,35 +137,35 @@ func TestRecordAssignTo(t *testing.T) { { src: pgtype.Record{ Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, + &pgtype.Text{String: "foo", Valid: true}, + &pgtype.Int4{Int: 42, Valid: true}, }, - Status: pgtype.Present, + Valid: true, }, dst: &valueSlice, expected: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, + &pgtype.Text{String: "foo", Valid: true}, + &pgtype.Int4{Int: 42, Valid: true}, }, }, { src: pgtype.Record{ Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, + &pgtype.Text{String: "foo", Valid: true}, + &pgtype.Int4{Int: 42, Valid: true}, }, - Status: pgtype.Present, + Valid: true, }, dst: &interfaceSlice, expected: []interface{}{"foo", int32(42)}, }, { - src: pgtype.Record{Status: pgtype.Null}, + src: pgtype.Record{}, dst: &valueSlice, expected: (([]pgtype.Value)(nil)), }, { - src: pgtype.Record{Status: pgtype.Null}, + src: pgtype.Record{}, dst: &interfaceSlice, expected: (([]interface{})(nil)), }, diff --git a/text.go b/text.go index a01815d9..5d27c44f 100644 --- a/text.go +++ b/text.go @@ -8,12 +8,12 @@ import ( type Text struct { String string - Status Status + Valid bool } func (dst *Text) Set(src interface{}) error { if src == nil { - *dst = Text{Status: Null} + *dst = Text{} return nil } @@ -26,24 +26,24 @@ func (dst *Text) Set(src interface{}) error { switch value := src.(type) { case string: - *dst = Text{String: value, Status: Present} + *dst = Text{String: value, Valid: true} case *string: if value == nil { - *dst = Text{Status: Null} + *dst = Text{} } else { - *dst = Text{String: *value, Status: Present} + *dst = Text{String: *value, Valid: true} } case []byte: if value == nil { - *dst = Text{Status: Null} + *dst = Text{} } else { - *dst = Text{String: string(value), Status: Present} + *dst = Text{String: string(value), Valid: true} } case fmt.Stringer: if value == fmt.Stringer(nil) { - *dst = Text{Status: Null} + *dst = Text{} } else { - *dst = Text{String: value.String(), Status: Present} + *dst = Text{String: value.String(), Valid: true} } default: // Cannot be part of the switch: If Value() returns nil on @@ -54,7 +54,7 @@ func (dst *Text) Set(src interface{}) error { // pointer receiver and fmt.Stringer with value receiver. if value, ok := src.(driver.Valuer); ok { if value == driver.Valuer(nil) { - *dst = Text{Status: Null} + *dst = Text{} return nil } else { v, err := value.Value() @@ -64,7 +64,7 @@ func (dst *Text) Set(src interface{}) error { // Handles also v == nil case. if s, ok := v.(string); ok { - *dst = Text{String: s, Status: Present} + *dst = Text{String: s, Valid: true} return nil } } @@ -80,38 +80,31 @@ func (dst *Text) Set(src interface{}) error { } func (dst Text) Get() interface{} { - switch dst.Status { - case Present: - return dst.String - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.String } func (src *Text) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *string: - *v = src.String - return nil - case *[]byte: - *v = make([]byte, len(src.String)) - copy(*v, src.String) - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + switch v := dst.(type) { + case *string: + *v = src.String + return nil + case *[]byte: + *v = make([]byte, len(src.String)) + copy(*v, src.String) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } } func (Text) PreferredResultFormat() int16 { @@ -120,11 +113,11 @@ func (Text) PreferredResultFormat() int16 { func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Text{Status: Null} + *dst = Text{} return nil } - *dst = Text{String: string(src), Status: Present} + *dst = Text{String: string(src), Valid: true} return nil } @@ -137,11 +130,8 @@ func (Text) PreferredParamFormat() int16 { } func (src Text) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return append(buf, src.String...), nil @@ -154,7 +144,7 @@ func (src Text) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Text) Scan(src interface{}) error { if src == nil { - *dst = Text{Status: Null} + *dst = Text{} return nil } @@ -172,27 +162,18 @@ func (dst *Text) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Text) Value() (driver.Value, error) { - switch src.Status { - case Present: - return src.String, nil - case Null: + if !src.Valid { return nil, nil - default: - return nil, errUndefined } + return src.String, nil } func (src Text) MarshalJSON() ([]byte, error) { - switch src.Status { - case Present: - return json.Marshal(src.String) - case Null: + if !src.Valid { return []byte("null"), nil - case Undefined: - return nil, errUndefined } - return nil, errBadStatus + return json.Marshal(src.String) } func (dst *Text) UnmarshalJSON(b []byte) error { @@ -203,9 +184,9 @@ func (dst *Text) UnmarshalJSON(b []byte) error { } if s == nil { - *dst = Text{Status: Null} + *dst = Text{} } else { - *dst = Text{String: *s, Status: Present} + *dst = Text{String: *s, Valid: true} } return nil diff --git a/text_array.go b/text_array.go index 2461966b..7fcc1c4d 100644 --- a/text_array.go +++ b/text_array.go @@ -14,13 +14,13 @@ import ( type TextArray struct { Elements []Text Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *TextArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = TextArray{Status: Null} + *dst = TextArray{} return nil } @@ -36,9 +36,9 @@ func (dst *TextArray) Set(src interface{}) error { case []string: if value == nil { - *dst = TextArray{Status: Null} + *dst = TextArray{} } else if len(value) == 0 { - *dst = TextArray{Status: Present} + *dst = TextArray{Valid: true} } else { elements := make([]Text, len(value)) for i := range value { @@ -49,15 +49,15 @@ func (dst *TextArray) Set(src interface{}) error { *dst = TextArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*string: if value == nil { - *dst = TextArray{Status: Null} + *dst = TextArray{} } else if len(value) == 0 { - *dst = TextArray{Status: Present} + *dst = TextArray{Valid: true} } else { elements := make([]Text, len(value)) for i := range value { @@ -68,20 +68,20 @@ func (dst *TextArray) Set(src interface{}) error { *dst = TextArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []Text: if value == nil { - *dst = TextArray{Status: Null} + *dst = TextArray{} } else if len(value) == 0 { - *dst = TextArray{Status: Present} + *dst = TextArray{Valid: true} } else { *dst = TextArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -90,7 +90,7 @@ func (dst *TextArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = TextArray{Status: Null} + *dst = TextArray{} return nil } @@ -99,7 +99,7 @@ func (dst *TextArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for TextArray", src) } if elementsLength == 0 { - *dst = TextArray{Status: Present} + *dst = TextArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -112,7 +112,7 @@ func (dst *TextArray) Set(src interface{}) error { *dst = TextArray{ Elements: make([]Text, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -179,84 +179,77 @@ func (dst *TextArray) setRecursive(value reflect.Value, index, dimension int) (i } func (dst TextArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *TextArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *TextArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -308,7 +301,7 @@ func (src *TextArray) assignToRecursive(value reflect.Value, index, dimension in func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = TextArray{Status: Null} + *dst = TextArray{} return nil } @@ -337,14 +330,14 @@ func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = TextArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = TextArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = TextArray{Status: Null} + *dst = TextArray{} return nil } @@ -355,7 +348,7 @@ func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = TextArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = TextArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -380,16 +373,13 @@ func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = TextArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = TextArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -442,11 +432,8 @@ func (src TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -460,7 +447,7 @@ func (src TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/text_array_test.go b/text_array_test.go index a5d050f6..4caeb692 100644 --- a/text_array_test.go +++ b/text_array_test.go @@ -16,8 +16,8 @@ func TestTextArrayDecodeTextNull(t *testing.T) { err := textArray.DecodeText(nil, []byte(`{abc,"NULL",NULL,def}`)) require.NoError(t, err) require.Len(t, textArray.Elements, 4) - assert.Equal(t, pgtype.Present, textArray.Elements[1].Status) - assert.Equal(t, pgtype.Null, textArray.Elements[2].Status) + assert.Equal(t, true, textArray.Elements[1].Valid) + assert.Equal(t, false, textArray.Elements[2].Valid) } func TestTextArrayTranscode(t *testing.T) { @@ -25,41 +25,41 @@ func TestTextArrayTranscode(t *testing.T) { &pgtype.TextArray{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.TextArray{ Elements: []pgtype.Text{ - {String: "foo", Status: pgtype.Present}, - {Status: pgtype.Null}, + {String: "foo", Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.TextArray{Status: pgtype.Null}, + &pgtype.TextArray{}, &pgtype.TextArray{ Elements: []pgtype.Text{ - {String: "bar ", Status: pgtype.Present}, - {String: "NuLL", Status: pgtype.Present}, - {String: `wow"quz\`, Status: pgtype.Present}, - {String: "", Status: pgtype.Present}, - {Status: pgtype.Null}, - {String: "null", Status: pgtype.Present}, + {String: "bar ", Valid: true}, + {String: "NuLL", Valid: true}, + {String: `wow"quz\`, Valid: true}, + {String: "", Valid: true}, + {}, + {String: "null", Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.TextArray{ Elements: []pgtype.Text{ - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "quz", Status: pgtype.Present}, - {String: "foo", Status: pgtype.Present}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "quz", Valid: true}, + {String: "foo", Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } @@ -72,61 +72,61 @@ func TestTextArraySet(t *testing.T) { { source: []string{"foo"}, result: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}}, + Elements: []pgtype.Text{{String: "foo", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]string)(nil)), - result: pgtype.TextArray{Status: pgtype.Null}, + result: pgtype.TextArray{}, }, { source: [][]string{{"foo"}, {"bar"}}, result: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, result: pgtype.TextArray{ Elements: []pgtype.Text{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1]string{{"foo"}, {"bar"}}, result: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, result: pgtype.TextArray{ Elements: []pgtype.Text{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -159,81 +159,81 @@ func TestTextArrayAssignTo(t *testing.T) { }{ { src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}}, + Elements: []pgtype.Text{{String: "foo", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &stringSlice, expected: []string{"foo"}, }, { src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.Text{{String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &namedStringSlice, expected: _stringSlice{"bar"}, }, { - src: pgtype.TextArray{Status: pgtype.Null}, + src: pgtype.TextArray{}, dst: &stringSlice, expected: (([]string)(nil)), }, { - src: pgtype.TextArray{Status: pgtype.Present}, + src: pgtype.TextArray{Valid: true}, dst: &stringSlice, expected: []string{}, }, { src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringSliceDim2, expected: [][]string{{"foo"}, {"bar"}}, }, { src: pgtype.TextArray{ Elements: []pgtype.Text{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringSliceDim4, expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, }, { src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim2, expected: [2][1]string{{"foo"}, {"bar"}}, }, { src: pgtype.TextArray{ Elements: []pgtype.Text{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim4, expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, }, @@ -256,31 +256,31 @@ func TestTextArrayAssignTo(t *testing.T) { }{ { src: pgtype.TextArray{ - Elements: []pgtype.Text{{Status: pgtype.Null}}, + Elements: []pgtype.Text{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &stringSlice, }, { src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim2, }, { src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringSlice, }, { src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim4, }, } diff --git a/text_test.go b/text_test.go index cca3a05d..5f34f8c0 100644 --- a/text_test.go +++ b/text_test.go @@ -12,9 +12,9 @@ import ( func TestTextTranscode(t *testing.T) { for _, pgTypeName := range []string{"text", "varchar"} { testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ - &pgtype.Text{String: "", Status: pgtype.Present}, - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Text{Status: pgtype.Null}, + &pgtype.Text{String: "", Valid: true}, + &pgtype.Text{String: "foo", Valid: true}, + &pgtype.Text{}, }) } } @@ -24,9 +24,9 @@ func TestTextSet(t *testing.T) { source interface{} result pgtype.Text }{ - {source: "foo", result: pgtype.Text{String: "foo", Status: pgtype.Present}}, - {source: _string("bar"), result: pgtype.Text{String: "bar", Status: pgtype.Present}}, - {source: (*string)(nil), result: pgtype.Text{Status: pgtype.Null}}, + {source: "foo", result: pgtype.Text{String: "foo", Valid: true}}, + {source: _string("bar"), result: pgtype.Text{String: "bar", Valid: true}}, + {source: (*string)(nil), result: pgtype.Text{}}, } for i, tt := range successfulTests { @@ -51,8 +51,8 @@ func TestTextAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &s, expected: "foo"}, - {src: pgtype.Text{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + {src: pgtype.Text{String: "foo", Valid: true}, dst: &s, expected: "foo"}, + {src: pgtype.Text{}, dst: &ps, expected: ((*string)(nil))}, } for i, tt := range stringTests { @@ -73,8 +73,8 @@ func TestTextAssignTo(t *testing.T) { dst *[]byte expected []byte }{ - {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &buf, expected: []byte("foo")}, - {src: pgtype.Text{Status: pgtype.Null}, dst: &buf, expected: nil}, + {src: pgtype.Text{String: "foo", Valid: true}, dst: &buf, expected: []byte("foo")}, + {src: pgtype.Text{}, dst: &buf, expected: nil}, } for i, tt := range bytesTests { @@ -93,7 +93,7 @@ func TestTextAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &ps, expected: "foo"}, + {src: pgtype.Text{String: "foo", Valid: true}, dst: &ps, expected: "foo"}, } for i, tt := range pointerAllocTests { @@ -111,7 +111,7 @@ func TestTextAssignTo(t *testing.T) { src pgtype.Text dst interface{} }{ - {src: pgtype.Text{Status: pgtype.Null}, dst: &s}, + {src: pgtype.Text{}, dst: &s}, } for i, tt := range errorTests { @@ -127,8 +127,8 @@ func TestTextMarshalJSON(t *testing.T) { source pgtype.Text result string }{ - {source: pgtype.Text{String: "", Status: pgtype.Null}, result: "null"}, - {source: pgtype.Text{String: "a", Status: pgtype.Present}, result: "\"a\""}, + {source: pgtype.Text{String: ""}, result: "null"}, + {source: pgtype.Text{String: "a", Valid: true}, result: "\"a\""}, } for i, tt := range successfulTests { r, err := tt.source.MarshalJSON() @@ -147,8 +147,8 @@ func TestTextUnmarshalJSON(t *testing.T) { source string result pgtype.Text }{ - {source: "null", result: pgtype.Text{String: "", Status: pgtype.Null}}, - {source: "\"a\"", result: pgtype.Text{String: "a", Status: pgtype.Present}}, + {source: "null", result: pgtype.Text{String: ""}}, + {source: "\"a\"", result: pgtype.Text{String: "a", Valid: true}}, } for i, tt := range successfulTests { var r pgtype.Text diff --git a/tid.go b/tid.go index 4bb57f64..0108d219 100644 --- a/tid.go +++ b/tid.go @@ -24,7 +24,7 @@ import ( type TID struct { BlockNumber uint32 OffsetNumber uint16 - Status Status + Valid bool } func (dst *TID) Set(src interface{}) error { @@ -32,36 +32,32 @@ func (dst *TID) Set(src interface{}) error { } func (dst TID) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *TID) AssignTo(dst interface{}) error { - if src.Status == Present { - switch v := dst.(type) { - case *string: - *v = fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber) - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } + if !src.Valid { + return fmt.Errorf("cannot assign %v to %T", src, dst) } - return fmt.Errorf("cannot assign %v to %T", src, dst) + switch v := dst.(type) { + case *string: + *v = fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } } func (dst *TID) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = TID{Status: Null} + *dst = TID{} return nil } @@ -84,13 +80,13 @@ func (dst *TID) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Status: Present} + *dst = TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Valid: true} return nil } func (dst *TID) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = TID{Status: Null} + *dst = TID{} return nil } @@ -101,17 +97,14 @@ func (dst *TID) DecodeBinary(ci *ConnInfo, src []byte) error { *dst = TID{ BlockNumber: binary.BigEndian.Uint32(src), OffsetNumber: binary.BigEndian.Uint16(src[4:]), - Status: Present, + Valid: true, } return nil } func (src TID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = append(buf, fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber)...) @@ -119,11 +112,8 @@ func (src TID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src TID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = pgio.AppendUint32(buf, src.BlockNumber) @@ -134,7 +124,7 @@ func (src TID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *TID) Scan(src interface{}) error { if src == nil { - *dst = TID{Status: Null} + *dst = TID{} return nil } diff --git a/tid_test.go b/tid_test.go index 818be8af..fcf93259 100644 --- a/tid_test.go +++ b/tid_test.go @@ -10,9 +10,9 @@ import ( func TestTIDTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "tid", []interface{}{ - &pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, - &pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, - &pgtype.TID{Status: pgtype.Null}, + &pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, + &pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, + &pgtype.TID{}, }) } @@ -25,8 +25,8 @@ func TestTIDAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, dst: &s, expected: "(42,43)"}, - {src: pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, dst: &s, expected: "(4294967295,65535)"}, + {src: pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, dst: &s, expected: "(42,43)"}, + {src: pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, dst: &s, expected: "(4294967295,65535)"}, } for i, tt := range simpleTests { @@ -45,8 +45,8 @@ func TestTIDAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, dst: &sp, expected: "(42,43)"}, - {src: pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, dst: &sp, expected: "(4294967295,65535)"}, + {src: pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, dst: &sp, expected: "(42,43)"}, + {src: pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, dst: &sp, expected: "(4294967295,65535)"}, } for i, tt := range pointerAllocTests { @@ -60,4 +60,3 @@ func TestTIDAssignTo(t *testing.T) { } } } - diff --git a/time.go b/time.go index f7a28870..3252a633 100644 --- a/time.go +++ b/time.go @@ -17,13 +17,13 @@ import ( // to needing to handle 24:00:00. time.Time converts that to 00:00:00 on the following day. type Time struct { Microseconds int64 // Number of microseconds since midnight - Status Status + Valid bool } // Set converts src into a Time and stores in dst. func (dst *Time) Set(src interface{}) error { if src == nil { - *dst = Time{Status: Null} + *dst = Time{} return nil } @@ -40,10 +40,10 @@ func (dst *Time) Set(src interface{}) error { int64(value.Minute())*microsecondsPerMinute + int64(value.Second())*microsecondsPerSecond + int64(value.Nanosecond())/1000 - *dst = Time{Microseconds: usec, Status: Present} + *dst = Time{Microseconds: usec, Valid: true} case *time.Time: if value == nil { - *dst = Time{Status: Null} + *dst = Time{} } else { return dst.Set(*value) } @@ -58,54 +58,47 @@ func (dst *Time) Set(src interface{}) error { } func (dst Time) Get() interface{} { - switch dst.Status { - case Present: - return dst.Microseconds - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.Microseconds } func (src *Time) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *time.Time: - // 24:00:00 is max allowed time in PostgreSQL, but time.Time will normalize that to 00:00:00 the next day. - var maxRepresentableByTime int64 = 24*60*60*1000000 - 1 - if src.Microseconds > maxRepresentableByTime { - return fmt.Errorf("%d microseconds cannot be represented as time.Time", src.Microseconds) - } - - usec := src.Microseconds - hours := usec / microsecondsPerHour - usec -= hours * microsecondsPerHour - minutes := usec / microsecondsPerMinute - usec -= minutes * microsecondsPerMinute - seconds := usec / microsecondsPerSecond - usec -= seconds * microsecondsPerSecond - ns := usec * 1000 - *v = time.Date(2000, 1, 1, int(hours), int(minutes), int(seconds), int(ns), time.UTC) - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + switch v := dst.(type) { + case *time.Time: + // 24:00:00 is max allowed time in PostgreSQL, but time.Time will normalize that to 00:00:00 the next day. + var maxRepresentableByTime int64 = 24*60*60*1000000 - 1 + if src.Microseconds > maxRepresentableByTime { + return fmt.Errorf("%d microseconds cannot be represented as time.Time", src.Microseconds) + } + + usec := src.Microseconds + hours := usec / microsecondsPerHour + usec -= hours * microsecondsPerHour + minutes := usec / microsecondsPerMinute + usec -= minutes * microsecondsPerMinute + seconds := usec / microsecondsPerSecond + usec -= seconds * microsecondsPerSecond + ns := usec * 1000 + *v = time.Date(2000, 1, 1, int(hours), int(minutes), int(seconds), int(ns), time.UTC) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } } // DecodeText decodes from src into dst. func (dst *Time) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Time{Status: Null} + *dst = Time{} return nil } @@ -147,7 +140,7 @@ func (dst *Time) DecodeText(ci *ConnInfo, src []byte) error { usec += n } - *dst = Time{Microseconds: usec, Status: Present} + *dst = Time{Microseconds: usec, Valid: true} return nil } @@ -155,7 +148,7 @@ func (dst *Time) DecodeText(ci *ConnInfo, src []byte) error { // DecodeBinary decodes from src into dst. func (dst *Time) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Time{Status: Null} + *dst = Time{} return nil } @@ -164,18 +157,15 @@ func (dst *Time) DecodeBinary(ci *ConnInfo, src []byte) error { } usec := int64(binary.BigEndian.Uint64(src)) - *dst = Time{Microseconds: usec, Status: Present} + *dst = Time{Microseconds: usec, Valid: true} return nil } // EncodeText writes the text encoding of src into w. func (src Time) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } usec := src.Microseconds @@ -194,11 +184,8 @@ func (src Time) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { // EncodeBinary writes the binary encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. func (src Time) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return pgio.AppendInt64(buf, src.Microseconds), nil @@ -207,7 +194,7 @@ func (src Time) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Time) Scan(src interface{}) error { if src == nil { - *dst = Time{Status: Null} + *dst = Time{} return nil } diff --git a/time_test.go b/time_test.go index 09ca3c4d..4a989375 100644 --- a/time_test.go +++ b/time_test.go @@ -11,11 +11,11 @@ import ( func TestTimeTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "time", []interface{}{ - &pgtype.Time{Microseconds: 0, Status: pgtype.Present}, - &pgtype.Time{Microseconds: 1, Status: pgtype.Present}, - &pgtype.Time{Microseconds: 86399999999, Status: pgtype.Present}, - &pgtype.Time{Microseconds: 86400000000, Status: pgtype.Present}, - &pgtype.Time{Status: pgtype.Null}, + &pgtype.Time{Microseconds: 0, Valid: true}, + &pgtype.Time{Microseconds: 1, Valid: true}, + &pgtype.Time{Microseconds: 86399999999, Valid: true}, + &pgtype.Time{Microseconds: 86400000000, Valid: true}, + &pgtype.Time{}, }) } @@ -26,18 +26,18 @@ func TestTimeSet(t *testing.T) { source interface{} result pgtype.Time }{ - {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 0, Status: pgtype.Present}}, - {source: time.Date(1900, 1, 1, 1, 0, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 3600000000, Status: pgtype.Present}}, - {source: time.Date(1900, 1, 1, 0, 1, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 60000000, Status: pgtype.Present}}, - {source: time.Date(1900, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Time{Microseconds: 1000000, Status: pgtype.Present}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 1, time.UTC), result: pgtype.Time{Microseconds: 0, Status: pgtype.Present}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 1000, time.UTC), result: pgtype.Time{Microseconds: 1, Status: pgtype.Present}}, - {source: time.Date(1999, 12, 31, 23, 59, 59, 999999999, time.UTC), result: pgtype.Time{Microseconds: 86399999999, Status: pgtype.Present}}, - {source: time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local), result: pgtype.Time{Microseconds: 2, Status: pgtype.Present}}, - {source: func(t time.Time) *time.Time { return &t }(time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local)), result: pgtype.Time{Microseconds: 2, Status: pgtype.Present}}, - {source: nil, result: pgtype.Time{Status: pgtype.Null}}, - {source: (*time.Time)(nil), result: pgtype.Time{Status: pgtype.Null}}, - {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 3000, time.UTC)), result: pgtype.Time{Microseconds: 3, Status: pgtype.Present}}, + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 0, Valid: true}}, + {source: time.Date(1900, 1, 1, 1, 0, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 3600000000, Valid: true}}, + {source: time.Date(1900, 1, 1, 0, 1, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 60000000, Valid: true}}, + {source: time.Date(1900, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Time{Microseconds: 1000000, Valid: true}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 1, time.UTC), result: pgtype.Time{Microseconds: 0, Valid: true}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 1000, time.UTC), result: pgtype.Time{Microseconds: 1, Valid: true}}, + {source: time.Date(1999, 12, 31, 23, 59, 59, 999999999, time.UTC), result: pgtype.Time{Microseconds: 86399999999, Valid: true}}, + {source: time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local), result: pgtype.Time{Microseconds: 2, Valid: true}}, + {source: func(t time.Time) *time.Time { return &t }(time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local)), result: pgtype.Time{Microseconds: 2, Valid: true}}, + {source: nil, result: pgtype.Time{}}, + {source: (*time.Time)(nil), result: pgtype.Time{}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 3000, time.UTC)), result: pgtype.Time{Microseconds: 3, Valid: true}}, } for i, tt := range successfulTests { @@ -62,13 +62,13 @@ func TestTimeAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Time{Microseconds: 0, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}, - {src: pgtype.Time{Microseconds: 3600000000, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 1, 0, 0, 0, time.UTC)}, - {src: pgtype.Time{Microseconds: 60000000, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 1, 0, 0, time.UTC)}, - {src: pgtype.Time{Microseconds: 1000000, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC)}, - {src: pgtype.Time{Microseconds: 1, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 0, 1000, time.UTC)}, - {src: pgtype.Time{Microseconds: 86399999999, Status: pgtype.Present}, dst: &tim, expected: time.Date(2000, 1, 1, 23, 59, 59, 999999000, time.UTC)}, - {src: pgtype.Time{Microseconds: 0, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + {src: pgtype.Time{Microseconds: 0, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 3600000000, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 1, 0, 0, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 60000000, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 1, 0, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 1000000, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 1, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 0, 1000, time.UTC)}, + {src: pgtype.Time{Microseconds: 86399999999, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 23, 59, 59, 999999000, time.UTC)}, + {src: pgtype.Time{Microseconds: 0}, dst: &ptim, expected: ((*time.Time)(nil))}, } for i, tt := range simpleTests { @@ -87,7 +87,7 @@ func TestTimeAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Time{Microseconds: 0, Status: pgtype.Present}, dst: &ptim, expected: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 0, Valid: true}, dst: &ptim, expected: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}, } for i, tt := range pointerAllocTests { @@ -105,7 +105,7 @@ func TestTimeAssignTo(t *testing.T) { src pgtype.Time dst interface{} }{ - {src: pgtype.Time{Microseconds: 86400000000, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Time{Microseconds: 86400000000, Valid: true}, dst: &tim}, } for i, tt := range errorTests { diff --git a/timestamp.go b/timestamp.go index 5517acb1..882cd41a 100644 --- a/timestamp.go +++ b/timestamp.go @@ -18,7 +18,7 @@ const pgTimestampFormat = "2006-01-02 15:04:05.999999999" // convert to UTC or return an error on non-UTC times. type Timestamp struct { Time time.Time // Time must always be in UTC. - Status Status + Valid bool InfinityModifier InfinityModifier } @@ -26,7 +26,7 @@ type Timestamp struct { // time.Time in a non-UTC time zone, the time zone is discarded. func (dst *Timestamp) Set(src interface{}) error { if src == nil { - *dst = Timestamp{Status: Null} + *dst = Timestamp{} return nil } @@ -39,15 +39,15 @@ func (dst *Timestamp) Set(src interface{}) error { switch value := src.(type) { case time.Time: - *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present} + *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Valid: true} case *time.Time: if value == nil { - *dst = Timestamp{Status: Null} + *dst = Timestamp{} } else { return dst.Set(*value) } case InfinityModifier: - *dst = Timestamp{InfinityModifier: value, Status: Present} + *dst = Timestamp{InfinityModifier: value, Valid: true} default: if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) @@ -59,63 +59,56 @@ func (dst *Timestamp) Set(src interface{}) error { } func (dst Timestamp) Get() interface{} { - switch dst.Status { - case Present: - if dst.InfinityModifier != None { - return dst.InfinityModifier - } - return dst.Time - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst.Time } func (src *Timestamp) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *time.Time: - if src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } } // DecodeText decodes from src into dst. The decoded time is considered to // be in UTC. func (dst *Timestamp) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Timestamp{Status: Null} + *dst = Timestamp{} return nil } sbuf := string(src) switch sbuf { case "infinity": - *dst = Timestamp{Status: Present, InfinityModifier: Infinity} + *dst = Timestamp{Valid: true, InfinityModifier: Infinity} case "-infinity": - *dst = Timestamp{Status: Present, InfinityModifier: -Infinity} + *dst = Timestamp{Valid: true, InfinityModifier: -Infinity} default: tim, err := time.Parse(pgTimestampFormat, sbuf) if err != nil { return err } - *dst = Timestamp{Time: tim, Status: Present} + *dst = Timestamp{Time: tim, Valid: true} } return nil @@ -125,7 +118,7 @@ func (dst *Timestamp) DecodeText(ci *ConnInfo, src []byte) error { // be in UTC. func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Timestamp{Status: Null} + *dst = Timestamp{} return nil } @@ -137,15 +130,15 @@ func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { switch microsecSinceY2K { case infinityMicrosecondOffset: - *dst = Timestamp{Status: Present, InfinityModifier: Infinity} + *dst = Timestamp{Valid: true, InfinityModifier: Infinity} case negativeInfinityMicrosecondOffset: - *dst = Timestamp{Status: Present, InfinityModifier: -Infinity} + *dst = Timestamp{Valid: true, InfinityModifier: -Infinity} default: tim := time.Unix( microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), ).UTC() - *dst = Timestamp{Time: tim, Status: Present} + *dst = Timestamp{Time: tim, Valid: true} } return nil @@ -154,11 +147,8 @@ func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { // EncodeText writes the text encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. func (src Timestamp) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if src.Time.Location() != time.UTC { return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") @@ -181,11 +171,8 @@ func (src Timestamp) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { // EncodeBinary writes the binary encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. func (src Timestamp) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if src.Time.Location() != time.UTC { return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") @@ -208,7 +195,7 @@ func (src Timestamp) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Timestamp) Scan(src interface{}) error { if src == nil { - *dst = Timestamp{Status: Null} + *dst = Timestamp{} return nil } @@ -220,7 +207,7 @@ func (dst *Timestamp) Scan(src interface{}) error { copy(srcCopy, src) return dst.DecodeText(nil, srcCopy) case time.Time: - *dst = Timestamp{Time: src, Status: Present} + *dst = Timestamp{Time: src, Valid: true} return nil } @@ -229,15 +216,12 @@ func (dst *Timestamp) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Timestamp) Value() (driver.Value, error) { - switch src.Status { - case Present: - if src.InfinityModifier != None { - return src.InfinityModifier.String(), nil - } - return src.Time, nil - case Null: + if !src.Valid { return nil, nil - default: - return nil, errUndefined } + + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil } diff --git a/timestamp_array.go b/timestamp_array.go index e12481e3..fbf7c48a 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -15,13 +15,13 @@ import ( type TimestampArray struct { Elements []Timestamp Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *TimestampArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = TimestampArray{Status: Null} + *dst = TimestampArray{} return nil } @@ -37,9 +37,9 @@ func (dst *TimestampArray) Set(src interface{}) error { case []time.Time: if value == nil { - *dst = TimestampArray{Status: Null} + *dst = TimestampArray{} } else if len(value) == 0 { - *dst = TimestampArray{Status: Present} + *dst = TimestampArray{Valid: true} } else { elements := make([]Timestamp, len(value)) for i := range value { @@ -50,15 +50,15 @@ func (dst *TimestampArray) Set(src interface{}) error { *dst = TimestampArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*time.Time: if value == nil { - *dst = TimestampArray{Status: Null} + *dst = TimestampArray{} } else if len(value) == 0 { - *dst = TimestampArray{Status: Present} + *dst = TimestampArray{Valid: true} } else { elements := make([]Timestamp, len(value)) for i := range value { @@ -69,20 +69,20 @@ func (dst *TimestampArray) Set(src interface{}) error { *dst = TimestampArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []Timestamp: if value == nil { - *dst = TimestampArray{Status: Null} + *dst = TimestampArray{} } else if len(value) == 0 { - *dst = TimestampArray{Status: Present} + *dst = TimestampArray{Valid: true} } else { *dst = TimestampArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -91,7 +91,7 @@ func (dst *TimestampArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = TimestampArray{Status: Null} + *dst = TimestampArray{} return nil } @@ -100,7 +100,7 @@ func (dst *TimestampArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for TimestampArray", src) } if elementsLength == 0 { - *dst = TimestampArray{Status: Present} + *dst = TimestampArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -113,7 +113,7 @@ func (dst *TimestampArray) Set(src interface{}) error { *dst = TimestampArray{ Elements: make([]Timestamp, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -180,84 +180,77 @@ func (dst *TimestampArray) setRecursive(value reflect.Value, index, dimension in } func (dst TimestampArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *TimestampArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*time.Time: - *v = make([]*time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]time.Time: + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*time.Time: + *v = make([]*time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *TimestampArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -309,7 +302,7 @@ func (src *TimestampArray) assignToRecursive(value reflect.Value, index, dimensi func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = TimestampArray{Status: Null} + *dst = TimestampArray{} return nil } @@ -338,14 +331,14 @@ func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = TimestampArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = TimestampArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = TimestampArray{Status: Null} + *dst = TimestampArray{} return nil } @@ -356,7 +349,7 @@ func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = TimestampArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = TimestampArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -381,16 +374,13 @@ func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = TimestampArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = TimestampArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -443,11 +433,8 @@ func (src TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -461,7 +448,7 @@ func (src TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/timestamp_array_test.go b/timestamp_array_test.go index 54d15b24..214c8a71 100644 --- a/timestamp_array_test.go +++ b/timestamp_array_test.go @@ -14,53 +14,53 @@ func TestTimestampArrayTranscode(t *testing.T) { &pgtype.TimestampArray{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.TimestampArray{ Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.TimestampArray{Status: pgtype.Null}, + &pgtype.TimestampArray{}, &pgtype.TimestampArray{ Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.TimestampArray{ Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }, func(a, b interface{}) bool { ata := a.(pgtype.TimestampArray) bta := b.(pgtype.TimestampArray) - if len(ata.Elements) != len(bta.Elements) || ata.Status != bta.Status { + if len(ata.Elements) != len(bta.Elements) || ata.Valid != bta.Valid { return false } for i := range ata.Elements { ae, be := ata.Elements[i], bta.Elements[i] - if !(ae.Time.Equal(be.Time) && ae.Status == be.Status && ae.InfinityModifier == be.InfinityModifier) { + if !(ae.Time.Equal(be.Time) && ae.Valid == be.Valid && ae.InfinityModifier == be.InfinityModifier) { return false } } @@ -77,13 +77,13 @@ func TestTimestampArraySet(t *testing.T) { { source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, result: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]time.Time)(nil)), - result: pgtype.TimestampArray{Status: pgtype.Null}, + result: pgtype.TimestampArray{}, }, { source: [][]time.Time{ @@ -91,10 +91,10 @@ func TestTimestampArraySet(t *testing.T) { {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, result: pgtype.TimestampArray{ Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]time.Time{ @@ -108,18 +108,18 @@ func TestTimestampArraySet(t *testing.T) { time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, result: pgtype.TimestampArray{ Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -150,30 +150,30 @@ func TestTimestampArrayAssignTo(t *testing.T) { }{ { src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &timeSlice, expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, }, { - src: pgtype.TimestampArray{Status: pgtype.Null}, + src: pgtype.TimestampArray{}, dst: &timeSlice, expected: (([]time.Time)(nil)), }, { - src: pgtype.TimestampArray{Status: pgtype.Present}, + src: pgtype.TimestampArray{Valid: true}, dst: &timeSlice, expected: []time.Time{}, }, { src: pgtype.TimestampArray{ Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeSliceDim2, expected: [][]time.Time{ {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, @@ -182,18 +182,18 @@ func TestTimestampArrayAssignTo(t *testing.T) { { src: pgtype.TimestampArray{ Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeSliceDim4, expected: [][][][]time.Time{ {{{ @@ -208,10 +208,10 @@ func TestTimestampArrayAssignTo(t *testing.T) { { src: pgtype.TimestampArray{ Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeArrayDim2, expected: [2][1]time.Time{ {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, @@ -220,18 +220,18 @@ func TestTimestampArrayAssignTo(t *testing.T) { { src: pgtype.TimestampArray{ Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeArrayDim4, expected: [2][1][1][3]time.Time{ {{{ @@ -262,37 +262,37 @@ func TestTimestampArrayAssignTo(t *testing.T) { }{ { src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{{Status: pgtype.Null}}, + Elements: []pgtype.Timestamp{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &timeSlice, }, { src: pgtype.TimestampArray{ Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeArrayDim2, }, { src: pgtype.TimestampArray{ Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeSlice, }, { src: pgtype.TimestampArray{ Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeArrayDim4, }, } diff --git a/timestamp_test.go b/timestamp_test.go index ea7ef57a..88e2bca8 100644 --- a/timestamp_test.go +++ b/timestamp_test.go @@ -13,24 +13,24 @@ import ( func TestTimestampTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ - &pgtype.Timestamp{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Status: pgtype.Null}, - &pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, - &pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + &pgtype.Timestamp{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{}, + &pgtype.Timestamp{Valid: true, InfinityModifier: pgtype.Infinity}, + &pgtype.Timestamp{Valid: true, InfinityModifier: -pgtype.Infinity}, }, func(a, b interface{}) bool { at := a.(pgtype.Timestamp) bt := b.(pgtype.Timestamp) - return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier + return at.Time.Equal(bt.Time) && at.Valid == bt.Valid && at.InfinityModifier == bt.InfinityModifier }) } @@ -42,7 +42,7 @@ func TestTimestampTranscodeBigTimeBinary(t *testing.T) { } defer testutil.MustCloseContext(t, conn) - in := &pgtype.Timestamp{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Status: pgtype.Present} + in := &pgtype.Timestamp{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Valid: true} var out pgtype.Timestamp err := conn.QueryRow(context.Background(), "select $1::timestamptz", in).Scan(&out) @@ -50,7 +50,7 @@ func TestTimestampTranscodeBigTimeBinary(t *testing.T) { t.Fatal(err) } - require.Equal(t, in.Status, out.Status) + require.Equal(t, in.Valid, out.Valid) require.Truef(t, in.Time.Equal(out.Time), "expected %v got %v", in.Time, out.Time) } @@ -64,7 +64,7 @@ func TestTimestampNanosecondsTruncated(t *testing.T) { } for i, tt := range tests { { - ts := pgtype.Timestamp{Time: tt.input, Status: pgtype.Present} + ts := pgtype.Timestamp{Time: tt.input, Valid: true} buf, err := ts.EncodeText(nil, nil) if err != nil { t.Errorf("%d. EncodeText failed - %v", i, err) @@ -75,13 +75,13 @@ func TestTimestampNanosecondsTruncated(t *testing.T) { t.Errorf("%d. DecodeText failed - %v", i, err) } - if !(ts.Status == pgtype.Present && ts.Time.Equal(tt.expected)) { + if !(ts.Valid && ts.Time.Equal(tt.expected)) { t.Errorf("%d. EncodeText did not truncate nanoseconds", i) } } { - ts := pgtype.Timestamp{Time: tt.input, Status: pgtype.Present} + ts := pgtype.Timestamp{Time: tt.input, Valid: true} buf, err := ts.EncodeBinary(nil, nil) if err != nil { t.Errorf("%d. EncodeBinary failed - %v", i, err) @@ -92,7 +92,7 @@ func TestTimestampNanosecondsTruncated(t *testing.T) { t.Errorf("%d. DecodeBinary failed - %v", i, err) } - if !(ts.Status == pgtype.Present && ts.Time.Equal(tt.expected)) { + if !(ts.Valid && ts.Time.Equal(tt.expected)) { t.Errorf("%d. EncodeBinary did not truncate nanoseconds", i) } } @@ -113,16 +113,16 @@ func TestTimestampSet(t *testing.T) { source interface{} result pgtype.Timestamp }{ - {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: pgtype.Infinity, result: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, - {source: pgtype.NegativeInfinity, result: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), Valid: true}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), Valid: true}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: pgtype.Infinity, result: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true}}, + {source: pgtype.NegativeInfinity, result: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, } for i, tt := range successfulTests { @@ -147,8 +147,8 @@ func TestTimestampAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)}, - {src: pgtype.Timestamp{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)}, + {src: pgtype.Timestamp{Time: time.Time{}}, dst: &ptim, expected: ((*time.Time)(nil))}, } for i, tt := range simpleTests { @@ -167,7 +167,7 @@ func TestTimestampAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, } for i, tt := range pointerAllocTests { @@ -185,9 +185,9 @@ func TestTimestampAssignTo(t *testing.T) { src pgtype.Timestamp dst interface{} }{ - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Valid: true}, dst: &tim}, + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Valid: true}, dst: &tim}, + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, dst: &tim}, } for i, tt := range errorTests { diff --git a/timestamptz.go b/timestamptz.go index 299a8668..2a711ffa 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -22,13 +22,13 @@ const ( type Timestamptz struct { Time time.Time - Status Status + Valid bool InfinityModifier InfinityModifier } func (dst *Timestamptz) Set(src interface{}) error { if src == nil { - *dst = Timestamptz{Status: Null} + *dst = Timestamptz{} return nil } @@ -41,15 +41,15 @@ func (dst *Timestamptz) Set(src interface{}) error { switch value := src.(type) { case time.Time: - *dst = Timestamptz{Time: value, Status: Present} + *dst = Timestamptz{Time: value, Valid: true} case *time.Time: if value == nil { - *dst = Timestamptz{Status: Null} + *dst = Timestamptz{} } else { return dst.Set(*value) } case InfinityModifier: - *dst = Timestamptz{InfinityModifier: value, Status: Present} + *dst = Timestamptz{InfinityModifier: value, Valid: true} default: if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) @@ -61,54 +61,47 @@ func (dst *Timestamptz) Set(src interface{}) error { } func (dst Timestamptz) Get() interface{} { - switch dst.Status { - case Present: - if dst.InfinityModifier != None { - return dst.InfinityModifier - } - return dst.Time - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst.Time } func (src *Timestamptz) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *time.Time: - if src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } } func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Timestamptz{Status: Null} + *dst = Timestamptz{} return nil } sbuf := string(src) switch sbuf { case "infinity": - *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} + *dst = Timestamptz{Valid: true, InfinityModifier: Infinity} case "-infinity": - *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} + *dst = Timestamptz{Valid: true, InfinityModifier: -Infinity} default: var format string if len(sbuf) >= 9 && (sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+') { @@ -124,7 +117,7 @@ func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Timestamptz{Time: tim, Status: Present} + *dst = Timestamptz{Time: tim, Valid: true} } return nil @@ -132,7 +125,7 @@ func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Timestamptz{Status: Null} + *dst = Timestamptz{} return nil } @@ -144,26 +137,23 @@ func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { switch microsecSinceY2K { case infinityMicrosecondOffset: - *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} + *dst = Timestamptz{Valid: true, InfinityModifier: Infinity} case negativeInfinityMicrosecondOffset: - *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} + *dst = Timestamptz{Valid: true, InfinityModifier: -Infinity} default: tim := time.Unix( microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), ) - *dst = Timestamptz{Time: tim, Status: Present} + *dst = Timestamptz{Time: tim, Valid: true} } return nil } func (src Timestamptz) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } var s string @@ -181,11 +171,8 @@ func (src Timestamptz) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Timestamptz) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } var microsecSinceY2K int64 @@ -205,7 +192,7 @@ func (src Timestamptz) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Timestamptz) Scan(src interface{}) error { if src == nil { - *dst = Timestamptz{Status: Null} + *dst = Timestamptz{} return nil } @@ -217,7 +204,7 @@ func (dst *Timestamptz) Scan(src interface{}) error { copy(srcCopy, src) return dst.DecodeText(nil, srcCopy) case time.Time: - *dst = Timestamptz{Time: src, Status: Present} + *dst = Timestamptz{Time: src, Valid: true} return nil } @@ -226,29 +213,19 @@ func (dst *Timestamptz) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Timestamptz) Value() (driver.Value, error) { - switch src.Status { - case Present: - if src.InfinityModifier != None { - return src.InfinityModifier.String(), nil - } - return src.Time, nil - case Null: + if !src.Valid { return nil, nil - default: - return nil, errUndefined } + + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil } func (src Timestamptz) MarshalJSON() ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return []byte("null"), nil - case Undefined: - return nil, errUndefined - } - - if src.Status != Present { - return nil, errBadStatus } var s string @@ -273,15 +250,15 @@ func (dst *Timestamptz) UnmarshalJSON(b []byte) error { } if s == nil { - *dst = Timestamptz{Status: Null} + *dst = Timestamptz{} return nil } switch *s { case "infinity": - *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} + *dst = Timestamptz{Valid: true, InfinityModifier: Infinity} case "-infinity": - *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} + *dst = Timestamptz{Valid: true, InfinityModifier: -Infinity} default: // PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz tim, err := time.Parse(time.RFC3339Nano, *s) @@ -289,7 +266,7 @@ func (dst *Timestamptz) UnmarshalJSON(b []byte) error { return err } - *dst = Timestamptz{Time: tim, Status: Present} + *dst = Timestamptz{Time: tim, Valid: true} } return nil diff --git a/timestamptz_array.go b/timestamptz_array.go index a3b4b263..4523b251 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -15,13 +15,13 @@ import ( type TimestamptzArray struct { Elements []Timestamptz Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *TimestamptzArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = TimestamptzArray{Status: Null} + *dst = TimestamptzArray{} return nil } @@ -37,9 +37,9 @@ func (dst *TimestamptzArray) Set(src interface{}) error { case []time.Time: if value == nil { - *dst = TimestamptzArray{Status: Null} + *dst = TimestamptzArray{} } else if len(value) == 0 { - *dst = TimestamptzArray{Status: Present} + *dst = TimestamptzArray{Valid: true} } else { elements := make([]Timestamptz, len(value)) for i := range value { @@ -50,15 +50,15 @@ func (dst *TimestamptzArray) Set(src interface{}) error { *dst = TimestamptzArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*time.Time: if value == nil { - *dst = TimestamptzArray{Status: Null} + *dst = TimestamptzArray{} } else if len(value) == 0 { - *dst = TimestamptzArray{Status: Present} + *dst = TimestamptzArray{Valid: true} } else { elements := make([]Timestamptz, len(value)) for i := range value { @@ -69,20 +69,20 @@ func (dst *TimestamptzArray) Set(src interface{}) error { *dst = TimestamptzArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []Timestamptz: if value == nil { - *dst = TimestamptzArray{Status: Null} + *dst = TimestamptzArray{} } else if len(value) == 0 { - *dst = TimestamptzArray{Status: Present} + *dst = TimestamptzArray{Valid: true} } else { *dst = TimestamptzArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -91,7 +91,7 @@ func (dst *TimestamptzArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = TimestamptzArray{Status: Null} + *dst = TimestamptzArray{} return nil } @@ -100,7 +100,7 @@ func (dst *TimestamptzArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for TimestamptzArray", src) } if elementsLength == 0 { - *dst = TimestamptzArray{Status: Present} + *dst = TimestamptzArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -113,7 +113,7 @@ func (dst *TimestamptzArray) Set(src interface{}) error { *dst = TimestamptzArray{ Elements: make([]Timestamptz, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -180,84 +180,77 @@ func (dst *TimestamptzArray) setRecursive(value reflect.Value, index, dimension } func (dst TimestamptzArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *TimestamptzArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*time.Time: - *v = make([]*time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]time.Time: + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*time.Time: + *v = make([]*time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *TimestamptzArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -309,7 +302,7 @@ func (src *TimestamptzArray) assignToRecursive(value reflect.Value, index, dimen func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = TimestamptzArray{Status: Null} + *dst = TimestamptzArray{} return nil } @@ -338,14 +331,14 @@ func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = TimestamptzArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = TimestamptzArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = TimestamptzArray{Status: Null} + *dst = TimestamptzArray{} return nil } @@ -356,7 +349,7 @@ func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = TimestamptzArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = TimestamptzArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -381,16 +374,13 @@ func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = TimestamptzArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = TimestamptzArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -443,11 +433,8 @@ func (src TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) } func (src TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -461,7 +448,7 @@ func (src TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, erro } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/timestamptz_array_test.go b/timestamptz_array_test.go index 9856e4e7..22e07b59 100644 --- a/timestamptz_array_test.go +++ b/timestamptz_array_test.go @@ -14,53 +14,53 @@ func TestTimestamptzArrayTranscode(t *testing.T) { &pgtype.TimestamptzArray{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.TimestamptzArray{Status: pgtype.Null}, + &pgtype.TimestamptzArray{}, &pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }, func(a, b interface{}) bool { ata := a.(pgtype.TimestamptzArray) bta := b.(pgtype.TimestamptzArray) - if len(ata.Elements) != len(bta.Elements) || ata.Status != bta.Status { + if len(ata.Elements) != len(bta.Elements) || ata.Valid != bta.Valid { return false } for i := range ata.Elements { ae, be := ata.Elements[i], bta.Elements[i] - if !(ae.Time.Equal(be.Time) && ae.Status == be.Status && ae.InfinityModifier == be.InfinityModifier) { + if !(ae.Time.Equal(be.Time) && ae.Valid == be.Valid && ae.InfinityModifier == be.InfinityModifier) { return false } } @@ -77,13 +77,13 @@ func TestTimestamptzArraySet(t *testing.T) { { source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, result: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]time.Time)(nil)), - result: pgtype.TimestamptzArray{Status: pgtype.Null}, + result: pgtype.TimestamptzArray{}, }, { source: [][]time.Time{ @@ -91,10 +91,10 @@ func TestTimestamptzArraySet(t *testing.T) { {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, result: pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]time.Time{ @@ -108,18 +108,18 @@ func TestTimestamptzArraySet(t *testing.T) { time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, result: pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1]time.Time{ @@ -127,10 +127,10 @@ func TestTimestamptzArraySet(t *testing.T) { {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, result: pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3]time.Time{ @@ -144,18 +144,18 @@ func TestTimestamptzArraySet(t *testing.T) { time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, result: pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -186,30 +186,30 @@ func TestTimestamptzArrayAssignTo(t *testing.T) { }{ { src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &timeSlice, expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, }, { - src: pgtype.TimestamptzArray{Status: pgtype.Null}, + src: pgtype.TimestamptzArray{}, dst: &timeSlice, expected: (([]time.Time)(nil)), }, { - src: pgtype.TimestamptzArray{Status: pgtype.Present}, + src: pgtype.TimestamptzArray{Valid: true}, dst: &timeSlice, expected: []time.Time{}, }, { src: pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeSliceDim2, expected: [][]time.Time{ {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, @@ -218,18 +218,18 @@ func TestTimestamptzArrayAssignTo(t *testing.T) { { src: pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeSliceDim4, expected: [][][][]time.Time{ {{{ @@ -244,10 +244,10 @@ func TestTimestamptzArrayAssignTo(t *testing.T) { { src: pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeArrayDim2, expected: [2][1]time.Time{ {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, @@ -256,18 +256,18 @@ func TestTimestamptzArrayAssignTo(t *testing.T) { { src: pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeArrayDim4, expected: [2][1][1][3]time.Time{ {{{ @@ -298,37 +298,37 @@ func TestTimestamptzArrayAssignTo(t *testing.T) { }{ { src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{{Status: pgtype.Null}}, + Elements: []pgtype.Timestamptz{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &timeSlice, }, { src: pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeArrayDim2, }, { src: pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeSlice, }, { src: pgtype.TimestamptzArray{ Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &timeArrayDim4, }, } diff --git a/timestamptz_test.go b/timestamptz_test.go index c3f63967..fa2a7e89 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -13,24 +13,24 @@ import ( func TestTimestamptzTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ - &pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Status: pgtype.Null}, - &pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, - &pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + &pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{}, + &pgtype.Timestamptz{Valid: true, InfinityModifier: pgtype.Infinity}, + &pgtype.Timestamptz{Valid: true, InfinityModifier: -pgtype.Infinity}, }, func(a, b interface{}) bool { at := a.(pgtype.Timestamptz) bt := b.(pgtype.Timestamptz) - return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier + return at.Time.Equal(bt.Time) && at.Valid == bt.Valid && at.InfinityModifier == bt.InfinityModifier }) } @@ -42,7 +42,7 @@ func TestTimestamptzTranscodeBigTimeBinary(t *testing.T) { } defer testutil.MustCloseContext(t, conn) - in := &pgtype.Timestamptz{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Status: pgtype.Present} + in := &pgtype.Timestamptz{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Valid: true} var out pgtype.Timestamptz err := conn.QueryRow(context.Background(), "select $1::timestamptz", in).Scan(&out) @@ -50,7 +50,7 @@ func TestTimestamptzTranscodeBigTimeBinary(t *testing.T) { t.Fatal(err) } - require.Equal(t, in.Status, out.Status) + require.Equal(t, in.Valid, out.Valid) require.Truef(t, in.Time.Equal(out.Time), "expected %v got %v", in.Time, out.Time) } @@ -64,7 +64,7 @@ func TestTimestamptzNanosecondsTruncated(t *testing.T) { } for i, tt := range tests { { - tstz := pgtype.Timestamptz{Time: tt.input, Status: pgtype.Present} + tstz := pgtype.Timestamptz{Time: tt.input, Valid: true} buf, err := tstz.EncodeText(nil, nil) if err != nil { t.Errorf("%d. EncodeText failed - %v", i, err) @@ -75,13 +75,13 @@ func TestTimestamptzNanosecondsTruncated(t *testing.T) { t.Errorf("%d. DecodeText failed - %v", i, err) } - if !(tstz.Status == pgtype.Present && tstz.Time.Equal(tt.expected)) { + if !(tstz.Valid && tstz.Time.Equal(tt.expected)) { t.Errorf("%d. EncodeText did not truncate nanoseconds", i) } } { - tstz := pgtype.Timestamptz{Time: tt.input, Status: pgtype.Present} + tstz := pgtype.Timestamptz{Time: tt.input, Valid: true} buf, err := tstz.EncodeBinary(nil, nil) if err != nil { t.Errorf("%d. EncodeBinary failed - %v", i, err) @@ -92,7 +92,7 @@ func TestTimestamptzNanosecondsTruncated(t *testing.T) { t.Errorf("%d. DecodeBinary failed - %v", i, err) } - if !(tstz.Status == pgtype.Present && tstz.Time.Equal(tt.expected)) { + if !(tstz.Valid && tstz.Time.Equal(tt.expected)) { t.Errorf("%d. EncodeBinary did not truncate nanoseconds", i) } } @@ -113,15 +113,15 @@ func TestTimestamptzSet(t *testing.T) { source interface{} result pgtype.Timestamptz }{ - {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, - {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), Status: pgtype.Present}}, - {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, - {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, - {source: pgtype.Infinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, - {source: pgtype.NegativeInfinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Valid: true}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), Valid: true}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: pgtype.Infinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true}}, + {source: pgtype.NegativeInfinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, } for i, tt := range successfulTests { @@ -146,8 +146,8 @@ func TestTimestamptzAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, - {src: pgtype.Timestamptz{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {src: pgtype.Timestamptz{Time: time.Time{}}, dst: &ptim, expected: ((*time.Time)(nil))}, } for i, tt := range simpleTests { @@ -166,7 +166,7 @@ func TestTimestamptzAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, } for i, tt := range pointerAllocTests { @@ -184,9 +184,9 @@ func TestTimestamptzAssignTo(t *testing.T) { src pgtype.Timestamptz dst interface{} }{ - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Valid: true}, dst: &tim}, + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Valid: true}, dst: &tim}, + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, dst: &tim}, } for i, tt := range errorTests { @@ -202,11 +202,11 @@ func TestTimestamptzMarshalJSON(t *testing.T) { source pgtype.Timestamptz result string }{ - {source: pgtype.Timestamptz{Status: pgtype.Null}, result: "null"}, - {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29T10:05:45-06:00\""}, - {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29T10:05:45.555-06:00\""}, - {source: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, result: "\"infinity\""}, - {source: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, result: "\"-infinity\""}, + {source: pgtype.Timestamptz{}, result: "null"}, + {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}, result: "\"2012-03-29T10:05:45-06:00\""}, + {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}, result: "\"2012-03-29T10:05:45.555-06:00\""}, + {source: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true}, result: "\"infinity\""}, + {source: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "\"-infinity\""}, } for i, tt := range successfulTests { r, err := tt.source.MarshalJSON() @@ -225,11 +225,11 @@ func TestTimestamptzUnmarshalJSON(t *testing.T) { source string result pgtype.Timestamptz }{ - {source: "null", result: pgtype.Timestamptz{Status: pgtype.Null}}, - {source: "\"2012-03-29T10:05:45-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, - {source: "\"2012-03-29T10:05:45.555-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, - {source: "\"infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, - {source: "\"-infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + {source: "null", result: pgtype.Timestamptz{}}, + {source: "\"2012-03-29T10:05:45-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}}, + {source: "\"2012-03-29T10:05:45.555-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}}, + {source: "\"infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true}}, + {source: "\"-infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, } for i, tt := range successfulTests { var r pgtype.Timestamptz @@ -238,7 +238,7 @@ func TestTimestamptzUnmarshalJSON(t *testing.T) { t.Errorf("%d: %v", i, err) } - if !r.Time.Equal(tt.result.Time) || r.Status != tt.result.Status || r.InfinityModifier != tt.result.InfinityModifier { + if !r.Time.Equal(tt.result.Time) || r.Valid != tt.result.Valid || r.InfinityModifier != tt.result.InfinityModifier { t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) } } diff --git a/tsrange.go b/tsrange.go index 19ecf446..7495d972 100644 --- a/tsrange.go +++ b/tsrange.go @@ -12,13 +12,13 @@ type Tsrange struct { Upper Timestamp LowerType BoundType UpperType BoundType - Status Status + Valid bool } func (dst *Tsrange) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = Tsrange{Status: Null} + *dst = Tsrange{} return nil } @@ -36,15 +36,11 @@ func (dst *Tsrange) Set(src interface{}) error { return nil } -func (dst Tsrange) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: +func (src Tsrange) Get() interface{} { + if !src.Valid { return nil - default: - return dst.Status } + return src } func (src *Tsrange) AssignTo(dst interface{}) error { @@ -53,7 +49,7 @@ func (src *Tsrange) AssignTo(dst interface{}) error { func (dst *Tsrange) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Tsrange{Status: Null} + *dst = Tsrange{} return nil } @@ -62,7 +58,7 @@ func (dst *Tsrange) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Tsrange{Status: Present} + *dst = Tsrange{Valid: true} dst.LowerType = utr.LowerType dst.UpperType = utr.UpperType @@ -88,7 +84,7 @@ func (dst *Tsrange) DecodeText(ci *ConnInfo, src []byte) error { func (dst *Tsrange) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Tsrange{Status: Null} + *dst = Tsrange{} return nil } @@ -97,7 +93,7 @@ func (dst *Tsrange) DecodeBinary(ci *ConnInfo, src []byte) error { return err } - *dst = Tsrange{Status: Present} + *dst = Tsrange{Valid: true} dst.LowerType = ubr.LowerType dst.UpperType = ubr.UpperType @@ -122,11 +118,8 @@ func (dst *Tsrange) DecodeBinary(ci *ConnInfo, src []byte) error { } func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } switch src.LowerType { @@ -175,11 +168,8 @@ func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } var rangeType byte @@ -245,7 +235,7 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Tsrange) Scan(src interface{}) error { if src == nil { - *dst = Tsrange{Status: Null} + *dst = Tsrange{} return nil } diff --git a/tsrange_array.go b/tsrange_array.go index c64048eb..2af25f8d 100644 --- a/tsrange_array.go +++ b/tsrange_array.go @@ -14,13 +14,13 @@ import ( type TsrangeArray struct { Elements []Tsrange Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *TsrangeArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = TsrangeArray{Status: Null} + *dst = TsrangeArray{} return nil } @@ -36,14 +36,14 @@ func (dst *TsrangeArray) Set(src interface{}) error { case []Tsrange: if value == nil { - *dst = TsrangeArray{Status: Null} + *dst = TsrangeArray{} } else if len(value) == 0 { - *dst = TsrangeArray{Status: Present} + *dst = TsrangeArray{Valid: true} } else { *dst = TsrangeArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -52,7 +52,7 @@ func (dst *TsrangeArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = TsrangeArray{Status: Null} + *dst = TsrangeArray{} return nil } @@ -61,7 +61,7 @@ func (dst *TsrangeArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for TsrangeArray", src) } if elementsLength == 0 { - *dst = TsrangeArray{Status: Present} + *dst = TsrangeArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -74,7 +74,7 @@ func (dst *TsrangeArray) Set(src interface{}) error { *dst = TsrangeArray{ Elements: make([]Tsrange, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -141,75 +141,68 @@ func (dst *TsrangeArray) setRecursive(value reflect.Value, index, dimension int) } func (dst TsrangeArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *TsrangeArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]Tsrange: - *v = make([]Tsrange, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]Tsrange: + *v = make([]Tsrange, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *TsrangeArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -261,7 +254,7 @@ func (src *TsrangeArray) assignToRecursive(value reflect.Value, index, dimension func (dst *TsrangeArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = TsrangeArray{Status: Null} + *dst = TsrangeArray{} return nil } @@ -290,14 +283,14 @@ func (dst *TsrangeArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = TsrangeArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = TsrangeArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *TsrangeArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = TsrangeArray{Status: Null} + *dst = TsrangeArray{} return nil } @@ -308,7 +301,7 @@ func (dst *TsrangeArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = TsrangeArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = TsrangeArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -333,16 +326,13 @@ func (dst *TsrangeArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = TsrangeArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = TsrangeArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src TsrangeArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -395,11 +385,8 @@ func (src TsrangeArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src TsrangeArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -413,7 +400,7 @@ func (src TsrangeArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/tsrange_test.go b/tsrange_test.go index 1be0c7d2..daea59bb 100644 --- a/tsrange_test.go +++ b/tsrange_test.go @@ -10,32 +10,32 @@ import ( func TestTsrangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "tsrange", []interface{}{ - &pgtype.Tsrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Tsrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, &pgtype.Tsrange{ - Lower: pgtype.Timestamp{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Timestamp{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + Lower: pgtype.Timestamp{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Timestamp{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, - Status: pgtype.Present, + Valid: true, }, &pgtype.Tsrange{ - Lower: pgtype.Timestamp{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + Lower: pgtype.Timestamp{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, - Status: pgtype.Present, + Valid: true, }, - &pgtype.Tsrange{Status: pgtype.Null}, + &pgtype.Tsrange{}, }, func(aa, bb interface{}) bool { a := aa.(pgtype.Tsrange) b := bb.(pgtype.Tsrange) - return a.Status == b.Status && + return a.Valid == b.Valid && a.Lower.Time.Equal(b.Lower.Time) && - a.Lower.Status == b.Lower.Status && + a.Lower.Valid == b.Lower.Valid && a.Lower.InfinityModifier == b.Lower.InfinityModifier && a.Upper.Time.Equal(b.Upper.Time) && - a.Upper.Status == b.Upper.Status && + a.Upper.Valid == b.Upper.Valid && a.Upper.InfinityModifier == b.Upper.InfinityModifier }) } diff --git a/tstzrange.go b/tstzrange.go index 25576308..3d4e2cde 100644 --- a/tstzrange.go +++ b/tstzrange.go @@ -12,13 +12,13 @@ type Tstzrange struct { Upper Timestamptz LowerType BoundType UpperType BoundType - Status Status + Valid bool } func (dst *Tstzrange) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = Tstzrange{Status: Null} + *dst = Tstzrange{} return nil } @@ -36,15 +36,11 @@ func (dst *Tstzrange) Set(src interface{}) error { return nil } -func (dst Tstzrange) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: +func (src Tstzrange) Get() interface{} { + if !src.Valid { return nil - default: - return dst.Status } + return src } func (src *Tstzrange) AssignTo(dst interface{}) error { @@ -53,7 +49,7 @@ func (src *Tstzrange) AssignTo(dst interface{}) error { func (dst *Tstzrange) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Tstzrange{Status: Null} + *dst = Tstzrange{} return nil } @@ -62,7 +58,7 @@ func (dst *Tstzrange) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Tstzrange{Status: Present} + *dst = Tstzrange{Valid: true} dst.LowerType = utr.LowerType dst.UpperType = utr.UpperType @@ -88,7 +84,7 @@ func (dst *Tstzrange) DecodeText(ci *ConnInfo, src []byte) error { func (dst *Tstzrange) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Tstzrange{Status: Null} + *dst = Tstzrange{} return nil } @@ -97,7 +93,7 @@ func (dst *Tstzrange) DecodeBinary(ci *ConnInfo, src []byte) error { return err } - *dst = Tstzrange{Status: Present} + *dst = Tstzrange{Valid: true} dst.LowerType = ubr.LowerType dst.UpperType = ubr.UpperType @@ -122,11 +118,8 @@ func (dst *Tstzrange) DecodeBinary(ci *ConnInfo, src []byte) error { } func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } switch src.LowerType { @@ -175,11 +168,8 @@ func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } var rangeType byte @@ -245,7 +235,7 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Tstzrange) Scan(src interface{}) error { if src == nil { - *dst = Tstzrange{Status: Null} + *dst = Tstzrange{} return nil } diff --git a/tstzrange_array.go b/tstzrange_array.go index a216820a..389d6b4c 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -14,13 +14,13 @@ import ( type TstzrangeArray struct { Elements []Tstzrange Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *TstzrangeArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = TstzrangeArray{Status: Null} + *dst = TstzrangeArray{} return nil } @@ -36,14 +36,14 @@ func (dst *TstzrangeArray) Set(src interface{}) error { case []Tstzrange: if value == nil { - *dst = TstzrangeArray{Status: Null} + *dst = TstzrangeArray{} } else if len(value) == 0 { - *dst = TstzrangeArray{Status: Present} + *dst = TstzrangeArray{Valid: true} } else { *dst = TstzrangeArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -52,7 +52,7 @@ func (dst *TstzrangeArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = TstzrangeArray{Status: Null} + *dst = TstzrangeArray{} return nil } @@ -61,7 +61,7 @@ func (dst *TstzrangeArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for TstzrangeArray", src) } if elementsLength == 0 { - *dst = TstzrangeArray{Status: Present} + *dst = TstzrangeArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -74,7 +74,7 @@ func (dst *TstzrangeArray) Set(src interface{}) error { *dst = TstzrangeArray{ Elements: make([]Tstzrange, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -141,75 +141,68 @@ func (dst *TstzrangeArray) setRecursive(value reflect.Value, index, dimension in } func (dst TstzrangeArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *TstzrangeArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]Tstzrange: - *v = make([]Tstzrange, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]Tstzrange: + *v = make([]Tstzrange, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *TstzrangeArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -261,7 +254,7 @@ func (src *TstzrangeArray) assignToRecursive(value reflect.Value, index, dimensi func (dst *TstzrangeArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = TstzrangeArray{Status: Null} + *dst = TstzrangeArray{} return nil } @@ -290,14 +283,14 @@ func (dst *TstzrangeArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = TstzrangeArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = TstzrangeArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *TstzrangeArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = TstzrangeArray{Status: Null} + *dst = TstzrangeArray{} return nil } @@ -308,7 +301,7 @@ func (dst *TstzrangeArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = TstzrangeArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = TstzrangeArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -333,16 +326,13 @@ func (dst *TstzrangeArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = TstzrangeArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = TstzrangeArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src TstzrangeArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -395,11 +385,8 @@ func (src TstzrangeArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src TstzrangeArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -413,7 +400,7 @@ func (src TstzrangeArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/tstzrange_test.go b/tstzrange_test.go index f8e2c2c5..49cfc63e 100644 --- a/tstzrange_test.go +++ b/tstzrange_test.go @@ -11,32 +11,32 @@ import ( func TestTstzrangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "tstzrange", []interface{}{ - &pgtype.Tstzrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Tstzrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, &pgtype.Tstzrange{ - Lower: pgtype.Timestamptz{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Timestamptz{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + Lower: pgtype.Timestamptz{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Timestamptz{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, - Status: pgtype.Present, + Valid: true, }, &pgtype.Tstzrange{ - Lower: pgtype.Timestamptz{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + Lower: pgtype.Timestamptz{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, - Status: pgtype.Present, + Valid: true, }, - &pgtype.Tstzrange{Status: pgtype.Null}, + &pgtype.Tstzrange{}, }, func(aa, bb interface{}) bool { a := aa.(pgtype.Tstzrange) b := bb.(pgtype.Tstzrange) - return a.Status == b.Status && + return a.Valid == b.Valid && a.Lower.Time.Equal(b.Lower.Time) && - a.Lower.Status == b.Lower.Status && + a.Lower.Valid == b.Lower.Valid && a.Lower.InfinityModifier == b.Lower.InfinityModifier && a.Upper.Time.Equal(b.Upper.Time) && - a.Upper.Status == b.Upper.Status && + a.Upper.Valid == b.Upper.Valid && a.Upper.InfinityModifier == b.Upper.InfinityModifier }) } diff --git a/typed_array.go.erb b/typed_array.go.erb index 5788626b..e1ead59c 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -13,13 +13,13 @@ import ( type <%= pgtype_array_type %> struct { Elements []<%= pgtype_element_type %> Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = <%= pgtype_array_type %>{Status: Null} + *dst = <%= pgtype_array_type %>{} return nil } @@ -36,9 +36,9 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { <% if t != "[]#{pgtype_element_type}" %> case <%= t %>: if value == nil { - *dst = <%= pgtype_array_type %>{Status: Null} + *dst = <%= pgtype_array_type %>{} } else if len(value) == 0 { - *dst = <%= pgtype_array_type %>{Status: Present} + *dst = <%= pgtype_array_type %>{Valid: true} } else { elements := make([]<%= pgtype_element_type %>, len(value)) for i := range value { @@ -49,21 +49,21 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { *dst = <%= pgtype_array_type %>{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } <% end %> <% end %> case []<%= pgtype_element_type %>: if value == nil { - *dst = <%= pgtype_array_type %>{Status: Null} + *dst = <%= pgtype_array_type %>{} } else if len(value) == 0 { - *dst = <%= pgtype_array_type %>{Status: Present} + *dst = <%= pgtype_array_type %>{Valid: true} } else { *dst = <%= pgtype_array_type %>{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status : Present, + Valid: true, } } default: @@ -72,7 +72,7 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = <%= pgtype_array_type %>{Status: Null} + *dst = <%= pgtype_array_type %>{} return nil } @@ -81,7 +81,7 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for <%= pgtype_array_type %>", src) } if elementsLength == 0 { - *dst = <%= pgtype_array_type %>{Status: Present} + *dst = <%= pgtype_array_type %>{Valid: true} return nil } if len(dimensions) == 0 { @@ -94,7 +94,7 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { *dst = <%= pgtype_array_type %> { Elements: make([]<%= pgtype_element_type %>, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -161,75 +161,68 @@ func (dst *<%= pgtype_array_type %>) setRecursive(value reflect.Value, index, di } func (dst <%= pgtype_array_type %>) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } + if !dst.Valid { + return nil + } + return dst } func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1{ - // Attempt to match to select common types: - switch v := dst.(type) { - <% go_array_types.split(",").each do |t| %> - case *<%= t %>: - *v = make(<%= t %>, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - <% end %> - } - } + if !src.Valid { + return NullAssignTo(dst) + } - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } + if len(src.Dimensions) <= 1{ + // Attempt to match to select common types: + switch v := dst.(type) { + <% go_array_types.split(",").each do |t| %> + case *<%= t %>: + *v = make(<%= t %>, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + <% end %> + } + } - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } - return nil - case Null: - return NullAssignTo(dst) - } + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + return nil } func (src *<%= pgtype_array_type %>) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -281,7 +274,7 @@ func (src *<%= pgtype_array_type %>) assignToRecursive(value reflect.Value, inde func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = <%= pgtype_array_type %>{Status: Null} + *dst = <%= pgtype_array_type %>{} return nil } @@ -310,7 +303,7 @@ func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error } } - *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } @@ -318,7 +311,7 @@ func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error <% if binary_format == "true" %> func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = <%= pgtype_array_type %>{Status: Null} + *dst = <%= pgtype_array_type %>{} return nil } @@ -329,7 +322,7 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) erro } if len(arrayHeader.Dimensions) == 0 { - *dst = <%= pgtype_array_type %>{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = <%= pgtype_array_type %>{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -354,18 +347,15 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) erro } } - *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } <% end %> func (src <%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined - } + } if len(src.Dimensions) == 0 { return append(buf, '{', '}'), nil @@ -418,12 +408,9 @@ func (src <%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte <% if binary_format == "true" %> func (src <%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined - } + } arrayHeader := ArrayHeader{ Dimensions: src.Dimensions, @@ -436,7 +423,7 @@ func (src <%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/typed_range.go.erb b/typed_range.go.erb index 5625587a..99d8c22d 100644 --- a/typed_range.go.erb +++ b/typed_range.go.erb @@ -14,13 +14,13 @@ type <%= range_type %> struct { Upper <%= element_type %> LowerType BoundType UpperType BoundType - Status Status + Valid bool } func (dst *<%= range_type %>) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = <%= range_type %>{Status: Null} + *dst = <%= range_type %>{} return nil } @@ -38,15 +38,11 @@ func (dst *<%= range_type %>) Set(src interface{}) error { return nil } -func (dst <%= range_type %>) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } +func (src <%= range_type %>) Get() interface{} { + if !src.Valid { + return nil + } + return src } func (src *<%= range_type %>) AssignTo(dst interface{}) error { @@ -55,7 +51,7 @@ func (src *<%= range_type %>) AssignTo(dst interface{}) error { func (dst *<%= range_type %>) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = <%= range_type %>{Status: Null} + *dst = <%= range_type %>{} return nil } @@ -64,7 +60,7 @@ func (dst *<%= range_type %>) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = <%= range_type %>{Status: Present} + *dst = <%= range_type %>{Valid: true} dst.LowerType = utr.LowerType dst.UpperType = utr.UpperType @@ -90,7 +86,7 @@ func (dst *<%= range_type %>) DecodeText(ci *ConnInfo, src []byte) error { func (dst *<%= range_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = <%= range_type %>{Status: Null} + *dst = <%= range_type %>{} return nil } @@ -99,7 +95,7 @@ func (dst *<%= range_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { return err } - *dst = <%= range_type %>{Status: Present} + *dst = <%= range_type %>{Valid: true} dst.LowerType = ubr.LowerType dst.UpperType = ubr.UpperType @@ -124,11 +120,8 @@ func (dst *<%= range_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { } func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } switch src.LowerType { @@ -177,11 +170,8 @@ func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error } func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } var rangeType byte @@ -247,7 +237,7 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err // Scan implements the database/sql Scanner interface. func (dst *<%= range_type %>) Scan(src interface{}) error { if src == nil { - *dst = <%= range_type %>{Status: Null} + *dst = <%= range_type %>{} return nil } diff --git a/unknown.go b/unknown.go index c591b708..0e576ee9 100644 --- a/unknown.go +++ b/unknown.go @@ -8,7 +8,7 @@ import "database/sql/driver" // type information. e.g. SELECT NULL; type Unknown struct { String string - Status Status + Valid bool } func (dst *Unknown) Set(src interface{}) error { diff --git a/uuid.go b/uuid.go index fa0be07f..d46111d3 100644 --- a/uuid.go +++ b/uuid.go @@ -8,13 +8,13 @@ import ( ) type UUID struct { - Bytes [16]byte - Status Status + Bytes [16]byte + Valid bool } func (dst *UUID) Set(src interface{}) error { if src == nil { - *dst = UUID{Status: Null} + *dst = UUID{} return nil } @@ -27,26 +27,26 @@ func (dst *UUID) Set(src interface{}) error { switch value := src.(type) { case [16]byte: - *dst = UUID{Bytes: value, Status: Present} + *dst = UUID{Bytes: value, Valid: true} case []byte: if value != nil { if len(value) != 16 { return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) } - *dst = UUID{Status: Present} + *dst = UUID{Valid: true} copy(dst.Bytes[:], value) } else { - *dst = UUID{Status: Null} + *dst = UUID{} } case string: uuid, err := parseUUID(value) if err != nil { return err } - *dst = UUID{Bytes: uuid, Status: Present} + *dst = UUID{Bytes: uuid, Valid: true} case *string: if value == nil { - *dst = UUID{Status: Null} + *dst = UUID{} } else { return dst.Set(*value) } @@ -61,40 +61,35 @@ func (dst *UUID) Set(src interface{}) error { } func (dst UUID) Get() interface{} { - switch dst.Status { - case Present: - return dst.Bytes - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst.Bytes } func (src *UUID) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *[16]byte: - *v = src.Bytes - return nil - case *[]byte: - *v = make([]byte, 16) - copy(*v, src.Bytes[:]) - return nil - case *string: - *v = encodeUUID(src.Bytes) - return nil - default: - if nextDst, retry := GetAssignToDstType(v); retry { - return src.AssignTo(nextDst) - } - } - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot assign %v into %T", src, dst) + switch v := dst.(type) { + case *[16]byte: + *v = src.Bytes + return nil + case *[]byte: + *v = make([]byte, 16) + copy(*v, src.Bytes[:]) + return nil + case *string: + *v = encodeUUID(src.Bytes) + return nil + default: + if nextDst, retry := GetAssignToDstType(v); retry { + return src.AssignTo(nextDst) + } + } + + return nil } // parseUUID converts a string UUID in standard form to a byte array. @@ -125,7 +120,7 @@ func encodeUUID(src [16]byte) string { func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = UUID{Status: Null} + *dst = UUID{} return nil } @@ -138,13 +133,13 @@ func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = UUID{Bytes: buf, Status: Present} + *dst = UUID{Bytes: buf, Valid: true} return nil } func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = UUID{Status: Null} + *dst = UUID{} return nil } @@ -152,28 +147,22 @@ func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { return fmt.Errorf("invalid length for UUID: %v", len(src)) } - *dst = UUID{Status: Present} + *dst = UUID{Valid: true} copy(dst.Bytes[:], src) return nil } func (src UUID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return append(buf, encodeUUID(src.Bytes)...), nil } func (src UUID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } return append(buf, src.Bytes[:]...), nil @@ -182,7 +171,7 @@ func (src UUID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *UUID) Scan(src interface{}) error { if src == nil { - *dst = UUID{Status: Null} + *dst = UUID{} return nil } @@ -204,19 +193,15 @@ func (src UUID) Value() (driver.Value, error) { } func (src UUID) MarshalJSON() ([]byte, error) { - switch src.Status { - case Present: - var buff bytes.Buffer - buff.WriteByte('"') - buff.WriteString(encodeUUID(src.Bytes)) - buff.WriteByte('"') - return buff.Bytes(), nil - case Null: + if !src.Valid { return []byte("null"), nil - case Undefined: - return nil, errUndefined } - return nil, errBadStatus + + var buff bytes.Buffer + buff.WriteByte('"') + buff.WriteString(encodeUUID(src.Bytes)) + buff.WriteByte('"') + return buff.Bytes(), nil } func (dst *UUID) UnmarshalJSON(src []byte) error { diff --git a/uuid_array.go b/uuid_array.go index 00721ef9..98904f9f 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -14,13 +14,13 @@ import ( type UUIDArray struct { Elements []UUID Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *UUIDArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = UUIDArray{Status: Null} + *dst = UUIDArray{} return nil } @@ -36,9 +36,9 @@ func (dst *UUIDArray) Set(src interface{}) error { case [][16]byte: if value == nil { - *dst = UUIDArray{Status: Null} + *dst = UUIDArray{} } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} + *dst = UUIDArray{Valid: true} } else { elements := make([]UUID, len(value)) for i := range value { @@ -49,15 +49,15 @@ func (dst *UUIDArray) Set(src interface{}) error { *dst = UUIDArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case [][]byte: if value == nil { - *dst = UUIDArray{Status: Null} + *dst = UUIDArray{} } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} + *dst = UUIDArray{Valid: true} } else { elements := make([]UUID, len(value)) for i := range value { @@ -68,15 +68,15 @@ func (dst *UUIDArray) Set(src interface{}) error { *dst = UUIDArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []string: if value == nil { - *dst = UUIDArray{Status: Null} + *dst = UUIDArray{} } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} + *dst = UUIDArray{Valid: true} } else { elements := make([]UUID, len(value)) for i := range value { @@ -87,15 +87,15 @@ func (dst *UUIDArray) Set(src interface{}) error { *dst = UUIDArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*string: if value == nil { - *dst = UUIDArray{Status: Null} + *dst = UUIDArray{} } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} + *dst = UUIDArray{Valid: true} } else { elements := make([]UUID, len(value)) for i := range value { @@ -106,20 +106,20 @@ func (dst *UUIDArray) Set(src interface{}) error { *dst = UUIDArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []UUID: if value == nil { - *dst = UUIDArray{Status: Null} + *dst = UUIDArray{} } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} + *dst = UUIDArray{Valid: true} } else { *dst = UUIDArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -128,7 +128,7 @@ func (dst *UUIDArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = UUIDArray{Status: Null} + *dst = UUIDArray{} return nil } @@ -137,7 +137,7 @@ func (dst *UUIDArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for UUIDArray", src) } if elementsLength == 0 { - *dst = UUIDArray{Status: Present} + *dst = UUIDArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -150,7 +150,7 @@ func (dst *UUIDArray) Set(src interface{}) error { *dst = UUIDArray{ Elements: make([]UUID, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -217,102 +217,95 @@ func (dst *UUIDArray) setRecursive(value reflect.Value, index, dimension int) (i } func (dst UUIDArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *UUIDArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[][16]byte: - *v = make([][16]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[][]byte: - *v = make([][]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[][16]byte: + *v = make([][16]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[][]byte: + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *UUIDArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -364,7 +357,7 @@ func (src *UUIDArray) assignToRecursive(value reflect.Value, index, dimension in func (dst *UUIDArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = UUIDArray{Status: Null} + *dst = UUIDArray{} return nil } @@ -393,14 +386,14 @@ func (dst *UUIDArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = UUIDArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = UUIDArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *UUIDArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = UUIDArray{Status: Null} + *dst = UUIDArray{} return nil } @@ -411,7 +404,7 @@ func (dst *UUIDArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = UUIDArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = UUIDArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -436,16 +429,13 @@ func (dst *UUIDArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = UUIDArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = UUIDArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -498,11 +488,8 @@ func (src UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src UUIDArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -516,7 +503,7 @@ func (src UUIDArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/uuid_array_test.go b/uuid_array_test.go index 7d822e7a..47afadff 100644 --- a/uuid_array_test.go +++ b/uuid_array_test.go @@ -13,41 +13,41 @@ func TestUUIDArrayTranscode(t *testing.T) { &pgtype.UUIDArray{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.UUIDArray{ Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Status: pgtype.Null}, + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.UUIDArray{Status: pgtype.Null}, + &pgtype.UUIDArray{}, &pgtype.UUIDArray{ Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, + {}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.UUIDArray{ Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } @@ -59,29 +59,29 @@ func TestUUIDArraySet(t *testing.T) { }{ { source: nil, - result: pgtype.UUIDArray{Status: pgtype.Null}, + result: pgtype.UUIDArray{}, }, { source: [][16]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][16]byte{}, - result: pgtype.UUIDArray{Status: pgtype.Present}, + result: pgtype.UUIDArray{Valid: true}, }, { source: ([][16]byte)(nil), - result: pgtype.UUIDArray{Status: pgtype.Null}, + result: pgtype.UUIDArray{}, }, { source: [][]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][]byte{ @@ -92,36 +92,36 @@ func TestUUIDArraySet(t *testing.T) { }, result: pgtype.UUIDArray{ Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, + {}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 4}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][]byte{}, - result: pgtype.UUIDArray{Status: pgtype.Present}, + result: pgtype.UUIDArray{Valid: true}, }, { source: ([][]byte)(nil), - result: pgtype.UUIDArray{Status: pgtype.Null}, + result: pgtype.UUIDArray{}, }, { source: []string{"00010203-0405-0607-0809-0a0b0c0d0e0f"}, result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: []string{}, - result: pgtype.UUIDArray{Status: pgtype.Present}, + result: pgtype.UUIDArray{Valid: true}, }, { source: ([]string)(nil), - result: pgtype.UUIDArray{Status: pgtype.Null}, + result: pgtype.UUIDArray{}, }, { source: [][][16]byte{{ @@ -129,10 +129,10 @@ func TestUUIDArraySet(t *testing.T) { {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, result: pgtype.UUIDArray{ Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]string{ @@ -146,18 +146,18 @@ func TestUUIDArraySet(t *testing.T) { "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, result: pgtype.UUIDArray{ Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, - {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, - {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Valid: true}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][16]byte{{ @@ -165,10 +165,10 @@ func TestUUIDArraySet(t *testing.T) { {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, result: pgtype.UUIDArray{ Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3]string{ @@ -182,18 +182,18 @@ func TestUUIDArraySet(t *testing.T) { "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, result: pgtype.UUIDArray{ Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, - {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, - {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Valid: true}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -227,63 +227,63 @@ func TestUUIDArrayAssignTo(t *testing.T) { }{ { src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &byteArraySlice, expected: [][16]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, }, { - src: pgtype.UUIDArray{Status: pgtype.Null}, + src: pgtype.UUIDArray{}, dst: &byteArraySlice, expected: ([][16]byte)(nil), }, { src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &byteSliceSlice, expected: [][]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, }, { - src: pgtype.UUIDArray{Status: pgtype.Null}, + src: pgtype.UUIDArray{}, dst: &byteSliceSlice, expected: ([][]byte)(nil), }, { - src: pgtype.UUIDArray{Status: pgtype.Present}, + src: pgtype.UUIDArray{Valid: true}, dst: &byteSlice, expected: []byte{}, }, { - src: pgtype.UUIDArray{Status: pgtype.Present}, + src: pgtype.UUIDArray{Valid: true}, dst: &stringSlice, expected: []string{}, }, { src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &stringSlice, expected: []string{"00010203-0405-0607-0809-0a0b0c0d0e0f"}, }, { - src: pgtype.UUIDArray{Status: pgtype.Null}, + src: pgtype.UUIDArray{}, dst: &stringSlice, expected: ([]string)(nil), }, { src: pgtype.UUIDArray{ Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &byteArraySliceDim2, expected: [][][16]byte{{ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, @@ -292,18 +292,18 @@ func TestUUIDArrayAssignTo(t *testing.T) { { src: pgtype.UUIDArray{ Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, - {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, - {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Valid: true}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringSliceDim4, expected: [][][][]string{ {{{ @@ -318,10 +318,10 @@ func TestUUIDArrayAssignTo(t *testing.T) { { src: pgtype.UUIDArray{ Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &byteArrayDim2, expected: [2][1][16]byte{{ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, @@ -330,18 +330,18 @@ func TestUUIDArrayAssignTo(t *testing.T) { { src: pgtype.UUIDArray{ Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, - {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, - {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Valid: true}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim4, expected: [2][1][1][3]string{ {{{ diff --git a/uuid_test.go b/uuid_test.go index 5a93ea8d..887f45dd 100644 --- a/uuid_test.go +++ b/uuid_test.go @@ -7,12 +7,13 @@ import ( "github.com/jackc/pgtype" "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" ) func TestUUIDTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - &pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - &pgtype.UUID{Status: pgtype.Null}, + &pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + &pgtype.UUID{}, }) } @@ -29,31 +30,31 @@ func TestUUIDSet(t *testing.T) { }{ { source: nil, - result: pgtype.UUID{Status: pgtype.Null}, + result: pgtype.UUID{}, }, { source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, }, { source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, }, { source: SomeUUIDType{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, }, { source: ([]byte)(nil), - result: pgtype.UUID{Status: pgtype.Null}, + result: pgtype.UUID{}, }, { source: "00010203-0405-0607-0809-0a0b0c0d0e0f", - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, }, { source: "000102030405060708090a0b0c0d0e0f", - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, }, } @@ -72,7 +73,7 @@ func TestUUIDSet(t *testing.T) { func TestUUIDAssignTo(t *testing.T) { { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} var dst [16]byte expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} @@ -87,7 +88,7 @@ func TestUUIDAssignTo(t *testing.T) { } { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} var dst []byte expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} @@ -102,7 +103,7 @@ func TestUUIDAssignTo(t *testing.T) { } { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} var dst SomeUUIDType expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} @@ -117,7 +118,7 @@ func TestUUIDAssignTo(t *testing.T) { } { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} var dst string expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" @@ -132,7 +133,7 @@ func TestUUIDAssignTo(t *testing.T) { } { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} var dst SomeUUIDWrapper expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} @@ -149,46 +150,30 @@ func TestUUIDAssignTo(t *testing.T) { func TestUUID_MarshalJSON(t *testing.T) { tests := []struct { - name string - src pgtype.UUID - want []byte - wantErr bool + name string + src pgtype.UUID + want []byte }{ { name: "first", src: pgtype.UUID{ - Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, - Status: pgtype.Present, + Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, + Valid: true, }, - want: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), - wantErr: false, - }, - { - name: "second", - src: pgtype.UUID{ - Bytes: [16]byte{}, - Status: pgtype.Undefined, - }, - want: nil, - wantErr: true, + want: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), }, { name: "third", src: pgtype.UUID{ - Bytes: [16]byte{}, - Status: pgtype.Null, + Bytes: [16]byte{}, }, - want: []byte("null"), - wantErr: false, + want: []byte("null"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.src.MarshalJSON() - if (err != nil) != tt.wantErr { - t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) - return - } + require.NoError(t, err) if !reflect.DeepEqual(got, tt.want) { t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) } @@ -206,8 +191,8 @@ func TestUUID_UnmarshalJSON(t *testing.T) { { name: "first", want: &pgtype.UUID{ - Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, - Status: pgtype.Present, + Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, + Valid: true, }, src: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), wantErr: false, @@ -215,8 +200,7 @@ func TestUUID_UnmarshalJSON(t *testing.T) { { name: "second", want: &pgtype.UUID{ - Bytes: [16]byte{}, - Status: pgtype.Null, + Bytes: [16]byte{}, }, src: []byte("null"), wantErr: false, @@ -224,8 +208,8 @@ func TestUUID_UnmarshalJSON(t *testing.T) { { name: "third", want: &pgtype.UUID{ - Bytes: [16]byte{}, - Status: pgtype.Undefined, + Bytes: [16]byte{}, + Valid: false, }, src: []byte("1d485a7a-6d18-4599-8c6c-34425616887a"), wantErr: true, diff --git a/varbit.go b/varbit.go index f24dc5bc..bc6fdac4 100644 --- a/varbit.go +++ b/varbit.go @@ -9,9 +9,9 @@ import ( ) type Varbit struct { - Bytes []byte - Len int32 // Number of bits - Status Status + Bytes []byte + Len int32 // Number of bits + Valid bool } func (dst *Varbit) Set(src interface{}) error { @@ -19,14 +19,10 @@ func (dst *Varbit) Set(src interface{}) error { } func (dst Varbit) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *Varbit) AssignTo(dst interface{}) error { @@ -35,7 +31,7 @@ func (src *Varbit) AssignTo(dst interface{}) error { func (dst *Varbit) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Varbit{Status: Null} + *dst = Varbit{} return nil } @@ -54,13 +50,13 @@ func (dst *Varbit) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = Varbit{Bytes: buf, Len: int32(bitLen), Status: Present} + *dst = Varbit{Bytes: buf, Len: int32(bitLen), Valid: true} return nil } func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Varbit{Status: Null} + *dst = Varbit{} return nil } @@ -71,16 +67,13 @@ func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { bitLen := int32(binary.BigEndian.Uint32(src)) rp := 4 - *dst = Varbit{Bytes: src[rp:], Len: bitLen, Status: Present} + *dst = Varbit{Bytes: src[rp:], Len: bitLen, Valid: true} return nil } func (src Varbit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } for i := int32(0); i < src.Len; i++ { @@ -97,11 +90,8 @@ func (src Varbit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src Varbit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } buf = pgio.AppendInt32(buf, src.Len) @@ -111,7 +101,7 @@ func (src Varbit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Varbit) Scan(src interface{}) error { if src == nil { - *dst = Varbit{Status: Null} + *dst = Varbit{} return nil } diff --git a/varbit_test.go b/varbit_test.go index 3c5aea1e..b81bdc0e 100644 --- a/varbit_test.go +++ b/varbit_test.go @@ -9,10 +9,10 @@ import ( func TestVarbitTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "varbit", []interface{}{ - &pgtype.Varbit{Bytes: []byte{}, Len: 0, Status: pgtype.Present}, - &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Status: pgtype.Present}, - &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Status: pgtype.Present}, - &pgtype.Varbit{Status: pgtype.Null}, + &pgtype.Varbit{Bytes: []byte{}, Len: 0, Valid: true}, + &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, + &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Valid: true}, + &pgtype.Varbit{}, }) } @@ -20,7 +20,7 @@ func TestVarbitNormalize(t *testing.T) { testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { SQL: "select B'111111111'", - Value: &pgtype.Varbit{Bytes: []byte{255, 128}, Len: 9, Status: pgtype.Present}, + Value: &pgtype.Varbit{Bytes: []byte{255, 128}, Len: 9, Valid: true}, }, }) } diff --git a/varchar_array.go b/varchar_array.go index 8a309a3f..3e0913dc 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -14,13 +14,13 @@ import ( type VarcharArray struct { Elements []Varchar Dimensions []ArrayDimension - Status Status + Valid bool } func (dst *VarcharArray) Set(src interface{}) error { // untyped nil and typed nil interfaces are different if src == nil { - *dst = VarcharArray{Status: Null} + *dst = VarcharArray{} return nil } @@ -36,9 +36,9 @@ func (dst *VarcharArray) Set(src interface{}) error { case []string: if value == nil { - *dst = VarcharArray{Status: Null} + *dst = VarcharArray{} } else if len(value) == 0 { - *dst = VarcharArray{Status: Present} + *dst = VarcharArray{Valid: true} } else { elements := make([]Varchar, len(value)) for i := range value { @@ -49,15 +49,15 @@ func (dst *VarcharArray) Set(src interface{}) error { *dst = VarcharArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []*string: if value == nil { - *dst = VarcharArray{Status: Null} + *dst = VarcharArray{} } else if len(value) == 0 { - *dst = VarcharArray{Status: Present} + *dst = VarcharArray{Valid: true} } else { elements := make([]Varchar, len(value)) for i := range value { @@ -68,20 +68,20 @@ func (dst *VarcharArray) Set(src interface{}) error { *dst = VarcharArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, + Valid: true, } } case []Varchar: if value == nil { - *dst = VarcharArray{Status: Null} + *dst = VarcharArray{} } else if len(value) == 0 { - *dst = VarcharArray{Status: Present} + *dst = VarcharArray{Valid: true} } else { *dst = VarcharArray{ Elements: value, Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, + Valid: true, } } default: @@ -90,7 +90,7 @@ func (dst *VarcharArray) Set(src interface{}) error { // but it comes with a 20-50% performance penalty for large arrays/slices reflectedValue := reflect.ValueOf(src) if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = VarcharArray{Status: Null} + *dst = VarcharArray{} return nil } @@ -99,7 +99,7 @@ func (dst *VarcharArray) Set(src interface{}) error { return fmt.Errorf("cannot find dimensions of %v for VarcharArray", src) } if elementsLength == 0 { - *dst = VarcharArray{Status: Present} + *dst = VarcharArray{Valid: true} return nil } if len(dimensions) == 0 { @@ -112,7 +112,7 @@ func (dst *VarcharArray) Set(src interface{}) error { *dst = VarcharArray{ Elements: make([]Varchar, elementsLength), Dimensions: dimensions, - Status: Present, + Valid: true, } elementCount, err := dst.setRecursive(reflectedValue, 0, 0) if err != nil { @@ -179,84 +179,77 @@ func (dst *VarcharArray) setRecursive(value reflect.Value, index, dimension int) } func (dst VarcharArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: + if !dst.Valid { return nil - default: - return dst.Status } + return dst } func (src *VarcharArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil - case Null: + if !src.Valid { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %#v into %T", src, dst) + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil } func (src *VarcharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { @@ -308,7 +301,7 @@ func (src *VarcharArray) assignToRecursive(value reflect.Value, index, dimension func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = VarcharArray{Status: Null} + *dst = VarcharArray{} return nil } @@ -337,14 +330,14 @@ func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = VarcharArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = VarcharArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} return nil } func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = VarcharArray{Status: Null} + *dst = VarcharArray{} return nil } @@ -355,7 +348,7 @@ func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = VarcharArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = VarcharArray{Dimensions: arrayHeader.Dimensions, Valid: true} return nil } @@ -380,16 +373,13 @@ func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = VarcharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = VarcharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} return nil } func (src VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } if len(src.Dimensions) == 0 { @@ -442,11 +432,8 @@ func (src VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } func (src VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !src.Valid { return nil, nil - case Undefined: - return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -460,7 +447,7 @@ func (src VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } for i := range src.Elements { - if src.Elements[i].Status == Null { + if !src.Elements[i].Valid { arrayHeader.ContainsNull = true break } diff --git a/varchar_array_test.go b/varchar_array_test.go index 5fb7326d..cf0efd6d 100644 --- a/varchar_array_test.go +++ b/varchar_array_test.go @@ -13,41 +13,41 @@ func TestVarcharArrayTranscode(t *testing.T) { &pgtype.VarcharArray{ Elements: nil, Dimensions: nil, - Status: pgtype.Present, + Valid: true, }, &pgtype.VarcharArray{ Elements: []pgtype.Varchar{ - {String: "foo", Status: pgtype.Present}, - {Status: pgtype.Null}, + {String: "foo", Valid: true}, + {}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, - &pgtype.VarcharArray{Status: pgtype.Null}, + &pgtype.VarcharArray{}, &pgtype.VarcharArray{ Elements: []pgtype.Varchar{ - {String: "bar ", Status: pgtype.Present}, - {String: "NuLL", Status: pgtype.Present}, - {String: `wow"quz\`, Status: pgtype.Present}, - {String: "", Status: pgtype.Present}, - {Status: pgtype.Null}, - {String: "null", Status: pgtype.Present}, + {String: "bar ", Valid: true}, + {String: "NuLL", Valid: true}, + {String: `wow"quz\`, Valid: true}, + {String: "", Valid: true}, + {}, + {String: "null", Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, }, &pgtype.VarcharArray{ Elements: []pgtype.Varchar{ - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "quz", Status: pgtype.Present}, - {String: "foo", Status: pgtype.Present}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "quz", Valid: true}, + {String: "foo", Valid: true}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, - Status: pgtype.Present, + Valid: true, }, }) } @@ -60,61 +60,61 @@ func TestVarcharArraySet(t *testing.T) { { source: []string{"foo"}, result: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, + Elements: []pgtype.Varchar{{String: "foo", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: (([]string)(nil)), - result: pgtype.VarcharArray{Status: pgtype.Null}, + result: pgtype.VarcharArray{}, }, { source: [][]string{{"foo"}, {"bar"}}, result: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, result: pgtype.VarcharArray{ Elements: []pgtype.Varchar{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1]string{{"foo"}, {"bar"}}, result: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, }, { source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, result: pgtype.VarcharArray{ Elements: []pgtype.Varchar{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, }, } @@ -147,81 +147,81 @@ func TestVarcharArrayAssignTo(t *testing.T) { }{ { src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, + Elements: []pgtype.Varchar{{String: "foo", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &stringSlice, expected: []string{"foo"}, }, { src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.Varchar{{String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &namedStringSlice, expected: _stringSlice{"bar"}, }, { - src: pgtype.VarcharArray{Status: pgtype.Null}, + src: pgtype.VarcharArray{}, dst: &stringSlice, expected: (([]string)(nil)), }, { - src: pgtype.VarcharArray{Status: pgtype.Present}, + src: pgtype.VarcharArray{Valid: true}, dst: &stringSlice, expected: []string{}, }, { src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringSliceDim2, expected: [][]string{{"foo"}, {"bar"}}, }, { src: pgtype.VarcharArray{ Elements: []pgtype.Varchar{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringSliceDim4, expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, }, { src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim2, expected: [2][1]string{{"foo"}, {"bar"}}, }, { src: pgtype.VarcharArray{ Elements: []pgtype.Varchar{ - {String: "foo", Status: pgtype.Present}, - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "wibble", Status: pgtype.Present}, - {String: "wobble", Status: pgtype.Present}, - {String: "wubble", Status: pgtype.Present}}, + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, Dimensions: []pgtype.ArrayDimension{ {LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 3}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim4, expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, }, @@ -244,31 +244,31 @@ func TestVarcharArrayAssignTo(t *testing.T) { }{ { src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{Status: pgtype.Null}}, + Elements: []pgtype.Varchar{{}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, + Valid: true, }, dst: &stringSlice, }, { src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim2, }, { src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringSlice, }, { src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, + Valid: true}, dst: &stringArrayDim4, }, } diff --git a/xid_test.go b/xid_test.go index 531867f6..fab10f79 100644 --- a/xid_test.go +++ b/xid_test.go @@ -11,8 +11,8 @@ import ( func TestXIDTranscode(t *testing.T) { pgTypeName := "xid" values := []interface{}{ - &pgtype.XID{Uint: 42, Status: pgtype.Present}, - &pgtype.XID{Status: pgtype.Null}, + &pgtype.XID{Uint: 42, Valid: true}, + &pgtype.XID{}, } eqFunc := func(a, b interface{}) bool { return reflect.DeepEqual(a, b) @@ -27,7 +27,7 @@ func TestXIDSet(t *testing.T) { source interface{} result pgtype.XID }{ - {source: uint32(1), result: pgtype.XID{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.XID{Uint: 1, Valid: true}}, } for i, tt := range successfulTests { @@ -52,8 +52,8 @@ func TestXIDAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.XID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.XID{Uint: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.XID{}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -72,7 +72,7 @@ func TestXIDAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.XID{Uint: 42, Valid: true}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -90,7 +90,7 @@ func TestXIDAssignTo(t *testing.T) { src pgtype.XID dst interface{} }{ - {src: pgtype.XID{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.XID{}, dst: &ui32}, } for i, tt := range errorTests { diff --git a/zeronull/float8.go b/zeronull/float8.go index ebc86ac3..07d5e1a5 100644 --- a/zeronull/float8.go +++ b/zeronull/float8.go @@ -15,7 +15,7 @@ func (dst *Float8) DecodeText(ci *pgtype.ConnInfo, src []byte) error { return err } - if nullable.Status == pgtype.Present { + if nullable.Valid { *dst = Float8(nullable.Float) } else { *dst = 0 @@ -31,7 +31,7 @@ func (dst *Float8) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } - if nullable.Status == pgtype.Present { + if nullable.Valid { *dst = Float8(nullable.Float) } else { *dst = 0 @@ -46,8 +46,8 @@ func (src Float8) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { } nullable := pgtype.Float8{ - Float: float64(src), - Status: pgtype.Present, + Float: float64(src), + Valid: true, } return nullable.EncodeText(ci, buf) @@ -59,8 +59,8 @@ func (src Float8) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) } nullable := pgtype.Float8{ - Float: float64(src), - Status: pgtype.Present, + Float: float64(src), + Valid: true, } return nullable.EncodeBinary(ci, buf) diff --git a/zeronull/int2.go b/zeronull/int2.go index a528642f..b3f9c328 100644 --- a/zeronull/int2.go +++ b/zeronull/int2.go @@ -15,7 +15,7 @@ func (dst *Int2) DecodeText(ci *pgtype.ConnInfo, src []byte) error { return err } - if nullable.Status == pgtype.Present { + if nullable.Valid { *dst = Int2(nullable.Int) } else { *dst = 0 @@ -31,7 +31,7 @@ func (dst *Int2) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } - if nullable.Status == pgtype.Present { + if nullable.Valid { *dst = Int2(nullable.Int) } else { *dst = 0 @@ -46,8 +46,8 @@ func (src Int2) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { } nullable := pgtype.Int2{ - Int: int16(src), - Status: pgtype.Present, + Int: int16(src), + Valid: true, } return nullable.EncodeText(ci, buf) @@ -59,8 +59,8 @@ func (src Int2) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { } nullable := pgtype.Int2{ - Int: int16(src), - Status: pgtype.Present, + Int: int16(src), + Valid: true, } return nullable.EncodeBinary(ci, buf) diff --git a/zeronull/int4.go b/zeronull/int4.go index c539e43a..3efca4e6 100644 --- a/zeronull/int4.go +++ b/zeronull/int4.go @@ -15,7 +15,7 @@ func (dst *Int4) DecodeText(ci *pgtype.ConnInfo, src []byte) error { return err } - if nullable.Status == pgtype.Present { + if nullable.Valid { *dst = Int4(nullable.Int) } else { *dst = 0 @@ -31,7 +31,7 @@ func (dst *Int4) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } - if nullable.Status == pgtype.Present { + if nullable.Valid { *dst = Int4(nullable.Int) } else { *dst = 0 @@ -46,8 +46,8 @@ func (src Int4) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { } nullable := pgtype.Int4{ - Int: int32(src), - Status: pgtype.Present, + Int: int32(src), + Valid: true, } return nullable.EncodeText(ci, buf) @@ -59,8 +59,8 @@ func (src Int4) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { } nullable := pgtype.Int4{ - Int: int32(src), - Status: pgtype.Present, + Int: int32(src), + Valid: true, } return nullable.EncodeBinary(ci, buf) diff --git a/zeronull/int8.go b/zeronull/int8.go index 19774645..5cb063d8 100644 --- a/zeronull/int8.go +++ b/zeronull/int8.go @@ -15,7 +15,7 @@ func (dst *Int8) DecodeText(ci *pgtype.ConnInfo, src []byte) error { return err } - if nullable.Status == pgtype.Present { + if nullable.Valid { *dst = Int8(nullable.Int) } else { *dst = 0 @@ -31,7 +31,7 @@ func (dst *Int8) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } - if nullable.Status == pgtype.Present { + if nullable.Valid { *dst = Int8(nullable.Int) } else { *dst = 0 @@ -46,8 +46,8 @@ func (src Int8) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { } nullable := pgtype.Int8{ - Int: int64(src), - Status: pgtype.Present, + Int: int64(src), + Valid: true, } return nullable.EncodeText(ci, buf) @@ -59,8 +59,8 @@ func (src Int8) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { } nullable := pgtype.Int8{ - Int: int64(src), - Status: pgtype.Present, + Int: int64(src), + Valid: true, } return nullable.EncodeBinary(ci, buf) diff --git a/zeronull/text.go b/zeronull/text.go index 8e79fc6a..afcb1a42 100644 --- a/zeronull/text.go +++ b/zeronull/text.go @@ -15,7 +15,7 @@ func (dst *Text) DecodeText(ci *pgtype.ConnInfo, src []byte) error { return err } - if nullable.Status == pgtype.Present { + if nullable.Valid { *dst = Text(nullable.String) } else { *dst = Text("") @@ -31,7 +31,7 @@ func (dst *Text) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } - if nullable.Status == pgtype.Present { + if nullable.Valid { *dst = Text(nullable.String) } else { *dst = Text("") @@ -47,7 +47,7 @@ func (src Text) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { nullable := pgtype.Text{ String: string(src), - Status: pgtype.Present, + Valid: true, } return nullable.EncodeText(ci, buf) @@ -60,7 +60,7 @@ func (src Text) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { nullable := pgtype.Text{ String: string(src), - Status: pgtype.Present, + Valid: true, } return nullable.EncodeBinary(ci, buf) diff --git a/zeronull/timestamp.go b/zeronull/timestamp.go index a94c67cc..61787818 100644 --- a/zeronull/timestamp.go +++ b/zeronull/timestamp.go @@ -16,7 +16,7 @@ func (dst *Timestamp) DecodeText(ci *pgtype.ConnInfo, src []byte) error { return err } - if nullable.Status == pgtype.Present { + if nullable.Valid { *dst = Timestamp(nullable.Time) } else { *dst = Timestamp{} @@ -32,7 +32,7 @@ func (dst *Timestamp) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } - if nullable.Status == pgtype.Present { + if nullable.Valid { *dst = Timestamp(nullable.Time) } else { *dst = Timestamp{} @@ -47,8 +47,8 @@ func (src Timestamp) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) } nullable := pgtype.Timestamp{ - Time: time.Time(src), - Status: pgtype.Present, + Time: time.Time(src), + Valid: true, } return nullable.EncodeText(ci, buf) @@ -60,8 +60,8 @@ func (src Timestamp) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, erro } nullable := pgtype.Timestamp{ - Time: time.Time(src), - Status: pgtype.Present, + Time: time.Time(src), + Valid: true, } return nullable.EncodeBinary(ci, buf) diff --git a/zeronull/timestamptz.go b/zeronull/timestamptz.go index c641ca10..4896e9b7 100644 --- a/zeronull/timestamptz.go +++ b/zeronull/timestamptz.go @@ -16,7 +16,7 @@ func (dst *Timestamptz) DecodeText(ci *pgtype.ConnInfo, src []byte) error { return err } - if nullable.Status == pgtype.Present { + if nullable.Valid { *dst = Timestamptz(nullable.Time) } else { *dst = Timestamptz{} @@ -32,7 +32,7 @@ func (dst *Timestamptz) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } - if nullable.Status == pgtype.Present { + if nullable.Valid { *dst = Timestamptz(nullable.Time) } else { *dst = Timestamptz{} @@ -47,8 +47,8 @@ func (src Timestamptz) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, erro } nullable := pgtype.Timestamptz{ - Time: time.Time(src), - Status: pgtype.Present, + Time: time.Time(src), + Valid: true, } return nullable.EncodeText(ci, buf) @@ -60,8 +60,8 @@ func (src Timestamptz) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, er } nullable := pgtype.Timestamptz{ - Time: time.Time(src), - Status: pgtype.Present, + Time: time.Time(src), + Valid: true, } return nullable.EncodeBinary(ci, buf) diff --git a/zeronull/uuid.go b/zeronull/uuid.go index 18fc667e..25211122 100644 --- a/zeronull/uuid.go +++ b/zeronull/uuid.go @@ -15,7 +15,7 @@ func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { return err } - if nullable.Status == pgtype.Present { + if nullable.Valid { *dst = UUID(nullable.Bytes) } else { *dst = UUID{} @@ -31,7 +31,7 @@ func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } - if nullable.Status == pgtype.Present { + if nullable.Valid { *dst = UUID(nullable.Bytes) } else { *dst = UUID{} @@ -46,8 +46,8 @@ func (src UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { } nullable := pgtype.UUID{ - Bytes: [16]byte(src), - Status: pgtype.Present, + Bytes: [16]byte(src), + Valid: true, } return nullable.EncodeText(ci, buf) @@ -59,8 +59,8 @@ func (src UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { } nullable := pgtype.UUID{ - Bytes: [16]byte(src), - Status: pgtype.Present, + Bytes: [16]byte(src), + Valid: true, } return nullable.EncodeBinary(ci, buf) From 2886673a3cb7c26f72fc5d3cc132a8177bdd1fa5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 28 Aug 2021 14:07:13 -0500 Subject: [PATCH 0740/1158] Add full query decoding benchmarks --- integration_benchmark_test.go | 1292 +++++++++++++++++++++++++++++ integration_benchmark_test.go.erb | 44 + integration_benchmark_test_gen.sh | 2 + 3 files changed, 1338 insertions(+) create mode 100644 integration_benchmark_test.go create mode 100644 integration_benchmark_test.go.erb create mode 100755 integration_benchmark_test_gen.sh diff --git a/integration_benchmark_test.go b/integration_benchmark_test.go new file mode 100644 index 00000000..d3af7c31 --- /dev/null +++ b/integration_benchmark_test.go @@ -0,0 +1,1292 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" +) + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/integration_benchmark_test.go.erb b/integration_benchmark_test.go.erb new file mode 100644 index 00000000..037c96c3 --- /dev/null +++ b/integration_benchmark_test.go.erb @@ -0,0 +1,44 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" +) + +<% + [ + ["int4", ["int16", "int32", "int64", "uint64", "pgtype.Int4"], [[1, 1], [1, 10], [10, 1], [100, 10]]], + ["numeric", ["int64", "float64", "pgtype.Numeric"], [[1, 1], [1, 10], [10, 1], [100, 10]]], + ].each do |pg_type, go_types, rows_columns| +%> +<% go_types.each do |go_type| %> +<% rows_columns.each do |rows, columns| %> +<% [["Text", "pgx.TextFormatCode"], ["Binary", "pgx.BinaryFormatCode"]].each do |formatName, formatCode| %> +func BenchmarkQuery<%= formatName %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go_type.gsub(/\W/, "_") %>_<%= rows %>_rows_<%= columns %>_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [<%= columns %>]<%= go_type %> + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select <% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>n::<%= pg_type %> + <%= col_idx%><% end %> from generate_series(1, <%= rows %>) n`, + []interface{}{pgx.QueryResultFormats{<%= formatCode %>}}, + []interface{}{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} +<% end %> +<% end %> +<% end %> +<% end %> diff --git a/integration_benchmark_test_gen.sh b/integration_benchmark_test_gen.sh new file mode 100755 index 00000000..22ac01aa --- /dev/null +++ b/integration_benchmark_test_gen.sh @@ -0,0 +1,2 @@ +erb integration_benchmark_test.go.erb > integration_benchmark_test.go +goimports -w integration_benchmark_test.go From 63a8fe12d7e92a6d8e75d67f52f75f5bea612b17 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 28 Aug 2021 18:23:54 -0500 Subject: [PATCH 0741/1158] Add hooks for efficiently integrating with 3rd party types --- numeric.go | 17 +++++++++++++++++ pgtype.go | 27 +++++++++++++++++++++------ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/numeric.go b/numeric.go index 3d209ff2..83791d4e 100644 --- a/numeric.go +++ b/numeric.go @@ -235,7 +235,24 @@ func (dst Numeric) Get() interface{} { return dst } +var NumericDecoderWrapper func(interface{}) NumericDecoder + +type NumericDecoder interface { + DecodeNumeric(*Numeric) error +} + func (src *Numeric) AssignTo(dst interface{}) error { + if d, ok := dst.(NumericDecoder); ok { + return d.DecodeNumeric(src) + } else { + if NumericDecoderWrapper != nil { + d = NumericDecoderWrapper(dst) + if d != nil { + return d.DecodeNumeric(src) + } + } + } + if !src.Valid { return NullAssignTo(dst) } diff --git a/pgtype.go b/pgtype.go index c4fe870d..39e0ad79 100644 --- a/pgtype.go +++ b/pgtype.go @@ -225,15 +225,18 @@ type ConnInfo struct { oidToResultFormatCode map[uint32]int16 reflectTypeToDataType map[reflect.Type]*DataType + + preferAssignToOverSQLScannerTypes map[reflect.Type]struct{} } func newConnInfo() *ConnInfo { return &ConnInfo{ - oidToDataType: make(map[uint32]*DataType), - nameToDataType: make(map[string]*DataType), - reflectTypeToName: make(map[reflect.Type]string), - oidToParamFormatCode: make(map[uint32]int16), - oidToResultFormatCode: make(map[uint32]int16), + oidToDataType: make(map[uint32]*DataType), + nameToDataType: make(map[string]*DataType), + reflectTypeToName: make(map[reflect.Type]string), + oidToParamFormatCode: make(map[uint32]int16), + oidToResultFormatCode: make(map[uint32]int16), + preferAssignToOverSQLScannerTypes: make(map[reflect.Type]struct{}), } } @@ -462,6 +465,12 @@ func (ci *ConnInfo) ResultFormatCodeForOID(oid uint32) int16 { return TextFormatCode } +// PreferAssignToOverSQLScannerForType makes a sql.Scanner type use the AssignTo scan path instead of sql.Scanner. +// This is primarily for efficient integration with 3rd party numeric and UUID types. +func (ci *ConnInfo) PreferAssignToOverSQLScannerForType(value interface{}) { + ci.preferAssignToOverSQLScannerTypes[reflect.TypeOf(value)] = struct{}{} +} + // DeepCopy makes a deep copy of the ConnInfo. func (ci *ConnInfo) DeepCopy() *ConnInfo { ci2 := newConnInfo() @@ -478,6 +487,10 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { ci2.reflectTypeToName[t] = n } + for t, _ := range ci.preferAssignToOverSQLScannerTypes { + ci2.preferAssignToOverSQLScannerTypes[t] = struct{}{} + } + return ci2 } @@ -808,7 +821,9 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan if dt != nil { if _, ok := dst.(sql.Scanner); ok { - return (*scanPlanDataTypeSQLScanner)(dt) + if _, found := ci.preferAssignToOverSQLScannerTypes[reflect.TypeOf(dst)]; !found { + return (*scanPlanDataTypeSQLScanner)(dt) + } } return (*scanPlanDataTypeAssignTo)(dt) } From c0eae32e8b498fd517bc4ddc5732bd0c5561c990 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 28 Aug 2021 18:25:51 -0500 Subject: [PATCH 0742/1158] Remove ConnInfo.DeepCopy() --- pgtype.go | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/pgtype.go b/pgtype.go index 39e0ad79..4fa6eebe 100644 --- a/pgtype.go +++ b/pgtype.go @@ -471,29 +471,6 @@ func (ci *ConnInfo) PreferAssignToOverSQLScannerForType(value interface{}) { ci.preferAssignToOverSQLScannerTypes[reflect.TypeOf(value)] = struct{}{} } -// DeepCopy makes a deep copy of the ConnInfo. -func (ci *ConnInfo) DeepCopy() *ConnInfo { - ci2 := newConnInfo() - - for _, dt := range ci.oidToDataType { - ci2.RegisterDataType(DataType{ - Value: NewValue(dt.Value), - Name: dt.Name, - OID: dt.OID, - }) - } - - for t, n := range ci.reflectTypeToName { - ci2.reflectTypeToName[t] = n - } - - for t, _ := range ci.preferAssignToOverSQLScannerTypes { - ci2.preferAssignToOverSQLScannerTypes[t] = struct{}{} - } - - return ci2 -} - // ScanPlan is a precompiled plan to scan into a type of destination. type ScanPlan interface { // Scan scans src into dst. If the dst type has changed in an incompatible way a ScanPlan should automatically From 55ad9007cd9c3c296927897a8dfeed9b5a672942 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 28 Aug 2021 19:11:35 -0500 Subject: [PATCH 0743/1158] Finish Numeric changes for easy integration with 3rd party types --- numeric.go | 115 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 73 insertions(+), 42 deletions(-) diff --git a/numeric.go b/numeric.go index 83791d4e..85648dc2 100644 --- a/numeric.go +++ b/numeric.go @@ -57,14 +57,47 @@ var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) type Numeric struct { Int *big.Int Exp int32 - Valid bool NaN bool InfinityModifier InfinityModifier + Valid bool + + NumericDecoderWrapper func(interface{}) NumericDecoder +} + +func (n *Numeric) NewTypeValue() Value { + return &Numeric{ + NumericDecoderWrapper: n.NumericDecoderWrapper, + } +} + +func (n *Numeric) TypeName() string { + return "numeric" +} + +func (dst *Numeric) setNil() { + dst.Int = nil + dst.Exp = 0 + dst.NaN = false + dst.Valid = false +} + +func (dst *Numeric) setNaN() { + dst.Int = nil + dst.Exp = 0 + dst.NaN = true + dst.Valid = true +} + +func (dst *Numeric) setNumber(i *big.Int, exp int32) { + dst.Int = i + dst.Exp = exp + dst.NaN = false + dst.Valid = true } func (dst *Numeric) Set(src interface{}) error { if src == nil { - *dst = Numeric{} + dst.setNil() return nil } @@ -78,7 +111,7 @@ func (dst *Numeric) Set(src interface{}) error { switch value := src.(type) { case float32: if math.IsNaN(float64(value)) { - *dst = Numeric{Valid: true, NaN: true} + dst.setNaN() return nil } else if math.IsInf(float64(value), 1) { *dst = Numeric{Valid: true, InfinityModifier: Infinity} @@ -91,10 +124,10 @@ func (dst *Numeric) Set(src interface{}) error { if err != nil { return err } - *dst = Numeric{Int: num, Exp: exp, Valid: true} + dst.setNumber(num, exp) case float64: if math.IsNaN(value) { - *dst = Numeric{Valid: true, NaN: true} + dst.setNaN() return nil } else if math.IsInf(value, 1) { *dst = Numeric{Valid: true, InfinityModifier: Infinity} @@ -107,108 +140,108 @@ func (dst *Numeric) Set(src interface{}) error { if err != nil { return err } - *dst = Numeric{Int: num, Exp: exp, Valid: true} + dst.setNumber(num, exp) case int8: - *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} + dst.setNumber(big.NewInt(int64(value)), 0) case uint8: - *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} + dst.setNumber(big.NewInt(int64(value)), 0) case int16: - *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} + dst.setNumber(big.NewInt(int64(value)), 0) case uint16: - *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} + dst.setNumber(big.NewInt(int64(value)), 0) case int32: - *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} + dst.setNumber(big.NewInt(int64(value)), 0) case uint32: - *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} + dst.setNumber(big.NewInt(int64(value)), 0) case int64: - *dst = Numeric{Int: big.NewInt(value), Valid: true} + dst.setNumber(big.NewInt(value), 0) case uint64: - *dst = Numeric{Int: (&big.Int{}).SetUint64(value), Valid: true} + dst.setNumber((&big.Int{}).SetUint64(value), 0) case int: - *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} + dst.setNumber(big.NewInt(int64(value)), 0) case uint: - *dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Valid: true} + dst.setNumber((&big.Int{}).SetUint64(uint64(value)), 0) case string: num, exp, err := parseNumericString(value) if err != nil { return err } - *dst = Numeric{Int: num, Exp: exp, Valid: true} + dst.setNumber(num, exp) case *float64: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *float32: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *int8: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *uint8: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *int16: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *uint16: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *int32: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *uint32: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *int64: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *uint64: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *int: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *uint: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *string: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } @@ -235,8 +268,6 @@ func (dst Numeric) Get() interface{} { return dst } -var NumericDecoderWrapper func(interface{}) NumericDecoder - type NumericDecoder interface { DecodeNumeric(*Numeric) error } @@ -245,8 +276,8 @@ func (src *Numeric) AssignTo(dst interface{}) error { if d, ok := dst.(NumericDecoder); ok { return d.DecodeNumeric(src) } else { - if NumericDecoderWrapper != nil { - d = NumericDecoderWrapper(dst) + if src.NumericDecoderWrapper != nil { + d = src.NumericDecoderWrapper(dst) if d != nil { return d.DecodeNumeric(src) } @@ -443,12 +474,12 @@ func (src *Numeric) toFloat64() (float64, error) { func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Numeric{} + dst.setNil() return nil } if string(src) == "NaN" { - *dst = Numeric{Valid: true, NaN: true} + dst.setNaN() return nil } else if string(src) == "Infinity" { *dst = Numeric{Valid: true, InfinityModifier: Infinity} @@ -463,7 +494,7 @@ func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Numeric{Int: num, Exp: exp, Valid: true} + dst.setNumber(num, exp) return nil } @@ -490,7 +521,7 @@ func parseNumericString(str string) (n *big.Int, exp int32, err error) { func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Numeric{} + dst.setNil() return nil } @@ -509,7 +540,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { rp += 2 if sign == pgNumericNaNSign { - *dst = Numeric{Valid: true, NaN: true} + dst.setNaN() return nil } else if sign == pgNumericPosInfSign { *dst = Numeric{Valid: true, InfinityModifier: Infinity} @@ -520,7 +551,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { } if ndigits == 0 { - *dst = Numeric{Int: big.NewInt(0), Valid: true} + dst.setNumber(big.NewInt(0), 0) return nil } @@ -592,7 +623,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { accum.Neg(accum) } - *dst = Numeric{Int: accum, Exp: exp, Valid: true} + dst.setNumber(accum, exp) return nil @@ -741,7 +772,7 @@ func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Numeric) Scan(src interface{}) error { if src == nil { - *dst = Numeric{} + dst.setNil() return nil } From 1a3e5b0266a97cf99c9f47258b1b272e51d50d56 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 28 Aug 2021 19:12:36 -0500 Subject: [PATCH 0744/1158] Remove explicit shopspring/decimal integration Better integration is now enabled by github.com/jackc/pgx-shopspring-decimal. --- ext/shopspring-numeric/decimal.go | 329 ---------------------- ext/shopspring-numeric/decimal_test.go | 360 ------------------------- 2 files changed, 689 deletions(-) delete mode 100644 ext/shopspring-numeric/decimal.go delete mode 100644 ext/shopspring-numeric/decimal_test.go diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go deleted file mode 100644 index 3a9d99ba..00000000 --- a/ext/shopspring-numeric/decimal.go +++ /dev/null @@ -1,329 +0,0 @@ -package numeric - -import ( - "database/sql/driver" - "fmt" - "strconv" - - "github.com/jackc/pgtype" - "github.com/shopspring/decimal" -) - -type Numeric struct { - Decimal decimal.Decimal - Valid bool -} - -func (dst *Numeric) Set(src interface{}) error { - if src == nil { - *dst = Numeric{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case decimal.Decimal: - *dst = Numeric{Decimal: value, Valid: true} - case decimal.NullDecimal: - if value.Valid { - *dst = Numeric{Decimal: value.Decimal, Valid: true} - } else { - *dst = Numeric{} - } - case float32: - *dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Valid: true} - case float64: - *dst = Numeric{Decimal: decimal.NewFromFloat(value), Valid: true} - case int8: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} - case uint8: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} - case int16: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} - case uint16: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} - case int32: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} - case uint32: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} - case int64: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} - case uint64: - // uint64 could be greater than int64 so convert to string then to decimal - dec, err := decimal.NewFromString(strconv.FormatUint(value, 10)) - if err != nil { - return err - } - *dst = Numeric{Decimal: dec, Valid: true} - case int: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} - case uint: - // uint could be greater than int64 so convert to string then to decimal - dec, err := decimal.NewFromString(strconv.FormatUint(uint64(value), 10)) - if err != nil { - return err - } - *dst = Numeric{Decimal: dec, Valid: true} - case string: - dec, err := decimal.NewFromString(value) - if err != nil { - return err - } - *dst = Numeric{Decimal: dec, Valid: true} - default: - // If all else fails see if pgtype.Numeric can handle it. If so, translate through that. - num := &pgtype.Numeric{} - if err := num.Set(value); err != nil { - return fmt.Errorf("cannot convert %v to Numeric", value) - } - - buf, err := num.EncodeText(nil, nil) - if err != nil { - return fmt.Errorf("cannot convert %v to Numeric", value) - } - - dec, err := decimal.NewFromString(string(buf)) - if err != nil { - return fmt.Errorf("cannot convert %v to Numeric", value) - } - *dst = Numeric{Decimal: dec, Valid: true} - } - - return nil -} - -func (dst Numeric) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.Decimal -} - -func (src *Numeric) AssignTo(dst interface{}) error { - if !src.Valid { - if v, ok := dst.(*decimal.NullDecimal); ok { - (*v).Valid = false - (*v).Decimal = src.Decimal - return nil - } - return pgtype.NullAssignTo(dst) - } - - switch v := dst.(type) { - case *decimal.Decimal: - *v = src.Decimal - case *decimal.NullDecimal: - (*v).Valid = true - (*v).Decimal = src.Decimal - case *float32: - f, _ := src.Decimal.Float64() - *v = float32(f) - case *float64: - f, _ := src.Decimal.Float64() - *v = f - case *int: - if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, strconv.IntSize) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int(n) - case *int8: - if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, 8) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int8(n) - case *int16: - if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, 16) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int16(n) - case *int32: - if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, 32) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int32(n) - case *int64: - if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, 64) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int64(n) - case *uint: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, strconv.IntSize) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint(n) - case *uint8: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, 8) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint8(n) - case *uint16: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, 16) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint16(n) - case *uint32: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, 32) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint32(n) - case *uint64: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, 64) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint64(n) - default: - if nextDst, retry := pgtype.GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - - return nil -} - -func (dst *Numeric) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - *dst = Numeric{} - return nil - } - - dec, err := decimal.NewFromString(string(src)) - if err != nil { - return err - } - - *dst = Numeric{Decimal: dec, Valid: true} - return nil -} - -func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - *dst = Numeric{} - return nil - } - - // For now at least, implement this in terms of pgtype.Numeric - - num := &pgtype.Numeric{} - if err := num.DecodeBinary(ci, src); err != nil { - return err - } - - *dst = Numeric{Decimal: decimal.NewFromBigInt(num.Int, num.Exp), Valid: true} - - return nil -} - -func (src Numeric) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - return append(buf, src.Decimal.String()...), nil -} - -func (src Numeric) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - // For now at least, implement this in terms of pgtype.Numeric - num := &pgtype.Numeric{} - if err := num.DecodeText(ci, []byte(src.Decimal.String())); err != nil { - return nil, err - } - - return num.EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *Numeric) Scan(src interface{}) error { - if src == nil { - *dst = Numeric{} - return nil - } - - switch src := src.(type) { - case float64: - *dst = Numeric{Decimal: decimal.NewFromFloat(src), Valid: true} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - return dst.DecodeText(nil, src) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Numeric) Value() (driver.Value, error) { - if !src.Valid { - return nil, nil - } - return src.Decimal.Value() -} - -func (src Numeric) MarshalJSON() ([]byte, error) { - if !src.Valid { - return []byte("null"), nil - } - return src.Decimal.MarshalJSON() -} - -func (dst *Numeric) UnmarshalJSON(b []byte) error { - d := decimal.NullDecimal{} - err := d.UnmarshalJSON(b) - if err != nil { - return err - } - - *dst = Numeric{Decimal: d.Decimal, Valid: d.Valid} - - return nil -} diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go deleted file mode 100644 index d130a69a..00000000 --- a/ext/shopspring-numeric/decimal_test.go +++ /dev/null @@ -1,360 +0,0 @@ -package numeric_test - -import ( - "fmt" - "math/big" - "math/rand" - "reflect" - "testing" - - shopspring "github.com/jackc/pgtype/ext/shopspring-numeric" - "github.com/jackc/pgtype/testutil" - "github.com/shopspring/decimal" - "github.com/stretchr/testify/require" -) - -func mustParseDecimal(t *testing.T, src string) decimal.Decimal { - dec, err := decimal.NewFromString(src) - if err != nil { - t.Fatal(err) - } - return dec -} - -func TestNumericNormalize(t *testing.T) { - testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ - { - SQL: "select '0'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Valid: true}, - }, - { - SQL: "select '1'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}, - }, - { - SQL: "select '10.00'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Valid: true}, - }, - { - SQL: "select '1e-3'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Valid: true}, - }, - { - SQL: "select '-1'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, - }, - { - SQL: "select '10000'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Valid: true}, - }, - { - SQL: "select '3.14'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Valid: true}, - }, - { - SQL: "select '1.1'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Valid: true}, - }, - { - SQL: "select '100010001'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Valid: true}, - }, - { - SQL: "select '100010001.0001'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Valid: true}, - }, - { - SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", - Value: &shopspring.Numeric{ - Decimal: mustParseDecimal(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), - Valid: true, - }, - }, - { - SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", - Value: &shopspring.Numeric{ - Decimal: mustParseDecimal(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), - Valid: true, - }, - }, - { - SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", - Value: &shopspring.Numeric{ - Decimal: mustParseDecimal(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), - Valid: true, - }, - }, - }, func(aa, bb interface{}) bool { - a := aa.(shopspring.Numeric) - b := bb.(shopspring.Numeric) - - return a.Valid == b.Valid && a.Decimal.Equal(b.Decimal) - }) -} - -func TestNumericTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "100000"), Valid: true}, - - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.1"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.01"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0001"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00001"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000001"), Valid: true}, - - &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000123"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000000123"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0000000123"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000123"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001234567890123456789"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "4309132809320932980457137401234890237489238912983572189348951289375283573984571892758234678903467889512893489128589347891272139.8489235871258912789347891235879148795891238915678189467128957812395781238579189025891238901583915890128973578957912385798125789012378905238905471598123758923478294374327894237892234"), Valid: true}, - &shopspring.Numeric{}, - }, func(aa, bb interface{}) bool { - a := aa.(shopspring.Numeric) - b := bb.(shopspring.Numeric) - - return a.Valid == b.Valid && a.Decimal.Equal(b.Decimal) - }) - -} - -func TestNumericTranscodeFuzz(t *testing.T) { - r := rand.New(rand.NewSource(0)) - max := &big.Int{} - max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) - - values := make([]interface{}, 0, 2000) - for i := 0; i < 500; i++ { - num := fmt.Sprintf("%s.%s", (&big.Int{}).Rand(r, max).String(), (&big.Int{}).Rand(r, max).String()) - negNum := "-" + num - values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, num), Valid: true}) - values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, negNum), Valid: true}) - } - - testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, - func(aa, bb interface{}) bool { - a := aa.(shopspring.Numeric) - b := bb.(shopspring.Numeric) - - return a.Valid == b.Valid && a.Decimal.Equal(b.Decimal) - }) -} - -func TestNumericSet(t *testing.T) { - type _int8 int8 - - successfulTests := []struct { - source interface{} - result *shopspring.Numeric - }{ - {source: decimal.New(1, 0), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: decimal.NullDecimal{Valid: true, Decimal: decimal.New(1, 0)}, result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: decimal.NullDecimal{Valid: false}, result: &shopspring.Numeric{}}, - {source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: int16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: int32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: int64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: int8(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}}, - {source: int16(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}}, - {source: int32(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}}, - {source: int64(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}}, - {source: uint8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: uint16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: uint32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: uint64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: "1", result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: _int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: float64(1000), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1000"), Valid: true}}, - {source: float64(1234), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1234"), Valid: true}}, - {source: float64(12345678900), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345678900"), Valid: true}}, - {source: float64(1.25), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.25"), Valid: true}}, - } - - for i, tt := range successfulTests { - r := &shopspring.Numeric{} - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !(r.Valid == tt.result.Valid && r.Decimal.Equal(tt.result.Decimal)) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestNumericAssignTo(t *testing.T) { - type _int8 int8 - - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - var f32 float32 - var f64 float64 - var pf32 *float32 - var pf64 *float64 - var d decimal.Decimal - var nd decimal.NullDecimal - - simpleTests := []struct { - src *shopspring.Numeric - dst interface{} - expected interface{} - }{ - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &f32, expected: float32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &f64, expected: float64(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Valid: true}, dst: &f32, expected: float32(4.2)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Valid: true}, dst: &f64, expected: float64(4.2)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &i16, expected: int16(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &i32, expected: int32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &i64, expected: int64(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Valid: true}, dst: &i64, expected: int64(42000)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &i, expected: int(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui8, expected: uint8(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui16, expected: uint16(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui64, expected: uint64(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui, expected: uint(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &_i8, expected: _int8(42)}, - {src: &shopspring.Numeric{}, dst: &pi8, expected: ((*int8)(nil))}, - {src: &shopspring.Numeric{}, dst: &_pi8, expected: ((*_int8)(nil))}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &d, expected: decimal.New(42, 0)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Valid: true}, dst: &d, expected: decimal.New(42, 3)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.042"), Valid: true}, dst: &d, expected: decimal.New(42, -3)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &nd, expected: decimal.NullDecimal{Valid: true, Decimal: decimal.New(42, 0)}}, - {src: &shopspring.Numeric{}, dst: &nd, expected: decimal.NullDecimal{Valid: false}}, - } - - for i, tt := range simpleTests { - // Zero out the destination variable - reflect.ValueOf(tt.dst).Elem().Set(reflect.Zero(reflect.TypeOf(tt.dst).Elem())) - - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - // Need to specially handle Decimal or NullDecimal methods so we can use their Equal method. Without this - // we end up checking reference equality on the *big.Int they contain. - switch dst := tt.dst.(type) { - case *decimal.Decimal: - if !dst.Equal(tt.expected.(decimal.Decimal)) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, d) - } - case *decimal.NullDecimal: - expected := tt.expected.(decimal.NullDecimal) - - if dst.Valid != expected.Valid { - t.Errorf("%d: expected %v to assign NullDecimal.Valid = %v, but result was NullDecimal.Valid = %v", i, tt.src, expected.Valid, dst.Valid) - } - if !dst.Decimal.Equal(expected.Decimal) { - t.Errorf("%d: expected %v to assign NullDecimal.Decimal = %v, but result was NullDecimal.Decimal = %v", i, tt.src, expected.Decimal, dst.Decimal) - } - default: - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - } - - pointerAllocTests := []struct { - src *shopspring.Numeric - dst interface{} - expected interface{} - }{ - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &pf32, expected: float32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &pf64, expected: float64(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src *shopspring.Numeric - dst interface{} - }{ - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "150"), Valid: true}, dst: &i8}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "40000"), Valid: true}, dst: &i16}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui8}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui16}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui32}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui64}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui}, - {src: &shopspring.Numeric{}, dst: &i32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} - -func BenchmarkDecode(b *testing.B) { - benchmarks := []struct { - name string - numberStr string - }{ - {"Zero", "0"}, - {"Small", "12345"}, - {"Medium", "12345.12345"}, - {"Large", "123457890.1234567890"}, - {"Huge", "123457890123457890123457890.1234567890123457890123457890"}, - } - - for _, bm := range benchmarks { - src := &shopspring.Numeric{} - err := src.Set(bm.numberStr) - require.NoError(b, err) - textFormat, err := src.EncodeText(nil, nil) - require.NoError(b, err) - binaryFormat, err := src.EncodeBinary(nil, nil) - require.NoError(b, err) - - b.Run(fmt.Sprintf("%s-Text", bm.name), func(b *testing.B) { - dst := &shopspring.Numeric{} - for i := 0; i < b.N; i++ { - err := dst.DecodeText(nil, textFormat) - if err != nil { - b.Fatal(err) - } - } - }) - - b.Run(fmt.Sprintf("%s-Binary", bm.name), func(b *testing.B) { - dst := &shopspring.Numeric{} - for i := 0; i < b.N; i++ { - err := dst.DecodeBinary(nil, binaryFormat) - if err != nil { - b.Fatal(err) - } - } - }) - } -} From 55195b3a64647fbbe017a9706d9bd7585959c13f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 2 Sep 2021 15:55:50 -0500 Subject: [PATCH 0745/1158] Add Numeric.Getter --- numeric.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/numeric.go b/numeric.go index 85648dc2..4cfbb657 100644 --- a/numeric.go +++ b/numeric.go @@ -62,11 +62,13 @@ type Numeric struct { Valid bool NumericDecoderWrapper func(interface{}) NumericDecoder + Getter func(Numeric) interface{} } func (n *Numeric) NewTypeValue() Value { return &Numeric{ NumericDecoderWrapper: n.NumericDecoderWrapper, + Getter: n.Getter, } } @@ -258,6 +260,10 @@ func (dst *Numeric) Set(src interface{}) error { } func (dst Numeric) Get() interface{} { + if dst.Getter != nil { + return dst.Getter(dst) + } + if !dst.Valid { return nil } From 0d9bd0366b9b3a4c8b5eedd9efa630504f6d5582 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 2 Sep 2021 17:16:00 -0500 Subject: [PATCH 0746/1158] Add Numeric.MarshalJSON --- numeric.go | 39 +++++++++++++++++++++++++++++++++++++++ numeric_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/numeric.go b/numeric.go index 4cfbb657..b24f433c 100644 --- a/numeric.go +++ b/numeric.go @@ -1,6 +1,7 @@ package pgtype import ( + "bytes" "database/sql/driver" "encoding/binary" "fmt" @@ -807,3 +808,41 @@ func (src Numeric) Value() (driver.Value, error) { return string(buf), nil } + +func (src Numeric) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + if src.NaN { + return []byte(`"NaN"`), nil + } + + intStr := src.Int.String() + buf := &bytes.Buffer{} + exp := int(src.Exp) + if exp > 0 { + buf.WriteString(intStr) + for i := 0; i < exp; i++ { + buf.WriteByte('0') + } + } else if exp < 0 { + if len(intStr) <= -exp { + buf.WriteString("0.") + leadingZeros := -exp - len(intStr) + for i := 0; i < leadingZeros; i++ { + buf.WriteByte('0') + } + buf.WriteString(intStr) + } else if len(intStr) > -exp { + dpPos := len(intStr) + exp + buf.WriteString(intStr[:dpPos]) + buf.WriteByte('.') + buf.WriteString(intStr[dpPos:]) + } + } else { + buf.WriteString(intStr) + } + + return buf.Bytes(), nil +} diff --git a/numeric_test.go b/numeric_test.go index 58ce5c0f..7f0734d0 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -1,6 +1,8 @@ package pgtype_test import ( + "context" + "encoding/json" "math" "math/big" "math/rand" @@ -9,6 +11,7 @@ import ( "github.com/jackc/pgtype" "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" ) // For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0) @@ -410,3 +413,35 @@ func TestNumericEncodeDecodeBinary(t *testing.T) { } } } + +func TestNumericMarshalJSON(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + for i, tt := range []struct { + decString string + }{ + {"NaN"}, + {"0"}, + {"1"}, + {"-1"}, + {"1000000000000000000"}, + {"1234.56789"}, + {"1.56789"}, + {"0.00000000000056789"}, + {"0.00123000"}, + {"123e-3"}, + {"243723409723490243842378942378901237502734019231380123e23790"}, + {"3409823409243892349028349023482934092340892390101e-14021"}, + } { + var num pgtype.Numeric + var pgJSON string + err := conn.QueryRow(context.Background(), `select $1::numeric, to_json($1::numeric)`, tt.decString).Scan(&num, &pgJSON) + require.NoErrorf(t, err, "%d", i) + + goJSON, err := json.Marshal(num) + require.NoErrorf(t, err, "%d", i) + + require.Equal(t, pgJSON, string(goJSON)) + } +} From 2226a5e14ece8446e1c6bec70517455b7c1d6f1a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Sep 2021 09:42:53 -0500 Subject: [PATCH 0747/1158] Remove explicit https://github.com/gofrs/uuid integration Better integration is now enabled by github.com/jackc/pgx-gofrs-uuid. --- ext/gofrs-uuid/uuid.go | 176 ------------------------------------ ext/gofrs-uuid/uuid_test.go | 100 -------------------- go.mod | 6 +- go.sum | 11 ++- uuid.go | 103 ++++++++++++++++----- uuid_test.go | 2 +- 6 files changed, 89 insertions(+), 309 deletions(-) delete mode 100644 ext/gofrs-uuid/uuid.go delete mode 100644 ext/gofrs-uuid/uuid_test.go diff --git a/ext/gofrs-uuid/uuid.go b/ext/gofrs-uuid/uuid.go deleted file mode 100644 index 0e0ebed3..00000000 --- a/ext/gofrs-uuid/uuid.go +++ /dev/null @@ -1,176 +0,0 @@ -package uuid - -import ( - "database/sql/driver" - "fmt" - - "github.com/gofrs/uuid" - "github.com/jackc/pgtype" -) - -type UUID struct { - UUID uuid.UUID - Valid bool -} - -func (dst *UUID) Set(src interface{}) error { - if src == nil { - *dst = UUID{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case uuid.UUID: - *dst = UUID{UUID: value, Valid: true} - case [16]byte: - *dst = UUID{UUID: uuid.UUID(value), Valid: true} - case []byte: - if len(value) != 16 { - return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) - } - *dst = UUID{Valid: true} - copy(dst.UUID[:], value) - case string: - uuid, err := uuid.FromString(value) - if err != nil { - return err - } - *dst = UUID{UUID: uuid, Valid: true} - default: - // If all else fails see if pgtype.UUID can handle it. If so, translate through that. - pgUUID := &pgtype.UUID{} - if err := pgUUID.Set(value); err != nil { - return fmt.Errorf("cannot convert %v to UUID", value) - } - - *dst = UUID{UUID: uuid.UUID(pgUUID.Bytes), Valid: pgUUID.Valid} - } - - return nil -} - -func (dst UUID) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.UUID -} - -func (src *UUID) AssignTo(dst interface{}) error { - if !src.Valid { - return pgtype.NullAssignTo(dst) - } - - switch v := dst.(type) { - case *uuid.UUID: - *v = src.UUID - return nil - case *[16]byte: - *v = [16]byte(src.UUID) - return nil - case *[]byte: - *v = make([]byte, 16) - copy(*v, src.UUID[:]) - return nil - case *string: - *v = src.UUID.String() - return nil - default: - if nextDst, retry := pgtype.GetAssignToDstType(v); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } -} - -func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - *dst = UUID{} - return nil - } - - u, err := uuid.FromString(string(src)) - if err != nil { - return err - } - - *dst = UUID{UUID: u, Valid: true} - return nil -} - -func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - *dst = UUID{} - return nil - } - - if len(src) != 16 { - return fmt.Errorf("invalid length for UUID: %v", len(src)) - } - - *dst = UUID{Valid: true} - copy(dst.UUID[:], src) - return nil -} - -func (src UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - return append(buf, src.UUID.String()...), nil -} - -func (src UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - return append(buf, src.UUID[:]...), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *UUID) Scan(src interface{}) error { - if src == nil { - *dst = UUID{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - return dst.DecodeText(nil, src) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src UUID) Value() (driver.Value, error) { - return pgtype.EncodeValueText(src) -} - -func (src UUID) MarshalJSON() ([]byte, error) { - if !src.Valid { - return []byte("null"), nil - } - return []byte(`"` + src.UUID.String() + `"`), nil -} - -func (dst *UUID) UnmarshalJSON(b []byte) error { - u := uuid.NullUUID{} - err := u.UnmarshalJSON(b) - if err != nil { - return err - } - - *dst = UUID{UUID: u.UUID, Valid: u.Valid} - - return nil -} diff --git a/ext/gofrs-uuid/uuid_test.go b/ext/gofrs-uuid/uuid_test.go deleted file mode 100644 index 3e5e4d82..00000000 --- a/ext/gofrs-uuid/uuid_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package uuid_test - -import ( - "bytes" - "testing" - - gofrs "github.com/jackc/pgtype/ext/gofrs-uuid" - "github.com/jackc/pgtype/testutil" -) - -func TestUUIDTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - &gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - &gofrs.UUID{}, - }) -} - -func TestUUIDSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result gofrs.UUID - }{ - { - source: &gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - }, - { - source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - }, - { - source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - }, - { - source: "00010203-0405-0607-0809-0a0b0c0d0e0f", - result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r gofrs.UUID - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestUUIDAssignTo(t *testing.T) { - { - src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} - var dst [16]byte - expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} - var dst []byte - expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if bytes.Compare(dst, expected) != 0 { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} - var dst string - expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - -} diff --git a/go.mod b/go.mod index 99c5b26e..b2f1cc10 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,8 @@ module github.com/jackc/pgtype go 1.13 require ( - github.com/gofrs/uuid v4.0.0+incompatible - github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530 + github.com/jackc/pgconn v1.10.1 github.com/jackc/pgio v1.0.0 - github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c - github.com/shopspring/decimal v1.2.0 + github.com/jackc/pgx/v4 v4.14.2-0.20211129172902-cf0de913ee8f github.com/stretchr/testify v1.7.0 ) diff --git a/go.sum b/go.sum index 8f2d760e..2a835726 100644 --- a/go.sum +++ b/go.sum @@ -25,8 +25,9 @@ github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= -github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530 h1:dUJ578zuPEsXjtzOfEF0q9zDAfljJ9oFnTHcQaNkccw= github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgconn v1.10.1 h1:DzdIHIjG1AxGwoEEqS+mGsURyjt4enSmqzACXvVzOT8= +github.com/jackc/pgconn v1.10.1/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= @@ -42,22 +43,26 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.2.0 h1:r7JypeP2D3onoQTCxWdTpCtJ4D+qpKr0TxvoyMhZ5ns= +github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= +github.com/jackc/pgtype v1.9.1/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= -github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c h1:Dznn52SgVIVst9UyOT9brctYUgxs+CvVfPaC3jKrA50= github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= +github.com/jackc/pgx/v4 v4.14.2-0.20211129172902-cf0de913ee8f h1:Y3Es3mIYatTvP4CXPXfmJtHWe8eq4E8owY6Fq61hEik= +github.com/jackc/pgx/v4 v4.14.2-0.20211129172902-cf0de913ee8f/go.mod h1:RgDuE4Z34o7XE92RpLsvFiOEfrAUT0Xt2KxvX73W06M= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= diff --git a/uuid.go b/uuid.go index d46111d3..4533aa06 100644 --- a/uuid.go +++ b/uuid.go @@ -10,11 +10,58 @@ import ( type UUID struct { Bytes [16]byte Valid bool + + UUIDDecoderWrapper func(interface{}) UUIDDecoder + Getter func(UUID) interface{} +} + +func (n *UUID) NewTypeValue() Value { + return &UUID{ + UUIDDecoderWrapper: n.UUIDDecoderWrapper, + Getter: n.Getter, + } +} + +func (n *UUID) TypeName() string { + return "uuid" +} + +func (dst *UUID) setNil() { + dst.Bytes = [16]byte{} + dst.Valid = false +} + +func (dst *UUID) setByteArray(value [16]byte) { + dst.Bytes = value + dst.Valid = true +} + +func (dst *UUID) setByteSlice(value []byte) error { + if value != nil { + if len(value) != 16 { + return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + } + copy(dst.Bytes[:], value) + dst.Valid = true + } else { + dst.setNil() + } + + return nil +} + +func (dst *UUID) setString(value string) error { + uuid, err := parseUUID(value) + if err != nil { + return err + } + dst.setByteArray(uuid) + return nil } func (dst *UUID) Set(src interface{}) error { if src == nil { - *dst = UUID{} + dst.setNil() return nil } @@ -27,28 +74,16 @@ func (dst *UUID) Set(src interface{}) error { switch value := src.(type) { case [16]byte: - *dst = UUID{Bytes: value, Valid: true} + dst.setByteArray(value) case []byte: - if value != nil { - if len(value) != 16 { - return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) - } - *dst = UUID{Valid: true} - copy(dst.Bytes[:], value) - } else { - *dst = UUID{} - } + return dst.setByteSlice(value) case string: - uuid, err := parseUUID(value) - if err != nil { - return err - } - *dst = UUID{Bytes: uuid, Valid: true} + return dst.setString(value) case *string: if value == nil { - *dst = UUID{} + dst.setNil() } else { - return dst.Set(*value) + return dst.setString(*value) } default: if originalSrc, ok := underlyingUUIDType(src); ok { @@ -61,13 +96,33 @@ func (dst *UUID) Set(src interface{}) error { } func (dst UUID) Get() interface{} { + if dst.Getter != nil { + return dst.Getter(dst) + } + if !dst.Valid { return nil } + return dst.Bytes } +type UUIDDecoder interface { + DecodeUUID(*UUID) error +} + func (src *UUID) AssignTo(dst interface{}) error { + if d, ok := dst.(UUIDDecoder); ok { + return d.DecodeUUID(src) + } else { + if src.UUIDDecoderWrapper != nil { + d = src.UUIDDecoderWrapper(dst) + if d != nil { + return d.DecodeUUID(src) + } + } + } + if !src.Valid { return NullAssignTo(dst) } @@ -120,7 +175,7 @@ func encodeUUID(src [16]byte) string { func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = UUID{} + dst.setNil() return nil } @@ -133,13 +188,13 @@ func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = UUID{Bytes: buf, Valid: true} + dst.setByteArray(buf) return nil } func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = UUID{} + dst.setNil() return nil } @@ -147,9 +202,7 @@ func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { return fmt.Errorf("invalid length for UUID: %v", len(src)) } - *dst = UUID{Valid: true} - copy(dst.Bytes[:], src) - return nil + return dst.setByteSlice(src) } func (src UUID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { @@ -171,7 +224,7 @@ func (src UUID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *UUID) Scan(src interface{}) error { if src == nil { - *dst = UUID{} + dst.setNil() return nil } diff --git a/uuid_test.go b/uuid_test.go index 887f45dd..63797178 100644 --- a/uuid_test.go +++ b/uuid_test.go @@ -65,7 +65,7 @@ func TestUUIDSet(t *testing.T) { t.Errorf("%d: %v", i, err) } - if r != tt.result { + if r.Bytes != tt.result.Bytes || r.Valid != tt.result.Valid { t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) } } From 75446032b914bb0be5e07da29c976034c0a666cf Mon Sep 17 00:00:00 2001 From: Torkel Rogstad Date: Fri, 12 Mar 2021 14:42:51 +0100 Subject: [PATCH 0748/1158] Normalize UTC timestamps to comply with stdlib --- timestamptz.go | 22 ++++++++++++++++++++-- timestamptz_test.go | 6 ++---- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/timestamptz.go b/timestamptz.go index 299a8668..58701970 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -124,7 +124,7 @@ func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Timestamptz{Time: tim, Status: Present} + *dst = Timestamptz{Time: normalizePotentialUTC(tim), Status: Present} } return nil @@ -231,6 +231,9 @@ func (src Timestamptz) Value() (driver.Value, error) { if src.InfinityModifier != None { return src.InfinityModifier.String(), nil } + if src.Time.Location().String() == time.UTC.String() { + return src.Time.UTC(), nil + } return src.Time, nil case Null: return nil, nil @@ -289,8 +292,23 @@ func (dst *Timestamptz) UnmarshalJSON(b []byte) error { return err } - *dst = Timestamptz{Time: tim, Status: Present} + *dst = Timestamptz{Time: normalizePotentialUTC(tim), Status: Present} } return nil } + +// Normalize timestamps in UTC location to behave similarly to how the Golang +// standard library does it: UTC timestamps lack a .loc value. +// +// Reason for this: when comparing two timestamps with reflect.DeepEqual (generally +// speaking not a good idea, but several testing libraries (for example testify) +// does this), their location data needs to be equal for them to be considered +// equal. +func normalizePotentialUTC(timestamp time.Time) time.Time { + if timestamp.Location().String() != time.UTC.String() { + return timestamp + } + + return timestamp.UTC() +} diff --git a/timestamptz_test.go b/timestamptz_test.go index c3f63967..2ff326bb 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -70,8 +70,7 @@ func TestTimestamptzNanosecondsTruncated(t *testing.T) { t.Errorf("%d. EncodeText failed - %v", i, err) } - tstz.DecodeText(nil, buf) - if err != nil { + if err := tstz.DecodeText(nil, buf); err != nil { t.Errorf("%d. DecodeText failed - %v", i, err) } @@ -87,8 +86,7 @@ func TestTimestamptzNanosecondsTruncated(t *testing.T) { t.Errorf("%d. EncodeBinary failed - %v", i, err) } - tstz.DecodeBinary(nil, buf) - if err != nil { + if err := tstz.DecodeBinary(nil, buf); err != nil { t.Errorf("%d. DecodeBinary failed - %v", i, err) } From 8f454e4cd6966adecea27084d8d22ec1829a5911 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Sep 2021 11:30:36 -0500 Subject: [PATCH 0749/1158] Add initial ParamEncoder and ResultDecoder support to core types --- new_pg_value.erb | 37 +++++++++++++++++++++++++++++++ new_pg_value_gen.sh | 45 ++++++++++++++++++++++++++++++++++++++ pgtype.go | 51 ++++++++++++++++++++++++++++++++++++++----- zzz.aclitem.go | 35 +++++++++++++++++++++++++++++ zzz.bit.go | 35 +++++++++++++++++++++++++++++ zzz.bool.go | 35 +++++++++++++++++++++++++++++ zzz.box.go | 35 +++++++++++++++++++++++++++++ zzz.bpchar.go | 35 +++++++++++++++++++++++++++++ zzz.bytea.go | 35 +++++++++++++++++++++++++++++ zzz.cid.go | 35 +++++++++++++++++++++++++++++ zzz.cidr.go | 35 +++++++++++++++++++++++++++++ zzz.circle.go | 35 +++++++++++++++++++++++++++++ zzz.date.go | 35 +++++++++++++++++++++++++++++ zzz.float4.go | 35 +++++++++++++++++++++++++++++ zzz.float8.go | 35 +++++++++++++++++++++++++++++ zzz.generic_binary.go | 35 +++++++++++++++++++++++++++++ zzz.generic_text.go | 35 +++++++++++++++++++++++++++++ zzz.hstore.go | 35 +++++++++++++++++++++++++++++ zzz.inet.go | 35 +++++++++++++++++++++++++++++ zzz.int2.go | 35 +++++++++++++++++++++++++++++ zzz.int4.go | 35 +++++++++++++++++++++++++++++ zzz.int8.go | 35 +++++++++++++++++++++++++++++ zzz.interval.go | 35 +++++++++++++++++++++++++++++ zzz.json.go | 35 +++++++++++++++++++++++++++++ zzz.jsonb.go | 35 +++++++++++++++++++++++++++++ zzz.line.go | 35 +++++++++++++++++++++++++++++ zzz.lseg.go | 35 +++++++++++++++++++++++++++++ zzz.macadder.go | 35 +++++++++++++++++++++++++++++ zzz.name.go | 35 +++++++++++++++++++++++++++++ zzz.numeric.go | 35 +++++++++++++++++++++++++++++ zzz.oid.go | 35 +++++++++++++++++++++++++++++ zzz.oid_value.go | 35 +++++++++++++++++++++++++++++ zzz.path.go | 35 +++++++++++++++++++++++++++++ zzz.pguint32.go | 35 +++++++++++++++++++++++++++++ zzz.point.go | 35 +++++++++++++++++++++++++++++ zzz.polygon.go | 35 +++++++++++++++++++++++++++++ zzz.qchar.go | 35 +++++++++++++++++++++++++++++ zzz.text.go | 35 +++++++++++++++++++++++++++++ zzz.tid.go | 35 +++++++++++++++++++++++++++++ zzz.time.go | 35 +++++++++++++++++++++++++++++ zzz.timestamp.go | 35 +++++++++++++++++++++++++++++ zzz.timestamptz.go | 35 +++++++++++++++++++++++++++++ zzz.uuid.go | 35 +++++++++++++++++++++++++++++ zzz.varbit.go | 35 +++++++++++++++++++++++++++++ zzz.varchar.go | 35 +++++++++++++++++++++++++++++ zzz.xid.go | 35 +++++++++++++++++++++++++++++ 46 files changed, 1632 insertions(+), 6 deletions(-) create mode 100644 new_pg_value.erb create mode 100644 new_pg_value_gen.sh create mode 100644 zzz.aclitem.go create mode 100644 zzz.bit.go create mode 100644 zzz.bool.go create mode 100644 zzz.box.go create mode 100644 zzz.bpchar.go create mode 100644 zzz.bytea.go create mode 100644 zzz.cid.go create mode 100644 zzz.cidr.go create mode 100644 zzz.circle.go create mode 100644 zzz.date.go create mode 100644 zzz.float4.go create mode 100644 zzz.float8.go create mode 100644 zzz.generic_binary.go create mode 100644 zzz.generic_text.go create mode 100644 zzz.hstore.go create mode 100644 zzz.inet.go create mode 100644 zzz.int2.go create mode 100644 zzz.int4.go create mode 100644 zzz.int8.go create mode 100644 zzz.interval.go create mode 100644 zzz.json.go create mode 100644 zzz.jsonb.go create mode 100644 zzz.line.go create mode 100644 zzz.lseg.go create mode 100644 zzz.macadder.go create mode 100644 zzz.name.go create mode 100644 zzz.numeric.go create mode 100644 zzz.oid.go create mode 100644 zzz.oid_value.go create mode 100644 zzz.path.go create mode 100644 zzz.pguint32.go create mode 100644 zzz.point.go create mode 100644 zzz.polygon.go create mode 100644 zzz.qchar.go create mode 100644 zzz.text.go create mode 100644 zzz.tid.go create mode 100644 zzz.time.go create mode 100644 zzz.timestamp.go create mode 100644 zzz.timestamptz.go create mode 100644 zzz.uuid.go create mode 100644 zzz.varbit.go create mode 100644 zzz.varchar.go create mode 100644 zzz.xid.go diff --git a/new_pg_value.erb b/new_pg_value.erb new file mode 100644 index 00000000..71a0da7f --- /dev/null +++ b/new_pg_value.erb @@ -0,0 +1,37 @@ +package pgtype + +<% skip_binary ||= false %> +<% skip_text ||= false %> +<% prefer_text_format ||= false %> + +func (<%= go_type %>) BinaryFormatSupported() bool { + return true +} + +func (<%= go_type %>) TextFormatSupported() bool { + return true +} + +func (<%= go_type %>) PreferredFormat() int16 { + return <%= prefer_text_format ? "Text" : "Binary" %>FormatCode +} + +func (dst *<%= go_type %>) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + <% if skip_binary %> return fmt.Errorf("binary format not supported for %T", dst) <% else %> return dst.DecodeBinary(ci, src) <% end %> + case TextFormatCode: + <% if skip_text %> return fmt.Errorf("text format not supported for %T", dst) <% else %> return dst.DecodeText(ci, src) <% end %> + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src <%= go_type %>) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + <% if skip_binary %>return nil, fmt.Errorf("binary format not supported for %T", src)<% else %>return src.EncodeBinary(ci, buf)<% end %> + case TextFormatCode: + <% if skip_text %>return nil, fmt.Errorf("text format not supported for %T", src)<% else %>return src.EncodeText(ci, buf)<% end %> + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/new_pg_value_gen.sh b/new_pg_value_gen.sh new file mode 100644 index 00000000..3dad08de --- /dev/null +++ b/new_pg_value_gen.sh @@ -0,0 +1,45 @@ +erb go_type=ACLItem skip_binary=true prefer_text_format=true new_pg_value.erb > zzz.aclitem.go +erb go_type=Bit new_pg_value.erb > zzz.bit.go +erb go_type=Bool new_pg_value.erb > zzz.bool.go +erb go_type=Box new_pg_value.erb > zzz.box.go +erb go_type=BPChar prefer_text_format=true new_pg_value.erb > zzz.bpchar.go +erb go_type=Bytea new_pg_value.erb > zzz.bytea.go +erb go_type=CID new_pg_value.erb > zzz.cid.go +erb go_type=CIDR new_pg_value.erb > zzz.cidr.go +erb go_type=Circle new_pg_value.erb > zzz.circle.go +erb go_type=Date new_pg_value.erb > zzz.date.go +erb go_type=Float4 new_pg_value.erb > zzz.float4.go +erb go_type=Float8 new_pg_value.erb > zzz.float8.go +erb go_type=GenericBinary skip_text=true new_pg_value.erb > zzz.generic_binary.go +erb go_type=GenericText skip_binary=true prefer_text_format=true new_pg_value.erb > zzz.generic_text.go +erb go_type=Hstore new_pg_value.erb > zzz.hstore.go +erb go_type=Inet new_pg_value.erb > zzz.inet.go +erb go_type=Int2 new_pg_value.erb > zzz.int2.go +erb go_type=Int4 new_pg_value.erb > zzz.int4.go +erb go_type=Int8 new_pg_value.erb > zzz.int8.go +erb go_type=Interval new_pg_value.erb > zzz.interval.go +erb go_type=JSON prefer_text_format=true new_pg_value.erb > zzz.json.go +erb go_type=JSONB prefer_text_format=true new_pg_value.erb > zzz.jsonb.go +erb go_type=Line new_pg_value.erb > zzz.line.go +erb go_type=Lseg new_pg_value.erb > zzz.lseg.go +erb go_type=Macaddr new_pg_value.erb > zzz.macadder.go +erb go_type=Name new_pg_value.erb > zzz.name.go +erb go_type=Numeric new_pg_value.erb > zzz.numeric.go +erb go_type=OIDValue new_pg_value.erb > zzz.oid_value.go +erb go_type=OID new_pg_value.erb > zzz.oid.go +erb go_type=Path new_pg_value.erb > zzz.path.go +erb go_type=pguint32 new_pg_value.erb > zzz.pguint32.go +erb go_type=Point new_pg_value.erb > zzz.point.go +erb go_type=Polygon new_pg_value.erb > zzz.polygon.go +erb go_type=QChar skip_text=true new_pg_value.erb > zzz.qchar.go +erb go_type=Text prefer_text_format=true new_pg_value.erb > zzz.text.go +erb go_type=TID new_pg_value.erb > zzz.tid.go +erb go_type=Time new_pg_value.erb > zzz.time.go +erb go_type=Timestamp new_pg_value.erb > zzz.timestamp.go +erb go_type=Timestamptz new_pg_value.erb > zzz.timestamptz.go +# erb go_type=Unknown new_pg_value.erb > zzz.unknown.go +erb go_type=UUID new_pg_value.erb > zzz.uuid.go +erb go_type=Varbit new_pg_value.erb > zzz.varbit.go +erb go_type=Varchar prefer_text_format=true new_pg_value.erb > zzz.varchar.go +erb go_type=XID new_pg_value.erb > zzz.xid.go +goimports -w zzz.* diff --git a/pgtype.go b/pgtype.go index 4fa6eebe..b9067fab 100644 --- a/pgtype.go +++ b/pgtype.go @@ -153,6 +153,22 @@ type ValueTranscoder interface { BinaryDecoder } +type FormatSupport interface { + BinaryFormatSupported() bool + TextFormatSupported() bool + PreferredFormat() int16 +} + +type ParamEncoder interface { + FormatSupport + EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) +} + +type ResultDecoder interface { + FormatSupport + DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error +} + // ResultFormatPreferrer allows a type to specify its preferred result format instead of it being inferred from // whether it is also a BinaryDecoder. type ResultFormatPreferrer interface { @@ -210,6 +226,8 @@ func (e *nullAssignmentError) Error() string { type DataType struct { Value Value + resultDecoder ResultDecoder + textDecoder TextDecoder binaryDecoder BinaryDecoder @@ -380,7 +398,9 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { { var formatCode int16 - if rfp, ok := t.Value.(ResultFormatPreferrer); ok { + if fs, ok := t.Value.(FormatSupport); ok { + formatCode = fs.PreferredFormat() + } else if rfp, ok := t.Value.(ResultFormatPreferrer); ok { formatCode = rfp.PreferredResultFormat() } else if _, ok := t.Value.(BinaryDecoder); ok { formatCode = BinaryFormatCode @@ -388,6 +408,10 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { ci.oidToResultFormatCode[t.OID] = formatCode } + if d, ok := t.Value.(ResultDecoder); ok { + t.resultDecoder = d + } + if d, ok := t.Value.(TextDecoder); ok { t.textDecoder = d } @@ -478,6 +502,17 @@ type ScanPlan interface { Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error } +type scanPlanDstResultDecoder struct{} + +func (scanPlanDstResultDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if d, ok := (dst).(ResultDecoder); ok { + return d.DecodeResult(ci, oid, formatCode, src) + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + type scanPlanDstBinaryDecoder struct{} func (scanPlanDstBinaryDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { @@ -533,11 +568,15 @@ type scanPlanDataTypeAssignTo DataType func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { dt := (*DataType)(plan) var err error - switch formatCode { - case BinaryFormatCode: - err = dt.binaryDecoder.DecodeBinary(ci, src) - case TextFormatCode: - err = dt.textDecoder.DecodeText(ci, src) + if dt.resultDecoder != nil { + err = dt.resultDecoder.DecodeResult(ci, oid, formatCode, src) + } else { + switch formatCode { + case BinaryFormatCode: + err = dt.binaryDecoder.DecodeBinary(ci, src) + case TextFormatCode: + err = dt.textDecoder.DecodeText(ci, src) + } } if err != nil { return err diff --git a/zzz.aclitem.go b/zzz.aclitem.go new file mode 100644 index 00000000..6ac1f94a --- /dev/null +++ b/zzz.aclitem.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (ACLItem) BinaryFormatSupported() bool { + return true +} + +func (ACLItem) TextFormatSupported() bool { + return true +} + +func (ACLItem) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *ACLItem) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return fmt.Errorf("binary format not supported for %T", dst) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src ACLItem) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return nil, fmt.Errorf("binary format not supported for %T", src) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.bit.go b/zzz.bit.go new file mode 100644 index 00000000..e95df74d --- /dev/null +++ b/zzz.bit.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Bit) BinaryFormatSupported() bool { + return true +} + +func (Bit) TextFormatSupported() bool { + return true +} + +func (Bit) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Bit) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Bit) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.bool.go b/zzz.bool.go new file mode 100644 index 00000000..e6ed52de --- /dev/null +++ b/zzz.bool.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Bool) BinaryFormatSupported() bool { + return true +} + +func (Bool) TextFormatSupported() bool { + return true +} + +func (Bool) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Bool) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Bool) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.box.go b/zzz.box.go new file mode 100644 index 00000000..5ca2df43 --- /dev/null +++ b/zzz.box.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Box) BinaryFormatSupported() bool { + return true +} + +func (Box) TextFormatSupported() bool { + return true +} + +func (Box) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Box) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Box) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.bpchar.go b/zzz.bpchar.go new file mode 100644 index 00000000..c3178670 --- /dev/null +++ b/zzz.bpchar.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (BPChar) BinaryFormatSupported() bool { + return true +} + +func (BPChar) TextFormatSupported() bool { + return true +} + +func (BPChar) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *BPChar) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src BPChar) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.bytea.go b/zzz.bytea.go new file mode 100644 index 00000000..4da5ad4f --- /dev/null +++ b/zzz.bytea.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Bytea) BinaryFormatSupported() bool { + return true +} + +func (Bytea) TextFormatSupported() bool { + return true +} + +func (Bytea) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Bytea) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Bytea) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.cid.go b/zzz.cid.go new file mode 100644 index 00000000..4cb9671d --- /dev/null +++ b/zzz.cid.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (CID) BinaryFormatSupported() bool { + return true +} + +func (CID) TextFormatSupported() bool { + return true +} + +func (CID) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *CID) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src CID) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.cidr.go b/zzz.cidr.go new file mode 100644 index 00000000..714908e0 --- /dev/null +++ b/zzz.cidr.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (CIDR) BinaryFormatSupported() bool { + return true +} + +func (CIDR) TextFormatSupported() bool { + return true +} + +func (CIDR) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *CIDR) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src CIDR) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.circle.go b/zzz.circle.go new file mode 100644 index 00000000..b111c06d --- /dev/null +++ b/zzz.circle.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Circle) BinaryFormatSupported() bool { + return true +} + +func (Circle) TextFormatSupported() bool { + return true +} + +func (Circle) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Circle) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Circle) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.date.go b/zzz.date.go new file mode 100644 index 00000000..66132082 --- /dev/null +++ b/zzz.date.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Date) BinaryFormatSupported() bool { + return true +} + +func (Date) TextFormatSupported() bool { + return true +} + +func (Date) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Date) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Date) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.float4.go b/zzz.float4.go new file mode 100644 index 00000000..b600805e --- /dev/null +++ b/zzz.float4.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Float4) BinaryFormatSupported() bool { + return true +} + +func (Float4) TextFormatSupported() bool { + return true +} + +func (Float4) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Float4) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Float4) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.float8.go b/zzz.float8.go new file mode 100644 index 00000000..dd3ba0fa --- /dev/null +++ b/zzz.float8.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Float8) BinaryFormatSupported() bool { + return true +} + +func (Float8) TextFormatSupported() bool { + return true +} + +func (Float8) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Float8) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Float8) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.generic_binary.go b/zzz.generic_binary.go new file mode 100644 index 00000000..b50f1f45 --- /dev/null +++ b/zzz.generic_binary.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (GenericBinary) BinaryFormatSupported() bool { + return true +} + +func (GenericBinary) TextFormatSupported() bool { + return true +} + +func (GenericBinary) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *GenericBinary) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return fmt.Errorf("text format not supported for %T", dst) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src GenericBinary) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return nil, fmt.Errorf("text format not supported for %T", src) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.generic_text.go b/zzz.generic_text.go new file mode 100644 index 00000000..5ab771cf --- /dev/null +++ b/zzz.generic_text.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (GenericText) BinaryFormatSupported() bool { + return true +} + +func (GenericText) TextFormatSupported() bool { + return true +} + +func (GenericText) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *GenericText) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return fmt.Errorf("binary format not supported for %T", dst) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src GenericText) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return nil, fmt.Errorf("binary format not supported for %T", src) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.hstore.go b/zzz.hstore.go new file mode 100644 index 00000000..ebd7bdee --- /dev/null +++ b/zzz.hstore.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Hstore) BinaryFormatSupported() bool { + return true +} + +func (Hstore) TextFormatSupported() bool { + return true +} + +func (Hstore) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Hstore) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Hstore) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.inet.go b/zzz.inet.go new file mode 100644 index 00000000..51daeee6 --- /dev/null +++ b/zzz.inet.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Inet) BinaryFormatSupported() bool { + return true +} + +func (Inet) TextFormatSupported() bool { + return true +} + +func (Inet) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Inet) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Inet) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.int2.go b/zzz.int2.go new file mode 100644 index 00000000..f2d959f9 --- /dev/null +++ b/zzz.int2.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Int2) BinaryFormatSupported() bool { + return true +} + +func (Int2) TextFormatSupported() bool { + return true +} + +func (Int2) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Int2) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Int2) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.int4.go b/zzz.int4.go new file mode 100644 index 00000000..bd7f9bda --- /dev/null +++ b/zzz.int4.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Int4) BinaryFormatSupported() bool { + return true +} + +func (Int4) TextFormatSupported() bool { + return true +} + +func (Int4) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Int4) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Int4) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.int8.go b/zzz.int8.go new file mode 100644 index 00000000..d6e98262 --- /dev/null +++ b/zzz.int8.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Int8) BinaryFormatSupported() bool { + return true +} + +func (Int8) TextFormatSupported() bool { + return true +} + +func (Int8) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Int8) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Int8) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.interval.go b/zzz.interval.go new file mode 100644 index 00000000..a34f2d59 --- /dev/null +++ b/zzz.interval.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Interval) BinaryFormatSupported() bool { + return true +} + +func (Interval) TextFormatSupported() bool { + return true +} + +func (Interval) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Interval) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Interval) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.json.go b/zzz.json.go new file mode 100644 index 00000000..40a736c9 --- /dev/null +++ b/zzz.json.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (JSON) BinaryFormatSupported() bool { + return true +} + +func (JSON) TextFormatSupported() bool { + return true +} + +func (JSON) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *JSON) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src JSON) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.jsonb.go b/zzz.jsonb.go new file mode 100644 index 00000000..a07934b7 --- /dev/null +++ b/zzz.jsonb.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (JSONB) BinaryFormatSupported() bool { + return true +} + +func (JSONB) TextFormatSupported() bool { + return true +} + +func (JSONB) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *JSONB) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src JSONB) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.line.go b/zzz.line.go new file mode 100644 index 00000000..7365744b --- /dev/null +++ b/zzz.line.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Line) BinaryFormatSupported() bool { + return true +} + +func (Line) TextFormatSupported() bool { + return true +} + +func (Line) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Line) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Line) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.lseg.go b/zzz.lseg.go new file mode 100644 index 00000000..1a95af09 --- /dev/null +++ b/zzz.lseg.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Lseg) BinaryFormatSupported() bool { + return true +} + +func (Lseg) TextFormatSupported() bool { + return true +} + +func (Lseg) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Lseg) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Lseg) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.macadder.go b/zzz.macadder.go new file mode 100644 index 00000000..5758d68f --- /dev/null +++ b/zzz.macadder.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Macaddr) BinaryFormatSupported() bool { + return true +} + +func (Macaddr) TextFormatSupported() bool { + return true +} + +func (Macaddr) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Macaddr) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Macaddr) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.name.go b/zzz.name.go new file mode 100644 index 00000000..6949c337 --- /dev/null +++ b/zzz.name.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Name) BinaryFormatSupported() bool { + return true +} + +func (Name) TextFormatSupported() bool { + return true +} + +func (Name) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Name) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Name) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.numeric.go b/zzz.numeric.go new file mode 100644 index 00000000..838bed40 --- /dev/null +++ b/zzz.numeric.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Numeric) BinaryFormatSupported() bool { + return true +} + +func (Numeric) TextFormatSupported() bool { + return true +} + +func (Numeric) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Numeric) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Numeric) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.oid.go b/zzz.oid.go new file mode 100644 index 00000000..bc3ba7d2 --- /dev/null +++ b/zzz.oid.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (OID) BinaryFormatSupported() bool { + return true +} + +func (OID) TextFormatSupported() bool { + return true +} + +func (OID) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *OID) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src OID) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.oid_value.go b/zzz.oid_value.go new file mode 100644 index 00000000..6fba9e44 --- /dev/null +++ b/zzz.oid_value.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (OIDValue) BinaryFormatSupported() bool { + return true +} + +func (OIDValue) TextFormatSupported() bool { + return true +} + +func (OIDValue) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *OIDValue) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src OIDValue) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.path.go b/zzz.path.go new file mode 100644 index 00000000..d761ac40 --- /dev/null +++ b/zzz.path.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Path) BinaryFormatSupported() bool { + return true +} + +func (Path) TextFormatSupported() bool { + return true +} + +func (Path) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Path) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Path) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.pguint32.go b/zzz.pguint32.go new file mode 100644 index 00000000..c869da8f --- /dev/null +++ b/zzz.pguint32.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (pguint32) BinaryFormatSupported() bool { + return true +} + +func (pguint32) TextFormatSupported() bool { + return true +} + +func (pguint32) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *pguint32) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src pguint32) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.point.go b/zzz.point.go new file mode 100644 index 00000000..083ded95 --- /dev/null +++ b/zzz.point.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Point) BinaryFormatSupported() bool { + return true +} + +func (Point) TextFormatSupported() bool { + return true +} + +func (Point) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Point) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Point) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.polygon.go b/zzz.polygon.go new file mode 100644 index 00000000..2bfdbbd4 --- /dev/null +++ b/zzz.polygon.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Polygon) BinaryFormatSupported() bool { + return true +} + +func (Polygon) TextFormatSupported() bool { + return true +} + +func (Polygon) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Polygon) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Polygon) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.qchar.go b/zzz.qchar.go new file mode 100644 index 00000000..adc0f462 --- /dev/null +++ b/zzz.qchar.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (QChar) BinaryFormatSupported() bool { + return true +} + +func (QChar) TextFormatSupported() bool { + return true +} + +func (QChar) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *QChar) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return fmt.Errorf("text format not supported for %T", dst) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src QChar) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return nil, fmt.Errorf("text format not supported for %T", src) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.text.go b/zzz.text.go new file mode 100644 index 00000000..e1a3908f --- /dev/null +++ b/zzz.text.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Text) BinaryFormatSupported() bool { + return true +} + +func (Text) TextFormatSupported() bool { + return true +} + +func (Text) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *Text) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Text) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.tid.go b/zzz.tid.go new file mode 100644 index 00000000..1a705277 --- /dev/null +++ b/zzz.tid.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (TID) BinaryFormatSupported() bool { + return true +} + +func (TID) TextFormatSupported() bool { + return true +} + +func (TID) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *TID) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src TID) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.time.go b/zzz.time.go new file mode 100644 index 00000000..be9a96a7 --- /dev/null +++ b/zzz.time.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Time) BinaryFormatSupported() bool { + return true +} + +func (Time) TextFormatSupported() bool { + return true +} + +func (Time) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Time) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Time) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.timestamp.go b/zzz.timestamp.go new file mode 100644 index 00000000..ce6135c7 --- /dev/null +++ b/zzz.timestamp.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Timestamp) BinaryFormatSupported() bool { + return true +} + +func (Timestamp) TextFormatSupported() bool { + return true +} + +func (Timestamp) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Timestamp) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Timestamp) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.timestamptz.go b/zzz.timestamptz.go new file mode 100644 index 00000000..1147b257 --- /dev/null +++ b/zzz.timestamptz.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Timestamptz) BinaryFormatSupported() bool { + return true +} + +func (Timestamptz) TextFormatSupported() bool { + return true +} + +func (Timestamptz) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Timestamptz) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Timestamptz) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.uuid.go b/zzz.uuid.go new file mode 100644 index 00000000..a0aefaf6 --- /dev/null +++ b/zzz.uuid.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (UUID) BinaryFormatSupported() bool { + return true +} + +func (UUID) TextFormatSupported() bool { + return true +} + +func (UUID) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *UUID) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src UUID) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.varbit.go b/zzz.varbit.go new file mode 100644 index 00000000..2b090ebf --- /dev/null +++ b/zzz.varbit.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Varbit) BinaryFormatSupported() bool { + return true +} + +func (Varbit) TextFormatSupported() bool { + return true +} + +func (Varbit) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Varbit) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Varbit) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.varchar.go b/zzz.varchar.go new file mode 100644 index 00000000..9771d412 --- /dev/null +++ b/zzz.varchar.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Varchar) BinaryFormatSupported() bool { + return true +} + +func (Varchar) TextFormatSupported() bool { + return true +} + +func (Varchar) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *Varchar) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Varchar) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/zzz.xid.go b/zzz.xid.go new file mode 100644 index 00000000..2754d98e --- /dev/null +++ b/zzz.xid.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (XID) BinaryFormatSupported() bool { + return true +} + +func (XID) TextFormatSupported() bool { + return true +} + +func (XID) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *XID) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src XID) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} From e22675d20b262cb065c55f2a77b1252ccddf7556 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Dec 2021 12:45:20 -0600 Subject: [PATCH 0750/1158] ValueTranscoder uses new interfaces --- array_type.go | 47 +++++++++++++++++----- composite_fields.go | 12 +++--- composite_type.go | 96 +++++++++++++++++++++++++++++++++------------ pgtype.go | 16 +++++--- 4 files changed, 123 insertions(+), 48 deletions(-) diff --git a/array_type.go b/array_type.go index 1df1689f..c4f162af 100644 --- a/array_type.go +++ b/array_type.go @@ -129,12 +129,44 @@ func (src *ArrayType) AssignTo(dst interface{}) error { } } -func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error { +func (ArrayType) BinaryFormatSupported() bool { + return true +} + +func (ArrayType) TextFormatSupported() bool { + return true +} + +func (ArrayType) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *ArrayType) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { if src == nil { dst.setNil() return nil } + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src ArrayType) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} + +func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error { uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err @@ -151,7 +183,7 @@ func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(ci, elemSrc) + err = elem.DecodeResult(ci, dst.elementOID, TextFormatCode, elemSrc) if err != nil { return err } @@ -168,11 +200,6 @@ func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error { } func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - dst.setNil() - return nil - } - var arrayHeader ArrayHeader rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { @@ -204,7 +231,7 @@ func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elem.DecodeBinary(ci, elemSrc) + err = elem.DecodeResult(ci, dst.elementOID, BinaryFormatCode, elemSrc) if err != nil { return err } @@ -253,7 +280,7 @@ func (src ArrayType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } } - elemBuf, err := elem.EncodeText(ci, inElemBuf) + elemBuf, err := elem.EncodeParam(ci, src.elementOID, TextFormatCode, inElemBuf) if err != nil { return nil, err } @@ -296,7 +323,7 @@ func (src ArrayType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { sp := len(buf) buf = pgio.AppendInt32(buf, -1) - elemBuf, err := src.elements[i].EncodeBinary(ci, buf) + elemBuf, err := src.elements[i].EncodeParam(ci, src.elementOID, BinaryFormatCode, buf) if err != nil { return nil, err } diff --git a/composite_fields.go b/composite_fields.go index b6d09fcf..e7ca89c7 100644 --- a/composite_fields.go +++ b/composite_fields.go @@ -59,8 +59,8 @@ func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { b := NewCompositeTextBuilder(ci, buf) for _, f := range cf { - if textEncoder, ok := f.(TextEncoder); ok { - b.AppendEncoder(textEncoder) + if paramEncoder, ok := f.(ParamEncoder); ok { + b.AppendEncoder(paramEncoder) } else { b.AppendValue(f) } @@ -88,15 +88,15 @@ func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) return nil, fmt.Errorf("Unknown OID for %#v", f) } - if binaryEncoder, ok := f.(BinaryEncoder); ok { - b.AppendEncoder(dt.OID, binaryEncoder) + if paramEncoder, ok := f.(ParamEncoder); ok { + b.AppendEncoder(dt.OID, paramEncoder) } else { err := dt.Value.Set(f) if err != nil { return nil, err } - if binaryEncoder, ok := dt.Value.(BinaryEncoder); ok { - b.AppendEncoder(dt.OID, binaryEncoder) + if paramEncoder, ok := dt.Value.(ParamEncoder); ok { + b.AppendEncoder(dt.OID, paramEncoder) } else { return nil, fmt.Errorf("Cannot encode binary format for %v", f) } diff --git a/composite_type.go b/composite_type.go index 90b7b6ff..85ab5910 100644 --- a/composite_type.go +++ b/composite_type.go @@ -91,9 +91,13 @@ func (ct *CompositeType) Fields() []CompositeTypeField { return ct.fields } +func (dst *CompositeType) setNil() { + dst.valid = false +} + func (dst *CompositeType) Set(src interface{}) error { if src == nil { - dst.valid = false + dst.setNil() return nil } @@ -110,7 +114,7 @@ func (dst *CompositeType) Set(src interface{}) error { dst.valid = true case *[]interface{}: if value == nil { - dst.valid = false + dst.setNil() return nil } return dst.Set(*value) @@ -213,6 +217,56 @@ func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) { return true, nil } +func (ct *CompositeType) BinaryFormatSupported() bool { + for _, vt := range ct.valueTranscoders { + if !vt.BinaryFormatSupported() { + return false + } + } + return true +} + +func (ct *CompositeType) TextFormatSupported() bool { + for _, vt := range ct.valueTranscoders { + if !vt.TextFormatSupported() { + return false + } + } + return true +} + +func (ct *CompositeType) PreferredFormat() int16 { + if ct.BinaryFormatSupported() { + return BinaryFormatCode + } + return TextFormatCode +} + +func (dst *CompositeType) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + if src == nil { + dst.setNil() + return nil + } + + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src CompositeType) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} + func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { if !src.valid { return nil, nil @@ -231,11 +285,6 @@ func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, // and decoding fails if SQL value can't be assigned due to // type mismatch func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { - if buf == nil { - dst.valid = false - return nil - } - scanner := NewCompositeBinaryScanner(ci, buf) for _, f := range dst.valueTranscoders { @@ -252,11 +301,6 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { } func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error { - if buf == nil { - dst.valid = false - return nil - } - scanner := NewCompositeTextScanner(ci, buf) for _, f := range dst.valueTranscoders { @@ -315,13 +359,13 @@ func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner } // ScanDecoder calls Next and decodes the result with d. -func (cfs *CompositeBinaryScanner) ScanDecoder(d BinaryDecoder) { +func (cfs *CompositeBinaryScanner) ScanDecoder(d ResultDecoder) { if cfs.err != nil { return } if cfs.Next() { - cfs.err = d.DecodeBinary(cfs.ci, cfs.fieldBytes) + cfs.err = d.DecodeResult(cfs.ci, 0, BinaryFormatCode, cfs.fieldBytes) } else { cfs.err = errors.New("read past end of composite") } @@ -425,13 +469,13 @@ func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner { } // ScanDecoder calls Next and decodes the result with d. -func (cfs *CompositeTextScanner) ScanDecoder(d TextDecoder) { +func (cfs *CompositeTextScanner) ScanDecoder(d ResultDecoder) { if cfs.err != nil { return } if cfs.Next() { - cfs.err = d.DecodeText(cfs.ci, cfs.fieldBytes) + cfs.err = d.DecodeResult(cfs.ci, 0, TextFormatCode, cfs.fieldBytes) } else { cfs.err = errors.New("read past end of composite") } @@ -547,16 +591,16 @@ func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) { return } - binaryEncoder, ok := dt.Value.(BinaryEncoder) + paramEncoder, ok := dt.Value.(ParamEncoder) if !ok { - b.err = fmt.Errorf("unable to encode binary for OID: %d", oid) + b.err = fmt.Errorf("unable to encode for OID: %d", oid) return } - b.AppendEncoder(oid, binaryEncoder) + b.AppendEncoder(oid, paramEncoder) } -func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field BinaryEncoder) { +func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field ParamEncoder) { if b.err != nil { return } @@ -564,7 +608,7 @@ func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field BinaryEncoder) b.buf = pgio.AppendUint32(b.buf, oid) lengthPos := len(b.buf) b.buf = pgio.AppendInt32(b.buf, -1) - fieldBuf, err := field.EncodeBinary(b.ci, b.buf) + fieldBuf, err := field.EncodeParam(b.ci, oid, BinaryFormatCode, b.buf) if err != nil { b.err = err return @@ -622,21 +666,21 @@ func (b *CompositeTextBuilder) AppendValue(field interface{}) { return } - textEncoder, ok := dt.Value.(TextEncoder) + paramEncoder, ok := dt.Value.(ParamEncoder) if !ok { - b.err = fmt.Errorf("unable to encode text for value: %v", field) + b.err = fmt.Errorf("unable to encode for value: %v", field) return } - b.AppendEncoder(textEncoder) + b.AppendEncoder(paramEncoder) } -func (b *CompositeTextBuilder) AppendEncoder(field TextEncoder) { +func (b *CompositeTextBuilder) AppendEncoder(field ParamEncoder) { if b.err != nil { return } - fieldBuf, err := field.EncodeText(b.ci, b.fieldBuf[0:0]) + fieldBuf, err := field.EncodeParam(b.ci, 0, TextFormatCode, b.fieldBuf[0:0]) if err != nil { b.err = err return diff --git a/pgtype.go b/pgtype.go index b9067fab..1705ae41 100644 --- a/pgtype.go +++ b/pgtype.go @@ -147,10 +147,9 @@ type TypeValue interface { // ValueTranscoder is a value that implements the text and binary encoding and decoding interfaces. type ValueTranscoder interface { Value - TextEncoder - BinaryEncoder - TextDecoder - BinaryDecoder + FormatSupport + ParamEncoder + ResultDecoder } type FormatSupport interface { @@ -160,12 +159,17 @@ type FormatSupport interface { } type ParamEncoder interface { - FormatSupport + // EncodeParam should append the encoded value of self to buf. If self is the + // SQL value NULL then append nothing and return (nil, nil). The caller of + // EncodeText is responsible for writing the correct NULL value or the + // length of the data written. EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) } type ResultDecoder interface { - FormatSupport + // DecodeResult decodes src into ResultDecoder. If src is nil then the + // original SQL value is NULL. ResultDecoder takes ownership of src. The + // caller MUST not use it again. DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error } From 550cc7b529735bfb4a31c413d6753110ce616368 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Dec 2021 12:53:20 -0600 Subject: [PATCH 0751/1158] wip --- pgtype.go | 40 +++++++--------------------------------- pgtype_test.go | 16 +++------------- 2 files changed, 10 insertions(+), 46 deletions(-) diff --git a/pgtype.go b/pgtype.go index 1705ae41..d8dd5abf 100644 --- a/pgtype.go +++ b/pgtype.go @@ -179,12 +179,6 @@ type ResultFormatPreferrer interface { PreferredResultFormat() int16 } -// ParamFormatPreferrer allows a type to specify its preferred param format instead of it being inferred from -// whether it is also a BinaryEncoder. -type ParamFormatPreferrer interface { - PreferredParamFormat() int16 -} - type BinaryDecoder interface { // DecodeBinary decodes src into BinaryDecoder. If src is nil then the // original SQL value is NULL. BinaryDecoder takes ownership of src. The @@ -243,7 +237,7 @@ type ConnInfo struct { oidToDataType map[uint32]*DataType nameToDataType map[string]*DataType reflectTypeToName map[reflect.Type]string - oidToParamFormatCode map[uint32]int16 + oidToFormatCode map[uint32]int16 oidToResultFormatCode map[uint32]int16 reflectTypeToDataType map[reflect.Type]*DataType @@ -256,7 +250,7 @@ func newConnInfo() *ConnInfo { oidToDataType: make(map[uint32]*DataType), nameToDataType: make(map[string]*DataType), reflectTypeToName: make(map[reflect.Type]string), - oidToParamFormatCode: make(map[uint32]int16), + oidToFormatCode: make(map[uint32]int16), oidToResultFormatCode: make(map[uint32]int16), preferAssignToOverSQLScannerTypes: make(map[reflect.Type]struct{}), } @@ -392,24 +386,12 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { { var formatCode int16 - if pfp, ok := t.Value.(ParamFormatPreferrer); ok { - formatCode = pfp.PreferredParamFormat() + if pfp, ok := t.Value.(FormatSupport); ok { + formatCode = pfp.PreferredFormat() } else if _, ok := t.Value.(BinaryEncoder); ok { formatCode = BinaryFormatCode } - ci.oidToParamFormatCode[t.OID] = formatCode - } - - { - var formatCode int16 - if fs, ok := t.Value.(FormatSupport); ok { - formatCode = fs.PreferredFormat() - } else if rfp, ok := t.Value.(ResultFormatPreferrer); ok { - formatCode = rfp.PreferredResultFormat() - } else if _, ok := t.Value.(BinaryDecoder); ok { - formatCode = BinaryFormatCode - } - ci.oidToResultFormatCode[t.OID] = formatCode + ci.oidToFormatCode[t.OID] = formatCode } if d, ok := t.Value.(ResultDecoder); ok { @@ -477,16 +459,8 @@ func (ci *ConnInfo) DataTypeForValue(v interface{}) (*DataType, bool) { return dt, ok } -func (ci *ConnInfo) ParamFormatCodeForOID(oid uint32) int16 { - fc, ok := ci.oidToParamFormatCode[oid] - if ok { - return fc - } - return TextFormatCode -} - -func (ci *ConnInfo) ResultFormatCodeForOID(oid uint32) int16 { - fc, ok := ci.oidToResultFormatCode[oid] +func (ci *ConnInfo) FormatCodeForOID(oid uint32) int16 { + fc, ok := ci.oidToFormatCode[oid] if ok { return fc } diff --git a/pgtype_test.go b/pgtype_test.go index 7ae756e5..9bf1f242 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -67,24 +67,14 @@ func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { return addr } -func TestConnInfoResultFormatCodeForOID(t *testing.T) { - ci := pgtype.NewConnInfo() - - // pgtype.JSONB implements BinaryDecoder but also implements ResultFormatPreferrer to override it to text. - assert.Equal(t, int16(pgtype.TextFormatCode), ci.ResultFormatCodeForOID(pgtype.JSONBOID)) - - // pgtype.Int4 implements BinaryDecoder but does not implement ResultFormatPreferrer so it should be binary. - assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.ResultFormatCodeForOID(pgtype.Int4OID)) -} - -func TestConnInfoParamFormatCodeForOID(t *testing.T) { +func TestConnInfoFormatCodeForOID(t *testing.T) { ci := pgtype.NewConnInfo() // pgtype.JSONB implements BinaryEncoder but also implements ParamFormatPreferrer to override it to text. - assert.Equal(t, int16(pgtype.TextFormatCode), ci.ParamFormatCodeForOID(pgtype.JSONBOID)) + assert.Equal(t, int16(pgtype.TextFormatCode), ci.FormatCodeForOID(pgtype.JSONBOID)) // pgtype.Int4 implements BinaryEncoder but does not implement ParamFormatPreferrer so it should be binary. - assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.ParamFormatCodeForOID(pgtype.Int4OID)) + assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.FormatCodeForOID(pgtype.Int4OID)) } func TestConnInfoScanNilIsNoOp(t *testing.T) { From 44214b78541d55dc777ff46694c7418ea1050331 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Dec 2021 13:07:54 -0600 Subject: [PATCH 0752/1158] Import to pgx main repo in pgtype subdir --- CHANGELOG.md => pgtype/CHANGELOG.md | 0 LICENSE => pgtype/LICENSE | 0 README.md => pgtype/README.md | 0 aclitem.go => pgtype/aclitem.go | 0 aclitem_array.go => pgtype/aclitem_array.go | 0 aclitem_array_test.go => pgtype/aclitem_array_test.go | 0 aclitem_test.go => pgtype/aclitem_test.go | 0 array.go => pgtype/array.go | 0 array_test.go => pgtype/array_test.go | 0 array_type.go => pgtype/array_type.go | 0 array_type_test.go => pgtype/array_type_test.go | 0 bit.go => pgtype/bit.go | 0 bit_test.go => pgtype/bit_test.go | 0 bool.go => pgtype/bool.go | 0 bool_array.go => pgtype/bool_array.go | 0 bool_array_test.go => pgtype/bool_array_test.go | 0 bool_test.go => pgtype/bool_test.go | 0 box.go => pgtype/box.go | 0 box_test.go => pgtype/box_test.go | 0 bpchar.go => pgtype/bpchar.go | 0 bpchar_array.go => pgtype/bpchar_array.go | 0 bpchar_array_test.go => pgtype/bpchar_array_test.go | 0 bpchar_test.go => pgtype/bpchar_test.go | 0 bytea.go => pgtype/bytea.go | 0 bytea_array.go => pgtype/bytea_array.go | 0 bytea_array_test.go => pgtype/bytea_array_test.go | 0 bytea_test.go => pgtype/bytea_test.go | 0 cid.go => pgtype/cid.go | 0 cid_test.go => pgtype/cid_test.go | 0 cidr.go => pgtype/cidr.go | 0 cidr_array.go => pgtype/cidr_array.go | 0 cidr_array_test.go => pgtype/cidr_array_test.go | 0 circle.go => pgtype/circle.go | 0 circle_test.go => pgtype/circle_test.go | 0 composite_bench_test.go => pgtype/composite_bench_test.go | 0 composite_fields.go => pgtype/composite_fields.go | 0 composite_fields_test.go => pgtype/composite_fields_test.go | 0 composite_type.go => pgtype/composite_type.go | 0 composite_type_test.go => pgtype/composite_type_test.go | 0 convert.go => pgtype/convert.go | 0 custom_composite_test.go => pgtype/custom_composite_test.go | 0 database_sql.go => pgtype/database_sql.go | 0 date.go => pgtype/date.go | 0 date_array.go => pgtype/date_array.go | 0 date_array_test.go => pgtype/date_array_test.go | 0 date_test.go => pgtype/date_test.go | 0 daterange.go => pgtype/daterange.go | 0 daterange_test.go => pgtype/daterange_test.go | 0 enum_array.go => pgtype/enum_array.go | 0 enum_array_test.go => pgtype/enum_array_test.go | 0 enum_type.go => pgtype/enum_type.go | 0 enum_type_test.go => pgtype/enum_type_test.go | 0 float4.go => pgtype/float4.go | 0 float4_array.go => pgtype/float4_array.go | 0 float4_array_test.go => pgtype/float4_array_test.go | 0 float4_test.go => pgtype/float4_test.go | 0 float8.go => pgtype/float8.go | 0 float8_array.go => pgtype/float8_array.go | 0 float8_array_test.go => pgtype/float8_array_test.go | 0 float8_test.go => pgtype/float8_test.go | 0 generic_binary.go => pgtype/generic_binary.go | 0 generic_text.go => pgtype/generic_text.go | 0 go.mod => pgtype/go.mod | 0 go.sum => pgtype/go.sum | 0 hstore.go => pgtype/hstore.go | 0 hstore_array.go => pgtype/hstore_array.go | 0 hstore_array_test.go => pgtype/hstore_array_test.go | 0 hstore_test.go => pgtype/hstore_test.go | 0 inet.go => pgtype/inet.go | 0 inet_array.go => pgtype/inet_array.go | 0 inet_array_test.go => pgtype/inet_array_test.go | 0 inet_test.go => pgtype/inet_test.go | 0 int2.go => pgtype/int2.go | 0 int2_array.go => pgtype/int2_array.go | 0 int2_array_test.go => pgtype/int2_array_test.go | 0 int2_test.go => pgtype/int2_test.go | 0 int4.go => pgtype/int4.go | 0 int4_array.go => pgtype/int4_array.go | 0 int4_array_test.go => pgtype/int4_array_test.go | 0 int4_test.go => pgtype/int4_test.go | 0 int4range.go => pgtype/int4range.go | 0 int4range_test.go => pgtype/int4range_test.go | 0 int8.go => pgtype/int8.go | 0 int8_array.go => pgtype/int8_array.go | 0 int8_array_test.go => pgtype/int8_array_test.go | 0 int8_test.go => pgtype/int8_test.go | 0 int8range.go => pgtype/int8range.go | 0 int8range_test.go => pgtype/int8range_test.go | 0 .../integration_benchmark_test.go | 0 .../integration_benchmark_test.go.erb | 0 .../integration_benchmark_test_gen.sh | 0 interval.go => pgtype/interval.go | 0 interval_test.go => pgtype/interval_test.go | 0 json.go => pgtype/json.go | 0 json_test.go => pgtype/json_test.go | 0 jsonb.go => pgtype/jsonb.go | 0 jsonb_array.go => pgtype/jsonb_array.go | 0 jsonb_array_test.go => pgtype/jsonb_array_test.go | 0 jsonb_test.go => pgtype/jsonb_test.go | 0 line.go => pgtype/line.go | 0 line_test.go => pgtype/line_test.go | 0 lseg.go => pgtype/lseg.go | 0 lseg_test.go => pgtype/lseg_test.go | 0 macaddr.go => pgtype/macaddr.go | 0 macaddr_array.go => pgtype/macaddr_array.go | 0 macaddr_array_test.go => pgtype/macaddr_array_test.go | 0 macaddr_test.go => pgtype/macaddr_test.go | 0 name.go => pgtype/name.go | 0 name_test.go => pgtype/name_test.go | 0 new_pg_value.erb => pgtype/new_pg_value.erb | 0 new_pg_value_gen.sh => pgtype/new_pg_value_gen.sh | 0 numeric.go => pgtype/numeric.go | 0 numeric_array.go => pgtype/numeric_array.go | 0 numeric_array_test.go => pgtype/numeric_array_test.go | 0 numeric_test.go => pgtype/numeric_test.go | 0 numrange.go => pgtype/numrange.go | 0 numrange_test.go => pgtype/numrange_test.go | 0 oid.go => pgtype/oid.go | 0 oid_value.go => pgtype/oid_value.go | 0 oid_value_test.go => pgtype/oid_value_test.go | 0 path.go => pgtype/path.go | 0 path_test.go => pgtype/path_test.go | 0 pgtype.go => pgtype/pgtype.go | 0 pgtype_test.go => pgtype/pgtype_test.go | 0 pguint32.go => pgtype/pguint32.go | 0 {pgxtype => pgtype/pgxtype}/README.md | 0 {pgxtype => pgtype/pgxtype}/pgxtype.go | 0 point.go => pgtype/point.go | 0 point_test.go => pgtype/point_test.go | 0 polygon.go => pgtype/polygon.go | 0 polygon_test.go => pgtype/polygon_test.go | 0 qchar.go => pgtype/qchar.go | 0 qchar_test.go => pgtype/qchar_test.go | 0 range.go => pgtype/range.go | 0 range_test.go => pgtype/range_test.go | 0 record.go => pgtype/record.go | 0 record_test.go => pgtype/record_test.go | 0 {testutil => pgtype/testutil}/testutil.go | 0 text.go => pgtype/text.go | 0 text_array.go => pgtype/text_array.go | 0 text_array_test.go => pgtype/text_array_test.go | 0 text_test.go => pgtype/text_test.go | 0 tid.go => pgtype/tid.go | 0 tid_test.go => pgtype/tid_test.go | 0 time.go => pgtype/time.go | 0 time_test.go => pgtype/time_test.go | 0 timestamp.go => pgtype/timestamp.go | 0 timestamp_array.go => pgtype/timestamp_array.go | 0 timestamp_array_test.go => pgtype/timestamp_array_test.go | 0 timestamp_test.go => pgtype/timestamp_test.go | 0 timestamptz.go => pgtype/timestamptz.go | 0 timestamptz_array.go => pgtype/timestamptz_array.go | 0 timestamptz_array_test.go => pgtype/timestamptz_array_test.go | 0 timestamptz_test.go => pgtype/timestamptz_test.go | 0 tsrange.go => pgtype/tsrange.go | 0 tsrange_array.go => pgtype/tsrange_array.go | 0 tsrange_test.go => pgtype/tsrange_test.go | 0 tstzrange.go => pgtype/tstzrange.go | 0 tstzrange_array.go => pgtype/tstzrange_array.go | 0 tstzrange_test.go => pgtype/tstzrange_test.go | 0 typed_array.go.erb => pgtype/typed_array.go.erb | 0 typed_array_gen.sh => pgtype/typed_array_gen.sh | 0 typed_range.go.erb => pgtype/typed_range.go.erb | 0 typed_range_gen.sh => pgtype/typed_range_gen.sh | 0 unknown.go => pgtype/unknown.go | 0 uuid.go => pgtype/uuid.go | 0 uuid_array.go => pgtype/uuid_array.go | 0 uuid_array_test.go => pgtype/uuid_array_test.go | 0 uuid_test.go => pgtype/uuid_test.go | 0 varbit.go => pgtype/varbit.go | 0 varbit_test.go => pgtype/varbit_test.go | 0 varchar.go => pgtype/varchar.go | 0 varchar_array.go => pgtype/varchar_array.go | 0 varchar_array_test.go => pgtype/varchar_array_test.go | 0 {.github => pgtype}/workflows/ci.yml | 0 xid.go => pgtype/xid.go | 0 xid_test.go => pgtype/xid_test.go | 0 {zeronull => pgtype/zeronull}/doc.go | 0 {zeronull => pgtype/zeronull}/float8.go | 0 {zeronull => pgtype/zeronull}/float8_test.go | 0 {zeronull => pgtype/zeronull}/int2.go | 0 {zeronull => pgtype/zeronull}/int2_test.go | 0 {zeronull => pgtype/zeronull}/int4.go | 0 {zeronull => pgtype/zeronull}/int4_test.go | 0 {zeronull => pgtype/zeronull}/int8.go | 0 {zeronull => pgtype/zeronull}/int8_test.go | 0 {zeronull => pgtype/zeronull}/text.go | 0 {zeronull => pgtype/zeronull}/text_test.go | 0 {zeronull => pgtype/zeronull}/timestamp.go | 0 {zeronull => pgtype/zeronull}/timestamp_test.go | 0 {zeronull => pgtype/zeronull}/timestamptz.go | 0 {zeronull => pgtype/zeronull}/timestamptz_test.go | 0 {zeronull => pgtype/zeronull}/uuid.go | 0 {zeronull => pgtype/zeronull}/uuid_test.go | 0 zzz.aclitem.go => pgtype/zzz.aclitem.go | 0 zzz.bit.go => pgtype/zzz.bit.go | 0 zzz.bool.go => pgtype/zzz.bool.go | 0 zzz.box.go => pgtype/zzz.box.go | 0 zzz.bpchar.go => pgtype/zzz.bpchar.go | 0 zzz.bytea.go => pgtype/zzz.bytea.go | 0 zzz.cid.go => pgtype/zzz.cid.go | 0 zzz.cidr.go => pgtype/zzz.cidr.go | 0 zzz.circle.go => pgtype/zzz.circle.go | 0 zzz.date.go => pgtype/zzz.date.go | 0 zzz.float4.go => pgtype/zzz.float4.go | 0 zzz.float8.go => pgtype/zzz.float8.go | 0 zzz.generic_binary.go => pgtype/zzz.generic_binary.go | 0 zzz.generic_text.go => pgtype/zzz.generic_text.go | 0 zzz.hstore.go => pgtype/zzz.hstore.go | 0 zzz.inet.go => pgtype/zzz.inet.go | 0 zzz.int2.go => pgtype/zzz.int2.go | 0 zzz.int4.go => pgtype/zzz.int4.go | 0 zzz.int8.go => pgtype/zzz.int8.go | 0 zzz.interval.go => pgtype/zzz.interval.go | 0 zzz.json.go => pgtype/zzz.json.go | 0 zzz.jsonb.go => pgtype/zzz.jsonb.go | 0 zzz.line.go => pgtype/zzz.line.go | 0 zzz.lseg.go => pgtype/zzz.lseg.go | 0 zzz.macadder.go => pgtype/zzz.macadder.go | 0 zzz.name.go => pgtype/zzz.name.go | 0 zzz.numeric.go => pgtype/zzz.numeric.go | 0 zzz.oid.go => pgtype/zzz.oid.go | 0 zzz.oid_value.go => pgtype/zzz.oid_value.go | 0 zzz.path.go => pgtype/zzz.path.go | 0 zzz.pguint32.go => pgtype/zzz.pguint32.go | 0 zzz.point.go => pgtype/zzz.point.go | 0 zzz.polygon.go => pgtype/zzz.polygon.go | 0 zzz.qchar.go => pgtype/zzz.qchar.go | 0 zzz.text.go => pgtype/zzz.text.go | 0 zzz.tid.go => pgtype/zzz.tid.go | 0 zzz.time.go => pgtype/zzz.time.go | 0 zzz.timestamp.go => pgtype/zzz.timestamp.go | 0 zzz.timestamptz.go => pgtype/zzz.timestamptz.go | 0 zzz.uuid.go => pgtype/zzz.uuid.go | 0 zzz.varbit.go => pgtype/zzz.varbit.go | 0 zzz.varchar.go => pgtype/zzz.varchar.go | 0 zzz.xid.go => pgtype/zzz.xid.go | 0 237 files changed, 0 insertions(+), 0 deletions(-) rename CHANGELOG.md => pgtype/CHANGELOG.md (100%) rename LICENSE => pgtype/LICENSE (100%) rename README.md => pgtype/README.md (100%) rename aclitem.go => pgtype/aclitem.go (100%) rename aclitem_array.go => pgtype/aclitem_array.go (100%) rename aclitem_array_test.go => pgtype/aclitem_array_test.go (100%) rename aclitem_test.go => pgtype/aclitem_test.go (100%) rename array.go => pgtype/array.go (100%) rename array_test.go => pgtype/array_test.go (100%) rename array_type.go => pgtype/array_type.go (100%) rename array_type_test.go => pgtype/array_type_test.go (100%) rename bit.go => pgtype/bit.go (100%) rename bit_test.go => pgtype/bit_test.go (100%) rename bool.go => pgtype/bool.go (100%) rename bool_array.go => pgtype/bool_array.go (100%) rename bool_array_test.go => pgtype/bool_array_test.go (100%) rename bool_test.go => pgtype/bool_test.go (100%) rename box.go => pgtype/box.go (100%) rename box_test.go => pgtype/box_test.go (100%) rename bpchar.go => pgtype/bpchar.go (100%) rename bpchar_array.go => pgtype/bpchar_array.go (100%) rename bpchar_array_test.go => pgtype/bpchar_array_test.go (100%) rename bpchar_test.go => pgtype/bpchar_test.go (100%) rename bytea.go => pgtype/bytea.go (100%) rename bytea_array.go => pgtype/bytea_array.go (100%) rename bytea_array_test.go => pgtype/bytea_array_test.go (100%) rename bytea_test.go => pgtype/bytea_test.go (100%) rename cid.go => pgtype/cid.go (100%) rename cid_test.go => pgtype/cid_test.go (100%) rename cidr.go => pgtype/cidr.go (100%) rename cidr_array.go => pgtype/cidr_array.go (100%) rename cidr_array_test.go => pgtype/cidr_array_test.go (100%) rename circle.go => pgtype/circle.go (100%) rename circle_test.go => pgtype/circle_test.go (100%) rename composite_bench_test.go => pgtype/composite_bench_test.go (100%) rename composite_fields.go => pgtype/composite_fields.go (100%) rename composite_fields_test.go => pgtype/composite_fields_test.go (100%) rename composite_type.go => pgtype/composite_type.go (100%) rename composite_type_test.go => pgtype/composite_type_test.go (100%) rename convert.go => pgtype/convert.go (100%) rename custom_composite_test.go => pgtype/custom_composite_test.go (100%) rename database_sql.go => pgtype/database_sql.go (100%) rename date.go => pgtype/date.go (100%) rename date_array.go => pgtype/date_array.go (100%) rename date_array_test.go => pgtype/date_array_test.go (100%) rename date_test.go => pgtype/date_test.go (100%) rename daterange.go => pgtype/daterange.go (100%) rename daterange_test.go => pgtype/daterange_test.go (100%) rename enum_array.go => pgtype/enum_array.go (100%) rename enum_array_test.go => pgtype/enum_array_test.go (100%) rename enum_type.go => pgtype/enum_type.go (100%) rename enum_type_test.go => pgtype/enum_type_test.go (100%) rename float4.go => pgtype/float4.go (100%) rename float4_array.go => pgtype/float4_array.go (100%) rename float4_array_test.go => pgtype/float4_array_test.go (100%) rename float4_test.go => pgtype/float4_test.go (100%) rename float8.go => pgtype/float8.go (100%) rename float8_array.go => pgtype/float8_array.go (100%) rename float8_array_test.go => pgtype/float8_array_test.go (100%) rename float8_test.go => pgtype/float8_test.go (100%) rename generic_binary.go => pgtype/generic_binary.go (100%) rename generic_text.go => pgtype/generic_text.go (100%) rename go.mod => pgtype/go.mod (100%) rename go.sum => pgtype/go.sum (100%) rename hstore.go => pgtype/hstore.go (100%) rename hstore_array.go => pgtype/hstore_array.go (100%) rename hstore_array_test.go => pgtype/hstore_array_test.go (100%) rename hstore_test.go => pgtype/hstore_test.go (100%) rename inet.go => pgtype/inet.go (100%) rename inet_array.go => pgtype/inet_array.go (100%) rename inet_array_test.go => pgtype/inet_array_test.go (100%) rename inet_test.go => pgtype/inet_test.go (100%) rename int2.go => pgtype/int2.go (100%) rename int2_array.go => pgtype/int2_array.go (100%) rename int2_array_test.go => pgtype/int2_array_test.go (100%) rename int2_test.go => pgtype/int2_test.go (100%) rename int4.go => pgtype/int4.go (100%) rename int4_array.go => pgtype/int4_array.go (100%) rename int4_array_test.go => pgtype/int4_array_test.go (100%) rename int4_test.go => pgtype/int4_test.go (100%) rename int4range.go => pgtype/int4range.go (100%) rename int4range_test.go => pgtype/int4range_test.go (100%) rename int8.go => pgtype/int8.go (100%) rename int8_array.go => pgtype/int8_array.go (100%) rename int8_array_test.go => pgtype/int8_array_test.go (100%) rename int8_test.go => pgtype/int8_test.go (100%) rename int8range.go => pgtype/int8range.go (100%) rename int8range_test.go => pgtype/int8range_test.go (100%) rename integration_benchmark_test.go => pgtype/integration_benchmark_test.go (100%) rename integration_benchmark_test.go.erb => pgtype/integration_benchmark_test.go.erb (100%) rename integration_benchmark_test_gen.sh => pgtype/integration_benchmark_test_gen.sh (100%) rename interval.go => pgtype/interval.go (100%) rename interval_test.go => pgtype/interval_test.go (100%) rename json.go => pgtype/json.go (100%) rename json_test.go => pgtype/json_test.go (100%) rename jsonb.go => pgtype/jsonb.go (100%) rename jsonb_array.go => pgtype/jsonb_array.go (100%) rename jsonb_array_test.go => pgtype/jsonb_array_test.go (100%) rename jsonb_test.go => pgtype/jsonb_test.go (100%) rename line.go => pgtype/line.go (100%) rename line_test.go => pgtype/line_test.go (100%) rename lseg.go => pgtype/lseg.go (100%) rename lseg_test.go => pgtype/lseg_test.go (100%) rename macaddr.go => pgtype/macaddr.go (100%) rename macaddr_array.go => pgtype/macaddr_array.go (100%) rename macaddr_array_test.go => pgtype/macaddr_array_test.go (100%) rename macaddr_test.go => pgtype/macaddr_test.go (100%) rename name.go => pgtype/name.go (100%) rename name_test.go => pgtype/name_test.go (100%) rename new_pg_value.erb => pgtype/new_pg_value.erb (100%) rename new_pg_value_gen.sh => pgtype/new_pg_value_gen.sh (100%) rename numeric.go => pgtype/numeric.go (100%) rename numeric_array.go => pgtype/numeric_array.go (100%) rename numeric_array_test.go => pgtype/numeric_array_test.go (100%) rename numeric_test.go => pgtype/numeric_test.go (100%) rename numrange.go => pgtype/numrange.go (100%) rename numrange_test.go => pgtype/numrange_test.go (100%) rename oid.go => pgtype/oid.go (100%) rename oid_value.go => pgtype/oid_value.go (100%) rename oid_value_test.go => pgtype/oid_value_test.go (100%) rename path.go => pgtype/path.go (100%) rename path_test.go => pgtype/path_test.go (100%) rename pgtype.go => pgtype/pgtype.go (100%) rename pgtype_test.go => pgtype/pgtype_test.go (100%) rename pguint32.go => pgtype/pguint32.go (100%) rename {pgxtype => pgtype/pgxtype}/README.md (100%) rename {pgxtype => pgtype/pgxtype}/pgxtype.go (100%) rename point.go => pgtype/point.go (100%) rename point_test.go => pgtype/point_test.go (100%) rename polygon.go => pgtype/polygon.go (100%) rename polygon_test.go => pgtype/polygon_test.go (100%) rename qchar.go => pgtype/qchar.go (100%) rename qchar_test.go => pgtype/qchar_test.go (100%) rename range.go => pgtype/range.go (100%) rename range_test.go => pgtype/range_test.go (100%) rename record.go => pgtype/record.go (100%) rename record_test.go => pgtype/record_test.go (100%) rename {testutil => pgtype/testutil}/testutil.go (100%) rename text.go => pgtype/text.go (100%) rename text_array.go => pgtype/text_array.go (100%) rename text_array_test.go => pgtype/text_array_test.go (100%) rename text_test.go => pgtype/text_test.go (100%) rename tid.go => pgtype/tid.go (100%) rename tid_test.go => pgtype/tid_test.go (100%) rename time.go => pgtype/time.go (100%) rename time_test.go => pgtype/time_test.go (100%) rename timestamp.go => pgtype/timestamp.go (100%) rename timestamp_array.go => pgtype/timestamp_array.go (100%) rename timestamp_array_test.go => pgtype/timestamp_array_test.go (100%) rename timestamp_test.go => pgtype/timestamp_test.go (100%) rename timestamptz.go => pgtype/timestamptz.go (100%) rename timestamptz_array.go => pgtype/timestamptz_array.go (100%) rename timestamptz_array_test.go => pgtype/timestamptz_array_test.go (100%) rename timestamptz_test.go => pgtype/timestamptz_test.go (100%) rename tsrange.go => pgtype/tsrange.go (100%) rename tsrange_array.go => pgtype/tsrange_array.go (100%) rename tsrange_test.go => pgtype/tsrange_test.go (100%) rename tstzrange.go => pgtype/tstzrange.go (100%) rename tstzrange_array.go => pgtype/tstzrange_array.go (100%) rename tstzrange_test.go => pgtype/tstzrange_test.go (100%) rename typed_array.go.erb => pgtype/typed_array.go.erb (100%) rename typed_array_gen.sh => pgtype/typed_array_gen.sh (100%) rename typed_range.go.erb => pgtype/typed_range.go.erb (100%) rename typed_range_gen.sh => pgtype/typed_range_gen.sh (100%) rename unknown.go => pgtype/unknown.go (100%) rename uuid.go => pgtype/uuid.go (100%) rename uuid_array.go => pgtype/uuid_array.go (100%) rename uuid_array_test.go => pgtype/uuid_array_test.go (100%) rename uuid_test.go => pgtype/uuid_test.go (100%) rename varbit.go => pgtype/varbit.go (100%) rename varbit_test.go => pgtype/varbit_test.go (100%) rename varchar.go => pgtype/varchar.go (100%) rename varchar_array.go => pgtype/varchar_array.go (100%) rename varchar_array_test.go => pgtype/varchar_array_test.go (100%) rename {.github => pgtype}/workflows/ci.yml (100%) rename xid.go => pgtype/xid.go (100%) rename xid_test.go => pgtype/xid_test.go (100%) rename {zeronull => pgtype/zeronull}/doc.go (100%) rename {zeronull => pgtype/zeronull}/float8.go (100%) rename {zeronull => pgtype/zeronull}/float8_test.go (100%) rename {zeronull => pgtype/zeronull}/int2.go (100%) rename {zeronull => pgtype/zeronull}/int2_test.go (100%) rename {zeronull => pgtype/zeronull}/int4.go (100%) rename {zeronull => pgtype/zeronull}/int4_test.go (100%) rename {zeronull => pgtype/zeronull}/int8.go (100%) rename {zeronull => pgtype/zeronull}/int8_test.go (100%) rename {zeronull => pgtype/zeronull}/text.go (100%) rename {zeronull => pgtype/zeronull}/text_test.go (100%) rename {zeronull => pgtype/zeronull}/timestamp.go (100%) rename {zeronull => pgtype/zeronull}/timestamp_test.go (100%) rename {zeronull => pgtype/zeronull}/timestamptz.go (100%) rename {zeronull => pgtype/zeronull}/timestamptz_test.go (100%) rename {zeronull => pgtype/zeronull}/uuid.go (100%) rename {zeronull => pgtype/zeronull}/uuid_test.go (100%) rename zzz.aclitem.go => pgtype/zzz.aclitem.go (100%) rename zzz.bit.go => pgtype/zzz.bit.go (100%) rename zzz.bool.go => pgtype/zzz.bool.go (100%) rename zzz.box.go => pgtype/zzz.box.go (100%) rename zzz.bpchar.go => pgtype/zzz.bpchar.go (100%) rename zzz.bytea.go => pgtype/zzz.bytea.go (100%) rename zzz.cid.go => pgtype/zzz.cid.go (100%) rename zzz.cidr.go => pgtype/zzz.cidr.go (100%) rename zzz.circle.go => pgtype/zzz.circle.go (100%) rename zzz.date.go => pgtype/zzz.date.go (100%) rename zzz.float4.go => pgtype/zzz.float4.go (100%) rename zzz.float8.go => pgtype/zzz.float8.go (100%) rename zzz.generic_binary.go => pgtype/zzz.generic_binary.go (100%) rename zzz.generic_text.go => pgtype/zzz.generic_text.go (100%) rename zzz.hstore.go => pgtype/zzz.hstore.go (100%) rename zzz.inet.go => pgtype/zzz.inet.go (100%) rename zzz.int2.go => pgtype/zzz.int2.go (100%) rename zzz.int4.go => pgtype/zzz.int4.go (100%) rename zzz.int8.go => pgtype/zzz.int8.go (100%) rename zzz.interval.go => pgtype/zzz.interval.go (100%) rename zzz.json.go => pgtype/zzz.json.go (100%) rename zzz.jsonb.go => pgtype/zzz.jsonb.go (100%) rename zzz.line.go => pgtype/zzz.line.go (100%) rename zzz.lseg.go => pgtype/zzz.lseg.go (100%) rename zzz.macadder.go => pgtype/zzz.macadder.go (100%) rename zzz.name.go => pgtype/zzz.name.go (100%) rename zzz.numeric.go => pgtype/zzz.numeric.go (100%) rename zzz.oid.go => pgtype/zzz.oid.go (100%) rename zzz.oid_value.go => pgtype/zzz.oid_value.go (100%) rename zzz.path.go => pgtype/zzz.path.go (100%) rename zzz.pguint32.go => pgtype/zzz.pguint32.go (100%) rename zzz.point.go => pgtype/zzz.point.go (100%) rename zzz.polygon.go => pgtype/zzz.polygon.go (100%) rename zzz.qchar.go => pgtype/zzz.qchar.go (100%) rename zzz.text.go => pgtype/zzz.text.go (100%) rename zzz.tid.go => pgtype/zzz.tid.go (100%) rename zzz.time.go => pgtype/zzz.time.go (100%) rename zzz.timestamp.go => pgtype/zzz.timestamp.go (100%) rename zzz.timestamptz.go => pgtype/zzz.timestamptz.go (100%) rename zzz.uuid.go => pgtype/zzz.uuid.go (100%) rename zzz.varbit.go => pgtype/zzz.varbit.go (100%) rename zzz.varchar.go => pgtype/zzz.varchar.go (100%) rename zzz.xid.go => pgtype/zzz.xid.go (100%) diff --git a/CHANGELOG.md b/pgtype/CHANGELOG.md similarity index 100% rename from CHANGELOG.md rename to pgtype/CHANGELOG.md diff --git a/LICENSE b/pgtype/LICENSE similarity index 100% rename from LICENSE rename to pgtype/LICENSE diff --git a/README.md b/pgtype/README.md similarity index 100% rename from README.md rename to pgtype/README.md diff --git a/aclitem.go b/pgtype/aclitem.go similarity index 100% rename from aclitem.go rename to pgtype/aclitem.go diff --git a/aclitem_array.go b/pgtype/aclitem_array.go similarity index 100% rename from aclitem_array.go rename to pgtype/aclitem_array.go diff --git a/aclitem_array_test.go b/pgtype/aclitem_array_test.go similarity index 100% rename from aclitem_array_test.go rename to pgtype/aclitem_array_test.go diff --git a/aclitem_test.go b/pgtype/aclitem_test.go similarity index 100% rename from aclitem_test.go rename to pgtype/aclitem_test.go diff --git a/array.go b/pgtype/array.go similarity index 100% rename from array.go rename to pgtype/array.go diff --git a/array_test.go b/pgtype/array_test.go similarity index 100% rename from array_test.go rename to pgtype/array_test.go diff --git a/array_type.go b/pgtype/array_type.go similarity index 100% rename from array_type.go rename to pgtype/array_type.go diff --git a/array_type_test.go b/pgtype/array_type_test.go similarity index 100% rename from array_type_test.go rename to pgtype/array_type_test.go diff --git a/bit.go b/pgtype/bit.go similarity index 100% rename from bit.go rename to pgtype/bit.go diff --git a/bit_test.go b/pgtype/bit_test.go similarity index 100% rename from bit_test.go rename to pgtype/bit_test.go diff --git a/bool.go b/pgtype/bool.go similarity index 100% rename from bool.go rename to pgtype/bool.go diff --git a/bool_array.go b/pgtype/bool_array.go similarity index 100% rename from bool_array.go rename to pgtype/bool_array.go diff --git a/bool_array_test.go b/pgtype/bool_array_test.go similarity index 100% rename from bool_array_test.go rename to pgtype/bool_array_test.go diff --git a/bool_test.go b/pgtype/bool_test.go similarity index 100% rename from bool_test.go rename to pgtype/bool_test.go diff --git a/box.go b/pgtype/box.go similarity index 100% rename from box.go rename to pgtype/box.go diff --git a/box_test.go b/pgtype/box_test.go similarity index 100% rename from box_test.go rename to pgtype/box_test.go diff --git a/bpchar.go b/pgtype/bpchar.go similarity index 100% rename from bpchar.go rename to pgtype/bpchar.go diff --git a/bpchar_array.go b/pgtype/bpchar_array.go similarity index 100% rename from bpchar_array.go rename to pgtype/bpchar_array.go diff --git a/bpchar_array_test.go b/pgtype/bpchar_array_test.go similarity index 100% rename from bpchar_array_test.go rename to pgtype/bpchar_array_test.go diff --git a/bpchar_test.go b/pgtype/bpchar_test.go similarity index 100% rename from bpchar_test.go rename to pgtype/bpchar_test.go diff --git a/bytea.go b/pgtype/bytea.go similarity index 100% rename from bytea.go rename to pgtype/bytea.go diff --git a/bytea_array.go b/pgtype/bytea_array.go similarity index 100% rename from bytea_array.go rename to pgtype/bytea_array.go diff --git a/bytea_array_test.go b/pgtype/bytea_array_test.go similarity index 100% rename from bytea_array_test.go rename to pgtype/bytea_array_test.go diff --git a/bytea_test.go b/pgtype/bytea_test.go similarity index 100% rename from bytea_test.go rename to pgtype/bytea_test.go diff --git a/cid.go b/pgtype/cid.go similarity index 100% rename from cid.go rename to pgtype/cid.go diff --git a/cid_test.go b/pgtype/cid_test.go similarity index 100% rename from cid_test.go rename to pgtype/cid_test.go diff --git a/cidr.go b/pgtype/cidr.go similarity index 100% rename from cidr.go rename to pgtype/cidr.go diff --git a/cidr_array.go b/pgtype/cidr_array.go similarity index 100% rename from cidr_array.go rename to pgtype/cidr_array.go diff --git a/cidr_array_test.go b/pgtype/cidr_array_test.go similarity index 100% rename from cidr_array_test.go rename to pgtype/cidr_array_test.go diff --git a/circle.go b/pgtype/circle.go similarity index 100% rename from circle.go rename to pgtype/circle.go diff --git a/circle_test.go b/pgtype/circle_test.go similarity index 100% rename from circle_test.go rename to pgtype/circle_test.go diff --git a/composite_bench_test.go b/pgtype/composite_bench_test.go similarity index 100% rename from composite_bench_test.go rename to pgtype/composite_bench_test.go diff --git a/composite_fields.go b/pgtype/composite_fields.go similarity index 100% rename from composite_fields.go rename to pgtype/composite_fields.go diff --git a/composite_fields_test.go b/pgtype/composite_fields_test.go similarity index 100% rename from composite_fields_test.go rename to pgtype/composite_fields_test.go diff --git a/composite_type.go b/pgtype/composite_type.go similarity index 100% rename from composite_type.go rename to pgtype/composite_type.go diff --git a/composite_type_test.go b/pgtype/composite_type_test.go similarity index 100% rename from composite_type_test.go rename to pgtype/composite_type_test.go diff --git a/convert.go b/pgtype/convert.go similarity index 100% rename from convert.go rename to pgtype/convert.go diff --git a/custom_composite_test.go b/pgtype/custom_composite_test.go similarity index 100% rename from custom_composite_test.go rename to pgtype/custom_composite_test.go diff --git a/database_sql.go b/pgtype/database_sql.go similarity index 100% rename from database_sql.go rename to pgtype/database_sql.go diff --git a/date.go b/pgtype/date.go similarity index 100% rename from date.go rename to pgtype/date.go diff --git a/date_array.go b/pgtype/date_array.go similarity index 100% rename from date_array.go rename to pgtype/date_array.go diff --git a/date_array_test.go b/pgtype/date_array_test.go similarity index 100% rename from date_array_test.go rename to pgtype/date_array_test.go diff --git a/date_test.go b/pgtype/date_test.go similarity index 100% rename from date_test.go rename to pgtype/date_test.go diff --git a/daterange.go b/pgtype/daterange.go similarity index 100% rename from daterange.go rename to pgtype/daterange.go diff --git a/daterange_test.go b/pgtype/daterange_test.go similarity index 100% rename from daterange_test.go rename to pgtype/daterange_test.go diff --git a/enum_array.go b/pgtype/enum_array.go similarity index 100% rename from enum_array.go rename to pgtype/enum_array.go diff --git a/enum_array_test.go b/pgtype/enum_array_test.go similarity index 100% rename from enum_array_test.go rename to pgtype/enum_array_test.go diff --git a/enum_type.go b/pgtype/enum_type.go similarity index 100% rename from enum_type.go rename to pgtype/enum_type.go diff --git a/enum_type_test.go b/pgtype/enum_type_test.go similarity index 100% rename from enum_type_test.go rename to pgtype/enum_type_test.go diff --git a/float4.go b/pgtype/float4.go similarity index 100% rename from float4.go rename to pgtype/float4.go diff --git a/float4_array.go b/pgtype/float4_array.go similarity index 100% rename from float4_array.go rename to pgtype/float4_array.go diff --git a/float4_array_test.go b/pgtype/float4_array_test.go similarity index 100% rename from float4_array_test.go rename to pgtype/float4_array_test.go diff --git a/float4_test.go b/pgtype/float4_test.go similarity index 100% rename from float4_test.go rename to pgtype/float4_test.go diff --git a/float8.go b/pgtype/float8.go similarity index 100% rename from float8.go rename to pgtype/float8.go diff --git a/float8_array.go b/pgtype/float8_array.go similarity index 100% rename from float8_array.go rename to pgtype/float8_array.go diff --git a/float8_array_test.go b/pgtype/float8_array_test.go similarity index 100% rename from float8_array_test.go rename to pgtype/float8_array_test.go diff --git a/float8_test.go b/pgtype/float8_test.go similarity index 100% rename from float8_test.go rename to pgtype/float8_test.go diff --git a/generic_binary.go b/pgtype/generic_binary.go similarity index 100% rename from generic_binary.go rename to pgtype/generic_binary.go diff --git a/generic_text.go b/pgtype/generic_text.go similarity index 100% rename from generic_text.go rename to pgtype/generic_text.go diff --git a/go.mod b/pgtype/go.mod similarity index 100% rename from go.mod rename to pgtype/go.mod diff --git a/go.sum b/pgtype/go.sum similarity index 100% rename from go.sum rename to pgtype/go.sum diff --git a/hstore.go b/pgtype/hstore.go similarity index 100% rename from hstore.go rename to pgtype/hstore.go diff --git a/hstore_array.go b/pgtype/hstore_array.go similarity index 100% rename from hstore_array.go rename to pgtype/hstore_array.go diff --git a/hstore_array_test.go b/pgtype/hstore_array_test.go similarity index 100% rename from hstore_array_test.go rename to pgtype/hstore_array_test.go diff --git a/hstore_test.go b/pgtype/hstore_test.go similarity index 100% rename from hstore_test.go rename to pgtype/hstore_test.go diff --git a/inet.go b/pgtype/inet.go similarity index 100% rename from inet.go rename to pgtype/inet.go diff --git a/inet_array.go b/pgtype/inet_array.go similarity index 100% rename from inet_array.go rename to pgtype/inet_array.go diff --git a/inet_array_test.go b/pgtype/inet_array_test.go similarity index 100% rename from inet_array_test.go rename to pgtype/inet_array_test.go diff --git a/inet_test.go b/pgtype/inet_test.go similarity index 100% rename from inet_test.go rename to pgtype/inet_test.go diff --git a/int2.go b/pgtype/int2.go similarity index 100% rename from int2.go rename to pgtype/int2.go diff --git a/int2_array.go b/pgtype/int2_array.go similarity index 100% rename from int2_array.go rename to pgtype/int2_array.go diff --git a/int2_array_test.go b/pgtype/int2_array_test.go similarity index 100% rename from int2_array_test.go rename to pgtype/int2_array_test.go diff --git a/int2_test.go b/pgtype/int2_test.go similarity index 100% rename from int2_test.go rename to pgtype/int2_test.go diff --git a/int4.go b/pgtype/int4.go similarity index 100% rename from int4.go rename to pgtype/int4.go diff --git a/int4_array.go b/pgtype/int4_array.go similarity index 100% rename from int4_array.go rename to pgtype/int4_array.go diff --git a/int4_array_test.go b/pgtype/int4_array_test.go similarity index 100% rename from int4_array_test.go rename to pgtype/int4_array_test.go diff --git a/int4_test.go b/pgtype/int4_test.go similarity index 100% rename from int4_test.go rename to pgtype/int4_test.go diff --git a/int4range.go b/pgtype/int4range.go similarity index 100% rename from int4range.go rename to pgtype/int4range.go diff --git a/int4range_test.go b/pgtype/int4range_test.go similarity index 100% rename from int4range_test.go rename to pgtype/int4range_test.go diff --git a/int8.go b/pgtype/int8.go similarity index 100% rename from int8.go rename to pgtype/int8.go diff --git a/int8_array.go b/pgtype/int8_array.go similarity index 100% rename from int8_array.go rename to pgtype/int8_array.go diff --git a/int8_array_test.go b/pgtype/int8_array_test.go similarity index 100% rename from int8_array_test.go rename to pgtype/int8_array_test.go diff --git a/int8_test.go b/pgtype/int8_test.go similarity index 100% rename from int8_test.go rename to pgtype/int8_test.go diff --git a/int8range.go b/pgtype/int8range.go similarity index 100% rename from int8range.go rename to pgtype/int8range.go diff --git a/int8range_test.go b/pgtype/int8range_test.go similarity index 100% rename from int8range_test.go rename to pgtype/int8range_test.go diff --git a/integration_benchmark_test.go b/pgtype/integration_benchmark_test.go similarity index 100% rename from integration_benchmark_test.go rename to pgtype/integration_benchmark_test.go diff --git a/integration_benchmark_test.go.erb b/pgtype/integration_benchmark_test.go.erb similarity index 100% rename from integration_benchmark_test.go.erb rename to pgtype/integration_benchmark_test.go.erb diff --git a/integration_benchmark_test_gen.sh b/pgtype/integration_benchmark_test_gen.sh similarity index 100% rename from integration_benchmark_test_gen.sh rename to pgtype/integration_benchmark_test_gen.sh diff --git a/interval.go b/pgtype/interval.go similarity index 100% rename from interval.go rename to pgtype/interval.go diff --git a/interval_test.go b/pgtype/interval_test.go similarity index 100% rename from interval_test.go rename to pgtype/interval_test.go diff --git a/json.go b/pgtype/json.go similarity index 100% rename from json.go rename to pgtype/json.go diff --git a/json_test.go b/pgtype/json_test.go similarity index 100% rename from json_test.go rename to pgtype/json_test.go diff --git a/jsonb.go b/pgtype/jsonb.go similarity index 100% rename from jsonb.go rename to pgtype/jsonb.go diff --git a/jsonb_array.go b/pgtype/jsonb_array.go similarity index 100% rename from jsonb_array.go rename to pgtype/jsonb_array.go diff --git a/jsonb_array_test.go b/pgtype/jsonb_array_test.go similarity index 100% rename from jsonb_array_test.go rename to pgtype/jsonb_array_test.go diff --git a/jsonb_test.go b/pgtype/jsonb_test.go similarity index 100% rename from jsonb_test.go rename to pgtype/jsonb_test.go diff --git a/line.go b/pgtype/line.go similarity index 100% rename from line.go rename to pgtype/line.go diff --git a/line_test.go b/pgtype/line_test.go similarity index 100% rename from line_test.go rename to pgtype/line_test.go diff --git a/lseg.go b/pgtype/lseg.go similarity index 100% rename from lseg.go rename to pgtype/lseg.go diff --git a/lseg_test.go b/pgtype/lseg_test.go similarity index 100% rename from lseg_test.go rename to pgtype/lseg_test.go diff --git a/macaddr.go b/pgtype/macaddr.go similarity index 100% rename from macaddr.go rename to pgtype/macaddr.go diff --git a/macaddr_array.go b/pgtype/macaddr_array.go similarity index 100% rename from macaddr_array.go rename to pgtype/macaddr_array.go diff --git a/macaddr_array_test.go b/pgtype/macaddr_array_test.go similarity index 100% rename from macaddr_array_test.go rename to pgtype/macaddr_array_test.go diff --git a/macaddr_test.go b/pgtype/macaddr_test.go similarity index 100% rename from macaddr_test.go rename to pgtype/macaddr_test.go diff --git a/name.go b/pgtype/name.go similarity index 100% rename from name.go rename to pgtype/name.go diff --git a/name_test.go b/pgtype/name_test.go similarity index 100% rename from name_test.go rename to pgtype/name_test.go diff --git a/new_pg_value.erb b/pgtype/new_pg_value.erb similarity index 100% rename from new_pg_value.erb rename to pgtype/new_pg_value.erb diff --git a/new_pg_value_gen.sh b/pgtype/new_pg_value_gen.sh similarity index 100% rename from new_pg_value_gen.sh rename to pgtype/new_pg_value_gen.sh diff --git a/numeric.go b/pgtype/numeric.go similarity index 100% rename from numeric.go rename to pgtype/numeric.go diff --git a/numeric_array.go b/pgtype/numeric_array.go similarity index 100% rename from numeric_array.go rename to pgtype/numeric_array.go diff --git a/numeric_array_test.go b/pgtype/numeric_array_test.go similarity index 100% rename from numeric_array_test.go rename to pgtype/numeric_array_test.go diff --git a/numeric_test.go b/pgtype/numeric_test.go similarity index 100% rename from numeric_test.go rename to pgtype/numeric_test.go diff --git a/numrange.go b/pgtype/numrange.go similarity index 100% rename from numrange.go rename to pgtype/numrange.go diff --git a/numrange_test.go b/pgtype/numrange_test.go similarity index 100% rename from numrange_test.go rename to pgtype/numrange_test.go diff --git a/oid.go b/pgtype/oid.go similarity index 100% rename from oid.go rename to pgtype/oid.go diff --git a/oid_value.go b/pgtype/oid_value.go similarity index 100% rename from oid_value.go rename to pgtype/oid_value.go diff --git a/oid_value_test.go b/pgtype/oid_value_test.go similarity index 100% rename from oid_value_test.go rename to pgtype/oid_value_test.go diff --git a/path.go b/pgtype/path.go similarity index 100% rename from path.go rename to pgtype/path.go diff --git a/path_test.go b/pgtype/path_test.go similarity index 100% rename from path_test.go rename to pgtype/path_test.go diff --git a/pgtype.go b/pgtype/pgtype.go similarity index 100% rename from pgtype.go rename to pgtype/pgtype.go diff --git a/pgtype_test.go b/pgtype/pgtype_test.go similarity index 100% rename from pgtype_test.go rename to pgtype/pgtype_test.go diff --git a/pguint32.go b/pgtype/pguint32.go similarity index 100% rename from pguint32.go rename to pgtype/pguint32.go diff --git a/pgxtype/README.md b/pgtype/pgxtype/README.md similarity index 100% rename from pgxtype/README.md rename to pgtype/pgxtype/README.md diff --git a/pgxtype/pgxtype.go b/pgtype/pgxtype/pgxtype.go similarity index 100% rename from pgxtype/pgxtype.go rename to pgtype/pgxtype/pgxtype.go diff --git a/point.go b/pgtype/point.go similarity index 100% rename from point.go rename to pgtype/point.go diff --git a/point_test.go b/pgtype/point_test.go similarity index 100% rename from point_test.go rename to pgtype/point_test.go diff --git a/polygon.go b/pgtype/polygon.go similarity index 100% rename from polygon.go rename to pgtype/polygon.go diff --git a/polygon_test.go b/pgtype/polygon_test.go similarity index 100% rename from polygon_test.go rename to pgtype/polygon_test.go diff --git a/qchar.go b/pgtype/qchar.go similarity index 100% rename from qchar.go rename to pgtype/qchar.go diff --git a/qchar_test.go b/pgtype/qchar_test.go similarity index 100% rename from qchar_test.go rename to pgtype/qchar_test.go diff --git a/range.go b/pgtype/range.go similarity index 100% rename from range.go rename to pgtype/range.go diff --git a/range_test.go b/pgtype/range_test.go similarity index 100% rename from range_test.go rename to pgtype/range_test.go diff --git a/record.go b/pgtype/record.go similarity index 100% rename from record.go rename to pgtype/record.go diff --git a/record_test.go b/pgtype/record_test.go similarity index 100% rename from record_test.go rename to pgtype/record_test.go diff --git a/testutil/testutil.go b/pgtype/testutil/testutil.go similarity index 100% rename from testutil/testutil.go rename to pgtype/testutil/testutil.go diff --git a/text.go b/pgtype/text.go similarity index 100% rename from text.go rename to pgtype/text.go diff --git a/text_array.go b/pgtype/text_array.go similarity index 100% rename from text_array.go rename to pgtype/text_array.go diff --git a/text_array_test.go b/pgtype/text_array_test.go similarity index 100% rename from text_array_test.go rename to pgtype/text_array_test.go diff --git a/text_test.go b/pgtype/text_test.go similarity index 100% rename from text_test.go rename to pgtype/text_test.go diff --git a/tid.go b/pgtype/tid.go similarity index 100% rename from tid.go rename to pgtype/tid.go diff --git a/tid_test.go b/pgtype/tid_test.go similarity index 100% rename from tid_test.go rename to pgtype/tid_test.go diff --git a/time.go b/pgtype/time.go similarity index 100% rename from time.go rename to pgtype/time.go diff --git a/time_test.go b/pgtype/time_test.go similarity index 100% rename from time_test.go rename to pgtype/time_test.go diff --git a/timestamp.go b/pgtype/timestamp.go similarity index 100% rename from timestamp.go rename to pgtype/timestamp.go diff --git a/timestamp_array.go b/pgtype/timestamp_array.go similarity index 100% rename from timestamp_array.go rename to pgtype/timestamp_array.go diff --git a/timestamp_array_test.go b/pgtype/timestamp_array_test.go similarity index 100% rename from timestamp_array_test.go rename to pgtype/timestamp_array_test.go diff --git a/timestamp_test.go b/pgtype/timestamp_test.go similarity index 100% rename from timestamp_test.go rename to pgtype/timestamp_test.go diff --git a/timestamptz.go b/pgtype/timestamptz.go similarity index 100% rename from timestamptz.go rename to pgtype/timestamptz.go diff --git a/timestamptz_array.go b/pgtype/timestamptz_array.go similarity index 100% rename from timestamptz_array.go rename to pgtype/timestamptz_array.go diff --git a/timestamptz_array_test.go b/pgtype/timestamptz_array_test.go similarity index 100% rename from timestamptz_array_test.go rename to pgtype/timestamptz_array_test.go diff --git a/timestamptz_test.go b/pgtype/timestamptz_test.go similarity index 100% rename from timestamptz_test.go rename to pgtype/timestamptz_test.go diff --git a/tsrange.go b/pgtype/tsrange.go similarity index 100% rename from tsrange.go rename to pgtype/tsrange.go diff --git a/tsrange_array.go b/pgtype/tsrange_array.go similarity index 100% rename from tsrange_array.go rename to pgtype/tsrange_array.go diff --git a/tsrange_test.go b/pgtype/tsrange_test.go similarity index 100% rename from tsrange_test.go rename to pgtype/tsrange_test.go diff --git a/tstzrange.go b/pgtype/tstzrange.go similarity index 100% rename from tstzrange.go rename to pgtype/tstzrange.go diff --git a/tstzrange_array.go b/pgtype/tstzrange_array.go similarity index 100% rename from tstzrange_array.go rename to pgtype/tstzrange_array.go diff --git a/tstzrange_test.go b/pgtype/tstzrange_test.go similarity index 100% rename from tstzrange_test.go rename to pgtype/tstzrange_test.go diff --git a/typed_array.go.erb b/pgtype/typed_array.go.erb similarity index 100% rename from typed_array.go.erb rename to pgtype/typed_array.go.erb diff --git a/typed_array_gen.sh b/pgtype/typed_array_gen.sh similarity index 100% rename from typed_array_gen.sh rename to pgtype/typed_array_gen.sh diff --git a/typed_range.go.erb b/pgtype/typed_range.go.erb similarity index 100% rename from typed_range.go.erb rename to pgtype/typed_range.go.erb diff --git a/typed_range_gen.sh b/pgtype/typed_range_gen.sh similarity index 100% rename from typed_range_gen.sh rename to pgtype/typed_range_gen.sh diff --git a/unknown.go b/pgtype/unknown.go similarity index 100% rename from unknown.go rename to pgtype/unknown.go diff --git a/uuid.go b/pgtype/uuid.go similarity index 100% rename from uuid.go rename to pgtype/uuid.go diff --git a/uuid_array.go b/pgtype/uuid_array.go similarity index 100% rename from uuid_array.go rename to pgtype/uuid_array.go diff --git a/uuid_array_test.go b/pgtype/uuid_array_test.go similarity index 100% rename from uuid_array_test.go rename to pgtype/uuid_array_test.go diff --git a/uuid_test.go b/pgtype/uuid_test.go similarity index 100% rename from uuid_test.go rename to pgtype/uuid_test.go diff --git a/varbit.go b/pgtype/varbit.go similarity index 100% rename from varbit.go rename to pgtype/varbit.go diff --git a/varbit_test.go b/pgtype/varbit_test.go similarity index 100% rename from varbit_test.go rename to pgtype/varbit_test.go diff --git a/varchar.go b/pgtype/varchar.go similarity index 100% rename from varchar.go rename to pgtype/varchar.go diff --git a/varchar_array.go b/pgtype/varchar_array.go similarity index 100% rename from varchar_array.go rename to pgtype/varchar_array.go diff --git a/varchar_array_test.go b/pgtype/varchar_array_test.go similarity index 100% rename from varchar_array_test.go rename to pgtype/varchar_array_test.go diff --git a/.github/workflows/ci.yml b/pgtype/workflows/ci.yml similarity index 100% rename from .github/workflows/ci.yml rename to pgtype/workflows/ci.yml diff --git a/xid.go b/pgtype/xid.go similarity index 100% rename from xid.go rename to pgtype/xid.go diff --git a/xid_test.go b/pgtype/xid_test.go similarity index 100% rename from xid_test.go rename to pgtype/xid_test.go diff --git a/zeronull/doc.go b/pgtype/zeronull/doc.go similarity index 100% rename from zeronull/doc.go rename to pgtype/zeronull/doc.go diff --git a/zeronull/float8.go b/pgtype/zeronull/float8.go similarity index 100% rename from zeronull/float8.go rename to pgtype/zeronull/float8.go diff --git a/zeronull/float8_test.go b/pgtype/zeronull/float8_test.go similarity index 100% rename from zeronull/float8_test.go rename to pgtype/zeronull/float8_test.go diff --git a/zeronull/int2.go b/pgtype/zeronull/int2.go similarity index 100% rename from zeronull/int2.go rename to pgtype/zeronull/int2.go diff --git a/zeronull/int2_test.go b/pgtype/zeronull/int2_test.go similarity index 100% rename from zeronull/int2_test.go rename to pgtype/zeronull/int2_test.go diff --git a/zeronull/int4.go b/pgtype/zeronull/int4.go similarity index 100% rename from zeronull/int4.go rename to pgtype/zeronull/int4.go diff --git a/zeronull/int4_test.go b/pgtype/zeronull/int4_test.go similarity index 100% rename from zeronull/int4_test.go rename to pgtype/zeronull/int4_test.go diff --git a/zeronull/int8.go b/pgtype/zeronull/int8.go similarity index 100% rename from zeronull/int8.go rename to pgtype/zeronull/int8.go diff --git a/zeronull/int8_test.go b/pgtype/zeronull/int8_test.go similarity index 100% rename from zeronull/int8_test.go rename to pgtype/zeronull/int8_test.go diff --git a/zeronull/text.go b/pgtype/zeronull/text.go similarity index 100% rename from zeronull/text.go rename to pgtype/zeronull/text.go diff --git a/zeronull/text_test.go b/pgtype/zeronull/text_test.go similarity index 100% rename from zeronull/text_test.go rename to pgtype/zeronull/text_test.go diff --git a/zeronull/timestamp.go b/pgtype/zeronull/timestamp.go similarity index 100% rename from zeronull/timestamp.go rename to pgtype/zeronull/timestamp.go diff --git a/zeronull/timestamp_test.go b/pgtype/zeronull/timestamp_test.go similarity index 100% rename from zeronull/timestamp_test.go rename to pgtype/zeronull/timestamp_test.go diff --git a/zeronull/timestamptz.go b/pgtype/zeronull/timestamptz.go similarity index 100% rename from zeronull/timestamptz.go rename to pgtype/zeronull/timestamptz.go diff --git a/zeronull/timestamptz_test.go b/pgtype/zeronull/timestamptz_test.go similarity index 100% rename from zeronull/timestamptz_test.go rename to pgtype/zeronull/timestamptz_test.go diff --git a/zeronull/uuid.go b/pgtype/zeronull/uuid.go similarity index 100% rename from zeronull/uuid.go rename to pgtype/zeronull/uuid.go diff --git a/zeronull/uuid_test.go b/pgtype/zeronull/uuid_test.go similarity index 100% rename from zeronull/uuid_test.go rename to pgtype/zeronull/uuid_test.go diff --git a/zzz.aclitem.go b/pgtype/zzz.aclitem.go similarity index 100% rename from zzz.aclitem.go rename to pgtype/zzz.aclitem.go diff --git a/zzz.bit.go b/pgtype/zzz.bit.go similarity index 100% rename from zzz.bit.go rename to pgtype/zzz.bit.go diff --git a/zzz.bool.go b/pgtype/zzz.bool.go similarity index 100% rename from zzz.bool.go rename to pgtype/zzz.bool.go diff --git a/zzz.box.go b/pgtype/zzz.box.go similarity index 100% rename from zzz.box.go rename to pgtype/zzz.box.go diff --git a/zzz.bpchar.go b/pgtype/zzz.bpchar.go similarity index 100% rename from zzz.bpchar.go rename to pgtype/zzz.bpchar.go diff --git a/zzz.bytea.go b/pgtype/zzz.bytea.go similarity index 100% rename from zzz.bytea.go rename to pgtype/zzz.bytea.go diff --git a/zzz.cid.go b/pgtype/zzz.cid.go similarity index 100% rename from zzz.cid.go rename to pgtype/zzz.cid.go diff --git a/zzz.cidr.go b/pgtype/zzz.cidr.go similarity index 100% rename from zzz.cidr.go rename to pgtype/zzz.cidr.go diff --git a/zzz.circle.go b/pgtype/zzz.circle.go similarity index 100% rename from zzz.circle.go rename to pgtype/zzz.circle.go diff --git a/zzz.date.go b/pgtype/zzz.date.go similarity index 100% rename from zzz.date.go rename to pgtype/zzz.date.go diff --git a/zzz.float4.go b/pgtype/zzz.float4.go similarity index 100% rename from zzz.float4.go rename to pgtype/zzz.float4.go diff --git a/zzz.float8.go b/pgtype/zzz.float8.go similarity index 100% rename from zzz.float8.go rename to pgtype/zzz.float8.go diff --git a/zzz.generic_binary.go b/pgtype/zzz.generic_binary.go similarity index 100% rename from zzz.generic_binary.go rename to pgtype/zzz.generic_binary.go diff --git a/zzz.generic_text.go b/pgtype/zzz.generic_text.go similarity index 100% rename from zzz.generic_text.go rename to pgtype/zzz.generic_text.go diff --git a/zzz.hstore.go b/pgtype/zzz.hstore.go similarity index 100% rename from zzz.hstore.go rename to pgtype/zzz.hstore.go diff --git a/zzz.inet.go b/pgtype/zzz.inet.go similarity index 100% rename from zzz.inet.go rename to pgtype/zzz.inet.go diff --git a/zzz.int2.go b/pgtype/zzz.int2.go similarity index 100% rename from zzz.int2.go rename to pgtype/zzz.int2.go diff --git a/zzz.int4.go b/pgtype/zzz.int4.go similarity index 100% rename from zzz.int4.go rename to pgtype/zzz.int4.go diff --git a/zzz.int8.go b/pgtype/zzz.int8.go similarity index 100% rename from zzz.int8.go rename to pgtype/zzz.int8.go diff --git a/zzz.interval.go b/pgtype/zzz.interval.go similarity index 100% rename from zzz.interval.go rename to pgtype/zzz.interval.go diff --git a/zzz.json.go b/pgtype/zzz.json.go similarity index 100% rename from zzz.json.go rename to pgtype/zzz.json.go diff --git a/zzz.jsonb.go b/pgtype/zzz.jsonb.go similarity index 100% rename from zzz.jsonb.go rename to pgtype/zzz.jsonb.go diff --git a/zzz.line.go b/pgtype/zzz.line.go similarity index 100% rename from zzz.line.go rename to pgtype/zzz.line.go diff --git a/zzz.lseg.go b/pgtype/zzz.lseg.go similarity index 100% rename from zzz.lseg.go rename to pgtype/zzz.lseg.go diff --git a/zzz.macadder.go b/pgtype/zzz.macadder.go similarity index 100% rename from zzz.macadder.go rename to pgtype/zzz.macadder.go diff --git a/zzz.name.go b/pgtype/zzz.name.go similarity index 100% rename from zzz.name.go rename to pgtype/zzz.name.go diff --git a/zzz.numeric.go b/pgtype/zzz.numeric.go similarity index 100% rename from zzz.numeric.go rename to pgtype/zzz.numeric.go diff --git a/zzz.oid.go b/pgtype/zzz.oid.go similarity index 100% rename from zzz.oid.go rename to pgtype/zzz.oid.go diff --git a/zzz.oid_value.go b/pgtype/zzz.oid_value.go similarity index 100% rename from zzz.oid_value.go rename to pgtype/zzz.oid_value.go diff --git a/zzz.path.go b/pgtype/zzz.path.go similarity index 100% rename from zzz.path.go rename to pgtype/zzz.path.go diff --git a/zzz.pguint32.go b/pgtype/zzz.pguint32.go similarity index 100% rename from zzz.pguint32.go rename to pgtype/zzz.pguint32.go diff --git a/zzz.point.go b/pgtype/zzz.point.go similarity index 100% rename from zzz.point.go rename to pgtype/zzz.point.go diff --git a/zzz.polygon.go b/pgtype/zzz.polygon.go similarity index 100% rename from zzz.polygon.go rename to pgtype/zzz.polygon.go diff --git a/zzz.qchar.go b/pgtype/zzz.qchar.go similarity index 100% rename from zzz.qchar.go rename to pgtype/zzz.qchar.go diff --git a/zzz.text.go b/pgtype/zzz.text.go similarity index 100% rename from zzz.text.go rename to pgtype/zzz.text.go diff --git a/zzz.tid.go b/pgtype/zzz.tid.go similarity index 100% rename from zzz.tid.go rename to pgtype/zzz.tid.go diff --git a/zzz.time.go b/pgtype/zzz.time.go similarity index 100% rename from zzz.time.go rename to pgtype/zzz.time.go diff --git a/zzz.timestamp.go b/pgtype/zzz.timestamp.go similarity index 100% rename from zzz.timestamp.go rename to pgtype/zzz.timestamp.go diff --git a/zzz.timestamptz.go b/pgtype/zzz.timestamptz.go similarity index 100% rename from zzz.timestamptz.go rename to pgtype/zzz.timestamptz.go diff --git a/zzz.uuid.go b/pgtype/zzz.uuid.go similarity index 100% rename from zzz.uuid.go rename to pgtype/zzz.uuid.go diff --git a/zzz.varbit.go b/pgtype/zzz.varbit.go similarity index 100% rename from zzz.varbit.go rename to pgtype/zzz.varbit.go diff --git a/zzz.varchar.go b/pgtype/zzz.varchar.go similarity index 100% rename from zzz.varchar.go rename to pgtype/zzz.varchar.go diff --git a/zzz.xid.go b/pgtype/zzz.xid.go similarity index 100% rename from zzz.xid.go rename to pgtype/zzz.xid.go From 7e13db45388d1e8ea642035509dc7dcea67fea54 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Dec 2021 13:33:09 -0600 Subject: [PATCH 0753/1158] Finish import of pgtype repo Fix some tests that broke by merging repos Tweak readme wording --- CHANGELOG.md | 10 ++ README.md | 4 +- bench_test.go | 8 +- conn.go | 8 +- conn_test.go | 2 +- example_custom_type_test.go | 21 ++- extended_query_builder.go | 2 +- go.mod | 1 - go.sum | 39 +---- messages.go | 2 +- pgtype/CHANGELOG.md | 121 --------------- pgtype/LICENSE | 22 --- pgtype/README.md | 8 - pgtype/aclitem_array_test.go | 4 +- pgtype/aclitem_test.go | 4 +- pgtype/array_test.go | 2 +- pgtype/array_type_test.go | 4 +- pgtype/bit_test.go | 4 +- pgtype/bool_array_test.go | 4 +- pgtype/bool_test.go | 4 +- pgtype/box_test.go | 4 +- pgtype/bpchar_array_test.go | 4 +- pgtype/bpchar_test.go | 4 +- pgtype/bytea_array_test.go | 4 +- pgtype/bytea_test.go | 4 +- pgtype/cid_test.go | 4 +- pgtype/cidr_array_test.go | 4 +- pgtype/circle_test.go | 4 +- pgtype/composite_bench_test.go | 2 +- pgtype/composite_fields_test.go | 4 +- pgtype/composite_type_test.go | 4 +- pgtype/custom_composite_test.go | 2 +- pgtype/date_array_test.go | 4 +- pgtype/date_test.go | 4 +- pgtype/daterange_test.go | 4 +- pgtype/enum_array_test.go | 4 +- pgtype/enum_type_test.go | 4 +- pgtype/float4_array_test.go | 4 +- pgtype/float4_test.go | 4 +- pgtype/float8_array_test.go | 4 +- pgtype/float8_test.go | 4 +- pgtype/go.mod | 10 -- pgtype/go.sum | 180 ----------------------- pgtype/hstore_array_test.go | 4 +- pgtype/hstore_test.go | 4 +- pgtype/inet_array_test.go | 4 +- pgtype/inet_test.go | 4 +- pgtype/int2_array_test.go | 4 +- pgtype/int2_test.go | 4 +- pgtype/int4_array_test.go | 4 +- pgtype/int4_test.go | 4 +- pgtype/int4range_test.go | 4 +- pgtype/int8_array_test.go | 4 +- pgtype/int8_test.go | 4 +- pgtype/int8range_test.go | 4 +- pgtype/integration_benchmark_test.go | 4 +- pgtype/integration_benchmark_test.go.erb | 2 +- pgtype/interval_test.go | 4 +- pgtype/json_test.go | 4 +- pgtype/jsonb_array_test.go | 4 +- pgtype/jsonb_test.go | 4 +- pgtype/line_test.go | 4 +- pgtype/lseg_test.go | 4 +- pgtype/macaddr_array_test.go | 4 +- pgtype/macaddr_test.go | 4 +- pgtype/name_test.go | 4 +- pgtype/numeric_array_test.go | 4 +- pgtype/numeric_test.go | 6 +- pgtype/numrange_test.go | 4 +- pgtype/oid_value_test.go | 4 +- pgtype/path_test.go | 4 +- pgtype/pgtype_test.go | 2 +- pgtype/pgxtype/pgxtype.go | 2 +- pgtype/point_test.go | 4 +- pgtype/polygon_test.go | 4 +- pgtype/qchar_test.go | 4 +- pgtype/record.go | 22 +++ pgtype/record_test.go | 4 +- pgtype/testutil/testutil.go | 2 +- pgtype/text_array_test.go | 6 +- pgtype/text_test.go | 4 +- pgtype/tid_test.go | 4 +- pgtype/time_test.go | 4 +- pgtype/timestamp_array_test.go | 4 +- pgtype/timestamp_test.go | 8 +- pgtype/timestamptz_array_test.go | 4 +- pgtype/timestamptz_test.go | 8 +- pgtype/tsrange_test.go | 4 +- pgtype/tstzrange_test.go | 6 +- pgtype/uuid_array_test.go | 4 +- pgtype/uuid_test.go | 4 +- pgtype/varbit_test.go | 4 +- pgtype/varchar_array_test.go | 4 +- pgtype/workflows/ci.yml | 52 ------- pgtype/xid_test.go | 4 +- pgtype/zeronull/float8.go | 2 +- pgtype/zeronull/float8_test.go | 4 +- pgtype/zeronull/int2.go | 2 +- pgtype/zeronull/int2_test.go | 4 +- pgtype/zeronull/int4.go | 2 +- pgtype/zeronull/int4_test.go | 4 +- pgtype/zeronull/int8.go | 2 +- pgtype/zeronull/int8_test.go | 4 +- pgtype/zeronull/text.go | 2 +- pgtype/zeronull/text_test.go | 4 +- pgtype/zeronull/timestamp.go | 2 +- pgtype/zeronull/timestamp_test.go | 4 +- pgtype/zeronull/timestamptz.go | 2 +- pgtype/zeronull/timestamptz_test.go | 4 +- pgtype/zeronull/uuid.go | 2 +- pgtype/zeronull/uuid_test.go | 4 +- query_test.go | 22 +-- rows.go | 2 +- stdlib/sql.go | 2 +- values.go | 8 +- values_test.go | 74 +++++----- 116 files changed, 294 insertions(+), 686 deletions(-) delete mode 100644 pgtype/CHANGELOG.md delete mode 100644 pgtype/LICENSE delete mode 100644 pgtype/README.md delete mode 100644 pgtype/go.mod delete mode 100644 pgtype/go.sum delete mode 100644 pgtype/workflows/ci.yml diff --git a/CHANGELOG.md b/CHANGELOG.md index 198a6ea4..1f285400 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,13 @@ +# Unreleased v5 + +* Import pgtype repository + +## pgtype Changes + +* Types now have Valid boolean field instead of Status byte. This matches database/sql pattern. +* Extracted integrations with github.com/shopspring/decimal and github.com/gofrs/uuid to https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. + + # 4.14.1 (November 28, 2021) * Upgrade pgtype to v1.9.1 (fixes unintentional change to timestamp binary decoding) diff --git a/README.md b/README.md index 110d4f02..5a65a00f 100644 --- a/README.md +++ b/README.md @@ -172,9 +172,9 @@ from pgx for lower-level control. This is a `database/sql` compatibility layer for pgx. pgx can be used as a normal `database/sql` driver, but at any time, the native interface can be acquired for more performance or PostgreSQL specific functionality. -### [github.com/jackc/pgtype](https://github.com/jackc/pgtype) +### [github.com/jackc/pgx/v5/pgtype](https://github.com/jackc/pgx/tree/master/pgtype) -Over 70 PostgreSQL types are supported including `uuid`, `hstore`, `json`, `bytea`, `numeric`, `interval`, `inet`, and arrays. These types support `database/sql` interfaces and are usable outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver. +Over 70 PostgreSQL types are supported including `uuid`, `hstore`, `json`, `bytea`, `numeric`, `interval`, `inet`, and arrays. ### [github.com/jackc/pgproto3](https://github.com/jackc/pgproto3) diff --git a/bench_test.go b/bench_test.go index f2d98bab..1a35a1a4 100644 --- a/bench_test.go +++ b/bench_test.go @@ -14,8 +14,8 @@ import ( "github.com/jackc/pgconn" "github.com/jackc/pgconn/stmtcache" - "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" "github.com/stretchr/testify/require" ) @@ -459,12 +459,12 @@ func newBenchmarkWriteTableCopyFromSrc(count int) pgx.CopyFromSource { row: []interface{}{ "varchar_1", "varchar_2", - &pgtype.Text{Status: pgtype.Null}, + &pgtype.Text{}, time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), - &pgtype.Date{Status: pgtype.Null}, + &pgtype.Date{}, 1, 2, - &pgtype.Int4{Status: pgtype.Null}, + &pgtype.Int4{}, time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local), true, diff --git a/conn.go b/conn.go index 102158ab..b0cbf72b 100644 --- a/conn.go +++ b/conn.go @@ -11,8 +11,8 @@ import ( "github.com/jackc/pgconn" "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgproto3/v2" - "github.com/jackc/pgtype" "github.com/jackc/pgx/v4/internal/sanitize" + "github.com/jackc/pgx/v4/pgtype" ) // ConnConfig contains all the options used to establish a connection. It must be created by ParseConfig and @@ -508,7 +508,7 @@ func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, argu } for i := range sd.Fields { - c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) + c.eqb.AppendResultFormat(c.ConnInfo().FormatCodeForOID(sd.Fields[i].DataTypeOID)) } return nil @@ -668,7 +668,7 @@ optionLoop: if resultFormats == nil { for i := range sd.Fields { - c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) + c.eqb.AppendResultFormat(c.ConnInfo().FormatCodeForOID(sd.Fields[i].DataTypeOID)) } resultFormats = c.eqb.resultFormats @@ -819,7 +819,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { } for i := range sd.Fields { - c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) + c.eqb.AppendResultFormat(c.ConnInfo().FormatCodeForOID(sd.Fields[i].DataTypeOID)) } if sd.Name == "" { diff --git a/conn_test.go b/conn_test.go index beddcdcd..d18ad1d9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -10,8 +10,8 @@ import ( "github.com/jackc/pgconn" "github.com/jackc/pgconn/stmtcache" - "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/example_custom_type_test.go b/example_custom_type_test.go index 34331f5b..7723df6a 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -7,16 +7,16 @@ import ( "regexp" "strconv" - "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" ) var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) // Point represents a point that may be null. type Point struct { - X, Y float64 // Coordinates of point - Status pgtype.Status + X, Y float64 // Coordinates of point + Valid bool } func (dst *Point) Set(src interface{}) error { @@ -24,14 +24,11 @@ func (dst *Point) Set(src interface{}) error { } func (dst *Point) Get() interface{} { - switch dst.Status { - case pgtype.Present: - return dst - case pgtype.Null: + if !dst.Valid { return nil - default: - return dst.Status } + + return dst } func (src *Point) AssignTo(dst interface{}) error { @@ -40,7 +37,7 @@ func (src *Point) AssignTo(dst interface{}) error { func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { if src == nil { - *dst = Point{Status: pgtype.Null} + *dst = Point{} return nil } @@ -59,13 +56,13 @@ func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { return fmt.Errorf("Received invalid point: %v", s) } - *dst = Point{X: x, Y: y, Status: pgtype.Present} + *dst = Point{X: x, Y: y, Valid: true} return nil } func (src *Point) String() string { - if src.Status == pgtype.Null { + if !src.Valid { return "null point" } diff --git a/extended_query_builder.go b/extended_query_builder.go index d06f63fd..53ea75ff 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -5,7 +5,7 @@ import ( "fmt" "reflect" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4/pgtype" ) type extendedQueryBuilder struct { diff --git a/go.mod b/go.mod index 1bc04650..577cc76a 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,6 @@ require ( github.com/jackc/pgconn v1.10.1 github.com/jackc/pgio v1.0.0 github.com/jackc/pgproto3/v2 v2.2.0 - github.com/jackc/pgtype v1.9.1 github.com/jackc/puddle v1.2.0 github.com/rs/zerolog v1.15.0 github.com/shopspring/decimal v1.2.0 diff --git a/go.sum b/go.sum index 90193e6d..e95937cc 100644 --- a/go.sum +++ b/go.sum @@ -29,9 +29,6 @@ github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= -github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= -github.com/jackc/pgconn v1.10.0 h1:4EYhlDVEMsJ30nNj0mmgwIUXoq7e9sMJrVC2ED6QlCU= -github.com/jackc/pgconn v1.10.0/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= github.com/jackc/pgconn v1.10.1 h1:DzdIHIjG1AxGwoEEqS+mGsURyjt4enSmqzACXvVzOT8= github.com/jackc/pgconn v1.10.1/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= @@ -49,7 +46,6 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.2.0 h1:r7JypeP2D3onoQTCxWdTpCtJ4D+qpKr0TxvoyMhZ5ns= github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= @@ -58,23 +54,11 @@ github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4 github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= -github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= -github.com/jackc/pgtype v1.8.1 h1:9k0IXtdJXHJbyAWQgbWr1lU+MEhPXZz6RIXxfR5oxXs= -github.com/jackc/pgtype v1.8.1/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= -github.com/jackc/pgtype v1.9.0 h1:/SH1RxEtltvJgsDqp3TbiTFApD3mey3iygpuEGeuBXk= -github.com/jackc/pgtype v1.9.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= -github.com/jackc/pgtype v1.9.1 h1:MJc2s0MFS8C3ok1wQTdQxWuXQcB6+HwAm5x1CzW7mf0= -github.com/jackc/pgtype v1.9.1/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= -github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.1.3 h1:JnPg/5Q9xVJGfjsO5CPUOjnJps1JaRUm8I9FXVCFK94= -github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.1.4 h1:5Ey/o5IfV7dYX6Znivq+N9MdK1S18OJI5OJq6EAAADw= -github.com/jackc/puddle v1.1.4/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.2.0 h1:DNDKdn/pDrWvDWyT2FYvpZVE81OAhWrjCv19I9n108Q= github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= @@ -89,16 +73,13 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= -github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.1 h1:G1f5SKeVxmagw/IyvzvtZE4Gybcc4Tr1tf7I8z0XgOg= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= -github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE= -github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7 h1:UvyT9uN+3r7yLEYSlJsbQGdsaB/a0DlgWP3pql6iwOc= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -127,13 +108,11 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.5.0 h1:OI5t8sDa1Or+q8AeE+yKeB/SDYioSHAgcVljj9JIETY= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= -go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= -go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.3.0 h1:sFPn2GLc3poCkfrpIXGhBD2X0CMIo4Q/zSULXrj/+uc= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= -go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A= -go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= @@ -144,7 +123,6 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= @@ -153,7 +131,6 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= -golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -168,8 +145,6 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -187,13 +162,11 @@ golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5 h1:hKsoRgsbwY1NafxrwTs+k64bikrLBkAgPir1TNCj3Zs= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20200103221440-774c71fcf114 h1:DnSr2mCsxyCE6ZgIkmcWUQY2R5cH/6wL7eIxEmQOMSE= -golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= diff --git a/messages.go b/messages.go index 5324cbb5..d7af6973 100644 --- a/messages.go +++ b/messages.go @@ -3,7 +3,7 @@ package pgx import ( "database/sql/driver" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4/pgtype" ) func convertDriverValuers(args []interface{}) ([]interface{}, error) { diff --git a/pgtype/CHANGELOG.md b/pgtype/CHANGELOG.md deleted file mode 100644 index e34c7979..00000000 --- a/pgtype/CHANGELOG.md +++ /dev/null @@ -1,121 +0,0 @@ -# 1.9.1 (November 28, 2021) - -* Fix: binary timestamp is assumed to be in UTC (restored behavior changed in v1.9.0) - -# 1.9.0 (November 20, 2021) - -* Fix binary hstore null decoding -* Add shopspring/decimal.NullDecimal support to integration (Eli Treuherz) -* Inet.Set supports bare IP address (Carl Dunham) -* Add zeronull.Float8 -* Fix NULL being lost when scanning unknown OID into sql.Scanner -* Fix BPChar.AssignTo **rune -* Add support for fmt.Stringer and driver.Valuer in String fields encoding (Jan Dubsky) -* Fix really big timestamp(tz)s binary format parsing (e.g. year 294276) (Jim Tsao) -* Support `map[string]*string` as hstore (Adrian Sieger) -* Fix parsing text array with negative bounds -* Add infinity support for numeric (Jim Tsao) - -# 1.8.1 (July 24, 2021) - -* Cleaned up Go module dependency chain - -# 1.8.0 (July 10, 2021) - -* Maintain host bits for inet types (Cameron Daniel) -* Support pointers of wrapping structs (Ivan Daunis) -* Register JSONBArray at NewConnInfo() (Rueian) -* CompositeTextScanner handles backslash escapes - -# 1.7.0 (March 25, 2021) - -* Fix scanning int into **sql.Scanner implementor -* Add tsrange array type (Vasilii Novikov) -* Fix: escaped strings when they start or end with a newline char (Stephane Martin) -* Accept nil *time.Time in Time.Set -* Fix numeric NaN support -* Use Go 1.13 errors instead of xerrors - -# 1.6.2 (December 3, 2020) - -* Fix panic on assigning empty array to non-slice or array -* Fix text array parsing disambiguates NULL and "NULL" -* Fix Timestamptz.DecodeText with too short text - -# 1.6.1 (October 31, 2020) - -* Fix simple protocol empty array support - -# 1.6.0 (October 24, 2020) - -* Fix AssignTo pointer to pointer to slice and named types. -* Fix zero length array assignment (Simo Haasanen) -* Add float64, float32 convert to int2, int4, int8 (lqu3j) -* Support setting infinite timestamps (Erik Agsjö) -* Polygon improvements (duohedron) -* Fix Inet.Set with nil (Tomas Volf) - -# 1.5.0 (September 26, 2020) - -* Add slice of slice mapping to multi-dimensional arrays (Simo Haasanen) -* Fix JSONBArray -* Fix selecting empty array -* Text formatted values except bytea can be directly scanned to []byte -* Add JSON marshalling for UUID (bakmataliev) -* Improve point type conversions (bakmataliev) - -# 1.4.2 (July 22, 2020) - -* Fix encoding of a large composite data type (Yaz Saito) - -# 1.4.1 (July 14, 2020) - -* Fix ArrayType DecodeBinary empty array breaks future reads - -# 1.4.0 (June 27, 2020) - -* Add JSON support to ext/gofrs-uuid -* Performance improvements in Scan path -* Improved ext/shopspring-numeric binary decoding performance -* Add composite type support (Maxim Ivanov and Jack Christensen) -* Add better generic enum type support -* Add generic array type support -* Clarify and normalize Value semantics -* Fix hstore with empty string values -* Numeric supports NaN values (leighhopcroft) -* Add slice of pointer support to array types (megaturbo) -* Add jsonb array type (tserakhau) -* Allow converting intervals with months and days to duration - -# 1.3.0 (March 30, 2020) - -* Get implemented on T instead of *T -* Set will call Get on src if possible -* Range types Set method supports its own type, string, and nil -* Date.Set parses string -* Fix correct format verb for unknown type error (Robert Welin) -* Truncate nanoseconds in EncodeText for Timestamptz and Timestamp - -# 1.2.0 (February 5, 2020) - -* Add zeronull package for easier NULL <-> zero conversion -* Add JSON marshalling for shopspring-numeric extension -* Add JSON marshalling for Bool, Date, JSON/B, Timestamptz (Jeffrey Stiles) -* Fix null status in UnmarshalJSON for some types (Jeffrey Stiles) - -# 1.1.0 (January 11, 2020) - -* Add PostgreSQL time type support -* Add more automatic conversions of integer arrays of different types (Jean-Philippe Quéméner) - -# 1.0.3 (November 16, 2019) - -* Support initializing Array types from a slice of the value (Alex Gaynor) - -# 1.0.2 (October 22, 2019) - -* Fix scan into null into pointer to pointer implementing Decode* interface. (Jeremy Altavilla) - -# 1.0.1 (September 19, 2019) - -* Fix daterange OID diff --git a/pgtype/LICENSE b/pgtype/LICENSE deleted file mode 100644 index 5c486c39..00000000 --- a/pgtype/LICENSE +++ /dev/null @@ -1,22 +0,0 @@ -Copyright (c) 2013-2021 Jack Christensen - -MIT License - -Permission is hereby granted, free of charge, to any person obtaining -a copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, -distribute, sublicense, and/or sell copies of the Software, and to -permit persons to whom the Software is furnished to do so, subject to -the following conditions: - -The above copyright notice and this permission notice shall be -included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/pgtype/README.md b/pgtype/README.md deleted file mode 100644 index bc4e72f9..00000000 --- a/pgtype/README.md +++ /dev/null @@ -1,8 +0,0 @@ -[![](https://godoc.org/github.com/jackc/pgtype?status.svg)](https://godoc.org/github.com/jackc/pgtype) -![CI](https://github.com/jackc/pgtype/workflows/CI/badge.svg) - -# pgtype - -pgtype implements Go types for over 70 PostgreSQL types. pgtype is the type system underlying the -https://github.com/jackc/pgx PostgreSQL driver. These types support the binary format for enhanced performance with pgx. -They also support the database/sql `Scan` and `Value` interfaces. diff --git a/pgtype/aclitem_array_test.go b/pgtype/aclitem_array_test.go index 0d6adb1d..736d0967 100644 --- a/pgtype/aclitem_array_test.go +++ b/pgtype/aclitem_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestACLItemArrayTranscode(t *testing.T) { diff --git a/pgtype/aclitem_test.go b/pgtype/aclitem_test.go index 4e9bc5b0..afc6a1e3 100644 --- a/pgtype/aclitem_test.go +++ b/pgtype/aclitem_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestACLItemTranscode(t *testing.T) { diff --git a/pgtype/array_test.go b/pgtype/array_test.go index f1fe90f4..549467bf 100644 --- a/pgtype/array_test.go +++ b/pgtype/array_test.go @@ -4,7 +4,7 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4/pgtype" "github.com/stretchr/testify/require" ) diff --git a/pgtype/array_type_test.go b/pgtype/array_type_test.go index 626df4dc..62b25dfa 100644 --- a/pgtype/array_type_test.go +++ b/pgtype/array_type_test.go @@ -4,8 +4,8 @@ import ( "context" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" "github.com/stretchr/testify/require" ) diff --git a/pgtype/bit_test.go b/pgtype/bit_test.go index df5fe4cb..51c12765 100644 --- a/pgtype/bit_test.go +++ b/pgtype/bit_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestBitTranscode(t *testing.T) { diff --git a/pgtype/bool_array_test.go b/pgtype/bool_array_test.go index cfb9ad79..9278c864 100644 --- a/pgtype/bool_array_test.go +++ b/pgtype/bool_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestBoolArrayTranscode(t *testing.T) { diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go index a1ba9bb0..f323c5e7 100644 --- a/pgtype/bool_test.go +++ b/pgtype/bool_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestBoolTranscode(t *testing.T) { diff --git a/pgtype/box_test.go b/pgtype/box_test.go index c7e00553..d6a928c9 100644 --- a/pgtype/box_test.go +++ b/pgtype/box_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestBoxTranscode(t *testing.T) { diff --git a/pgtype/bpchar_array_test.go b/pgtype/bpchar_array_test.go index 277f6e3c..4714b261 100644 --- a/pgtype/bpchar_array_test.go +++ b/pgtype/bpchar_array_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestBPCharArrayTranscode(t *testing.T) { diff --git a/pgtype/bpchar_test.go b/pgtype/bpchar_test.go index fe7e651c..68fbfc9f 100644 --- a/pgtype/bpchar_test.go +++ b/pgtype/bpchar_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestChar3Transcode(t *testing.T) { diff --git a/pgtype/bytea_array_test.go b/pgtype/bytea_array_test.go index 1473eb9c..d081db11 100644 --- a/pgtype/bytea_array_test.go +++ b/pgtype/bytea_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestByteaArrayTranscode(t *testing.T) { diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go index 0f47cb7f..b87b3c96 100644 --- a/pgtype/bytea_test.go +++ b/pgtype/bytea_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestByteaTranscode(t *testing.T) { diff --git a/pgtype/cid_test.go b/pgtype/cid_test.go index 041cb805..e915e534 100644 --- a/pgtype/cid_test.go +++ b/pgtype/cid_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestCIDTranscode(t *testing.T) { diff --git a/pgtype/cidr_array_test.go b/pgtype/cidr_array_test.go index 7821cf44..93d3933d 100644 --- a/pgtype/cidr_array_test.go +++ b/pgtype/cidr_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestCIDRArrayTranscode(t *testing.T) { diff --git a/pgtype/circle_test.go b/pgtype/circle_test.go index 416a1a41..2e5a8c86 100644 --- a/pgtype/circle_test.go +++ b/pgtype/circle_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestCircleTranscode(t *testing.T) { diff --git a/pgtype/composite_bench_test.go b/pgtype/composite_bench_test.go index a1d91f8e..92330905 100644 --- a/pgtype/composite_bench_test.go +++ b/pgtype/composite_bench_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/jackc/pgio" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4/pgtype" "github.com/stretchr/testify/require" ) diff --git a/pgtype/composite_fields_test.go b/pgtype/composite_fields_test.go index be0b8125..07b3954e 100644 --- a/pgtype/composite_fields_test.go +++ b/pgtype/composite_fields_test.go @@ -4,9 +4,9 @@ import ( "context" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgtype/composite_type_test.go b/pgtype/composite_type_test.go index e06927fa..80dd4a5c 100644 --- a/pgtype/composite_type_test.go +++ b/pgtype/composite_type_test.go @@ -6,9 +6,9 @@ import ( "os" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" pgx "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgtype/custom_composite_test.go b/pgtype/custom_composite_test.go index 86203828..0cc14442 100644 --- a/pgtype/custom_composite_test.go +++ b/pgtype/custom_composite_test.go @@ -6,8 +6,8 @@ import ( "fmt" "os" - "github.com/jackc/pgtype" pgx "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" ) type MyType struct { diff --git a/pgtype/date_array_test.go b/pgtype/date_array_test.go index 421427cd..93e423a0 100644 --- a/pgtype/date_array_test.go +++ b/pgtype/date_array_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestDateArrayTranscode(t *testing.T) { diff --git a/pgtype/date_test.go b/pgtype/date_test.go index 87425540..0df84468 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestDateTranscode(t *testing.T) { diff --git a/pgtype/daterange_test.go b/pgtype/daterange_test.go index 830942d0..d0bb8d60 100644 --- a/pgtype/daterange_test.go +++ b/pgtype/daterange_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestDaterangeTranscode(t *testing.T) { diff --git a/pgtype/enum_array_test.go b/pgtype/enum_array_test.go index 7d0ff864..7b9c4d23 100644 --- a/pgtype/enum_array_test.go +++ b/pgtype/enum_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestEnumArrayTranscode(t *testing.T) { diff --git a/pgtype/enum_type_test.go b/pgtype/enum_type_test.go index 4dd88f2a..965f713a 100644 --- a/pgtype/enum_type_test.go +++ b/pgtype/enum_type_test.go @@ -5,9 +5,9 @@ import ( "context" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgtype/float4_array_test.go b/pgtype/float4_array_test.go index 9b401ac8..d28cd38c 100644 --- a/pgtype/float4_array_test.go +++ b/pgtype/float4_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestFloat4ArrayTranscode(t *testing.T) { diff --git a/pgtype/float4_test.go b/pgtype/float4_test.go index 191df65e..3ad480f5 100644 --- a/pgtype/float4_test.go +++ b/pgtype/float4_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestFloat4Transcode(t *testing.T) { diff --git a/pgtype/float8_array_test.go b/pgtype/float8_array_test.go index 52209238..6fc85993 100644 --- a/pgtype/float8_array_test.go +++ b/pgtype/float8_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestFloat8ArrayTranscode(t *testing.T) { diff --git a/pgtype/float8_test.go b/pgtype/float8_test.go index dcc45879..2bc8de0c 100644 --- a/pgtype/float8_test.go +++ b/pgtype/float8_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestFloat8Transcode(t *testing.T) { diff --git a/pgtype/go.mod b/pgtype/go.mod deleted file mode 100644 index b2f1cc10..00000000 --- a/pgtype/go.mod +++ /dev/null @@ -1,10 +0,0 @@ -module github.com/jackc/pgtype - -go 1.13 - -require ( - github.com/jackc/pgconn v1.10.1 - github.com/jackc/pgio v1.0.0 - github.com/jackc/pgx/v4 v4.14.2-0.20211129172902-cf0de913ee8f - github.com/stretchr/testify v1.7.0 -) diff --git a/pgtype/go.sum b/pgtype/go.sum deleted file mode 100644 index 2a835726..00000000 --- a/pgtype/go.sum +++ /dev/null @@ -1,180 +0,0 @@ -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= -github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= -github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= -github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= -github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= -github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= -github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= -github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= -github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= -github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= -github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= -github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= -github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= -github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= -github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= -github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= -github.com/jackc/pgconn v1.10.1 h1:DzdIHIjG1AxGwoEEqS+mGsURyjt4enSmqzACXvVzOT8= -github.com/jackc/pgconn v1.10.1/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= -github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= -github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= -github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= -github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= -github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= -github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= -github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= -github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= -github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= -github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.2.0 h1:r7JypeP2D3onoQTCxWdTpCtJ4D+qpKr0TxvoyMhZ5ns= -github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= -github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= -github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= -github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= -github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= -github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= -github.com/jackc/pgtype v1.9.1/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= -github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= -github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= -github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= -github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= -github.com/jackc/pgx/v4 v4.14.2-0.20211129172902-cf0de913ee8f h1:Y3Es3mIYatTvP4CXPXfmJtHWe8eq4E8owY6Fq61hEik= -github.com/jackc/pgx/v4 v4.14.2-0.20211129172902-cf0de913ee8f/go.mod h1:RgDuE4Z34o7XE92RpLsvFiOEfrAUT0Xt2KxvX73W06M= -github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= -github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= -github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= -github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= -github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= -github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= -github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= -github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= -github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= -github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= -github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= -go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= -go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= -go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= -go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= -go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= -go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= -golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= -golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= -golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/pgtype/hstore_array_test.go b/pgtype/hstore_array_test.go index 11290fb1..b164f598 100644 --- a/pgtype/hstore_array_test.go +++ b/pgtype/hstore_array_test.go @@ -5,9 +5,9 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestHstoreArrayTranscode(t *testing.T) { diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go index 9c26a3df..6faf7496 100644 --- a/pgtype/hstore_test.go +++ b/pgtype/hstore_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestHstoreTranscode(t *testing.T) { diff --git a/pgtype/inet_array_test.go b/pgtype/inet_array_test.go index 1019c7eb..b111746d 100644 --- a/pgtype/inet_array_test.go +++ b/pgtype/inet_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInetArrayTranscode(t *testing.T) { diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go index c2a5dc28..56705e4d 100644 --- a/pgtype/inet_test.go +++ b/pgtype/inet_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" "github.com/stretchr/testify/assert" ) diff --git a/pgtype/int2_array_test.go b/pgtype/int2_array_test.go index 78dc532a..e5366edd 100644 --- a/pgtype/int2_array_test.go +++ b/pgtype/int2_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInt2ArrayTranscode(t *testing.T) { diff --git a/pgtype/int2_test.go b/pgtype/int2_test.go index 6ed8fe90..26f43eec 100644 --- a/pgtype/int2_test.go +++ b/pgtype/int2_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInt2Transcode(t *testing.T) { diff --git a/pgtype/int4_array_test.go b/pgtype/int4_array_test.go index a9c9acd9..bcabe8ca 100644 --- a/pgtype/int4_array_test.go +++ b/pgtype/int4_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInt4ArrayTranscode(t *testing.T) { diff --git a/pgtype/int4_test.go b/pgtype/int4_test.go index 3085babd..cdff4b44 100644 --- a/pgtype/int4_test.go +++ b/pgtype/int4_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInt4Transcode(t *testing.T) { diff --git a/pgtype/int4range_test.go b/pgtype/int4range_test.go index 8b990036..a45e4779 100644 --- a/pgtype/int4range_test.go +++ b/pgtype/int4range_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInt4rangeTranscode(t *testing.T) { diff --git a/pgtype/int8_array_test.go b/pgtype/int8_array_test.go index 29eaf8cb..c4de8bb1 100644 --- a/pgtype/int8_array_test.go +++ b/pgtype/int8_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInt8ArrayTranscode(t *testing.T) { diff --git a/pgtype/int8_test.go b/pgtype/int8_test.go index 8aca741d..9f96a1e3 100644 --- a/pgtype/int8_test.go +++ b/pgtype/int8_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInt8Transcode(t *testing.T) { diff --git a/pgtype/int8range_test.go b/pgtype/int8range_test.go index f2e4098d..aefa2f53 100644 --- a/pgtype/int8range_test.go +++ b/pgtype/int8range_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestInt8rangeTranscode(t *testing.T) { diff --git a/pgtype/integration_benchmark_test.go b/pgtype/integration_benchmark_test.go index d3af7c31..cca6dd1e 100644 --- a/pgtype/integration_benchmark_test.go +++ b/pgtype/integration_benchmark_test.go @@ -6,9 +6,9 @@ import ( "context" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *testing.B) { diff --git a/pgtype/integration_benchmark_test.go.erb b/pgtype/integration_benchmark_test.go.erb index 037c96c3..d9bb7937 100644 --- a/pgtype/integration_benchmark_test.go.erb +++ b/pgtype/integration_benchmark_test.go.erb @@ -6,7 +6,7 @@ import ( "context" "testing" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype/testutil" "github.com/jackc/pgx/v4" ) diff --git a/pgtype/interval_test.go b/pgtype/interval_test.go index 844f3866..6754a222 100644 --- a/pgtype/interval_test.go +++ b/pgtype/interval_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgtype/json_test.go b/pgtype/json_test.go index c56f403f..a42b7ab4 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestJSONTranscode(t *testing.T) { diff --git a/pgtype/jsonb_array_test.go b/pgtype/jsonb_array_test.go index 4f293e9e..172892bc 100644 --- a/pgtype/jsonb_array_test.go +++ b/pgtype/jsonb_array_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestJSONBArrayTranscode(t *testing.T) { diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index 41df18fa..242014f3 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestJSONBTranscode(t *testing.T) { diff --git a/pgtype/line_test.go b/pgtype/line_test.go index c47f6512..6f38d85f 100644 --- a/pgtype/line_test.go +++ b/pgtype/line_test.go @@ -4,8 +4,8 @@ import ( "context" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestLineTranscode(t *testing.T) { diff --git a/pgtype/lseg_test.go b/pgtype/lseg_test.go index af2faf3f..9122f76b 100644 --- a/pgtype/lseg_test.go +++ b/pgtype/lseg_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestLsegTranscode(t *testing.T) { diff --git a/pgtype/macaddr_array_test.go b/pgtype/macaddr_array_test.go index a4a55cb0..4941ad80 100644 --- a/pgtype/macaddr_array_test.go +++ b/pgtype/macaddr_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestMacaddrArrayTranscode(t *testing.T) { diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go index dc475c41..9a78521b 100644 --- a/pgtype/macaddr_test.go +++ b/pgtype/macaddr_test.go @@ -6,8 +6,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestMacaddrTranscode(t *testing.T) { diff --git a/pgtype/name_test.go b/pgtype/name_test.go index 5f429d83..b71ea490 100644 --- a/pgtype/name_test.go +++ b/pgtype/name_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestNameTranscode(t *testing.T) { diff --git a/pgtype/numeric_array_test.go b/pgtype/numeric_array_test.go index ee36d1a7..82a4fb6c 100644 --- a/pgtype/numeric_array_test.go +++ b/pgtype/numeric_array_test.go @@ -6,8 +6,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestNumericArrayTranscode(t *testing.T) { diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 7f0734d0..22bd22ef 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -9,8 +9,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" "github.com/stretchr/testify/require" ) @@ -289,7 +289,7 @@ func TestNumericAssignTo(t *testing.T) { {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &_i8, expected: _int8(42)}, {src: &pgtype.Numeric{Int: big.NewInt(0)}, dst: &pi8, expected: ((*int8)(nil))}, {src: &pgtype.Numeric{Int: big.NewInt(0)}, dst: &_pi8, expected: ((*_int8)(nil))}, - {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Valid: true}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgtype/issues/27 + {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Valid: true}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgx/v4/pgtype/issues/27 {src: &pgtype.Numeric{Valid: true, NaN: true}, dst: &f64, expected: math.NaN()}, {src: &pgtype.Numeric{Valid: true, NaN: true}, dst: &f32, expected: float32(math.NaN())}, {src: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}, dst: &f64, expected: math.Inf(1)}, diff --git a/pgtype/numrange_test.go b/pgtype/numrange_test.go index b9ea7658..3e89dc73 100644 --- a/pgtype/numrange_test.go +++ b/pgtype/numrange_test.go @@ -4,8 +4,8 @@ import ( "math/big" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestNumrangeTranscode(t *testing.T) { diff --git a/pgtype/oid_value_test.go b/pgtype/oid_value_test.go index 021f81d3..e3d2e014 100644 --- a/pgtype/oid_value_test.go +++ b/pgtype/oid_value_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestOIDValueTranscode(t *testing.T) { diff --git a/pgtype/path_test.go b/pgtype/path_test.go index 9a66996e..af410540 100644 --- a/pgtype/path_test.go +++ b/pgtype/path_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestPathTranscode(t *testing.T) { diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 9bf1f242..6540842c 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -7,8 +7,8 @@ import ( "net" "testing" - "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" _ "github.com/jackc/pgx/v4/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/pgtype/pgxtype/pgxtype.go b/pgtype/pgxtype/pgxtype.go index 041f2545..db4d8926 100644 --- a/pgtype/pgxtype/pgxtype.go +++ b/pgtype/pgxtype/pgxtype.go @@ -5,8 +5,8 @@ import ( "errors" "github.com/jackc/pgconn" - "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" ) type Querier interface { diff --git a/pgtype/point_test.go b/pgtype/point_test.go index 82f58e17..b8681cd5 100644 --- a/pgtype/point_test.go +++ b/pgtype/point_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" "github.com/stretchr/testify/require" ) diff --git a/pgtype/polygon_test.go b/pgtype/polygon_test.go index 34f8d59a..25cbd7dc 100644 --- a/pgtype/polygon_test.go +++ b/pgtype/polygon_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestPolygonTranscode(t *testing.T) { diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go index eb54bf65..a27cb098 100644 --- a/pgtype/qchar_test.go +++ b/pgtype/qchar_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestQCharTranscode(t *testing.T) { diff --git a/pgtype/record.go b/pgtype/record.go index 20b119c6..5bb4d701 100644 --- a/pgtype/record.go +++ b/pgtype/record.go @@ -88,6 +88,28 @@ func prepareNewBinaryDecoder(ci *ConnInfo, fieldOID uint32, v *Value) (BinaryDec return binaryDecoder, nil } +func (Record) BinaryFormatSupported() bool { + return true +} + +func (Record) TextFormatSupported() bool { + return false +} + +func (Record) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Record) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return fmt.Errorf("text format is not supported") + } + return fmt.Errorf("unknown format code %d", format) +} + func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Record{} diff --git a/pgtype/record_test.go b/pgtype/record_test.go index c8e7d4b7..6e052b71 100644 --- a/pgtype/record_test.go +++ b/pgtype/record_test.go @@ -6,9 +6,9 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) var recordTests = []struct { diff --git a/pgtype/testutil/testutil.go b/pgtype/testutil/testutil.go index 5dded2b9..6007d7a4 100644 --- a/pgtype/testutil/testutil.go +++ b/pgtype/testutil/testutil.go @@ -8,8 +8,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" _ "github.com/jackc/pgx/v4/stdlib" ) diff --git a/pgtype/text_array_test.go b/pgtype/text_array_test.go index 4caeb692..ce4b0d20 100644 --- a/pgtype/text_array_test.go +++ b/pgtype/text_array_test.go @@ -4,13 +4,13 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// https://github.com/jackc/pgtype/issues/78 +// https://github.com/jackc/pgx/v4/pgtype/issues/78 func TestTextArrayDecodeTextNull(t *testing.T) { textArray := &pgtype.TextArray{} err := textArray.DecodeText(nil, []byte(`{abc,"NULL",NULL,def}`)) diff --git a/pgtype/text_test.go b/pgtype/text_test.go index 5f34f8c0..17201764 100644 --- a/pgtype/text_test.go +++ b/pgtype/text_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestTextTranscode(t *testing.T) { diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go index fcf93259..133e5a35 100644 --- a/pgtype/tid_test.go +++ b/pgtype/tid_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestTIDTranscode(t *testing.T) { diff --git a/pgtype/time_test.go b/pgtype/time_test.go index 4a989375..008b1448 100644 --- a/pgtype/time_test.go +++ b/pgtype/time_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestTimeTranscode(t *testing.T) { diff --git a/pgtype/timestamp_array_test.go b/pgtype/timestamp_array_test.go index 214c8a71..c354f0cf 100644 --- a/pgtype/timestamp_array_test.go +++ b/pgtype/timestamp_array_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestTimestampArrayTranscode(t *testing.T) { diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index 88e2bca8..f906462d 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -6,8 +6,8 @@ import ( "testing" "time" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" "github.com/stretchr/testify/require" ) @@ -34,7 +34,7 @@ func TestTimestampTranscode(t *testing.T) { }) } -// https://github.com/jackc/pgtype/pull/128 +// https://github.com/jackc/pgx/v4/pgtype/pull/128 func TestTimestampTranscodeBigTimeBinary(t *testing.T) { conn := testutil.MustConnectPgx(t) if _, ok := conn.ConnInfo().DataTypeForName("line"); !ok { @@ -99,7 +99,7 @@ func TestTimestampNanosecondsTruncated(t *testing.T) { } } -// https://github.com/jackc/pgtype/issues/74 +// https://github.com/jackc/pgx/v4/pgtype/issues/74 func TestTimestampDecodeTextInvalid(t *testing.T) { tstz := &pgtype.Timestamp{} err := tstz.DecodeText(nil, []byte(`eeeee`)) diff --git a/pgtype/timestamptz_array_test.go b/pgtype/timestamptz_array_test.go index 22e07b59..fd40cd35 100644 --- a/pgtype/timestamptz_array_test.go +++ b/pgtype/timestamptz_array_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestTimestamptzArrayTranscode(t *testing.T) { diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index fa2a7e89..8815d363 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -6,8 +6,8 @@ import ( "testing" "time" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" "github.com/stretchr/testify/require" ) @@ -34,7 +34,7 @@ func TestTimestamptzTranscode(t *testing.T) { }) } -// https://github.com/jackc/pgtype/pull/128 +// https://github.com/jackc/pgx/v4/pgtype/pull/128 func TestTimestamptzTranscodeBigTimeBinary(t *testing.T) { conn := testutil.MustConnectPgx(t) if _, ok := conn.ConnInfo().DataTypeForName("line"); !ok { @@ -99,7 +99,7 @@ func TestTimestamptzNanosecondsTruncated(t *testing.T) { } } -// https://github.com/jackc/pgtype/issues/74 +// https://github.com/jackc/pgx/v4/pgtype/issues/74 func TestTimestamptzDecodeTextInvalid(t *testing.T) { tstz := &pgtype.Timestamptz{} err := tstz.DecodeText(nil, []byte(`eeeee`)) diff --git a/pgtype/tsrange_test.go b/pgtype/tsrange_test.go index daea59bb..f24f824b 100644 --- a/pgtype/tsrange_test.go +++ b/pgtype/tsrange_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestTsrangeTranscode(t *testing.T) { diff --git a/pgtype/tstzrange_test.go b/pgtype/tstzrange_test.go index 49cfc63e..bf604ed5 100644 --- a/pgtype/tstzrange_test.go +++ b/pgtype/tstzrange_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" "github.com/stretchr/testify/require" ) @@ -41,7 +41,7 @@ func TestTstzrangeTranscode(t *testing.T) { }) } -// https://github.com/jackc/pgtype/issues/74 +// https://github.com/jackc/pgx/v4/pgtype/issues/74 func TestTstzRangeDecodeTextInvalid(t *testing.T) { tstzrange := &pgtype.Tstzrange{} err := tstzrange.DecodeText(nil, []byte(`[eeee,)`)) diff --git a/pgtype/uuid_array_test.go b/pgtype/uuid_array_test.go index 47afadff..b4ec2f86 100644 --- a/pgtype/uuid_array_test.go +++ b/pgtype/uuid_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestUUIDArrayTranscode(t *testing.T) { diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index 63797178..036c0dd8 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" "github.com/stretchr/testify/require" ) diff --git a/pgtype/varbit_test.go b/pgtype/varbit_test.go index b81bdc0e..1ca5357b 100644 --- a/pgtype/varbit_test.go +++ b/pgtype/varbit_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestVarbitTranscode(t *testing.T) { diff --git a/pgtype/varchar_array_test.go b/pgtype/varchar_array_test.go index cf0efd6d..c45162a0 100644 --- a/pgtype/varchar_array_test.go +++ b/pgtype/varchar_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestVarcharArrayTranscode(t *testing.T) { diff --git a/pgtype/workflows/ci.yml b/pgtype/workflows/ci.yml deleted file mode 100644 index 4b5a72f2..00000000 --- a/pgtype/workflows/ci.yml +++ /dev/null @@ -1,52 +0,0 @@ -name: CI - -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] - -jobs: - - test: - name: Test - runs-on: ubuntu-latest - - services: - postgres: - image: postgres - env: - POSTGRES_PASSWORD: secret - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - ports: - - 5432:5432 - - steps: - - - name: Set up Go 1.x - uses: actions/setup-go@v2 - with: - go-version: ^1.13 - - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - - - name: Create hstore extension - run: psql -c 'create extension hstore' - env: - PGHOST: localhost - PGUSER: postgres - PGPASSWORD: secret - PGSSLMODE: disable - - - name: Test - run: go test -v ./... - env: - PGHOST: localhost - PGUSER: postgres - PGPASSWORD: secret - PGSSLMODE: disable diff --git a/pgtype/xid_test.go b/pgtype/xid_test.go index fab10f79..5b30753a 100644 --- a/pgtype/xid_test.go +++ b/pgtype/xid_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgtype" - "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v4/pgtype/testutil" ) func TestXIDTranscode(t *testing.T) { diff --git a/pgtype/zeronull/float8.go b/pgtype/zeronull/float8.go index 07d5e1a5..3d9d4d22 100644 --- a/pgtype/zeronull/float8.go +++ b/pgtype/zeronull/float8.go @@ -3,7 +3,7 @@ package zeronull import ( "database/sql/driver" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4/pgtype" ) type Float8 float64 diff --git a/pgtype/zeronull/float8_test.go b/pgtype/zeronull/float8_test.go index 27fb785e..cdc51245 100644 --- a/pgtype/zeronull/float8_test.go +++ b/pgtype/zeronull/float8_test.go @@ -3,8 +3,8 @@ package zeronull_test import ( "testing" - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgtype/zeronull" + "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype/zeronull" ) func TestFloat8Transcode(t *testing.T) { diff --git a/pgtype/zeronull/int2.go b/pgtype/zeronull/int2.go index b3f9c328..011e96d5 100644 --- a/pgtype/zeronull/int2.go +++ b/pgtype/zeronull/int2.go @@ -3,7 +3,7 @@ package zeronull import ( "database/sql/driver" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4/pgtype" ) type Int2 int16 diff --git a/pgtype/zeronull/int2_test.go b/pgtype/zeronull/int2_test.go index 2dcb4e79..9cbd75db 100644 --- a/pgtype/zeronull/int2_test.go +++ b/pgtype/zeronull/int2_test.go @@ -3,8 +3,8 @@ package zeronull_test import ( "testing" - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgtype/zeronull" + "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype/zeronull" ) func TestInt2Transcode(t *testing.T) { diff --git a/pgtype/zeronull/int4.go b/pgtype/zeronull/int4.go index 3efca4e6..9d34c163 100644 --- a/pgtype/zeronull/int4.go +++ b/pgtype/zeronull/int4.go @@ -3,7 +3,7 @@ package zeronull import ( "database/sql/driver" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4/pgtype" ) type Int4 int32 diff --git a/pgtype/zeronull/int4_test.go b/pgtype/zeronull/int4_test.go index 309e4125..456f15d2 100644 --- a/pgtype/zeronull/int4_test.go +++ b/pgtype/zeronull/int4_test.go @@ -3,8 +3,8 @@ package zeronull_test import ( "testing" - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgtype/zeronull" + "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype/zeronull" ) func TestInt4Transcode(t *testing.T) { diff --git a/pgtype/zeronull/int8.go b/pgtype/zeronull/int8.go index 5cb063d8..185fdb8f 100644 --- a/pgtype/zeronull/int8.go +++ b/pgtype/zeronull/int8.go @@ -3,7 +3,7 @@ package zeronull import ( "database/sql/driver" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4/pgtype" ) type Int8 int64 diff --git a/pgtype/zeronull/int8_test.go b/pgtype/zeronull/int8_test.go index ae80bc0a..ca261d36 100644 --- a/pgtype/zeronull/int8_test.go +++ b/pgtype/zeronull/int8_test.go @@ -3,8 +3,8 @@ package zeronull_test import ( "testing" - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgtype/zeronull" + "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype/zeronull" ) func TestInt8Transcode(t *testing.T) { diff --git a/pgtype/zeronull/text.go b/pgtype/zeronull/text.go index afcb1a42..5fc9d94b 100644 --- a/pgtype/zeronull/text.go +++ b/pgtype/zeronull/text.go @@ -3,7 +3,7 @@ package zeronull import ( "database/sql/driver" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4/pgtype" ) type Text string diff --git a/pgtype/zeronull/text_test.go b/pgtype/zeronull/text_test.go index f08a0d2a..8595253c 100644 --- a/pgtype/zeronull/text_test.go +++ b/pgtype/zeronull/text_test.go @@ -3,8 +3,8 @@ package zeronull_test import ( "testing" - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgtype/zeronull" + "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype/zeronull" ) func TestTextTranscode(t *testing.T) { diff --git a/pgtype/zeronull/timestamp.go b/pgtype/zeronull/timestamp.go index 61787818..193bc959 100644 --- a/pgtype/zeronull/timestamp.go +++ b/pgtype/zeronull/timestamp.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "time" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4/pgtype" ) type Timestamp time.Time diff --git a/pgtype/zeronull/timestamp_test.go b/pgtype/zeronull/timestamp_test.go index ec96ff07..787c6de9 100644 --- a/pgtype/zeronull/timestamp_test.go +++ b/pgtype/zeronull/timestamp_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgtype/zeronull" + "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype/zeronull" ) func TestTimestampTranscode(t *testing.T) { diff --git a/pgtype/zeronull/timestamptz.go b/pgtype/zeronull/timestamptz.go index 4896e9b7..5ecefe64 100644 --- a/pgtype/zeronull/timestamptz.go +++ b/pgtype/zeronull/timestamptz.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "time" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4/pgtype" ) type Timestamptz time.Time diff --git a/pgtype/zeronull/timestamptz_test.go b/pgtype/zeronull/timestamptz_test.go index 3a401c49..dcbd0d58 100644 --- a/pgtype/zeronull/timestamptz_test.go +++ b/pgtype/zeronull/timestamptz_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgtype/zeronull" + "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype/zeronull" ) func TestTimestamptzTranscode(t *testing.T) { diff --git a/pgtype/zeronull/uuid.go b/pgtype/zeronull/uuid.go index 25211122..2e54a933 100644 --- a/pgtype/zeronull/uuid.go +++ b/pgtype/zeronull/uuid.go @@ -3,7 +3,7 @@ package zeronull import ( "database/sql/driver" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4/pgtype" ) type UUID [16]byte diff --git a/pgtype/zeronull/uuid_test.go b/pgtype/zeronull/uuid_test.go index 162bdf1f..e79503c6 100644 --- a/pgtype/zeronull/uuid_test.go +++ b/pgtype/zeronull/uuid_test.go @@ -3,8 +3,8 @@ package zeronull_test import ( "testing" - "github.com/jackc/pgtype/testutil" - "github.com/jackc/pgtype/zeronull" + "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v4/pgtype/zeronull" ) func TestUUIDTranscode(t *testing.T) { diff --git a/query_test.go b/query_test.go index 968c0ecc..7580d8b0 100644 --- a/query_test.go +++ b/query_test.go @@ -17,8 +17,8 @@ import ( "github.com/gofrs/uuid" "github.com/jackc/pgconn" "github.com/jackc/pgconn/stmtcache" - "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -257,7 +257,7 @@ func TestConnQueryReadRowMultipleTimes(t *testing.T) { require.Equal(t, "foo", a) require.Equal(t, "bar", b) require.Equal(t, rowCount, c) - require.Equal(t, pgtype.Null, d.Status) + require.False(t, d.Valid) require.Equal(t, rowCount, e) } } @@ -275,22 +275,22 @@ func TestConnQueryValuesWithMultipleComplexColumnsOfSameType(t *testing.T) { expected0 := &pgtype.Int8Array{ Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, } expected1 := &pgtype.Int8Array{ Elements: []pgtype.Int8{ - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, - Status: pgtype.Present, + Valid: true, } var rowCount int32 @@ -1792,7 +1792,7 @@ func TestConnSimpleProtocol(t *testing.T) { { if conn.PgConn().ParameterStatus("crdb_version") == "" { // CockroachDB doesn't support circle type. - expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Status: pgtype.Present} + expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Valid: true} actual := expected err := conn.QueryRow( context.Background(), diff --git a/rows.go b/rows.go index d57d5cbf..539ce3a5 100644 --- a/rows.go +++ b/rows.go @@ -8,7 +8,7 @@ import ( "github.com/jackc/pgconn" "github.com/jackc/pgproto3/v2" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4/pgtype" ) // Rows is the result set returned from *Conn.Query. Rows must be closed before diff --git a/stdlib/sql.go b/stdlib/sql.go index fa81e73d..20892ab3 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -64,8 +64,8 @@ import ( "time" "github.com/jackc/pgconn" - "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgtype" ) // Only intrinsic types should be binary format with database/sql. diff --git a/values.go b/values.go index 1a945475..2978e5a3 100644 --- a/values.go +++ b/values.go @@ -8,7 +8,7 @@ import ( "time" "github.com/jackc/pgio" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4/pgtype" ) // PostgreSQL format codes @@ -228,15 +228,15 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32 // determination can be made. func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid uint32, arg interface{}) int16 { switch arg := arg.(type) { - case pgtype.ParamFormatPreferrer: - return arg.PreferredParamFormat() + case pgtype.FormatSupport: + return arg.PreferredFormat() case pgtype.BinaryEncoder: return BinaryFormatCode case string, *string, pgtype.TextEncoder: return TextFormatCode } - return ci.ParamFormatCodeForOID(oid) + return ci.FormatCodeForOID(oid) } func stripNamedType(val *reflect.Value) (interface{}, bool) { diff --git a/values_test.go b/values_test.go index 6ae6c8a0..47aacf89 100644 --- a/values_test.go +++ b/values_test.go @@ -942,50 +942,50 @@ func TestEncodeTypeRename(t *testing.T) { }) } -func TestRowDecodeBinary(t *testing.T) { - t.Parallel() +// func TestRowDecodeBinary(t *testing.T) { +// t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) +// conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) +// defer closeConn(t, conn) - tests := []struct { - sql string - expected []interface{} - }{ - { - "select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)", - []interface{}{ - int32(1), - "cat", - time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(), - }, - }, - { - "select row(100.0::float, 1.09::float)", - []interface{}{ - float64(100), - float64(1.09), - }, - }, - } +// tests := []struct { +// sql string +// expected []interface{} +// }{ +// { +// "select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)", +// []interface{}{ +// int32(1), +// "cat", +// time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(), +// }, +// }, +// { +// "select row(100.0::float, 1.09::float)", +// []interface{}{ +// float64(100), +// float64(1.09), +// }, +// }, +// } - for i, tt := range tests { - var actual []interface{} +// for i, tt := range tests { +// var actual []interface{} - err := conn.QueryRow(context.Background(), tt.sql).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) - continue - } +// err := conn.QueryRow(context.Background(), tt.sql).Scan(&actual) +// if err != nil { +// t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) +// continue +// } - for j := range tt.expected { - assert.EqualValuesf(t, tt.expected[j], actual[j], "%d. [%d]", i, j) +// for j := range tt.expected { +// assert.EqualValuesf(t, tt.expected[j], actual[j], "%d. [%d]", i, j) - } +// } - ensureConnValid(t, conn) - } -} +// ensureConnValid(t, conn) +// } +// } // https://github.com/jackc/pgx/issues/810 func TestRowsScanNilThenScanValue(t *testing.T) { From 19ec4d505ffaf0d1fecefb733c722c319e5df081 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Dec 2021 13:51:24 -0600 Subject: [PATCH 0754/1158] Import to pgx main repo in pgconn subdir --- {.github => pgconn/.github}/workflows/ci.yml | 0 .gitignore => pgconn/.gitignore | 0 CHANGELOG.md => pgconn/CHANGELOG.md | 0 LICENSE => pgconn/LICENSE | 0 README.md => pgconn/README.md | 0 auth_scram.go => pgconn/auth_scram.go | 0 benchmark_test.go => pgconn/benchmark_test.go | 0 {ci => pgconn/ci}/script.bash | 0 {ci => pgconn/ci}/setup_test.bash | 0 config.go => pgconn/config.go | 0 config_test.go => pgconn/config_test.go | 0 defaults.go => pgconn/defaults.go | 0 defaults_windows.go => pgconn/defaults_windows.go | 0 doc.go => pgconn/doc.go | 0 errors.go => pgconn/errors.go | 0 errors_test.go => pgconn/errors_test.go | 0 export_test.go => pgconn/export_test.go | 0 frontend_test.go => pgconn/frontend_test.go | 0 go.mod => pgconn/go.mod | 0 go.sum => pgconn/go.sum | 0 helper_test.go => pgconn/helper_test.go | 0 {internal => pgconn/internal}/ctxwatch/context_watcher.go | 0 {internal => pgconn/internal}/ctxwatch/context_watcher_test.go | 0 pgconn.go => pgconn/pgconn.go | 0 pgconn_stress_test.go => pgconn/pgconn_stress_test.go | 0 pgconn_test.go => pgconn/pgconn_test.go | 0 {stmtcache => pgconn/stmtcache}/lru.go | 0 {stmtcache => pgconn/stmtcache}/lru_test.go | 0 {stmtcache => pgconn/stmtcache}/stmtcache.go | 0 29 files changed, 0 insertions(+), 0 deletions(-) rename {.github => pgconn/.github}/workflows/ci.yml (100%) rename .gitignore => pgconn/.gitignore (100%) rename CHANGELOG.md => pgconn/CHANGELOG.md (100%) rename LICENSE => pgconn/LICENSE (100%) rename README.md => pgconn/README.md (100%) rename auth_scram.go => pgconn/auth_scram.go (100%) rename benchmark_test.go => pgconn/benchmark_test.go (100%) rename {ci => pgconn/ci}/script.bash (100%) rename {ci => pgconn/ci}/setup_test.bash (100%) rename config.go => pgconn/config.go (100%) rename config_test.go => pgconn/config_test.go (100%) rename defaults.go => pgconn/defaults.go (100%) rename defaults_windows.go => pgconn/defaults_windows.go (100%) rename doc.go => pgconn/doc.go (100%) rename errors.go => pgconn/errors.go (100%) rename errors_test.go => pgconn/errors_test.go (100%) rename export_test.go => pgconn/export_test.go (100%) rename frontend_test.go => pgconn/frontend_test.go (100%) rename go.mod => pgconn/go.mod (100%) rename go.sum => pgconn/go.sum (100%) rename helper_test.go => pgconn/helper_test.go (100%) rename {internal => pgconn/internal}/ctxwatch/context_watcher.go (100%) rename {internal => pgconn/internal}/ctxwatch/context_watcher_test.go (100%) rename pgconn.go => pgconn/pgconn.go (100%) rename pgconn_stress_test.go => pgconn/pgconn_stress_test.go (100%) rename pgconn_test.go => pgconn/pgconn_test.go (100%) rename {stmtcache => pgconn/stmtcache}/lru.go (100%) rename {stmtcache => pgconn/stmtcache}/lru_test.go (100%) rename {stmtcache => pgconn/stmtcache}/stmtcache.go (100%) diff --git a/.github/workflows/ci.yml b/pgconn/.github/workflows/ci.yml similarity index 100% rename from .github/workflows/ci.yml rename to pgconn/.github/workflows/ci.yml diff --git a/.gitignore b/pgconn/.gitignore similarity index 100% rename from .gitignore rename to pgconn/.gitignore diff --git a/CHANGELOG.md b/pgconn/CHANGELOG.md similarity index 100% rename from CHANGELOG.md rename to pgconn/CHANGELOG.md diff --git a/LICENSE b/pgconn/LICENSE similarity index 100% rename from LICENSE rename to pgconn/LICENSE diff --git a/README.md b/pgconn/README.md similarity index 100% rename from README.md rename to pgconn/README.md diff --git a/auth_scram.go b/pgconn/auth_scram.go similarity index 100% rename from auth_scram.go rename to pgconn/auth_scram.go diff --git a/benchmark_test.go b/pgconn/benchmark_test.go similarity index 100% rename from benchmark_test.go rename to pgconn/benchmark_test.go diff --git a/ci/script.bash b/pgconn/ci/script.bash similarity index 100% rename from ci/script.bash rename to pgconn/ci/script.bash diff --git a/ci/setup_test.bash b/pgconn/ci/setup_test.bash similarity index 100% rename from ci/setup_test.bash rename to pgconn/ci/setup_test.bash diff --git a/config.go b/pgconn/config.go similarity index 100% rename from config.go rename to pgconn/config.go diff --git a/config_test.go b/pgconn/config_test.go similarity index 100% rename from config_test.go rename to pgconn/config_test.go diff --git a/defaults.go b/pgconn/defaults.go similarity index 100% rename from defaults.go rename to pgconn/defaults.go diff --git a/defaults_windows.go b/pgconn/defaults_windows.go similarity index 100% rename from defaults_windows.go rename to pgconn/defaults_windows.go diff --git a/doc.go b/pgconn/doc.go similarity index 100% rename from doc.go rename to pgconn/doc.go diff --git a/errors.go b/pgconn/errors.go similarity index 100% rename from errors.go rename to pgconn/errors.go diff --git a/errors_test.go b/pgconn/errors_test.go similarity index 100% rename from errors_test.go rename to pgconn/errors_test.go diff --git a/export_test.go b/pgconn/export_test.go similarity index 100% rename from export_test.go rename to pgconn/export_test.go diff --git a/frontend_test.go b/pgconn/frontend_test.go similarity index 100% rename from frontend_test.go rename to pgconn/frontend_test.go diff --git a/go.mod b/pgconn/go.mod similarity index 100% rename from go.mod rename to pgconn/go.mod diff --git a/go.sum b/pgconn/go.sum similarity index 100% rename from go.sum rename to pgconn/go.sum diff --git a/helper_test.go b/pgconn/helper_test.go similarity index 100% rename from helper_test.go rename to pgconn/helper_test.go diff --git a/internal/ctxwatch/context_watcher.go b/pgconn/internal/ctxwatch/context_watcher.go similarity index 100% rename from internal/ctxwatch/context_watcher.go rename to pgconn/internal/ctxwatch/context_watcher.go diff --git a/internal/ctxwatch/context_watcher_test.go b/pgconn/internal/ctxwatch/context_watcher_test.go similarity index 100% rename from internal/ctxwatch/context_watcher_test.go rename to pgconn/internal/ctxwatch/context_watcher_test.go diff --git a/pgconn.go b/pgconn/pgconn.go similarity index 100% rename from pgconn.go rename to pgconn/pgconn.go diff --git a/pgconn_stress_test.go b/pgconn/pgconn_stress_test.go similarity index 100% rename from pgconn_stress_test.go rename to pgconn/pgconn_stress_test.go diff --git a/pgconn_test.go b/pgconn/pgconn_test.go similarity index 100% rename from pgconn_test.go rename to pgconn/pgconn_test.go diff --git a/stmtcache/lru.go b/pgconn/stmtcache/lru.go similarity index 100% rename from stmtcache/lru.go rename to pgconn/stmtcache/lru.go diff --git a/stmtcache/lru_test.go b/pgconn/stmtcache/lru_test.go similarity index 100% rename from stmtcache/lru_test.go rename to pgconn/stmtcache/lru_test.go diff --git a/stmtcache/stmtcache.go b/pgconn/stmtcache/stmtcache.go similarity index 100% rename from stmtcache/stmtcache.go rename to pgconn/stmtcache/stmtcache.go From 0e293b966c56e88027b1316651e7623611e84892 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Dec 2021 14:06:57 -0600 Subject: [PATCH 0755/1158] Finish import of pgconn --- .github/workflows/ci.yml | 37 +++++ README.md | 2 +- batch.go | 2 +- batch_test.go | 4 +- bench_test.go | 4 +- ci/setup_test.bash | 11 ++ conn.go | 4 +- conn_test.go | 4 +- copy_from.go | 2 +- copy_from_test.go | 2 +- go.mod | 9 +- go.sum | 16 ++- helper_test.go | 2 +- large_objects_test.go | 2 +- pgbouncer_test.go | 4 +- pgconn/.github/workflows/ci.yml | 81 ----------- pgconn/.gitignore | 3 - pgconn/CHANGELOG.md | 129 ----------------- pgconn/LICENSE | 22 --- pgconn/README.md | 3 - pgconn/benchmark_test.go | 2 +- pgconn/ci/script.bash | 10 -- pgconn/ci/setup_test.bash | 59 -------- pgconn/config_test.go | 2 +- pgconn/errors_test.go | 2 +- pgconn/frontend_test.go | 2 +- pgconn/go.mod | 15 -- pgconn/go.sum | 130 ------------------ pgconn/helper_test.go | 2 +- .../internal/ctxwatch/context_watcher_test.go | 2 +- pgconn/pgconn.go | 2 +- pgconn/pgconn_stress_test.go | 2 +- pgconn/pgconn_test.go | 2 +- pgconn/stmtcache/lru.go | 2 +- pgconn/stmtcache/lru_test.go | 4 +- pgconn/stmtcache/stmtcache.go | 2 +- pgtype/pgxtype/pgxtype.go | 2 +- pgxpool/batch_results.go | 2 +- pgxpool/common_test.go | 2 +- pgxpool/conn.go | 2 +- pgxpool/pool.go | 2 +- pgxpool/rows.go | 2 +- pgxpool/tx.go | 2 +- query_test.go | 4 +- rows.go | 2 +- stdlib/sql.go | 2 +- stdlib/sql_test.go | 2 +- tx.go | 2 +- tx_test.go | 2 +- 49 files changed, 109 insertions(+), 502 deletions(-) delete mode 100644 pgconn/.github/workflows/ci.yml delete mode 100644 pgconn/.gitignore delete mode 100644 pgconn/CHANGELOG.md delete mode 100644 pgconn/LICENSE delete mode 100755 pgconn/ci/script.bash delete mode 100755 pgconn/ci/setup_test.bash delete mode 100644 pgconn/go.mod delete mode 100644 pgconn/go.sum diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b66bba46..36a9ec4e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,16 +19,47 @@ jobs: include: - pg-version: 10 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: 11 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: 12 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: 13 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: 14 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: cockroachdb pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" + pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" steps: @@ -49,3 +80,9 @@ jobs: run: go test -race ./... env: PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }} + PGX_TEST_CONN_STRING: ${{ matrix.pgx-test-conn-string }} + PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }} + PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }} + PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }} + PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }} + PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }} diff --git a/README.md b/README.md index 5a65a00f..63a0cff8 100644 --- a/README.md +++ b/README.md @@ -160,7 +160,7 @@ pgx follows semantic versioning for the documented public API on stable releases pgx is the head of a family of PostgreSQL libraries. Many of these can be used independently. Many can also be accessed from pgx for lower-level control. -### [github.com/jackc/pgconn](https://github.com/jackc/pgconn) +### [github.com/jackc/v4/pgconn](https://github.com/jackc/pgx/tree/master/pgconn) `pgconn` is a lower-level PostgreSQL database driver that operates at nearly the same level as the C library `libpq`. diff --git a/batch.go b/batch.go index f0479ea6..547a0efc 100644 --- a/batch.go +++ b/batch.go @@ -4,7 +4,7 @@ import ( "context" "errors" - "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4/pgconn" ) type batchItem struct { diff --git a/batch_test.go b/batch_test.go index 988a1682..c0eb32d5 100644 --- a/batch_test.go +++ b/batch_test.go @@ -5,9 +5,9 @@ import ( "os" "testing" - "github.com/jackc/pgconn" - "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v4/pgconn/stmtcache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/bench_test.go b/bench_test.go index 1a35a1a4..40e24da3 100644 --- a/bench_test.go +++ b/bench_test.go @@ -12,9 +12,9 @@ import ( "testing" "time" - "github.com/jackc/pgconn" - "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v4/pgconn/stmtcache" "github.com/jackc/pgx/v4/pgtype" "github.com/stretchr/testify/require" ) diff --git a/ci/setup_test.bash b/ci/setup_test.bash index c279a7a4..8f3f26f0 100755 --- a/ci/setup_test.bash +++ b/ci/setup_test.bash @@ -13,6 +13,10 @@ then echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf @@ -24,8 +28,15 @@ then psql -U postgres -c 'create database pgx_test' psql -U postgres pgx_test -c 'create extension hstore' psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)' + psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" psql -U postgres -c "create user `whoami`" + psql -U postgres -c "create user pgx_replication with replication password 'secret'" + + # The tricky test user, below, has to actually exist so that it can be used in a test + # of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. + psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" fi if [[ "${PGVERSION-}" =~ ^cockroach ]] diff --git a/conn.go b/conn.go index b0cbf72b..22cc4f74 100644 --- a/conn.go +++ b/conn.go @@ -8,10 +8,10 @@ import ( "strings" "time" - "github.com/jackc/pgconn" - "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgproto3/v2" "github.com/jackc/pgx/v4/internal/sanitize" + "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v4/pgconn/stmtcache" "github.com/jackc/pgx/v4/pgtype" ) diff --git a/conn_test.go b/conn_test.go index d18ad1d9..fa764b18 100644 --- a/conn_test.go +++ b/conn_test.go @@ -8,9 +8,9 @@ import ( "testing" "time" - "github.com/jackc/pgconn" - "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v4/pgconn/stmtcache" "github.com/jackc/pgx/v4/pgtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/copy_from.go b/copy_from.go index 3494e28f..52d69cdd 100644 --- a/copy_from.go +++ b/copy_from.go @@ -7,8 +7,8 @@ import ( "io" "time" - "github.com/jackc/pgconn" "github.com/jackc/pgio" + "github.com/jackc/pgx/v4/pgconn" ) // CopyFromRows returns a CopyFromSource interface over the provided rows slice diff --git a/copy_from_test.go b/copy_from_test.go index 20e5b247..b514fd8b 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -8,8 +8,8 @@ import ( "testing" "time" - "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" "github.com/stretchr/testify/require" ) diff --git a/go.mod b/go.mod index 577cc76a..ebcdc2f5 100644 --- a/go.mod +++ b/go.mod @@ -7,14 +7,21 @@ require ( github.com/cockroachdb/apd v1.1.0 github.com/go-kit/log v0.1.0 github.com/gofrs/uuid v4.0.0+incompatible - github.com/jackc/pgconn v1.10.1 + github.com/jackc/chunkreader/v2 v2.0.1 github.com/jackc/pgio v1.0.0 + github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 + github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.2.0 + github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/jackc/puddle v1.2.0 + github.com/lib/pq v1.10.4 // indirect + github.com/mattn/go-colorable v0.1.12 // indirect github.com/rs/zerolog v1.15.0 github.com/shopspring/decimal v1.2.0 github.com/sirupsen/logrus v1.4.2 github.com/stretchr/testify v1.7.0 go.uber.org/zap v1.13.0 + golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 + golang.org/x/text v0.3.6 gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec ) diff --git a/go.sum b/go.sum index e95937cc..97b704ee 100644 --- a/go.sum +++ b/go.sum @@ -28,9 +28,8 @@ github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9 github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.9.0 h1:gqibKSTJup/ahCsNKyMZAniPuZEfIqfXFc8FOWVYR+Q= github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= -github.com/jackc/pgconn v1.10.1 h1:DzdIHIjG1AxGwoEEqS+mGsURyjt4enSmqzACXvVzOT8= -github.com/jackc/pgconn v1.10.1/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= @@ -73,13 +72,16 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/mattn/go-colorable v0.1.1 h1:G1f5SKeVxmagw/IyvzvtZE4Gybcc4Tr1tf7I8z0XgOg= +github.com/lib/pq v1.10.4 h1:SO9z7FRPzA03QhHKJrH5BXA6HU1rS4V2nIVrrNC1iYk= +github.com/lib/pq v1.10.4/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.7 h1:UvyT9uN+3r7yLEYSlJsbQGdsaB/a0DlgWP3pql6iwOc= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -146,8 +148,10 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6 h1:foEbQz/B0Oz6YIqu/69kfXPYeFQAuuMYFkjaqXzl5Wo= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/helper_test.go b/helper_test.go index b9cd21c1..25280ac3 100644 --- a/helper_test.go +++ b/helper_test.go @@ -7,8 +7,8 @@ import ( "github.com/stretchr/testify/assert" - "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" "github.com/stretchr/testify/require" ) diff --git a/large_objects_test.go b/large_objects_test.go index 672729ee..167e50d3 100644 --- a/large_objects_test.go +++ b/large_objects_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" ) func TestLargeObjects(t *testing.T) { diff --git a/pgbouncer_test.go b/pgbouncer_test.go index e3fa4d0c..abc349a1 100644 --- a/pgbouncer_test.go +++ b/pgbouncer_test.go @@ -5,9 +5,9 @@ import ( "os" "testing" - "github.com/jackc/pgconn" - "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v4/pgconn/stmtcache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgconn/.github/workflows/ci.yml b/pgconn/.github/workflows/ci.yml deleted file mode 100644 index d84462da..00000000 --- a/pgconn/.github/workflows/ci.yml +++ /dev/null @@ -1,81 +0,0 @@ -name: CI - -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] - -jobs: - - test: - name: Test - runs-on: ubuntu-18.04 - - strategy: - matrix: - go-version: [1.15, 1.16] - pg-version: [9.6, 10, 11, 12, 13, cockroachdb] - include: - - pg-version: 9.6 - pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" - pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - - pg-version: 10 - pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" - pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - - pg-version: 11 - pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" - pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - - pg-version: 12 - pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" - pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - - pg-version: 13 - pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" - pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - - pg-version: cockroachdb - pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" - - steps: - - - name: Set up Go 1.x - uses: actions/setup-go@v2 - with: - go-version: ${{ matrix.go-version }} - - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - - - name: Setup database server for testing - run: ci/setup_test.bash - env: - PGVERSION: ${{ matrix.pg-version }} - - - name: Test - run: go test -v -race ./... - env: - PGX_TEST_CONN_STRING: ${{ matrix.pgx-test-conn-string }} - PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }} - PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }} - PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }} - PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }} - PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }} diff --git a/pgconn/.gitignore b/pgconn/.gitignore deleted file mode 100644 index e980f555..00000000 --- a/pgconn/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -.envrc -vendor/ -.vscode diff --git a/pgconn/CHANGELOG.md b/pgconn/CHANGELOG.md deleted file mode 100644 index 63933a3a..00000000 --- a/pgconn/CHANGELOG.md +++ /dev/null @@ -1,129 +0,0 @@ -# 1.10.1 (November 20, 2021) - -* Close without waiting for response (Kei Kamikawa) -* Save waiting for network round-trip in CopyFrom (Rueian) -* Fix concurrency issue with ContextWatcher -* LRU.Get always checks context for cancellation / expiration (Georges Varouchas) - -# 1.10.0 (July 24, 2021) - -* net.Timeout errors are no longer returned when a query is canceled via context. A wrapped context error is returned. - -# 1.9.0 (July 10, 2021) - -* pgconn.Timeout only is true for errors originating in pgconn (Michael Darr) -* Add defaults for sslcert, sslkey, and sslrootcert (Joshua Brindle) -* Solve issue with 'sslmode=verify-full' when there are multiple hosts (mgoddard) -* Fix default host when parsing URL without host but with port -* Allow dbname query parameter in URL conn string -* Update underlying dependencies - -# 1.8.1 (March 25, 2021) - -* Better connection string sanitization (ip.novikov) -* Use proper pgpass location on Windows (Moshe Katz) -* Use errors instead of golang.org/x/xerrors -* Resume fallback on server error in Connect (Andrey Borodin) - -# 1.8.0 (December 3, 2020) - -* Add StatementErrored method to stmtcache.Cache. This allows the cache to purge invalidated prepared statements. (Ethan Pailes) - -# 1.7.2 (November 3, 2020) - -* Fix data value slices into work buffer with capacities larger than length. - -# 1.7.1 (October 31, 2020) - -* Do not asyncClose after receiving FATAL error from PostgreSQL server - -# 1.7.0 (September 26, 2020) - -* Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded -* Add ReceiveResults (Sebastiaan Mannem) -* Fix parsing DSN connection with bad backslash -* Add PgConn.CleanupDone so connection pools can determine when async close is complete - -# 1.6.4 (July 29, 2020) - -* Fix deadlock on error after CommandComplete but before ReadyForQuery -* Fix panic on parsing DSN with trailing '=' - -# 1.6.3 (July 22, 2020) - -* Fix error message after AppendCertsFromPEM failure (vahid-sohrabloo) - -# 1.6.2 (July 14, 2020) - -* Update pgservicefile library - -# 1.6.1 (June 27, 2020) - -* Update golang.org/x/crypto to latest -* Update golang.org/x/text to 0.3.3 -* Fix error handling for bad PGSERVICE definition -* Redact passwords in ParseConfig errors (Lukas Vogel) - -# 1.6.0 (June 6, 2020) - -* Fix panic when closing conn during cancellable query -* Fix behavior of sslmode=require with sslrootcert present (Petr Jediný) -* Fix field descriptions available after command concluded (Tobias Salzmann) -* Support connect_timeout (georgysavva) -* Handle IPv6 in connection URLs (Lukas Vogel) -* Fix ValidateConnect with cancelable context -* Improve CopyFrom performance -* Add Config.Copy (georgysavva) - -# 1.5.0 (March 30, 2020) - -* Update golang.org/x/crypto for security fix -* Implement "verify-ca" SSL mode (Greg Curtis) - -# 1.4.0 (March 7, 2020) - -* Fix ExecParams and ExecPrepared handling of empty query. -* Support reading config from PostgreSQL service files. - -# 1.3.2 (February 14, 2020) - -* Update chunkreader to v2.0.1 for optimized default buffer size. - -# 1.3.1 (February 5, 2020) - -* Fix CopyFrom deadlock when multiple NoticeResponse received during copy - -# 1.3.0 (January 23, 2020) - -* Add Hijack and Construct. -* Update pgproto3 to v2.0.1. - -# 1.2.1 (January 13, 2020) - -* Fix data race in context cancellation introduced in v1.2.0. - -# 1.2.0 (January 11, 2020) - -## Features - -* Add Insert(), Update(), Delete(), and Select() statement type query methods to CommandTag. -* Add PgError.SQLState method. This could be used for compatibility with other drivers and databases. - -## Performance - -* Improve performance when context.Background() is used. (bakape) -* CommandTag.RowsAffected is faster and does not allocate. - -## Fixes - -* Try to cancel any in-progress query when a conn is closed by ctx cancel. -* Handle NoticeResponse during CopyFrom. -* Ignore errors sending Terminate message while closing connection. This mimics the behavior of libpq PGfinish. - -# 1.1.0 (October 12, 2019) - -* Add PgConn.IsBusy() method. - -# 1.0.1 (September 19, 2019) - -* Fix statement cache not properly cleaning discarded statements. diff --git a/pgconn/LICENSE b/pgconn/LICENSE deleted file mode 100644 index aebadd6c..00000000 --- a/pgconn/LICENSE +++ /dev/null @@ -1,22 +0,0 @@ -Copyright (c) 2019-2021 Jack Christensen - -MIT License - -Permission is hereby granted, free of charge, to any person obtaining -a copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, -distribute, sublicense, and/or sell copies of the Software, and to -permit persons to whom the Software is furnished to do so, subject to -the following conditions: - -The above copyright notice and this permission notice shall be -included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/pgconn/README.md b/pgconn/README.md index 1c698a11..4f0349f2 100644 --- a/pgconn/README.md +++ b/pgconn/README.md @@ -1,6 +1,3 @@ -[![](https://godoc.org/github.com/jackc/pgconn?status.svg)](https://godoc.org/github.com/jackc/pgconn) -![CI](https://github.com/jackc/pgconn/workflows/CI/badge.svg) - # pgconn Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq. diff --git a/pgconn/benchmark_test.go b/pgconn/benchmark_test.go index ced785b6..48955cf6 100644 --- a/pgconn/benchmark_test.go +++ b/pgconn/benchmark_test.go @@ -7,7 +7,7 @@ import ( "strings" "testing" - "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4/pgconn" "github.com/stretchr/testify/require" ) diff --git a/pgconn/ci/script.bash b/pgconn/ci/script.bash deleted file mode 100755 index 5bf1b77e..00000000 --- a/pgconn/ci/script.bash +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env bash -set -eux - -if [ "${PGVERSION-}" != "" ] -then - go test -v -race ./... -elif [ "${CRATEVERSION-}" != "" ] -then - go test -v -race -run 'TestCrateDBConnect' -fi diff --git a/pgconn/ci/setup_test.bash b/pgconn/ci/setup_test.bash deleted file mode 100755 index f71bd98c..00000000 --- a/pgconn/ci/setup_test.bash +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env bash -set -eux - -if [[ "${PGVERSION-}" =~ ^[0-9.]+$ ]] -then - sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common - sudo rm -rf /var/lib/postgresql - wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - - sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list" - sudo apt-get update -qq - sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION - sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf - if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then - echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf - echo "max_wal_senders=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf - echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf - fi - sudo /etc/init.d/postgresql restart - - # The tricky test user, below, has to actually exist so that it can be used in a test - # of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. - psql -U postgres -c 'create database pgx_test' - psql -U postgres pgx_test -c 'create extension hstore' - psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)' - psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user `whoami`" - psql -U postgres -c "create user pgx_replication with replication password 'secret'" - psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" -fi - -if [[ "${PGVERSION-}" =~ ^cockroach ]] -then - wget -qO- https://binaries.cockroachdb.com/cockroach-v20.2.5.linux-amd64.tgz | tar xvz - sudo mv cockroach-v20.2.5.linux-amd64/cockroach /usr/local/bin/ - cockroach start-single-node --insecure --background --listen-addr=localhost - cockroach sql --insecure -e 'create database pgx_test' -fi - -if [ "${CRATEVERSION-}" != "" ] -then - docker run \ - -p "6543:5432" \ - -d \ - crate:"$CRATEVERSION" \ - crate \ - -Cnetwork.host=0.0.0.0 \ - -Ctransport.host=localhost \ - -Clicense.enterprise=false -fi diff --git a/pgconn/config_test.go b/pgconn/config_test.go index d29173d1..367c2d3e 100644 --- a/pgconn/config_test.go +++ b/pgconn/config_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgconn/errors_test.go b/pgconn/errors_test.go index 1bff3656..dafe4332 100644 --- a/pgconn/errors_test.go +++ b/pgconn/errors_test.go @@ -3,7 +3,7 @@ package pgconn_test import ( "testing" - "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4/pgconn" "github.com/stretchr/testify/assert" ) diff --git a/pgconn/frontend_test.go b/pgconn/frontend_test.go index b82552bf..f1c3830c 100644 --- a/pgconn/frontend_test.go +++ b/pgconn/frontend_test.go @@ -6,8 +6,8 @@ import ( "os" "testing" - "github.com/jackc/pgconn" "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgx/v4/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgconn/go.mod b/pgconn/go.mod deleted file mode 100644 index 6fdd0e97..00000000 --- a/pgconn/go.mod +++ /dev/null @@ -1,15 +0,0 @@ -module github.com/jackc/pgconn - -go 1.12 - -require ( - github.com/jackc/chunkreader/v2 v2.0.1 - github.com/jackc/pgio v1.0.0 - github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 - github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.1.1 - github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b - github.com/stretchr/testify v1.7.0 - golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 - golang.org/x/text v0.3.6 -) diff --git a/pgconn/go.sum b/pgconn/go.sum deleted file mode 100644 index 3c77ee21..00000000 --- a/pgconn/go.sum +++ /dev/null @@ -1,130 +0,0 @@ -github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= -github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= -github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= -github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= -github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= -github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= -github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= -github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= -github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= -github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= -github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= -github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= -github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= -github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= -github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= -github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= -github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= -github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= -github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= -github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= -github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= -github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= -github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= -github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= -github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= -github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= -github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= -github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= -github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= -github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= -github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= -github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= -github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= -github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= -github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= -github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= -go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= -golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= -golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pgconn/helper_test.go b/pgconn/helper_test.go index 87613dc9..eb4eaa6b 100644 --- a/pgconn/helper_test.go +++ b/pgconn/helper_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/pgconn/internal/ctxwatch/context_watcher_test.go b/pgconn/internal/ctxwatch/context_watcher_test.go index 289606c3..d9061812 100644 --- a/pgconn/internal/ctxwatch/context_watcher_test.go +++ b/pgconn/internal/ctxwatch/context_watcher_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/jackc/pgconn/internal/ctxwatch" + "github.com/jackc/pgx/v4/pgconn/internal/ctxwatch" "github.com/stretchr/testify/require" ) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 382ad33c..9b0e2735 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -15,9 +15,9 @@ import ( "sync" "time" - "github.com/jackc/pgconn/internal/ctxwatch" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgx/v4/pgconn/internal/ctxwatch" ) const ( diff --git a/pgconn/pgconn_stress_test.go b/pgconn/pgconn_stress_test.go index 356b529a..83847593 100644 --- a/pgconn/pgconn_stress_test.go +++ b/pgconn/pgconn_stress_test.go @@ -8,7 +8,7 @@ import ( "strconv" "testing" - "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4/pgconn" "github.com/stretchr/testify/require" ) diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index c20b7425..79ded806 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -18,9 +18,9 @@ import ( "testing" "time" - "github.com/jackc/pgconn" "github.com/jackc/pgmock" "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgx/v4/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgconn/stmtcache/lru.go b/pgconn/stmtcache/lru.go index 90fb76c2..7935a8c9 100644 --- a/pgconn/stmtcache/lru.go +++ b/pgconn/stmtcache/lru.go @@ -6,7 +6,7 @@ import ( "fmt" "sync/atomic" - "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4/pgconn" ) var lruCount uint64 diff --git a/pgconn/stmtcache/lru_test.go b/pgconn/stmtcache/lru_test.go index f594ceac..9131defb 100644 --- a/pgconn/stmtcache/lru_test.go +++ b/pgconn/stmtcache/lru_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" - "github.com/jackc/pgconn" - "github.com/jackc/pgconn/stmtcache" + "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v4/pgconn/stmtcache" "github.com/stretchr/testify/require" ) diff --git a/pgconn/stmtcache/stmtcache.go b/pgconn/stmtcache/stmtcache.go index d083e1b4..91538cf5 100644 --- a/pgconn/stmtcache/stmtcache.go +++ b/pgconn/stmtcache/stmtcache.go @@ -4,7 +4,7 @@ package stmtcache import ( "context" - "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4/pgconn" ) const ( diff --git a/pgtype/pgxtype/pgxtype.go b/pgtype/pgxtype/pgxtype.go index db4d8926..a16c0389 100644 --- a/pgtype/pgxtype/pgxtype.go +++ b/pgtype/pgxtype/pgxtype.go @@ -4,8 +4,8 @@ import ( "context" "errors" - "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" "github.com/jackc/pgx/v4/pgtype" ) diff --git a/pgxpool/batch_results.go b/pgxpool/batch_results.go index c625a474..42f597a5 100644 --- a/pgxpool/batch_results.go +++ b/pgxpool/batch_results.go @@ -1,8 +1,8 @@ package pgxpool import ( - "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" ) type errBatchResults struct { diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go index c701e1f7..8ef85052 100644 --- a/pgxpool/common_test.go +++ b/pgxpool/common_test.go @@ -7,8 +7,8 @@ import ( "github.com/jackc/pgx/v4/pgxpool" - "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgxpool/conn.go b/pgxpool/conn.go index 0b59d741..fee1ff2b 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -4,8 +4,8 @@ import ( "context" "time" - "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" "github.com/jackc/puddle" ) diff --git a/pgxpool/pool.go b/pgxpool/pool.go index f287ad88..84d05ad9 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -8,8 +8,8 @@ import ( "sync" "time" - "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" "github.com/jackc/puddle" ) diff --git a/pgxpool/rows.go b/pgxpool/rows.go index 6dc0cc34..cc45c76f 100644 --- a/pgxpool/rows.go +++ b/pgxpool/rows.go @@ -1,9 +1,9 @@ package pgxpool import ( - "github.com/jackc/pgconn" "github.com/jackc/pgproto3/v2" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" ) type errRows struct { diff --git a/pgxpool/tx.go b/pgxpool/tx.go index 6f566e41..962b8725 100644 --- a/pgxpool/tx.go +++ b/pgxpool/tx.go @@ -3,8 +3,8 @@ package pgxpool import ( "context" - "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" ) // Tx represents a database transaction acquired from a Pool. diff --git a/query_test.go b/query_test.go index 7580d8b0..393235b7 100644 --- a/query_test.go +++ b/query_test.go @@ -15,9 +15,9 @@ import ( "github.com/cockroachdb/apd" "github.com/gofrs/uuid" - "github.com/jackc/pgconn" - "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v4/pgconn/stmtcache" "github.com/jackc/pgx/v4/pgtype" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" diff --git a/rows.go b/rows.go index 539ce3a5..14bc50ba 100644 --- a/rows.go +++ b/rows.go @@ -6,8 +6,8 @@ import ( "fmt" "time" - "github.com/jackc/pgconn" "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgx/v4/pgconn" "github.com/jackc/pgx/v4/pgtype" ) diff --git a/stdlib/sql.go b/stdlib/sql.go index 20892ab3..c9e36eef 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -63,8 +63,8 @@ import ( "sync" "time" - "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" "github.com/jackc/pgx/v4/pgtype" ) diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 099320c0..6b6440f7 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -13,8 +13,8 @@ import ( "time" "github.com/Masterminds/semver/v3" - "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" "github.com/jackc/pgx/v4/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/tx.go b/tx.go index 1971ed67..44311f1f 100644 --- a/tx.go +++ b/tx.go @@ -7,7 +7,7 @@ import ( "fmt" "strconv" - "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4/pgconn" ) // TxIsoLevel is the transaction isolation level (serializable, repeatable read, read committed or read uncommitted) diff --git a/tx_test.go b/tx_test.go index e9830d32..85083830 100644 --- a/tx_test.go +++ b/tx_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgconn" "github.com/stretchr/testify/require" ) From fbbf403cf21674d220ab0a65a0b6e23fc5ae43f0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 08:56:41 -0600 Subject: [PATCH 0756/1158] Update changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f285400..bd4001dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # Unreleased v5 -* Import pgtype repository +* Import github.com/jackc/pgtype repository +* Import github.com/jackc/pgconn repository ## pgtype Changes From d9e53647ecb714638934e134f3fca3ddd0d83f79 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 09:08:05 -0600 Subject: [PATCH 0757/1158] Use ideomatic casing --- pgtype/integration_benchmark_test.go.erb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pgtype/integration_benchmark_test.go.erb b/pgtype/integration_benchmark_test.go.erb index d9bb7937..761087b7 100644 --- a/pgtype/integration_benchmark_test.go.erb +++ b/pgtype/integration_benchmark_test.go.erb @@ -18,8 +18,8 @@ import ( %> <% go_types.each do |go_type| %> <% rows_columns.each do |rows, columns| %> -<% [["Text", "pgx.TextFormatCode"], ["Binary", "pgx.BinaryFormatCode"]].each do |formatName, formatCode| %> -func BenchmarkQuery<%= formatName %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go_type.gsub(/\W/, "_") %>_<%= rows %>_rows_<%= columns %>_columns(b *testing.B) { +<% [["Text", "pgx.TextFormatCode"], ["Binary", "pgx.BinaryFormatCode"]].each do |format_name, format_code| %> +func BenchmarkQuery<%= format_name %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go_type.gsub(/\W/, "_") %>_<%= rows %>_rows_<%= columns %>_columns(b *testing.B) { conn := testutil.MustConnectPgx(b) defer testutil.MustCloseContext(b, conn) @@ -29,7 +29,7 @@ func BenchmarkQuery<%= formatName %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go_ _, err := conn.QueryFunc( context.Background(), `select <% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>n::<%= pg_type %> + <%= col_idx%><% end %> from generate_series(1, <%= rows %>) n`, - []interface{}{pgx.QueryResultFormats{<%= formatCode %>}}, + []interface{}{pgx.QueryResultFormats{<%= format_code %>}}, []interface{}{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, func(pgx.QueryFuncRow) error { return nil }, ) From 390bd79757c4342a320b64f2e04296d6a2636ebf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 09:19:11 -0600 Subject: [PATCH 0758/1158] Add array integration benchmarks --- pgtype/integration_benchmark_test.go | 276 +++++++++++++++++++++++ pgtype/integration_benchmark_test.go.erb | 50 ++++ 2 files changed, 326 insertions(+) diff --git a/pgtype/integration_benchmark_test.go b/pgtype/integration_benchmark_test.go index cca6dd1e..0ee87ba3 100644 --- a/pgtype/integration_benchmark_test.go +++ b/pgtype/integration_benchmark_test.go @@ -1290,3 +1290,279 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_1 } } } + +func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select array_agg(n) from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_ArrayType_10(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + conn.ConnInfo().RegisterDataType(pgtype.DataType{ + Value: pgtype.NewArrayType("_int4", pgtype.Int4OID, func() pgtype.ValueTranscoder { return &pgtype.Int4{} }), + Name: "_int4", + OID: pgtype.Int4ArrayOID, + }) + + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select array_agg(n) from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select array_agg(n) from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_ArrayType_10(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + conn.ConnInfo().RegisterDataType(pgtype.DataType{ + Value: pgtype.NewArrayType("_int4", pgtype.Int4OID, func() pgtype.ValueTranscoder { return &pgtype.Int4{} }), + Name: "_int4", + OID: pgtype.Int4ArrayOID, + }) + + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select array_agg(n) from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select array_agg(n) from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_ArrayType_100(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + conn.ConnInfo().RegisterDataType(pgtype.DataType{ + Value: pgtype.NewArrayType("_int4", pgtype.Int4OID, func() pgtype.ValueTranscoder { return &pgtype.Int4{} }), + Name: "_int4", + OID: pgtype.Int4ArrayOID, + }) + + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select array_agg(n) from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select array_agg(n) from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_ArrayType_100(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + conn.ConnInfo().RegisterDataType(pgtype.DataType{ + Value: pgtype.NewArrayType("_int4", pgtype.Int4OID, func() pgtype.ValueTranscoder { return &pgtype.Int4{} }), + Name: "_int4", + OID: pgtype.Int4ArrayOID, + }) + + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select array_agg(n) from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select array_agg(n) from generate_series(1, 1000) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_ArrayType_1000(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + conn.ConnInfo().RegisterDataType(pgtype.DataType{ + Value: pgtype.NewArrayType("_int4", pgtype.Int4OID, func() pgtype.ValueTranscoder { return &pgtype.Int4{} }), + Name: "_int4", + OID: pgtype.Int4ArrayOID, + }) + + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select array_agg(n) from generate_series(1, 1000) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select array_agg(n) from generate_series(1, 1000) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_ArrayType_1000(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + conn.ConnInfo().RegisterDataType(pgtype.DataType{ + Value: pgtype.NewArrayType("_int4", pgtype.Int4OID, func() pgtype.ValueTranscoder { return &pgtype.Int4{} }), + Name: "_int4", + OID: pgtype.Int4ArrayOID, + }) + + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select array_agg(n) from generate_series(1, 1000) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/pgtype/integration_benchmark_test.go.erb b/pgtype/integration_benchmark_test.go.erb index 761087b7..f642e6ca 100644 --- a/pgtype/integration_benchmark_test.go.erb +++ b/pgtype/integration_benchmark_test.go.erb @@ -42,3 +42,53 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go <% end %> <% end %> <% end %> + +<% [10, 100, 1000].each do |array_size| %> +<% [["Text", "pgx.TextFormatCode"], ["Binary", "pgx.BinaryFormatCode"]].each do |format_name, format_code| %> +func BenchmarkQuery<%= format_name %>FormatDecode_PG_Int4Array_With_Go_Int4Array_<%= array_size %>(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select array_agg(n) from generate_series(1, <%= array_size %>) n`, + []interface{}{pgx.QueryResultFormats{<%= format_code %>}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQuery<%= format_name %>FormatDecode_PG_Int4Array_With_Go_ArrayType_<%= array_size %>(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + conn.ConnInfo().RegisterDataType(pgtype.DataType{ + Value: pgtype.NewArrayType("_int4", pgtype.Int4OID, func() pgtype.ValueTranscoder { return &pgtype.Int4{} }), + Name: "_int4", + OID: pgtype.Int4ArrayOID, + }) + + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select array_agg(n) from generate_series(1, <%= array_size %>) n`, + []interface{}{pgx.QueryResultFormats{<%= format_code %>}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} +<% end %> +<% end %> From 72cc95e4dd18576691da38c46c7e6a101869c10a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 13:29:03 -0600 Subject: [PATCH 0759/1158] Bump module version to v5 --- README.md | 8 ++++---- batch.go | 2 +- batch_test.go | 6 +++--- bench_test.go | 8 ++++---- conn.go | 8 ++++---- conn_test.go | 8 ++++---- copy_from.go | 2 +- copy_from_test.go | 4 ++-- example_custom_type_test.go | 4 ++-- example_json_test.go | 2 +- examples/chat/main.go | 2 +- examples/todo/main.go | 2 +- examples/url_shortener/main.go | 6 +++--- extended_query_builder.go | 2 +- go.mod | 2 +- helper_test.go | 4 ++-- internal/sanitize/sanitize_test.go | 2 +- large_objects_test.go | 4 ++-- log/kitlogadapter/adapter.go | 2 +- log/log15adapter/adapter.go | 2 +- log/logrusadapter/adapter.go | 2 +- log/testingadapter/adapter.go | 2 +- log/zapadapter/adapter.go | 2 +- log/zerologadapter/adapter.go | 2 +- log/zerologadapter/adapter_test.go | 4 ++-- messages.go | 2 +- pgbouncer_test.go | 6 +++--- pgconn/benchmark_test.go | 2 +- pgconn/config_test.go | 2 +- pgconn/errors_test.go | 2 +- pgconn/frontend_test.go | 2 +- pgconn/helper_test.go | 2 +- pgconn/internal/ctxwatch/context_watcher_test.go | 2 +- pgconn/pgconn.go | 2 +- pgconn/pgconn_stress_test.go | 2 +- pgconn/pgconn_test.go | 2 +- pgconn/stmtcache/lru.go | 2 +- pgconn/stmtcache/lru_test.go | 4 ++-- pgconn/stmtcache/stmtcache.go | 2 +- pgtype/aclitem_array_test.go | 4 ++-- pgtype/aclitem_test.go | 4 ++-- pgtype/array_test.go | 2 +- pgtype/array_type_test.go | 4 ++-- pgtype/bit_test.go | 4 ++-- pgtype/bool_array_test.go | 4 ++-- pgtype/bool_test.go | 4 ++-- pgtype/box_test.go | 4 ++-- pgtype/bpchar_array_test.go | 4 ++-- pgtype/bpchar_test.go | 4 ++-- pgtype/bytea_array_test.go | 4 ++-- pgtype/bytea_test.go | 4 ++-- pgtype/cid_test.go | 4 ++-- pgtype/cidr_array_test.go | 4 ++-- pgtype/circle_test.go | 4 ++-- pgtype/composite_bench_test.go | 2 +- pgtype/composite_fields_test.go | 6 +++--- pgtype/composite_type_test.go | 6 +++--- pgtype/custom_composite_test.go | 4 ++-- pgtype/date_array_test.go | 4 ++-- pgtype/date_test.go | 4 ++-- pgtype/daterange_test.go | 4 ++-- pgtype/enum_array_test.go | 4 ++-- pgtype/enum_type_test.go | 6 +++--- pgtype/float4_array_test.go | 4 ++-- pgtype/float4_test.go | 4 ++-- pgtype/float8_array_test.go | 4 ++-- pgtype/float8_test.go | 4 ++-- pgtype/hstore_array_test.go | 6 +++--- pgtype/hstore_test.go | 4 ++-- pgtype/inet_array_test.go | 4 ++-- pgtype/inet_test.go | 4 ++-- pgtype/int2_array_test.go | 4 ++-- pgtype/int2_test.go | 4 ++-- pgtype/int4_array_test.go | 4 ++-- pgtype/int4_test.go | 4 ++-- pgtype/int4range_test.go | 4 ++-- pgtype/int8_array_test.go | 4 ++-- pgtype/int8_test.go | 4 ++-- pgtype/int8range_test.go | 4 ++-- pgtype/integration_benchmark_test.go | 6 +++--- pgtype/integration_benchmark_test.go.erb | 4 ++-- pgtype/interval_test.go | 4 ++-- pgtype/json_test.go | 4 ++-- pgtype/jsonb_array_test.go | 4 ++-- pgtype/jsonb_test.go | 4 ++-- pgtype/line_test.go | 4 ++-- pgtype/lseg_test.go | 4 ++-- pgtype/macaddr_array_test.go | 4 ++-- pgtype/macaddr_test.go | 4 ++-- pgtype/name_test.go | 4 ++-- pgtype/numeric_array_test.go | 4 ++-- pgtype/numeric_test.go | 6 +++--- pgtype/numrange_test.go | 4 ++-- pgtype/oid_value_test.go | 4 ++-- pgtype/path_test.go | 4 ++-- pgtype/pgtype_test.go | 6 +++--- pgtype/pgxtype/pgxtype.go | 6 +++--- pgtype/point_test.go | 4 ++-- pgtype/polygon_test.go | 4 ++-- pgtype/qchar_test.go | 4 ++-- pgtype/record_test.go | 6 +++--- pgtype/testutil/testutil.go | 6 +++--- pgtype/text_array_test.go | 4 ++-- pgtype/text_test.go | 4 ++-- pgtype/tid_test.go | 4 ++-- pgtype/time_test.go | 4 ++-- pgtype/timestamp_array_test.go | 4 ++-- pgtype/timestamp_test.go | 4 ++-- pgtype/timestamptz_array_test.go | 4 ++-- pgtype/timestamptz_test.go | 4 ++-- pgtype/tsrange_test.go | 4 ++-- pgtype/tstzrange_test.go | 4 ++-- pgtype/uuid_array_test.go | 4 ++-- pgtype/uuid_test.go | 4 ++-- pgtype/varbit_test.go | 4 ++-- pgtype/varchar_array_test.go | 4 ++-- pgtype/xid_test.go | 4 ++-- pgtype/zeronull/float8.go | 2 +- pgtype/zeronull/float8_test.go | 4 ++-- pgtype/zeronull/int2.go | 2 +- pgtype/zeronull/int2_test.go | 4 ++-- pgtype/zeronull/int4.go | 2 +- pgtype/zeronull/int4_test.go | 4 ++-- pgtype/zeronull/int8.go | 2 +- pgtype/zeronull/int8_test.go | 4 ++-- pgtype/zeronull/text.go | 2 +- pgtype/zeronull/text_test.go | 4 ++-- pgtype/zeronull/timestamp.go | 2 +- pgtype/zeronull/timestamp_test.go | 4 ++-- pgtype/zeronull/timestamptz.go | 2 +- pgtype/zeronull/timestamptz_test.go | 4 ++-- pgtype/zeronull/uuid.go | 2 +- pgtype/zeronull/uuid_test.go | 4 ++-- pgxpool/batch_results.go | 4 ++-- pgxpool/bench_test.go | 4 ++-- pgxpool/common_test.go | 6 +++--- pgxpool/conn.go | 4 ++-- pgxpool/conn_test.go | 2 +- pgxpool/pool.go | 4 ++-- pgxpool/pool_test.go | 4 ++-- pgxpool/rows.go | 4 ++-- pgxpool/tx.go | 4 ++-- pgxpool/tx_test.go | 2 +- query_test.go | 8 ++++---- rows.go | 4 ++-- stdlib/sql.go | 6 +++--- stdlib/sql_test.go | 6 +++--- tx.go | 2 +- tx_test.go | 4 ++-- values.go | 2 +- values_test.go | 2 +- 151 files changed, 287 insertions(+), 287 deletions(-) diff --git a/README.md b/README.md index 63a0cff8..03f23bdb 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -[![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://pkg.go.dev/github.com/jackc/pgx/v4) +[![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://pkg.go.dev/github.com/jackc/pgx/v5) [![Build Status](https://travis-ci.org/jackc/pgx.svg)](https://travis-ci.org/jackc/pgx) # pgx - PostgreSQL Driver and Toolkit @@ -25,7 +25,7 @@ import ( "fmt" "os" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" ) func main() { @@ -164,11 +164,11 @@ from pgx for lower-level control. `pgconn` is a lower-level PostgreSQL database driver that operates at nearly the same level as the C library `libpq`. -### [github.com/jackc/pgx/v4/pgxpool](https://github.com/jackc/pgx/tree/master/pgxpool) +### [github.com/jackc/pgx/v5/pgxpool](https://github.com/jackc/pgx/tree/master/pgxpool) `pgxpool` is a connection pool for pgx. pgx is entirely decoupled from its default pool implementation. This means that pgx can be used with a different pool or without any pool at all. -### [github.com/jackc/pgx/v4/stdlib](https://github.com/jackc/pgx/tree/master/stdlib) +### [github.com/jackc/pgx/v5/stdlib](https://github.com/jackc/pgx/tree/master/stdlib) This is a `database/sql` compatibility layer for pgx. pgx can be used as a normal `database/sql` driver, but at any time, the native interface can be acquired for more performance or PostgreSQL specific functionality. diff --git a/batch.go b/batch.go index 547a0efc..18ee8339 100644 --- a/batch.go +++ b/batch.go @@ -4,7 +4,7 @@ import ( "context" "errors" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5/pgconn" ) type batchItem struct { diff --git a/batch_test.go b/batch_test.go index c0eb32d5..dc57b379 100644 --- a/batch_test.go +++ b/batch_test.go @@ -5,9 +5,9 @@ import ( "os" "testing" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" - "github.com/jackc/pgx/v4/pgconn/stmtcache" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgconn/stmtcache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/bench_test.go b/bench_test.go index 40e24da3..06cfd0c4 100644 --- a/bench_test.go +++ b/bench_test.go @@ -12,10 +12,10 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" - "github.com/jackc/pgx/v4/pgconn/stmtcache" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgconn/stmtcache" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/require" ) diff --git a/conn.go b/conn.go index 22cc4f74..4412e174 100644 --- a/conn.go +++ b/conn.go @@ -9,10 +9,10 @@ import ( "time" "github.com/jackc/pgproto3/v2" - "github.com/jackc/pgx/v4/internal/sanitize" - "github.com/jackc/pgx/v4/pgconn" - "github.com/jackc/pgx/v4/pgconn/stmtcache" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5/internal/sanitize" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgconn/stmtcache" + "github.com/jackc/pgx/v5/pgtype" ) // ConnConfig contains all the options used to establish a connection. It must be created by ParseConfig and diff --git a/conn_test.go b/conn_test.go index fa764b18..857fd828 100644 --- a/conn_test.go +++ b/conn_test.go @@ -8,10 +8,10 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" - "github.com/jackc/pgx/v4/pgconn/stmtcache" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgconn/stmtcache" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/copy_from.go b/copy_from.go index 52d69cdd..0bf3478f 100644 --- a/copy_from.go +++ b/copy_from.go @@ -8,7 +8,7 @@ import ( "time" "github.com/jackc/pgio" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5/pgconn" ) // CopyFromRows returns a CopyFromSource interface over the provided rows slice diff --git a/copy_from_test.go b/copy_from_test.go index b514fd8b..32644f38 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -8,8 +8,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/require" ) diff --git a/example_custom_type_test.go b/example_custom_type_test.go index 7723df6a..10014278 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -7,8 +7,8 @@ import ( "regexp" "strconv" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" ) var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) diff --git a/example_json_test.go b/example_json_test.go index 33bd7519..017699b9 100644 --- a/example_json_test.go +++ b/example_json_test.go @@ -5,7 +5,7 @@ import ( "fmt" "os" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" ) func Example_JSON() { diff --git a/examples/chat/main.go b/examples/chat/main.go index 6be4ee1c..6e705fb6 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -6,7 +6,7 @@ import ( "fmt" "os" - "github.com/jackc/pgx/v4/pgxpool" + "github.com/jackc/pgx/v5/pgxpool" ) var pool *pgxpool.Pool diff --git a/examples/todo/main.go b/examples/todo/main.go index 9aa8c1cb..6c644ede 100644 --- a/examples/todo/main.go +++ b/examples/todo/main.go @@ -6,7 +6,7 @@ import ( "os" "strconv" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" ) var conn *pgx.Conn diff --git a/examples/url_shortener/main.go b/examples/url_shortener/main.go index c5e87eb3..cb474f32 100644 --- a/examples/url_shortener/main.go +++ b/examples/url_shortener/main.go @@ -6,9 +6,9 @@ import ( "net/http" "os" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/log/log15adapter" - "github.com/jackc/pgx/v4/pgxpool" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/log/log15adapter" + "github.com/jackc/pgx/v5/pgxpool" log "gopkg.in/inconshreveable/log15.v2" ) diff --git a/extended_query_builder.go b/extended_query_builder.go index 53ea75ff..1420a808 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -5,7 +5,7 @@ import ( "fmt" "reflect" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5/pgtype" ) type extendedQueryBuilder struct { diff --git a/go.mod b/go.mod index ebcdc2f5..737db8cc 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/jackc/pgx/v4 +module github.com/jackc/pgx/v5 go 1.13 diff --git a/helper_test.go b/helper_test.go index 25280ac3..74c17431 100644 --- a/helper_test.go +++ b/helper_test.go @@ -7,8 +7,8 @@ import ( "github.com/stretchr/testify/assert" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/require" ) diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go index 344c46b0..acfac2ec 100644 --- a/internal/sanitize/sanitize_test.go +++ b/internal/sanitize/sanitize_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/internal/sanitize" + "github.com/jackc/pgx/v5/internal/sanitize" ) func TestNewQuery(t *testing.T) { diff --git a/large_objects_test.go b/large_objects_test.go index 167e50d3..e42a90e7 100644 --- a/large_objects_test.go +++ b/large_objects_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) func TestLargeObjects(t *testing.T) { diff --git a/log/kitlogadapter/adapter.go b/log/kitlogadapter/adapter.go index 0a46197f..95fc6c49 100644 --- a/log/kitlogadapter/adapter.go +++ b/log/kitlogadapter/adapter.go @@ -5,7 +5,7 @@ import ( "github.com/go-kit/log" kitlevel "github.com/go-kit/log/level" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" ) type Logger struct { diff --git a/log/log15adapter/adapter.go b/log/log15adapter/adapter.go index 70608e33..dece65b4 100644 --- a/log/log15adapter/adapter.go +++ b/log/log15adapter/adapter.go @@ -5,7 +5,7 @@ package log15adapter import ( "context" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" ) // Log15Logger interface defines the subset of diff --git a/log/logrusadapter/adapter.go b/log/logrusadapter/adapter.go index e0cd6328..65a64230 100644 --- a/log/logrusadapter/adapter.go +++ b/log/logrusadapter/adapter.go @@ -5,7 +5,7 @@ package logrusadapter import ( "context" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" "github.com/sirupsen/logrus" ) diff --git a/log/testingadapter/adapter.go b/log/testingadapter/adapter.go index 3ddce5a1..aa1b4bd6 100644 --- a/log/testingadapter/adapter.go +++ b/log/testingadapter/adapter.go @@ -6,7 +6,7 @@ import ( "context" "fmt" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" ) // TestingLogger interface defines the subset of testing.TB methods used by this diff --git a/log/zapadapter/adapter.go b/log/zapadapter/adapter.go index ebc540aa..4dc47cd0 100644 --- a/log/zapadapter/adapter.go +++ b/log/zapadapter/adapter.go @@ -4,7 +4,7 @@ package zapadapter import ( "context" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) diff --git a/log/zerologadapter/adapter.go b/log/zerologadapter/adapter.go index 6e8b4b94..b93036fe 100644 --- a/log/zerologadapter/adapter.go +++ b/log/zerologadapter/adapter.go @@ -4,7 +4,7 @@ package zerologadapter import ( "context" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" "github.com/rs/zerolog" ) diff --git a/log/zerologadapter/adapter_test.go b/log/zerologadapter/adapter_test.go index 3a11cbc0..152b2129 100644 --- a/log/zerologadapter/adapter_test.go +++ b/log/zerologadapter/adapter_test.go @@ -5,8 +5,8 @@ import ( "context" "testing" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/log/zerologadapter" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/log/zerologadapter" "github.com/rs/zerolog" ) diff --git a/messages.go b/messages.go index d7af6973..87c0aa22 100644 --- a/messages.go +++ b/messages.go @@ -3,7 +3,7 @@ package pgx import ( "database/sql/driver" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5/pgtype" ) func convertDriverValuers(args []interface{}) ([]interface{}, error) { diff --git a/pgbouncer_test.go b/pgbouncer_test.go index abc349a1..eeae6db4 100644 --- a/pgbouncer_test.go +++ b/pgbouncer_test.go @@ -5,9 +5,9 @@ import ( "os" "testing" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" - "github.com/jackc/pgx/v4/pgconn/stmtcache" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgconn/stmtcache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgconn/benchmark_test.go b/pgconn/benchmark_test.go index 48955cf6..088a9bd9 100644 --- a/pgconn/benchmark_test.go +++ b/pgconn/benchmark_test.go @@ -7,7 +7,7 @@ import ( "strings" "testing" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/require" ) diff --git a/pgconn/config_test.go b/pgconn/config_test.go index 367c2d3e..3ffa384f 100644 --- a/pgconn/config_test.go +++ b/pgconn/config_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgconn/errors_test.go b/pgconn/errors_test.go index dafe4332..9d559346 100644 --- a/pgconn/errors_test.go +++ b/pgconn/errors_test.go @@ -3,7 +3,7 @@ package pgconn_test import ( "testing" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/assert" ) diff --git a/pgconn/frontend_test.go b/pgconn/frontend_test.go index f1c3830c..9ea53b10 100644 --- a/pgconn/frontend_test.go +++ b/pgconn/frontend_test.go @@ -7,7 +7,7 @@ import ( "testing" "github.com/jackc/pgproto3/v2" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgconn/helper_test.go b/pgconn/helper_test.go index eb4eaa6b..0696f4ce 100644 --- a/pgconn/helper_test.go +++ b/pgconn/helper_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/pgconn/internal/ctxwatch/context_watcher_test.go b/pgconn/internal/ctxwatch/context_watcher_test.go index d9061812..39652995 100644 --- a/pgconn/internal/ctxwatch/context_watcher_test.go +++ b/pgconn/internal/ctxwatch/context_watcher_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgconn/internal/ctxwatch" + "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" "github.com/stretchr/testify/require" ) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 9b0e2735..52577bb0 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -17,7 +17,7 @@ import ( "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" - "github.com/jackc/pgx/v4/pgconn/internal/ctxwatch" + "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" ) const ( diff --git a/pgconn/pgconn_stress_test.go b/pgconn/pgconn_stress_test.go index 83847593..3d72964f 100644 --- a/pgconn/pgconn_stress_test.go +++ b/pgconn/pgconn_stress_test.go @@ -8,7 +8,7 @@ import ( "strconv" "testing" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/require" ) diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 79ded806..22f0c26f 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -20,7 +20,7 @@ import ( "github.com/jackc/pgmock" "github.com/jackc/pgproto3/v2" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgconn/stmtcache/lru.go b/pgconn/stmtcache/lru.go index 7935a8c9..938ddee4 100644 --- a/pgconn/stmtcache/lru.go +++ b/pgconn/stmtcache/lru.go @@ -6,7 +6,7 @@ import ( "fmt" "sync/atomic" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5/pgconn" ) var lruCount uint64 diff --git a/pgconn/stmtcache/lru_test.go b/pgconn/stmtcache/lru_test.go index 9131defb..549e7670 100644 --- a/pgconn/stmtcache/lru_test.go +++ b/pgconn/stmtcache/lru_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgconn" - "github.com/jackc/pgx/v4/pgconn/stmtcache" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgconn/stmtcache" "github.com/stretchr/testify/require" ) diff --git a/pgconn/stmtcache/stmtcache.go b/pgconn/stmtcache/stmtcache.go index 91538cf5..a2582019 100644 --- a/pgconn/stmtcache/stmtcache.go +++ b/pgconn/stmtcache/stmtcache.go @@ -4,7 +4,7 @@ package stmtcache import ( "context" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5/pgconn" ) const ( diff --git a/pgtype/aclitem_array_test.go b/pgtype/aclitem_array_test.go index 736d0967..bdae9b56 100644 --- a/pgtype/aclitem_array_test.go +++ b/pgtype/aclitem_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestACLItemArrayTranscode(t *testing.T) { diff --git a/pgtype/aclitem_test.go b/pgtype/aclitem_test.go index afc6a1e3..84388142 100644 --- a/pgtype/aclitem_test.go +++ b/pgtype/aclitem_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestACLItemTranscode(t *testing.T) { diff --git a/pgtype/array_test.go b/pgtype/array_test.go index 549467bf..77700ad6 100644 --- a/pgtype/array_test.go +++ b/pgtype/array_test.go @@ -4,7 +4,7 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/require" ) diff --git a/pgtype/array_type_test.go b/pgtype/array_type_test.go index 62b25dfa..3ea5bc79 100644 --- a/pgtype/array_type_test.go +++ b/pgtype/array_type_test.go @@ -4,8 +4,8 @@ import ( "context" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/require" ) diff --git a/pgtype/bit_test.go b/pgtype/bit_test.go index 51c12765..2f07c3c9 100644 --- a/pgtype/bit_test.go +++ b/pgtype/bit_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestBitTranscode(t *testing.T) { diff --git a/pgtype/bool_array_test.go b/pgtype/bool_array_test.go index 9278c864..7de5612a 100644 --- a/pgtype/bool_array_test.go +++ b/pgtype/bool_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestBoolArrayTranscode(t *testing.T) { diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go index f323c5e7..9a07491f 100644 --- a/pgtype/bool_test.go +++ b/pgtype/bool_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestBoolTranscode(t *testing.T) { diff --git a/pgtype/box_test.go b/pgtype/box_test.go index d6a928c9..481723b5 100644 --- a/pgtype/box_test.go +++ b/pgtype/box_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestBoxTranscode(t *testing.T) { diff --git a/pgtype/bpchar_array_test.go b/pgtype/bpchar_array_test.go index 4714b261..0118ad7d 100644 --- a/pgtype/bpchar_array_test.go +++ b/pgtype/bpchar_array_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestBPCharArrayTranscode(t *testing.T) { diff --git a/pgtype/bpchar_test.go b/pgtype/bpchar_test.go index 68fbfc9f..ead26220 100644 --- a/pgtype/bpchar_test.go +++ b/pgtype/bpchar_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestChar3Transcode(t *testing.T) { diff --git a/pgtype/bytea_array_test.go b/pgtype/bytea_array_test.go index d081db11..08b69c26 100644 --- a/pgtype/bytea_array_test.go +++ b/pgtype/bytea_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestByteaArrayTranscode(t *testing.T) { diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go index b87b3c96..21751e24 100644 --- a/pgtype/bytea_test.go +++ b/pgtype/bytea_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestByteaTranscode(t *testing.T) { diff --git a/pgtype/cid_test.go b/pgtype/cid_test.go index e915e534..3d3ad2a5 100644 --- a/pgtype/cid_test.go +++ b/pgtype/cid_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestCIDTranscode(t *testing.T) { diff --git a/pgtype/cidr_array_test.go b/pgtype/cidr_array_test.go index 93d3933d..550bf9d1 100644 --- a/pgtype/cidr_array_test.go +++ b/pgtype/cidr_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestCIDRArrayTranscode(t *testing.T) { diff --git a/pgtype/circle_test.go b/pgtype/circle_test.go index 2e5a8c86..8f39644b 100644 --- a/pgtype/circle_test.go +++ b/pgtype/circle_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestCircleTranscode(t *testing.T) { diff --git a/pgtype/composite_bench_test.go b/pgtype/composite_bench_test.go index 92330905..ef57709b 100644 --- a/pgtype/composite_bench_test.go +++ b/pgtype/composite_bench_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/jackc/pgio" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/require" ) diff --git a/pgtype/composite_fields_test.go b/pgtype/composite_fields_test.go index 07b3954e..e73d8441 100644 --- a/pgtype/composite_fields_test.go +++ b/pgtype/composite_fields_test.go @@ -4,9 +4,9 @@ import ( "context" "testing" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgtype/composite_type_test.go b/pgtype/composite_type_test.go index 80dd4a5c..a41ad0f4 100644 --- a/pgtype/composite_type_test.go +++ b/pgtype/composite_type_test.go @@ -6,9 +6,9 @@ import ( "os" "testing" - pgx "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgtype/custom_composite_test.go b/pgtype/custom_composite_test.go index 0cc14442..e5f2166e 100644 --- a/pgtype/custom_composite_test.go +++ b/pgtype/custom_composite_test.go @@ -6,8 +6,8 @@ import ( "fmt" "os" - pgx "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgtype" + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" ) type MyType struct { diff --git a/pgtype/date_array_test.go b/pgtype/date_array_test.go index 93e423a0..39ee32ce 100644 --- a/pgtype/date_array_test.go +++ b/pgtype/date_array_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestDateArrayTranscode(t *testing.T) { diff --git a/pgtype/date_test.go b/pgtype/date_test.go index 0df84468..caccfc47 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestDateTranscode(t *testing.T) { diff --git a/pgtype/daterange_test.go b/pgtype/daterange_test.go index d0bb8d60..a6501372 100644 --- a/pgtype/daterange_test.go +++ b/pgtype/daterange_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestDaterangeTranscode(t *testing.T) { diff --git a/pgtype/enum_array_test.go b/pgtype/enum_array_test.go index 7b9c4d23..6e49aaaf 100644 --- a/pgtype/enum_array_test.go +++ b/pgtype/enum_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestEnumArrayTranscode(t *testing.T) { diff --git a/pgtype/enum_type_test.go b/pgtype/enum_type_test.go index 965f713a..903b742f 100644 --- a/pgtype/enum_type_test.go +++ b/pgtype/enum_type_test.go @@ -5,9 +5,9 @@ import ( "context" "testing" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgtype/float4_array_test.go b/pgtype/float4_array_test.go index d28cd38c..ecd65206 100644 --- a/pgtype/float4_array_test.go +++ b/pgtype/float4_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestFloat4ArrayTranscode(t *testing.T) { diff --git a/pgtype/float4_test.go b/pgtype/float4_test.go index 3ad480f5..aab6e980 100644 --- a/pgtype/float4_test.go +++ b/pgtype/float4_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestFloat4Transcode(t *testing.T) { diff --git a/pgtype/float8_array_test.go b/pgtype/float8_array_test.go index 6fc85993..66a10784 100644 --- a/pgtype/float8_array_test.go +++ b/pgtype/float8_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestFloat8ArrayTranscode(t *testing.T) { diff --git a/pgtype/float8_test.go b/pgtype/float8_test.go index 2bc8de0c..e7bd4444 100644 --- a/pgtype/float8_test.go +++ b/pgtype/float8_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestFloat8Transcode(t *testing.T) { diff --git a/pgtype/hstore_array_test.go b/pgtype/hstore_array_test.go index b164f598..7912b626 100644 --- a/pgtype/hstore_array_test.go +++ b/pgtype/hstore_array_test.go @@ -5,9 +5,9 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestHstoreArrayTranscode(t *testing.T) { diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go index 6faf7496..dd80f0c5 100644 --- a/pgtype/hstore_test.go +++ b/pgtype/hstore_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestHstoreTranscode(t *testing.T) { diff --git a/pgtype/inet_array_test.go b/pgtype/inet_array_test.go index b111746d..da7ee975 100644 --- a/pgtype/inet_array_test.go +++ b/pgtype/inet_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestInetArrayTranscode(t *testing.T) { diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go index 56705e4d..d4716479 100644 --- a/pgtype/inet_test.go +++ b/pgtype/inet_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/assert" ) diff --git a/pgtype/int2_array_test.go b/pgtype/int2_array_test.go index e5366edd..110968fc 100644 --- a/pgtype/int2_array_test.go +++ b/pgtype/int2_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestInt2ArrayTranscode(t *testing.T) { diff --git a/pgtype/int2_test.go b/pgtype/int2_test.go index 26f43eec..58dcd141 100644 --- a/pgtype/int2_test.go +++ b/pgtype/int2_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestInt2Transcode(t *testing.T) { diff --git a/pgtype/int4_array_test.go b/pgtype/int4_array_test.go index bcabe8ca..906e4775 100644 --- a/pgtype/int4_array_test.go +++ b/pgtype/int4_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestInt4ArrayTranscode(t *testing.T) { diff --git a/pgtype/int4_test.go b/pgtype/int4_test.go index cdff4b44..118c3ac5 100644 --- a/pgtype/int4_test.go +++ b/pgtype/int4_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestInt4Transcode(t *testing.T) { diff --git a/pgtype/int4range_test.go b/pgtype/int4range_test.go index a45e4779..1a11e039 100644 --- a/pgtype/int4range_test.go +++ b/pgtype/int4range_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestInt4rangeTranscode(t *testing.T) { diff --git a/pgtype/int8_array_test.go b/pgtype/int8_array_test.go index c4de8bb1..2d875b24 100644 --- a/pgtype/int8_array_test.go +++ b/pgtype/int8_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestInt8ArrayTranscode(t *testing.T) { diff --git a/pgtype/int8_test.go b/pgtype/int8_test.go index 9f96a1e3..657eb702 100644 --- a/pgtype/int8_test.go +++ b/pgtype/int8_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestInt8Transcode(t *testing.T) { diff --git a/pgtype/int8range_test.go b/pgtype/int8range_test.go index aefa2f53..1fab1caa 100644 --- a/pgtype/int8range_test.go +++ b/pgtype/int8range_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestInt8rangeTranscode(t *testing.T) { diff --git a/pgtype/integration_benchmark_test.go b/pgtype/integration_benchmark_test.go index 0ee87ba3..ad9c0598 100644 --- a/pgtype/integration_benchmark_test.go +++ b/pgtype/integration_benchmark_test.go @@ -6,9 +6,9 @@ import ( "context" "testing" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *testing.B) { diff --git a/pgtype/integration_benchmark_test.go.erb b/pgtype/integration_benchmark_test.go.erb index f642e6ca..1b091364 100644 --- a/pgtype/integration_benchmark_test.go.erb +++ b/pgtype/integration_benchmark_test.go.erb @@ -6,8 +6,8 @@ import ( "context" "testing" - "github.com/jackc/pgx/v4/pgtype/testutil" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5" ) <% diff --git a/pgtype/interval_test.go b/pgtype/interval_test.go index 6754a222..a8241bf6 100644 --- a/pgtype/interval_test.go +++ b/pgtype/interval_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgtype/json_test.go b/pgtype/json_test.go index a42b7ab4..cb5162d3 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestJSONTranscode(t *testing.T) { diff --git a/pgtype/jsonb_array_test.go b/pgtype/jsonb_array_test.go index 172892bc..0fc4d40e 100644 --- a/pgtype/jsonb_array_test.go +++ b/pgtype/jsonb_array_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestJSONBArrayTranscode(t *testing.T) { diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index 242014f3..3a0d62c2 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestJSONBTranscode(t *testing.T) { diff --git a/pgtype/line_test.go b/pgtype/line_test.go index 6f38d85f..b171a7a5 100644 --- a/pgtype/line_test.go +++ b/pgtype/line_test.go @@ -4,8 +4,8 @@ import ( "context" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestLineTranscode(t *testing.T) { diff --git a/pgtype/lseg_test.go b/pgtype/lseg_test.go index 9122f76b..ce128784 100644 --- a/pgtype/lseg_test.go +++ b/pgtype/lseg_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestLsegTranscode(t *testing.T) { diff --git a/pgtype/macaddr_array_test.go b/pgtype/macaddr_array_test.go index 4941ad80..ac76a052 100644 --- a/pgtype/macaddr_array_test.go +++ b/pgtype/macaddr_array_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestMacaddrArrayTranscode(t *testing.T) { diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go index 9a78521b..5b9d8d88 100644 --- a/pgtype/macaddr_test.go +++ b/pgtype/macaddr_test.go @@ -6,8 +6,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestMacaddrTranscode(t *testing.T) { diff --git a/pgtype/name_test.go b/pgtype/name_test.go index b71ea490..89b16579 100644 --- a/pgtype/name_test.go +++ b/pgtype/name_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestNameTranscode(t *testing.T) { diff --git a/pgtype/numeric_array_test.go b/pgtype/numeric_array_test.go index 82a4fb6c..4542ed3e 100644 --- a/pgtype/numeric_array_test.go +++ b/pgtype/numeric_array_test.go @@ -6,8 +6,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestNumericArrayTranscode(t *testing.T) { diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 22bd22ef..ff53d92b 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -9,8 +9,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/require" ) @@ -289,7 +289,7 @@ func TestNumericAssignTo(t *testing.T) { {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &_i8, expected: _int8(42)}, {src: &pgtype.Numeric{Int: big.NewInt(0)}, dst: &pi8, expected: ((*int8)(nil))}, {src: &pgtype.Numeric{Int: big.NewInt(0)}, dst: &_pi8, expected: ((*_int8)(nil))}, - {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Valid: true}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgx/v4/pgtype/issues/27 + {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Valid: true}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgx/v5/pgtype/issues/27 {src: &pgtype.Numeric{Valid: true, NaN: true}, dst: &f64, expected: math.NaN()}, {src: &pgtype.Numeric{Valid: true, NaN: true}, dst: &f32, expected: float32(math.NaN())}, {src: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}, dst: &f64, expected: math.Inf(1)}, diff --git a/pgtype/numrange_test.go b/pgtype/numrange_test.go index 3e89dc73..a7792faf 100644 --- a/pgtype/numrange_test.go +++ b/pgtype/numrange_test.go @@ -4,8 +4,8 @@ import ( "math/big" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestNumrangeTranscode(t *testing.T) { diff --git a/pgtype/oid_value_test.go b/pgtype/oid_value_test.go index e3d2e014..aecfc149 100644 --- a/pgtype/oid_value_test.go +++ b/pgtype/oid_value_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestOIDValueTranscode(t *testing.T) { diff --git a/pgtype/path_test.go b/pgtype/path_test.go index af410540..8a218fe1 100644 --- a/pgtype/path_test.go +++ b/pgtype/path_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestPathTranscode(t *testing.T) { diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 6540842c..17b8afe1 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -7,9 +7,9 @@ import ( "net" "testing" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgtype" - _ "github.com/jackc/pgx/v4/stdlib" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + _ "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgtype/pgxtype/pgxtype.go b/pgtype/pgxtype/pgxtype.go index a16c0389..ba14c094 100644 --- a/pgtype/pgxtype/pgxtype.go +++ b/pgtype/pgxtype/pgxtype.go @@ -4,9 +4,9 @@ import ( "context" "errors" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" ) type Querier interface { diff --git a/pgtype/point_test.go b/pgtype/point_test.go index b8681cd5..bb4c2126 100644 --- a/pgtype/point_test.go +++ b/pgtype/point_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/require" ) diff --git a/pgtype/polygon_test.go b/pgtype/polygon_test.go index 25cbd7dc..4e7f69ce 100644 --- a/pgtype/polygon_test.go +++ b/pgtype/polygon_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestPolygonTranscode(t *testing.T) { diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go index a27cb098..cb9b6786 100644 --- a/pgtype/qchar_test.go +++ b/pgtype/qchar_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestQCharTranscode(t *testing.T) { diff --git a/pgtype/record_test.go b/pgtype/record_test.go index 6e052b71..921f0975 100644 --- a/pgtype/record_test.go +++ b/pgtype/record_test.go @@ -6,9 +6,9 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) var recordTests = []struct { diff --git a/pgtype/testutil/testutil.go b/pgtype/testutil/testutil.go index 6007d7a4..bfe9b01f 100644 --- a/pgtype/testutil/testutil.go +++ b/pgtype/testutil/testutil.go @@ -8,9 +8,9 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgtype" - _ "github.com/jackc/pgx/v4/stdlib" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + _ "github.com/jackc/pgx/v5/stdlib" ) func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { diff --git a/pgtype/text_array_test.go b/pgtype/text_array_test.go index ce4b0d20..4fa1c39e 100644 --- a/pgtype/text_array_test.go +++ b/pgtype/text_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgtype/text_test.go b/pgtype/text_test.go index 17201764..dca6af4a 100644 --- a/pgtype/text_test.go +++ b/pgtype/text_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestTextTranscode(t *testing.T) { diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go index 133e5a35..ef24005a 100644 --- a/pgtype/tid_test.go +++ b/pgtype/tid_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestTIDTranscode(t *testing.T) { diff --git a/pgtype/time_test.go b/pgtype/time_test.go index 008b1448..f710ed03 100644 --- a/pgtype/time_test.go +++ b/pgtype/time_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestTimeTranscode(t *testing.T) { diff --git a/pgtype/timestamp_array_test.go b/pgtype/timestamp_array_test.go index c354f0cf..f78bf14e 100644 --- a/pgtype/timestamp_array_test.go +++ b/pgtype/timestamp_array_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestTimestampArrayTranscode(t *testing.T) { diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index f906462d..00e1b2c7 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -6,8 +6,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/require" ) diff --git a/pgtype/timestamptz_array_test.go b/pgtype/timestamptz_array_test.go index fd40cd35..e00b7d7f 100644 --- a/pgtype/timestamptz_array_test.go +++ b/pgtype/timestamptz_array_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestTimestamptzArrayTranscode(t *testing.T) { diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index 8815d363..6cfe9a53 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -6,8 +6,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/require" ) diff --git a/pgtype/tsrange_test.go b/pgtype/tsrange_test.go index f24f824b..f7c4dc84 100644 --- a/pgtype/tsrange_test.go +++ b/pgtype/tsrange_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestTsrangeTranscode(t *testing.T) { diff --git a/pgtype/tstzrange_test.go b/pgtype/tstzrange_test.go index bf604ed5..ab1ab44f 100644 --- a/pgtype/tstzrange_test.go +++ b/pgtype/tstzrange_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/require" ) diff --git a/pgtype/uuid_array_test.go b/pgtype/uuid_array_test.go index b4ec2f86..b432d0f8 100644 --- a/pgtype/uuid_array_test.go +++ b/pgtype/uuid_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestUUIDArrayTranscode(t *testing.T) { diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index 036c0dd8..9701db74 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -5,8 +5,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/require" ) diff --git a/pgtype/varbit_test.go b/pgtype/varbit_test.go index 1ca5357b..031d5fa8 100644 --- a/pgtype/varbit_test.go +++ b/pgtype/varbit_test.go @@ -3,8 +3,8 @@ package pgtype_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestVarbitTranscode(t *testing.T) { diff --git a/pgtype/varchar_array_test.go b/pgtype/varchar_array_test.go index c45162a0..2d437274 100644 --- a/pgtype/varchar_array_test.go +++ b/pgtype/varchar_array_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestVarcharArrayTranscode(t *testing.T) { diff --git a/pgtype/xid_test.go b/pgtype/xid_test.go index 5b30753a..ee11fa41 100644 --- a/pgtype/xid_test.go +++ b/pgtype/xid_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/jackc/pgx/v4/pgtype" - "github.com/jackc/pgx/v4/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestXIDTranscode(t *testing.T) { diff --git a/pgtype/zeronull/float8.go b/pgtype/zeronull/float8.go index 3d9d4d22..a4efb1ed 100644 --- a/pgtype/zeronull/float8.go +++ b/pgtype/zeronull/float8.go @@ -3,7 +3,7 @@ package zeronull import ( "database/sql/driver" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5/pgtype" ) type Float8 float64 diff --git a/pgtype/zeronull/float8_test.go b/pgtype/zeronull/float8_test.go index cdc51245..b0331faa 100644 --- a/pgtype/zeronull/float8_test.go +++ b/pgtype/zeronull/float8_test.go @@ -3,8 +3,8 @@ package zeronull_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype/testutil" - "github.com/jackc/pgx/v4/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype/zeronull" ) func TestFloat8Transcode(t *testing.T) { diff --git a/pgtype/zeronull/int2.go b/pgtype/zeronull/int2.go index 011e96d5..81e89ab3 100644 --- a/pgtype/zeronull/int2.go +++ b/pgtype/zeronull/int2.go @@ -3,7 +3,7 @@ package zeronull import ( "database/sql/driver" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5/pgtype" ) type Int2 int16 diff --git a/pgtype/zeronull/int2_test.go b/pgtype/zeronull/int2_test.go index 9cbd75db..ff78a6e6 100644 --- a/pgtype/zeronull/int2_test.go +++ b/pgtype/zeronull/int2_test.go @@ -3,8 +3,8 @@ package zeronull_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype/testutil" - "github.com/jackc/pgx/v4/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype/zeronull" ) func TestInt2Transcode(t *testing.T) { diff --git a/pgtype/zeronull/int4.go b/pgtype/zeronull/int4.go index 9d34c163..4e06435a 100644 --- a/pgtype/zeronull/int4.go +++ b/pgtype/zeronull/int4.go @@ -3,7 +3,7 @@ package zeronull import ( "database/sql/driver" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5/pgtype" ) type Int4 int32 diff --git a/pgtype/zeronull/int4_test.go b/pgtype/zeronull/int4_test.go index 456f15d2..3510aa9d 100644 --- a/pgtype/zeronull/int4_test.go +++ b/pgtype/zeronull/int4_test.go @@ -3,8 +3,8 @@ package zeronull_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype/testutil" - "github.com/jackc/pgx/v4/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype/zeronull" ) func TestInt4Transcode(t *testing.T) { diff --git a/pgtype/zeronull/int8.go b/pgtype/zeronull/int8.go index 185fdb8f..3c89a1ec 100644 --- a/pgtype/zeronull/int8.go +++ b/pgtype/zeronull/int8.go @@ -3,7 +3,7 @@ package zeronull import ( "database/sql/driver" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5/pgtype" ) type Int8 int64 diff --git a/pgtype/zeronull/int8_test.go b/pgtype/zeronull/int8_test.go index ca261d36..97fe9cd0 100644 --- a/pgtype/zeronull/int8_test.go +++ b/pgtype/zeronull/int8_test.go @@ -3,8 +3,8 @@ package zeronull_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype/testutil" - "github.com/jackc/pgx/v4/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype/zeronull" ) func TestInt8Transcode(t *testing.T) { diff --git a/pgtype/zeronull/text.go b/pgtype/zeronull/text.go index 5fc9d94b..33ce367f 100644 --- a/pgtype/zeronull/text.go +++ b/pgtype/zeronull/text.go @@ -3,7 +3,7 @@ package zeronull import ( "database/sql/driver" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5/pgtype" ) type Text string diff --git a/pgtype/zeronull/text_test.go b/pgtype/zeronull/text_test.go index 8595253c..e4293024 100644 --- a/pgtype/zeronull/text_test.go +++ b/pgtype/zeronull/text_test.go @@ -3,8 +3,8 @@ package zeronull_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype/testutil" - "github.com/jackc/pgx/v4/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype/zeronull" ) func TestTextTranscode(t *testing.T) { diff --git a/pgtype/zeronull/timestamp.go b/pgtype/zeronull/timestamp.go index 193bc959..d96dbf08 100644 --- a/pgtype/zeronull/timestamp.go +++ b/pgtype/zeronull/timestamp.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "time" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5/pgtype" ) type Timestamp time.Time diff --git a/pgtype/zeronull/timestamp_test.go b/pgtype/zeronull/timestamp_test.go index 787c6de9..2eb072c6 100644 --- a/pgtype/zeronull/timestamp_test.go +++ b/pgtype/zeronull/timestamp_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype/testutil" - "github.com/jackc/pgx/v4/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype/zeronull" ) func TestTimestampTranscode(t *testing.T) { diff --git a/pgtype/zeronull/timestamptz.go b/pgtype/zeronull/timestamptz.go index 5ecefe64..46448607 100644 --- a/pgtype/zeronull/timestamptz.go +++ b/pgtype/zeronull/timestamptz.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "time" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5/pgtype" ) type Timestamptz time.Time diff --git a/pgtype/zeronull/timestamptz_test.go b/pgtype/zeronull/timestamptz_test.go index dcbd0d58..e288b9e8 100644 --- a/pgtype/zeronull/timestamptz_test.go +++ b/pgtype/zeronull/timestamptz_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgtype/testutil" - "github.com/jackc/pgx/v4/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype/zeronull" ) func TestTimestamptzTranscode(t *testing.T) { diff --git a/pgtype/zeronull/uuid.go b/pgtype/zeronull/uuid.go index 2e54a933..8c0978b3 100644 --- a/pgtype/zeronull/uuid.go +++ b/pgtype/zeronull/uuid.go @@ -3,7 +3,7 @@ package zeronull import ( "database/sql/driver" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5/pgtype" ) type UUID [16]byte diff --git a/pgtype/zeronull/uuid_test.go b/pgtype/zeronull/uuid_test.go index e79503c6..913698d9 100644 --- a/pgtype/zeronull/uuid_test.go +++ b/pgtype/zeronull/uuid_test.go @@ -3,8 +3,8 @@ package zeronull_test import ( "testing" - "github.com/jackc/pgx/v4/pgtype/testutil" - "github.com/jackc/pgx/v4/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype/zeronull" ) func TestUUIDTranscode(t *testing.T) { diff --git a/pgxpool/batch_results.go b/pgxpool/batch_results.go index 42f597a5..8bec35cb 100644 --- a/pgxpool/batch_results.go +++ b/pgxpool/batch_results.go @@ -1,8 +1,8 @@ package pgxpool import ( - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) type errBatchResults struct { diff --git a/pgxpool/bench_test.go b/pgxpool/bench_test.go index 9ec63ca3..704371db 100644 --- a/pgxpool/bench_test.go +++ b/pgxpool/bench_test.go @@ -5,8 +5,8 @@ import ( "os" "testing" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgxpool" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/require" ) diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go index 8ef85052..c6f3b77b 100644 --- a/pgxpool/common_test.go +++ b/pgxpool/common_test.go @@ -5,10 +5,10 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4/pgxpool" + "github.com/jackc/pgx/v5/pgxpool" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgxpool/conn.go b/pgxpool/conn.go index fee1ff2b..8798db4b 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -4,8 +4,8 @@ import ( "context" "time" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/puddle" ) diff --git a/pgxpool/conn_test.go b/pgxpool/conn_test.go index c03ae13e..ff34f969 100644 --- a/pgxpool/conn_test.go +++ b/pgxpool/conn_test.go @@ -5,7 +5,7 @@ import ( "os" "testing" - "github.com/jackc/pgx/v4/pgxpool" + "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/require" ) diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 84d05ad9..41fb4d5b 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -8,8 +8,8 @@ import ( "sync" "time" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/puddle" ) diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 54b688a1..20586b81 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgxpool" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgxpool/rows.go b/pgxpool/rows.go index cc45c76f..0c97dc91 100644 --- a/pgxpool/rows.go +++ b/pgxpool/rows.go @@ -2,8 +2,8 @@ package pgxpool import ( "github.com/jackc/pgproto3/v2" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) type errRows struct { diff --git a/pgxpool/tx.go b/pgxpool/tx.go index 962b8725..a82b2176 100644 --- a/pgxpool/tx.go +++ b/pgxpool/tx.go @@ -3,8 +3,8 @@ package pgxpool import ( "context" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) // Tx represents a database transaction acquired from a Pool. diff --git a/pgxpool/tx_test.go b/pgxpool/tx_test.go index d66ad338..e32d3efe 100644 --- a/pgxpool/tx_test.go +++ b/pgxpool/tx_test.go @@ -5,7 +5,7 @@ import ( "os" "testing" - "github.com/jackc/pgx/v4/pgxpool" + "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/require" ) diff --git a/query_test.go b/query_test.go index 393235b7..7157e5dd 100644 --- a/query_test.go +++ b/query_test.go @@ -15,10 +15,10 @@ import ( "github.com/cockroachdb/apd" "github.com/gofrs/uuid" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" - "github.com/jackc/pgx/v4/pgconn/stmtcache" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgconn/stmtcache" + "github.com/jackc/pgx/v5/pgtype" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/rows.go b/rows.go index 14bc50ba..cc6e26d5 100644 --- a/rows.go +++ b/rows.go @@ -7,8 +7,8 @@ import ( "time" "github.com/jackc/pgproto3/v2" - "github.com/jackc/pgx/v4/pgconn" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" ) // Rows is the result set returned from *Conn.Query. Rows must be closed before diff --git a/stdlib/sql.go b/stdlib/sql.go index c9e36eef..39b7524a 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -63,9 +63,9 @@ import ( "sync" "time" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" ) // Only intrinsic types should be binary format with database/sql. diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 6b6440f7..e5eb47bf 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -13,9 +13,9 @@ import ( "time" "github.com/Masterminds/semver/v3" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" - "github.com/jackc/pgx/v4/stdlib" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/tx.go b/tx.go index 44311f1f..3ed0ca67 100644 --- a/tx.go +++ b/tx.go @@ -7,7 +7,7 @@ import ( "fmt" "strconv" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5/pgconn" ) // TxIsoLevel is the transaction isolation level (serializable, repeatable read, read committed or read uncommitted) diff --git a/tx_test.go b/tx_test.go index 85083830..23d76663 100644 --- a/tx_test.go +++ b/tx_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgconn" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/require" ) diff --git a/values.go b/values.go index 2978e5a3..2f328b82 100644 --- a/values.go +++ b/values.go @@ -8,7 +8,7 @@ import ( "time" "github.com/jackc/pgio" - "github.com/jackc/pgx/v4/pgtype" + "github.com/jackc/pgx/v5/pgtype" ) // PostgreSQL format codes diff --git a/values_test.go b/values_test.go index 47aacf89..27fbe977 100644 --- a/values_test.go +++ b/values_test.go @@ -10,7 +10,7 @@ import ( "testing" "time" - "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) From 85b08ac6633ba72f177407217bb5ccf54b58df04 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 13:30:36 -0600 Subject: [PATCH 0760/1158] Fix some previously broken comment links --- pgtype/text_array_test.go | 2 +- pgtype/timestamp_test.go | 2 +- pgtype/timestamptz_test.go | 2 +- pgtype/tstzrange_test.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pgtype/text_array_test.go b/pgtype/text_array_test.go index 4fa1c39e..22e2ca27 100644 --- a/pgtype/text_array_test.go +++ b/pgtype/text_array_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/require" ) -// https://github.com/jackc/pgx/v4/pgtype/issues/78 +// https://github.com/jackc/pgtype/issues/78 func TestTextArrayDecodeTextNull(t *testing.T) { textArray := &pgtype.TextArray{} err := textArray.DecodeText(nil, []byte(`{abc,"NULL",NULL,def}`)) diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index 00e1b2c7..85de5c94 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -99,7 +99,7 @@ func TestTimestampNanosecondsTruncated(t *testing.T) { } } -// https://github.com/jackc/pgx/v4/pgtype/issues/74 +// https://github.com/jackc/pgtype/issues/74 func TestTimestampDecodeTextInvalid(t *testing.T) { tstz := &pgtype.Timestamp{} err := tstz.DecodeText(nil, []byte(`eeeee`)) diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index 6cfe9a53..332fc8a7 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -99,7 +99,7 @@ func TestTimestamptzNanosecondsTruncated(t *testing.T) { } } -// https://github.com/jackc/pgx/v4/pgtype/issues/74 +// https://github.com/jackc/pgtype/issues/74 func TestTimestamptzDecodeTextInvalid(t *testing.T) { tstz := &pgtype.Timestamptz{} err := tstz.DecodeText(nil, []byte(`eeeee`)) diff --git a/pgtype/tstzrange_test.go b/pgtype/tstzrange_test.go index ab1ab44f..5d0b750f 100644 --- a/pgtype/tstzrange_test.go +++ b/pgtype/tstzrange_test.go @@ -41,7 +41,7 @@ func TestTstzrangeTranscode(t *testing.T) { }) } -// https://github.com/jackc/pgx/v4/pgtype/issues/74 +// https://github.com/jackc/pgtype/issues/74 func TestTstzRangeDecodeTextInvalid(t *testing.T) { tstzrange := &pgtype.Tstzrange{} err := tstzrange.DecodeText(nil, []byte(`[eeee,)`)) From 81168a61d1572e3ebc8e61cf5e505f25f540fde8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 13:32:50 -0600 Subject: [PATCH 0761/1158] Update go.mod go version to 1.17 --- go.mod | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 737db8cc..a61252d5 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/jackc/pgx/v5 -go 1.13 +go 1.17 require ( github.com/Masterminds/semver/v3 v3.1.1 @@ -14,8 +14,6 @@ require ( github.com/jackc/pgproto3/v2 v2.2.0 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/jackc/puddle v1.2.0 - github.com/lib/pq v1.10.4 // indirect - github.com/mattn/go-colorable v0.1.12 // indirect github.com/rs/zerolog v1.15.0 github.com/shopspring/decimal v1.2.0 github.com/sirupsen/logrus v1.4.2 @@ -25,3 +23,24 @@ require ( golang.org/x/text v0.3.6 gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec ) + +require ( + github.com/BurntSushi/toml v0.3.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-logfmt/logfmt v0.5.0 // indirect + github.com/go-stack/stack v1.8.0 // indirect + github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect + github.com/lib/pq v1.10.4 // indirect + github.com/mattn/go-colorable v0.1.12 // indirect + github.com/mattn/go-isatty v0.0.14 // indirect + github.com/pkg/errors v0.8.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + go.uber.org/atomic v1.5.0 // indirect + go.uber.org/multierr v1.3.0 // indirect + go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee // indirect + golang.org/x/lint v0.0.0-20190930215403-16217165b5de // indirect + golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6 // indirect + golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect + honnef.co/go/tools v0.0.1-2019.2.3 // indirect +) From 6b2a0d99a2c0582005172d6c1c38fed93efbb77c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 13:37:13 -0600 Subject: [PATCH 0762/1158] Run CI on v5-dev branch --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 36a9ec4e..b13e8c2f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,7 +2,7 @@ name: CI on: push: - branches: [ master ] + branches: [ master, v5-dev ] pull_request: branches: [ master ] From 8c9646dbfe00af419799d787e1db0c09804a1c60 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 13:45:37 -0600 Subject: [PATCH 0763/1158] Remove github.com/cockroachdb/apd test dependency --- go.mod | 3 --- go.sum | 3 --- query_test.go | 3 +-- 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/go.mod b/go.mod index a61252d5..9b539ce1 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.17 require ( github.com/Masterminds/semver/v3 v3.1.1 - github.com/cockroachdb/apd v1.1.0 github.com/go-kit/log v0.1.0 github.com/gofrs/uuid v4.0.0+incompatible github.com/jackc/chunkreader/v2 v2.0.1 @@ -30,10 +29,8 @@ require ( github.com/go-logfmt/logfmt v0.5.0 // indirect github.com/go-stack/stack v1.8.0 // indirect github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect - github.com/lib/pq v1.10.4 // indirect github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect - github.com/pkg/errors v0.8.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect go.uber.org/atomic v1.5.0 // indirect go.uber.org/multierr v1.3.0 // indirect diff --git a/go.sum b/go.sum index 97b704ee..1f7e7b63 100644 --- a/go.sum +++ b/go.sum @@ -2,7 +2,6 @@ github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= -github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -73,8 +72,6 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.10.4 h1:SO9z7FRPzA03QhHKJrH5BXA6HU1rS4V2nIVrrNC1iYk= -github.com/lib/pq v1.10.4/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= diff --git a/query_test.go b/query_test.go index 7157e5dd..71e05197 100644 --- a/query_test.go +++ b/query_test.go @@ -13,7 +13,6 @@ import ( "testing" "time" - "github.com/cockroachdb/apd" "github.com/gofrs/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" @@ -1217,7 +1216,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes mustExec(t, conn, "create temporary table t(n numeric)") - var d *apd.Decimal + var d *sql.NullInt64 commandTag, err := conn.Exec(context.Background(), `insert into t(n) values($1)`, d) if err != nil { t.Fatal(err) From 9ae745219633043e2e714fa2c96687889f425668 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 14:07:52 -0600 Subject: [PATCH 0764/1158] Remove Go 1.16 from CI By the time v5 is released 1.17 will be the minimum supported version. May as well save some CI time in the mean while. --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b13e8c2f..af164815 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: - go-version: [1.16, 1.17] + go-version: [1.17] pg-version: [10, 11, 12, 13, 14, cockroachdb] include: - pg-version: 10 From 5fbf907471b11b970cf0b7799410007ea83d3df5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 14:09:37 -0600 Subject: [PATCH 0765/1158] Temporarily remove cockroachdb from CI pgtype has a ton of tests that don't work on CockroachDB. And because of how the tests are structured it is difficult to skip just those tests. pgtype may have significant changes before v5 is released so delay updating these tests. --- .github/workflows/ci.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index af164815..bda3789b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,8 @@ jobs: strategy: matrix: go-version: [1.17] - pg-version: [10, 11, 12, 13, 14, cockroachdb] + pg-version: [10, 11, 12, 13, 14] + # pg-version: [10, 11, 12, 13, 14, cockroachdb] include: - pg-version: 10 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test @@ -57,9 +58,9 @@ jobs: pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - - pg-version: cockroachdb - pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" - pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" + # - pg-version: cockroachdb + # pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" + # pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" steps: From 066908d4f823e0e01fe688574d5a146241b8a8d5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 14:15:22 -0600 Subject: [PATCH 0766/1158] Temporarily remove all PG versions but 14 from CI Same issue as previous commit removing CockroachDB. numeric type only supports infinity on PG 14 and there is no easy way in the current test structure to skip tests based on server version. --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bda3789b..c2a67cf6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,7 @@ jobs: strategy: matrix: go-version: [1.17] - pg-version: [10, 11, 12, 13, 14] + pg-version: [14] # pg-version: [10, 11, 12, 13, 14, cockroachdb] include: - pg-version: 10 From 1b416b36dc4a5b95b5ce4c241ca313dea05e3141 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 14:26:04 -0600 Subject: [PATCH 0767/1158] Finish temp removal of PG < 14 from CI --- .github/workflows/ci.yml | 64 ++++++++++++++++++++-------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c2a67cf6..a62d38e0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,38 +18,38 @@ jobs: pg-version: [14] # pg-version: [10, 11, 12, 13, 14, cockroachdb] include: - - pg-version: 10 - pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" - pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - - pg-version: 11 - pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" - pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - - pg-version: 12 - pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" - pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - - pg-version: 13 - pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" - pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + # - pg-version: 10 + # pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + # pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + # pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + # pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + # pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + # pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + # pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + # - pg-version: 11 + # pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + # pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + # pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + # pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + # pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + # pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + # pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + # - pg-version: 12 + # pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + # pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + # pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + # pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + # pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + # pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + # pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + # - pg-version: 13 + # pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + # pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + # pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + # pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + # pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + # pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + # pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: 14 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test From 9ab821620f16afd4c8fa8ff259102e02e31a366e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 14:27:00 -0600 Subject: [PATCH 0768/1158] Remove github.com/Masterminds/semver/v3 test dependency --- go.mod | 1 - go.sum | 2 -- stdlib/sql_test.go | 18 +++++++----------- 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/go.mod b/go.mod index 9b539ce1..ea8840d7 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/jackc/pgx/v5 go 1.17 require ( - github.com/Masterminds/semver/v3 v3.1.1 github.com/go-kit/log v0.1.0 github.com/gofrs/uuid v4.0.0+incompatible github.com/jackc/chunkreader/v2 v2.0.1 diff --git a/go.sum b/go.sum index 1f7e7b63..481a6947 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= -github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index e5eb47bf..07498843 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -9,10 +9,10 @@ import ( "os" "reflect" "regexp" + "strconv" "testing" "time" - "github.com/Masterminds/semver/v3" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/stdlib" @@ -46,7 +46,7 @@ func skipCockroachDB(t testing.TB, db *sql.DB, msg string) { require.NoError(t, err) } -func skipPostgreSQLVersion(t testing.TB, db *sql.DB, constraintStr, msg string) { +func skipPostgreSQLVersionLessThan(t testing.TB, db *sql.DB, minVersion int64) { conn, err := db.Conn(context.Background()) require.NoError(t, err) defer conn.Close() @@ -54,25 +54,21 @@ func skipPostgreSQLVersion(t testing.TB, db *sql.DB, constraintStr, msg string) err = conn.Raw(func(driverConn interface{}) error { conn := driverConn.(*stdlib.Conn).Conn() serverVersionStr := conn.PgConn().ParameterStatus("server_version") - serverVersionStr = regexp.MustCompile(`^[0-9.]+`).FindString(serverVersionStr) + serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) // if not PostgreSQL do nothing if serverVersionStr == "" { return nil } - serverVersion, err := semver.NewVersion(serverVersionStr) + serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64) if err != nil { return err } - c, err := semver.NewConstraint(constraintStr) - if err != nil { - return err + if serverVersion < minVersion { + t.Skipf("Test requires PostgreSQL v%d+", minVersion) } - if c.Check(serverVersion) { - t.Skip(msg) - } return nil }) require.NoError(t, err) @@ -1093,7 +1089,7 @@ func TestRegisterConnConfig(t *testing.T) { // https://github.com/jackc/pgx/issues/958 func TestConnQueryRowConstraintErrors(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { - skipPostgreSQLVersion(t, db, "< 11", "Test requires PG 11+") + skipPostgreSQLVersionLessThan(t, db, 11) skipCockroachDB(t, db, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") _, err := db.Exec(`create temporary table defer_test ( From 731312fea86d60ed23aced74335a8cd6bf8b4dbf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 14:32:32 -0600 Subject: [PATCH 0769/1158] Remove github.com/shopspring/decimal test dependency --- go.mod | 1 - go.sum | 2 -- query_test.go | 37 +++++++++---------------------------- 3 files changed, 9 insertions(+), 31 deletions(-) diff --git a/go.mod b/go.mod index ea8840d7..a540c5f5 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,6 @@ require ( github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/jackc/puddle v1.2.0 github.com/rs/zerolog v1.15.0 - github.com/shopspring/decimal v1.2.0 github.com/sirupsen/logrus v1.4.2 github.com/stretchr/testify v1.7.0 go.uber.org/zap v1.13.0 diff --git a/go.sum b/go.sum index 481a6947..0ac9ca50 100644 --- a/go.sum +++ b/go.sum @@ -88,8 +88,6 @@ github.com/rs/zerolog v1.15.0 h1:uPRuwkWF4J6fGsJ2R0Gn2jB1EQiav9k3S6CSdygQJXY= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= -github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= -github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= diff --git a/query_test.go b/query_test.go index 71e05197..63040894 100644 --- a/query_test.go +++ b/query_test.go @@ -18,7 +18,6 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn/stmtcache" "github.com/jackc/pgx/v5/pgtype" - "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1154,55 +1153,37 @@ func TestReadingNullByteArrays(t *testing.T) { } } -// Use github.com/shopspring/decimal as real-world database/sql custom type -// to test against. func TestConnQueryDatabaseSQLScanner(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - var num decimal.Decimal + var num sql.NullFloat64 - err := conn.QueryRow(context.Background(), "select '1234.567'::decimal").Scan(&num) + err := conn.QueryRow(context.Background(), "select '1234.567'::float8").Scan(&num) if err != nil { t.Fatalf("Scan failed: %v", err) } - expected, err := decimal.NewFromString("1234.567") - if err != nil { - t.Fatal(err) - } - - if !num.Equals(expected) { - t.Errorf("Expected num to be %v, but it was %v", expected, num) - } + require.True(t, num.Valid) + require.Equal(t, 1234.567, num.Float64) ensureConnValid(t, conn) } -// Use github.com/shopspring/decimal as real-world database/sql custom type -// to test against. func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - expected, err := decimal.NewFromString("1234.567") - if err != nil { - t.Fatal(err) - } - var num decimal.Decimal + expected := sql.NullFloat64{Float64: 1234.567, Valid: true} + var actual sql.NullFloat64 - err = conn.QueryRow(context.Background(), "select $1::decimal", &expected).Scan(&num) - if err != nil { - t.Fatalf("Scan failed: %v", err) - } - - if !num.Equals(expected) { - t.Errorf("Expected num to be %v, but it was %v", expected, num) - } + err := conn.QueryRow(context.Background(), "select $1::float8", &expected).Scan(&actual) + require.NoError(t, err) + require.Equal(t, expected, actual) ensureConnValid(t, conn) } From ef2b70edadd1ad7366e12c10e55da40f6bb40286 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 14:37:02 -0600 Subject: [PATCH 0770/1158] Remove github.com/gofrs/uuid test dependency --- go.mod | 1 - go.sum | 2 -- query_test.go | 19 +++++-------------- 3 files changed, 5 insertions(+), 17 deletions(-) diff --git a/go.mod b/go.mod index a540c5f5..b9ae1f8e 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.17 require ( github.com/go-kit/log v0.1.0 - github.com/gofrs/uuid v4.0.0+incompatible github.com/jackc/chunkreader/v2 v2.0.1 github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 diff --git a/go.sum b/go.sum index 0ac9ca50..9fe60dde 100644 --- a/go.sum +++ b/go.sum @@ -13,8 +13,6 @@ github.com/go-logfmt/logfmt v0.5.0 h1:TrB8swr/68K7m9CcGut2g3UOihhbcbiMAYiuTXdEih github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= -github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= diff --git a/query_test.go b/query_test.go index 63040894..e725bd40 100644 --- a/query_test.go +++ b/query_test.go @@ -13,7 +13,6 @@ import ( "testing" "time" - "github.com/gofrs/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn/stmtcache" @@ -1215,20 +1214,12 @@ func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t * conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - expected, err := uuid.FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8") - if err != nil { - t.Fatal(err) - } + var actual sql.NullString + err := conn.QueryRow(context.Background(), "select '6ba7b810-9dad-11d1-80b4-00c04fd430c8'::uuid").Scan(&actual) + require.NoError(t, err) - var u2 uuid.UUID - err = conn.QueryRow(context.Background(), "select $1::uuid", expected).Scan(&u2) - if err != nil { - t.Fatalf("Scan failed: %v", err) - } - - if expected != u2 { - t.Errorf("Expected u2 to be %v, but it was %v", expected, u2) - } + require.True(t, actual.Valid) + require.Equal(t, "6ba7b810-9dad-11d1-80b4-00c04fd430c8", actual.String) ensureConnValid(t, conn) } From 8e2e8a700950f7cd625e2870ce566aab97e7bcdf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 14:52:31 -0600 Subject: [PATCH 0771/1158] Remove external log adapters --- examples/url_shortener/main.go | 18 ++--- go.mod | 18 ----- go.sum | 42 ------------ log/kitlogadapter/adapter.go | 39 ----------- log/log15adapter/adapter.go | 49 -------------- log/logrusadapter/adapter.go | 42 ------------ log/zapadapter/adapter.go | 42 ------------ log/zerologadapter/adapter.go | 102 ----------------------------- log/zerologadapter/adapter_test.go | 98 --------------------------- 9 files changed, 5 insertions(+), 445 deletions(-) delete mode 100644 log/kitlogadapter/adapter.go delete mode 100644 log/log15adapter/adapter.go delete mode 100644 log/logrusadapter/adapter.go delete mode 100644 log/zapadapter/adapter.go delete mode 100644 log/zerologadapter/adapter.go delete mode 100644 log/zerologadapter/adapter_test.go diff --git a/examples/url_shortener/main.go b/examples/url_shortener/main.go index cb474f32..fb267b74 100644 --- a/examples/url_shortener/main.go +++ b/examples/url_shortener/main.go @@ -3,13 +3,12 @@ package main import ( "context" "io/ioutil" + "log" "net/http" "os" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/log/log15adapter" "github.com/jackc/pgx/v5/pgxpool" - log "gopkg.in/inconshreveable/log15.v2" ) var db *pgxpool.Pool @@ -71,28 +70,21 @@ func urlHandler(w http.ResponseWriter, req *http.Request) { } func main() { - logger := log15adapter.NewLogger(log.New("module", "pgx")) - poolConfig, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL")) if err != nil { - log.Crit("Unable to parse DATABASE_URL", "error", err) - os.Exit(1) + log.Fatalln("Unable to parse DATABASE_URL:", err) } - poolConfig.ConnConfig.Logger = logger - db, err = pgxpool.ConnectConfig(context.Background(), poolConfig) if err != nil { - log.Crit("Unable to create connection pool", "error", err) - os.Exit(1) + log.Fatalln("Unable to create connection pool:", err) } http.HandleFunc("/", urlHandler) - log.Info("Starting URL shortener on localhost:8080") + log.Println("Starting URL shortener on localhost:8080") err = http.ListenAndServe("localhost:8080", nil) if err != nil { - log.Crit("Unable to start web server", "error", err) - os.Exit(1) + log.Fatalln("Unable to start web server:" err) } } diff --git a/go.mod b/go.mod index b9ae1f8e..8cbffeb8 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/jackc/pgx/v5 go 1.17 require ( - github.com/go-kit/log v0.1.0 github.com/jackc/chunkreader/v2 v2.0.1 github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 @@ -11,30 +10,13 @@ require ( github.com/jackc/pgproto3/v2 v2.2.0 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/jackc/puddle v1.2.0 - github.com/rs/zerolog v1.15.0 - github.com/sirupsen/logrus v1.4.2 github.com/stretchr/testify v1.7.0 - go.uber.org/zap v1.13.0 golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 golang.org/x/text v0.3.6 - gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec ) require ( - github.com/BurntSushi/toml v0.3.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-logfmt/logfmt v0.5.0 // indirect - github.com/go-stack/stack v1.8.0 // indirect - github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect - github.com/mattn/go-colorable v0.1.12 // indirect - github.com/mattn/go-isatty v0.0.14 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - go.uber.org/atomic v1.5.0 // indirect - go.uber.org/multierr v1.3.0 // indirect - go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee // indirect - golang.org/x/lint v0.0.0-20190930215403-16217165b5de // indirect - golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6 // indirect - golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5 // indirect gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect - honnef.co/go/tools v0.0.1-2019.2.3 // indirect ) diff --git a/go.sum b/go.sum index 9fe60dde..cca0fd29 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -7,13 +5,7 @@ github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7Do github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-kit/log v0.1.0 h1:DGJh0Sm43HbOeYDNnVZFl8BvcYVvjD5bqYJvp0REbwQ= -github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= -github.com/go-logfmt/logfmt v0.5.0 h1:TrB8swr/68K7m9CcGut2g3UOihhbcbiMAYiuTXdEih4= -github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= -github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= @@ -55,9 +47,7 @@ github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0f github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.2.0 h1:DNDKdn/pDrWvDWyT2FYvpZVE81OAhWrjCv19I9n108Q= github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -69,25 +59,17 @@ github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= -github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= -github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= -github.com/rs/zerolog v1.15.0 h1:uPRuwkWF4J6fGsJ2R0Gn2jB1EQiav9k3S6CSdygQJXY= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= -github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -101,29 +83,17 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/atomic v1.5.0 h1:OI5t8sDa1Or+q8AeE+yKeB/SDYioSHAgcVljj9JIETY= -go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -go.uber.org/multierr v1.3.0 h1:sFPn2GLc3poCkfrpIXGhBD2X0CMIo4Q/zSULXrj/+uc= -go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= -go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= -go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -go.uber.org/zap v1.13.0 h1:nR6NoDBgAf67s68NhaXbsojM+2gxp3S1hWkHDl27pVU= -go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= -golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -140,9 +110,6 @@ golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6 h1:foEbQz/B0Oz6YIqu/69kfXPYeFQAuuMYFkjaqXzl5Wo= -golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -152,13 +119,8 @@ golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5 h1:hKsoRgsbwY1NafxrwTs+k64bikrLBkAgPir1TNCj3Zs= -golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -166,11 +128,7 @@ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8T gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec h1:RlWgLqCMMIYYEVcAR5MDsuHlVkaIPDAF+5Dehzg8L5A= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= -honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/log/kitlogadapter/adapter.go b/log/kitlogadapter/adapter.go deleted file mode 100644 index 95fc6c49..00000000 --- a/log/kitlogadapter/adapter.go +++ /dev/null @@ -1,39 +0,0 @@ -package kitlogadapter - -import ( - "context" - - "github.com/go-kit/log" - kitlevel "github.com/go-kit/log/level" - "github.com/jackc/pgx/v5" -) - -type Logger struct { - l log.Logger -} - -func NewLogger(l log.Logger) *Logger { - return &Logger{l: l} -} - -func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { - logger := l.l - for k, v := range data { - logger = log.With(logger, k, v) - } - - switch level { - case pgx.LogLevelTrace: - logger.Log("PGX_LOG_LEVEL", level, "msg", msg) - case pgx.LogLevelDebug: - kitlevel.Debug(logger).Log("msg", msg) - case pgx.LogLevelInfo: - kitlevel.Info(logger).Log("msg", msg) - case pgx.LogLevelWarn: - kitlevel.Warn(logger).Log("msg", msg) - case pgx.LogLevelError: - kitlevel.Error(logger).Log("msg", msg) - default: - logger.Log("INVALID_PGX_LOG_LEVEL", level, "error", msg) - } -} diff --git a/log/log15adapter/adapter.go b/log/log15adapter/adapter.go deleted file mode 100644 index dece65b4..00000000 --- a/log/log15adapter/adapter.go +++ /dev/null @@ -1,49 +0,0 @@ -// Package log15adapter provides a logger that writes to a github.com/inconshreveable/log15.Logger -// log. -package log15adapter - -import ( - "context" - - "github.com/jackc/pgx/v5" -) - -// Log15Logger interface defines the subset of -// github.com/inconshreveable/log15.Logger that this adapter uses. -type Log15Logger interface { - Debug(msg string, ctx ...interface{}) - Info(msg string, ctx ...interface{}) - Warn(msg string, ctx ...interface{}) - Error(msg string, ctx ...interface{}) - Crit(msg string, ctx ...interface{}) -} - -type Logger struct { - l Log15Logger -} - -func NewLogger(l Log15Logger) *Logger { - return &Logger{l: l} -} - -func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { - logArgs := make([]interface{}, 0, len(data)) - for k, v := range data { - logArgs = append(logArgs, k, v) - } - - switch level { - case pgx.LogLevelTrace: - l.l.Debug(msg, append(logArgs, "PGX_LOG_LEVEL", level)...) - case pgx.LogLevelDebug: - l.l.Debug(msg, logArgs...) - case pgx.LogLevelInfo: - l.l.Info(msg, logArgs...) - case pgx.LogLevelWarn: - l.l.Warn(msg, logArgs...) - case pgx.LogLevelError: - l.l.Error(msg, logArgs...) - default: - l.l.Error(msg, append(logArgs, "INVALID_PGX_LOG_LEVEL", level)...) - } -} diff --git a/log/logrusadapter/adapter.go b/log/logrusadapter/adapter.go deleted file mode 100644 index 65a64230..00000000 --- a/log/logrusadapter/adapter.go +++ /dev/null @@ -1,42 +0,0 @@ -// Package logrusadapter provides a logger that writes to a github.com/sirupsen/logrus.Logger -// log. -package logrusadapter - -import ( - "context" - - "github.com/jackc/pgx/v5" - "github.com/sirupsen/logrus" -) - -type Logger struct { - l logrus.FieldLogger -} - -func NewLogger(l logrus.FieldLogger) *Logger { - return &Logger{l: l} -} - -func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { - var logger logrus.FieldLogger - if data != nil { - logger = l.l.WithFields(data) - } else { - logger = l.l - } - - switch level { - case pgx.LogLevelTrace: - logger.WithField("PGX_LOG_LEVEL", level).Debug(msg) - case pgx.LogLevelDebug: - logger.Debug(msg) - case pgx.LogLevelInfo: - logger.Info(msg) - case pgx.LogLevelWarn: - logger.Warn(msg) - case pgx.LogLevelError: - logger.Error(msg) - default: - logger.WithField("INVALID_PGX_LOG_LEVEL", level).Error(msg) - } -} diff --git a/log/zapadapter/adapter.go b/log/zapadapter/adapter.go deleted file mode 100644 index 4dc47cd0..00000000 --- a/log/zapadapter/adapter.go +++ /dev/null @@ -1,42 +0,0 @@ -// Package zapadapter provides a logger that writes to a go.uber.org/zap.Logger. -package zapadapter - -import ( - "context" - - "github.com/jackc/pgx/v5" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -type Logger struct { - logger *zap.Logger -} - -func NewLogger(logger *zap.Logger) *Logger { - return &Logger{logger: logger.WithOptions(zap.AddCallerSkip(1))} -} - -func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { - fields := make([]zapcore.Field, len(data)) - i := 0 - for k, v := range data { - fields[i] = zap.Any(k, v) - i++ - } - - switch level { - case pgx.LogLevelTrace: - pl.logger.Debug(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...) - case pgx.LogLevelDebug: - pl.logger.Debug(msg, fields...) - case pgx.LogLevelInfo: - pl.logger.Info(msg, fields...) - case pgx.LogLevelWarn: - pl.logger.Warn(msg, fields...) - case pgx.LogLevelError: - pl.logger.Error(msg, fields...) - default: - pl.logger.Error(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...) - } -} diff --git a/log/zerologadapter/adapter.go b/log/zerologadapter/adapter.go deleted file mode 100644 index b93036fe..00000000 --- a/log/zerologadapter/adapter.go +++ /dev/null @@ -1,102 +0,0 @@ -// Package zerologadapter provides a logger that writes to a github.com/rs/zerolog. -package zerologadapter - -import ( - "context" - - "github.com/jackc/pgx/v5" - "github.com/rs/zerolog" -) - -type Logger struct { - logger zerolog.Logger - withFunc func(context.Context, zerolog.Context) zerolog.Context - fromContext bool - skipModule bool -} - -// option options for configuring the logger when creating a new logger. -type option func(logger *Logger) - -// WithContextFunc adds possibility to get request scoped values from the -// ctx.Context before logging lines. -func WithContextFunc(withFunc func(context.Context, zerolog.Context) zerolog.Context) option { - return func(logger *Logger) { - logger.withFunc = withFunc - } -} - -// WithoutPGXModule disables adding module:pgx to the default logger context. -func WithoutPGXModule() option { - return func(logger *Logger) { - logger.skipModule = true - } -} - -// NewLogger accepts a zerolog.Logger as input and returns a new custom pgx -// logging facade as output. -func NewLogger(logger zerolog.Logger, options ...option) *Logger { - l := Logger{ - logger: logger, - } - l.init(options) - return &l -} - -// NewContextLogger creates logger that extracts the zerolog.Logger from the -// context.Context by using `zerolog.Ctx`. The zerolog.DefaultContextLogger will -// be used if no logger is associated with the context. -func NewContextLogger(options ...option) *Logger { - l := Logger{ - fromContext: true, - } - l.init(options) - return &l -} - -func (pl *Logger) init(options []option) { - for _, opt := range options { - opt(pl) - } - if !pl.skipModule { - pl.logger = pl.logger.With().Str("module", "pgx").Logger() - } -} - -func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { - var zlevel zerolog.Level - switch level { - case pgx.LogLevelNone: - zlevel = zerolog.NoLevel - case pgx.LogLevelError: - zlevel = zerolog.ErrorLevel - case pgx.LogLevelWarn: - zlevel = zerolog.WarnLevel - case pgx.LogLevelInfo: - zlevel = zerolog.InfoLevel - case pgx.LogLevelDebug: - zlevel = zerolog.DebugLevel - default: - zlevel = zerolog.DebugLevel - } - - var zctx zerolog.Context - if pl.fromContext { - logger := zerolog.Ctx(ctx) - zctx = logger.With() - } else { - zctx = pl.logger.With() - } - if pl.withFunc != nil { - zctx = pl.withFunc(ctx, zctx) - } - - pgxlog := zctx.Logger() - event := pgxlog.WithLevel(zlevel) - if event.Enabled() { - if pl.fromContext && !pl.skipModule { - event.Str("module", "pgx") - } - event.Fields(data).Msg(msg) - } -} diff --git a/log/zerologadapter/adapter_test.go b/log/zerologadapter/adapter_test.go deleted file mode 100644 index 152b2129..00000000 --- a/log/zerologadapter/adapter_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package zerologadapter_test - -import ( - "bytes" - "context" - "testing" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/log/zerologadapter" - "github.com/rs/zerolog" -) - -func TestLogger(t *testing.T) { - - t.Run("default", func(t *testing.T) { - var buf bytes.Buffer - zlogger := zerolog.New(&buf) - logger := zerologadapter.NewLogger(zlogger) - logger.Log(context.Background(), pgx.LogLevelInfo, "hello", map[string]interface{}{"one": "two"}) - const want = `{"level":"info","module":"pgx","one":"two","message":"hello"} -` - got := buf.String() - if got != want { - t.Errorf("%s != %s", got, want) - } - }) - - t.Run("disable pgx module", func(t *testing.T) { - var buf bytes.Buffer - zlogger := zerolog.New(&buf) - logger := zerologadapter.NewLogger(zlogger, zerologadapter.WithoutPGXModule()) - logger.Log(context.Background(), pgx.LogLevelInfo, "hello", nil) - const want = `{"level":"info","message":"hello"} -` - got := buf.String() - if got != want { - t.Errorf("%s != %s", got, want) - } - }) - - t.Run("from context", func(t *testing.T) { - var buf bytes.Buffer - zlogger := zerolog.New(&buf) - ctx := zlogger.WithContext(context.Background()) - logger := zerologadapter.NewContextLogger() - logger.Log(ctx, pgx.LogLevelInfo, "hello", map[string]interface{}{"one": "two"}) - const want = `{"level":"info","module":"pgx","one":"two","message":"hello"} -` - - got := buf.String() - if got != want { - t.Log(got) - t.Log(want) - t.Errorf("%s != %s", got, want) - } - }) - - var buf bytes.Buffer - type key string - var ck key - zlogger := zerolog.New(&buf) - logger := zerologadapter.NewLogger(zlogger, - zerologadapter.WithContextFunc(func(ctx context.Context, logWith zerolog.Context) zerolog.Context { - // You can use zerolog.hlog.IDFromCtx(ctx) or even - // zerolog.log.Ctx(ctx) to fetch the whole logger instance from the - // context if you want. - id, ok := ctx.Value(ck).(string) - if ok { - logWith = logWith.Str("req_id", id) - } - return logWith - }), - ) - - t.Run("no request id", func(t *testing.T) { - buf.Reset() - ctx := context.Background() - logger.Log(ctx, pgx.LogLevelInfo, "hello", nil) - const want = `{"level":"info","module":"pgx","message":"hello"} -` - got := buf.String() - if got != want { - t.Errorf("%s != %s", got, want) - } - }) - - t.Run("with request id", func(t *testing.T) { - buf.Reset() - ctx := context.WithValue(context.Background(), ck, "1") - logger.Log(ctx, pgx.LogLevelInfo, "hello", map[string]interface{}{"two": "2"}) - const want = `{"level":"info","module":"pgx","req_id":"1","two":"2","message":"hello"} -` - got := buf.String() - if got != want { - t.Errorf("%s != %s", got, want) - } - }) -} From 7c5dbde59e7235ba35bc30b4f9c17bd624ad00bb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 14:54:25 -0600 Subject: [PATCH 0772/1158] Upgrade remaining dependencies --- go.mod | 8 ++++---- go.sum | 15 ++++++++++----- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 8cbffeb8..a4a9b8b3 100644 --- a/go.mod +++ b/go.mod @@ -9,14 +9,14 @@ require ( github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.2.0 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b - github.com/jackc/puddle v1.2.0 + github.com/jackc/puddle v1.2.1 github.com/stretchr/testify v1.7.0 - golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 - golang.org/x/text v0.3.6 + golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b + golang.org/x/text v0.3.7 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index cca0fd29..b055f738 100644 --- a/go.sum +++ b/go.sum @@ -45,8 +45,8 @@ github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9 github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.2.0 h1:DNDKdn/pDrWvDWyT2FYvpZVE81OAhWrjCv19I9n108Q= -github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.2.1 h1:gI8os0wpRXFd4FiAY2dWiqRK037tjj3t7rKFeO4X5iw= +github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= @@ -92,13 +92,15 @@ golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b h1:QAqMVf3pSa6eeTsuklijukjXBlj7Es2QQplab+/RbQ4= +golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -109,6 +111,7 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -116,8 +119,9 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= @@ -130,5 +134,6 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From b2569172d814e3a5ed36950afea462f5e3b858f2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 14:55:02 -0600 Subject: [PATCH 0773/1158] Fix typo in example --- examples/url_shortener/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/url_shortener/main.go b/examples/url_shortener/main.go index fb267b74..bcee235e 100644 --- a/examples/url_shortener/main.go +++ b/examples/url_shortener/main.go @@ -85,6 +85,6 @@ func main() { log.Println("Starting URL shortener on localhost:8080") err = http.ListenAndServe("localhost:8080", nil) if err != nil { - log.Fatalln("Unable to start web server:" err) + log.Fatalln("Unable to start web server:", err) } } From d2dc20af8166db0099abff9857b385479675343a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Dec 2021 15:32:52 -0600 Subject: [PATCH 0774/1158] Link to extensions --- README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/README.md b/README.md index 03f23bdb..98381b01 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,19 @@ tern is a stand-alone SQL migration system. pgerrcode contains constants for the PostgreSQL error codes. +## Adapters for 3rd Party Types + +* [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid) +* [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal) + +## Adapters for 3rd Party Loggers + +* [github.com/jackc/pgx-go-kit-log](https://github.com/jackc/pgx-go-kit-log) +* [github.com/jackc/pgx-log15](https://github.com/jackc/pgx-log15) +* [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus) +* [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap) +* [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog) + ## 3rd Party Libraries with PGX Support ### [github.com/georgysavva/scany](https://github.com/georgysavva/scany) From 5a5260b73dbca7c66d1573c6588c6a1730bc37ba Mon Sep 17 00:00:00 2001 From: James Hartig Date: Tue, 14 Dec 2021 13:40:04 -0500 Subject: [PATCH 0775/1158] feat: support port in ip from LookupFunc to override config Fixes #97 --- pgconn.go | 27 +++++++++++++++++++++------ pgconn_test.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/pgconn.go b/pgconn.go index 382ad33c..dad522c6 100644 --- a/pgconn.go +++ b/pgconn.go @@ -11,6 +11,7 @@ import ( "io" "math" "net" + "strconv" "strings" "sync" "time" @@ -44,7 +45,8 @@ type Notification struct { // DialFunc is a function that can be used to connect to a PostgreSQL server. type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) -// LookupFunc is a function that can be used to lookup IPs addrs from host. +// LookupFunc is a function that can be used to lookup IPs addrs from host. Optionally an ip:port combination can be +// returned in order to override the connection string's port. type LookupFunc func(ctx context.Context, host string) (addrs []string, err error) // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. @@ -196,11 +198,24 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba } for _, ip := range ips { - configs = append(configs, &FallbackConfig{ - Host: ip, - Port: fb.Port, - TLSConfig: fb.TLSConfig, - }) + splitIP, splitPort, err := net.SplitHostPort(ip) + if err == nil { + port, err := strconv.ParseUint(splitPort, 10, 16) + if err != nil { + return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err) + } + configs = append(configs, &FallbackConfig{ + Host: splitIP, + Port: uint16(port), + TLSConfig: fb.TLSConfig, + }) + } else { + configs = append(configs, &FallbackConfig{ + Host: ip, + Port: fb.Port, + TLSConfig: fb.TLSConfig, + }) + } } } diff --git a/pgconn_test.go b/pgconn_test.go index c20b7425..43e97eef 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -237,6 +237,40 @@ func TestConnectCustomLookup(t *testing.T) { closeConn(t, conn) } +func TestConnectCustomLookupWithPort(t *testing.T) { + t.Parallel() + + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + origPort := config.Port + // Chnage the config an invalid port so it will fail if used + config.Port = 0 + + looked := false + config.LookupFunc = func(ctx context.Context, host string) ([]string, error) { + looked = true + addrs, err := net.LookupHost(host) + if err != nil { + return nil, err + } + for i := range addrs { + addrs[i] = net.JoinHostPort(addrs[i], strconv.FormatUint(uint64(origPort), 10)) + } + return addrs, nil + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) +} + func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel() From 58b7486343db477ee8d3ca8201aab28a5907c121 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 23 Dec 2021 13:12:54 -0600 Subject: [PATCH 0776/1158] Initial codec support for int2 and int2[] --- pgtype/array.go | 14 ++ pgtype/array_codec.go | 352 +++++++++++++++++++++++++++++++++++++ pgtype/array_codec_test.go | 106 +++++++++++ pgtype/convert.go | 136 ++++++++++++++ pgtype/int2_codec.go | 146 +++++++++++++++ pgtype/pgtype.go | 47 ++++- 6 files changed, 799 insertions(+), 2 deletions(-) create mode 100644 pgtype/array_codec.go create mode 100644 pgtype/array_codec_test.go create mode 100644 pgtype/int2_codec.go diff --git a/pgtype/array.go b/pgtype/array.go index 174007c1..29d6f803 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -28,6 +28,20 @@ type ArrayDimension struct { LowerBound int32 } +// cardinality returns the number of elements in an array of dimensions size. +func cardinality(dimensions []ArrayDimension) int { + if len(dimensions) == 0 { + return 0 + } + + elementCount := int(dimensions[0].Length) + for _, d := range dimensions[1:] { + elementCount *= int(d.Length) + } + + return elementCount +} + func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { if len(src) < 12 { return 0, fmt.Errorf("array header too short: %d", len(src)) diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go new file mode 100644 index 00000000..b72290a0 --- /dev/null +++ b/pgtype/array_codec.go @@ -0,0 +1,352 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + + "github.com/jackc/pgio" +) + +// ArrayGetter is a type that can be converted into a PostgreSQL array. +type ArrayGetter interface { + // Dimensions returns the array dimensions. If array is nil then nil is returned. + Dimensions() []ArrayDimension + + // Index returns the element at i. + Index(i int) interface{} +} + +// ArraySetter is a type can be set from a PostgreSQL array. +type ArraySetter interface { + // SetDimensions prepares the value such that ScanIndex can be called for each element. dimensions may be nil to + // indicate a NULL array. If unable to exactly preserve dimensions SetDimensions may return an error or silently + // flatten the array dimensions. + SetDimensions(dimensions []ArrayDimension) error + + // ScanIndex returns a value usable as a scan target for i. SetDimensions must be called before ScanIndex. + ScanIndex(i int) interface{} +} + +type int16Array []int16 + +func (a int16Array) Dimensions() []ArrayDimension { + if a == nil { + return nil + } + + return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} +} + +func (a int16Array) Index(i int) interface{} { + return a[i] +} + +func (a *int16Array) SetDimensions(dimensions []ArrayDimension) error { + if dimensions == nil { + a = nil + return nil + } + + elementCount := cardinality(dimensions) + *a = make(int16Array, elementCount) + return nil +} + +func (a int16Array) ScanIndex(i int) interface{} { + return &a[i] +} + +func makeArrayGetter(a interface{}) (ArrayGetter, error) { + switch a := a.(type) { + case ArrayGetter: + return a, nil + case []int16: + return (*int16Array)(&a), nil + } + + return nil, fmt.Errorf("cannot convert %T to ArrayGetter", a) +} + +func makeArraySetter(a interface{}) (ArraySetter, error) { + switch a := a.(type) { + case ArraySetter: + return a, nil + case *[]int16: + return (*int16Array)(a), nil + } + + return nil, fmt.Errorf("cannot convert %T to ArraySetter", a) +} + +// ArrayCodec is a codec for any array type. +type ArrayCodec struct { + ElementCodec Codec + ElementOID uint32 +} + +func (c *ArrayCodec) FormatSupported(format int16) bool { + return c.ElementCodec.FormatSupported(format) +} + +func (c *ArrayCodec) PreferredFormat() int16 { + return c.ElementCodec.PreferredFormat() +} + +func (c *ArrayCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { + if value == nil { + return nil, nil + } + + array, err := makeArrayGetter(value) + if err != nil { + return nil, err + } + + switch format { + case BinaryFormatCode: + return c.encodeBinary(ci, oid, array, buf) + case TextFormatCode: + return c.encodeText(ci, oid, array, buf) + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } + +} + +func (c *ArrayCodec) encodeBinary(ci *ConnInfo, oid uint32, array ArrayGetter, buf []byte) (newBuf []byte, err error) { + dimensions := array.Dimensions() + if dimensions == nil { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: dimensions, + ElementOID: int32(c.ElementOID), + } + + containsNullIndex := len(buf) + 4 + + buf = arrayHeader.EncodeBinary(ci, buf) + + elementCount := cardinality(dimensions) + for i := 0; i < elementCount; i++ { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := c.ElementCodec.Encode(ci, c.ElementOID, BinaryFormatCode, array.Index(i), buf) + if err != nil { + return nil, err + } + if elemBuf == nil { + pgio.SetInt32(buf[containsNullIndex:], 1) + } else { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +func (c *ArrayCodec) encodeText(ci *ConnInfo, oid uint32, array ArrayGetter, buf []byte) (newBuf []byte, err error) { + dimensions := array.Dimensions() + if dimensions == nil { + return nil, nil + } + + if len(dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(dimensions)) + dimElemCounts[len(dimensions)-1] = int(dimensions[len(dimensions)-1].Length) + for i := len(dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + elementCount := cardinality(dimensions) + for i := 0; i < elementCount; i++ { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := c.ElementCodec.Encode(ci, c.ElementOID, TextFormatCode, array.Index(i), inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (c *ArrayCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + _, err := makeArraySetter(target) + if err != nil { + return nil + } + + return (*scanPlanArrayCodec)(c) +} + +func (c *ArrayCodec) decodeBinary(ci *ConnInfo, arrayOID uint32, src []byte, array ArraySetter) error { + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + // TODO - ArrayHeader.DecodeBinary should do this. But doing this there breaks old array code. Leave until old code + // can be removed. + if arrayHeader.Dimensions == nil { + arrayHeader.Dimensions = []ArrayDimension{} + } + + err = array.SetDimensions(arrayHeader.Dimensions) + if err != nil { + return err + } + + elementCount := cardinality(arrayHeader.Dimensions) + if elementCount == 0 { + return nil + } + + elementScanPlan := c.ElementCodec.PlanScan(ci, c.ElementOID, BinaryFormatCode, array.ScanIndex(0), false) + if elementScanPlan == nil { + elementScanPlan = ci.PlanScan(c.ElementOID, BinaryFormatCode, array.ScanIndex(0)) + } + + for i := 0; i < elementCount; i++ { + elem := array.ScanIndex(i) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elementScanPlan.Scan(ci, c.ElementOID, BinaryFormatCode, elemSrc, elem) + if err != nil { + return err + } + } + + return nil +} + +func (c *ArrayCodec) decodeText(ci *ConnInfo, arrayOID uint32, src []byte, array ArraySetter) error { + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + // TODO - ParseUntypedTextArray should do this. But doing this there breaks old array code. Leave until old code + // can be removed. + if uta.Dimensions == nil { + uta.Dimensions = []ArrayDimension{} + } + + err = array.SetDimensions(uta.Dimensions) + if err != nil { + return err + } + + if len(uta.Elements) == 0 { + return nil + } + + elementScanPlan := c.ElementCodec.PlanScan(ci, c.ElementOID, TextFormatCode, array.ScanIndex(0), false) + if elementScanPlan == nil { + elementScanPlan = ci.PlanScan(c.ElementOID, TextFormatCode, array.ScanIndex(0)) + } + + for i, s := range uta.Elements { + elem := array.ScanIndex(i) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + + err = elementScanPlan.Scan(ci, c.ElementOID, TextFormatCode, elemSrc, elem) + if err != nil { + return err + } + } + + return nil +} + +type scanPlanArrayCodec ArrayCodec + +func (spac *scanPlanArrayCodec) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + c := (*ArrayCodec)(spac) + + array, err := makeArraySetter(dst) + if err != nil { + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + + if src == nil { + return array.SetDimensions(nil) + } + + switch formatCode { + case BinaryFormatCode: + return c.decodeBinary(ci, oid, src, array) + case TextFormatCode: + return c.decodeText(ci, oid, src, array) + default: + return fmt.Errorf("unknown format code %d", formatCode) + } +} + +func (c ArrayCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + // var n int64 + // err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) + // return n, err + + return nil, fmt.Errorf("not implemented") +} + +func (c ArrayCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + // var n int16 + // err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) + // return n, err + + return nil, fmt.Errorf("not implemented") +} diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go new file mode 100644 index 00000000..f213d0ec --- /dev/null +++ b/pgtype/array_codec_test.go @@ -0,0 +1,106 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/stretchr/testify/assert" +) + +func TestArrayCodec(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + tests := []struct { + expected []int16 + }{ + {[]int16(nil)}, + {[]int16{}}, + {[]int16{1, 2, 3}}, + } + for i, tt := range tests { + var actual []int16 + err := conn.QueryRow( + context.Background(), + "select $1::smallint[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } +} + +// func TestArrayCodecValue(t *testing.T) { +// ArrayCodec := pgtype.NewArrayCodec("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }) + +// err := ArrayCodec.Set(nil) +// require.NoError(t, err) + +// gotValue := ArrayCodec.Get() +// require.Nil(t, gotValue) + +// slice := []string{"foo", "bar"} +// err = ArrayCodec.AssignTo(&slice) +// require.NoError(t, err) +// require.Nil(t, slice) + +// err = ArrayCodec.Set([]string{}) +// require.NoError(t, err) + +// gotValue = ArrayCodec.Get() +// require.Len(t, gotValue, 0) + +// err = ArrayCodec.AssignTo(&slice) +// require.NoError(t, err) +// require.EqualValues(t, []string{}, slice) + +// err = ArrayCodec.Set([]string{"baz", "quz"}) +// require.NoError(t, err) + +// gotValue = ArrayCodec.Get() +// require.Len(t, gotValue, 2) + +// err = ArrayCodec.AssignTo(&slice) +// require.NoError(t, err) +// require.EqualValues(t, []string{"baz", "quz"}, slice) +// } + +// func TestArrayCodecTranscode(t *testing.T) { +// conn := testutil.MustConnectPgx(t) +// defer testutil.MustCloseContext(t, conn) + +// conn.ConnInfo().RegisterDataType(pgtype.DataType{ +// Value: pgtype.NewArrayCodec("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }), +// Name: "_text", +// OID: pgtype.TextArrayOID, +// }) + +// var dstStrings []string +// err := conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings) +// require.NoError(t, err) + +// require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) +// } + +// func TestArrayCodecEmptyArrayDoesNotBreakArrayCodec(t *testing.T) { +// conn := testutil.MustConnectPgx(t) +// defer testutil.MustCloseContext(t, conn) + +// conn.ConnInfo().RegisterDataType(pgtype.DataType{ +// Value: pgtype.NewArrayCodec("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }), +// Name: "_text", +// OID: pgtype.TextArrayOID, +// }) + +// var dstStrings []string +// err := conn.QueryRow(context.Background(), "select '{}'::text[]").Scan(&dstStrings) +// require.NoError(t, err) + +// require.EqualValues(t, []string{}, dstStrings) + +// err = conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings) +// require.NoError(t, err) + +// require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) +// } diff --git a/pgtype/convert.go b/pgtype/convert.go index 21e208f5..ee5ba393 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -5,6 +5,7 @@ import ( "fmt" "math" "reflect" + "strconv" "time" ) @@ -452,6 +453,141 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { return nil, false } +func convertToInt64ForEncode(v interface{}) (n int64, valid bool, err error) { + if v == nil { + return 0, false, nil + } + + switch v := v.(type) { + case int8: + return int64(v), true, nil + case uint8: + return int64(v), true, nil + case int16: + return int64(v), true, nil + case uint16: + return int64(v), true, nil + case int32: + return int64(v), true, nil + case uint32: + return int64(v), true, nil + case int64: + return int64(v), true, nil + case uint64: + if v > math.MaxInt64 { + return 0, false, fmt.Errorf("%d is greater than maximum value for int64", v) + } + return int64(v), true, nil + case int: + return int64(v), true, nil + case uint: + if v > math.MaxInt64 { + return 0, false, fmt.Errorf("%d is greater than maximum value for int64", v) + } + return int64(v), true, nil + case string: + num, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, false, err + } + return num, true, nil + case float32: + if v > math.MaxInt64 { + return 0, false, fmt.Errorf("%f is greater than maximum value for int64", v) + } + return int64(v), true, nil + case float64: + if v > math.MaxInt64 { + return 0, false, fmt.Errorf("%f is greater than maximum value for int64", v) + } + return int64(v), true, nil + case *int8: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *uint8: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *int16: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *uint16: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *int32: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *uint32: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *int64: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *uint64: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *int: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *uint: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *string: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *float32: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *float64: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + + default: + if originalvalue, ok := underlyingNumberType(v); ok { + return convertToInt64ForEncode(originalvalue) + } + return 0, false, fmt.Errorf("cannot convert %v to int64", v) + } +} + func init() { kindTypes = map[reflect.Kind]reflect.Type{ reflect.Bool: reflect.TypeOf(false), diff --git a/pgtype/int2_codec.go b/pgtype/int2_codec.go new file mode 100644 index 00000000..7ea50870 --- /dev/null +++ b/pgtype/int2_codec.go @@ -0,0 +1,146 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +type Int2Codec struct{} + +func (Int2Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Int2Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Int2Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { + n, valid, err := convertToInt64ForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to int2: %v", value, err) + } + if !valid { + return nil, nil + } + + if n > math.MaxInt16 { + return nil, fmt.Errorf("%d is greater than maximum value for int2", n) + } + if n < math.MinInt16 { + return nil, fmt.Errorf("%d is less than minimum value for int2", n) + } + + switch format { + case BinaryFormatCode: + return pgio.AppendInt16(buf, int16(n)), nil + case TextFormatCode: + return append(buf, strconv.FormatInt(n, 10)...), nil + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } +} + +func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { + case BinaryFormatCode: + case TextFormatCode: + switch target.(type) { + case *int16: + return scanPlanTextToAnyInt16{} + case *int32: + return scanPlanTextToAnyInt32{} + case *int64: + return scanPlanTextToAnyInt64{} + } + } + + return nil +} + +func (c Int2Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n int64 + err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) + return n, err +} + +func (c Int2Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var n int16 + err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) + return n, err +} + +type scanPlanTextToAnyInt16 struct{} + +func (scanPlanTextToAnyInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 16) + if err != nil { + return err + } + + *p = int16(n) + return nil +} + +type scanPlanTextToAnyInt32 struct{} + +func (scanPlanTextToAnyInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int32) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 32) + if err != nil { + return err + } + + *p = int32(n) + return nil +} + +type scanPlanTextToAnyInt64 struct{} + +func (scanPlanTextToAnyInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int64) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + + *p = int64(n) + return nil +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index d8dd5abf..b0b07663 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -2,7 +2,9 @@ package pgtype import ( "database/sql" + "database/sql/driver" "encoding/binary" + "errors" "fmt" "math" "net" @@ -173,6 +175,34 @@ type ResultDecoder interface { DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error } +type Encoder interface { + // Encode appends the encoded bytes of value to buf. If value is the SQL NULL then append nothing and return + // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data + // written. + Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) +} + +type Codec interface { + // FormatSupported returns true if the format is supported. + FormatSupported(int16) bool + + // PreferredFormat returns the preferred format. + PreferredFormat() int16 + + Encoder + + // PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If + // actualTarget is true then the returned ScanPlan may be optimized to directly scan into target. If no plan can be + // found then nil is returned. + PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan + + // DecodeDatabaseSQLValue returns src decoded into a value compatible with the sql.Scanner interface. + DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) + + // DecodeValue returns src decoded into its default format. + DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) +} + // ResultFormatPreferrer allows a type to specify its preferred result format instead of it being inferred from // whether it is also a BinaryDecoder. type ResultFormatPreferrer interface { @@ -229,6 +259,8 @@ type DataType struct { textDecoder TextDecoder binaryDecoder BinaryDecoder + Codec Codec + Name string OID uint32 } @@ -268,7 +300,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID}) ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID}) ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID}) - ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID}) + ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementCodec: Int2Codec{}, ElementOID: Int2OID}}) ci.RegisterDataType(DataType{Value: &Int4Array{}, Name: "_int4", OID: Int4ArrayOID}) ci.RegisterDataType(DataType{Value: &Int8Array{}, Name: "_int8", OID: Int8ArrayOID}) ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) @@ -292,7 +324,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID}) ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID}) ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID}) - ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID}) + ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID}) ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID}) @@ -752,6 +784,15 @@ func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byt // PlanScan prepares a plan to scan a value into dst. func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan { + if oid != 0 { + if dt, ok := ci.DataTypeForOID(oid); ok && dt.Codec != nil { + plan := dt.Codec.PlanScan(ci, oid, formatCode, dst, false) + if plan != nil { + return plan + } + } + } + switch formatCode { case BinaryFormatCode: switch dst.(type) { @@ -866,6 +907,8 @@ func NewValue(v Value) Value { } } +var ErrScanTargetTypeChanged = errors.New("scan target type changed") + var nameValues map[string]Value func init() { From c0a0be876d02652626514ed97b0349f245e3bf76 Mon Sep 17 00:00:00 2001 From: Blake Embrey Date: Wed, 22 Dec 2021 08:33:10 -0800 Subject: [PATCH 0777/1158] Fix TLS connection timeout --- pgconn.go | 32 ++++++++++----------- pgconn_test.go | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 16 deletions(-) diff --git a/pgconn.go b/pgconn.go index dad522c6..7e5d585b 100644 --- a/pgconn.go +++ b/pgconn.go @@ -241,13 +241,6 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.parameterStatuses = make(map[string]string) - if fallbackConfig.TLSConfig != nil { - if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil { - pgConn.conn.Close() - return nil, &connectError{config: config, msg: "tls error", err: err} - } - } - pgConn.status = connStatusConnecting pgConn.contextWatcher = ctxwatch.NewContextWatcher( func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, @@ -257,6 +250,15 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() + if fallbackConfig.TLSConfig != nil { + tlsConn, err := startTLS(pgConn.conn, fallbackConfig.TLSConfig) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "tls error", err: err} + } + pgConn.conn = tlsConn + } + pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) startupMsg := pgproto3.StartupMessage{ @@ -344,24 +346,22 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } } -func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { - err = binary.Write(pgConn.conn, binary.BigEndian, []int32{8, 80877103}) +func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { + err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { - return + return nil, err } response := make([]byte, 1) - if _, err = io.ReadFull(pgConn.conn, response); err != nil { - return + if _, err = io.ReadFull(conn, response); err != nil { + return nil, err } if response[0] != 'S' { - return errors.New("server refused TLS connection") + return nil, errors.New("server refused TLS connection") } - pgConn.conn = tls.Client(pgConn.conn, tlsConfig) - - return nil + return tls.Client(conn, tlsConfig), nil } func (pgConn *PgConn) txPasswordMessage(password string) (err error) { diff --git a/pgconn_test.go b/pgconn_test.go index 43e97eef..b22792fb 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -161,6 +161,84 @@ func TestConnectTimeout(t *testing.T) { } } +func TestConnectTimeoutStuckOnTLSHandshake(t *testing.T) { + t.Parallel() + tests := []struct { + name string + connect func(connStr string) error + }{ + { + name: "via context that times out", + connect: func(connStr string) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + defer cancel() + _, err := pgconn.Connect(ctx, connStr) + return err + }, + }, + { + name: "via config ConnectTimeout", + connect: func(connStr string) error { + conf, err := pgconn.ParseConfig(connStr) + require.NoError(t, err) + conf.ConnectTimeout = time.Millisecond * 10 + _, err = pgconn.ConnectConfig(context.Background(), conf) + return err + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error) + defer close(serverErrChan) + go func() { + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + var buf []byte + _, err = conn.Read(buf) + if err != nil { + serverErrChan <- err + return + } + + // Sleeping to hang the TLS handshake. + time.Sleep(time.Minute) + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("host=%s port=%s", host, port) + + errChan := make(chan error) + go func() { + err := tt.connect(connStr) + errChan <- err + }() + + select { + case err = <-errChan: + require.True(t, pgconn.Timeout(err), err) + case err = <-serverErrChan: + t.Fatalf("server failed with error: %s", err) + case <-time.After(time.Millisecond * 100): + t.Fatal("exceeded connection timeout without erroring out") + } + }) + } +} + func TestConnectInvalidUser(t *testing.T) { t.Parallel() From 024de4c8f330d2e0bd96dd3b99c866ae05a3f25a Mon Sep 17 00:00:00 2001 From: Blake Embrey Date: Thu, 23 Dec 2021 08:53:33 -0800 Subject: [PATCH 0778/1158] Unwatch and re-watch tls --- pgconn.go | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/pgconn.go b/pgconn.go index 7e5d585b..6fde4e50 100644 --- a/pgconn.go +++ b/pgconn.go @@ -230,7 +230,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) - pgConn.conn, err = config.DialFunc(ctx, network, address) + netConn, err := config.DialFunc(ctx, network, address) if err != nil { var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { @@ -239,26 +239,28 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, &connectError{config: config, msg: "dial error", err: err} } - pgConn.parameterStatuses = make(map[string]string) - - pgConn.status = connStatusConnecting - pgConn.contextWatcher = ctxwatch.NewContextWatcher( - func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { pgConn.conn.SetDeadline(time.Time{}) }, - ) - + pgConn.contextWatcher = contextWatcher(netConn) pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() + pgConn.status = connStatusConnecting + pgConn.conn = netConn + if fallbackConfig.TLSConfig != nil { - tlsConn, err := startTLS(pgConn.conn, fallbackConfig.TLSConfig) + tlsConn, err := startTLS(netConn, fallbackConfig.TLSConfig) if err != nil { - pgConn.conn.Close() + netConn.Close() return nil, &connectError{config: config, msg: "tls error", err: err} } + + pgConn.contextWatcher.Unwatch() + pgConn.contextWatcher = contextWatcher(tlsConn) + pgConn.contextWatcher.Watch(ctx) + pgConn.conn = tlsConn } + pgConn.parameterStatuses = make(map[string]string) pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) startupMsg := pgproto3.StartupMessage{ @@ -346,6 +348,13 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } } +func contextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { + return ctxwatch.NewContextWatcher( + func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { conn.SetDeadline(time.Time{}) }, + ) +} + func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { @@ -1709,10 +1718,7 @@ func Construct(hc *HijackedConn) (*PgConn, error) { cleanupDone: make(chan struct{}), } - pgConn.contextWatcher = ctxwatch.NewContextWatcher( - func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { pgConn.conn.SetDeadline(time.Time{}) }, - ) + pgConn.contextWatcher = contextWatcher(pgConn.conn) return pgConn, nil } From 01a6923376d212fe4174b676f2667c817ee696f4 Mon Sep 17 00:00:00 2001 From: Blake Embrey Date: Thu, 23 Dec 2021 08:55:38 -0800 Subject: [PATCH 0779/1158] Rename fn to new --- pgconn.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pgconn.go b/pgconn.go index 6fde4e50..c437c119 100644 --- a/pgconn.go +++ b/pgconn.go @@ -239,7 +239,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, &connectError{config: config, msg: "dial error", err: err} } - pgConn.contextWatcher = contextWatcher(netConn) + pgConn.contextWatcher = newContextWatcher(netConn) pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() @@ -254,7 +254,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } pgConn.contextWatcher.Unwatch() - pgConn.contextWatcher = contextWatcher(tlsConn) + pgConn.contextWatcher = newContextWatcher(tlsConn) pgConn.contextWatcher.Watch(ctx) pgConn.conn = tlsConn @@ -348,7 +348,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } } -func contextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { +func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { return ctxwatch.NewContextWatcher( func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, func() { conn.SetDeadline(time.Time{}) }, @@ -1718,7 +1718,7 @@ func Construct(hc *HijackedConn) (*PgConn, error) { cleanupDone: make(chan struct{}), } - pgConn.contextWatcher = contextWatcher(pgConn.conn) + pgConn.contextWatcher = newContextWatcher(pgConn.conn) return pgConn, nil } From b148a14bbee1d4667bd2be01ea2d12253d7d1931 Mon Sep 17 00:00:00 2001 From: Blake Embrey Date: Mon, 27 Dec 2021 10:29:21 -0800 Subject: [PATCH 0780/1158] Fix defer usage --- pgconn.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pgconn.go b/pgconn.go index c437c119..4bec872d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -239,27 +239,28 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, &connectError{config: config, msg: "dial error", err: err} } - pgConn.contextWatcher = newContextWatcher(netConn) - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() - pgConn.status = connStatusConnecting pgConn.conn = netConn + pgConn.contextWatcher = newContextWatcher(netConn) + pgConn.contextWatcher.Watch(ctx) + if fallbackConfig.TLSConfig != nil { tlsConn, err := startTLS(netConn, fallbackConfig.TLSConfig) + pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { netConn.Close() return nil, &connectError{config: config, msg: "tls error", err: err} } - pgConn.contextWatcher.Unwatch() pgConn.contextWatcher = newContextWatcher(tlsConn) pgConn.contextWatcher.Watch(ctx) pgConn.conn = tlsConn } + defer pgConn.contextWatcher.Unwatch() + pgConn.parameterStatuses = make(map[string]string) pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) From a1852214fe10eeeb5b7f8634a8dd1e928884f2cf Mon Sep 17 00:00:00 2001 From: Blake Embrey Date: Mon, 27 Dec 2021 10:36:38 -0800 Subject: [PATCH 0781/1158] Keep status connecting after tls --- pgconn.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pgconn.go b/pgconn.go index 4bec872d..f8b8a659 100644 --- a/pgconn.go +++ b/pgconn.go @@ -239,9 +239,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, &connectError{config: config, msg: "dial error", err: err} } - pgConn.status = connStatusConnecting pgConn.conn = netConn - pgConn.contextWatcher = newContextWatcher(netConn) pgConn.contextWatcher.Watch(ctx) @@ -253,15 +251,15 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, &connectError{config: config, msg: "tls error", err: err} } + pgConn.conn = tlsConn pgConn.contextWatcher = newContextWatcher(tlsConn) pgConn.contextWatcher.Watch(ctx) - - pgConn.conn = tlsConn } defer pgConn.contextWatcher.Unwatch() pgConn.parameterStatuses = make(map[string]string) + pgConn.status = connStatusConnecting pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) startupMsg := pgproto3.StartupMessage{ From 3ce8a835e1d61d5b25c9e0305af1560ec01bc160 Mon Sep 17 00:00:00 2001 From: Oscar Date: Tue, 28 Dec 2021 14:43:01 +0100 Subject: [PATCH 0782/1158] add support for read-only, primary, standby, prefer-standby target_session_attributes --- config.go | 60 ++++++++++++++++++++++++++++++++++-- config_test.go | 83 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 139 insertions(+), 4 deletions(-) diff --git a/config.go b/config.go index 172e7478..c48905f9 100644 --- a/config.go +++ b/config.go @@ -329,10 +329,19 @@ func ParseConfig(connString string) (*Config, error) { } } - if settings["target_session_attrs"] == "read-write" { + switch tsa := settings["target_session_attrs"]; tsa { + case "read-write": config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite - } else if settings["target_session_attrs"] != "any" { - return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])} + case "read-only": + config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly + case "primary": + config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary + case "standby": + config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby + case "any", "prefer-standby": + // do nothing + default: + return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} } return config, nil @@ -727,3 +736,48 @@ func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgC return nil } + +// ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible +// target_session_attrs=read-only. +func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error { + result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() + if result.Err != nil { + return result.Err + } + + if string(result.Rows[0][0]) != "on" { + return errors.New("connection is not read only") + } + + return nil +} + +// ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible +// target_session_attrs=standby. +func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error { + result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() + if result.Err != nil { + return result.Err + } + + if string(result.Rows[0][0]) != "f" { + return errors.New("server is not in hot standby mode") + } + + return nil +} + +// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible +// target_session_attrs=primary. +func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error { + result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() + if result.Err != nil { + return result.Err + } + + if string(result.Rows[0][0]) == "t" { + return errors.New("server is in standby mode") + } + + return nil +} diff --git a/config_test.go b/config_test.go index d29173d1..da28782d 100644 --- a/config_test.go +++ b/config_test.go @@ -541,7 +541,7 @@ func TestParseConfig(t *testing.T) { }, }, { - name: "target_session_attrs", + name: "target_session_attrs read-write", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write", config: &pgconn.Config{ User: "jack", @@ -554,6 +554,87 @@ func TestParseConfig(t *testing.T) { ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadWrite, }, }, + { + name: "target_session_attrs read-only", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-only", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadOnly, + }, + }, + { + name: "target_session_attrs primary", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=primary", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPrimary, + }, + }, + { + name: "target_session_attrs standby", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=standby", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsStandby, + }, + }, + { + name: "target_session_attrs prefer-standby", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=prefer-standby", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "target_session_attrs any", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=any", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "target_session_attrs not set (any)", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, } for i, tt := range tests { From 3aaf3409ce9137e534e9433cc4256aaef7d88e7e Mon Sep 17 00:00:00 2001 From: Oscar Date: Tue, 28 Dec 2021 14:44:04 +0100 Subject: [PATCH 0783/1158] remove redundant map value type --- config.go | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/config.go b/config.go index c48905f9..406a2877 100644 --- a/config.go +++ b/config.go @@ -248,21 +248,21 @@ func ParseConfig(connString string) (*Config, error) { config.LookupFunc = makeDefaultResolver().LookupHost notRuntimeParams := map[string]struct{}{ - "host": struct{}{}, - "port": struct{}{}, - "database": struct{}{}, - "user": struct{}{}, - "password": struct{}{}, - "passfile": struct{}{}, - "connect_timeout": struct{}{}, - "sslmode": struct{}{}, - "sslkey": struct{}{}, - "sslcert": struct{}{}, - "sslrootcert": struct{}{}, - "target_session_attrs": struct{}{}, - "min_read_buffer_size": struct{}{}, - "service": struct{}{}, - "servicefile": struct{}{}, + "host": {}, + "port": {}, + "database": {}, + "user": {}, + "password": {}, + "passfile": {}, + "connect_timeout": {}, + "sslmode": {}, + "sslkey": {}, + "sslcert": {}, + "sslrootcert": {}, + "target_session_attrs": {}, + "min_read_buffer_size": {}, + "service": {}, + "servicefile": {}, } for k, v := range settings { From 109c4c2d95fd0925d0a4f842c0c9878ea2f16971 Mon Sep 17 00:00:00 2001 From: Oscar Date: Tue, 28 Dec 2021 14:56:49 +0100 Subject: [PATCH 0784/1158] fix standby mode validation --- config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.go b/config.go index 406a2877..0eab23af 100644 --- a/config.go +++ b/config.go @@ -760,7 +760,7 @@ func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgCon return result.Err } - if string(result.Rows[0][0]) != "f" { + if string(result.Rows[0][0]) != "t" { return errors.New("server is not in hot standby mode") } From 9fc8f9b3a8b0a82bc271c7450ec2ecbe6e5f21a4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 30 Dec 2021 18:12:47 -0600 Subject: [PATCH 0785/1158] Initial passing tests for main pgx package --- extended_query_builder.go | 37 +- pgtype/array_codec.go | 55 +-- pgtype/int2.go | 265 +---------- pgtype/int2_array.go | 896 -------------------------------------- pgtype/int2_codec.go | 223 +++++++++- pgtype/int2_test.go | 213 ++++----- pgtype/pgtype.go | 201 +++++---- pgtype/typed_array_gen.sh | 1 - pgtype/zzz.int2.go | 35 -- query_test.go | 73 ++-- rows.go | 49 ++- values.go | 93 ++-- 12 files changed, 574 insertions(+), 1567 deletions(-) delete mode 100644 pgtype/int2_array.go delete mode 100644 pgtype/zzz.int2.go diff --git a/extended_query_builder.go b/extended_query_builder.go index 1420a808..480e35d3 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -113,23 +113,34 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o } if dt, ok := ci.DataTypeForOID(oid); ok { - value := dt.Value - err := value.Set(arg) - if err != nil { - { - if arg, ok := arg.(driver.Valuer); ok { - v, err := callValuerValue(arg) - if err != nil { - return nil, err + if dt.Value != nil { + value := dt.Value + err := value.Set(arg) + if err != nil { + { + if arg, ok := arg.(driver.Valuer); ok { + v, err := callValuerValue(arg) + if err != nil { + return nil, err + } + return eqb.encodeExtendedParamValue(ci, oid, formatCode, v) } - return eqb.encodeExtendedParamValue(ci, oid, formatCode, v) } + + return nil, err } - - return nil, err + return eqb.encodeExtendedParamValue(ci, oid, formatCode, value) + } else if dt.Codec != nil { + buf, err := dt.Codec.Encode(ci, oid, formatCode, arg, eqb.paramValueBytes) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + eqb.paramValueBytes = buf + return eqb.paramValueBytes[pos:], nil } - - return eqb.encodeExtendedParamValue(ci, oid, formatCode, value) } // There is no data type registered for the destination OID, but maybe there is data type registered for the arg diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index b72290a0..16ce7382 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -28,57 +28,6 @@ type ArraySetter interface { ScanIndex(i int) interface{} } -type int16Array []int16 - -func (a int16Array) Dimensions() []ArrayDimension { - if a == nil { - return nil - } - - return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} -} - -func (a int16Array) Index(i int) interface{} { - return a[i] -} - -func (a *int16Array) SetDimensions(dimensions []ArrayDimension) error { - if dimensions == nil { - a = nil - return nil - } - - elementCount := cardinality(dimensions) - *a = make(int16Array, elementCount) - return nil -} - -func (a int16Array) ScanIndex(i int) interface{} { - return &a[i] -} - -func makeArrayGetter(a interface{}) (ArrayGetter, error) { - switch a := a.(type) { - case ArrayGetter: - return a, nil - case []int16: - return (*int16Array)(&a), nil - } - - return nil, fmt.Errorf("cannot convert %T to ArrayGetter", a) -} - -func makeArraySetter(a interface{}) (ArraySetter, error) { - switch a := a.(type) { - case ArraySetter: - return a, nil - case *[]int16: - return (*int16Array)(a), nil - } - - return nil, fmt.Errorf("cannot convert %T to ArraySetter", a) -} - // ArrayCodec is a codec for any array type. type ArrayCodec struct { ElementCodec Codec @@ -155,7 +104,8 @@ func (c *ArrayCodec) encodeText(ci *ConnInfo, oid uint32, array ArrayGetter, buf return nil, nil } - if len(dimensions) == 0 { + elementCount := cardinality(dimensions) + if elementCount == 0 { return append(buf, '{', '}'), nil } @@ -173,7 +123,6 @@ func (c *ArrayCodec) encodeText(ci *ConnInfo, oid uint32, array ArrayGetter, buf } inElemBuf := make([]byte, 0, 32) - elementCount := cardinality(dimensions) for i := 0; i < elementCount; i++ { if i > 0 { buf = append(buf, ',') diff --git a/pgtype/int2.go b/pgtype/int2.go index bbfee1cf..b7b7243f 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -2,12 +2,9 @@ package pgtype import ( "database/sql/driver" - "encoding/binary" "fmt" "math" "strconv" - - "github.com/jackc/pgio" ) type Int2 struct { @@ -15,231 +12,6 @@ type Int2 struct { Valid bool } -func (dst *Int2) Set(src interface{}) error { - if src == nil { - *dst = Int2{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case int8: - *dst = Int2{Int: int16(value), Valid: true} - case uint8: - *dst = Int2{Int: int16(value), Valid: true} - case int16: - *dst = Int2{Int: int16(value), Valid: true} - case uint16: - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case int32: - if value < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case uint32: - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case int64: - if value < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case uint64: - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case int: - if value < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case uint: - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case string: - num, err := strconv.ParseInt(value, 10, 16) - if err != nil { - return err - } - *dst = Int2{Int: int16(num), Valid: true} - case float32: - if value > math.MaxInt16 { - return fmt.Errorf("%f is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case float64: - if value > math.MaxInt16 { - return fmt.Errorf("%f is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case *int8: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *uint8: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *int16: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *uint16: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *int32: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *uint32: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *int64: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *uint64: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *int: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *uint: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *string: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *float32: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *float64: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - default: - if originalSrc, ok := underlyingNumberType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Int2", value) - } - - return nil -} - -func (dst Int2) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.Int -} - -func (src *Int2) AssignTo(dst interface{}) error { - return int64AssignTo(int64(src.Int), src.Valid, dst) -} - -func (dst *Int2) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int2{} - return nil - } - - n, err := strconv.ParseInt(string(src), 10, 16) - if err != nil { - return err - } - - *dst = Int2{Int: int16(n), Valid: true} - return nil -} - -func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int2{} - return nil - } - - if len(src) != 2 { - return fmt.Errorf("invalid length for int2: %v", len(src)) - } - - n := int16(binary.BigEndian.Uint16(src)) - *dst = Int2{Int: n, Valid: true} - return nil -} - -func (src Int2) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil -} - -func (src Int2) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return pgio.AppendInt16(buf, src.Int), nil -} - // Scan implements the database/sql Scanner interface. func (dst *Int2) Scan(src interface{}) error { if src == nil { @@ -247,25 +19,36 @@ func (dst *Int2) Scan(src interface{}) error { return nil } + var n int64 + switch src := src.(type) { case int64: - if src < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", src) - } - if src > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", src) - } - *dst = Int2{Int: int16(src), Valid: true} - return nil + n = src case string: - return dst.DecodeText(nil, []byte(src)) + var err error + n, err = strconv.ParseInt(src, 10, 16) + if err != nil { + return err + } case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + var err error + n, err = strconv.ParseInt(string(src), 10, 16) + if err != nil { + return err + } + default: + return fmt.Errorf("cannot scan %T", src) } - return fmt.Errorf("cannot scan %T", src) + if n < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) + } + if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) + } + *dst = Int2{Int: int16(n), Valid: true} + + return nil } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go deleted file mode 100644 index d96240dc..00000000 --- a/pgtype/int2_array.go +++ /dev/null @@ -1,896 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type Int2Array struct { - Elements []Int2 - Dimensions []ArrayDimension - Valid bool -} - -func (dst *Int2Array) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Int2Array{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []int16: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int16: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint16: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint16: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []int32: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int32: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint32: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint32: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []int64: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int64: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint64: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint64: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []int: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Int2: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - *dst = Int2Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = Int2Array{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for Int2Array", src) - } - if elementsLength == 0 { - *dst = Int2Array{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Int2Array", src) - } - - *dst = Int2Array{ - Elements: make([]Int2, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Int2, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to Int2Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *Int2Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to Int2Array") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in Int2Array", err) - } - index++ - - return index, nil -} - -func (dst Int2Array) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *Int2Array) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int16: - *v = make([]*int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint16: - *v = make([]*uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int32: - *v = make([]*int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint32: - *v = make([]*uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int: - *v = make([]int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int: - *v = make([]*int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint: - *v = make([]uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint: - *v = make([]*uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *Int2Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from Int2Array") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from Int2Array") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int2Array{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Int2 - - if len(uta.Elements) > 0 { - elements = make([]Int2, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Int2 - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int2Array{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = Int2Array{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Int2, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("int2"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "int2") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int2Array) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Int2Array) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/int2_codec.go b/pgtype/int2_codec.go index 7ea50870..c50b56d7 100644 --- a/pgtype/int2_codec.go +++ b/pgtype/int2_codec.go @@ -2,6 +2,7 @@ package pgtype import ( "database/sql/driver" + "encoding/binary" "fmt" "math" "strconv" @@ -46,16 +47,31 @@ func (Int2Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{ } func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { case BinaryFormatCode: case TextFormatCode: switch target.(type) { + case *int8: + return scanPlanTextAnyToInt8{} case *int16: - return scanPlanTextToAnyInt16{} + return scanPlanTextAnyToInt16{} case *int32: - return scanPlanTextToAnyInt32{} + return scanPlanTextAnyToInt32{} case *int64: - return scanPlanTextToAnyInt64{} + return scanPlanTextAnyToInt64{} + case *int: + return scanPlanTextAnyToInt{} + case *uint8: + return scanPlanTextAnyToUint8{} + case *uint16: + return scanPlanTextAnyToUint16{} + case *uint32: + return scanPlanTextAnyToUint32{} + case *uint64: + return scanPlanTextAnyToUint64{} + case *uint: + return scanPlanTextAnyToUint{} } } @@ -68,8 +84,15 @@ func (c Int2Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16 } var n int64 - err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) - return n, err + scanPlan := c.PlanScan(ci, oid, format, &n, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") + } + err := scanPlan.Scan(ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil } func (c Int2Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { @@ -78,13 +101,61 @@ func (c Int2Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt } var n int16 - err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) - return n, err + scanPlan := c.PlanScan(ci, oid, format, &n, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") + } + err := scanPlan.Scan(ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil } -type scanPlanTextToAnyInt16 struct{} +type scanPlanBinaryInt2ToInt16 struct{} -func (scanPlanTextToAnyInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int16(binary.BigEndian.Uint16(src)) + return nil +} + +type scanPlanTextAnyToInt8 struct{} + +func (scanPlanTextAnyToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int8) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 8) + if err != nil { + return err + } + + *p = int8(n) + return nil +} + +type scanPlanTextAnyToInt16 struct{} + +func (scanPlanTextAnyToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -103,9 +174,9 @@ func (scanPlanTextToAnyInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, s return nil } -type scanPlanTextToAnyInt32 struct{} +type scanPlanTextAnyToInt32 struct{} -func (scanPlanTextToAnyInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -124,9 +195,9 @@ func (scanPlanTextToAnyInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, s return nil } -type scanPlanTextToAnyInt64 struct{} +type scanPlanTextAnyToInt64 struct{} -func (scanPlanTextToAnyInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -144,3 +215,129 @@ func (scanPlanTextToAnyInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, s *p = int64(n) return nil } + +type scanPlanTextAnyToInt struct{} + +func (scanPlanTextAnyToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 0) + if err != nil { + return err + } + + *p = int(n) + return nil +} + +type scanPlanTextAnyToUint8 struct{} + +func (scanPlanTextAnyToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint8) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 8) + if err != nil { + return err + } + + *p = uint8(n) + return nil +} + +type scanPlanTextAnyToUint16 struct{} + +func (scanPlanTextAnyToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint16) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 16) + if err != nil { + return err + } + + *p = uint16(n) + return nil +} + +type scanPlanTextAnyToUint32 struct{} + +func (scanPlanTextAnyToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint32) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + *p = uint32(n) + return nil +} + +type scanPlanTextAnyToUint64 struct{} + +func (scanPlanTextAnyToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint64) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 64) + if err != nil { + return err + } + + *p = uint64(n) + return nil +} + +type scanPlanTextAnyToUint struct{} + +func (scanPlanTextAnyToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 0) + if err != nil { + return err + } + + *p = uint(n) + return nil +} diff --git a/pgtype/int2_test.go b/pgtype/int2_test.go index 58dcd141..f5bdac89 100644 --- a/pgtype/int2_test.go +++ b/pgtype/int2_test.go @@ -1,144 +1,95 @@ package pgtype_test import ( + "context" + "fmt" "math" "reflect" "testing" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestInt2Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ - &pgtype.Int2{Int: math.MinInt16, Valid: true}, - &pgtype.Int2{Int: -1, Valid: true}, - &pgtype.Int2{Int: 0, Valid: true}, - &pgtype.Int2{Int: 1, Valid: true}, - &pgtype.Int2{Int: math.MaxInt16, Valid: true}, - &pgtype.Int2{Int: 0}, +type PgxTranscodeTestCase struct { + src interface{} + dst interface{} + test func(interface{}) bool +} + +func isExpectedEq(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + return a == v + } +} + +func testPgxCodec(t testing.TB, pgTypeName string, tests []PgxTranscodeTestCase) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for i, tt := range tests { + for _, format := range formats { + err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{format.code}, tt.src).Scan(tt.dst) + if err != nil { + t.Errorf("%s %d: %v", format.name, i, err) + } + + dst := reflect.ValueOf(tt.dst) + if dst.Kind() == reflect.Ptr { + dst = dst.Elem() + } + + if !tt.test(dst.Interface()) { + t.Errorf("%s %d: unexpected result for %v: %v", format.name, i, tt.src, dst.Interface()) + } + } + } +} + +func TestInt2Codec(t *testing.T) { + testPgxCodec(t, "int2", []PgxTranscodeTestCase{ + {int8(1), new(int16), isExpectedEq(int16(1))}, + {int16(1), new(int16), isExpectedEq(int16(1))}, + {int32(1), new(int16), isExpectedEq(int16(1))}, + {int64(1), new(int16), isExpectedEq(int16(1))}, + {uint8(1), new(int16), isExpectedEq(int16(1))}, + {uint16(1), new(int16), isExpectedEq(int16(1))}, + {uint32(1), new(int16), isExpectedEq(int16(1))}, + {uint64(1), new(int16), isExpectedEq(int16(1))}, + {int(1), new(int16), isExpectedEq(int16(1))}, + {uint(1), new(int16), isExpectedEq(int16(1))}, + {pgtype.Int2{Int: 1, Valid: true}, new(int16), isExpectedEq(int16(1))}, + {1, new(int8), isExpectedEq(int8(1))}, + {1, new(int16), isExpectedEq(int16(1))}, + {1, new(int32), isExpectedEq(int32(1))}, + {1, new(int64), isExpectedEq(int64(1))}, + {1, new(uint8), isExpectedEq(uint8(1))}, + {1, new(uint16), isExpectedEq(uint16(1))}, + {1, new(uint32), isExpectedEq(uint32(1))}, + {1, new(uint64), isExpectedEq(uint64(1))}, + {1, new(int), isExpectedEq(int(1))}, + {1, new(uint), isExpectedEq(uint(1))}, + {math.MinInt16, new(int16), isExpectedEq(int16(math.MinInt16))}, + {-1, new(int16), isExpectedEq(int16(-1))}, + {0, new(int16), isExpectedEq(int16(0))}, + {1, new(int16), isExpectedEq(int16(1))}, + {math.MaxInt16, new(int16), isExpectedEq(int16(math.MaxInt16))}, + {1, new(pgtype.Int2), isExpectedEq(pgtype.Int2{Int: 1, Valid: true})}, + {pgtype.Int2{}, new(pgtype.Int2), isExpectedEq(pgtype.Int2{})}, + {nil, new(*int16), isExpectedEq((*int16)(nil))}, }) } - -func TestInt2Set(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int2 - }{ - {source: int8(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: int16(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: int32(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: int64(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: int8(-1), result: pgtype.Int2{Int: -1, Valid: true}}, - {source: int16(-1), result: pgtype.Int2{Int: -1, Valid: true}}, - {source: int32(-1), result: pgtype.Int2{Int: -1, Valid: true}}, - {source: int64(-1), result: pgtype.Int2{Int: -1, Valid: true}}, - {source: uint8(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: uint16(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: uint32(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: uint64(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: float32(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: float64(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: "1", result: pgtype.Int2{Int: 1, Valid: true}}, - {source: _int8(1), result: pgtype.Int2{Int: 1, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.Int2 - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt2AssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - - simpleTests := []struct { - src pgtype.Int2 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i8, expected: int8(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i16, expected: int16(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i32, expected: int32(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i64, expected: int64(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i, expected: int(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui, expected: uint(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Int2{Int: 0}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Int2{Int: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Int2 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &pi8, expected: int8(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &_pi8, expected: _int8(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int2 - dst interface{} - }{ - {src: pgtype.Int2{Int: 150, Valid: true}, dst: &i8}, - {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui8}, - {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui16}, - {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui32}, - {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui64}, - {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui}, - {src: pgtype.Int2{Int: 0}, dst: &i16}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index b0b07663..4983c3a2 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -300,7 +300,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID}) ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID}) ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID}) - ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementCodec: Int2Codec{}, ElementOID: Int2OID}}) + ci.RegisterDataType(DataType{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementCodec: Int2Codec{}, ElementOID: Int2OID}}) ci.RegisterDataType(DataType{Value: &Int4Array{}, Name: "_int4", OID: Int4ArrayOID}) ci.RegisterDataType(DataType{Value: &Int8Array{}, Name: "_int8", OID: Int8ArrayOID}) ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) @@ -324,7 +324,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID}) ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID}) ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID}) - ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) + ci.RegisterDataType(DataType{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID}) ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID}) @@ -398,20 +398,10 @@ func NewConnInfo() *ConnInfo { return ci } -func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) { - for name, oid := range nameOIDs { - var value Value - if t, ok := nameValues[name]; ok { - value = reflect.New(reflect.ValueOf(t).Elem().Type()).Interface().(Value) - } else { - value = &GenericText{} - } - ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid}) - } -} - func (ci *ConnInfo) RegisterDataType(t DataType) { - t.Value = NewValue(t.Value) + if t.Value != nil { + t.Value = NewValue(t.Value) + } ci.oidToDataType[t.OID] = &t ci.nameToDataType[t.Name] = &t @@ -463,8 +453,10 @@ func (ci *ConnInfo) buildReflectTypeToDataType() { ci.reflectTypeToDataType = make(map[reflect.Type]*DataType) for _, dt := range ci.oidToDataType { - if _, is := dt.Value.(TypeValue); !is { - ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt + if dt.Value != nil { + if _, is := dt.Value.(TypeValue); !is { + ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt + } } } @@ -583,8 +575,14 @@ func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode } else { switch formatCode { case BinaryFormatCode: + if dt.binaryDecoder == nil { + return fmt.Errorf("dt.binaryDecoder is nil") + } err = dt.binaryDecoder.DecodeBinary(ci, src) case TextFormatCode: + if dt.textDecoder == nil { + return fmt.Errorf("dt.textDecoder is nil") + } err = dt.textDecoder.DecodeText(ci, src) } } @@ -782,14 +780,105 @@ func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byt return newPlan.Scan(ci, oid, formatCode, src, dst) } +type pointerPointerScanPlan struct { + dstType reflect.Type + next ScanPlan +} + +func (plan *pointerPointerScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if plan.dstType != reflect.TypeOf(dst) { + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + + el := reflect.ValueOf(dst).Elem() + if src == nil { + el.Set(reflect.Zero(el.Type())) + return nil + } + + el.Set(reflect.New(el.Type().Elem())) + return plan.next.Scan(ci, oid, formatCode, src, el.Interface()) +} + +func tryPointerPointerScanPlan(dst interface{}) (plan *pointerPointerScanPlan, nextDst interface{}, ok bool) { + if dstValue := reflect.ValueOf(dst); dstValue.Kind() == reflect.Ptr { + elemValue := dstValue.Elem() + if elemValue.Kind() == reflect.Ptr { + plan = &pointerPointerScanPlan{dstType: dstValue.Type()} + return plan, reflect.Zero(elemValue.Type()).Interface(), true + } + } + + return nil, nil, false +} + +var elemKindToBasePointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ + reflect.Int: reflect.TypeOf(new(int)), + reflect.Int8: reflect.TypeOf(new(int8)), + reflect.Int16: reflect.TypeOf(new(int16)), + reflect.Int32: reflect.TypeOf(new(int32)), + reflect.Int64: reflect.TypeOf(new(int64)), + reflect.Uint: reflect.TypeOf(new(uint)), + reflect.Uint8: reflect.TypeOf(new(uint8)), + reflect.Uint16: reflect.TypeOf(new(uint16)), + reflect.Uint32: reflect.TypeOf(new(uint32)), + reflect.Uint64: reflect.TypeOf(new(uint64)), + reflect.Float32: reflect.TypeOf(new(float32)), + reflect.Float64: reflect.TypeOf(new(float64)), + reflect.String: reflect.TypeOf(new(string)), +} + +type baseTypeScanPlan struct { + dstType reflect.Type + nextDstType reflect.Type + next ScanPlan +} + +func (plan *baseTypeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if plan.dstType != reflect.TypeOf(dst) { + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + + return plan.next.Scan(ci, oid, formatCode, src, reflect.ValueOf(dst).Convert(plan.nextDstType).Interface()) +} + +func tryBaseTypeScanPlan(dst interface{}) (plan *baseTypeScanPlan, nextDst interface{}, ok bool) { + dstValue := reflect.ValueOf(dst) + + if dstValue.Kind() == reflect.Ptr { + elemValue := dstValue.Elem() + nextDstType := elemKindToBasePointerTypes[elemValue.Kind()] + if nextDstType != nil { + return &baseTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true + } + } + + return nil, nil, false +} + // PlanScan prepares a plan to scan a value into dst. func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan { if oid != 0 { if dt, ok := ci.DataTypeForOID(oid); ok && dt.Codec != nil { - plan := dt.Codec.PlanScan(ci, oid, formatCode, dst, false) - if plan != nil { + if plan := dt.Codec.PlanScan(ci, oid, formatCode, dst, false); plan != nil { return plan } + + if pointerPointerPlan, nextDst, ok := tryPointerPointerScanPlan(dst); ok { + if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { + pointerPointerPlan.next = nextPlan + return pointerPointerPlan + } + } + + if baseTypePlan, nextDst, ok := tryBaseTypeScanPlan(dst); ok { + if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { + baseTypePlan.next = nextPlan + return baseTypePlan + } + } } } @@ -908,77 +997,3 @@ func NewValue(v Value) Value { } var ErrScanTargetTypeChanged = errors.New("scan target type changed") - -var nameValues map[string]Value - -func init() { - nameValues = map[string]Value{ - "_aclitem": &ACLItemArray{}, - "_bool": &BoolArray{}, - "_bpchar": &BPCharArray{}, - "_bytea": &ByteaArray{}, - "_cidr": &CIDRArray{}, - "_date": &DateArray{}, - "_float4": &Float4Array{}, - "_float8": &Float8Array{}, - "_inet": &InetArray{}, - "_int2": &Int2Array{}, - "_int4": &Int4Array{}, - "_int8": &Int8Array{}, - "_numeric": &NumericArray{}, - "_text": &TextArray{}, - "_timestamp": &TimestampArray{}, - "_timestamptz": &TimestamptzArray{}, - "_uuid": &UUIDArray{}, - "_varchar": &VarcharArray{}, - "_jsonb": &JSONBArray{}, - "aclitem": &ACLItem{}, - "bit": &Bit{}, - "bool": &Bool{}, - "box": &Box{}, - "bpchar": &BPChar{}, - "bytea": &Bytea{}, - "char": &QChar{}, - "cid": &CID{}, - "cidr": &CIDR{}, - "circle": &Circle{}, - "date": &Date{}, - "daterange": &Daterange{}, - "float4": &Float4{}, - "float8": &Float8{}, - "hstore": &Hstore{}, - "inet": &Inet{}, - "int2": &Int2{}, - "int4": &Int4{}, - "int4range": &Int4range{}, - "int8": &Int8{}, - "int8range": &Int8range{}, - "interval": &Interval{}, - "json": &JSON{}, - "jsonb": &JSONB{}, - "line": &Line{}, - "lseg": &Lseg{}, - "macaddr": &Macaddr{}, - "name": &Name{}, - "numeric": &Numeric{}, - "numrange": &Numrange{}, - "oid": &OIDValue{}, - "path": &Path{}, - "point": &Point{}, - "polygon": &Polygon{}, - "record": &Record{}, - "text": &Text{}, - "tid": &TID{}, - "timestamp": &Timestamp{}, - "timestamptz": &Timestamptz{}, - "tsrange": &Tsrange{}, - "_tsrange": &TsrangeArray{}, - "tstzrange": &Tstzrange{}, - "_tstzrange": &TstzrangeArray{}, - "unknown": &Unknown{}, - "uuid": &UUID{}, - "varbit": &Varbit{}, - "varchar": &Varchar{}, - "xid": &XID{}, - } -} diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index ea28be07..ae0e67cb 100755 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -1,4 +1,3 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[]*bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go diff --git a/pgtype/zzz.int2.go b/pgtype/zzz.int2.go deleted file mode 100644 index f2d959f9..00000000 --- a/pgtype/zzz.int2.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Int2) BinaryFormatSupported() bool { - return true -} - -func (Int2) TextFormatSupported() bool { - return true -} - -func (Int2) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Int2) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Int2) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/query_test.go b/query_test.go index e725bd40..d9b35e28 100644 --- a/query_test.go +++ b/query_test.go @@ -920,65 +920,64 @@ func TestQueryRowCoreIntegerDecoding(t *testing.T) { } failedDecodeTests := []struct { - sql string - scanArg interface{} - expectedErr string + sql string + scanArg interface{} }{ // Check any integer type where value is outside Go:int8 range cannot be decoded - {"select 128::int2", &actual.i8, "is greater than"}, - {"select 128::int4", &actual.i8, "is greater than"}, - {"select 128::int8", &actual.i8, "is greater than"}, - {"select -129::int2", &actual.i8, "is less than"}, - {"select -129::int4", &actual.i8, "is less than"}, - {"select -129::int8", &actual.i8, "is less than"}, + {"select 128::int2", &actual.i8}, + {"select 128::int4", &actual.i8}, + {"select 128::int8", &actual.i8}, + {"select -129::int2", &actual.i8}, + {"select -129::int4", &actual.i8}, + {"select -129::int8", &actual.i8}, // Check any integer type where value is outside Go:int16 range cannot be decoded - {"select 32768::int4", &actual.i16, "is greater than"}, - {"select 32768::int8", &actual.i16, "is greater than"}, - {"select -32769::int4", &actual.i16, "is less than"}, - {"select -32769::int8", &actual.i16, "is less than"}, + {"select 32768::int4", &actual.i16}, + {"select 32768::int8", &actual.i16}, + {"select -32769::int4", &actual.i16}, + {"select -32769::int8", &actual.i16}, // Check any integer type where value is outside Go:int32 range cannot be decoded - {"select 2147483648::int8", &actual.i32, "is greater than"}, - {"select -2147483649::int8", &actual.i32, "is less than"}, + {"select 2147483648::int8", &actual.i32}, + {"select -2147483649::int8", &actual.i32}, // Check any integer type where value is outside Go:uint range cannot be decoded - {"select -1::int2", &actual.ui, "is less than"}, - {"select -1::int4", &actual.ui, "is less than"}, - {"select -1::int8", &actual.ui, "is less than"}, + {"select -1::int2", &actual.ui}, + {"select -1::int4", &actual.ui}, + {"select -1::int8", &actual.ui}, // Check any integer type where value is outside Go:uint8 range cannot be decoded - {"select 256::int2", &actual.ui8, "is greater than"}, - {"select 256::int4", &actual.ui8, "is greater than"}, - {"select 256::int8", &actual.ui8, "is greater than"}, - {"select -1::int2", &actual.ui8, "is less than"}, - {"select -1::int4", &actual.ui8, "is less than"}, - {"select -1::int8", &actual.ui8, "is less than"}, + {"select 256::int2", &actual.ui8}, + {"select 256::int4", &actual.ui8}, + {"select 256::int8", &actual.ui8}, + {"select -1::int2", &actual.ui8}, + {"select -1::int4", &actual.ui8}, + {"select -1::int8", &actual.ui8}, // Check any integer type where value is outside Go:uint16 cannot be decoded - {"select 65536::int4", &actual.ui16, "is greater than"}, - {"select 65536::int8", &actual.ui16, "is greater than"}, - {"select -1::int2", &actual.ui16, "is less than"}, - {"select -1::int4", &actual.ui16, "is less than"}, - {"select -1::int8", &actual.ui16, "is less than"}, + {"select 65536::int4", &actual.ui16}, + {"select 65536::int8", &actual.ui16}, + {"select -1::int2", &actual.ui16}, + {"select -1::int4", &actual.ui16}, + {"select -1::int8", &actual.ui16}, // Check any integer type where value is outside Go:uint32 range cannot be decoded - {"select 4294967296::int8", &actual.ui32, "is greater than"}, - {"select -1::int2", &actual.ui32, "is less than"}, - {"select -1::int4", &actual.ui32, "is less than"}, - {"select -1::int8", &actual.ui32, "is less than"}, + {"select 4294967296::int8", &actual.ui32}, + {"select -1::int2", &actual.ui32}, + {"select -1::int4", &actual.ui32}, + {"select -1::int8", &actual.ui32}, // Check any integer type where value is outside Go:uint64 range cannot be decoded - {"select -1::int2", &actual.ui64, "is less than"}, - {"select -1::int4", &actual.ui64, "is less than"}, - {"select -1::int8", &actual.ui64, "is less than"}, + {"select -1::int2", &actual.ui64}, + {"select -1::int4", &actual.ui64}, + {"select -1::int8", &actual.ui64}, } for i, tt := range failedDecodeTests { err := conn.QueryRow(context.Background(), tt.sql).Scan(tt.scanArg) if err == nil { t.Errorf("%d. Expected failure to decode, but unexpectedly succeeded: %v (sql -> %v)", i, err, tt.sql) - } else if !strings.Contains(err.Error(), tt.expectedErr) { + } else if !strings.Contains(err.Error(), "can't scan") { t.Errorf("%d. Expected failure to decode, but got: %v (sql -> %v)", i, err, tt.sql) } diff --git a/rows.go b/rows.go index cc6e26d5..0cc09ad9 100644 --- a/rows.go +++ b/rows.go @@ -246,31 +246,40 @@ func (rows *connRows) Values() ([]interface{}, error) { } if dt, ok := rows.connInfo.DataTypeForOID(fd.DataTypeOID); ok { - value := dt.Value + if dt.Value != nil { - switch fd.Format { - case TextFormatCode: - decoder, ok := value.(pgtype.TextDecoder) - if !ok { - decoder = &pgtype.GenericText{} + value := dt.Value + + switch fd.Format { + case TextFormatCode: + decoder, ok := value.(pgtype.TextDecoder) + if !ok { + decoder = &pgtype.GenericText{} + } + err := decoder.DecodeText(rows.connInfo, buf) + if err != nil { + rows.fatal(err) + } + values = append(values, decoder.(pgtype.Value).Get()) + case BinaryFormatCode: + decoder, ok := value.(pgtype.BinaryDecoder) + if !ok { + decoder = &pgtype.GenericBinary{} + } + err := decoder.DecodeBinary(rows.connInfo, buf) + if err != nil { + rows.fatal(err) + } + values = append(values, value.Get()) + default: + rows.fatal(errors.New("Unknown format code")) } - err := decoder.DecodeText(rows.connInfo, buf) + } else if dt.Codec != nil { + value, err := dt.Codec.DecodeValue(rows.connInfo, fd.DataTypeOID, fd.Format, buf) if err != nil { rows.fatal(err) } - values = append(values, decoder.(pgtype.Value).Get()) - case BinaryFormatCode: - decoder, ok := value.(pgtype.BinaryDecoder) - if !ok { - decoder = &pgtype.GenericBinary{} - } - err := decoder.DecodeBinary(rows.connInfo, buf) - if err != nil { - rows.fatal(err) - } - values = append(values, value.Get()) - default: - rows.fatal(errors.New("Unknown format code")) + values = append(values, value) } } else { switch fd.Format { diff --git a/values.go b/values.go index 2f328b82..00606689 100644 --- a/values.go +++ b/values.go @@ -115,19 +115,30 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e } if dt, found := ci.DataTypeForValue(arg); found { - v := dt.Value - err := v.Set(arg) - if err != nil { - return nil, err + if dt.Value != nil { + v := dt.Value + err := v.Set(arg) + if err != nil { + return nil, err + } + buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + return string(buf), nil + } else if dt.Codec != nil { + buf, err := dt.Codec.Encode(ci, 0, TextFormatCode, arg, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + return string(buf), nil } - buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil } if refVal.Kind() == reflect.Ptr { @@ -188,33 +199,47 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32 } if dt, ok := ci.DataTypeForOID(oid); ok { - value := dt.Value - err := value.Set(arg) - if err != nil { - { - if arg, ok := arg.(driver.Valuer); ok { - v, err := callValuerValue(arg) - if err != nil { - return nil, err + if dt.Value != nil { + value := dt.Value + err := value.Set(arg) + if err != nil { + { + if arg, ok := arg.(driver.Valuer); ok { + v, err := callValuerValue(arg) + if err != nil { + return nil, err + } + return encodePreparedStatementArgument(ci, buf, oid, v) } - return encodePreparedStatementArgument(ci, buf, oid, v) } + + return nil, err } - return nil, err + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + return buf, nil + } else if dt.Codec != nil { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := dt.Codec.Encode(ci, oid, BinaryFormatCode, arg, buf) + if err != nil { + return nil, err + } + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + return buf, nil } - - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if argBuf != nil { - buf = argBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - return buf, nil } if strippedArg, ok := stripNamedType(&refVal); ok { From c39924d0c67d4f5142fa5cdf56414a702c149f3c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 31 Dec 2021 12:28:45 -0600 Subject: [PATCH 0786/1158] Improvements to ArrayCodec --- pgtype/array_codec_test.go | 116 ++++------ pgtype/array_getter_setter.go | 143 +++++++++++++ pgtype/array_getter_setter.go.erb | 117 ++++++++++ pgtype/int2_array_test.go | 342 ------------------------------ 4 files changed, 301 insertions(+), 417 deletions(-) create mode 100644 pgtype/array_getter_setter.go create mode 100644 pgtype/array_getter_setter.go.erb delete mode 100644 pgtype/int2_array_test.go diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go index f213d0ec..c358586e 100644 --- a/pgtype/array_codec_test.go +++ b/pgtype/array_codec_test.go @@ -12,14 +12,13 @@ func TestArrayCodec(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) - tests := []struct { - expected []int16 + for i, tt := range []struct { + expected interface{} }{ {[]int16(nil)}, {[]int16{}}, {[]int16{1, 2, 3}}, - } - for i, tt := range tests { + } { var actual []int16 err := conn.QueryRow( context.Background(), @@ -29,78 +28,45 @@ func TestArrayCodec(t *testing.T) { assert.NoErrorf(t, err, "%d", i) assert.Equalf(t, tt.expected, actual, "%d", i) } + + newInt16 := func(n int16) *int16 { return &n } + + for i, tt := range []struct { + expected interface{} + }{ + {[]*int16{newInt16(1), nil, newInt16(3), nil, newInt16(5)}}, + } { + var actual []*int16 + err := conn.QueryRow( + context.Background(), + "select $1::smallint[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } } -// func TestArrayCodecValue(t *testing.T) { -// ArrayCodec := pgtype.NewArrayCodec("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }) +func TestArrayCodecAnySlice(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) -// err := ArrayCodec.Set(nil) -// require.NoError(t, err) + type _int16Slice []int16 -// gotValue := ArrayCodec.Get() -// require.Nil(t, gotValue) - -// slice := []string{"foo", "bar"} -// err = ArrayCodec.AssignTo(&slice) -// require.NoError(t, err) -// require.Nil(t, slice) - -// err = ArrayCodec.Set([]string{}) -// require.NoError(t, err) - -// gotValue = ArrayCodec.Get() -// require.Len(t, gotValue, 0) - -// err = ArrayCodec.AssignTo(&slice) -// require.NoError(t, err) -// require.EqualValues(t, []string{}, slice) - -// err = ArrayCodec.Set([]string{"baz", "quz"}) -// require.NoError(t, err) - -// gotValue = ArrayCodec.Get() -// require.Len(t, gotValue, 2) - -// err = ArrayCodec.AssignTo(&slice) -// require.NoError(t, err) -// require.EqualValues(t, []string{"baz", "quz"}, slice) -// } - -// func TestArrayCodecTranscode(t *testing.T) { -// conn := testutil.MustConnectPgx(t) -// defer testutil.MustCloseContext(t, conn) - -// conn.ConnInfo().RegisterDataType(pgtype.DataType{ -// Value: pgtype.NewArrayCodec("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }), -// Name: "_text", -// OID: pgtype.TextArrayOID, -// }) - -// var dstStrings []string -// err := conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings) -// require.NoError(t, err) - -// require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) -// } - -// func TestArrayCodecEmptyArrayDoesNotBreakArrayCodec(t *testing.T) { -// conn := testutil.MustConnectPgx(t) -// defer testutil.MustCloseContext(t, conn) - -// conn.ConnInfo().RegisterDataType(pgtype.DataType{ -// Value: pgtype.NewArrayCodec("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }), -// Name: "_text", -// OID: pgtype.TextArrayOID, -// }) - -// var dstStrings []string -// err := conn.QueryRow(context.Background(), "select '{}'::text[]").Scan(&dstStrings) -// require.NoError(t, err) - -// require.EqualValues(t, []string{}, dstStrings) - -// err = conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings) -// require.NoError(t, err) - -// require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) -// } + for i, tt := range []struct { + expected interface{} + }{ + {_int16Slice(nil)}, + {_int16Slice{}}, + {_int16Slice{1, 2, 3}}, + } { + var actual _int16Slice + err := conn.QueryRow( + context.Background(), + "select $1::smallint[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } +} diff --git a/pgtype/array_getter_setter.go b/pgtype/array_getter_setter.go new file mode 100644 index 00000000..72a6f0e7 --- /dev/null +++ b/pgtype/array_getter_setter.go @@ -0,0 +1,143 @@ +package pgtype + +import ( + "fmt" + "reflect" +) + +type int16Array []int16 + +func (a int16Array) Dimensions() []ArrayDimension { + if a == nil { + return nil + } + + return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} +} + +func (a int16Array) Index(i int) interface{} { + return a[i] +} + +func (a *int16Array) SetDimensions(dimensions []ArrayDimension) error { + if dimensions == nil { + a = nil + return nil + } + + elementCount := cardinality(dimensions) + *a = make(int16Array, elementCount) + return nil +} + +func (a int16Array) ScanIndex(i int) interface{} { + return &a[i] +} + +type uint16Array []uint16 + +func (a uint16Array) Dimensions() []ArrayDimension { + if a == nil { + return nil + } + + return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} +} + +func (a uint16Array) Index(i int) interface{} { + return a[i] +} + +func (a *uint16Array) SetDimensions(dimensions []ArrayDimension) error { + if dimensions == nil { + a = nil + return nil + } + + elementCount := cardinality(dimensions) + *a = make(uint16Array, elementCount) + return nil +} + +func (a uint16Array) ScanIndex(i int) interface{} { + return &a[i] +} + +type anySliceArray struct { + slice reflect.Value +} + +func (a anySliceArray) Dimensions() []ArrayDimension { + if a.slice.IsNil() { + return nil + } + + return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}} +} + +func (a anySliceArray) Index(i int) interface{} { + return a.slice.Index(i).Interface() +} + +func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error { + sliceType := a.slice.Type() + + if dimensions == nil { + a.slice.Set(reflect.Zero(sliceType)) + return nil + } + + elementCount := cardinality(dimensions) + slice := reflect.MakeSlice(sliceType, elementCount, elementCount) + a.slice.Set(slice) + return nil +} + +func (a anySliceArray) ScanIndex(i int) interface{} { + return a.slice.Index(i).Addr().Interface() +} + +func makeArrayGetter(a interface{}) (ArrayGetter, error) { + switch a := a.(type) { + case ArrayGetter: + return a, nil + + case []int16: + return (*int16Array)(&a), nil + + case []uint16: + return (*uint16Array)(&a), nil + + } + + reflectValue := reflect.ValueOf(a) + if reflectValue.Kind() == reflect.Slice { + return &anySliceArray{slice: reflectValue}, nil + } + + return nil, fmt.Errorf("cannot convert %T to ArrayGetter", a) +} + +func makeArraySetter(a interface{}) (ArraySetter, error) { + switch a := a.(type) { + case ArraySetter: + return a, nil + + case *[]int16: + return (*int16Array)(a), nil + + case *[]uint16: + return (*uint16Array)(a), nil + + } + + value := reflect.ValueOf(a) + if value.Kind() == reflect.Ptr { + elemValue := value.Elem() + if elemValue.Kind() == reflect.Slice { + return &anySliceArray{slice: elemValue}, nil + } + } + + return nil, fmt.Errorf("cannot convert %T to ArraySetter", a) +} diff --git a/pgtype/array_getter_setter.go.erb b/pgtype/array_getter_setter.go.erb new file mode 100644 index 00000000..01b7d4fa --- /dev/null +++ b/pgtype/array_getter_setter.go.erb @@ -0,0 +1,117 @@ +package pgtype + +import ( + "fmt" + "reflect" +) + +<% + types = [ + ["int16Array", "int16"], + ["uint16Array", "uint16"], + ] +%> + +<% types.each do |array_type, element_type| %> + type <%= array_type %> []<%= element_type %> + + func (a <%= array_type %>) Dimensions() []ArrayDimension { + if a == nil { + return nil + } + + return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} + } + + func (a <%= array_type %>) Index(i int) interface{} { + return a[i] + } + + func (a *<%= array_type %>) SetDimensions(dimensions []ArrayDimension) error { + if dimensions == nil { + a = nil + return nil + } + + elementCount := cardinality(dimensions) + *a = make(<%= array_type %>, elementCount) + return nil + } + + func (a <%= array_type %>) ScanIndex(i int) interface{} { + return &a[i] + } +<% end %> + +type anySliceArray struct { + slice reflect.Value +} + +func (a anySliceArray) Dimensions() []ArrayDimension { + if a.slice.IsNil() { + return nil + } + + return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}} +} + +func (a anySliceArray) Index(i int) interface{} { + return a.slice.Index(i).Interface() +} + +func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error { + sliceType := a.slice.Type() + + if dimensions == nil { + a.slice.Set(reflect.Zero(sliceType)) + return nil + } + + elementCount := cardinality(dimensions) + slice := reflect.MakeSlice(sliceType, elementCount, elementCount) + a.slice.Set(slice) + return nil +} + +func (a anySliceArray) ScanIndex(i int) interface{} { + return a.slice.Index(i).Addr().Interface() +} + +func makeArrayGetter(a interface{}) (ArrayGetter, error) { + switch a := a.(type) { + case ArrayGetter: + return a, nil + <% types.each do |array_type, element_type| %> + case []<%= element_type %>: + return (*<%= array_type %>)(&a), nil + <% end %> + } + + reflectValue := reflect.ValueOf(a) + if reflectValue.Kind() == reflect.Slice { + return &anySliceArray{slice: reflectValue}, nil + } + + return nil, fmt.Errorf("cannot convert %T to ArrayGetter", a) +} + +func makeArraySetter(a interface{}) (ArraySetter, error) { + switch a := a.(type) { + case ArraySetter: + return a, nil + <% types.each do |array_type, element_type| %> + case *[]<%= element_type %>: + return (*<%= array_type %>)(a), nil + <% end %> + } + + value := reflect.ValueOf(a) + if value.Kind() == reflect.Ptr { + elemValue := value.Elem() + if elemValue.Kind() == reflect.Slice { + return &anySliceArray{slice: elemValue}, nil + } + } + + return nil, fmt.Errorf("cannot convert %T to ArraySetter", a) +} diff --git a/pgtype/int2_array_test.go b/pgtype/int2_array_test.go deleted file mode 100644 index 110968fc..00000000 --- a/pgtype/int2_array_test.go +++ /dev/null @@ -1,342 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestInt2ArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int2[]", []interface{}{ - &pgtype.Int2Array{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.Int2Array{}, - &pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - {}, - {Int: 6, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestInt2ArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int2Array - }{ - { - source: []int64{1}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []int32{1}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []int16{1}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []int{1}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []uint64{1}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []uint32{1}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []uint16{1}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]int16)(nil)), - result: pgtype.Int2Array{}, - }, - { - source: [][]int16{{1}, {2}}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - {Int: 5, Valid: true}, - {Int: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]int16{{1}, {2}}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - {Int: 5, Valid: true}, - {Int: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Int2Array - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt2ArrayAssignTo(t *testing.T) { - var int16Slice []int16 - var uint16Slice []uint16 - var namedInt16Slice _int16Slice - var int16SliceDim2 [][]int16 - var int16SliceDim4 [][][][]int16 - var int16ArrayDim2 [2][1]int16 - var int16ArrayDim4 [2][1][1][3]int16 - - simpleTests := []struct { - src pgtype.Int2Array - dst interface{} - expected interface{} - }{ - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &int16Slice, - expected: []int16{1}, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &uint16Slice, - expected: []uint16{1}, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &namedInt16Slice, - expected: _int16Slice{1}, - }, - { - src: pgtype.Int2Array{}, - dst: &int16Slice, - expected: (([]int16)(nil)), - }, - { - src: pgtype.Int2Array{Valid: true}, - dst: &int16Slice, - expected: []int16{}, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - expected: [][]int16{{1}, {2}}, - dst: &int16SliceDim2, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - {Int: 5, Valid: true}, - {Int: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - expected: [][][][]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &int16SliceDim4, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - expected: [2][1]int16{{1}, {2}}, - dst: &int16ArrayDim2, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - {Int: 5, Valid: true}, - {Int: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - expected: [2][1][1][3]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &int16ArrayDim4, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int2Array - dst interface{} - }{ - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &int16Slice, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: -1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &uint16Slice, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &int16ArrayDim2, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &int16Slice, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &int16ArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} From b99d95470fbb88a39e5122abd485a45a0d4362b6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 31 Dec 2021 12:32:46 -0600 Subject: [PATCH 0787/1158] Fix tryBaseTypeScanPlan infinite recursion --- pgtype/pgtype.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 4983c3a2..3e807a89 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -850,7 +850,7 @@ func tryBaseTypeScanPlan(dst interface{}) (plan *baseTypeScanPlan, nextDst inter if dstValue.Kind() == reflect.Ptr { elemValue := dstValue.Elem() nextDstType := elemKindToBasePointerTypes[elemValue.Kind()] - if nextDstType != nil { + if nextDstType != nil && dstValue.Type() != nextDstType { return &baseTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true } } From 77b9b59622e211ac4102c2cbf526394416e4f67f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 31 Dec 2021 13:07:08 -0600 Subject: [PATCH 0788/1158] Generate text to int scan plans --- pgtype/int2_codec.go | 210 ---------------------------------- pgtype/int_scan_plans.go | 216 +++++++++++++++++++++++++++++++++++ pgtype/int_scan_plans.go.erb | 56 +++++++++ 3 files changed, 272 insertions(+), 210 deletions(-) create mode 100644 pgtype/int_scan_plans.go create mode 100644 pgtype/int_scan_plans.go.erb diff --git a/pgtype/int2_codec.go b/pgtype/int2_codec.go index c50b56d7..17c335a6 100644 --- a/pgtype/int2_codec.go +++ b/pgtype/int2_codec.go @@ -131,213 +131,3 @@ func (scanPlanBinaryInt2ToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16 *p = int16(binary.BigEndian.Uint16(src)) return nil } - -type scanPlanTextAnyToInt8 struct{} - -func (scanPlanTextAnyToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*int8) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseInt(string(src), 10, 8) - if err != nil { - return err - } - - *p = int8(n) - return nil -} - -type scanPlanTextAnyToInt16 struct{} - -func (scanPlanTextAnyToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*int16) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseInt(string(src), 10, 16) - if err != nil { - return err - } - - *p = int16(n) - return nil -} - -type scanPlanTextAnyToInt32 struct{} - -func (scanPlanTextAnyToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*int32) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseInt(string(src), 10, 32) - if err != nil { - return err - } - - *p = int32(n) - return nil -} - -type scanPlanTextAnyToInt64 struct{} - -func (scanPlanTextAnyToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*int64) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseInt(string(src), 10, 64) - if err != nil { - return err - } - - *p = int64(n) - return nil -} - -type scanPlanTextAnyToInt struct{} - -func (scanPlanTextAnyToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*int) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseInt(string(src), 10, 0) - if err != nil { - return err - } - - *p = int(n) - return nil -} - -type scanPlanTextAnyToUint8 struct{} - -func (scanPlanTextAnyToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*uint8) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseUint(string(src), 10, 8) - if err != nil { - return err - } - - *p = uint8(n) - return nil -} - -type scanPlanTextAnyToUint16 struct{} - -func (scanPlanTextAnyToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*uint16) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseUint(string(src), 10, 16) - if err != nil { - return err - } - - *p = uint16(n) - return nil -} - -type scanPlanTextAnyToUint32 struct{} - -func (scanPlanTextAnyToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*uint32) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseUint(string(src), 10, 32) - if err != nil { - return err - } - - *p = uint32(n) - return nil -} - -type scanPlanTextAnyToUint64 struct{} - -func (scanPlanTextAnyToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*uint64) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseUint(string(src), 10, 64) - if err != nil { - return err - } - - *p = uint64(n) - return nil -} - -type scanPlanTextAnyToUint struct{} - -func (scanPlanTextAnyToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*uint) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseUint(string(src), 10, 0) - if err != nil { - return err - } - - *p = uint(n) - return nil -} diff --git a/pgtype/int_scan_plans.go b/pgtype/int_scan_plans.go new file mode 100644 index 00000000..e7fce506 --- /dev/null +++ b/pgtype/int_scan_plans.go @@ -0,0 +1,216 @@ +package pgtype + +import ( + "fmt" + "strconv" +) + +type scanPlanTextAnyToInt8 struct{} + +func (scanPlanTextAnyToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int8) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 8) + if err != nil { + return err + } + + *p = int8(n) + return nil +} + +type scanPlanTextAnyToUint8 struct{} + +func (scanPlanTextAnyToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint8) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 8) + if err != nil { + return err + } + + *p = uint8(n) + return nil +} + +type scanPlanTextAnyToInt16 struct{} + +func (scanPlanTextAnyToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 16) + if err != nil { + return err + } + + *p = int16(n) + return nil +} + +type scanPlanTextAnyToUint16 struct{} + +func (scanPlanTextAnyToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint16) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 16) + if err != nil { + return err + } + + *p = uint16(n) + return nil +} + +type scanPlanTextAnyToInt32 struct{} + +func (scanPlanTextAnyToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int32) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 32) + if err != nil { + return err + } + + *p = int32(n) + return nil +} + +type scanPlanTextAnyToUint32 struct{} + +func (scanPlanTextAnyToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint32) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + *p = uint32(n) + return nil +} + +type scanPlanTextAnyToInt64 struct{} + +func (scanPlanTextAnyToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int64) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + + *p = int64(n) + return nil +} + +type scanPlanTextAnyToUint64 struct{} + +func (scanPlanTextAnyToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint64) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 64) + if err != nil { + return err + } + + *p = uint64(n) + return nil +} + +type scanPlanTextAnyToInt struct{} + +func (scanPlanTextAnyToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 0) + if err != nil { + return err + } + + *p = int(n) + return nil +} + +type scanPlanTextAnyToUint struct{} + +func (scanPlanTextAnyToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 0) + if err != nil { + return err + } + + *p = uint(n) + return nil +} diff --git a/pgtype/int_scan_plans.go.erb b/pgtype/int_scan_plans.go.erb new file mode 100644 index 00000000..abdd1329 --- /dev/null +++ b/pgtype/int_scan_plans.go.erb @@ -0,0 +1,56 @@ +package pgtype + +import ( + "fmt" + "strconv" +) + +<% [ + ["8", 8], + ["16", 16], + ["32", 32], + ["64", 64], + ["", 0] +].each do |type_suffix, bit_size| %> +type scanPlanTextAnyToInt<%= type_suffix %> struct{} + +func (scanPlanTextAnyToInt<%= type_suffix %>) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int<%= type_suffix %>) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, <%= bit_size %>) + if err != nil { + return err + } + + *p = int<%= type_suffix %>(n) + return nil +} + +type scanPlanTextAnyToUint<%= type_suffix %> struct{} + +func (scanPlanTextAnyToUint<%= type_suffix %>) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint<%= type_suffix %>) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, <%= bit_size %>) + if err != nil { + return err + } + + *p = uint<%= type_suffix %>(n) + return nil +} +<% end %> From 19ae359e9e3ebe14181289a08480c66f15755af4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 31 Dec 2021 17:03:31 -0600 Subject: [PATCH 0789/1158] Add binary scan plans for int2 --- pgtype/int2_codec.go | 43 +++--- pgtype/int_scan_plans.go | 248 +++++++++++++++++++++++++++++++++++ pgtype/int_scan_plans.go.erb | 135 +++++++++++++++++++ 3 files changed, 405 insertions(+), 21 deletions(-) diff --git a/pgtype/int2_codec.go b/pgtype/int2_codec.go index 17c335a6..d436346c 100644 --- a/pgtype/int2_codec.go +++ b/pgtype/int2_codec.go @@ -2,7 +2,6 @@ package pgtype import ( "database/sql/driver" - "encoding/binary" "fmt" "math" "strconv" @@ -50,6 +49,28 @@ func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa switch format { case BinaryFormatCode: + switch target.(type) { + case *int8: + return scanPlanBinaryInt2ToInt8{} + case *int16: + return scanPlanBinaryInt2ToInt16{} + case *int32: + return scanPlanBinaryInt2ToInt32{} + case *int64: + return scanPlanBinaryInt2ToInt64{} + case *int: + return scanPlanBinaryInt2ToInt{} + case *uint8: + return scanPlanBinaryInt2ToUint8{} + case *uint16: + return scanPlanBinaryInt2ToUint16{} + case *uint32: + return scanPlanBinaryInt2ToUint32{} + case *uint64: + return scanPlanBinaryInt2ToUint64{} + case *uint: + return scanPlanBinaryInt2ToUint{} + } case TextFormatCode: switch target.(type) { case *int8: @@ -111,23 +132,3 @@ func (c Int2Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt } return n, nil } - -type scanPlanBinaryInt2ToInt16 struct{} - -func (scanPlanBinaryInt2ToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - if len(src) != 2 { - return fmt.Errorf("invalid length for int2: %v", len(src)) - } - - p, ok := (dst).(*int16) - if !ok { - return ErrScanTargetTypeChanged - } - - *p = int16(binary.BigEndian.Uint16(src)) - return nil -} diff --git a/pgtype/int_scan_plans.go b/pgtype/int_scan_plans.go index e7fce506..1694e021 100644 --- a/pgtype/int_scan_plans.go +++ b/pgtype/int_scan_plans.go @@ -1,7 +1,9 @@ package pgtype import ( + "encoding/binary" "fmt" + "math" "strconv" ) @@ -214,3 +216,249 @@ func (scanPlanTextAnyToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, sr *p = uint(n) return nil } + +type scanPlanBinaryInt2ToInt8 struct{} + +func (scanPlanBinaryInt2ToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int16(binary.BigEndian.Uint16(src)) + if n < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", n) + } else if n > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", n) + } + + *p = int8(n) + + return nil +} + +type scanPlanBinaryInt2ToUint8 struct{} + +func (scanPlanBinaryInt2ToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for uint2: %v", len(src)) + } + + p, ok := (dst).(*uint8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int16(binary.BigEndian.Uint16(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint8", n) + } + + if n > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", n) + } + + *p = uint8(n) + + return nil +} + +type scanPlanBinaryInt2ToInt16 struct{} + +func (scanPlanBinaryInt2ToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int16(binary.BigEndian.Uint16(src)) + + return nil +} + +type scanPlanBinaryInt2ToUint16 struct{} + +func (scanPlanBinaryInt2ToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for uint2: %v", len(src)) + } + + p, ok := (dst).(*uint16) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int16(binary.BigEndian.Uint16(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint16", n) + } + + *p = uint16(n) + + return nil +} + +type scanPlanBinaryInt2ToInt32 struct{} + +func (scanPlanBinaryInt2ToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int32) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int32(binary.BigEndian.Uint16(src)) + + return nil +} + +type scanPlanBinaryInt2ToUint32 struct{} + +func (scanPlanBinaryInt2ToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for uint2: %v", len(src)) + } + + p, ok := (dst).(*uint32) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int16(binary.BigEndian.Uint16(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint32", n) + } + + *p = uint32(n) + + return nil +} + +type scanPlanBinaryInt2ToInt64 struct{} + +func (scanPlanBinaryInt2ToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int64) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int64(binary.BigEndian.Uint16(src)) + + return nil +} + +type scanPlanBinaryInt2ToUint64 struct{} + +func (scanPlanBinaryInt2ToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for uint2: %v", len(src)) + } + + p, ok := (dst).(*uint64) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int16(binary.BigEndian.Uint16(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", n) + } + + *p = uint64(n) + + return nil +} + +type scanPlanBinaryInt2ToInt struct{} + +func (scanPlanBinaryInt2ToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int(binary.BigEndian.Uint16(src)) + + return nil +} + +type scanPlanBinaryInt2ToUint struct{} + +func (scanPlanBinaryInt2ToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for uint2: %v", len(src)) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint16(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint", n) + } + + *p = uint(n) + + return nil +} diff --git a/pgtype/int_scan_plans.go.erb b/pgtype/int_scan_plans.go.erb index abdd1329..448453f1 100644 --- a/pgtype/int_scan_plans.go.erb +++ b/pgtype/int_scan_plans.go.erb @@ -54,3 +54,138 @@ func (scanPlanTextAnyToUint<%= type_suffix %>) Scan(ci *ConnInfo, oid uint32, fo return nil } <% end %> + +<% [ + [16, 8], + [16, 16], + [16, 32], + [16, 64], +].each do |src_bit_size, dst_bit_size| %> +<% src_byte_size = src_bit_size / 8 %> +type scanPlanBinaryInt<%= src_byte_size %>ToInt<%= dst_bit_size %> struct{} + +func (scanPlanBinaryInt<%= src_byte_size %>ToInt<%= dst_bit_size %>) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != <%= src_byte_size %> { + return fmt.Errorf("invalid length for int<%= src_byte_size %>: %v", len(src)) + } + + p, ok := (dst).(*int<%= dst_bit_size %>) + if !ok { + return ErrScanTargetTypeChanged + } + + <% if dst_bit_size < src_bit_size %> + n := int<%= src_bit_size %>(binary.BigEndian.Uint<%= src_bit_size %>(src)) + if n < math.MinInt<%= dst_bit_size %> { + return fmt.Errorf("%d is less than minimum value for int<%= dst_bit_size %>", n) + } else if n > math.MaxInt<%= dst_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for int<%= dst_bit_size %>", n) + } + + *p = int<%= dst_bit_size %>(n) + <% else %> + *p = int<%= dst_bit_size %>(binary.BigEndian.Uint<%= src_bit_size %>(src)) + <% end %> + + return nil +} + +type scanPlanBinaryInt<%= src_byte_size %>ToUint<%= dst_bit_size %> struct{} + +func (scanPlanBinaryInt<%= src_byte_size %>ToUint<%= dst_bit_size %>) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != <%= src_byte_size %> { + return fmt.Errorf("invalid length for uint<%= src_byte_size %>: %v", len(src)) + } + + p, ok := (dst).(*uint<%= dst_bit_size %>) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int<%= src_bit_size %>(binary.BigEndian.Uint<%= src_bit_size %>(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint<%= dst_bit_size %>", n) + } + <% if dst_bit_size < src_bit_size %> + if n > math.MaxUint<%= dst_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for uint<%= dst_bit_size %>", n) + } + <% end %> + *p = uint<%= dst_bit_size %>(n) + + return nil +} +<% end %> + +<% [16].each do |src_bit_size| %> +<% src_byte_size = src_bit_size / 8 %> +type scanPlanBinaryInt<%= src_byte_size %>ToInt struct{} + +func (scanPlanBinaryInt<%= src_byte_size %>ToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != <%= src_byte_size %> { + return fmt.Errorf("invalid length for int<%= src_byte_size %>: %v", len(src)) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + <% if 32 < src_bit_size %> + n := int64(binary.BigEndian.Uint<%= src_bit_size %>(src)) + if n < math.MinInt { + return fmt.Errorf("%d is less than minimum value for int", n) + } else if n > math.MaxInt { + return fmt.Errorf("%d is greater than maximum value for int", n) + } + + *p = int(n) + <% else %> + *p = int(binary.BigEndian.Uint<%= src_bit_size %>(src)) + <% end %> + + return nil +} + +type scanPlanBinaryInt<%= src_byte_size %>ToUint struct{} + +func (scanPlanBinaryInt<%= src_byte_size %>ToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != <%= src_byte_size %> { + return fmt.Errorf("invalid length for uint<%= src_byte_size %>: %v", len(src)) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint<%= src_bit_size %>(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint", n) + } + <% if 32 < src_bit_size %> + if uint64(n) > math.MaxUint { + return fmt.Errorf("%d is greater than maximum value for uint", n) + } + <% end %> + *p = uint(n) + + return nil +} +<% end %> From 1516a0d8db88d11410bdb1948a11d5d533603bd1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 31 Dec 2021 17:51:18 -0600 Subject: [PATCH 0790/1158] pgtype tests pass --- pgtype/int2.go | 18 ++++++++++++ pgtype/int2_codec.go | 8 ++++++ pgtype/int_scan_plans.go | 46 ++++++++++++++++++++++++++++++ pgtype/int_scan_plans.go.erb | 54 ++++++++++++++++++++++++++++++++++++ 4 files changed, 126 insertions(+) diff --git a/pgtype/int2.go b/pgtype/int2.go index b7b7243f..309fe6b6 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -12,6 +12,24 @@ type Int2 struct { Valid bool } +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int2) ScanInt64(n int64, valid bool) error { + if !valid { + *dst = Int2{} + return nil + } + + if n < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) + } + if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) + } + *dst = Int2{Int: int16(n), Valid: true} + + return nil +} + // Scan implements the database/sql Scanner interface. func (dst *Int2) Scan(src interface{}) error { if src == nil { diff --git a/pgtype/int2_codec.go b/pgtype/int2_codec.go index d436346c..55fd7a12 100644 --- a/pgtype/int2_codec.go +++ b/pgtype/int2_codec.go @@ -70,6 +70,8 @@ func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa return scanPlanBinaryInt2ToUint64{} case *uint: return scanPlanBinaryInt2ToUint{} + case Int64Scanner: + return scanPlanBinaryInt2ToInt64Scanner{} } case TextFormatCode: switch target.(type) { @@ -93,6 +95,8 @@ func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa return scanPlanTextAnyToUint64{} case *uint: return scanPlanTextAnyToUint{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} } } @@ -132,3 +136,7 @@ func (c Int2Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt } return n, nil } + +type Int64Scanner interface { + ScanInt64(v int64, valid bool) error +} diff --git a/pgtype/int_scan_plans.go b/pgtype/int_scan_plans.go index 1694e021..13b61c4a 100644 --- a/pgtype/int_scan_plans.go +++ b/pgtype/int_scan_plans.go @@ -462,3 +462,49 @@ func (scanPlanBinaryInt2ToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, return nil } + +type scanPlanBinaryInt2ToInt64Scanner struct{} + +func (scanPlanBinaryInt2ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(0, false) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + n := int64(binary.BigEndian.Uint16(src)) + + return s.ScanInt64(n, true) +} + +type scanPlanTextAnyToInt64Scanner struct{} + +func (scanPlanTextAnyToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(0, false) + } + + n, err := strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + + err = s.ScanInt64(n, true) + if err != nil { + return err + } + + return nil +} diff --git a/pgtype/int_scan_plans.go.erb b/pgtype/int_scan_plans.go.erb index 448453f1..67bfa3f4 100644 --- a/pgtype/int_scan_plans.go.erb +++ b/pgtype/int_scan_plans.go.erb @@ -5,6 +5,7 @@ import ( "strconv" ) +<%# Any text to all integer types %> <% [ ["8", 8], ["16", 16], @@ -55,6 +56,7 @@ func (scanPlanTextAnyToUint<%= type_suffix %>) Scan(ci *ConnInfo, oid uint32, fo } <% end %> +<%# PostgreSQL binary integers to fixed size Go integers %> <% [ [16, 8], [16, 16], @@ -125,6 +127,7 @@ func (scanPlanBinaryInt<%= src_byte_size %>ToUint<%= dst_bit_size %>) Scan(ci *C } <% end %> +<%# PostgreSQL binary integers to Go machine integers %> <% [16].each do |src_bit_size| %> <% src_byte_size = src_bit_size / 8 %> type scanPlanBinaryInt<%= src_byte_size %>ToInt struct{} @@ -189,3 +192,54 @@ func (scanPlanBinaryInt<%= src_byte_size %>ToUint) Scan(ci *ConnInfo, oid uint32 return nil } <% end %> + +<%# PostgreSQL binary integers to Go Int64Scanner %> +<% [16].each do |src_bit_size| %> +<% src_byte_size = src_bit_size / 8 %> +type scanPlanBinaryInt<%= src_byte_size %>ToInt64Scanner struct{} + +func (scanPlanBinaryInt<%= src_byte_size %>ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(0, false) + } + + if len(src) != <%= src_byte_size %> { + return fmt.Errorf("invalid length for int<%= src_byte_size %>: %v", len(src)) + } + + + n := int64(binary.BigEndian.Uint<%= src_bit_size %>(src)) + + return s.ScanInt64(n, true) +} +<% end %> + +type scanPlanTextAnyToInt64Scanner struct{} + +func (scanPlanTextAnyToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(0, false) + } + + n, err := strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + + err = s.ScanInt64(n, true) + if err != nil { + return err + } + + return nil +} From 93cc21199f3ab90d3e17a21494e449670f8c1cf8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 31 Dec 2021 17:54:47 -0600 Subject: [PATCH 0791/1158] All tests passing --- pgtype/zeronull/int2.go | 69 ++++++++++------------------------------- 1 file changed, 17 insertions(+), 52 deletions(-) diff --git a/pgtype/zeronull/int2.go b/pgtype/zeronull/int2.go index 81e89ab3..2f63d8cc 100644 --- a/pgtype/zeronull/int2.go +++ b/pgtype/zeronull/int2.go @@ -2,70 +2,32 @@ package zeronull import ( "database/sql/driver" + "fmt" + "math" "github.com/jackc/pgx/v5/pgtype" ) type Int2 int16 -func (dst *Int2) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - var nullable pgtype.Int2 - err := nullable.DecodeText(ci, src) - if err != nil { - return err +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int2) ScanInt64(n int64, valid bool) error { + if !valid { + *dst = 0 + return nil } - if nullable.Valid { - *dst = Int2(nullable.Int) - } else { - *dst = 0 + if n < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) } + if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) + } + *dst = Int2(n) return nil } -func (dst *Int2) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - var nullable pgtype.Int2 - err := nullable.DecodeBinary(ci, src) - if err != nil { - return err - } - - if nullable.Valid { - *dst = Int2(nullable.Int) - } else { - *dst = 0 - } - - return nil -} - -func (src Int2) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if src == 0 { - return nil, nil - } - - nullable := pgtype.Int2{ - Int: int16(src), - Valid: true, - } - - return nullable.EncodeText(ci, buf) -} - -func (src Int2) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if src == 0 { - return nil, nil - } - - nullable := pgtype.Int2{ - Int: int16(src), - Valid: true, - } - - return nullable.EncodeBinary(ci, buf) -} - // Scan implements the database/sql Scanner interface. func (dst *Int2) Scan(src interface{}) error { if src == nil { @@ -86,5 +48,8 @@ func (dst *Int2) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Int2) Value() (driver.Value, error) { - return pgtype.EncodeValueText(src) + if src == 0 { + return nil, nil + } + return int64(src), nil } From 6c7f1593e8b270005a850c804e10320154a8f6da Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Jan 2022 10:41:20 -0600 Subject: [PATCH 0792/1158] Use rake to build generated code --- Rakefile | 10 ++++++++++ pgtype/int_scan_plans.go | 1 + 2 files changed, 11 insertions(+) create mode 100644 Rakefile diff --git a/Rakefile b/Rakefile new file mode 100644 index 00000000..70b857de --- /dev/null +++ b/Rakefile @@ -0,0 +1,10 @@ +require "erb" + +rule '.go' => '.go.erb' do |task| + erb = ERB.new(File.read(task.source)) + File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding)) + sh "goimports", "-w", task.name +end + +desc "Generate code" +task generate: ["pgtype/int_scan_plans.go"] diff --git a/pgtype/int_scan_plans.go b/pgtype/int_scan_plans.go index 13b61c4a..20cde1c2 100644 --- a/pgtype/int_scan_plans.go +++ b/pgtype/int_scan_plans.go @@ -1,3 +1,4 @@ +// Do not edit. Generated from pgtype/int_scan_plans.go.erb package pgtype import ( From 1b353297d5a62c3ef77570f011481c7078911bf5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Jan 2022 11:11:31 -0600 Subject: [PATCH 0793/1158] Prepare for generating int types --- Rakefile | 2 +- pgtype/{int_scan_plans.go => int.go} | 544 +++++++++++++++++++-------- pgtype/int.go.erb | 451 ++++++++++++++++++++++ pgtype/int2.go | 85 ----- pgtype/int2_codec.go | 142 ------- pgtype/int_scan_plans.go.erb | 245 ------------ 6 files changed, 830 insertions(+), 639 deletions(-) rename pgtype/{int_scan_plans.go => int.go} (68%) create mode 100644 pgtype/int.go.erb delete mode 100644 pgtype/int2.go delete mode 100644 pgtype/int2_codec.go delete mode 100644 pgtype/int_scan_plans.go.erb diff --git a/Rakefile b/Rakefile index 70b857de..275755bd 100644 --- a/Rakefile +++ b/Rakefile @@ -7,4 +7,4 @@ rule '.go' => '.go.erb' do |task| end desc "Generate code" -task generate: ["pgtype/int_scan_plans.go"] +task generate: ["pgtype/int.go"] diff --git a/pgtype/int_scan_plans.go b/pgtype/int.go similarity index 68% rename from pgtype/int_scan_plans.go rename to pgtype/int.go index 20cde1c2..4d7dea83 100644 --- a/pgtype/int_scan_plans.go +++ b/pgtype/int.go @@ -1,221 +1,223 @@ -// Do not edit. Generated from pgtype/int_scan_plans.go.erb +// Do not edit. Generated from pgtype/int.go.erb package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "math" "strconv" + + "github.com/jackc/pgio" ) -type scanPlanTextAnyToInt8 struct{} +type Int64Scanner interface { + ScanInt64(v int64, valid bool) error +} -func (scanPlanTextAnyToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) +type Int2 struct { + Int int16 + Valid bool +} + +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int2) ScanInt64(n int64, valid bool) error { + if !valid { + *dst = Int2{} + return nil } - p, ok := (dst).(*int8) - if !ok { - return ErrScanTargetTypeChanged + if n < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) } - - n, err := strconv.ParseInt(string(src), 10, 8) - if err != nil { - return err + if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) } + *dst = Int2{Int: int16(n), Valid: true} - *p = int8(n) return nil } -type scanPlanTextAnyToUint8 struct{} - -func (scanPlanTextAnyToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +// Scan implements the database/sql Scanner interface. +func (dst *Int2) Scan(src interface{}) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + *dst = Int2{} + return nil } - p, ok := (dst).(*uint8) - if !ok { - return ErrScanTargetTypeChanged + var n int64 + + switch src := src.(type) { + case int64: + n = src + case string: + var err error + n, err = strconv.ParseInt(src, 10, 16) + if err != nil { + return err + } + case []byte: + var err error + n, err = strconv.ParseInt(string(src), 10, 16) + if err != nil { + return err + } + default: + return fmt.Errorf("cannot scan %T", src) } - n, err := strconv.ParseUint(string(src), 10, 8) - if err != nil { - return err + if n < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) } + if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) + } + *dst = Int2{Int: int16(n), Valid: true} - *p = uint8(n) return nil } -type scanPlanTextAnyToInt16 struct{} - -func (scanPlanTextAnyToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) +// Value implements the database/sql/driver Valuer interface. +func (src Int2) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil } + return int64(src.Int), nil +} - p, ok := (dst).(*int16) - if !ok { - return ErrScanTargetTypeChanged +func (src Int2) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil } + return []byte(strconv.FormatInt(int64(src.Int), 10)), nil +} - n, err := strconv.ParseInt(string(src), 10, 16) +type Int2Codec struct{} + +func (Int2Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Int2Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Int2Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { + n, valid, err := convertToInt64ForEncode(value) if err != nil { - return err + return nil, fmt.Errorf("cannot convert %v to int2: %v", value, err) + } + if !valid { + return nil, nil + } + + if n > math.MaxInt16 { + return nil, fmt.Errorf("%d is greater than maximum value for int2", n) + } + if n < math.MinInt16 { + return nil, fmt.Errorf("%d is less than minimum value for int2", n) + } + + switch format { + case BinaryFormatCode: + return pgio.AppendInt16(buf, int16(n)), nil + case TextFormatCode: + return append(buf, strconv.FormatInt(n, 10)...), nil + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } +} + +func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *int8: + return scanPlanBinaryInt2ToInt8{} + case *int16: + return scanPlanBinaryInt2ToInt16{} + case *int32: + return scanPlanBinaryInt2ToInt32{} + case *int64: + return scanPlanBinaryInt2ToInt64{} + case *int: + return scanPlanBinaryInt2ToInt{} + case *uint8: + return scanPlanBinaryInt2ToUint8{} + case *uint16: + return scanPlanBinaryInt2ToUint16{} + case *uint32: + return scanPlanBinaryInt2ToUint32{} + case *uint64: + return scanPlanBinaryInt2ToUint64{} + case *uint: + return scanPlanBinaryInt2ToUint{} + case Int64Scanner: + return scanPlanBinaryInt2ToInt64Scanner{} + } + case TextFormatCode: + switch target.(type) { + case *int8: + return scanPlanTextAnyToInt8{} + case *int16: + return scanPlanTextAnyToInt16{} + case *int32: + return scanPlanTextAnyToInt32{} + case *int64: + return scanPlanTextAnyToInt64{} + case *int: + return scanPlanTextAnyToInt{} + case *uint8: + return scanPlanTextAnyToUint8{} + case *uint16: + return scanPlanTextAnyToUint16{} + case *uint32: + return scanPlanTextAnyToUint32{} + case *uint64: + return scanPlanTextAnyToUint64{} + case *uint: + return scanPlanTextAnyToUint{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } } - *p = int16(n) return nil } -type scanPlanTextAnyToUint16 struct{} - -func (scanPlanTextAnyToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (c Int2Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return nil, nil } - p, ok := (dst).(*uint16) - if !ok { - return ErrScanTargetTypeChanged + var n int64 + scanPlan := c.PlanScan(ci, oid, format, &n, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") } - - n, err := strconv.ParseUint(string(src), 10, 16) + err := scanPlan.Scan(ci, oid, format, src, &n) if err != nil { - return err + return nil, err } - - *p = uint16(n) - return nil + return n, nil } -type scanPlanTextAnyToInt32 struct{} - -func (scanPlanTextAnyToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (c Int2Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return nil, nil } - p, ok := (dst).(*int32) - if !ok { - return ErrScanTargetTypeChanged + var n int16 + scanPlan := c.PlanScan(ci, oid, format, &n, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") } - - n, err := strconv.ParseInt(string(src), 10, 32) + err := scanPlan.Scan(ci, oid, format, src, &n) if err != nil { - return err + return nil, err } - - *p = int32(n) - return nil -} - -type scanPlanTextAnyToUint32 struct{} - -func (scanPlanTextAnyToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*uint32) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseUint(string(src), 10, 32) - if err != nil { - return err - } - - *p = uint32(n) - return nil -} - -type scanPlanTextAnyToInt64 struct{} - -func (scanPlanTextAnyToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*int64) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseInt(string(src), 10, 64) - if err != nil { - return err - } - - *p = int64(n) - return nil -} - -type scanPlanTextAnyToUint64 struct{} - -func (scanPlanTextAnyToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*uint64) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseUint(string(src), 10, 64) - if err != nil { - return err - } - - *p = uint64(n) - return nil -} - -type scanPlanTextAnyToInt struct{} - -func (scanPlanTextAnyToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*int) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseInt(string(src), 10, 0) - if err != nil { - return err - } - - *p = int(n) - return nil -} - -type scanPlanTextAnyToUint struct{} - -func (scanPlanTextAnyToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*uint) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseUint(string(src), 10, 0) - if err != nil { - return err - } - - *p = uint(n) - return nil + return n, nil } type scanPlanBinaryInt2ToInt8 struct{} @@ -485,6 +487,216 @@ func (scanPlanBinaryInt2ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCod return s.ScanInt64(n, true) } +type scanPlanTextAnyToInt8 struct{} + +func (scanPlanTextAnyToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int8) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 8) + if err != nil { + return err + } + + *p = int8(n) + return nil +} + +type scanPlanTextAnyToUint8 struct{} + +func (scanPlanTextAnyToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint8) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 8) + if err != nil { + return err + } + + *p = uint8(n) + return nil +} + +type scanPlanTextAnyToInt16 struct{} + +func (scanPlanTextAnyToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 16) + if err != nil { + return err + } + + *p = int16(n) + return nil +} + +type scanPlanTextAnyToUint16 struct{} + +func (scanPlanTextAnyToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint16) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 16) + if err != nil { + return err + } + + *p = uint16(n) + return nil +} + +type scanPlanTextAnyToInt32 struct{} + +func (scanPlanTextAnyToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int32) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 32) + if err != nil { + return err + } + + *p = int32(n) + return nil +} + +type scanPlanTextAnyToUint32 struct{} + +func (scanPlanTextAnyToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint32) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + *p = uint32(n) + return nil +} + +type scanPlanTextAnyToInt64 struct{} + +func (scanPlanTextAnyToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int64) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + + *p = int64(n) + return nil +} + +type scanPlanTextAnyToUint64 struct{} + +func (scanPlanTextAnyToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint64) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 64) + if err != nil { + return err + } + + *p = uint64(n) + return nil +} + +type scanPlanTextAnyToInt struct{} + +func (scanPlanTextAnyToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 0) + if err != nil { + return err + } + + *p = int(n) + return nil +} + +type scanPlanTextAnyToUint struct{} + +func (scanPlanTextAnyToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 0) + if err != nil { + return err + } + + *p = uint(n) + return nil +} + type scanPlanTextAnyToInt64Scanner struct{} func (scanPlanTextAnyToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb new file mode 100644 index 00000000..152dabed --- /dev/null +++ b/pgtype/int.go.erb @@ -0,0 +1,451 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +type Int64Scanner interface { + ScanInt64(v int64, valid bool) error +} + + +<% [2].each do |pg_byte_size| %> +<% pg_bit_size = pg_byte_size * 8 %> +type Int<%= pg_byte_size %> struct { + Int int<%= pg_bit_size %> + Valid bool +} + +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int<%= pg_byte_size %>) ScanInt64(n int64, valid bool) error { + if !valid { + *dst = Int<%= pg_byte_size %>{} + return nil + } + + if n < math.MinInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + } + if n > math.MaxInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + } + *dst = Int<%= pg_byte_size %>{Int: int<%= pg_bit_size %>(n), Valid: true} + + return nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int<%= pg_byte_size %>) Scan(src interface{}) error { + if src == nil { + *dst = Int<%= pg_byte_size %>{} + return nil + } + + var n int64 + + switch src := src.(type) { + case int64: + n = src + case string: + var err error + n, err = strconv.ParseInt(src, 10, <%= pg_bit_size %>) + if err != nil { + return err + } + case []byte: + var err error + n, err = strconv.ParseInt(string(src), 10, <%= pg_bit_size %>) + if err != nil { + return err + } + default: + return fmt.Errorf("cannot scan %T", src) + } + + if n < math.MinInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + } + if n > math.MaxInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + } + *dst = Int<%= pg_byte_size %>{Int: int<%= pg_bit_size %>(n), Valid: true} + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Int), nil +} + +func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(int64(src.Int), 10)), nil +} + +type Int<%= pg_byte_size %>Codec struct{} + +func (Int<%= pg_byte_size %>Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Int<%= pg_byte_size %>Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Int<%= pg_byte_size %>Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { + n, valid, err := convertToInt64ForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to int<%= pg_byte_size %>: %v", value, err) + } + if !valid { + return nil, nil + } + + if n > math.MaxInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n) + } + if n < math.MinInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n) + } + + switch format { + case BinaryFormatCode: + return pgio.AppendInt<%= pg_bit_size %>(buf, int<%= pg_bit_size %>(n)), nil + case TextFormatCode: + return append(buf, strconv.FormatInt(n, 10)...), nil + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } +} + +func (Int<%= pg_byte_size %>Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *int8: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt8{} + case *int16: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt16{} + case *int32: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt32{} + case *int64: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt64{} + case *int: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt{} + case *uint8: + return scanPlanBinaryInt<%= pg_byte_size %>ToUint8{} + case *uint16: + return scanPlanBinaryInt<%= pg_byte_size %>ToUint16{} + case *uint32: + return scanPlanBinaryInt<%= pg_byte_size %>ToUint32{} + case *uint64: + return scanPlanBinaryInt<%= pg_byte_size %>ToUint64{} + case *uint: + return scanPlanBinaryInt<%= pg_byte_size %>ToUint{} + case Int64Scanner: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner{} + } + case TextFormatCode: + switch target.(type) { + case *int8: + return scanPlanTextAnyToInt8{} + case *int16: + return scanPlanTextAnyToInt16{} + case *int32: + return scanPlanTextAnyToInt32{} + case *int64: + return scanPlanTextAnyToInt64{} + case *int: + return scanPlanTextAnyToInt{} + case *uint8: + return scanPlanTextAnyToUint8{} + case *uint16: + return scanPlanTextAnyToUint16{} + case *uint32: + return scanPlanTextAnyToUint32{} + case *uint64: + return scanPlanTextAnyToUint64{} + case *uint: + return scanPlanTextAnyToUint{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +func (c Int<%= pg_byte_size %>Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n int64 + scanPlan := c.PlanScan(ci, oid, format, &n, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") + } + err := scanPlan.Scan(ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +func (c Int<%= pg_byte_size %>Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var n int<%= pg_bit_size %> + scanPlan := c.PlanScan(ci, oid, format, &n, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") + } + err := scanPlan.Scan(ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +<%# PostgreSQL binary format integer to fixed size Go integers %> +<% [8, 16, 32, 64].each do |dst_bit_size| %> +type scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %> struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %>) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for int<%= pg_byte_size %>: %v", len(src)) + } + + p, ok := (dst).(*int<%= dst_bit_size %>) + if !ok { + return ErrScanTargetTypeChanged + } + + <% if dst_bit_size < pg_bit_size %> + n := int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + if n < math.MinInt<%= dst_bit_size %> { + return fmt.Errorf("%d is less than minimum value for int<%= dst_bit_size %>", n) + } else if n > math.MaxInt<%= dst_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for int<%= dst_bit_size %>", n) + } + + *p = int<%= dst_bit_size %>(n) + <% else %> + *p = int<%= dst_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + <% end %> + + return nil +} + +type scanPlanBinaryInt<%= pg_byte_size %>ToUint<%= dst_bit_size %> struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToUint<%= dst_bit_size %>) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for uint<%= pg_byte_size %>: %v", len(src)) + } + + p, ok := (dst).(*uint<%= dst_bit_size %>) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint<%= dst_bit_size %>", n) + } + <% if dst_bit_size < pg_bit_size %> + if n > math.MaxUint<%= dst_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for uint<%= dst_bit_size %>", n) + } + <% end %> + *p = uint<%= dst_bit_size %>(n) + + return nil +} +<% end %> + +<%# PostgreSQL binary format integer to Go machine integers %> +type scanPlanBinaryInt<%= pg_byte_size %>ToInt struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for int<%= pg_byte_size %>: %v", len(src)) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + <% if 32 < pg_bit_size %> + n := int64(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + if n < math.MinInt { + return fmt.Errorf("%d is less than minimum value for int", n) + } else if n > math.MaxInt { + return fmt.Errorf("%d is greater than maximum value for int", n) + } + + *p = int(n) + <% else %> + *p = int(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + <% end %> + + return nil +} + +type scanPlanBinaryInt<%= pg_byte_size %>ToUint struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for uint<%= pg_byte_size %>: %v", len(src)) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint", n) + } + <% if 32 < pg_bit_size %> + if uint64(n) > math.MaxUint { + return fmt.Errorf("%d is greater than maximum value for uint", n) + } + <% end %> + *p = uint(n) + + return nil +} + +<%# PostgreSQL binary format integer to Go Int64Scanner %> +type scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(0, false) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for int<%= pg_byte_size %>: %v", len(src)) + } + + + n := int64(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + + return s.ScanInt64(n, true) +} +<% end %> + +<%# Any text to all integer types %> +<% [ + ["8", 8], + ["16", 16], + ["32", 32], + ["64", 64], + ["", 0] +].each do |type_suffix, bit_size| %> +type scanPlanTextAnyToInt<%= type_suffix %> struct{} + +func (scanPlanTextAnyToInt<%= type_suffix %>) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int<%= type_suffix %>) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, <%= bit_size %>) + if err != nil { + return err + } + + *p = int<%= type_suffix %>(n) + return nil +} + +type scanPlanTextAnyToUint<%= type_suffix %> struct{} + +func (scanPlanTextAnyToUint<%= type_suffix %>) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint<%= type_suffix %>) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, <%= bit_size %>) + if err != nil { + return err + } + + *p = uint<%= type_suffix %>(n) + return nil +} +<% end %> + +type scanPlanTextAnyToInt64Scanner struct{} + +func (scanPlanTextAnyToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(0, false) + } + + n, err := strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + + err = s.ScanInt64(n, true) + if err != nil { + return err + } + + return nil +} diff --git a/pgtype/int2.go b/pgtype/int2.go deleted file mode 100644 index 309fe6b6..00000000 --- a/pgtype/int2.go +++ /dev/null @@ -1,85 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "fmt" - "math" - "strconv" -) - -type Int2 struct { - Int int16 - Valid bool -} - -// ScanInt64 implements the Int64Scanner interface. -func (dst *Int2) ScanInt64(n int64, valid bool) error { - if !valid { - *dst = Int2{} - return nil - } - - if n < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n) - } - if n > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n) - } - *dst = Int2{Int: int16(n), Valid: true} - - return nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int2) Scan(src interface{}) error { - if src == nil { - *dst = Int2{} - return nil - } - - var n int64 - - switch src := src.(type) { - case int64: - n = src - case string: - var err error - n, err = strconv.ParseInt(src, 10, 16) - if err != nil { - return err - } - case []byte: - var err error - n, err = strconv.ParseInt(string(src), 10, 16) - if err != nil { - return err - } - default: - return fmt.Errorf("cannot scan %T", src) - } - - if n < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n) - } - if n > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n) - } - *dst = Int2{Int: int16(n), Valid: true} - - return nil -} - -// Value implements the database/sql/driver Valuer interface. -func (src Int2) Value() (driver.Value, error) { - if !src.Valid { - return nil, nil - } - return int64(src.Int), nil -} - -func (src Int2) MarshalJSON() ([]byte, error) { - if !src.Valid { - return []byte("null"), nil - } - return []byte(strconv.FormatInt(int64(src.Int), 10)), nil -} diff --git a/pgtype/int2_codec.go b/pgtype/int2_codec.go deleted file mode 100644 index 55fd7a12..00000000 --- a/pgtype/int2_codec.go +++ /dev/null @@ -1,142 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "fmt" - "math" - "strconv" - - "github.com/jackc/pgio" -) - -type Int2Codec struct{} - -func (Int2Codec) FormatSupported(format int16) bool { - return format == TextFormatCode || format == BinaryFormatCode -} - -func (Int2Codec) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (Int2Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) - if err != nil { - return nil, fmt.Errorf("cannot convert %v to int2: %v", value, err) - } - if !valid { - return nil, nil - } - - if n > math.MaxInt16 { - return nil, fmt.Errorf("%d is greater than maximum value for int2", n) - } - if n < math.MinInt16 { - return nil, fmt.Errorf("%d is less than minimum value for int2", n) - } - - switch format { - case BinaryFormatCode: - return pgio.AppendInt16(buf, int16(n)), nil - case TextFormatCode: - return append(buf, strconv.FormatInt(n, 10)...), nil - default: - return nil, fmt.Errorf("unknown format code: %v", format) - } -} - -func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { - - switch format { - case BinaryFormatCode: - switch target.(type) { - case *int8: - return scanPlanBinaryInt2ToInt8{} - case *int16: - return scanPlanBinaryInt2ToInt16{} - case *int32: - return scanPlanBinaryInt2ToInt32{} - case *int64: - return scanPlanBinaryInt2ToInt64{} - case *int: - return scanPlanBinaryInt2ToInt{} - case *uint8: - return scanPlanBinaryInt2ToUint8{} - case *uint16: - return scanPlanBinaryInt2ToUint16{} - case *uint32: - return scanPlanBinaryInt2ToUint32{} - case *uint64: - return scanPlanBinaryInt2ToUint64{} - case *uint: - return scanPlanBinaryInt2ToUint{} - case Int64Scanner: - return scanPlanBinaryInt2ToInt64Scanner{} - } - case TextFormatCode: - switch target.(type) { - case *int8: - return scanPlanTextAnyToInt8{} - case *int16: - return scanPlanTextAnyToInt16{} - case *int32: - return scanPlanTextAnyToInt32{} - case *int64: - return scanPlanTextAnyToInt64{} - case *int: - return scanPlanTextAnyToInt{} - case *uint8: - return scanPlanTextAnyToUint8{} - case *uint16: - return scanPlanTextAnyToUint16{} - case *uint32: - return scanPlanTextAnyToUint32{} - case *uint64: - return scanPlanTextAnyToUint64{} - case *uint: - return scanPlanTextAnyToUint{} - case Int64Scanner: - return scanPlanTextAnyToInt64Scanner{} - } - } - - return nil -} - -func (c Int2Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - if src == nil { - return nil, nil - } - - var n int64 - scanPlan := c.PlanScan(ci, oid, format, &n, true) - if scanPlan == nil { - return nil, fmt.Errorf("PlanScan did not find a plan") - } - err := scanPlan.Scan(ci, oid, format, src, &n) - if err != nil { - return nil, err - } - return n, nil -} - -func (c Int2Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { - if src == nil { - return nil, nil - } - - var n int16 - scanPlan := c.PlanScan(ci, oid, format, &n, true) - if scanPlan == nil { - return nil, fmt.Errorf("PlanScan did not find a plan") - } - err := scanPlan.Scan(ci, oid, format, src, &n) - if err != nil { - return nil, err - } - return n, nil -} - -type Int64Scanner interface { - ScanInt64(v int64, valid bool) error -} diff --git a/pgtype/int_scan_plans.go.erb b/pgtype/int_scan_plans.go.erb deleted file mode 100644 index 67bfa3f4..00000000 --- a/pgtype/int_scan_plans.go.erb +++ /dev/null @@ -1,245 +0,0 @@ -package pgtype - -import ( - "fmt" - "strconv" -) - -<%# Any text to all integer types %> -<% [ - ["8", 8], - ["16", 16], - ["32", 32], - ["64", 64], - ["", 0] -].each do |type_suffix, bit_size| %> -type scanPlanTextAnyToInt<%= type_suffix %> struct{} - -func (scanPlanTextAnyToInt<%= type_suffix %>) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*int<%= type_suffix %>) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseInt(string(src), 10, <%= bit_size %>) - if err != nil { - return err - } - - *p = int<%= type_suffix %>(n) - return nil -} - -type scanPlanTextAnyToUint<%= type_suffix %> struct{} - -func (scanPlanTextAnyToUint<%= type_suffix %>) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - p, ok := (dst).(*uint<%= type_suffix %>) - if !ok { - return ErrScanTargetTypeChanged - } - - n, err := strconv.ParseUint(string(src), 10, <%= bit_size %>) - if err != nil { - return err - } - - *p = uint<%= type_suffix %>(n) - return nil -} -<% end %> - -<%# PostgreSQL binary integers to fixed size Go integers %> -<% [ - [16, 8], - [16, 16], - [16, 32], - [16, 64], -].each do |src_bit_size, dst_bit_size| %> -<% src_byte_size = src_bit_size / 8 %> -type scanPlanBinaryInt<%= src_byte_size %>ToInt<%= dst_bit_size %> struct{} - -func (scanPlanBinaryInt<%= src_byte_size %>ToInt<%= dst_bit_size %>) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - if len(src) != <%= src_byte_size %> { - return fmt.Errorf("invalid length for int<%= src_byte_size %>: %v", len(src)) - } - - p, ok := (dst).(*int<%= dst_bit_size %>) - if !ok { - return ErrScanTargetTypeChanged - } - - <% if dst_bit_size < src_bit_size %> - n := int<%= src_bit_size %>(binary.BigEndian.Uint<%= src_bit_size %>(src)) - if n < math.MinInt<%= dst_bit_size %> { - return fmt.Errorf("%d is less than minimum value for int<%= dst_bit_size %>", n) - } else if n > math.MaxInt<%= dst_bit_size %> { - return fmt.Errorf("%d is greater than maximum value for int<%= dst_bit_size %>", n) - } - - *p = int<%= dst_bit_size %>(n) - <% else %> - *p = int<%= dst_bit_size %>(binary.BigEndian.Uint<%= src_bit_size %>(src)) - <% end %> - - return nil -} - -type scanPlanBinaryInt<%= src_byte_size %>ToUint<%= dst_bit_size %> struct{} - -func (scanPlanBinaryInt<%= src_byte_size %>ToUint<%= dst_bit_size %>) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - if len(src) != <%= src_byte_size %> { - return fmt.Errorf("invalid length for uint<%= src_byte_size %>: %v", len(src)) - } - - p, ok := (dst).(*uint<%= dst_bit_size %>) - if !ok { - return ErrScanTargetTypeChanged - } - - n := int<%= src_bit_size %>(binary.BigEndian.Uint<%= src_bit_size %>(src)) - if n < 0 { - return fmt.Errorf("%d is less than minimum value for uint<%= dst_bit_size %>", n) - } - <% if dst_bit_size < src_bit_size %> - if n > math.MaxUint<%= dst_bit_size %> { - return fmt.Errorf("%d is greater than maximum value for uint<%= dst_bit_size %>", n) - } - <% end %> - *p = uint<%= dst_bit_size %>(n) - - return nil -} -<% end %> - -<%# PostgreSQL binary integers to Go machine integers %> -<% [16].each do |src_bit_size| %> -<% src_byte_size = src_bit_size / 8 %> -type scanPlanBinaryInt<%= src_byte_size %>ToInt struct{} - -func (scanPlanBinaryInt<%= src_byte_size %>ToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - if len(src) != <%= src_byte_size %> { - return fmt.Errorf("invalid length for int<%= src_byte_size %>: %v", len(src)) - } - - p, ok := (dst).(*int) - if !ok { - return ErrScanTargetTypeChanged - } - - <% if 32 < src_bit_size %> - n := int64(binary.BigEndian.Uint<%= src_bit_size %>(src)) - if n < math.MinInt { - return fmt.Errorf("%d is less than minimum value for int", n) - } else if n > math.MaxInt { - return fmt.Errorf("%d is greater than maximum value for int", n) - } - - *p = int(n) - <% else %> - *p = int(binary.BigEndian.Uint<%= src_bit_size %>(src)) - <% end %> - - return nil -} - -type scanPlanBinaryInt<%= src_byte_size %>ToUint struct{} - -func (scanPlanBinaryInt<%= src_byte_size %>ToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - if len(src) != <%= src_byte_size %> { - return fmt.Errorf("invalid length for uint<%= src_byte_size %>: %v", len(src)) - } - - p, ok := (dst).(*uint) - if !ok { - return ErrScanTargetTypeChanged - } - - n := int64(binary.BigEndian.Uint<%= src_bit_size %>(src)) - if n < 0 { - return fmt.Errorf("%d is less than minimum value for uint", n) - } - <% if 32 < src_bit_size %> - if uint64(n) > math.MaxUint { - return fmt.Errorf("%d is greater than maximum value for uint", n) - } - <% end %> - *p = uint(n) - - return nil -} -<% end %> - -<%# PostgreSQL binary integers to Go Int64Scanner %> -<% [16].each do |src_bit_size| %> -<% src_byte_size = src_bit_size / 8 %> -type scanPlanBinaryInt<%= src_byte_size %>ToInt64Scanner struct{} - -func (scanPlanBinaryInt<%= src_byte_size %>ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - s, ok := (dst).(Int64Scanner) - if !ok { - return ErrScanTargetTypeChanged - } - - if src == nil { - return s.ScanInt64(0, false) - } - - if len(src) != <%= src_byte_size %> { - return fmt.Errorf("invalid length for int<%= src_byte_size %>: %v", len(src)) - } - - - n := int64(binary.BigEndian.Uint<%= src_bit_size %>(src)) - - return s.ScanInt64(n, true) -} -<% end %> - -type scanPlanTextAnyToInt64Scanner struct{} - -func (scanPlanTextAnyToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - s, ok := (dst).(Int64Scanner) - if !ok { - return ErrScanTargetTypeChanged - } - - if src == nil { - return s.ScanInt64(0, false) - } - - n, err := strconv.ParseInt(string(src), 10, 64) - if err != nil { - return err - } - - err = s.ScanInt64(n, true) - if err != nil { - return err - } - - return nil -} From 0403c34ae3a56aa5517d928e577d7ce9c563d1e9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Jan 2022 11:22:14 -0600 Subject: [PATCH 0794/1158] Prepare for generating tests --- Rakefile | 2 +- pgtype/{int2_test.go => int_test.go} | 54 +--------------------------- pgtype/int_test.go.erb | 45 +++++++++++++++++++++++ pgtype/pgtype_test.go | 52 +++++++++++++++++++++++++++ 4 files changed, 99 insertions(+), 54 deletions(-) rename pgtype/{int2_test.go => int_test.go} (56%) create mode 100644 pgtype/int_test.go.erb diff --git a/Rakefile b/Rakefile index 275755bd..7076d2a0 100644 --- a/Rakefile +++ b/Rakefile @@ -7,4 +7,4 @@ rule '.go' => '.go.erb' do |task| end desc "Generate code" -task generate: ["pgtype/int.go"] +task generate: ["pgtype/int.go", "pgtype/int_test.go"] diff --git a/pgtype/int2_test.go b/pgtype/int_test.go similarity index 56% rename from pgtype/int2_test.go rename to pgtype/int_test.go index f5bdac89..3ba1306b 100644 --- a/pgtype/int2_test.go +++ b/pgtype/int_test.go @@ -1,65 +1,13 @@ +// Do not edit. Generated from pgtype/int_test.go.erb package pgtype_test import ( - "context" - "fmt" "math" - "reflect" "testing" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -type PgxTranscodeTestCase struct { - src interface{} - dst interface{} - test func(interface{}) bool -} - -func isExpectedEq(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { - return a == v - } -} - -func testPgxCodec(t testing.TB, pgTypeName string, tests []PgxTranscodeTestCase) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - formats := []struct { - name string - code int16 - }{ - {name: "TextFormat", code: pgx.TextFormatCode}, - {name: "BinaryFormat", code: pgx.BinaryFormatCode}, - } - - for i, tt := range tests { - for _, format := range formats { - err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{format.code}, tt.src).Scan(tt.dst) - if err != nil { - t.Errorf("%s %d: %v", format.name, i, err) - } - - dst := reflect.ValueOf(tt.dst) - if dst.Kind() == reflect.Ptr { - dst = dst.Elem() - } - - if !tt.test(dst.Interface()) { - t.Errorf("%s %d: unexpected result for %v: %v", format.name, i, tt.src, dst.Interface()) - } - } - } -} - func TestInt2Codec(t *testing.T) { testPgxCodec(t, "int2", []PgxTranscodeTestCase{ {int8(1), new(int16), isExpectedEq(int16(1))}, diff --git a/pgtype/int_test.go.erb b/pgtype/int_test.go.erb new file mode 100644 index 00000000..be1f5358 --- /dev/null +++ b/pgtype/int_test.go.erb @@ -0,0 +1,45 @@ +package pgtype_test + +import ( + "math" + "testing" + + "github.com/jackc/pgx/v5/pgtype" +) + +<% [2].each do |pg_byte_size| %> +<% pg_bit_size = pg_byte_size * 8 %> +func TestInt<%= pg_byte_size %>Codec(t *testing.T) { + testPgxCodec(t, "int<%= pg_byte_size %>", []PgxTranscodeTestCase{ + {int8(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int16(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int32(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int64(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint8(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint16(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint32(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint64(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {pgtype.Int<%= pg_byte_size %>{Int: 1, Valid: true}, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {1, new(int8), isExpectedEq(int8(1))}, + {1, new(int16), isExpectedEq(int16(1))}, + {1, new(int32), isExpectedEq(int32(1))}, + {1, new(int64), isExpectedEq(int64(1))}, + {1, new(uint8), isExpectedEq(uint8(1))}, + {1, new(uint16), isExpectedEq(uint16(1))}, + {1, new(uint32), isExpectedEq(uint32(1))}, + {1, new(uint64), isExpectedEq(uint64(1))}, + {1, new(int), isExpectedEq(int(1))}, + {1, new(uint), isExpectedEq(uint(1))}, + {math.MinInt<%= pg_bit_size %>, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(math.MinInt<%= pg_bit_size %>))}, + {-1, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(-1))}, + {0, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(0))}, + {1, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {math.MaxInt<%= pg_bit_size %>, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(math.MaxInt<%= pg_bit_size %>))}, + {1, new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{Int: 1, Valid: true})}, + {pgtype.Int<%= pg_byte_size %>{}, new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{})}, + {nil, new(*int<%= pg_bit_size %>), isExpectedEq((*int<%= pg_bit_size %>)(nil))}, + }) +} +<% end %> diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 17b8afe1..43c6c24b 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -2,13 +2,17 @@ package pgtype_test import ( "bytes" + "context" "database/sql" "errors" + "fmt" "net" + "reflect" "testing" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" _ "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -299,3 +303,51 @@ func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { } } } + +type PgxTranscodeTestCase struct { + src interface{} + dst interface{} + test func(interface{}) bool +} + +func isExpectedEq(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + return a == v + } +} + +func testPgxCodec(t testing.TB, pgTypeName string, tests []PgxTranscodeTestCase) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for i, tt := range tests { + for _, format := range formats { + err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{format.code}, tt.src).Scan(tt.dst) + if err != nil { + t.Errorf("%s %d: %v", format.name, i, err) + } + + dst := reflect.ValueOf(tt.dst) + if dst.Kind() == reflect.Ptr { + dst = dst.Elem() + } + + if !tt.test(dst.Interface()) { + t.Errorf("%s %d: unexpected result for %v: %v", format.name, i, tt.src, dst.Interface()) + } + } + } +} From d2cf33ed40645729a0439cc1b33eb05c3ed0c17f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Jan 2022 11:25:26 -0600 Subject: [PATCH 0795/1158] Add UnmarshalJSON to generated ints --- pgtype/int.go | 17 +++++++++++++++++ pgtype/int.go.erb | 16 ++++++++++++++++ pgtype/int_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ pgtype/int_test.go.erb | 41 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 115 insertions(+) diff --git a/pgtype/int.go b/pgtype/int.go index 4d7dea83..9da79a89 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -4,6 +4,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "fmt" "math" "strconv" @@ -92,6 +93,22 @@ func (src Int2) MarshalJSON() ([]byte, error) { return []byte(strconv.FormatInt(int64(src.Int), 10)), nil } +func (dst *Int2) UnmarshalJSON(b []byte) error { + var n *int16 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int2{} + } else { + *dst = Int2{Int: *n, Valid: true} + } + + return nil +} + type Int2Codec struct{} func (Int2Codec) FormatSupported(format int16) bool { diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 152dabed..5c8e44fa 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -94,6 +94,22 @@ func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) { return []byte(strconv.FormatInt(int64(src.Int), 10)), nil } +func (dst *Int<%= pg_byte_size %>) UnmarshalJSON(b []byte) error { + var n *int<%= pg_bit_size %> + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int<%= pg_byte_size %>{} + } else { + *dst = Int<%= pg_byte_size %>{Int: *n, Valid: true} + } + + return nil +} + type Int<%= pg_byte_size %>Codec struct{} func (Int<%= pg_byte_size %>Codec) FormatSupported(format int16) bool { diff --git a/pgtype/int_test.go b/pgtype/int_test.go index 3ba1306b..272ad69b 100644 --- a/pgtype/int_test.go +++ b/pgtype/int_test.go @@ -41,3 +41,44 @@ func TestInt2Codec(t *testing.T) { {nil, new(*int16), isExpectedEq((*int16)(nil))}, }) } + +func TestInt2MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int2 + result string + }{ + {source: pgtype.Int2{Int: 0}, result: "null"}, + {source: pgtype.Int2{Int: 1, Valid: true}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestInt2UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int2 + }{ + {source: "null", result: pgtype.Int2{Int: 0}}, + {source: "1", result: pgtype.Int2{Int: 1, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Int2 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/int_test.go.erb b/pgtype/int_test.go.erb index be1f5358..28847a11 100644 --- a/pgtype/int_test.go.erb +++ b/pgtype/int_test.go.erb @@ -42,4 +42,45 @@ func TestInt<%= pg_byte_size %>Codec(t *testing.T) { {nil, new(*int<%= pg_bit_size %>), isExpectedEq((*int<%= pg_bit_size %>)(nil))}, }) } + +func TestInt<%= pg_byte_size %>MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int<%= pg_byte_size %> + result string + }{ + {source: pgtype.Int<%= pg_byte_size %>{Int: 0}, result: "null"}, + {source: pgtype.Int<%= pg_byte_size %>{Int: 1, Valid: true}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestInt<%= pg_byte_size %>UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int<%= pg_byte_size %> + }{ + {source: "null", result: pgtype.Int<%= pg_byte_size %>{Int: 0}}, + {source: "1", result: pgtype.Int<%= pg_byte_size %>{Int: 1, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Int<%= pg_byte_size %> + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} <% end %> From ffa1fdd66e778f15f524f3bef862792c7af5b75b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Jan 2022 11:32:52 -0600 Subject: [PATCH 0796/1158] Temporarily remove range type support --- pgtype/daterange.go | 257 --------------------- pgtype/daterange_test.go | 133 ----------- pgtype/int4range.go | 257 --------------------- pgtype/int4range_test.go | 28 --- pgtype/int8range.go | 257 --------------------- pgtype/int8range_test.go | 28 --- pgtype/numrange.go | 257 --------------------- pgtype/numrange_test.go | 46 ---- pgtype/pgtype.go | 16 +- pgtype/range.go | 277 ----------------------- pgtype/range_test.go | 177 --------------- pgtype/tsrange.go | 257 --------------------- pgtype/tsrange_array.go | 457 -------------------------------------- pgtype/tsrange_test.go | 41 ---- pgtype/tstzrange.go | 257 --------------------- pgtype/tstzrange_array.go | 457 -------------------------------------- pgtype/tstzrange_test.go | 49 ---- pgtype/typed_range.go.erb | 259 --------------------- pgtype/typed_range_gen.sh | 7 - 19 files changed, 8 insertions(+), 3509 deletions(-) delete mode 100644 pgtype/daterange.go delete mode 100644 pgtype/daterange_test.go delete mode 100644 pgtype/int4range.go delete mode 100644 pgtype/int4range_test.go delete mode 100644 pgtype/int8range.go delete mode 100644 pgtype/int8range_test.go delete mode 100644 pgtype/numrange.go delete mode 100644 pgtype/numrange_test.go delete mode 100644 pgtype/range.go delete mode 100644 pgtype/range_test.go delete mode 100644 pgtype/tsrange.go delete mode 100644 pgtype/tsrange_array.go delete mode 100644 pgtype/tsrange_test.go delete mode 100644 pgtype/tstzrange.go delete mode 100644 pgtype/tstzrange_array.go delete mode 100644 pgtype/tstzrange_test.go delete mode 100644 pgtype/typed_range.go.erb delete mode 100644 pgtype/typed_range_gen.sh diff --git a/pgtype/daterange.go b/pgtype/daterange.go deleted file mode 100644 index 8b0c03f1..00000000 --- a/pgtype/daterange.go +++ /dev/null @@ -1,257 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "fmt" - - "github.com/jackc/pgio" -) - -type Daterange struct { - Lower Date - Upper Date - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (dst *Daterange) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Daterange{} - return nil - } - - switch value := src.(type) { - case Daterange: - *dst = value - case *Daterange: - *dst = *value - case string: - return dst.DecodeText(nil, []byte(value)) - default: - return fmt.Errorf("cannot convert %v to Daterange", src) - } - - return nil -} - -func (src Daterange) Get() interface{} { - if !src.Valid { - return nil - } - return src -} - -func (src *Daterange) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Daterange) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Daterange{} - return nil - } - - utr, err := ParseUntypedTextRange(string(src)) - if err != nil { - return err - } - - *dst = Daterange{Valid: true} - - dst.LowerType = utr.LowerType - dst.UpperType = utr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { - return err - } - } - - return nil -} - -func (dst *Daterange) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Daterange{} - return nil - } - - ubr, err := ParseUntypedBinaryRange(src) - if err != nil { - return err - } - - *dst = Daterange{Valid: true} - - dst.LowerType = ubr.LowerType - dst.UpperType = ubr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { - return err - } - } - - return nil -} - -func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - switch src.LowerType { - case Exclusive, Unbounded: - buf = append(buf, '(') - case Inclusive: - buf = append(buf, '[') - case Empty: - return append(buf, "empty"...), nil - default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) - } - - var err error - - if src.LowerType != Unbounded { - buf, err = src.Lower.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - } - - buf = append(buf, ',') - - if src.UpperType != Unbounded { - buf, err = src.Upper.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - } - - switch src.UpperType { - case Exclusive, Unbounded: - buf = append(buf, ')') - case Inclusive: - buf = append(buf, ']') - default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) - } - - return buf, nil -} - -func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - var rangeType byte - switch src.LowerType { - case Inclusive: - rangeType |= lowerInclusiveMask - case Unbounded: - rangeType |= lowerUnboundedMask - case Exclusive: - case Empty: - return append(buf, emptyMask), nil - default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) - } - - switch src.UpperType { - case Inclusive: - rangeType |= upperInclusiveMask - case Unbounded: - rangeType |= upperUnboundedMask - case Exclusive: - default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) - } - - buf = append(buf, rangeType) - - var err error - - if src.LowerType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Lower.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - if src.UpperType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Upper.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Daterange) Scan(src interface{}) error { - if src == nil { - *dst = Daterange{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Daterange) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/daterange_test.go b/pgtype/daterange_test.go deleted file mode 100644 index a6501372..00000000 --- a/pgtype/daterange_test.go +++ /dev/null @@ -1,133 +0,0 @@ -package pgtype_test - -import ( - "testing" - "time" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestDaterangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "daterange", []interface{}{ - &pgtype.Daterange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, - &pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Valid: true, - }, - &pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, - Upper: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Valid: true, - }, - &pgtype.Daterange{}, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.Daterange) - b := bb.(pgtype.Daterange) - - return a.Valid == b.Valid && - a.Lower.Time.Equal(b.Lower.Time) && - a.Lower.Valid == b.Lower.Valid && - a.Lower.InfinityModifier == b.Lower.InfinityModifier && - a.Upper.Time.Equal(b.Upper.Time) && - a.Upper.Valid == b.Upper.Valid && - a.Upper.InfinityModifier == b.Upper.InfinityModifier - }) -} - -func TestDaterangeNormalize(t *testing.T) { - testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ - { - SQL: "select daterange('2010-01-01', '2010-01-11', '(]')", - Value: pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(2010, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true}, - Upper: pgtype.Date{Time: time.Date(2010, 1, 12, 0, 0, 0, 0, time.UTC), Valid: true}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Valid: true, - }, - }, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.Daterange) - b := bb.(pgtype.Daterange) - - return a.Valid == b.Valid && - a.Lower.Time.Equal(b.Lower.Time) && - a.Lower.Valid == b.Lower.Valid && - a.Lower.InfinityModifier == b.Lower.InfinityModifier && - a.Upper.Time.Equal(b.Upper.Time) && - a.Upper.Valid == b.Upper.Valid && - a.Upper.InfinityModifier == b.Upper.InfinityModifier - }) -} - -func TestDaterangeSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Daterange - }{ - { - source: nil, - result: pgtype.Daterange{}, - }, - { - source: &pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Valid: true, - }, - result: pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Valid: true, - }, - }, - { - source: pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Valid: true, - }, - result: pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Valid: true, - }, - }, - { - source: "[1990-12-31,2028-01-01)", - result: pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Valid: true, - }, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Daterange - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} diff --git a/pgtype/int4range.go b/pgtype/int4range.go deleted file mode 100644 index 49503c0d..00000000 --- a/pgtype/int4range.go +++ /dev/null @@ -1,257 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "fmt" - - "github.com/jackc/pgio" -) - -type Int4range struct { - Lower Int4 - Upper Int4 - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (dst *Int4range) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Int4range{} - return nil - } - - switch value := src.(type) { - case Int4range: - *dst = value - case *Int4range: - *dst = *value - case string: - return dst.DecodeText(nil, []byte(value)) - default: - return fmt.Errorf("cannot convert %v to Int4range", src) - } - - return nil -} - -func (src Int4range) Get() interface{} { - if !src.Valid { - return nil - } - return src -} - -func (src *Int4range) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Int4range) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int4range{} - return nil - } - - utr, err := ParseUntypedTextRange(string(src)) - if err != nil { - return err - } - - *dst = Int4range{Valid: true} - - dst.LowerType = utr.LowerType - dst.UpperType = utr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { - return err - } - } - - return nil -} - -func (dst *Int4range) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int4range{} - return nil - } - - ubr, err := ParseUntypedBinaryRange(src) - if err != nil { - return err - } - - *dst = Int4range{Valid: true} - - dst.LowerType = ubr.LowerType - dst.UpperType = ubr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { - return err - } - } - - return nil -} - -func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - switch src.LowerType { - case Exclusive, Unbounded: - buf = append(buf, '(') - case Inclusive: - buf = append(buf, '[') - case Empty: - return append(buf, "empty"...), nil - default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) - } - - var err error - - if src.LowerType != Unbounded { - buf, err = src.Lower.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - } - - buf = append(buf, ',') - - if src.UpperType != Unbounded { - buf, err = src.Upper.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - } - - switch src.UpperType { - case Exclusive, Unbounded: - buf = append(buf, ')') - case Inclusive: - buf = append(buf, ']') - default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) - } - - return buf, nil -} - -func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - var rangeType byte - switch src.LowerType { - case Inclusive: - rangeType |= lowerInclusiveMask - case Unbounded: - rangeType |= lowerUnboundedMask - case Exclusive: - case Empty: - return append(buf, emptyMask), nil - default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) - } - - switch src.UpperType { - case Inclusive: - rangeType |= upperInclusiveMask - case Unbounded: - rangeType |= upperUnboundedMask - case Exclusive: - default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) - } - - buf = append(buf, rangeType) - - var err error - - if src.LowerType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Lower.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - if src.UpperType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Upper.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int4range) Scan(src interface{}) error { - if src == nil { - *dst = Int4range{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Int4range) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/int4range_test.go b/pgtype/int4range_test.go deleted file mode 100644 index 1a11e039..00000000 --- a/pgtype/int4range_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestInt4rangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int4range", []interface{}{ - &pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, - &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Valid: true}, Upper: pgtype.Int4{Int: 10, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, - &pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Valid: true}, Upper: pgtype.Int4{Int: -5, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, - &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, Valid: true}, - &pgtype.Int4range{Upper: pgtype.Int4{Int: 1, Valid: true}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, Valid: true}, - &pgtype.Int4range{}, - }) -} - -func TestInt4rangeNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select int4range(1, 10, '(]')", - Value: pgtype.Int4range{Lower: pgtype.Int4{Int: 2, Valid: true}, Upper: pgtype.Int4{Int: 11, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, - }, - }) -} diff --git a/pgtype/int8range.go b/pgtype/int8range.go deleted file mode 100644 index a7cbcd12..00000000 --- a/pgtype/int8range.go +++ /dev/null @@ -1,257 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "fmt" - - "github.com/jackc/pgio" -) - -type Int8range struct { - Lower Int8 - Upper Int8 - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (dst *Int8range) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Int8range{} - return nil - } - - switch value := src.(type) { - case Int8range: - *dst = value - case *Int8range: - *dst = *value - case string: - return dst.DecodeText(nil, []byte(value)) - default: - return fmt.Errorf("cannot convert %v to Int8range", src) - } - - return nil -} - -func (src Int8range) Get() interface{} { - if !src.Valid { - return nil - } - return src -} - -func (src *Int8range) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Int8range) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int8range{} - return nil - } - - utr, err := ParseUntypedTextRange(string(src)) - if err != nil { - return err - } - - *dst = Int8range{Valid: true} - - dst.LowerType = utr.LowerType - dst.UpperType = utr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { - return err - } - } - - return nil -} - -func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int8range{} - return nil - } - - ubr, err := ParseUntypedBinaryRange(src) - if err != nil { - return err - } - - *dst = Int8range{Valid: true} - - dst.LowerType = ubr.LowerType - dst.UpperType = ubr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { - return err - } - } - - return nil -} - -func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - switch src.LowerType { - case Exclusive, Unbounded: - buf = append(buf, '(') - case Inclusive: - buf = append(buf, '[') - case Empty: - return append(buf, "empty"...), nil - default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) - } - - var err error - - if src.LowerType != Unbounded { - buf, err = src.Lower.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - } - - buf = append(buf, ',') - - if src.UpperType != Unbounded { - buf, err = src.Upper.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - } - - switch src.UpperType { - case Exclusive, Unbounded: - buf = append(buf, ')') - case Inclusive: - buf = append(buf, ']') - default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) - } - - return buf, nil -} - -func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - var rangeType byte - switch src.LowerType { - case Inclusive: - rangeType |= lowerInclusiveMask - case Unbounded: - rangeType |= lowerUnboundedMask - case Exclusive: - case Empty: - return append(buf, emptyMask), nil - default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) - } - - switch src.UpperType { - case Inclusive: - rangeType |= upperInclusiveMask - case Unbounded: - rangeType |= upperUnboundedMask - case Exclusive: - default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) - } - - buf = append(buf, rangeType) - - var err error - - if src.LowerType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Lower.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - if src.UpperType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Upper.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int8range) Scan(src interface{}) error { - if src == nil { - *dst = Int8range{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Int8range) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/int8range_test.go b/pgtype/int8range_test.go deleted file mode 100644 index 1fab1caa..00000000 --- a/pgtype/int8range_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestInt8rangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "Int8range", []interface{}{ - &pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, - &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Valid: true}, Upper: pgtype.Int8{Int: 10, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, - &pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Valid: true}, Upper: pgtype.Int8{Int: -5, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, - &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, Valid: true}, - &pgtype.Int8range{Upper: pgtype.Int8{Int: 1, Valid: true}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, Valid: true}, - &pgtype.Int8range{}, - }) -} - -func TestInt8rangeNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select Int8range(1, 10, '(]')", - Value: pgtype.Int8range{Lower: pgtype.Int8{Int: 2, Valid: true}, Upper: pgtype.Int8{Int: 11, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, - }, - }) -} diff --git a/pgtype/numrange.go b/pgtype/numrange.go deleted file mode 100644 index f1118d83..00000000 --- a/pgtype/numrange.go +++ /dev/null @@ -1,257 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "fmt" - - "github.com/jackc/pgio" -) - -type Numrange struct { - Lower Numeric - Upper Numeric - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (dst *Numrange) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Numrange{} - return nil - } - - switch value := src.(type) { - case Numrange: - *dst = value - case *Numrange: - *dst = *value - case string: - return dst.DecodeText(nil, []byte(value)) - default: - return fmt.Errorf("cannot convert %v to Numrange", src) - } - - return nil -} - -func (src Numrange) Get() interface{} { - if !src.Valid { - return nil - } - return src -} - -func (src *Numrange) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Numrange) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Numrange{} - return nil - } - - utr, err := ParseUntypedTextRange(string(src)) - if err != nil { - return err - } - - *dst = Numrange{Valid: true} - - dst.LowerType = utr.LowerType - dst.UpperType = utr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { - return err - } - } - - return nil -} - -func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Numrange{} - return nil - } - - ubr, err := ParseUntypedBinaryRange(src) - if err != nil { - return err - } - - *dst = Numrange{Valid: true} - - dst.LowerType = ubr.LowerType - dst.UpperType = ubr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { - return err - } - } - - return nil -} - -func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - switch src.LowerType { - case Exclusive, Unbounded: - buf = append(buf, '(') - case Inclusive: - buf = append(buf, '[') - case Empty: - return append(buf, "empty"...), nil - default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) - } - - var err error - - if src.LowerType != Unbounded { - buf, err = src.Lower.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - } - - buf = append(buf, ',') - - if src.UpperType != Unbounded { - buf, err = src.Upper.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - } - - switch src.UpperType { - case Exclusive, Unbounded: - buf = append(buf, ')') - case Inclusive: - buf = append(buf, ']') - default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) - } - - return buf, nil -} - -func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - var rangeType byte - switch src.LowerType { - case Inclusive: - rangeType |= lowerInclusiveMask - case Unbounded: - rangeType |= lowerUnboundedMask - case Exclusive: - case Empty: - return append(buf, emptyMask), nil - default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) - } - - switch src.UpperType { - case Inclusive: - rangeType |= upperInclusiveMask - case Unbounded: - rangeType |= upperUnboundedMask - case Exclusive: - default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) - } - - buf = append(buf, rangeType) - - var err error - - if src.LowerType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Lower.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - if src.UpperType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Upper.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Numrange) Scan(src interface{}) error { - if src == nil { - *dst = Numrange{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Numrange) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/numrange_test.go b/pgtype/numrange_test.go deleted file mode 100644 index a7792faf..00000000 --- a/pgtype/numrange_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package pgtype_test - -import ( - "math/big" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestNumrangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "numrange", []interface{}{ - &pgtype.Numrange{ - LowerType: pgtype.Empty, - UpperType: pgtype.Empty, - Valid: true, - }, - &pgtype.Numrange{ - Lower: pgtype.Numeric{Int: big.NewInt(-543), Exp: 3, Valid: true}, - Upper: pgtype.Numeric{Int: big.NewInt(342), Exp: 1, Valid: true}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Valid: true, - }, - &pgtype.Numrange{ - Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Valid: true}, - Upper: pgtype.Numeric{Int: big.NewInt(-5), Exp: 0, Valid: true}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Valid: true, - }, - &pgtype.Numrange{ - Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Valid: true}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Unbounded, - Valid: true, - }, - &pgtype.Numrange{ - Upper: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Valid: true}, - LowerType: pgtype.Unbounded, - UpperType: pgtype.Exclusive, - Valid: true, - }, - &pgtype.Numrange{}, - }) -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 3e807a89..50ca29c3 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -320,15 +320,15 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID}) ci.RegisterDataType(DataType{Value: &Circle{}, Name: "circle", OID: CircleOID}) ci.RegisterDataType(DataType{Value: &Date{}, Name: "date", OID: DateOID}) - ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) + // ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID}) ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID}) ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID}) ci.RegisterDataType(DataType{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID}) - ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) + // ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID}) - ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) + // ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) ci.RegisterDataType(DataType{Value: &Interval{}, Name: "interval", OID: IntervalOID}) ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID}) ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID}) @@ -338,7 +338,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID}) ci.RegisterDataType(DataType{Value: &Name{}, Name: "name", OID: NameOID}) ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID}) - ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) + // ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) ci.RegisterDataType(DataType{Value: &OIDValue{}, Name: "oid", OID: OIDOID}) ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID}) ci.RegisterDataType(DataType{Value: &Point{}, Name: "point", OID: PointOID}) @@ -349,10 +349,10 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Time{}, Name: "time", OID: TimeOID}) ci.RegisterDataType(DataType{Value: &Timestamp{}, Name: "timestamp", OID: TimestampOID}) ci.RegisterDataType(DataType{Value: &Timestamptz{}, Name: "timestamptz", OID: TimestamptzOID}) - ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) - ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID}) - ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) - ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID}) + // ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) + // ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID}) + // ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) + // ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID}) ci.RegisterDataType(DataType{Value: &Unknown{}, Name: "unknown", OID: UnknownOID}) ci.RegisterDataType(DataType{Value: &UUID{}, Name: "uuid", OID: UUIDOID}) ci.RegisterDataType(DataType{Value: &Varbit{}, Name: "varbit", OID: VarbitOID}) diff --git a/pgtype/range.go b/pgtype/range.go deleted file mode 100644 index e999f6a9..00000000 --- a/pgtype/range.go +++ /dev/null @@ -1,277 +0,0 @@ -package pgtype - -import ( - "bytes" - "encoding/binary" - "fmt" -) - -type BoundType byte - -const ( - Inclusive = BoundType('i') - Exclusive = BoundType('e') - Unbounded = BoundType('U') - Empty = BoundType('E') -) - -func (bt BoundType) String() string { - return string(bt) -} - -type UntypedTextRange struct { - Lower string - Upper string - LowerType BoundType - UpperType BoundType -} - -func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { - utr := &UntypedTextRange{} - if src == "empty" { - utr.LowerType = Empty - utr.UpperType = Empty - return utr, nil - } - - buf := bytes.NewBufferString(src) - - skipWhitespace(buf) - - r, _, err := buf.ReadRune() - if err != nil { - return nil, fmt.Errorf("invalid lower bound: %v", err) - } - switch r { - case '(': - utr.LowerType = Exclusive - case '[': - utr.LowerType = Inclusive - default: - return nil, fmt.Errorf("missing lower bound, instead got: %v", string(r)) - } - - r, _, err = buf.ReadRune() - if err != nil { - return nil, fmt.Errorf("invalid lower value: %v", err) - } - buf.UnreadRune() - - if r == ',' { - utr.LowerType = Unbounded - } else { - utr.Lower, err = rangeParseValue(buf) - if err != nil { - return nil, fmt.Errorf("invalid lower value: %v", err) - } - } - - r, _, err = buf.ReadRune() - if err != nil { - return nil, fmt.Errorf("missing range separator: %v", err) - } - if r != ',' { - return nil, fmt.Errorf("missing range separator: %v", r) - } - - r, _, err = buf.ReadRune() - if err != nil { - return nil, fmt.Errorf("invalid upper value: %v", err) - } - - if r == ')' || r == ']' { - utr.UpperType = Unbounded - } else { - buf.UnreadRune() - utr.Upper, err = rangeParseValue(buf) - if err != nil { - return nil, fmt.Errorf("invalid upper value: %v", err) - } - - r, _, err = buf.ReadRune() - if err != nil { - return nil, fmt.Errorf("missing upper bound: %v", err) - } - switch r { - case ')': - utr.UpperType = Exclusive - case ']': - utr.UpperType = Inclusive - default: - return nil, fmt.Errorf("missing upper bound, instead got: %v", string(r)) - } - } - - skipWhitespace(buf) - - if buf.Len() > 0 { - return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) - } - - return utr, nil -} - -func rangeParseValue(buf *bytes.Buffer) (string, error) { - r, _, err := buf.ReadRune() - if err != nil { - return "", err - } - if r == '"' { - return rangeParseQuotedValue(buf) - } - buf.UnreadRune() - - s := &bytes.Buffer{} - - for { - r, _, err := buf.ReadRune() - if err != nil { - return "", err - } - - switch r { - case '\\': - r, _, err = buf.ReadRune() - if err != nil { - return "", err - } - case ',', '[', ']', '(', ')': - buf.UnreadRune() - return s.String(), nil - } - - s.WriteRune(r) - } -} - -func rangeParseQuotedValue(buf *bytes.Buffer) (string, error) { - s := &bytes.Buffer{} - - for { - r, _, err := buf.ReadRune() - if err != nil { - return "", err - } - - switch r { - case '\\': - r, _, err = buf.ReadRune() - if err != nil { - return "", err - } - case '"': - r, _, err = buf.ReadRune() - if err != nil { - return "", err - } - if r != '"' { - buf.UnreadRune() - return s.String(), nil - } - } - s.WriteRune(r) - } -} - -type UntypedBinaryRange struct { - Lower []byte - Upper []byte - LowerType BoundType - UpperType BoundType -} - -// 0 = () = 00000 -// 1 = empty = 00001 -// 2 = [) = 00010 -// 4 = (] = 00100 -// 6 = [] = 00110 -// 8 = ) = 01000 -// 12 = ] = 01100 -// 16 = ( = 10000 -// 18 = [ = 10010 -// 24 = = 11000 - -const emptyMask = 1 -const lowerInclusiveMask = 2 -const upperInclusiveMask = 4 -const lowerUnboundedMask = 8 -const upperUnboundedMask = 16 - -func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { - ubr := &UntypedBinaryRange{} - - if len(src) == 0 { - return nil, fmt.Errorf("range too short: %v", len(src)) - } - - rangeType := src[0] - rp := 1 - - if rangeType&emptyMask > 0 { - if len(src[rp:]) > 0 { - return nil, fmt.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) - } - ubr.LowerType = Empty - ubr.UpperType = Empty - return ubr, nil - } - - if rangeType&lowerInclusiveMask > 0 { - ubr.LowerType = Inclusive - } else if rangeType&lowerUnboundedMask > 0 { - ubr.LowerType = Unbounded - } else { - ubr.LowerType = Exclusive - } - - if rangeType&upperInclusiveMask > 0 { - ubr.UpperType = Inclusive - } else if rangeType&upperUnboundedMask > 0 { - ubr.UpperType = Unbounded - } else { - ubr.UpperType = Exclusive - } - - if ubr.LowerType == Unbounded && ubr.UpperType == Unbounded { - if len(src[rp:]) > 0 { - return nil, fmt.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) - } - return ubr, nil - } - - if len(src[rp:]) < 4 { - return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) - } - valueLen := int(binary.BigEndian.Uint32(src[rp:])) - rp += 4 - - val := src[rp : rp+valueLen] - rp += valueLen - - if ubr.LowerType != Unbounded { - ubr.Lower = val - } else { - ubr.Upper = val - if len(src[rp:]) > 0 { - return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) - } - return ubr, nil - } - - if ubr.UpperType != Unbounded { - if len(src[rp:]) < 4 { - return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) - } - valueLen := int(binary.BigEndian.Uint32(src[rp:])) - rp += 4 - ubr.Upper = src[rp : rp+valueLen] - rp += valueLen - } - - if len(src[rp:]) > 0 { - return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) - } - - return ubr, nil - -} diff --git a/pgtype/range_test.go b/pgtype/range_test.go deleted file mode 100644 index 9e16df59..00000000 --- a/pgtype/range_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package pgtype - -import ( - "bytes" - "testing" -) - -func TestParseUntypedTextRange(t *testing.T) { - tests := []struct { - src string - result UntypedTextRange - err error - }{ - { - src: `[1,2)`, - result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `[1,2]`, - result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive}, - err: nil, - }, - { - src: `(1,3)`, - result: UntypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: ` [1,2) `, - result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `[ foo , bar )`, - result: UntypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `["foo","bar")`, - result: UntypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `["f""oo","b""ar")`, - result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `["f""oo","b""ar")`, - result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `["","bar")`, - result: UntypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `[f\"oo\,,b\\ar\))`, - result: UntypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `empty`, - result: UntypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty}, - err: nil, - }, - } - - for i, tt := range tests { - r, err := ParseUntypedTextRange(tt.src) - if err != tt.err { - t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) - continue - } - - if r.LowerType != tt.result.LowerType { - t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) - } - - if r.UpperType != tt.result.UpperType { - t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) - } - - if r.Lower != tt.result.Lower { - t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) - } - - if r.Upper != tt.result.Upper { - t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) - } - } -} - -func TestParseUntypedBinaryRange(t *testing.T) { - tests := []struct { - src []byte - result UntypedBinaryRange - err error - }{ - { - src: []byte{0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: []byte{1}, - result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty}, - err: nil, - }, - { - src: []byte{2, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: []byte{4, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive}, - err: nil, - }, - { - src: []byte{6, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive}, - err: nil, - }, - { - src: []byte{8, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive}, - err: nil, - }, - { - src: []byte{12, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive}, - err: nil, - }, - { - src: []byte{16, 0, 0, 0, 2, 0, 4}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded}, - err: nil, - }, - { - src: []byte{18, 0, 0, 0, 2, 0, 4}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded}, - err: nil, - }, - { - src: []byte{24}, - result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded}, - err: nil, - }, - } - - for i, tt := range tests { - r, err := ParseUntypedBinaryRange(tt.src) - if err != tt.err { - t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) - continue - } - - if r.LowerType != tt.result.LowerType { - t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) - } - - if r.UpperType != tt.result.UpperType { - t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) - } - - if bytes.Compare(r.Lower, tt.result.Lower) != 0 { - t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) - } - - if bytes.Compare(r.Upper, tt.result.Upper) != 0 { - t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) - } - } -} diff --git a/pgtype/tsrange.go b/pgtype/tsrange.go deleted file mode 100644 index 7495d972..00000000 --- a/pgtype/tsrange.go +++ /dev/null @@ -1,257 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "fmt" - - "github.com/jackc/pgio" -) - -type Tsrange struct { - Lower Timestamp - Upper Timestamp - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (dst *Tsrange) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Tsrange{} - return nil - } - - switch value := src.(type) { - case Tsrange: - *dst = value - case *Tsrange: - *dst = *value - case string: - return dst.DecodeText(nil, []byte(value)) - default: - return fmt.Errorf("cannot convert %v to Tsrange", src) - } - - return nil -} - -func (src Tsrange) Get() interface{} { - if !src.Valid { - return nil - } - return src -} - -func (src *Tsrange) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Tsrange) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Tsrange{} - return nil - } - - utr, err := ParseUntypedTextRange(string(src)) - if err != nil { - return err - } - - *dst = Tsrange{Valid: true} - - dst.LowerType = utr.LowerType - dst.UpperType = utr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { - return err - } - } - - return nil -} - -func (dst *Tsrange) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Tsrange{} - return nil - } - - ubr, err := ParseUntypedBinaryRange(src) - if err != nil { - return err - } - - *dst = Tsrange{Valid: true} - - dst.LowerType = ubr.LowerType - dst.UpperType = ubr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { - return err - } - } - - return nil -} - -func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - switch src.LowerType { - case Exclusive, Unbounded: - buf = append(buf, '(') - case Inclusive: - buf = append(buf, '[') - case Empty: - return append(buf, "empty"...), nil - default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) - } - - var err error - - if src.LowerType != Unbounded { - buf, err = src.Lower.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - } - - buf = append(buf, ',') - - if src.UpperType != Unbounded { - buf, err = src.Upper.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - } - - switch src.UpperType { - case Exclusive, Unbounded: - buf = append(buf, ')') - case Inclusive: - buf = append(buf, ']') - default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) - } - - return buf, nil -} - -func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - var rangeType byte - switch src.LowerType { - case Inclusive: - rangeType |= lowerInclusiveMask - case Unbounded: - rangeType |= lowerUnboundedMask - case Exclusive: - case Empty: - return append(buf, emptyMask), nil - default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) - } - - switch src.UpperType { - case Inclusive: - rangeType |= upperInclusiveMask - case Unbounded: - rangeType |= upperUnboundedMask - case Exclusive: - default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) - } - - buf = append(buf, rangeType) - - var err error - - if src.LowerType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Lower.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - if src.UpperType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Upper.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Tsrange) Scan(src interface{}) error { - if src == nil { - *dst = Tsrange{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Tsrange) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/tsrange_array.go b/pgtype/tsrange_array.go deleted file mode 100644 index 2af25f8d..00000000 --- a/pgtype/tsrange_array.go +++ /dev/null @@ -1,457 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type TsrangeArray struct { - Elements []Tsrange - Dimensions []ArrayDimension - Valid bool -} - -func (dst *TsrangeArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = TsrangeArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []Tsrange: - if value == nil { - *dst = TsrangeArray{} - } else if len(value) == 0 { - *dst = TsrangeArray{Valid: true} - } else { - *dst = TsrangeArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = TsrangeArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for TsrangeArray", src) - } - if elementsLength == 0 { - *dst = TsrangeArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to TsrangeArray", src) - } - - *dst = TsrangeArray{ - Elements: make([]Tsrange, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Tsrange, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to TsrangeArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *TsrangeArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to TsrangeArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in TsrangeArray", err) - } - index++ - - return index, nil -} - -func (dst TsrangeArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *TsrangeArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]Tsrange: - *v = make([]Tsrange, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *TsrangeArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from TsrangeArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from TsrangeArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *TsrangeArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TsrangeArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Tsrange - - if len(uta.Elements) > 0 { - elements = make([]Tsrange, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Tsrange - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = TsrangeArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *TsrangeArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TsrangeArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = TsrangeArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Tsrange, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = TsrangeArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src TsrangeArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src TsrangeArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("tsrange"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "tsrange") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *TsrangeArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src TsrangeArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/tsrange_test.go b/pgtype/tsrange_test.go deleted file mode 100644 index f7c4dc84..00000000 --- a/pgtype/tsrange_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package pgtype_test - -import ( - "testing" - "time" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestTsrangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "tsrange", []interface{}{ - &pgtype.Tsrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, - &pgtype.Tsrange{ - Lower: pgtype.Timestamp{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, - Upper: pgtype.Timestamp{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Valid: true}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Valid: true, - }, - &pgtype.Tsrange{ - Lower: pgtype.Timestamp{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, - Upper: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Valid: true}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Valid: true, - }, - &pgtype.Tsrange{}, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.Tsrange) - b := bb.(pgtype.Tsrange) - - return a.Valid == b.Valid && - a.Lower.Time.Equal(b.Lower.Time) && - a.Lower.Valid == b.Lower.Valid && - a.Lower.InfinityModifier == b.Lower.InfinityModifier && - a.Upper.Time.Equal(b.Upper.Time) && - a.Upper.Valid == b.Upper.Valid && - a.Upper.InfinityModifier == b.Upper.InfinityModifier - }) -} diff --git a/pgtype/tstzrange.go b/pgtype/tstzrange.go deleted file mode 100644 index 3d4e2cde..00000000 --- a/pgtype/tstzrange.go +++ /dev/null @@ -1,257 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "fmt" - - "github.com/jackc/pgio" -) - -type Tstzrange struct { - Lower Timestamptz - Upper Timestamptz - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (dst *Tstzrange) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Tstzrange{} - return nil - } - - switch value := src.(type) { - case Tstzrange: - *dst = value - case *Tstzrange: - *dst = *value - case string: - return dst.DecodeText(nil, []byte(value)) - default: - return fmt.Errorf("cannot convert %v to Tstzrange", src) - } - - return nil -} - -func (src Tstzrange) Get() interface{} { - if !src.Valid { - return nil - } - return src -} - -func (src *Tstzrange) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Tstzrange) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Tstzrange{} - return nil - } - - utr, err := ParseUntypedTextRange(string(src)) - if err != nil { - return err - } - - *dst = Tstzrange{Valid: true} - - dst.LowerType = utr.LowerType - dst.UpperType = utr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { - return err - } - } - - return nil -} - -func (dst *Tstzrange) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Tstzrange{} - return nil - } - - ubr, err := ParseUntypedBinaryRange(src) - if err != nil { - return err - } - - *dst = Tstzrange{Valid: true} - - dst.LowerType = ubr.LowerType - dst.UpperType = ubr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { - return err - } - } - - return nil -} - -func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - switch src.LowerType { - case Exclusive, Unbounded: - buf = append(buf, '(') - case Inclusive: - buf = append(buf, '[') - case Empty: - return append(buf, "empty"...), nil - default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) - } - - var err error - - if src.LowerType != Unbounded { - buf, err = src.Lower.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - } - - buf = append(buf, ',') - - if src.UpperType != Unbounded { - buf, err = src.Upper.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - } - - switch src.UpperType { - case Exclusive, Unbounded: - buf = append(buf, ')') - case Inclusive: - buf = append(buf, ']') - default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) - } - - return buf, nil -} - -func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - var rangeType byte - switch src.LowerType { - case Inclusive: - rangeType |= lowerInclusiveMask - case Unbounded: - rangeType |= lowerUnboundedMask - case Exclusive: - case Empty: - return append(buf, emptyMask), nil - default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) - } - - switch src.UpperType { - case Inclusive: - rangeType |= upperInclusiveMask - case Unbounded: - rangeType |= upperUnboundedMask - case Exclusive: - default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) - } - - buf = append(buf, rangeType) - - var err error - - if src.LowerType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Lower.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - if src.UpperType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Upper.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Tstzrange) Scan(src interface{}) error { - if src == nil { - *dst = Tstzrange{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Tstzrange) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/tstzrange_array.go b/pgtype/tstzrange_array.go deleted file mode 100644 index 389d6b4c..00000000 --- a/pgtype/tstzrange_array.go +++ /dev/null @@ -1,457 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type TstzrangeArray struct { - Elements []Tstzrange - Dimensions []ArrayDimension - Valid bool -} - -func (dst *TstzrangeArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = TstzrangeArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []Tstzrange: - if value == nil { - *dst = TstzrangeArray{} - } else if len(value) == 0 { - *dst = TstzrangeArray{Valid: true} - } else { - *dst = TstzrangeArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = TstzrangeArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for TstzrangeArray", src) - } - if elementsLength == 0 { - *dst = TstzrangeArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to TstzrangeArray", src) - } - - *dst = TstzrangeArray{ - Elements: make([]Tstzrange, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Tstzrange, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to TstzrangeArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *TstzrangeArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to TstzrangeArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in TstzrangeArray", err) - } - index++ - - return index, nil -} - -func (dst TstzrangeArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *TstzrangeArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]Tstzrange: - *v = make([]Tstzrange, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *TstzrangeArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from TstzrangeArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from TstzrangeArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *TstzrangeArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TstzrangeArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Tstzrange - - if len(uta.Elements) > 0 { - elements = make([]Tstzrange, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Tstzrange - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = TstzrangeArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *TstzrangeArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TstzrangeArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = TstzrangeArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Tstzrange, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = TstzrangeArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src TstzrangeArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src TstzrangeArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("tstzrange"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "tstzrange") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *TstzrangeArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src TstzrangeArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/tstzrange_test.go b/pgtype/tstzrange_test.go deleted file mode 100644 index 5d0b750f..00000000 --- a/pgtype/tstzrange_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package pgtype_test - -import ( - "testing" - "time" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" - "github.com/stretchr/testify/require" -) - -func TestTstzrangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "tstzrange", []interface{}{ - &pgtype.Tstzrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, - &pgtype.Tstzrange{ - Lower: pgtype.Timestamptz{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, - Upper: pgtype.Timestamptz{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Valid: true}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Valid: true, - }, - &pgtype.Tstzrange{ - Lower: pgtype.Timestamptz{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, - Upper: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Valid: true}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Valid: true, - }, - &pgtype.Tstzrange{}, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.Tstzrange) - b := bb.(pgtype.Tstzrange) - - return a.Valid == b.Valid && - a.Lower.Time.Equal(b.Lower.Time) && - a.Lower.Valid == b.Lower.Valid && - a.Lower.InfinityModifier == b.Lower.InfinityModifier && - a.Upper.Time.Equal(b.Upper.Time) && - a.Upper.Valid == b.Upper.Valid && - a.Upper.InfinityModifier == b.Upper.InfinityModifier - }) -} - -// https://github.com/jackc/pgtype/issues/74 -func TestTstzRangeDecodeTextInvalid(t *testing.T) { - tstzrange := &pgtype.Tstzrange{} - err := tstzrange.DecodeText(nil, []byte(`[eeee,)`)) - require.Error(t, err) -} diff --git a/pgtype/typed_range.go.erb b/pgtype/typed_range.go.erb deleted file mode 100644 index 99d8c22d..00000000 --- a/pgtype/typed_range.go.erb +++ /dev/null @@ -1,259 +0,0 @@ -package pgtype - -import ( - "bytes" - "database/sql/driver" - "fmt" - "io" - - "github.com/jackc/pgio" -) - -type <%= range_type %> struct { - Lower <%= element_type %> - Upper <%= element_type %> - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (dst *<%= range_type %>) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = <%= range_type %>{} - return nil - } - - switch value := src.(type) { - case <%= range_type %>: - *dst = value - case *<%= range_type %>: - *dst = *value - case string: - return dst.DecodeText(nil, []byte(value)) - default: - return fmt.Errorf("cannot convert %v to <%= range_type %>", src) - } - - return nil -} - -func (src <%= range_type %>) Get() interface{} { - if !src.Valid { - return nil - } - return src -} - -func (src *<%= range_type %>) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *<%= range_type %>) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = <%= range_type %>{} - return nil - } - - utr, err := ParseUntypedTextRange(string(src)) - if err != nil { - return err - } - - *dst = <%= range_type %>{Valid: true} - - dst.LowerType = utr.LowerType - dst.UpperType = utr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { - return err - } - } - - return nil -} - -func (dst *<%= range_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = <%= range_type %>{} - return nil - } - - ubr, err := ParseUntypedBinaryRange(src) - if err != nil { - return err - } - - *dst = <%= range_type %>{Valid: true} - - dst.LowerType = ubr.LowerType - dst.UpperType = ubr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { - return err - } - } - - return nil -} - -func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - switch src.LowerType { - case Exclusive, Unbounded: - buf = append(buf, '(') - case Inclusive: - buf = append(buf, '[') - case Empty: - return append(buf, "empty"...), nil - default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) - } - - var err error - - if src.LowerType != Unbounded { - buf, err = src.Lower.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - } - - buf = append(buf, ',') - - if src.UpperType != Unbounded { - buf, err = src.Upper.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - } - - switch src.UpperType { - case Exclusive, Unbounded: - buf = append(buf, ')') - case Inclusive: - buf = append(buf, ']') - default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) - } - - return buf, nil -} - -func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - var rangeType byte - switch src.LowerType { - case Inclusive: - rangeType |= lowerInclusiveMask - case Unbounded: - rangeType |= lowerUnboundedMask - case Exclusive: - case Empty: - return append(buf, emptyMask), nil - default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) - } - - switch src.UpperType { - case Inclusive: - rangeType |= upperInclusiveMask - case Unbounded: - rangeType |= upperUnboundedMask - case Exclusive: - default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) - } - - buf = append(buf, rangeType) - - var err error - - if src.LowerType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Lower.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - if src.UpperType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Upper.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *<%= range_type %>) Scan(src interface{}) error { - if src == nil { - *dst = <%= range_type %>{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src <%= range_type %>) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/typed_range_gen.sh b/pgtype/typed_range_gen.sh deleted file mode 100644 index bedda292..00000000 --- a/pgtype/typed_range_gen.sh +++ /dev/null @@ -1,7 +0,0 @@ -erb range_type=Int4range element_type=Int4 typed_range.go.erb > int4range.go -erb range_type=Int8range element_type=Int8 typed_range.go.erb > int8range.go -erb range_type=Tsrange element_type=Timestamp typed_range.go.erb > tsrange.go -erb range_type=Tstzrange element_type=Timestamptz typed_range.go.erb > tstzrange.go -erb range_type=Daterange element_type=Date typed_range.go.erb > daterange.go -erb range_type=Numrange element_type=Numeric typed_range.go.erb > numrange.go -goimports -w *range.go From 40fb889605053b74090643a6d6d892affaa8940b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Jan 2022 11:41:08 -0600 Subject: [PATCH 0797/1158] Temporarily remove composite and record support --- pgtype/composite_bench_test.go | 192 --------- pgtype/composite_fields.go | 107 ----- pgtype/composite_fields_test.go | 273 ------------ pgtype/composite_type.go | 715 -------------------------------- pgtype/composite_type_test.go | 320 -------------- pgtype/custom_composite_test.go | 87 ---- pgtype/pgtype.go | 2 +- pgtype/record.go | 141 ------- pgtype/record_test.go | 184 -------- 9 files changed, 1 insertion(+), 2020 deletions(-) delete mode 100644 pgtype/composite_bench_test.go delete mode 100644 pgtype/composite_fields.go delete mode 100644 pgtype/composite_fields_test.go delete mode 100644 pgtype/composite_type.go delete mode 100644 pgtype/composite_type_test.go delete mode 100644 pgtype/custom_composite_test.go delete mode 100644 pgtype/record.go delete mode 100644 pgtype/record_test.go diff --git a/pgtype/composite_bench_test.go b/pgtype/composite_bench_test.go deleted file mode 100644 index ef57709b..00000000 --- a/pgtype/composite_bench_test.go +++ /dev/null @@ -1,192 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgio" - "github.com/jackc/pgx/v5/pgtype" - "github.com/stretchr/testify/require" -) - -type MyCompositeRaw struct { - A int32 - B *string -} - -func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - buf = pgio.AppendUint32(buf, 2) - - buf = pgio.AppendUint32(buf, pgtype.Int4OID) - buf = pgio.AppendInt32(buf, 4) - buf = pgio.AppendInt32(buf, src.A) - - buf = pgio.AppendUint32(buf, pgtype.TextOID) - if src.B != nil { - buf = pgio.AppendInt32(buf, int32(len(*src.B))) - buf = append(buf, (*src.B)...) - } else { - buf = pgio.AppendInt32(buf, -1) - } - - return buf, nil -} - -func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - a := pgtype.Int4{} - b := pgtype.Text{} - - scanner := pgtype.NewCompositeBinaryScanner(ci, src) - scanner.ScanDecoder(&a) - scanner.ScanDecoder(&b) - - if scanner.Err() != nil { - return scanner.Err() - } - - dst.A = a.Int - if b.Valid { - dst.B = &b.String - } else { - dst.B = nil - } - - return nil -} - -var x []byte - -func BenchmarkBinaryEncodingManual(b *testing.B) { - buf := make([]byte, 0, 128) - ci := pgtype.NewConnInfo() - v := MyCompositeRaw{4, ptrS("ABCDEFG")} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - buf, _ = v.EncodeBinary(ci, buf[:0]) - } - x = buf -} - -func BenchmarkBinaryEncodingHelper(b *testing.B) { - buf := make([]byte, 0, 128) - ci := pgtype.NewConnInfo() - v := MyType{4, ptrS("ABCDEFG")} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - buf, _ = v.EncodeBinary(ci, buf[:0]) - } - x = buf -} - -func BenchmarkBinaryEncodingComposite(b *testing.B) { - buf := make([]byte, 0, 128) - ci := pgtype.NewConnInfo() - f1 := 2 - f2 := ptrS("bar") - c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ - {"a", pgtype.Int4OID}, - {"b", pgtype.TextOID}, - }, ci) - require.NoError(b, err) - - b.ResetTimer() - for n := 0; n < b.N; n++ { - c.Set([]interface{}{f1, f2}) - buf, _ = c.EncodeBinary(ci, buf[:0]) - } - x = buf -} - -func BenchmarkBinaryEncodingJSON(b *testing.B) { - buf := make([]byte, 0, 128) - ci := pgtype.NewConnInfo() - v := MyCompositeRaw{4, ptrS("ABCDEFG")} - j := pgtype.JSON{} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - j.Set(v) - buf, _ = j.EncodeBinary(ci, buf[:0]) - } - x = buf -} - -var dstRaw MyCompositeRaw - -func BenchmarkBinaryDecodingManual(b *testing.B) { - ci := pgtype.NewConnInfo() - buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) - dst := MyCompositeRaw{} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - err := dst.DecodeBinary(ci, buf) - E(err) - } - dstRaw = dst -} - -var dstMyType MyType - -func BenchmarkBinaryDecodingHelpers(b *testing.B) { - ci := pgtype.NewConnInfo() - buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) - dst := MyType{} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - err := dst.DecodeBinary(ci, buf) - E(err) - } - dstMyType = dst -} - -var gf1 int -var gf2 *string - -func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { - ci := pgtype.NewConnInfo() - buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) - var f1 int - var f2 *string - - c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ - {"a", pgtype.Int4OID}, - {"b", pgtype.TextOID}, - }, ci) - require.NoError(b, err) - - b.ResetTimer() - for n := 0; n < b.N; n++ { - err := c.DecodeBinary(ci, buf) - if err != nil { - b.Fatal(err) - } - err = c.AssignTo([]interface{}{&f1, &f2}) - if err != nil { - b.Fatal(err) - } - } - gf1 = f1 - gf2 = f2 -} - -func BenchmarkBinaryDecodingJSON(b *testing.B) { - ci := pgtype.NewConnInfo() - j := pgtype.JSON{} - j.Set(MyCompositeRaw{4, ptrS("ABCDEFG")}) - buf, _ := j.EncodeBinary(ci, nil) - - j = pgtype.JSON{} - dst := MyCompositeRaw{} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - err := j.DecodeBinary(ci, buf) - E(err) - err = j.AssignTo(&dst) - E(err) - } - dstRaw = dst -} diff --git a/pgtype/composite_fields.go b/pgtype/composite_fields.go deleted file mode 100644 index e7ca89c7..00000000 --- a/pgtype/composite_fields.go +++ /dev/null @@ -1,107 +0,0 @@ -package pgtype - -import "fmt" - -// CompositeFields scans the fields of a composite type into the elements of the CompositeFields value. To scan a -// nullable value use a *CompositeFields. It will be set to nil in case of null. -// -// CompositeFields implements EncodeBinary and EncodeText. However, functionality is limited due to CompositeFields not -// knowing the PostgreSQL schema of the composite type. Prefer using a registered CompositeType. -type CompositeFields []interface{} - -func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error { - if len(cf) == 0 { - return fmt.Errorf("cannot decode into empty CompositeFields") - } - - if src == nil { - return fmt.Errorf("cannot decode unexpected null into CompositeFields") - } - - scanner := NewCompositeBinaryScanner(ci, src) - - for _, f := range cf { - scanner.ScanValue(f) - } - - if scanner.Err() != nil { - return scanner.Err() - } - - return nil -} - -func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error { - if len(cf) == 0 { - return fmt.Errorf("cannot decode into empty CompositeFields") - } - - if src == nil { - return fmt.Errorf("cannot decode unexpected null into CompositeFields") - } - - scanner := NewCompositeTextScanner(ci, src) - - for _, f := range cf { - scanner.ScanValue(f) - } - - if scanner.Err() != nil { - return scanner.Err() - } - - return nil -} - -// EncodeText encodes composite fields into the text format. Prefer registering a CompositeType to using -// CompositeFields to encode directly. -func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - b := NewCompositeTextBuilder(ci, buf) - - for _, f := range cf { - if paramEncoder, ok := f.(ParamEncoder); ok { - b.AppendEncoder(paramEncoder) - } else { - b.AppendValue(f) - } - } - - return b.Finish() -} - -// EncodeBinary encodes composite fields into the binary format. Unlike CompositeType the schema of the destination is -// unknown. Prefer registering a CompositeType to using CompositeFields to encode directly. Because the binary -// composite format requires the OID of each field to be specified the only types that will work are those known to -// ConnInfo. -// -// In particular: -// -// * Nil cannot be used because there is no way to determine what type it. -// * Integer types must be exact matches. e.g. A Go int32 into a PostgreSQL bigint will fail. -// * No dereferencing will be done. e.g. *Text must be used instead of Text. -func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - b := NewCompositeBinaryBuilder(ci, buf) - - for _, f := range cf { - dt, ok := ci.DataTypeForValue(f) - if !ok { - return nil, fmt.Errorf("Unknown OID for %#v", f) - } - - if paramEncoder, ok := f.(ParamEncoder); ok { - b.AppendEncoder(dt.OID, paramEncoder) - } else { - err := dt.Value.Set(f) - if err != nil { - return nil, err - } - if paramEncoder, ok := dt.Value.(ParamEncoder); ok { - b.AppendEncoder(dt.OID, paramEncoder) - } else { - return nil, fmt.Errorf("Cannot encode binary format for %v", f) - } - } - } - - return b.Finish() -} diff --git a/pgtype/composite_fields_test.go b/pgtype/composite_fields_test.go deleted file mode 100644 index e73d8441..00000000 --- a/pgtype/composite_fields_test.go +++ /dev/null @@ -1,273 +0,0 @@ -package pgtype_test - -import ( - "context" - "testing" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestCompositeFieldsDecode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - formats := []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} - - // Assorted values - { - var a int32 - var b string - var c float64 - - for _, format := range formats { - err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan( - pgtype.CompositeFields{&a, &b, &c}, - ) - if !assert.NoErrorf(t, err, "Format: %v", format) { - continue - } - - assert.EqualValuesf(t, 1, a, "Format: %v", format) - assert.EqualValuesf(t, "hi", b, "Format: %v", format) - assert.EqualValuesf(t, 2.1, c, "Format: %v", format) - } - } - - // nulls, string "null", and empty string fields - { - var a pgtype.Text - var b string - var c pgtype.Text - var d string - var e pgtype.Text - - for _, format := range formats { - err := conn.QueryRow(context.Background(), "select row(null,'null',null,'',null)", pgx.QueryResultFormats{format}).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - if !assert.NoErrorf(t, err, "Format: %v", format) { - continue - } - - assert.Nilf(t, a.Get(), "Format: %v", format) - assert.EqualValuesf(t, "null", b, "Format: %v", format) - assert.Nilf(t, c.Get(), "Format: %v", format) - assert.EqualValuesf(t, "", d, "Format: %v", format) - assert.Nilf(t, e.Get(), "Format: %v", format) - } - } - - // null record - { - var a pgtype.Text - var b string - cf := pgtype.CompositeFields{&a, &b} - - for _, format := range formats { - // Cannot scan nil into - err := conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan( - cf, - ) - if assert.Errorf(t, err, "Format: %v", format) { - continue - } - assert.NotNilf(t, cf, "Format: %v", format) - - // But can scan nil into *pgtype.CompositeFields - err = conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan( - &cf, - ) - if assert.Errorf(t, err, "Format: %v", format) { - continue - } - assert.Nilf(t, cf, "Format: %v", format) - } - } - - // quotes and special characters - { - var a, b, c, d string - - for _, format := range formats { - err := conn.QueryRow(context.Background(), `select row('"', 'foo bar', 'foo''bar', 'baz)bar')`, pgx.QueryResultFormats{format}).Scan( - pgtype.CompositeFields{&a, &b, &c, &d}, - ) - if !assert.NoErrorf(t, err, "Format: %v", format) { - continue - } - - assert.Equalf(t, `"`, a, "Format: %v", format) - assert.Equalf(t, `foo bar`, b, "Format: %v", format) - assert.Equalf(t, `foo'bar`, c, "Format: %v", format) - assert.Equalf(t, `baz)bar`, d, "Format: %v", format) - } - } - - // arrays - { - var a []string - var b []int64 - - for _, format := range formats { - err := conn.QueryRow(context.Background(), `select row(array['foo', 'bar', 'baz'], array[1,2,3])`, pgx.QueryResultFormats{format}).Scan( - pgtype.CompositeFields{&a, &b}, - ) - if !assert.NoErrorf(t, err, "Format: %v", format) { - continue - } - - assert.EqualValuesf(t, []string{"foo", "bar", "baz"}, a, "Format: %v", format) - assert.EqualValuesf(t, []int64{1, 2, 3}, b, "Format: %v", format) - } - } - - // Skip nil fields - { - var a int32 - var c float64 - - for _, format := range formats { - err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan( - pgtype.CompositeFields{&a, nil, &c}, - ) - if !assert.NoErrorf(t, err, "Format: %v", format) { - continue - } - - assert.EqualValuesf(t, 1, a, "Format: %v", format) - assert.EqualValuesf(t, 2.1, c, "Format: %v", format) - } - } -} - -func TestCompositeFieldsEncode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - _, err := conn.Exec(context.Background(), `drop type if exists cf_encode; - -create type cf_encode as ( - a text, - b int4, - c text, - d float8, - e text -);`) - require.NoError(t, err) - defer conn.Exec(context.Background(), "drop type cf_encode") - - // Use simple protocol to force text or binary encoding - simpleProtocols := []bool{true, false} - - // Assorted values - { - var a string - var b int32 - var c string - var d float64 - var e string - - for _, simpleProtocol := range simpleProtocols { - err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{"hi", int32(1), "ok", float64(2.1), "bye"}, - ).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { - assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, "ok", c, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, 2.1, d, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, "bye", e, "Simple Protocol: %v", simpleProtocol) - } - } - } - - // untyped nil - { - var a pgtype.Text - var b int32 - var c string - var d pgtype.Float8 - var e pgtype.Text - - simpleProtocol := true - err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{nil, int32(1), "null", nil, nil}, - ).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { - assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol) - assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol) - assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol) - } - - // untyped nil cannot be represented in binary format because CompositeFields does not know the PostgreSQL schema - // of the composite type. - simpleProtocol = false - err = conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{nil, int32(1), "null", nil, nil}, - ).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - assert.Errorf(t, err, "Simple Protocol: %v", simpleProtocol) - } - - // nulls, string "null", and empty string fields - { - var a pgtype.Text - var b int32 - var c string - var d pgtype.Float8 - var e pgtype.Text - - for _, simpleProtocol := range simpleProtocols { - err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{&pgtype.Text{}, int32(1), "null", &pgtype.Float8{}, &pgtype.Text{}}, - ).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { - assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol) - assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol) - assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol) - } - } - } - - // quotes and special characters - { - var a string - var b int32 - var c string - var d float64 - var e string - - for _, simpleProtocol := range simpleProtocols { - err := conn.QueryRow( - context.Background(), - `select $1::cf_encode`, - pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{`"`, int32(42), `foo'bar`, float64(1.2), `baz)bar`}, - ).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { - assert.Equalf(t, `"`, a, "Simple Protocol: %v", simpleProtocol) - assert.Equalf(t, int32(42), b, "Simple Protocol: %v", simpleProtocol) - assert.Equalf(t, `foo'bar`, c, "Simple Protocol: %v", simpleProtocol) - assert.Equalf(t, float64(1.2), d, "Simple Protocol: %v", simpleProtocol) - assert.Equalf(t, `baz)bar`, e, "Simple Protocol: %v", simpleProtocol) - } - } - } -} diff --git a/pgtype/composite_type.go b/pgtype/composite_type.go deleted file mode 100644 index 85ab5910..00000000 --- a/pgtype/composite_type.go +++ /dev/null @@ -1,715 +0,0 @@ -package pgtype - -import ( - "encoding/binary" - "errors" - "fmt" - "reflect" - "strings" - - "github.com/jackc/pgio" -) - -type CompositeTypeField struct { - Name string - OID uint32 -} - -type CompositeType struct { - valid bool - - typeName string - - fields []CompositeTypeField - valueTranscoders []ValueTranscoder -} - -// NewCompositeType creates a CompositeType from fields and ci. ci is used to find the ValueTranscoders used -// for fields. All field OIDs must be previously registered in ci. -func NewCompositeType(typeName string, fields []CompositeTypeField, ci *ConnInfo) (*CompositeType, error) { - valueTranscoders := make([]ValueTranscoder, len(fields)) - - for i := range fields { - dt, ok := ci.DataTypeForOID(fields[i].OID) - if !ok { - return nil, fmt.Errorf("no data type registered for oid: %d", fields[i].OID) - } - - value := NewValue(dt.Value) - valueTranscoder, ok := value.(ValueTranscoder) - if !ok { - return nil, fmt.Errorf("data type for oid does not implement ValueTranscoder: %d", fields[i].OID) - } - - valueTranscoders[i] = valueTranscoder - } - - return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: valueTranscoders}, nil -} - -// NewCompositeTypeValues creates a CompositeType from fields and values. fields and values must have the same length. -// Prefer NewCompositeType unless overriding the transcoding of fields is required. -func NewCompositeTypeValues(typeName string, fields []CompositeTypeField, values []ValueTranscoder) (*CompositeType, error) { - if len(fields) != len(values) { - return nil, errors.New("fields and valueTranscoders must have same length") - } - - return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: values}, nil -} - -func (src CompositeType) Get() interface{} { - if !src.valid { - return nil - } - - results := make(map[string]interface{}, len(src.valueTranscoders)) - for i := range src.valueTranscoders { - results[src.fields[i].Name] = src.valueTranscoders[i].Get() - } - return results -} - -func (ct *CompositeType) NewTypeValue() Value { - a := &CompositeType{ - typeName: ct.typeName, - fields: ct.fields, - valueTranscoders: make([]ValueTranscoder, len(ct.valueTranscoders)), - } - - for i := range ct.valueTranscoders { - a.valueTranscoders[i] = NewValue(ct.valueTranscoders[i]).(ValueTranscoder) - } - - return a -} - -func (ct *CompositeType) TypeName() string { - return ct.typeName -} - -func (ct *CompositeType) Fields() []CompositeTypeField { - return ct.fields -} - -func (dst *CompositeType) setNil() { - dst.valid = false -} - -func (dst *CompositeType) Set(src interface{}) error { - if src == nil { - dst.setNil() - return nil - } - - switch value := src.(type) { - case []interface{}: - if len(value) != len(dst.valueTranscoders) { - return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.valueTranscoders)) - } - for i, v := range value { - if err := dst.valueTranscoders[i].Set(v); err != nil { - return err - } - } - dst.valid = true - case *[]interface{}: - if value == nil { - dst.setNil() - return nil - } - return dst.Set(*value) - default: - return fmt.Errorf("Can not convert %v to Composite", src) - } - - return nil -} - -// AssignTo should never be called on composite value directly -func (src CompositeType) AssignTo(dst interface{}) error { - if !src.valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case []interface{}: - if len(v) != len(src.valueTranscoders) { - return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders)) - } - for i := range src.valueTranscoders { - if v[i] == nil { - continue - } - - err := assignToOrSet(src.valueTranscoders[i], v[i]) - if err != nil { - return fmt.Errorf("unable to assign to dst[%d]: %v", i, err) - } - } - return nil - case *[]interface{}: - return src.AssignTo(*v) - default: - if isPtrStruct, err := src.assignToPtrStruct(dst); isPtrStruct { - return err - } - - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } -} - -func assignToOrSet(src Value, dst interface{}) error { - assignToErr := src.AssignTo(dst) - if assignToErr != nil { - // Try to use get / set instead -- this avoids every type having to be able to AssignTo type of self. - setSucceeded := false - if setter, ok := dst.(Value); ok { - err := setter.Set(src.Get()) - setSucceeded = err == nil - } - if !setSucceeded { - return assignToErr - } - } - - return nil -} - -func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) { - dstValue := reflect.ValueOf(dst) - if dstValue.Kind() != reflect.Ptr { - return false, nil - } - - if dstValue.IsNil() { - return false, nil - } - - dstElemValue := dstValue.Elem() - dstElemType := dstElemValue.Type() - - if dstElemType.Kind() != reflect.Struct { - return false, nil - } - - exportedFields := make([]int, 0, dstElemType.NumField()) - for i := 0; i < dstElemType.NumField(); i++ { - sf := dstElemType.Field(i) - if sf.PkgPath == "" { - exportedFields = append(exportedFields, i) - } - } - - if len(exportedFields) != len(src.valueTranscoders) { - return false, nil - } - - for i := range exportedFields { - err := assignToOrSet(src.valueTranscoders[i], dstElemValue.Field(exportedFields[i]).Addr().Interface()) - if err != nil { - return true, fmt.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err) - } - } - - return true, nil -} - -func (ct *CompositeType) BinaryFormatSupported() bool { - for _, vt := range ct.valueTranscoders { - if !vt.BinaryFormatSupported() { - return false - } - } - return true -} - -func (ct *CompositeType) TextFormatSupported() bool { - for _, vt := range ct.valueTranscoders { - if !vt.TextFormatSupported() { - return false - } - } - return true -} - -func (ct *CompositeType) PreferredFormat() int16 { - if ct.BinaryFormatSupported() { - return BinaryFormatCode - } - return TextFormatCode -} - -func (dst *CompositeType) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - if src == nil { - dst.setNil() - return nil - } - - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src CompositeType) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} - -func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { - if !src.valid { - return nil, nil - } - - b := NewCompositeBinaryBuilder(ci, buf) - for i := range src.valueTranscoders { - b.AppendEncoder(src.fields[i].OID, src.valueTranscoders[i]) - } - - return b.Finish() -} - -// DecodeBinary implements BinaryDecoder interface. -// Opposite to Record, fields in a composite act as a "schema" -// and decoding fails if SQL value can't be assigned due to -// type mismatch -func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { - scanner := NewCompositeBinaryScanner(ci, buf) - - for _, f := range dst.valueTranscoders { - scanner.ScanDecoder(f) - } - - if scanner.Err() != nil { - return scanner.Err() - } - - dst.valid = true - - return nil -} - -func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error { - scanner := NewCompositeTextScanner(ci, buf) - - for _, f := range dst.valueTranscoders { - scanner.ScanDecoder(f) - } - - if scanner.Err() != nil { - return scanner.Err() - } - - dst.valid = true - - return nil -} - -func (src CompositeType) EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { - if !src.valid { - return nil, nil - } - - b := NewCompositeTextBuilder(ci, buf) - for _, f := range src.valueTranscoders { - b.AppendEncoder(f) - } - - return b.Finish() -} - -type CompositeBinaryScanner struct { - ci *ConnInfo - rp int - src []byte - - fieldCount int32 - fieldBytes []byte - fieldOID uint32 - err error -} - -// NewCompositeBinaryScanner a scanner over a binary encoded composite balue. -func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner { - rp := 0 - if len(src[rp:]) < 4 { - return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)} - } - - fieldCount := int32(binary.BigEndian.Uint32(src[rp:])) - rp += 4 - - return &CompositeBinaryScanner{ - ci: ci, - rp: rp, - src: src, - fieldCount: fieldCount, - } -} - -// ScanDecoder calls Next and decodes the result with d. -func (cfs *CompositeBinaryScanner) ScanDecoder(d ResultDecoder) { - if cfs.err != nil { - return - } - - if cfs.Next() { - cfs.err = d.DecodeResult(cfs.ci, 0, BinaryFormatCode, cfs.fieldBytes) - } else { - cfs.err = errors.New("read past end of composite") - } -} - -// ScanDecoder calls Next and scans the result into d. -func (cfs *CompositeBinaryScanner) ScanValue(d interface{}) { - if cfs.err != nil { - return - } - - if cfs.Next() { - cfs.err = cfs.ci.Scan(cfs.OID(), BinaryFormatCode, cfs.Bytes(), d) - } else { - cfs.err = errors.New("read past end of composite") - } -} - -// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After -// Next returns false, the Err method can be called to check if any errors occurred. -func (cfs *CompositeBinaryScanner) Next() bool { - if cfs.err != nil { - return false - } - - if cfs.rp == len(cfs.src) { - return false - } - - if len(cfs.src[cfs.rp:]) < 8 { - cfs.err = fmt.Errorf("Record incomplete %v", cfs.src) - return false - } - cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:]) - cfs.rp += 4 - - fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:]))) - cfs.rp += 4 - - if fieldLen >= 0 { - if len(cfs.src[cfs.rp:]) < fieldLen { - cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src) - return false - } - cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen] - cfs.rp += fieldLen - } else { - cfs.fieldBytes = nil - } - - return true -} - -func (cfs *CompositeBinaryScanner) FieldCount() int { - return int(cfs.fieldCount) -} - -// Bytes returns the bytes of the field most recently read by Scan(). -func (cfs *CompositeBinaryScanner) Bytes() []byte { - return cfs.fieldBytes -} - -// OID returns the OID of the field most recently read by Scan(). -func (cfs *CompositeBinaryScanner) OID() uint32 { - return cfs.fieldOID -} - -// Err returns any error encountered by the scanner. -func (cfs *CompositeBinaryScanner) Err() error { - return cfs.err -} - -type CompositeTextScanner struct { - ci *ConnInfo - rp int - src []byte - - fieldBytes []byte - err error -} - -// NewCompositeTextScanner a scanner over a text encoded composite value. -func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner { - if len(src) < 2 { - return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)} - } - - if src[0] != '(' { - return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")} - } - - if src[len(src)-1] != ')' { - return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")} - } - - return &CompositeTextScanner{ - ci: ci, - rp: 1, - src: src, - } -} - -// ScanDecoder calls Next and decodes the result with d. -func (cfs *CompositeTextScanner) ScanDecoder(d ResultDecoder) { - if cfs.err != nil { - return - } - - if cfs.Next() { - cfs.err = d.DecodeResult(cfs.ci, 0, TextFormatCode, cfs.fieldBytes) - } else { - cfs.err = errors.New("read past end of composite") - } -} - -// ScanDecoder calls Next and scans the result into d. -func (cfs *CompositeTextScanner) ScanValue(d interface{}) { - if cfs.err != nil { - return - } - - if cfs.Next() { - cfs.err = cfs.ci.Scan(0, TextFormatCode, cfs.Bytes(), d) - } else { - cfs.err = errors.New("read past end of composite") - } -} - -// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After -// Next returns false, the Err method can be called to check if any errors occurred. -func (cfs *CompositeTextScanner) Next() bool { - if cfs.err != nil { - return false - } - - if cfs.rp == len(cfs.src) { - return false - } - - switch cfs.src[cfs.rp] { - case ',', ')': // null - cfs.rp++ - cfs.fieldBytes = nil - return true - case '"': // quoted value - cfs.rp++ - cfs.fieldBytes = make([]byte, 0, 16) - for { - ch := cfs.src[cfs.rp] - - if ch == '"' { - cfs.rp++ - if cfs.src[cfs.rp] == '"' { - cfs.fieldBytes = append(cfs.fieldBytes, '"') - cfs.rp++ - } else { - break - } - } else if ch == '\\' { - cfs.rp++ - cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp]) - cfs.rp++ - } else { - cfs.fieldBytes = append(cfs.fieldBytes, ch) - cfs.rp++ - } - } - cfs.rp++ - return true - default: // unquoted value - start := cfs.rp - for { - ch := cfs.src[cfs.rp] - if ch == ',' || ch == ')' { - break - } - cfs.rp++ - } - cfs.fieldBytes = cfs.src[start:cfs.rp] - cfs.rp++ - return true - } -} - -// Bytes returns the bytes of the field most recently read by Scan(). -func (cfs *CompositeTextScanner) Bytes() []byte { - return cfs.fieldBytes -} - -// Err returns any error encountered by the scanner. -func (cfs *CompositeTextScanner) Err() error { - return cfs.err -} - -type CompositeBinaryBuilder struct { - ci *ConnInfo - buf []byte - startIdx int - fieldCount uint32 - err error -} - -func NewCompositeBinaryBuilder(ci *ConnInfo, buf []byte) *CompositeBinaryBuilder { - startIdx := len(buf) - buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields - return &CompositeBinaryBuilder{ci: ci, buf: buf, startIdx: startIdx} -} - -func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) { - if b.err != nil { - return - } - - dt, ok := b.ci.DataTypeForOID(oid) - if !ok { - b.err = fmt.Errorf("unknown data type for OID: %d", oid) - return - } - - err := dt.Value.Set(field) - if err != nil { - b.err = err - return - } - - paramEncoder, ok := dt.Value.(ParamEncoder) - if !ok { - b.err = fmt.Errorf("unable to encode for OID: %d", oid) - return - } - - b.AppendEncoder(oid, paramEncoder) -} - -func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field ParamEncoder) { - if b.err != nil { - return - } - - b.buf = pgio.AppendUint32(b.buf, oid) - lengthPos := len(b.buf) - b.buf = pgio.AppendInt32(b.buf, -1) - fieldBuf, err := field.EncodeParam(b.ci, oid, BinaryFormatCode, b.buf) - if err != nil { - b.err = err - return - } - if fieldBuf != nil { - binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf))) - b.buf = fieldBuf - } - - b.fieldCount++ -} - -func (b *CompositeBinaryBuilder) Finish() ([]byte, error) { - if b.err != nil { - return nil, b.err - } - - binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount) - return b.buf, nil -} - -type CompositeTextBuilder struct { - ci *ConnInfo - buf []byte - startIdx int - fieldCount uint32 - err error - fieldBuf [32]byte -} - -func NewCompositeTextBuilder(ci *ConnInfo, buf []byte) *CompositeTextBuilder { - buf = append(buf, '(') // allocate room for number of fields - return &CompositeTextBuilder{ci: ci, buf: buf} -} - -func (b *CompositeTextBuilder) AppendValue(field interface{}) { - if b.err != nil { - return - } - - if field == nil { - b.buf = append(b.buf, ',') - return - } - - dt, ok := b.ci.DataTypeForValue(field) - if !ok { - b.err = fmt.Errorf("unknown data type for field: %v", field) - return - } - - err := dt.Value.Set(field) - if err != nil { - b.err = err - return - } - - paramEncoder, ok := dt.Value.(ParamEncoder) - if !ok { - b.err = fmt.Errorf("unable to encode for value: %v", field) - return - } - - b.AppendEncoder(paramEncoder) -} - -func (b *CompositeTextBuilder) AppendEncoder(field ParamEncoder) { - if b.err != nil { - return - } - - fieldBuf, err := field.EncodeParam(b.ci, 0, TextFormatCode, b.fieldBuf[0:0]) - if err != nil { - b.err = err - return - } - if fieldBuf != nil { - b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...) - } - - b.buf = append(b.buf, ',') -} - -func (b *CompositeTextBuilder) Finish() ([]byte, error) { - if b.err != nil { - return nil, b.err - } - - b.buf[len(b.buf)-1] = ')' - return b.buf, nil -} - -var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) - -func quoteCompositeField(src string) string { - return `"` + quoteCompositeReplacer.Replace(src) + `"` -} - -func quoteCompositeFieldIfNeeded(src string) string { - if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) { - return quoteCompositeField(src) - } - return src -} diff --git a/pgtype/composite_type_test.go b/pgtype/composite_type_test.go deleted file mode 100644 index a41ad0f4..00000000 --- a/pgtype/composite_type_test.go +++ /dev/null @@ -1,320 +0,0 @@ -package pgtype_test - -import ( - "context" - "fmt" - "os" - "testing" - - pgx "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestCompositeTypeSetAndGet(t *testing.T) { - ci := pgtype.NewConnInfo() - ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ - {"a", pgtype.TextOID}, - {"b", pgtype.Int4OID}, - }, ci) - require.NoError(t, err) - assert.Equal(t, nil, ct.Get()) - - nilTests := []struct { - src interface{} - }{ - {nil}, // nil interface - {(*[]interface{})(nil)}, // typed nil - } - - for i, tt := range nilTests { - err := ct.Set(tt.src) - assert.NoErrorf(t, err, "%d", i) - assert.Equal(t, nil, ct.Get()) - } - - compatibleValuesTests := []struct { - src []interface{} - expected map[string]interface{} - }{ - { - src: []interface{}{"foo", int32(42)}, - expected: map[string]interface{}{"a": "foo", "b": int32(42)}, - }, - { - src: []interface{}{nil, nil}, - expected: map[string]interface{}{"a": nil, "b": nil}, - }, - { - src: []interface{}{&pgtype.Text{String: "hi", Valid: true}, &pgtype.Int4{Int: 7, Valid: true}}, - expected: map[string]interface{}{"a": "hi", "b": int32(7)}, - }, - } - - for i, tt := range compatibleValuesTests { - err := ct.Set(tt.src) - assert.NoErrorf(t, err, "%d", i) - assert.EqualValues(t, tt.expected, ct.Get()) - } -} - -func TestCompositeTypeAssignTo(t *testing.T) { - ci := pgtype.NewConnInfo() - ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ - {"a", pgtype.TextOID}, - {"b", pgtype.Int4OID}, - }, ci) - require.NoError(t, err) - - { - err := ct.Set([]interface{}{"foo", int32(42)}) - assert.NoError(t, err) - - var a string - var b int32 - - err = ct.AssignTo([]interface{}{&a, &b}) - assert.NoError(t, err) - - assert.Equal(t, "foo", a) - assert.Equal(t, int32(42), b) - } - - { - err := ct.Set([]interface{}{"foo", int32(42)}) - assert.NoError(t, err) - - var a pgtype.Text - var b pgtype.Int4 - - err = ct.AssignTo([]interface{}{&a, &b}) - assert.NoError(t, err) - - assert.Equal(t, pgtype.Text{String: "foo", Valid: true}, a) - assert.Equal(t, pgtype.Int4{Int: 42, Valid: true}, b) - } - - // Allow nil destination component as no-op - { - err := ct.Set([]interface{}{"foo", int32(42)}) - assert.NoError(t, err) - - var b int32 - - err = ct.AssignTo([]interface{}{nil, &b}) - assert.NoError(t, err) - - assert.Equal(t, int32(42), b) - } - - // *[]interface{} dest when null - { - err := ct.Set(nil) - assert.NoError(t, err) - - var a pgtype.Text - var b pgtype.Int4 - dst := []interface{}{&a, &b} - - err = ct.AssignTo(&dst) - assert.NoError(t, err) - - assert.Nil(t, dst) - } - - // *[]interface{} dest when not null - { - err := ct.Set([]interface{}{"foo", int32(42)}) - assert.NoError(t, err) - - var a pgtype.Text - var b pgtype.Int4 - dst := []interface{}{&a, &b} - - err = ct.AssignTo(&dst) - assert.NoError(t, err) - - assert.NotNil(t, dst) - assert.Equal(t, pgtype.Text{String: "foo", Valid: true}, a) - assert.Equal(t, pgtype.Int4{Int: 42, Valid: true}, b) - } - - // Struct fields positionally via reflection - { - err := ct.Set([]interface{}{"foo", int32(42)}) - assert.NoError(t, err) - - s := struct { - A string - B int32 - }{} - - err = ct.AssignTo(&s) - if assert.NoError(t, err) { - assert.Equal(t, "foo", s.A) - assert.Equal(t, int32(42), s.B) - } - } -} - -func TestCompositeTypeTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - _, err := conn.Exec(context.Background(), `drop type if exists ct_test; - -create type ct_test as ( - a text, - b int4 -);`) - require.NoError(t, err) - defer conn.Exec(context.Background(), "drop type ct_test") - - var oid uint32 - err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid) - require.NoError(t, err) - - defer conn.Exec(context.Background(), "drop type ct_test") - - ct, err := pgtype.NewCompositeType("ct_test", []pgtype.CompositeTypeField{ - {"a", pgtype.TextOID}, - {"b", pgtype.Int4OID}, - }, conn.ConnInfo()) - require.NoError(t, err) - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid}) - - // Use simple protocol to force text or binary encoding - simpleProtocols := []bool{true, false} - - var a string - var b int32 - - for _, simpleProtocol := range simpleProtocols { - err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{"hi", int32(42)}, - ).Scan( - []interface{}{&a, &b}, - ) - if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { - assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, 42, b, "Simple Protocol: %v", simpleProtocol) - } - } -} - -// https://github.com/jackc/pgx/issues/874 -func TestCompositeTypeTextDecodeNested(t *testing.T) { - newCompositeType := func(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name} - } - - rowType, err := pgtype.NewCompositeTypeValues(name, fields, vals) - require.NoError(t, err) - return rowType - } - - dimensionsType := func() pgtype.ValueTranscoder { - return newCompositeType( - "dimensions", - []string{"width", "height"}, - &pgtype.Int4{}, - &pgtype.Int4{}, - ) - } - productImageType := func() pgtype.ValueTranscoder { - return newCompositeType( - "product_image_type", - []string{"source", "dimensions"}, - &pgtype.Text{}, - dimensionsType(), - ) - } - productImageSetType := newCompositeType( - "product_image_set_type", - []string{"name", "orig_image", "images"}, - &pgtype.Text{}, - productImageType(), - pgtype.NewArrayType("product_image", 0, func() pgtype.ValueTranscoder { - return productImageType() - }), - ) - - err := productImageSetType.DecodeText(nil, []byte(`(name,"(img1,""(11,11)"")","{""(img2,\\""(22,22)\\"")"",""(img3,\\""(33,33)\\"")""}")`)) - require.NoError(t, err) -} - -func Example_composite() { - conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - fmt.Println(err) - return - } - - defer conn.Close(context.Background()) - _, err = conn.Exec(context.Background(), `drop type if exists mytype;`) - if err != nil { - fmt.Println(err) - return - } - - _, err = conn.Exec(context.Background(), `create type mytype as ( - a int4, - b text -);`) - if err != nil { - fmt.Println(err) - return - } - defer conn.Exec(context.Background(), "drop type mytype") - - var oid uint32 - err = conn.QueryRow(context.Background(), `select 'mytype'::regtype::oid`).Scan(&oid) - if err != nil { - fmt.Println(err) - return - } - - ct, err := pgtype.NewCompositeType("mytype", []pgtype.CompositeTypeField{ - {"a", pgtype.Int4OID}, - {"b", pgtype.TextOID}, - }, conn.ConnInfo()) - if err != nil { - fmt.Println(err) - return - } - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid}) - - var a int - var b *string - - err = conn.QueryRow(context.Background(), "select $1::mytype", []interface{}{2, "bar"}).Scan([]interface{}{&a, &b}) - if err != nil { - fmt.Println(err) - return - } - - fmt.Printf("First: a=%d b=%s\n", a, *b) - - err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan([]interface{}{&a, &b}) - if err != nil { - fmt.Println(err) - return - } - - fmt.Printf("Second: a=%d b=%v\n", a, b) - - scanTarget := []interface{}{&a, &b} - err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(&scanTarget) - E(err) - - fmt.Printf("Third: isNull=%v\n", scanTarget == nil) - - // Output: - // First: a=2 b=bar - // Second: a=1 b= - // Third: isNull=true -} diff --git a/pgtype/custom_composite_test.go b/pgtype/custom_composite_test.go deleted file mode 100644 index e5f2166e..00000000 --- a/pgtype/custom_composite_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package pgtype_test - -import ( - "context" - "errors" - "fmt" - "os" - - pgx "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" -) - -type MyType struct { - a int32 // NULL will cause decoding error - b *string // there can be NULL in this position in SQL -} - -func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs") - } - - if err := (pgtype.CompositeFields{&dst.a, &dst.b}).DecodeBinary(ci, src); err != nil { - return err - } - - return nil -} - -func (src MyType) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) { - a := pgtype.Int4{src.a, true} - var b pgtype.Text - if src.b != nil { - b = pgtype.Text{*src.b, true} - } else { - b = pgtype.Text{} - } - - return (pgtype.CompositeFields{&a, &b}).EncodeBinary(ci, buf) -} - -func ptrS(s string) *string { - return &s -} - -func E(err error) { - if err != nil { - panic(err) - } -} - -// ExampleCustomCompositeTypes demonstrates how support for custom types mappable to SQL -// composites can be added. -func Example_customCompositeTypes() { - conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - E(err) - - defer conn.Close(context.Background()) - _, err = conn.Exec(context.Background(), `drop type if exists mytype; - -create type mytype as ( - a int4, - b text -);`) - E(err) - defer conn.Exec(context.Background(), "drop type mytype") - - var result *MyType - - // Demonstrates both passing and reading back composite values - err = conn.QueryRow(context.Background(), "select $1::mytype", - pgx.QueryResultFormats{pgx.BinaryFormatCode}, MyType{1, ptrS("foo")}). - Scan(&result) - E(err) - - fmt.Printf("First row: a=%d b=%s\n", result.a, *result.b) - - // Because we scan into &*MyType, NULLs are handled generically by assigning nil to result - err = conn.QueryRow(context.Background(), "select NULL::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result) - E(err) - - fmt.Printf("Second row: %v\n", result) - - // Output: - // First row: a=1 b=foo - // Second row: -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 50ca29c3..6df3a582 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -343,7 +343,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID}) ci.RegisterDataType(DataType{Value: &Point{}, Name: "point", OID: PointOID}) ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID}) - ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) + // ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) ci.RegisterDataType(DataType{Value: &Text{}, Name: "text", OID: TextOID}) ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID}) ci.RegisterDataType(DataType{Value: &Time{}, Name: "time", OID: TimeOID}) diff --git a/pgtype/record.go b/pgtype/record.go deleted file mode 100644 index 5bb4d701..00000000 --- a/pgtype/record.go +++ /dev/null @@ -1,141 +0,0 @@ -package pgtype - -import ( - "fmt" - "reflect" -) - -// Record is the generic PostgreSQL record type such as is created with the -// "row" function. Record only implements BinaryEncoder and Value. The text -// format output format from PostgreSQL does not include type information and is -// therefore impossible to decode. No encoders are implemented because -// PostgreSQL does not support input of generic records. -type Record struct { - Fields []Value - Valid bool -} - -func (dst *Record) Set(src interface{}) error { - if src == nil { - *dst = Record{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case []Value: - *dst = Record{Fields: value, Valid: true} - default: - return fmt.Errorf("cannot convert %v to Record", src) - } - - return nil -} - -func (dst Record) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.Fields -} - -func (src *Record) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case *[]Value: - *v = make([]Value, len(src.Fields)) - copy(*v, src.Fields) - return nil - case *[]interface{}: - *v = make([]interface{}, len(src.Fields)) - for i := range *v { - (*v)[i] = src.Fields[i].Get() - } - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } -} - -func prepareNewBinaryDecoder(ci *ConnInfo, fieldOID uint32, v *Value) (BinaryDecoder, error) { - var binaryDecoder BinaryDecoder - - if dt, ok := ci.DataTypeForOID(fieldOID); ok { - binaryDecoder, _ = dt.Value.(BinaryDecoder) - } else { - return nil, fmt.Errorf("unknown oid while decoding record: %v", fieldOID) - } - - if binaryDecoder == nil { - return nil, fmt.Errorf("no binary decoder registered for: %v", fieldOID) - } - - // Duplicate struct to scan into - binaryDecoder = reflect.New(reflect.ValueOf(binaryDecoder).Elem().Type()).Interface().(BinaryDecoder) - *v = binaryDecoder.(Value) - return binaryDecoder, nil -} - -func (Record) BinaryFormatSupported() bool { - return true -} - -func (Record) TextFormatSupported() bool { - return false -} - -func (Record) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Record) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return fmt.Errorf("text format is not supported") - } - return fmt.Errorf("unknown format code %d", format) -} - -func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Record{} - return nil - } - - scanner := NewCompositeBinaryScanner(ci, src) - - fields := make([]Value, scanner.FieldCount()) - - for i := 0; scanner.Next(); i++ { - binaryDecoder, err := prepareNewBinaryDecoder(ci, scanner.OID(), &fields[i]) - if err != nil { - return err - } - - if err = binaryDecoder.DecodeBinary(ci, scanner.Bytes()); err != nil { - return err - } - } - - if scanner.Err() != nil { - return scanner.Err() - } - - *dst = Record{Fields: fields, Valid: true} - - return nil -} diff --git a/pgtype/record_test.go b/pgtype/record_test.go deleted file mode 100644 index 921f0975..00000000 --- a/pgtype/record_test.go +++ /dev/null @@ -1,184 +0,0 @@ -package pgtype_test - -import ( - "context" - "fmt" - "reflect" - "testing" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -var recordTests = []struct { - sql string - expected pgtype.Record -}{ - { - sql: `select row()`, - expected: pgtype.Record{ - Fields: []pgtype.Value{}, - Valid: true, - }, - }, - { - sql: `select row('foo'::text, 42::int4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Valid: true}, - &pgtype.Int4{Int: 42, Valid: true}, - }, - Valid: true, - }, - }, - { - sql: `select row(100.0::float4, 1.09::float4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Float4{Float: 100, Valid: true}, - &pgtype.Float4{Float: 1.09, Valid: true}, - }, - Valid: true, - }, - }, - { - sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Valid: true}, - &pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {}, - {Int: 4, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, - Valid: true, - }, - &pgtype.Int4{Int: 42, Valid: true}, - }, - Valid: true, - }, - }, - { - sql: `select row(null)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Unknown{}, - }, - Valid: true, - }, - }, - { - sql: `select null::record`, - expected: pgtype.Record{}, - }, -} - -func TestRecordTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - for i, tt := range recordTests { - psName := fmt.Sprintf("test%d", i) - _, err := conn.Prepare(context.Background(), psName, tt.sql) - if err != nil { - t.Fatal(err) - } - - t.Run(tt.sql, func(t *testing.T) { - var result pgtype.Record - if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil { - t.Errorf("%v", err) - return - } - - if !reflect.DeepEqual(tt.expected, result) { - t.Errorf("expected %#v, got %#v", tt.expected, result) - } - }) - - } -} - -func TestRecordWithUnknownOID(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - _, err := conn.Exec(context.Background(), `drop type if exists floatrange; - -create type floatrange as range ( - subtype = float8, - subtype_diff = float8mi -);`) - if err != nil { - t.Fatal(err) - } - defer conn.Exec(context.Background(), "drop type floatrange") - - var result pgtype.Record - err = conn.QueryRow(context.Background(), "select row('foo'::text, floatrange(1, 10), 'bar'::text)").Scan(&result) - if err == nil { - t.Errorf("expected error but none") - } -} - -func TestRecordAssignTo(t *testing.T) { - var valueSlice []pgtype.Value - var interfaceSlice []interface{} - - simpleTests := []struct { - src pgtype.Record - dst interface{} - expected interface{} - }{ - { - src: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Valid: true}, - &pgtype.Int4{Int: 42, Valid: true}, - }, - Valid: true, - }, - dst: &valueSlice, - expected: []pgtype.Value{ - &pgtype.Text{String: "foo", Valid: true}, - &pgtype.Int4{Int: 42, Valid: true}, - }, - }, - { - src: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Valid: true}, - &pgtype.Int4{Int: 42, Valid: true}, - }, - Valid: true, - }, - dst: &interfaceSlice, - expected: []interface{}{"foo", int32(42)}, - }, - { - src: pgtype.Record{}, - dst: &valueSlice, - expected: (([]pgtype.Value)(nil)), - }, - { - src: pgtype.Record{}, - dst: &interfaceSlice, - expected: (([]interface{})(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} From 0c0e28a70a1cebbb43a156ceca1bc7983a7512ef Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Jan 2022 17:16:03 -0600 Subject: [PATCH 0798/1158] Convert int4 and int8 to new system Note: purposely disabled some tests and composite support that needs to be restored later in v5 development. --- Rakefile | 2 +- bench_test.go | 67 -- conn_test.go | 2 + pgtype/array_test.go | 12 - pgtype/int.go | 1028 +++++++++++++++++++++- pgtype/int.go.erb | 10 +- pgtype/int4.go | 292 ------ pgtype/int4_array.go | 896 ------------------- pgtype/int4_array_test.go | 356 -------- pgtype/int4_test.go | 186 ---- pgtype/int8.go | 278 ------ pgtype/int8_array.go | 896 ------------------- pgtype/int8_array_test.go | 349 -------- pgtype/int8_test.go | 187 ---- pgtype/int_test.go | 165 ++++ pgtype/int_test.go.erb | 7 +- pgtype/integration_benchmark_test.go | 159 +--- pgtype/integration_benchmark_test.go.erb | 28 - pgtype/pgtype.go | 109 +-- pgtype/pgtype_test.go | 4 +- pgtype/pgxtype/pgxtype.go | 74 +- pgtype/typed_array_gen.sh | 1 - pgtype/zeronull/int.go | 148 ++++ pgtype/zeronull/int.go.erb | 58 ++ pgtype/zeronull/int2.go | 55 -- pgtype/zeronull/int2_test.go | 23 - pgtype/zeronull/int4.go | 90 -- pgtype/zeronull/int4_test.go | 23 - pgtype/zeronull/int8.go | 90 -- pgtype/zeronull/int8_test.go | 23 - pgtype/zeronull/int_test.go | 54 ++ pgtype/zeronull/int_test.go.erb | 26 + pgtype/zzz.int4.go | 35 - pgtype/zzz.int8.go | 35 - query_test.go | 65 +- values_test.go | 2 + 36 files changed, 1568 insertions(+), 4267 deletions(-) delete mode 100644 pgtype/int4.go delete mode 100644 pgtype/int4_array.go delete mode 100644 pgtype/int4_array_test.go delete mode 100644 pgtype/int4_test.go delete mode 100644 pgtype/int8.go delete mode 100644 pgtype/int8_array.go delete mode 100644 pgtype/int8_array_test.go delete mode 100644 pgtype/int8_test.go create mode 100644 pgtype/zeronull/int.go create mode 100644 pgtype/zeronull/int.go.erb delete mode 100644 pgtype/zeronull/int2.go delete mode 100644 pgtype/zeronull/int2_test.go delete mode 100644 pgtype/zeronull/int4.go delete mode 100644 pgtype/zeronull/int4_test.go delete mode 100644 pgtype/zeronull/int8.go delete mode 100644 pgtype/zeronull/int8_test.go create mode 100644 pgtype/zeronull/int_test.go create mode 100644 pgtype/zeronull/int_test.go.erb delete mode 100644 pgtype/zzz.int4.go delete mode 100644 pgtype/zzz.int8.go diff --git a/Rakefile b/Rakefile index 7076d2a0..4579034d 100644 --- a/Rakefile +++ b/Rakefile @@ -7,4 +7,4 @@ rule '.go' => '.go.erb' do |task| end desc "Generate code" -task generate: ["pgtype/int.go", "pgtype/int_test.go"] +task generate: ["pgtype/int.go", "pgtype/int_test.go", "pgtype/integration_benchmark_test.go", "pgtype/zeronull/int.go", "pgtype/zeronull/int_test.go"] diff --git a/bench_test.go b/bench_test.go index 06cfd0c4..9b14b7d3 100644 --- a/bench_test.go +++ b/bench_test.go @@ -1105,73 +1105,6 @@ func BenchmarkSelectRowsScanDecoder(b *testing.B) { } } -func BenchmarkSelectRowsExplicitDecoding(b *testing.B) { - conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(b, conn) - - rowCounts := getSelectRowsCounts(b) - - for _, rowCount := range rowCounts { - b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { - br := &BenchRowDecoder{} - for i := 0; i < b.N; i++ { - rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) - if err != nil { - b.Fatal(err) - } - - for rows.Next() { - rawValues := rows.RawValues() - - err = br.ID.DecodeBinary(conn.ConnInfo(), rawValues[0]) - if err != nil { - b.Fatal(err) - } - - err = br.FirstName.DecodeText(conn.ConnInfo(), rawValues[1]) - if err != nil { - b.Fatal(err) - } - - err = br.LastName.DecodeText(conn.ConnInfo(), rawValues[2]) - if err != nil { - b.Fatal(err) - } - - err = br.Sex.DecodeText(conn.ConnInfo(), rawValues[3]) - if err != nil { - b.Fatal(err) - } - - err = br.BirthDate.DecodeBinary(conn.ConnInfo(), rawValues[4]) - if err != nil { - b.Fatal(err) - } - - err = br.Weight.DecodeBinary(conn.ConnInfo(), rawValues[5]) - if err != nil { - b.Fatal(err) - } - - err = br.Height.DecodeBinary(conn.ConnInfo(), rawValues[6]) - if err != nil { - b.Fatal(err) - } - - err = br.UpdateTime.DecodeBinary(conn.ConnInfo(), rawValues[7]) - if err != nil { - b.Fatal(err) - } - } - - if rows.Err() != nil { - b.Fatal(rows.Err()) - } - } - }) - } -} - func BenchmarkSelectRowsPgConnExecText(b *testing.B) { conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(b, conn) diff --git a/conn_test.go b/conn_test.go index 857fd828..55297e26 100644 --- a/conn_test.go +++ b/conn_test.go @@ -879,6 +879,8 @@ func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) { } func TestDomainType(t *testing.T) { + t.Skip("TODO - unskip later in v5") + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { skipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") diff --git a/pgtype/array_test.go b/pgtype/array_test.go index 77700ad6..82f5f229 100644 --- a/pgtype/array_test.go +++ b/pgtype/array_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/stretchr/testify/require" ) func TestParseUntypedTextArray(t *testing.T) { @@ -122,14 +121,3 @@ func TestParseUntypedTextArray(t *testing.T) { } } } - -// https://github.com/jackc/pgx/issues/881 -func TestArrayAssignToEmptyToNonSlice(t *testing.T) { - var a pgtype.Int4Array - err := a.Set([]int32{}) - require.NoError(t, err) - - var iface interface{} - err = a.AssignTo(&iface) - require.EqualError(t, err, "cannot assign *pgtype.Int4Array to *interface {}") -} diff --git a/pgtype/int.go b/pgtype/int.go index 9da79a89..c8ec7509 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -358,7 +358,7 @@ func (scanPlanBinaryInt2ToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16 return ErrScanTargetTypeChanged } - *p = int32(binary.BigEndian.Uint16(src)) + *p = int32(int16(binary.BigEndian.Uint16(src))) return nil } @@ -405,7 +405,7 @@ func (scanPlanBinaryInt2ToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16 return ErrScanTargetTypeChanged } - *p = int64(binary.BigEndian.Uint16(src)) + *p = int64(int16(binary.BigEndian.Uint16(src))) return nil } @@ -452,7 +452,7 @@ func (scanPlanBinaryInt2ToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, return ErrScanTargetTypeChanged } - *p = int(binary.BigEndian.Uint16(src)) + *p = int(int16(binary.BigEndian.Uint16(src))) return nil } @@ -473,7 +473,7 @@ func (scanPlanBinaryInt2ToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, return ErrScanTargetTypeChanged } - n := int64(binary.BigEndian.Uint16(src)) + n := int64(int16(binary.BigEndian.Uint16(src))) if n < 0 { return fmt.Errorf("%d is less than minimum value for uint", n) } @@ -504,6 +504,1026 @@ func (scanPlanBinaryInt2ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCod return s.ScanInt64(n, true) } +type Int4 struct { + Int int32 + Valid bool +} + +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int4) ScanInt64(n int64, valid bool) error { + if !valid { + *dst = Int4{} + return nil + } + + if n < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n) + } + if n > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n) + } + *dst = Int4{Int: int32(n), Valid: true} + + return nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4) Scan(src interface{}) error { + if src == nil { + *dst = Int4{} + return nil + } + + var n int64 + + switch src := src.(type) { + case int64: + n = src + case string: + var err error + n, err = strconv.ParseInt(src, 10, 32) + if err != nil { + return err + } + case []byte: + var err error + n, err = strconv.ParseInt(string(src), 10, 32) + if err != nil { + return err + } + default: + return fmt.Errorf("cannot scan %T", src) + } + + if n < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n) + } + if n > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n) + } + *dst = Int4{Int: int32(n), Valid: true} + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Int), nil +} + +func (src Int4) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(int64(src.Int), 10)), nil +} + +func (dst *Int4) UnmarshalJSON(b []byte) error { + var n *int32 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int4{} + } else { + *dst = Int4{Int: *n, Valid: true} + } + + return nil +} + +type Int4Codec struct{} + +func (Int4Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Int4Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Int4Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { + n, valid, err := convertToInt64ForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to int4: %v", value, err) + } + if !valid { + return nil, nil + } + + if n > math.MaxInt32 { + return nil, fmt.Errorf("%d is greater than maximum value for int4", n) + } + if n < math.MinInt32 { + return nil, fmt.Errorf("%d is less than minimum value for int4", n) + } + + switch format { + case BinaryFormatCode: + return pgio.AppendInt32(buf, int32(n)), nil + case TextFormatCode: + return append(buf, strconv.FormatInt(n, 10)...), nil + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } +} + +func (Int4Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *int8: + return scanPlanBinaryInt4ToInt8{} + case *int16: + return scanPlanBinaryInt4ToInt16{} + case *int32: + return scanPlanBinaryInt4ToInt32{} + case *int64: + return scanPlanBinaryInt4ToInt64{} + case *int: + return scanPlanBinaryInt4ToInt{} + case *uint8: + return scanPlanBinaryInt4ToUint8{} + case *uint16: + return scanPlanBinaryInt4ToUint16{} + case *uint32: + return scanPlanBinaryInt4ToUint32{} + case *uint64: + return scanPlanBinaryInt4ToUint64{} + case *uint: + return scanPlanBinaryInt4ToUint{} + case Int64Scanner: + return scanPlanBinaryInt4ToInt64Scanner{} + } + case TextFormatCode: + switch target.(type) { + case *int8: + return scanPlanTextAnyToInt8{} + case *int16: + return scanPlanTextAnyToInt16{} + case *int32: + return scanPlanTextAnyToInt32{} + case *int64: + return scanPlanTextAnyToInt64{} + case *int: + return scanPlanTextAnyToInt{} + case *uint8: + return scanPlanTextAnyToUint8{} + case *uint16: + return scanPlanTextAnyToUint16{} + case *uint32: + return scanPlanTextAnyToUint32{} + case *uint64: + return scanPlanTextAnyToUint64{} + case *uint: + return scanPlanTextAnyToUint{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +func (c Int4Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n int64 + scanPlan := c.PlanScan(ci, oid, format, &n, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") + } + err := scanPlan.Scan(ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +func (c Int4Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var n int32 + scanPlan := c.PlanScan(ci, oid, format, &n, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") + } + err := scanPlan.Scan(ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +type scanPlanBinaryInt4ToInt8 struct{} + +func (scanPlanBinaryInt4ToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + p, ok := (dst).(*int8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", n) + } else if n > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", n) + } + + *p = int8(n) + + return nil +} + +type scanPlanBinaryInt4ToUint8 struct{} + +func (scanPlanBinaryInt4ToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint4: %v", len(src)) + } + + p, ok := (dst).(*uint8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint8", n) + } + + if n > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", n) + } + + *p = uint8(n) + + return nil +} + +type scanPlanBinaryInt4ToInt16 struct{} + +func (scanPlanBinaryInt4ToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", n) + } else if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", n) + } + + *p = int16(n) + + return nil +} + +type scanPlanBinaryInt4ToUint16 struct{} + +func (scanPlanBinaryInt4ToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint4: %v", len(src)) + } + + p, ok := (dst).(*uint16) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint16", n) + } + + if n > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", n) + } + + *p = uint16(n) + + return nil +} + +type scanPlanBinaryInt4ToInt32 struct{} + +func (scanPlanBinaryInt4ToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + p, ok := (dst).(*int32) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int32(binary.BigEndian.Uint32(src)) + + return nil +} + +type scanPlanBinaryInt4ToUint32 struct{} + +func (scanPlanBinaryInt4ToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint4: %v", len(src)) + } + + p, ok := (dst).(*uint32) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint32", n) + } + + *p = uint32(n) + + return nil +} + +type scanPlanBinaryInt4ToInt64 struct{} + +func (scanPlanBinaryInt4ToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + p, ok := (dst).(*int64) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int64(int32(binary.BigEndian.Uint32(src))) + + return nil +} + +type scanPlanBinaryInt4ToUint64 struct{} + +func (scanPlanBinaryInt4ToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint4: %v", len(src)) + } + + p, ok := (dst).(*uint64) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", n) + } + + *p = uint64(n) + + return nil +} + +type scanPlanBinaryInt4ToInt struct{} + +func (scanPlanBinaryInt4ToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int(int32(binary.BigEndian.Uint32(src))) + + return nil +} + +type scanPlanBinaryInt4ToUint struct{} + +func (scanPlanBinaryInt4ToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint4: %v", len(src)) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(int32(binary.BigEndian.Uint32(src))) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint", n) + } + + *p = uint(n) + + return nil +} + +type scanPlanBinaryInt4ToInt64Scanner struct{} + +func (scanPlanBinaryInt4ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(0, false) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + n := int64(binary.BigEndian.Uint32(src)) + + return s.ScanInt64(n, true) +} + +type Int8 struct { + Int int64 + Valid bool +} + +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int8) ScanInt64(n int64, valid bool) error { + if !valid { + *dst = Int8{} + return nil + } + + if n < math.MinInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n) + } + if n > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n) + } + *dst = Int8{Int: int64(n), Valid: true} + + return nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8) Scan(src interface{}) error { + if src == nil { + *dst = Int8{} + return nil + } + + var n int64 + + switch src := src.(type) { + case int64: + n = src + case string: + var err error + n, err = strconv.ParseInt(src, 10, 64) + if err != nil { + return err + } + case []byte: + var err error + n, err = strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + default: + return fmt.Errorf("cannot scan %T", src) + } + + if n < math.MinInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n) + } + if n > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n) + } + *dst = Int8{Int: int64(n), Valid: true} + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Int), nil +} + +func (src Int8) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(int64(src.Int), 10)), nil +} + +func (dst *Int8) UnmarshalJSON(b []byte) error { + var n *int64 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int8{} + } else { + *dst = Int8{Int: *n, Valid: true} + } + + return nil +} + +type Int8Codec struct{} + +func (Int8Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Int8Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Int8Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { + n, valid, err := convertToInt64ForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to int8: %v", value, err) + } + if !valid { + return nil, nil + } + + if n > math.MaxInt64 { + return nil, fmt.Errorf("%d is greater than maximum value for int8", n) + } + if n < math.MinInt64 { + return nil, fmt.Errorf("%d is less than minimum value for int8", n) + } + + switch format { + case BinaryFormatCode: + return pgio.AppendInt64(buf, int64(n)), nil + case TextFormatCode: + return append(buf, strconv.FormatInt(n, 10)...), nil + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } +} + +func (Int8Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *int8: + return scanPlanBinaryInt8ToInt8{} + case *int16: + return scanPlanBinaryInt8ToInt16{} + case *int32: + return scanPlanBinaryInt8ToInt32{} + case *int64: + return scanPlanBinaryInt8ToInt64{} + case *int: + return scanPlanBinaryInt8ToInt{} + case *uint8: + return scanPlanBinaryInt8ToUint8{} + case *uint16: + return scanPlanBinaryInt8ToUint16{} + case *uint32: + return scanPlanBinaryInt8ToUint32{} + case *uint64: + return scanPlanBinaryInt8ToUint64{} + case *uint: + return scanPlanBinaryInt8ToUint{} + case Int64Scanner: + return scanPlanBinaryInt8ToInt64Scanner{} + } + case TextFormatCode: + switch target.(type) { + case *int8: + return scanPlanTextAnyToInt8{} + case *int16: + return scanPlanTextAnyToInt16{} + case *int32: + return scanPlanTextAnyToInt32{} + case *int64: + return scanPlanTextAnyToInt64{} + case *int: + return scanPlanTextAnyToInt{} + case *uint8: + return scanPlanTextAnyToUint8{} + case *uint16: + return scanPlanTextAnyToUint16{} + case *uint32: + return scanPlanTextAnyToUint32{} + case *uint64: + return scanPlanTextAnyToUint64{} + case *uint: + return scanPlanTextAnyToUint{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +func (c Int8Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n int64 + scanPlan := c.PlanScan(ci, oid, format, &n, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") + } + err := scanPlan.Scan(ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +func (c Int8Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var n int64 + scanPlan := c.PlanScan(ci, oid, format, &n, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") + } + err := scanPlan.Scan(ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +type scanPlanBinaryInt8ToInt8 struct{} + +func (scanPlanBinaryInt8ToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + p, ok := (dst).(*int8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", n) + } else if n > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", n) + } + + *p = int8(n) + + return nil +} + +type scanPlanBinaryInt8ToUint8 struct{} + +func (scanPlanBinaryInt8ToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint8: %v", len(src)) + } + + p, ok := (dst).(*uint8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint8", n) + } + + if n > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", n) + } + + *p = uint8(n) + + return nil +} + +type scanPlanBinaryInt8ToInt16 struct{} + +func (scanPlanBinaryInt8ToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", n) + } else if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", n) + } + + *p = int16(n) + + return nil +} + +type scanPlanBinaryInt8ToUint16 struct{} + +func (scanPlanBinaryInt8ToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint8: %v", len(src)) + } + + p, ok := (dst).(*uint16) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint16", n) + } + + if n > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", n) + } + + *p = uint16(n) + + return nil +} + +type scanPlanBinaryInt8ToInt32 struct{} + +func (scanPlanBinaryInt8ToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + p, ok := (dst).(*int32) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", n) + } else if n > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", n) + } + + *p = int32(n) + + return nil +} + +type scanPlanBinaryInt8ToUint32 struct{} + +func (scanPlanBinaryInt8ToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint8: %v", len(src)) + } + + p, ok := (dst).(*uint32) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint32", n) + } + + if n > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", n) + } + + *p = uint32(n) + + return nil +} + +type scanPlanBinaryInt8ToInt64 struct{} + +func (scanPlanBinaryInt8ToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + p, ok := (dst).(*int64) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int64(binary.BigEndian.Uint64(src)) + + return nil +} + +type scanPlanBinaryInt8ToUint64 struct{} + +func (scanPlanBinaryInt8ToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint8: %v", len(src)) + } + + p, ok := (dst).(*uint64) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", n) + } + + *p = uint64(n) + + return nil +} + +type scanPlanBinaryInt8ToInt struct{} + +func (scanPlanBinaryInt8ToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < math.MinInt { + return fmt.Errorf("%d is less than minimum value for int", n) + } else if n > math.MaxInt { + return fmt.Errorf("%d is greater than maximum value for int", n) + } + + *p = int(n) + + return nil +} + +type scanPlanBinaryInt8ToUint struct{} + +func (scanPlanBinaryInt8ToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint8: %v", len(src)) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(int64(binary.BigEndian.Uint64(src))) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint", n) + } + + if uint64(n) > math.MaxUint { + return fmt.Errorf("%d is greater than maximum value for uint", n) + } + + *p = uint(n) + + return nil +} + +type scanPlanBinaryInt8ToInt64Scanner struct{} + +func (scanPlanBinaryInt8ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(0, false) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + n := int64(binary.BigEndian.Uint64(src)) + + return s.ScanInt64(n, true) +} + type scanPlanTextAnyToInt8 struct{} func (scanPlanTextAnyToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 5c8e44fa..99659d4c 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -15,7 +15,7 @@ type Int64Scanner interface { } -<% [2].each do |pg_byte_size| %> +<% [2, 4, 8].each do |pg_byte_size| %> <% pg_bit_size = pg_byte_size * 8 %> type Int<%= pg_byte_size %> struct { Int int<%= pg_bit_size %> @@ -265,8 +265,10 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %>) Scan(ci *Con } *p = int<%= dst_bit_size %>(n) - <% else %> + <% elsif dst_bit_size == pg_bit_size %> *p = int<%= dst_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + <% else %> + *p = int<%= dst_bit_size %>(int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src))) <% end %> return nil @@ -330,7 +332,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToInt) Scan(ci *ConnInfo, oid uint32, *p = int(n) <% else %> - *p = int(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + *p = int(int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src))) <% end %> return nil @@ -352,7 +354,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToUint) Scan(ci *ConnInfo, oid uint32, return ErrScanTargetTypeChanged } - n := int64(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + n := int64(int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src))) if n < 0 { return fmt.Errorf("%d is less than minimum value for uint", n) } diff --git a/pgtype/int4.go b/pgtype/int4.go deleted file mode 100644 index 6f1e61f3..00000000 --- a/pgtype/int4.go +++ /dev/null @@ -1,292 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "encoding/json" - "fmt" - "math" - "strconv" - - "github.com/jackc/pgio" -) - -type Int4 struct { - Int int32 - Valid bool -} - -func (dst *Int4) Set(src interface{}) error { - if src == nil { - *dst = Int4{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case int8: - *dst = Int4{Int: int32(value), Valid: true} - case uint8: - *dst = Int4{Int: int32(value), Valid: true} - case int16: - *dst = Int4{Int: int32(value), Valid: true} - case uint16: - *dst = Int4{Int: int32(value), Valid: true} - case int32: - *dst = Int4{Int: int32(value), Valid: true} - case uint32: - if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) - } - *dst = Int4{Int: int32(value), Valid: true} - case int64: - if value < math.MinInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) - } - if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) - } - *dst = Int4{Int: int32(value), Valid: true} - case uint64: - if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) - } - *dst = Int4{Int: int32(value), Valid: true} - case int: - if value < math.MinInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) - } - if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) - } - *dst = Int4{Int: int32(value), Valid: true} - case uint: - if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) - } - *dst = Int4{Int: int32(value), Valid: true} - case string: - num, err := strconv.ParseInt(value, 10, 32) - if err != nil { - return err - } - *dst = Int4{Int: int32(num), Valid: true} - case float32: - if value > math.MaxInt32 { - return fmt.Errorf("%f is greater than maximum value for Int4", value) - } - *dst = Int4{Int: int32(value), Valid: true} - case float64: - if value > math.MaxInt32 { - return fmt.Errorf("%f is greater than maximum value for Int4", value) - } - *dst = Int4{Int: int32(value), Valid: true} - case *int8: - if value == nil { - *dst = Int4{} - } else { - return dst.Set(*value) - } - case *uint8: - if value == nil { - *dst = Int4{} - } else { - return dst.Set(*value) - } - case *int16: - if value == nil { - *dst = Int4{} - } else { - return dst.Set(*value) - } - case *uint16: - if value == nil { - *dst = Int4{} - } else { - return dst.Set(*value) - } - case *int32: - if value == nil { - *dst = Int4{} - } else { - return dst.Set(*value) - } - case *uint32: - if value == nil { - *dst = Int4{} - } else { - return dst.Set(*value) - } - case *int64: - if value == nil { - *dst = Int4{} - } else { - return dst.Set(*value) - } - case *uint64: - if value == nil { - *dst = Int4{} - } else { - return dst.Set(*value) - } - case *int: - if value == nil { - *dst = Int4{} - } else { - return dst.Set(*value) - } - case *uint: - if value == nil { - *dst = Int4{} - } else { - return dst.Set(*value) - } - case *string: - if value == nil { - *dst = Int4{} - } else { - return dst.Set(*value) - } - case *float32: - if value == nil { - *dst = Int4{} - } else { - return dst.Set(*value) - } - case *float64: - if value == nil { - *dst = Int4{} - } else { - return dst.Set(*value) - } - default: - if originalSrc, ok := underlyingNumberType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Int4", value) - } - - return nil -} - -func (dst Int4) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.Int -} - -func (src *Int4) AssignTo(dst interface{}) error { - return int64AssignTo(int64(src.Int), src.Valid, dst) -} - -func (dst *Int4) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int4{} - return nil - } - - n, err := strconv.ParseInt(string(src), 10, 32) - if err != nil { - return err - } - - *dst = Int4{Int: int32(n), Valid: true} - return nil -} - -func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int4{} - return nil - } - - if len(src) != 4 { - return fmt.Errorf("invalid length for int4: %v", len(src)) - } - - n := int32(binary.BigEndian.Uint32(src)) - *dst = Int4{Int: n, Valid: true} - return nil -} - -func (src Int4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil -} - -func (src Int4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return pgio.AppendInt32(buf, src.Int), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int4) Scan(src interface{}) error { - if src == nil { - *dst = Int4{} - return nil - } - - switch src := src.(type) { - case int64: - if src < math.MinInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", src) - } - if src > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", src) - } - *dst = Int4{Int: int32(src), Valid: true} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Int4) Value() (driver.Value, error) { - if !src.Valid { - return nil, nil - } - return int64(src.Int), nil -} - -func (src Int4) MarshalJSON() ([]byte, error) { - if !src.Valid { - return []byte("null"), nil - } - return []byte(strconv.FormatInt(int64(src.Int), 10)), nil -} - -func (dst *Int4) UnmarshalJSON(b []byte) error { - var n *int32 - err := json.Unmarshal(b, &n) - if err != nil { - return err - } - - if n == nil { - *dst = Int4{} - } else { - *dst = Int4{Int: *n, Valid: true} - } - - return nil -} diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go deleted file mode 100644 index e725e7a8..00000000 --- a/pgtype/int4_array.go +++ /dev/null @@ -1,896 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type Int4Array struct { - Elements []Int4 - Dimensions []ArrayDimension - Valid bool -} - -func (dst *Int4Array) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Int4Array{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []int16: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int16: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint16: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint16: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []int32: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int32: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint32: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint32: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []int64: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int64: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint64: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint64: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []int: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Int4: - if value == nil { - *dst = Int4Array{} - } else if len(value) == 0 { - *dst = Int4Array{Valid: true} - } else { - *dst = Int4Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = Int4Array{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for Int4Array", src) - } - if elementsLength == 0 { - *dst = Int4Array{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Int4Array", src) - } - - *dst = Int4Array{ - Elements: make([]Int4, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Int4, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to Int4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *Int4Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to Int4Array") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in Int4Array", err) - } - index++ - - return index, nil -} - -func (dst Int4Array) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *Int4Array) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int16: - *v = make([]*int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint16: - *v = make([]*uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int32: - *v = make([]*int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint32: - *v = make([]*uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int: - *v = make([]int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int: - *v = make([]*int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint: - *v = make([]uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint: - *v = make([]*uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *Int4Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from Int4Array") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from Int4Array") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int4Array{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Int4 - - if len(uta.Elements) > 0 { - elements = make([]Int4, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Int4 - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = Int4Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int4Array{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = Int4Array{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Int4, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = Int4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("int4"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "int4") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int4Array) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Int4Array) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/int4_array_test.go b/pgtype/int4_array_test.go deleted file mode 100644 index 906e4775..00000000 --- a/pgtype/int4_array_test.go +++ /dev/null @@ -1,356 +0,0 @@ -package pgtype_test - -import ( - "math" - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestInt4ArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int4[]", []interface{}{ - &pgtype.Int4Array{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.Int4Array{}, - &pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - {}, - {Int: 6, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestInt4ArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int4Array - expectedError bool - }{ - { - source: []int64{1}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []int32{1}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []int16{1}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []int{1}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []int{1, math.MaxInt32 + 1, 2}, - expectedError: true, - }, - { - source: []uint64{1}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []uint32{1}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []uint16{1}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]int32)(nil)), - result: pgtype.Int4Array{}, - }, - { - source: [][]int32{{1}, {2}}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - {Int: 5, Valid: true}, - {Int: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]int32{{1}, {2}}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - {Int: 5, Valid: true}, - {Int: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Int4Array - err := r.Set(tt.source) - if err != nil { - if tt.expectedError { - continue - } - t.Errorf("%d: %v", i, err) - } - - if tt.expectedError { - t.Errorf("%d: an error was expected, %v", i, tt) - continue - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt4ArrayAssignTo(t *testing.T) { - var int32Slice []int32 - var uint32Slice []uint32 - var namedInt32Slice _int32Slice - var int32SliceDim2 [][]int32 - var int32SliceDim4 [][][][]int32 - var int32ArrayDim2 [2][1]int32 - var int32ArrayDim4 [2][1][1][3]int32 - - simpleTests := []struct { - src pgtype.Int4Array - dst interface{} - expected interface{} - }{ - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &int32Slice, - expected: []int32{1}, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &uint32Slice, - expected: []uint32{1}, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &namedInt32Slice, - expected: _int32Slice{1}, - }, - { - src: pgtype.Int4Array{}, - dst: &int32Slice, - expected: (([]int32)(nil)), - }, - { - src: pgtype.Int4Array{Valid: true}, - dst: &int32Slice, - expected: []int32{}, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - expected: [][]int32{{1}, {2}}, - dst: &int32SliceDim2, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - {Int: 5, Valid: true}, - {Int: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - expected: [][][][]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &int32SliceDim4, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - expected: [2][1]int32{{1}, {2}}, - dst: &int32ArrayDim2, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - {Int: 5, Valid: true}, - {Int: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - expected: [2][1][1][3]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &int32ArrayDim4, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int4Array - dst interface{} - }{ - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &int32Slice, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: -1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &uint32Slice, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &int32ArrayDim2, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &int32Slice, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &int32ArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/int4_test.go b/pgtype/int4_test.go deleted file mode 100644 index 118c3ac5..00000000 --- a/pgtype/int4_test.go +++ /dev/null @@ -1,186 +0,0 @@ -package pgtype_test - -import ( - "math" - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestInt4Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ - &pgtype.Int4{Int: math.MinInt32, Valid: true}, - &pgtype.Int4{Int: -1, Valid: true}, - &pgtype.Int4{Int: 0, Valid: true}, - &pgtype.Int4{Int: 1, Valid: true}, - &pgtype.Int4{Int: math.MaxInt32, Valid: true}, - &pgtype.Int4{Int: 0}, - }) -} - -func TestInt4Set(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int4 - }{ - {source: int8(1), result: pgtype.Int4{Int: 1, Valid: true}}, - {source: int16(1), result: pgtype.Int4{Int: 1, Valid: true}}, - {source: int32(1), result: pgtype.Int4{Int: 1, Valid: true}}, - {source: int64(1), result: pgtype.Int4{Int: 1, Valid: true}}, - {source: int8(-1), result: pgtype.Int4{Int: -1, Valid: true}}, - {source: int16(-1), result: pgtype.Int4{Int: -1, Valid: true}}, - {source: int32(-1), result: pgtype.Int4{Int: -1, Valid: true}}, - {source: int64(-1), result: pgtype.Int4{Int: -1, Valid: true}}, - {source: uint8(1), result: pgtype.Int4{Int: 1, Valid: true}}, - {source: uint16(1), result: pgtype.Int4{Int: 1, Valid: true}}, - {source: uint32(1), result: pgtype.Int4{Int: 1, Valid: true}}, - {source: uint64(1), result: pgtype.Int4{Int: 1, Valid: true}}, - {source: float32(1), result: pgtype.Int4{Int: 1, Valid: true}}, - {source: float64(1), result: pgtype.Int4{Int: 1, Valid: true}}, - {source: "1", result: pgtype.Int4{Int: 1, Valid: true}}, - {source: _int8(1), result: pgtype.Int4{Int: 1, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.Int4 - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt4AssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - - simpleTests := []struct { - src pgtype.Int4 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int4{Int: 42, Valid: true}, dst: &i8, expected: int8(42)}, - {src: pgtype.Int4{Int: 42, Valid: true}, dst: &i16, expected: int16(42)}, - {src: pgtype.Int4{Int: 42, Valid: true}, dst: &i32, expected: int32(42)}, - {src: pgtype.Int4{Int: 42, Valid: true}, dst: &i64, expected: int64(42)}, - {src: pgtype.Int4{Int: 42, Valid: true}, dst: &i, expected: int(42)}, - {src: pgtype.Int4{Int: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Int4{Int: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Int4{Int: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Int4{Int: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Int4{Int: 42, Valid: true}, dst: &ui, expected: uint(42)}, - {src: pgtype.Int4{Int: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Int4{Int: 0}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Int4{Int: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Int4 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int4{Int: 42, Valid: true}, dst: &pi8, expected: int8(42)}, - {src: pgtype.Int4{Int: 42, Valid: true}, dst: &_pi8, expected: _int8(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int4 - dst interface{} - }{ - {src: pgtype.Int4{Int: 150, Valid: true}, dst: &i8}, - {src: pgtype.Int4{Int: 40000, Valid: true}, dst: &i16}, - {src: pgtype.Int4{Int: -1, Valid: true}, dst: &ui8}, - {src: pgtype.Int4{Int: -1, Valid: true}, dst: &ui16}, - {src: pgtype.Int4{Int: -1, Valid: true}, dst: &ui32}, - {src: pgtype.Int4{Int: -1, Valid: true}, dst: &ui64}, - {src: pgtype.Int4{Int: -1, Valid: true}, dst: &ui}, - {src: pgtype.Int4{Int: 0}, dst: &i32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} - -func TestInt4MarshalJSON(t *testing.T) { - successfulTests := []struct { - source pgtype.Int4 - result string - }{ - {source: pgtype.Int4{Int: 0}, result: "null"}, - {source: pgtype.Int4{Int: 1, Valid: true}, result: "1"}, - } - for i, tt := range successfulTests { - r, err := tt.source.MarshalJSON() - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if string(r) != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) - } - } -} - -func TestInt4UnmarshalJSON(t *testing.T) { - successfulTests := []struct { - source string - result pgtype.Int4 - }{ - {source: "null", result: pgtype.Int4{Int: 0}}, - {source: "1", result: pgtype.Int4{Int: 1, Valid: true}}, - } - for i, tt := range successfulTests { - var r pgtype.Int4 - err := r.UnmarshalJSON([]byte(tt.source)) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} diff --git a/pgtype/int8.go b/pgtype/int8.go deleted file mode 100644 index 794f92c6..00000000 --- a/pgtype/int8.go +++ /dev/null @@ -1,278 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "encoding/json" - "fmt" - "math" - "strconv" - - "github.com/jackc/pgio" -) - -type Int8 struct { - Int int64 - Valid bool -} - -func (dst *Int8) Set(src interface{}) error { - if src == nil { - *dst = Int8{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case int8: - *dst = Int8{Int: int64(value), Valid: true} - case uint8: - *dst = Int8{Int: int64(value), Valid: true} - case int16: - *dst = Int8{Int: int64(value), Valid: true} - case uint16: - *dst = Int8{Int: int64(value), Valid: true} - case int32: - *dst = Int8{Int: int64(value), Valid: true} - case uint32: - *dst = Int8{Int: int64(value), Valid: true} - case int64: - *dst = Int8{Int: int64(value), Valid: true} - case uint64: - if value > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", value) - } - *dst = Int8{Int: int64(value), Valid: true} - case int: - if int64(value) < math.MinInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", value) - } - if int64(value) > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", value) - } - *dst = Int8{Int: int64(value), Valid: true} - case uint: - if uint64(value) > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", value) - } - *dst = Int8{Int: int64(value), Valid: true} - case string: - num, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return err - } - *dst = Int8{Int: num, Valid: true} - case float32: - if value > math.MaxInt64 { - return fmt.Errorf("%f is greater than maximum value for Int8", value) - } - *dst = Int8{Int: int64(value), Valid: true} - case float64: - if value > math.MaxInt64 { - return fmt.Errorf("%f is greater than maximum value for Int8", value) - } - *dst = Int8{Int: int64(value), Valid: true} - case *int8: - if value == nil { - *dst = Int8{} - } else { - return dst.Set(*value) - } - case *uint8: - if value == nil { - *dst = Int8{} - } else { - return dst.Set(*value) - } - case *int16: - if value == nil { - *dst = Int8{} - } else { - return dst.Set(*value) - } - case *uint16: - if value == nil { - *dst = Int8{} - } else { - return dst.Set(*value) - } - case *int32: - if value == nil { - *dst = Int8{} - } else { - return dst.Set(*value) - } - case *uint32: - if value == nil { - *dst = Int8{} - } else { - return dst.Set(*value) - } - case *int64: - if value == nil { - *dst = Int8{} - } else { - return dst.Set(*value) - } - case *uint64: - if value == nil { - *dst = Int8{} - } else { - return dst.Set(*value) - } - case *int: - if value == nil { - *dst = Int8{} - } else { - return dst.Set(*value) - } - case *uint: - if value == nil { - *dst = Int8{} - } else { - return dst.Set(*value) - } - case *string: - if value == nil { - *dst = Int8{} - } else { - return dst.Set(*value) - } - case *float32: - if value == nil { - *dst = Int8{} - } else { - return dst.Set(*value) - } - case *float64: - if value == nil { - *dst = Int8{} - } else { - return dst.Set(*value) - } - default: - if originalSrc, ok := underlyingNumberType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Int8", value) - } - - return nil -} - -func (dst Int8) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.Int -} - -func (src *Int8) AssignTo(dst interface{}) error { - return int64AssignTo(int64(src.Int), src.Valid, dst) -} - -func (dst *Int8) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int8{} - return nil - } - - n, err := strconv.ParseInt(string(src), 10, 64) - if err != nil { - return err - } - - *dst = Int8{Int: n, Valid: true} - return nil -} - -func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int8{} - return nil - } - - if len(src) != 8 { - return fmt.Errorf("invalid length for int8: %v", len(src)) - } - - n := int64(binary.BigEndian.Uint64(src)) - - *dst = Int8{Int: n, Valid: true} - return nil -} - -func (src Int8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return append(buf, strconv.FormatInt(src.Int, 10)...), nil -} - -func (src Int8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return pgio.AppendInt64(buf, src.Int), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int8) Scan(src interface{}) error { - if src == nil { - *dst = Int8{} - return nil - } - - switch src := src.(type) { - case int64: - *dst = Int8{Int: src, Valid: true} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Int8) Value() (driver.Value, error) { - if !src.Valid { - return nil, nil - } - return int64(src.Int), nil -} - -func (src Int8) MarshalJSON() ([]byte, error) { - if !src.Valid { - return []byte("null"), nil - } - return []byte(strconv.FormatInt(src.Int, 10)), nil -} - -func (dst *Int8) UnmarshalJSON(b []byte) error { - var n *int64 - err := json.Unmarshal(b, &n) - if err != nil { - return err - } - - if n == nil { - *dst = Int8{} - } else { - *dst = Int8{Int: *n, Valid: true} - } - - return nil -} diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go deleted file mode 100644 index d6f38994..00000000 --- a/pgtype/int8_array.go +++ /dev/null @@ -1,896 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type Int8Array struct { - Elements []Int8 - Dimensions []ArrayDimension - Valid bool -} - -func (dst *Int8Array) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Int8Array{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []int16: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int16: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint16: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint16: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []int32: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int32: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint32: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint32: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []int64: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int64: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint64: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint64: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []int: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Int8: - if value == nil { - *dst = Int8Array{} - } else if len(value) == 0 { - *dst = Int8Array{Valid: true} - } else { - *dst = Int8Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = Int8Array{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for Int8Array", src) - } - if elementsLength == 0 { - *dst = Int8Array{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Int8Array", src) - } - - *dst = Int8Array{ - Elements: make([]Int8, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Int8, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to Int8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *Int8Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to Int8Array") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in Int8Array", err) - } - index++ - - return index, nil -} - -func (dst Int8Array) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *Int8Array) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int16: - *v = make([]*int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint16: - *v = make([]*uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int32: - *v = make([]*int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint32: - *v = make([]*uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int: - *v = make([]int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int: - *v = make([]*int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint: - *v = make([]uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint: - *v = make([]*uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *Int8Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from Int8Array") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from Int8Array") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int8Array{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Int8 - - if len(uta.Elements) > 0 { - elements = make([]Int8, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Int8 - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = Int8Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int8Array{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = Int8Array{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Int8, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = Int8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("int8"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "int8") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int8Array) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Int8Array) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/int8_array_test.go b/pgtype/int8_array_test.go deleted file mode 100644 index 2d875b24..00000000 --- a/pgtype/int8_array_test.go +++ /dev/null @@ -1,349 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestInt8ArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int8[]", []interface{}{ - &pgtype.Int8Array{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.Int8Array{}, - &pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - {}, - {Int: 6, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestInt8ArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int8Array - }{ - { - source: []int64{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []int32{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []int16{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []int{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []uint64{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []uint32{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []uint16{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []uint{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]int64)(nil)), - result: pgtype.Int8Array{}, - }, - { - source: [][]int64{{1}, {2}}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - {Int: 5, Valid: true}, - {Int: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]int64{{1}, {2}}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - {Int: 5, Valid: true}, - {Int: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Int8Array - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt8ArrayAssignTo(t *testing.T) { - var int64Slice []int64 - var uint64Slice []uint64 - var namedInt64Slice _int64Slice - var int64SliceDim2 [][]int64 - var int64SliceDim4 [][][][]int64 - var int64ArrayDim2 [2][1]int64 - var int64ArrayDim4 [2][1][1][3]int64 - - simpleTests := []struct { - src pgtype.Int8Array - dst interface{} - expected interface{} - }{ - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &int64Slice, - expected: []int64{1}, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &uint64Slice, - expected: []uint64{1}, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &namedInt64Slice, - expected: _int64Slice{1}, - }, - { - src: pgtype.Int8Array{}, - dst: &int64Slice, - expected: (([]int64)(nil)), - }, - { - src: pgtype.Int8Array{Valid: true}, - dst: &int64Slice, - expected: []int64{}, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - expected: [][]int64{{1}, {2}}, - dst: &int64SliceDim2, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - {Int: 5, Valid: true}, - {Int: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - expected: [][][][]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &int64SliceDim4, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - expected: [2][1]int64{{1}, {2}}, - dst: &int64ArrayDim2, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - {Int: 4, Valid: true}, - {Int: 5, Valid: true}, - {Int: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - expected: [2][1][1][3]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &int64ArrayDim4, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int8Array - dst interface{} - }{ - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &int64Slice, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: -1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &uint64Slice, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &int64ArrayDim2, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &int64Slice, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &int64ArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/int8_test.go b/pgtype/int8_test.go deleted file mode 100644 index 657eb702..00000000 --- a/pgtype/int8_test.go +++ /dev/null @@ -1,187 +0,0 @@ -package pgtype_test - -import ( - "math" - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestInt8Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ - &pgtype.Int8{Int: math.MinInt64, Valid: true}, - &pgtype.Int8{Int: -1, Valid: true}, - &pgtype.Int8{Int: 0, Valid: true}, - &pgtype.Int8{Int: 1, Valid: true}, - &pgtype.Int8{Int: math.MaxInt64, Valid: true}, - &pgtype.Int8{Int: 0}, - }) -} - -func TestInt8Set(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int8 - }{ - {source: int8(1), result: pgtype.Int8{Int: 1, Valid: true}}, - {source: int16(1), result: pgtype.Int8{Int: 1, Valid: true}}, - {source: int32(1), result: pgtype.Int8{Int: 1, Valid: true}}, - {source: int64(1), result: pgtype.Int8{Int: 1, Valid: true}}, - {source: int8(-1), result: pgtype.Int8{Int: -1, Valid: true}}, - {source: int16(-1), result: pgtype.Int8{Int: -1, Valid: true}}, - {source: int32(-1), result: pgtype.Int8{Int: -1, Valid: true}}, - {source: int64(-1), result: pgtype.Int8{Int: -1, Valid: true}}, - {source: uint8(1), result: pgtype.Int8{Int: 1, Valid: true}}, - {source: uint16(1), result: pgtype.Int8{Int: 1, Valid: true}}, - {source: uint32(1), result: pgtype.Int8{Int: 1, Valid: true}}, - {source: uint64(1), result: pgtype.Int8{Int: 1, Valid: true}}, - {source: float32(1), result: pgtype.Int8{Int: 1, Valid: true}}, - {source: float64(1), result: pgtype.Int8{Int: 1, Valid: true}}, - {source: "1", result: pgtype.Int8{Int: 1, Valid: true}}, - {source: _int8(1), result: pgtype.Int8{Int: 1, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.Int8 - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt8AssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - - simpleTests := []struct { - src pgtype.Int8 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int8{Int: 42, Valid: true}, dst: &i8, expected: int8(42)}, - {src: pgtype.Int8{Int: 42, Valid: true}, dst: &i16, expected: int16(42)}, - {src: pgtype.Int8{Int: 42, Valid: true}, dst: &i32, expected: int32(42)}, - {src: pgtype.Int8{Int: 42, Valid: true}, dst: &i64, expected: int64(42)}, - {src: pgtype.Int8{Int: 42, Valid: true}, dst: &i, expected: int(42)}, - {src: pgtype.Int8{Int: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Int8{Int: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Int8{Int: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Int8{Int: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Int8{Int: 42, Valid: true}, dst: &ui, expected: uint(42)}, - {src: pgtype.Int8{Int: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Int8{Int: 0}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Int8{Int: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Int8 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int8{Int: 42, Valid: true}, dst: &pi8, expected: int8(42)}, - {src: pgtype.Int8{Int: 42, Valid: true}, dst: &_pi8, expected: _int8(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int8 - dst interface{} - }{ - {src: pgtype.Int8{Int: 150, Valid: true}, dst: &i8}, - {src: pgtype.Int8{Int: 40000, Valid: true}, dst: &i16}, - {src: pgtype.Int8{Int: 5000000000, Valid: true}, dst: &i32}, - {src: pgtype.Int8{Int: -1, Valid: true}, dst: &ui8}, - {src: pgtype.Int8{Int: -1, Valid: true}, dst: &ui16}, - {src: pgtype.Int8{Int: -1, Valid: true}, dst: &ui32}, - {src: pgtype.Int8{Int: -1, Valid: true}, dst: &ui64}, - {src: pgtype.Int8{Int: -1, Valid: true}, dst: &ui}, - {src: pgtype.Int8{Int: 0}, dst: &i64}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} - -func TestInt8MarshalJSON(t *testing.T) { - successfulTests := []struct { - source pgtype.Int8 - result string - }{ - {source: pgtype.Int8{Int: 0}, result: "null"}, - {source: pgtype.Int8{Int: 1, Valid: true}, result: "1"}, - } - for i, tt := range successfulTests { - r, err := tt.source.MarshalJSON() - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if string(r) != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) - } - } -} - -func TestInt8UnmarshalJSON(t *testing.T) { - successfulTests := []struct { - source string - result pgtype.Int8 - }{ - {source: "null", result: pgtype.Int8{Int: 0}}, - {source: "1", result: pgtype.Int8{Int: 1, Valid: true}}, - } - for i, tt := range successfulTests { - var r pgtype.Int8 - err := r.UnmarshalJSON([]byte(tt.source)) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} diff --git a/pgtype/int_test.go b/pgtype/int_test.go index 272ad69b..77aa0589 100644 --- a/pgtype/int_test.go +++ b/pgtype/int_test.go @@ -31,6 +31,11 @@ func TestInt2Codec(t *testing.T) { {1, new(uint64), isExpectedEq(uint64(1))}, {1, new(int), isExpectedEq(int(1))}, {1, new(uint), isExpectedEq(uint(1))}, + {-1, new(int8), isExpectedEq(int8(-1))}, + {-1, new(int16), isExpectedEq(int16(-1))}, + {-1, new(int32), isExpectedEq(int32(-1))}, + {-1, new(int64), isExpectedEq(int64(-1))}, + {-1, new(int), isExpectedEq(int(-1))}, {math.MinInt16, new(int16), isExpectedEq(int16(math.MinInt16))}, {-1, new(int16), isExpectedEq(int16(-1))}, {0, new(int16), isExpectedEq(int16(0))}, @@ -82,3 +87,163 @@ func TestInt2UnmarshalJSON(t *testing.T) { } } } + +func TestInt4Codec(t *testing.T) { + testPgxCodec(t, "int4", []PgxTranscodeTestCase{ + {int8(1), new(int32), isExpectedEq(int32(1))}, + {int16(1), new(int32), isExpectedEq(int32(1))}, + {int32(1), new(int32), isExpectedEq(int32(1))}, + {int64(1), new(int32), isExpectedEq(int32(1))}, + {uint8(1), new(int32), isExpectedEq(int32(1))}, + {uint16(1), new(int32), isExpectedEq(int32(1))}, + {uint32(1), new(int32), isExpectedEq(int32(1))}, + {uint64(1), new(int32), isExpectedEq(int32(1))}, + {int(1), new(int32), isExpectedEq(int32(1))}, + {uint(1), new(int32), isExpectedEq(int32(1))}, + {pgtype.Int4{Int: 1, Valid: true}, new(int32), isExpectedEq(int32(1))}, + {1, new(int8), isExpectedEq(int8(1))}, + {1, new(int16), isExpectedEq(int16(1))}, + {1, new(int32), isExpectedEq(int32(1))}, + {1, new(int64), isExpectedEq(int64(1))}, + {1, new(uint8), isExpectedEq(uint8(1))}, + {1, new(uint16), isExpectedEq(uint16(1))}, + {1, new(uint32), isExpectedEq(uint32(1))}, + {1, new(uint64), isExpectedEq(uint64(1))}, + {1, new(int), isExpectedEq(int(1))}, + {1, new(uint), isExpectedEq(uint(1))}, + {-1, new(int8), isExpectedEq(int8(-1))}, + {-1, new(int16), isExpectedEq(int16(-1))}, + {-1, new(int32), isExpectedEq(int32(-1))}, + {-1, new(int64), isExpectedEq(int64(-1))}, + {-1, new(int), isExpectedEq(int(-1))}, + {math.MinInt32, new(int32), isExpectedEq(int32(math.MinInt32))}, + {-1, new(int32), isExpectedEq(int32(-1))}, + {0, new(int32), isExpectedEq(int32(0))}, + {1, new(int32), isExpectedEq(int32(1))}, + {math.MaxInt32, new(int32), isExpectedEq(int32(math.MaxInt32))}, + {1, new(pgtype.Int4), isExpectedEq(pgtype.Int4{Int: 1, Valid: true})}, + {pgtype.Int4{}, new(pgtype.Int4), isExpectedEq(pgtype.Int4{})}, + {nil, new(*int32), isExpectedEq((*int32)(nil))}, + }) +} + +func TestInt4MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int4 + result string + }{ + {source: pgtype.Int4{Int: 0}, result: "null"}, + {source: pgtype.Int4{Int: 1, Valid: true}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestInt4UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int4 + }{ + {source: "null", result: pgtype.Int4{Int: 0}}, + {source: "1", result: pgtype.Int4{Int: 1, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Int4 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt8Codec(t *testing.T) { + testPgxCodec(t, "int8", []PgxTranscodeTestCase{ + {int8(1), new(int64), isExpectedEq(int64(1))}, + {int16(1), new(int64), isExpectedEq(int64(1))}, + {int32(1), new(int64), isExpectedEq(int64(1))}, + {int64(1), new(int64), isExpectedEq(int64(1))}, + {uint8(1), new(int64), isExpectedEq(int64(1))}, + {uint16(1), new(int64), isExpectedEq(int64(1))}, + {uint32(1), new(int64), isExpectedEq(int64(1))}, + {uint64(1), new(int64), isExpectedEq(int64(1))}, + {int(1), new(int64), isExpectedEq(int64(1))}, + {uint(1), new(int64), isExpectedEq(int64(1))}, + {pgtype.Int8{Int: 1, Valid: true}, new(int64), isExpectedEq(int64(1))}, + {1, new(int8), isExpectedEq(int8(1))}, + {1, new(int16), isExpectedEq(int16(1))}, + {1, new(int32), isExpectedEq(int32(1))}, + {1, new(int64), isExpectedEq(int64(1))}, + {1, new(uint8), isExpectedEq(uint8(1))}, + {1, new(uint16), isExpectedEq(uint16(1))}, + {1, new(uint32), isExpectedEq(uint32(1))}, + {1, new(uint64), isExpectedEq(uint64(1))}, + {1, new(int), isExpectedEq(int(1))}, + {1, new(uint), isExpectedEq(uint(1))}, + {-1, new(int8), isExpectedEq(int8(-1))}, + {-1, new(int16), isExpectedEq(int16(-1))}, + {-1, new(int32), isExpectedEq(int32(-1))}, + {-1, new(int64), isExpectedEq(int64(-1))}, + {-1, new(int), isExpectedEq(int(-1))}, + {math.MinInt64, new(int64), isExpectedEq(int64(math.MinInt64))}, + {-1, new(int64), isExpectedEq(int64(-1))}, + {0, new(int64), isExpectedEq(int64(0))}, + {1, new(int64), isExpectedEq(int64(1))}, + {math.MaxInt64, new(int64), isExpectedEq(int64(math.MaxInt64))}, + {1, new(pgtype.Int8), isExpectedEq(pgtype.Int8{Int: 1, Valid: true})}, + {pgtype.Int8{}, new(pgtype.Int8), isExpectedEq(pgtype.Int8{})}, + {nil, new(*int64), isExpectedEq((*int64)(nil))}, + }) +} + +func TestInt8MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int8 + result string + }{ + {source: pgtype.Int8{Int: 0}, result: "null"}, + {source: pgtype.Int8{Int: 1, Valid: true}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestInt8UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int8 + }{ + {source: "null", result: pgtype.Int8{Int: 0}}, + {source: "1", result: pgtype.Int8{Int: 1, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Int8 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/int_test.go.erb b/pgtype/int_test.go.erb index 28847a11..afcc8b9c 100644 --- a/pgtype/int_test.go.erb +++ b/pgtype/int_test.go.erb @@ -7,7 +7,7 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -<% [2].each do |pg_byte_size| %> +<% [2, 4, 8].each do |pg_byte_size| %> <% pg_bit_size = pg_byte_size * 8 %> func TestInt<%= pg_byte_size %>Codec(t *testing.T) { testPgxCodec(t, "int<%= pg_byte_size %>", []PgxTranscodeTestCase{ @@ -32,6 +32,11 @@ func TestInt<%= pg_byte_size %>Codec(t *testing.T) { {1, new(uint64), isExpectedEq(uint64(1))}, {1, new(int), isExpectedEq(int(1))}, {1, new(uint), isExpectedEq(uint(1))}, + {-1, new(int8), isExpectedEq(int8(-1))}, + {-1, new(int16), isExpectedEq(int16(-1))}, + {-1, new(int32), isExpectedEq(int32(-1))}, + {-1, new(int64), isExpectedEq(int64(-1))}, + {-1, new(int), isExpectedEq(int(-1))}, {math.MinInt<%= pg_bit_size %>, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(math.MinInt<%= pg_bit_size %>))}, {-1, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(-1))}, {0, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(0))}, diff --git a/pgtype/integration_benchmark_test.go b/pgtype/integration_benchmark_test.go index ad9c0598..58934ead 100644 --- a/pgtype/integration_benchmark_test.go +++ b/pgtype/integration_benchmark_test.go @@ -1,5 +1,4 @@ -// Code generated by erb. DO NOT EDIT. - +// Do not edit. Generated from pgtype/integration_benchmark_test.go.erb package pgtype_test import ( @@ -1311,32 +1310,6 @@ func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testing } } -func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_ArrayType_10(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - conn.ConnInfo().RegisterDataType(pgtype.DataType{ - Value: pgtype.NewArrayType("_int4", pgtype.Int4OID, func() pgtype.ValueTranscoder { return &pgtype.Int4{} }), - Name: "_int4", - OID: pgtype.Int4ArrayOID, - }) - - b.ResetTimer() - var v []int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select array_agg(n) from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) - } - } -} - func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testing.B) { conn := testutil.MustConnectPgx(b) defer testutil.MustCloseContext(b, conn) @@ -1357,32 +1330,6 @@ func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testi } } -func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_ArrayType_10(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - conn.ConnInfo().RegisterDataType(pgtype.DataType{ - Value: pgtype.NewArrayType("_int4", pgtype.Int4OID, func() pgtype.ValueTranscoder { return &pgtype.Int4{} }), - Name: "_int4", - OID: pgtype.Int4ArrayOID, - }) - - b.ResetTimer() - var v []int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select array_agg(n) from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) - } - } -} - func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *testing.B) { conn := testutil.MustConnectPgx(b) defer testutil.MustCloseContext(b, conn) @@ -1403,32 +1350,6 @@ func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *testin } } -func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_ArrayType_100(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - conn.ConnInfo().RegisterDataType(pgtype.DataType{ - Value: pgtype.NewArrayType("_int4", pgtype.Int4OID, func() pgtype.ValueTranscoder { return &pgtype.Int4{} }), - Name: "_int4", - OID: pgtype.Int4ArrayOID, - }) - - b.ResetTimer() - var v []int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select array_agg(n) from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) - } - } -} - func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *testing.B) { conn := testutil.MustConnectPgx(b) defer testutil.MustCloseContext(b, conn) @@ -1449,32 +1370,6 @@ func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *test } } -func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_ArrayType_100(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - conn.ConnInfo().RegisterDataType(pgtype.DataType{ - Value: pgtype.NewArrayType("_int4", pgtype.Int4OID, func() pgtype.ValueTranscoder { return &pgtype.Int4{} }), - Name: "_int4", - OID: pgtype.Int4ArrayOID, - }) - - b.ResetTimer() - var v []int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select array_agg(n) from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) - } - } -} - func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *testing.B) { conn := testutil.MustConnectPgx(b) defer testutil.MustCloseContext(b, conn) @@ -1495,32 +1390,6 @@ func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *testi } } -func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_ArrayType_1000(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - conn.ConnInfo().RegisterDataType(pgtype.DataType{ - Value: pgtype.NewArrayType("_int4", pgtype.Int4OID, func() pgtype.ValueTranscoder { return &pgtype.Int4{} }), - Name: "_int4", - OID: pgtype.Int4ArrayOID, - }) - - b.ResetTimer() - var v []int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select array_agg(n) from generate_series(1, 1000) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) - } - } -} - func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *testing.B) { conn := testutil.MustConnectPgx(b) defer testutil.MustCloseContext(b, conn) @@ -1540,29 +1409,3 @@ func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *tes } } } - -func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_ArrayType_1000(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - conn.ConnInfo().RegisterDataType(pgtype.DataType{ - Value: pgtype.NewArrayType("_int4", pgtype.Int4OID, func() pgtype.ValueTranscoder { return &pgtype.Int4{} }), - Name: "_int4", - OID: pgtype.Int4ArrayOID, - }) - - b.ResetTimer() - var v []int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select array_agg(n) from generate_series(1, 1000) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) - } - } -} diff --git a/pgtype/integration_benchmark_test.go.erb b/pgtype/integration_benchmark_test.go.erb index 1b091364..de9cabc9 100644 --- a/pgtype/integration_benchmark_test.go.erb +++ b/pgtype/integration_benchmark_test.go.erb @@ -1,5 +1,3 @@ -// Code generated by erb. DO NOT EDIT. - package pgtype_test import ( @@ -64,31 +62,5 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_Int4Array_With_Go_Int4Array } } } - -func BenchmarkQuery<%= format_name %>FormatDecode_PG_Int4Array_With_Go_ArrayType_<%= array_size %>(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - conn.ConnInfo().RegisterDataType(pgtype.DataType{ - Value: pgtype.NewArrayType("_int4", pgtype.Int4OID, func() pgtype.ValueTranscoder { return &pgtype.Int4{} }), - Name: "_int4", - OID: pgtype.Int4ArrayOID, - }) - - b.ResetTimer() - var v []int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select array_agg(n) from generate_series(1, <%= array_size %>) n`, - []interface{}{pgx.QueryResultFormats{<%= format_code %>}}, - []interface{}{&v}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) - } - } -} <% end %> <% end %> diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 6df3a582..fe3fae44 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -301,8 +301,8 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID}) ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID}) ci.RegisterDataType(DataType{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementCodec: Int2Codec{}, ElementOID: Int2OID}}) - ci.RegisterDataType(DataType{Value: &Int4Array{}, Name: "_int4", OID: Int4ArrayOID}) - ci.RegisterDataType(DataType{Value: &Int8Array{}, Name: "_int8", OID: Int8ArrayOID}) + ci.RegisterDataType(DataType{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementCodec: Int4Codec{}, ElementOID: Int4OID}}) + ci.RegisterDataType(DataType{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementCodec: Int8Codec{}, ElementOID: Int8OID}}) ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) ci.RegisterDataType(DataType{Value: &TextArray{}, Name: "_text", OID: TextArrayOID}) ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID}) @@ -325,9 +325,9 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID}) ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID}) ci.RegisterDataType(DataType{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) - ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID}) + ci.RegisterDataType(DataType{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) // ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) - ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID}) + ci.RegisterDataType(DataType{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) // ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) ci.RegisterDataType(DataType{Value: &Interval{}, Name: "interval", OID: IntervalOID}) ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID}) @@ -408,7 +408,9 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { { var formatCode int16 - if pfp, ok := t.Value.(FormatSupport); ok { + if t.Codec != nil { + formatCode = t.Codec.PreferredFormat() + } else if pfp, ok := t.Value.(FormatSupport); ok { formatCode = pfp.PreferredFormat() } else if _, ok := t.Value.(BinaryEncoder); ok { formatCode = BinaryFormatCode @@ -547,6 +549,13 @@ func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCod } dt := (*DataType)(plan) + if dt.Codec != nil { + sqlValue, err := dt.Codec.DecodeDatabaseSQLValue(ci, oid, formatCode, src) + if err != nil { + return err + } + return scanner.Scan(sqlValue) + } var err error switch formatCode { case BinaryFormatCode: @@ -650,46 +659,6 @@ func (scanPlanReflection) Scan(ci *ConnInfo, oid uint32, formatCode int16, src [ return scanUnknownType(oid, formatCode, src, dst) } -type scanPlanBinaryInt16 struct{} - -func (scanPlanBinaryInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - if len(src) != 2 { - return fmt.Errorf("invalid length for int2: %v", len(src)) - } - - if p, ok := (dst).(*int16); ok { - *p = int16(binary.BigEndian.Uint16(src)) - return nil - } - - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) -} - -type scanPlanBinaryInt32 struct{} - -func (scanPlanBinaryInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - if len(src) != 4 { - return fmt.Errorf("invalid length for int4: %v", len(src)) - } - - if p, ok := (dst).(*int32); ok { - *p = int32(binary.BigEndian.Uint32(src)) - return nil - } - - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) -} - type scanPlanBinaryInt64 struct{} func (scanPlanBinaryInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { @@ -860,28 +829,6 @@ func tryBaseTypeScanPlan(dst interface{}) (plan *baseTypeScanPlan, nextDst inter // PlanScan prepares a plan to scan a value into dst. func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan { - if oid != 0 { - if dt, ok := ci.DataTypeForOID(oid); ok && dt.Codec != nil { - if plan := dt.Codec.PlanScan(ci, oid, formatCode, dst, false); plan != nil { - return plan - } - - if pointerPointerPlan, nextDst, ok := tryPointerPointerScanPlan(dst); ok { - if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { - pointerPointerPlan.next = nextPlan - return pointerPointerPlan - } - } - - if baseTypePlan, nextDst, ok := tryBaseTypeScanPlan(dst); ok { - if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { - baseTypePlan.next = nextPlan - return baseTypePlan - } - } - } - } - switch formatCode { case BinaryFormatCode: switch dst.(type) { @@ -890,14 +837,6 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan case TextOID, VarcharOID: return scanPlanString{} } - case *int16: - if oid == Int2OID { - return scanPlanBinaryInt16{} - } - case *int32: - if oid == Int4OID { - return scanPlanBinaryInt32{} - } case *int64: if oid == Int8OID { return scanPlanBinaryInt64{} @@ -943,6 +882,26 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan } } + if dt != nil && dt.Codec != nil { + if plan := dt.Codec.PlanScan(ci, oid, formatCode, dst, false); plan != nil { + return plan + } + + if pointerPointerPlan, nextDst, ok := tryPointerPointerScanPlan(dst); ok { + if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { + pointerPointerPlan.next = nextPlan + return pointerPointerPlan + } + } + + if baseTypePlan, nextDst, ok := tryBaseTypeScanPlan(dst); ok { + if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { + baseTypePlan.next = nextPlan + return baseTypePlan + } + } + } + if dt != nil { if _, ok := dst.(sql.Scanner); ok { if _, found := ci.preferAssignToOverSQLScannerTypes[reflect.TypeOf(dst)]; !found { diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 43c6c24b..b9dbda9d 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -244,9 +244,7 @@ func TestScanPlanBinaryInt32ScanChangedType(t *testing.T) { var d pgtype.Int4 err = plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &d) - require.NoError(t, err) - require.EqualValues(t, 42, d.Int) - require.True(t, d.Valid) + require.EqualError(t, err, pgtype.ErrScanTargetTypeChanged.Error()) } func BenchmarkConnInfoScanInt4IntoGoInt32(b *testing.B) { diff --git a/pgtype/pgxtype/pgxtype.go b/pgtype/pgxtype/pgxtype.go index ba14c094..49df7059 100644 --- a/pgtype/pgxtype/pgxtype.go +++ b/pgtype/pgxtype/pgxtype.go @@ -53,15 +53,16 @@ func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeNa at := pgtype.NewArrayType(typeName, elementOID, newElement) return pgtype.DataType{Value: at, Name: typeName, OID: oid}, nil case "c": // composite - fields, err := GetCompositeFields(ctx, conn, oid) - if err != nil { - return pgtype.DataType{}, err - } - ct, err := pgtype.NewCompositeType(typeName, fields, ci) - if err != nil { - return pgtype.DataType{}, err - } - return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil + panic("TODO - restore composite support") + // fields, err := GetCompositeFields(ctx, conn, oid) + // if err != nil { + // return pgtype.DataType{}, err + // } + // ct, err := pgtype.NewCompositeType(typeName, fields, ci) + // if err != nil { + // return pgtype.DataType{}, err + // } + // return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil case "e": // enum members, err := GetEnumMembers(ctx, conn, oid) if err != nil { @@ -84,40 +85,41 @@ func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32, return typelem, nil } +// TODO - restore composite support // GetCompositeFields gets the fields of a composite type. -func GetCompositeFields(ctx context.Context, conn Querier, oid uint32) ([]pgtype.CompositeTypeField, error) { - var typrelid uint32 +// func GetCompositeFields(ctx context.Context, conn Querier, oid uint32) ([]pgtype.CompositeTypeField, error) { +// var typrelid uint32 - err := conn.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid) - if err != nil { - return nil, err - } +// err := conn.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid) +// if err != nil { +// return nil, err +// } - var fields []pgtype.CompositeTypeField +// var fields []pgtype.CompositeTypeField - rows, err := conn.Query(ctx, `select attname, atttypid -from pg_attribute -where attrelid=$1 -order by attnum`, typrelid) - if err != nil { - return nil, err - } +// rows, err := conn.Query(ctx, `select attname, atttypid +// from pg_attribute +// where attrelid=$1 +// order by attnum`, typrelid) +// if err != nil { +// return nil, err +// } - for rows.Next() { - var f pgtype.CompositeTypeField - err := rows.Scan(&f.Name, &f.OID) - if err != nil { - return nil, err - } - fields = append(fields, f) - } +// for rows.Next() { +// var f pgtype.CompositeTypeField +// err := rows.Scan(&f.Name, &f.OID) +// if err != nil { +// return nil, err +// } +// fields = append(fields, f) +// } - if rows.Err() != nil { - return nil, rows.Err() - } +// if rows.Err() != nil { +// return nil, rows.Err() +// } - return fields, nil -} +// return fields, nil +// } // GetEnumMembers gets the possible values of the enum by oid. func GetEnumMembers(ctx context.Context, conn Querier, oid uint32) ([]string, error) { diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index ae0e67cb..3766c3f8 100755 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -1,4 +1,3 @@ -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[]*bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time,[]*time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go diff --git a/pgtype/zeronull/int.go b/pgtype/zeronull/int.go new file mode 100644 index 00000000..0149834b --- /dev/null +++ b/pgtype/zeronull/int.go @@ -0,0 +1,148 @@ +// Do not edit. Generated from pgtype/zeronull/int.go.erb +package zeronull + +import ( + "database/sql/driver" + "fmt" + "math" + + "github.com/jackc/pgx/v5/pgtype" +) + +type Int2 int16 + +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int2) ScanInt64(n int64, valid bool) error { + if !valid { + *dst = 0 + return nil + } + + if n < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) + } + if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) + } + *dst = Int2(n) + + return nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int2) Scan(src interface{}) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int2 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int2(nullable.Int) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int2) Value() (driver.Value, error) { + if src == 0 { + return nil, nil + } + return int64(src), nil +} + +type Int4 int32 + +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int4) ScanInt64(n int64, valid bool) error { + if !valid { + *dst = 0 + return nil + } + + if n < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n) + } + if n > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n) + } + *dst = Int4(n) + + return nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4) Scan(src interface{}) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int4 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int4(nullable.Int) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4) Value() (driver.Value, error) { + if src == 0 { + return nil, nil + } + return int64(src), nil +} + +type Int8 int64 + +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int8) ScanInt64(n int64, valid bool) error { + if !valid { + *dst = 0 + return nil + } + + if n < math.MinInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n) + } + if n > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n) + } + *dst = Int8(n) + + return nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8) Scan(src interface{}) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int8 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int8(nullable.Int) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8) Value() (driver.Value, error) { + if src == 0 { + return nil, nil + } + return int64(src), nil +} diff --git a/pgtype/zeronull/int.go.erb b/pgtype/zeronull/int.go.erb new file mode 100644 index 00000000..935b56a9 --- /dev/null +++ b/pgtype/zeronull/int.go.erb @@ -0,0 +1,58 @@ +package zeronull + +import ( + "database/sql/driver" + "fmt" + "math" + + "github.com/jackc/pgx/v5/pgtype" +) + +<% [2, 4, 8].each do |pg_byte_size| %> +<% pg_bit_size = pg_byte_size * 8 %> +type Int<%= pg_byte_size %> int<%= pg_bit_size %> + +// ScanInt64 implements the Int64Scanner interface. +func (dst *Int<%= pg_byte_size %>) ScanInt64(n int64, valid bool) error { + if !valid { + *dst = 0 + return nil + } + + if n < math.MinInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + } + if n > math.MaxInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + } + *dst = Int<%= pg_byte_size %>(n) + + return nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int<%= pg_byte_size %>) Scan(src interface{}) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int<%= pg_byte_size %> + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int<%= pg_byte_size %>(nullable.Int) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { + if src == 0 { + return nil, nil + } + return int64(src), nil +} +<% end %> diff --git a/pgtype/zeronull/int2.go b/pgtype/zeronull/int2.go deleted file mode 100644 index 2f63d8cc..00000000 --- a/pgtype/zeronull/int2.go +++ /dev/null @@ -1,55 +0,0 @@ -package zeronull - -import ( - "database/sql/driver" - "fmt" - "math" - - "github.com/jackc/pgx/v5/pgtype" -) - -type Int2 int16 - -// ScanInt64 implements the Int64Scanner interface. -func (dst *Int2) ScanInt64(n int64, valid bool) error { - if !valid { - *dst = 0 - return nil - } - - if n < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n) - } - if n > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n) - } - *dst = Int2(n) - - return nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int2) Scan(src interface{}) error { - if src == nil { - *dst = 0 - return nil - } - - var nullable pgtype.Int2 - err := nullable.Scan(src) - if err != nil { - return err - } - - *dst = Int2(nullable.Int) - - return nil -} - -// Value implements the database/sql/driver Valuer interface. -func (src Int2) Value() (driver.Value, error) { - if src == 0 { - return nil, nil - } - return int64(src), nil -} diff --git a/pgtype/zeronull/int2_test.go b/pgtype/zeronull/int2_test.go deleted file mode 100644 index ff78a6e6..00000000 --- a/pgtype/zeronull/int2_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package zeronull_test - -import ( - "testing" - - "github.com/jackc/pgx/v5/pgtype/testutil" - "github.com/jackc/pgx/v5/pgtype/zeronull" -) - -func TestInt2Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ - (zeronull.Int2)(1), - (zeronull.Int2)(0), - }) -} - -func TestInt2ConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "int2", (zeronull.Int2)(0)) -} - -func TestInt2ConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "int2", (zeronull.Int2)(0)) -} diff --git a/pgtype/zeronull/int4.go b/pgtype/zeronull/int4.go deleted file mode 100644 index 4e06435a..00000000 --- a/pgtype/zeronull/int4.go +++ /dev/null @@ -1,90 +0,0 @@ -package zeronull - -import ( - "database/sql/driver" - - "github.com/jackc/pgx/v5/pgtype" -) - -type Int4 int32 - -func (dst *Int4) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - var nullable pgtype.Int4 - err := nullable.DecodeText(ci, src) - if err != nil { - return err - } - - if nullable.Valid { - *dst = Int4(nullable.Int) - } else { - *dst = 0 - } - - return nil -} - -func (dst *Int4) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - var nullable pgtype.Int4 - err := nullable.DecodeBinary(ci, src) - if err != nil { - return err - } - - if nullable.Valid { - *dst = Int4(nullable.Int) - } else { - *dst = 0 - } - - return nil -} - -func (src Int4) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if src == 0 { - return nil, nil - } - - nullable := pgtype.Int4{ - Int: int32(src), - Valid: true, - } - - return nullable.EncodeText(ci, buf) -} - -func (src Int4) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if src == 0 { - return nil, nil - } - - nullable := pgtype.Int4{ - Int: int32(src), - Valid: true, - } - - return nullable.EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int4) Scan(src interface{}) error { - if src == nil { - *dst = 0 - return nil - } - - var nullable pgtype.Int4 - err := nullable.Scan(src) - if err != nil { - return err - } - - *dst = Int4(nullable.Int) - - return nil -} - -// Value implements the database/sql/driver Valuer interface. -func (src Int4) Value() (driver.Value, error) { - return pgtype.EncodeValueText(src) -} diff --git a/pgtype/zeronull/int4_test.go b/pgtype/zeronull/int4_test.go deleted file mode 100644 index 3510aa9d..00000000 --- a/pgtype/zeronull/int4_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package zeronull_test - -import ( - "testing" - - "github.com/jackc/pgx/v5/pgtype/testutil" - "github.com/jackc/pgx/v5/pgtype/zeronull" -) - -func TestInt4Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ - (zeronull.Int4)(1), - (zeronull.Int4)(0), - }) -} - -func TestInt4ConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "int4", (zeronull.Int4)(0)) -} - -func TestInt4ConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "int4", (zeronull.Int4)(0)) -} diff --git a/pgtype/zeronull/int8.go b/pgtype/zeronull/int8.go deleted file mode 100644 index 3c89a1ec..00000000 --- a/pgtype/zeronull/int8.go +++ /dev/null @@ -1,90 +0,0 @@ -package zeronull - -import ( - "database/sql/driver" - - "github.com/jackc/pgx/v5/pgtype" -) - -type Int8 int64 - -func (dst *Int8) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - var nullable pgtype.Int8 - err := nullable.DecodeText(ci, src) - if err != nil { - return err - } - - if nullable.Valid { - *dst = Int8(nullable.Int) - } else { - *dst = 0 - } - - return nil -} - -func (dst *Int8) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - var nullable pgtype.Int8 - err := nullable.DecodeBinary(ci, src) - if err != nil { - return err - } - - if nullable.Valid { - *dst = Int8(nullable.Int) - } else { - *dst = 0 - } - - return nil -} - -func (src Int8) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if src == 0 { - return nil, nil - } - - nullable := pgtype.Int8{ - Int: int64(src), - Valid: true, - } - - return nullable.EncodeText(ci, buf) -} - -func (src Int8) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if src == 0 { - return nil, nil - } - - nullable := pgtype.Int8{ - Int: int64(src), - Valid: true, - } - - return nullable.EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int8) Scan(src interface{}) error { - if src == nil { - *dst = 0 - return nil - } - - var nullable pgtype.Int8 - err := nullable.Scan(src) - if err != nil { - return err - } - - *dst = Int8(nullable.Int) - - return nil -} - -// Value implements the database/sql/driver Valuer interface. -func (src Int8) Value() (driver.Value, error) { - return pgtype.EncodeValueText(src) -} diff --git a/pgtype/zeronull/int8_test.go b/pgtype/zeronull/int8_test.go deleted file mode 100644 index 97fe9cd0..00000000 --- a/pgtype/zeronull/int8_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package zeronull_test - -import ( - "testing" - - "github.com/jackc/pgx/v5/pgtype/testutil" - "github.com/jackc/pgx/v5/pgtype/zeronull" -) - -func TestInt8Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ - (zeronull.Int8)(1), - (zeronull.Int8)(0), - }) -} - -func TestInt8ConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "int8", (zeronull.Int8)(0)) -} - -func TestInt8ConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "int8", (zeronull.Int8)(0)) -} diff --git a/pgtype/zeronull/int_test.go b/pgtype/zeronull/int_test.go new file mode 100644 index 00000000..bd2ef0b2 --- /dev/null +++ b/pgtype/zeronull/int_test.go @@ -0,0 +1,54 @@ +// Do not edit. Generated from pgtype/zeronull/int_test.go.erb +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype/zeronull" +) + +func TestInt2Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ + (zeronull.Int2)(1), + (zeronull.Int2)(0), + }) +} + +func TestInt2ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "int2", (zeronull.Int2)(0)) +} + +func TestInt2ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "int2", (zeronull.Int2)(0)) +} + +func TestInt4Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ + (zeronull.Int4)(1), + (zeronull.Int4)(0), + }) +} + +func TestInt4ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "int4", (zeronull.Int4)(0)) +} + +func TestInt4ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "int4", (zeronull.Int4)(0)) +} + +func TestInt8Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ + (zeronull.Int8)(1), + (zeronull.Int8)(0), + }) +} + +func TestInt8ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "int8", (zeronull.Int8)(0)) +} + +func TestInt8ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "int8", (zeronull.Int8)(0)) +} diff --git a/pgtype/zeronull/int_test.go.erb b/pgtype/zeronull/int_test.go.erb new file mode 100644 index 00000000..51273710 --- /dev/null +++ b/pgtype/zeronull/int_test.go.erb @@ -0,0 +1,26 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype/zeronull" +) + +<% [2, 4, 8].each do |pg_byte_size| %> +<% pg_bit_size = pg_byte_size * 8 %> +func TestInt<%= pg_byte_size %>Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int<%= pg_byte_size %>", []interface{}{ + (zeronull.Int<%= pg_byte_size %>)(1), + (zeronull.Int<%= pg_byte_size %>)(0), + }) +} + +func TestInt<%= pg_byte_size %>ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "int<%= pg_byte_size %>", (zeronull.Int<%= pg_byte_size %>)(0)) +} + +func TestInt<%= pg_byte_size %>ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "int<%= pg_byte_size %>", (zeronull.Int<%= pg_byte_size %>)(0)) +} +<% end %> diff --git a/pgtype/zzz.int4.go b/pgtype/zzz.int4.go deleted file mode 100644 index bd7f9bda..00000000 --- a/pgtype/zzz.int4.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Int4) BinaryFormatSupported() bool { - return true -} - -func (Int4) TextFormatSupported() bool { - return true -} - -func (Int4) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Int4) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Int4) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.int8.go b/pgtype/zzz.int8.go deleted file mode 100644 index d6e98262..00000000 --- a/pgtype/zzz.int8.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Int8) BinaryFormatSupported() bool { - return true -} - -func (Int8) TextFormatSupported() bool { - return true -} - -func (Int8) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Int8) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Int8) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/query_test.go b/query_test.go index d9b35e28..c0fbebaf 100644 --- a/query_test.go +++ b/query_test.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "os" - "reflect" "strconv" "strings" "testing" @@ -263,70 +262,9 @@ func TestConnQueryReadRowMultipleTimes(t *testing.T) { require.Equal(t, int32(10), rowCount) } -// https://github.com/jackc/pgx/issues/386 -func TestConnQueryValuesWithMultipleComplexColumnsOfSameType(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - expected0 := &pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {Int: 3, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, - Valid: true, - } - - expected1 := &pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 4, Valid: true}, - {Int: 5, Valid: true}, - {Int: 6, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, - Valid: true, - } - - var rowCount int32 - - rows, err := conn.Query(context.Background(), "select '{1,2,3}'::bigint[], '{4,5,6}'::bigint[] from generate_series(1,$1) n", 10) - if err != nil { - t.Fatalf("conn.Query failed: %v", err) - } - defer rows.Close() - - for rows.Next() { - rowCount++ - - values, err := rows.Values() - if err != nil { - t.Fatalf("rows.Values failed: %v", err) - } - if len(values) != 2 { - t.Errorf("Expected rows.Values to return 2 values, but it returned %d", len(values)) - } - if !reflect.DeepEqual(values[0], *expected0) { - t.Errorf(`Expected values[0] to be %v, but it was %v`, *expected0, values[0]) - } - if !reflect.DeepEqual(values[1], *expected1) { - t.Errorf(`Expected values[1] to be %v, but it was %v`, *expected1, values[1]) - } - } - - if rows.Err() != nil { - t.Fatalf("conn.Query failed: %v", err) - } - - if rowCount != 10 { - t.Error("Select called onDataRow wrong number of times") - } -} - // https://github.com/jackc/pgx/issues/228 func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T) { + t.Skip("TODO - unskip later in v5") t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -431,6 +369,7 @@ func TestConnQueryCloseEarlyWithErrorOnWire(t *testing.T) { // Test that a connection stays valid when query results read incorrectly func TestConnQueryReadWrongTypeError(t *testing.T) { + t.Skip("TODO - unskip later in v5") t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) diff --git a/values_test.go b/values_test.go index 27fbe977..82a6496a 100644 --- a/values_test.go +++ b/values_test.go @@ -989,6 +989,7 @@ func TestEncodeTypeRename(t *testing.T) { // https://github.com/jackc/pgx/issues/810 func TestRowsScanNilThenScanValue(t *testing.T) { + t.Skip("TODO - unskip later in v5") t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { @@ -1021,6 +1022,7 @@ order by a nulls first } func TestScanIntoByteSlice(t *testing.T) { + t.Skip("TODO - unskip later in v5") t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) From 4b1121c2a9996ebfce3aa5e7102578f1300b36ae Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Jan 2022 18:18:47 -0600 Subject: [PATCH 0799/1158] Convert bool to Codec --- pgtype/bool.go | 342 +++++++++++++++++--------- pgtype/bool_array.go | 504 -------------------------------------- pgtype/bool_array_test.go | 283 --------------------- pgtype/bool_test.go | 94 +------ pgtype/pgtype.go | 4 +- pgtype/zzz.bool.go | 35 --- 6 files changed, 230 insertions(+), 1032 deletions(-) delete mode 100644 pgtype/bool_array.go delete mode 100644 pgtype/bool_array_test.go delete mode 100644 pgtype/zzz.bool.go diff --git a/pgtype/bool.go b/pgtype/bool.go index 4fcc67e3..4b6fbaf2 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -7,136 +7,27 @@ import ( "strconv" ) +type BoolScanner interface { + ScanBool(v bool, valid bool) error +} + type Bool struct { Bool bool Valid bool } -func (dst *Bool) Set(src interface{}) error { - if src == nil { +// ScanBool implements the BoolScanner interface. +func (dst *Bool) ScanBool(v bool, valid bool) error { + if !valid { *dst = Bool{} return nil } - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case bool: - *dst = Bool{Bool: value, Valid: true} - case string: - bb, err := strconv.ParseBool(value) - if err != nil { - return err - } - *dst = Bool{Bool: bb, Valid: true} - case *bool: - if value == nil { - *dst = Bool{} - } else { - return dst.Set(*value) - } - case *string: - if value == nil { - *dst = Bool{} - } else { - return dst.Set(*value) - } - default: - if originalSrc, ok := underlyingBoolType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Bool", value) - } + *dst = Bool{Bool: v, Valid: true} return nil } -func (dst Bool) Get() interface{} { - if !dst.Valid { - return nil - } - - return dst.Bool -} - -func (src *Bool) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case *bool: - *v = src.Bool - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } -} - -func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Bool{} - return nil - } - - if len(src) != 1 { - return fmt.Errorf("invalid length for bool: %v", len(src)) - } - - *dst = Bool{Bool: src[0] == 't', Valid: true} - return nil -} - -func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Bool{} - return nil - } - - if len(src) != 1 { - return fmt.Errorf("invalid length for bool: %v", len(src)) - } - - *dst = Bool{Bool: src[0] == 1, Valid: true} - return nil -} - -func (src Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if src.Bool { - buf = append(buf, 't') - } else { - buf = append(buf, 'f') - } - - return buf, nil -} - -func (src Bool) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if src.Bool { - buf = append(buf, 1) - } else { - buf = append(buf, 0) - } - - return buf, nil -} - // Scan implements the database/sql Scanner interface. func (dst *Bool) Scan(src interface{}) error { if src == nil { @@ -149,11 +40,19 @@ func (dst *Bool) Scan(src interface{}) error { *dst = Bool{Bool: src, Valid: true} return nil case string: - return dst.DecodeText(nil, []byte(src)) + b, err := strconv.ParseBool(src) + if err != nil { + return err + } + *dst = Bool{Bool: b, Valid: true} + return nil case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + b, err := strconv.ParseBool(string(src)) + if err != nil { + return err + } + *dst = Bool{Bool: b, Valid: true} + return nil } return fmt.Errorf("cannot scan %T", src) @@ -195,3 +94,204 @@ func (dst *Bool) UnmarshalJSON(b []byte) error { return nil } + +type BoolCodec struct{} + +func (BoolCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (BoolCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (BoolCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { + v, valid, err := convertToBoolForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to bool: %v", value, err) + } + if !valid { + return nil, nil + } + if value == nil { + return nil, nil + } + + switch format { + case BinaryFormatCode: + if v { + buf = append(buf, 1) + } else { + buf = append(buf, 0) + } + return buf, nil + case TextFormatCode: + if v { + buf = append(buf, 't') + } else { + buf = append(buf, 'f') + } + return buf, nil + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } +} + +func (BoolCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *bool: + return scanPlanBinaryBoolToBool{} + case BoolScanner: + return scanPlanBinaryBoolToBoolScanner{} + } + case TextFormatCode: + switch target.(type) { + case *bool: + return scanPlanTextAnyToBool{} + case BoolScanner: + return scanPlanTextAnyToBoolScanner{} + } + } + + return nil +} + +func (c BoolCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(ci, oid, format, src) +} + +func (c BoolCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var b bool + scanPlan := c.PlanScan(ci, oid, format, &b, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") + } + err := scanPlan.Scan(ci, oid, format, src, &b) + if err != nil { + return nil, err + } + return b, nil +} + +func convertToBoolForEncode(v interface{}) (b bool, valid bool, err error) { + if v == nil { + return false, false, nil + } + + switch v := v.(type) { + case bool: + return v, true, nil + case *bool: + if v == nil { + return false, false, nil + } + return *v, true, nil + case string: + bb, err := strconv.ParseBool(v) + if err != nil { + return false, false, err + } + return bb, true, nil + case *string: + if v == nil { + return false, false, nil + } + bb, err := strconv.ParseBool(*v) + if err != nil { + return false, false, err + } + return bb, true, nil + default: + if originalvalue, ok := underlyingBoolType(v); ok { + return convertToBoolForEncode(originalvalue) + } + return false, false, fmt.Errorf("cannot convert %v to bool", v) + } +} + +type scanPlanBinaryBoolToBool struct{} + +func (scanPlanBinaryBoolToBool) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + p, ok := (dst).(*bool) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = src[0] == 1 + + return nil +} + +type scanPlanTextAnyToBool struct{} + +func (scanPlanTextAnyToBool) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + p, ok := (dst).(*bool) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = src[0] == 't' + + return nil +} + +type scanPlanBinaryBoolToBoolScanner struct{} + +func (scanPlanBinaryBoolToBoolScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s, ok := (dst).(BoolScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanBool(false, false) + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + return s.ScanBool(src[0] == 1, true) +} + +type scanPlanTextAnyToBoolScanner struct{} + +func (scanPlanTextAnyToBoolScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s, ok := (dst).(BoolScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanBool(false, false) + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + return s.ScanBool(src[0] == 't', true) +} diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go deleted file mode 100644 index a282fd6b..00000000 --- a/pgtype/bool_array.go +++ /dev/null @@ -1,504 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type BoolArray struct { - Elements []Bool - Dimensions []ArrayDimension - Valid bool -} - -func (dst *BoolArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = BoolArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []bool: - if value == nil { - *dst = BoolArray{} - } else if len(value) == 0 { - *dst = BoolArray{Valid: true} - } else { - elements := make([]Bool, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BoolArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*bool: - if value == nil { - *dst = BoolArray{} - } else if len(value) == 0 { - *dst = BoolArray{Valid: true} - } else { - elements := make([]Bool, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BoolArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Bool: - if value == nil { - *dst = BoolArray{} - } else if len(value) == 0 { - *dst = BoolArray{Valid: true} - } else { - *dst = BoolArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = BoolArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for BoolArray", src) - } - if elementsLength == 0 { - *dst = BoolArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to BoolArray", src) - } - - *dst = BoolArray{ - Elements: make([]Bool, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Bool, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to BoolArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *BoolArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to BoolArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in BoolArray", err) - } - index++ - - return index, nil -} - -func (dst BoolArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *BoolArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]bool: - *v = make([]bool, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*bool: - *v = make([]*bool, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *BoolArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from BoolArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from BoolArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = BoolArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Bool - - if len(uta.Elements) > 0 { - elements = make([]Bool, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Bool - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = BoolArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = BoolArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = BoolArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Bool, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = BoolArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("bool"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "bool") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *BoolArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src BoolArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/bool_array_test.go b/pgtype/bool_array_test.go deleted file mode 100644 index 7de5612a..00000000 --- a/pgtype/bool_array_test.go +++ /dev/null @@ -1,283 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestBoolArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bool[]", []interface{}{ - &pgtype.BoolArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.BoolArray{}, - &pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {}, - {Bool: false, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestBoolArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.BoolArray - }{ - { - source: []bool{true}, - result: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]bool)(nil)), - result: pgtype.BoolArray{}, - }, - { - source: [][]bool{{true}, {false}}, - result: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, - result: pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]bool{{true}, {false}}, - result: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, - result: pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.BoolArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestBoolArrayAssignTo(t *testing.T) { - var boolSlice []bool - type _boolSlice []bool - var namedBoolSlice _boolSlice - var boolSliceDim2 [][]bool - var boolSliceDim4 [][][][]bool - var boolArrayDim2 [2][1]bool - var boolArrayDim4 [2][1][1][3]bool - - simpleTests := []struct { - src pgtype.BoolArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &boolSlice, - expected: []bool{true}, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &namedBoolSlice, - expected: _boolSlice{true}, - }, - { - src: pgtype.BoolArray{}, - dst: &boolSlice, - expected: (([]bool)(nil)), - }, - { - src: pgtype.BoolArray{Valid: true}, - dst: &boolSlice, - expected: []bool{}, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - expected: [][]bool{{true}, {false}}, - dst: &boolSliceDim2, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - expected: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, - dst: &boolSliceDim4, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - expected: [2][1]bool{{true}, {false}}, - dst: &boolArrayDim2, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - expected: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, - dst: &boolArrayDim4, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.BoolArray - dst interface{} - }{ - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &boolSlice, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &boolArrayDim2, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &boolSlice, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &boolArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go index 9a07491f..ec8c31d9 100644 --- a/pgtype/bool_test.go +++ b/pgtype/bool_test.go @@ -1,101 +1,21 @@ package pgtype_test import ( - "reflect" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestBoolTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bool", []interface{}{ - &pgtype.Bool{Bool: false, Valid: true}, - &pgtype.Bool{Bool: true, Valid: true}, - &pgtype.Bool{Bool: false}, +func TestBoolCodec(t *testing.T) { + testPgxCodec(t, "bool", []PgxTranscodeTestCase{ + {true, new(bool), isExpectedEq(true)}, + {false, new(bool), isExpectedEq(false)}, + {true, new(pgtype.Bool), isExpectedEq(pgtype.Bool{Bool: true, Valid: true})}, + {pgtype.Bool{}, new(pgtype.Bool), isExpectedEq(pgtype.Bool{})}, + {nil, new(*bool), isExpectedEq((*bool)(nil))}, }) } -func TestBoolSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Bool - }{ - {source: true, result: pgtype.Bool{Bool: true, Valid: true}}, - {source: false, result: pgtype.Bool{Bool: false, Valid: true}}, - {source: "true", result: pgtype.Bool{Bool: true, Valid: true}}, - {source: "false", result: pgtype.Bool{Bool: false, Valid: true}}, - {source: "t", result: pgtype.Bool{Bool: true, Valid: true}}, - {source: "f", result: pgtype.Bool{Bool: false, Valid: true}}, - {source: _bool(true), result: pgtype.Bool{Bool: true, Valid: true}}, - {source: _bool(false), result: pgtype.Bool{Bool: false, Valid: true}}, - {source: nil, result: pgtype.Bool{}}, - } - - for i, tt := range successfulTests { - var r pgtype.Bool - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestBoolAssignTo(t *testing.T) { - var b bool - var _b _bool - var pb *bool - var _pb *_bool - - simpleTests := []struct { - src pgtype.Bool - dst interface{} - expected interface{} - }{ - {src: pgtype.Bool{Bool: false, Valid: true}, dst: &b, expected: false}, - {src: pgtype.Bool{Bool: true, Valid: true}, dst: &b, expected: true}, - {src: pgtype.Bool{Bool: false, Valid: true}, dst: &_b, expected: _bool(false)}, - {src: pgtype.Bool{Bool: true, Valid: true}, dst: &_b, expected: _bool(true)}, - {src: pgtype.Bool{Bool: false}, dst: &pb, expected: ((*bool)(nil))}, - {src: pgtype.Bool{Bool: false}, dst: &_pb, expected: ((*_bool)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Bool - dst interface{} - expected interface{} - }{ - {src: pgtype.Bool{Bool: true, Valid: true}, dst: &pb, expected: true}, - {src: pgtype.Bool{Bool: true, Valid: true}, dst: &_pb, expected: _bool(true)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} - func TestBoolMarshalJSON(t *testing.T) { successfulTests := []struct { source pgtype.Bool diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index fe3fae44..372f755b 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -292,7 +292,7 @@ func NewConnInfo() *ConnInfo { ci := newConnInfo() ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID}) - ci.RegisterDataType(DataType{Value: &BoolArray{}, Name: "_bool", OID: BoolArrayOID}) + ci.RegisterDataType(DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementCodec: BoolCodec{}, ElementOID: BoolOID}}) ci.RegisterDataType(DataType{Value: &BPCharArray{}, Name: "_bpchar", OID: BPCharArrayOID}) ci.RegisterDataType(DataType{Value: &ByteaArray{}, Name: "_bytea", OID: ByteaArrayOID}) ci.RegisterDataType(DataType{Value: &CIDRArray{}, Name: "_cidr", OID: CIDRArrayOID}) @@ -311,7 +311,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &VarcharArray{}, Name: "_varchar", OID: VarcharArrayOID}) ci.RegisterDataType(DataType{Value: &ACLItem{}, Name: "aclitem", OID: ACLItemOID}) ci.RegisterDataType(DataType{Value: &Bit{}, Name: "bit", OID: BitOID}) - ci.RegisterDataType(DataType{Value: &Bool{}, Name: "bool", OID: BoolOID}) + ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) ci.RegisterDataType(DataType{Value: &Box{}, Name: "box", OID: BoxOID}) ci.RegisterDataType(DataType{Value: &BPChar{}, Name: "bpchar", OID: BPCharOID}) ci.RegisterDataType(DataType{Value: &Bytea{}, Name: "bytea", OID: ByteaOID}) diff --git a/pgtype/zzz.bool.go b/pgtype/zzz.bool.go deleted file mode 100644 index e6ed52de..00000000 --- a/pgtype/zzz.bool.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Bool) BinaryFormatSupported() bool { - return true -} - -func (Bool) TextFormatSupported() bool { - return true -} - -func (Bool) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Bool) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Bool) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} From 298a5f0dca6eac6ae3103ee39482c85702490df5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 3 Jan 2022 20:27:35 -0600 Subject: [PATCH 0800/1158] Convert box to Codec --- pgtype/box.go | 260 +++++++++++++++++++++++++++++---------------- pgtype/box_test.go | 32 ++++-- pgtype/pgtype.go | 2 +- pgtype/zzz.box.go | 35 ------ 4 files changed, 194 insertions(+), 135 deletions(-) delete mode 100644 pgtype/zzz.box.go diff --git a/pgtype/box.go b/pgtype/box.go index 868b40a2..438a4f21 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -11,32 +11,189 @@ import ( "github.com/jackc/pgio" ) +type BoxScanner interface { + ScanBox(v Box) error +} + +type BoxValuer interface { + BoxValue() (Box, error) +} + type Box struct { P [2]Vec2 Valid bool } -func (dst *Box) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Box", src) +func (b *Box) ScanBox(v Box) error { + *b = v + return nil } -func (dst Box) Get() interface{} { - if !dst.Valid { - return nil - } - return dst +func (b Box) BoxValue() (Box, error) { + return b, nil } -func (src *Box) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { +// Scan implements the database/sql Scanner interface. +func (dst *Box) Scan(src interface{}) error { if src == nil { *dst = Box{} return nil } + switch src := src.(type) { + case string: + return scanPlanTextAnyToBoxScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Box) Value() (driver.Value, error) { + buf, err := BoxCodec{}.Encode(nil, 0, TextFormatCode, src, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type BoxCodec struct{} + +func (BoxCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (BoxCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (BoxCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { + if value == nil { + return nil, nil + } + + var box Box + if v, ok := value.(BoxValuer); ok { + b, err := v.BoxValue() + if err != nil { + return nil, err + } + box = b + } else { + return nil, fmt.Errorf("cannot convert %v to box: %v", value, err) + } + + if !box.Valid { + return nil, nil + } + + switch format { + case BinaryFormatCode: + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].Y)) + return buf, nil + case TextFormatCode: + buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, + strconv.FormatFloat(box.P[0].X, 'f', -1, 64), + strconv.FormatFloat(box.P[0].Y, 'f', -1, 64), + strconv.FormatFloat(box.P[1].X, 'f', -1, 64), + strconv.FormatFloat(box.P[1].Y, 'f', -1, 64), + )...) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } +} + +func (BoxCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case BoxScanner: + return scanPlanBinaryBoxToBoxScanner{} + } + case TextFormatCode: + switch target.(type) { + case BoxScanner: + return scanPlanTextAnyToBoxScanner{} + } + } + + return nil +} + +func (c BoxCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if format == TextFormatCode { + return string(src), nil + } else { + box, err := c.DecodeValue(ci, oid, format, src) + if err != nil { + return nil, err + } + buf, err := c.Encode(ci, oid, TextFormatCode, box, nil) + if err != nil { + return nil, err + } + return string(buf), nil + } +} + +func (c BoxCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var box Box + scanPlan := c.PlanScan(ci, oid, format, &box, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") + } + err := scanPlan.Scan(ci, oid, format, src, &box) + if err != nil { + return nil, err + } + return box, nil +} + +type scanPlanBinaryBoxToBoxScanner struct{} + +func (scanPlanBinaryBoxToBoxScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(BoxScanner) + + if src == nil { + return scanner.ScanBox(Box{}) + } + + if len(src) != 32 { + return fmt.Errorf("invalid length for Box: %v", len(src)) + } + + x1 := binary.BigEndian.Uint64(src) + y1 := binary.BigEndian.Uint64(src[8:]) + x2 := binary.BigEndian.Uint64(src[16:]) + y2 := binary.BigEndian.Uint64(src[24:]) + + return scanner.ScanBox(Box{ + P: [2]Vec2{ + {math.Float64frombits(x1), math.Float64frombits(y1)}, + {math.Float64frombits(x2), math.Float64frombits(y2)}, + }, + Valid: true, + }) +} + +type scanPlanTextAnyToBoxScanner struct{} + +func (scanPlanTextAnyToBoxScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(BoxScanner) + + if src == nil { + return scanner.ScanBox(Box{}) + } + if len(src) < 11 { return fmt.Errorf("invalid length for Box: %v", len(src)) } @@ -74,82 +231,5 @@ func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true} - return nil -} - -func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Box{} - return nil - } - - if len(src) != 32 { - return fmt.Errorf("invalid length for Box: %v", len(src)) - } - - x1 := binary.BigEndian.Uint64(src) - y1 := binary.BigEndian.Uint64(src[8:]) - x2 := binary.BigEndian.Uint64(src[16:]) - y2 := binary.BigEndian.Uint64(src[24:]) - - *dst = Box{ - P: [2]Vec2{ - {math.Float64frombits(x1), math.Float64frombits(y1)}, - {math.Float64frombits(x2), math.Float64frombits(y2)}, - }, - Valid: true, - } - return nil -} - -func (src Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, - strconv.FormatFloat(src.P[0].X, 'f', -1, 64), - strconv.FormatFloat(src.P[0].Y, 'f', -1, 64), - strconv.FormatFloat(src.P[1].X, 'f', -1, 64), - strconv.FormatFloat(src.P[1].Y, 'f', -1, 64), - )...) - return buf, nil -} - -func (src Box) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y)) - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Box) Scan(src interface{}) error { - if src == nil { - *dst = Box{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Box) Value() (driver.Value, error) { - return EncodeValueText(src) + return scanner.ScanBox(Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true}) } diff --git a/pgtype/box_test.go b/pgtype/box_test.go index 481723b5..f4e26370 100644 --- a/pgtype/box_test.go +++ b/pgtype/box_test.go @@ -7,17 +7,31 @@ import ( "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestBoxTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "box", []interface{}{ - &pgtype.Box{ - P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, - Valid: true, +func TestBoxCodec(t *testing.T) { + testPgxCodec(t, "box", []PgxTranscodeTestCase{ + { + pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, + Valid: true, + }, + new(pgtype.Box), + isExpectedEq(pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, + Valid: true, + }), }, - &pgtype.Box{ - P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, - Valid: true, + { + pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {-13.14, -5.234}}, + Valid: true, + }, + new(pgtype.Box), + isExpectedEq(pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {-13.14, -5.234}}, + Valid: true, + }), }, - &pgtype.Box{}, + {nil, new(pgtype.Box), isExpectedEq(pgtype.Box{})}, }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 372f755b..b72255ce 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -312,7 +312,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &ACLItem{}, Name: "aclitem", OID: ACLItemOID}) ci.RegisterDataType(DataType{Value: &Bit{}, Name: "bit", OID: BitOID}) ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) - ci.RegisterDataType(DataType{Value: &Box{}, Name: "box", OID: BoxOID}) + ci.RegisterDataType(DataType{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) ci.RegisterDataType(DataType{Value: &BPChar{}, Name: "bpchar", OID: BPCharOID}) ci.RegisterDataType(DataType{Value: &Bytea{}, Name: "bytea", OID: ByteaOID}) ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) diff --git a/pgtype/zzz.box.go b/pgtype/zzz.box.go deleted file mode 100644 index 5ca2df43..00000000 --- a/pgtype/zzz.box.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Box) BinaryFormatSupported() bool { - return true -} - -func (Box) TextFormatSupported() bool { - return true -} - -func (Box) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Box) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Box) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} From 5c4560eed3a20eabb88a6ea086ccfc91c89b2295 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 3 Jan 2022 20:30:57 -0600 Subject: [PATCH 0801/1158] Add box array --- pgtype/pgtype.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index b72255ce..aedc6dd5 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -49,6 +49,7 @@ const ( BPCharArrayOID = 1014 VarcharArrayOID = 1015 Int8ArrayOID = 1016 + BoxArrayOID = 1020 Float4ArrayOID = 1021 Float8ArrayOID = 1022 ACLItemOID = 1033 @@ -303,6 +304,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementCodec: Int2Codec{}, ElementOID: Int2OID}}) ci.RegisterDataType(DataType{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementCodec: Int4Codec{}, ElementOID: Int4OID}}) ci.RegisterDataType(DataType{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementCodec: Int8Codec{}, ElementOID: Int8OID}}) + ci.RegisterDataType(DataType{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementCodec: BoxCodec{}, ElementOID: BoxOID}}) ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) ci.RegisterDataType(DataType{Value: &TextArray{}, Name: "_text", OID: TextArrayOID}) ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID}) From eb2c37a983044f67bdf46e64dd0584309dd2f1fa Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 3 Jan 2022 20:53:50 -0600 Subject: [PATCH 0802/1158] Convert circle to Codec --- pgtype/circle.go | 249 ++++++++++++++++++++++++++++-------------- pgtype/circle_test.go | 17 ++- pgtype/pgtype.go | 2 +- pgtype/zzz.circle.go | 35 ------ 4 files changed, 177 insertions(+), 126 deletions(-) delete mode 100644 pgtype/zzz.circle.go diff --git a/pgtype/circle.go b/pgtype/circle.go index 7524d7b9..ec136438 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -11,33 +11,184 @@ import ( "github.com/jackc/pgio" ) +type CircleScanner interface { + ScanCircle(v Circle) error +} + +type CircleValuer interface { + CircleValue() (Circle, error) +} + type Circle struct { P Vec2 R float64 Valid bool } -func (dst *Circle) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Circle", src) +func (c *Circle) ScanCircle(v Circle) error { + *c = v + return nil } -func (dst Circle) Get() interface{} { - if !dst.Valid { - return nil - } - return dst +func (c Circle) CircleValue() (Circle, error) { + return c, nil } -func (src *Circle) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { +// Scan implements the database/sql Scanner interface. +func (dst *Circle) Scan(src interface{}) error { if src == nil { *dst = Circle{} return nil } + switch src := src.(type) { + case string: + return scanPlanTextAnyToCircleScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Circle) Value() (driver.Value, error) { + buf, err := CircleCodec{}.Encode(nil, 0, TextFormatCode, src, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type CircleCodec struct{} + +func (CircleCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (CircleCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (CircleCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { + if value == nil { + return nil, nil + } + + var circle Circle + if v, ok := value.(CircleValuer); ok { + c, err := v.CircleValue() + if err != nil { + return nil, err + } + circle = c + } else { + return nil, fmt.Errorf("cannot convert %v to circle: %v", value, err) + } + + if !circle.Valid { + return nil, nil + } + + switch format { + case BinaryFormatCode: + buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(circle.R)) + return buf, nil + case TextFormatCode: + buf = append(buf, fmt.Sprintf(`<(%s,%s),%s>`, + strconv.FormatFloat(circle.P.X, 'f', -1, 64), + strconv.FormatFloat(circle.P.Y, 'f', -1, 64), + strconv.FormatFloat(circle.R, 'f', -1, 64), + )...) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } +} + +func (CircleCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case CircleScanner: + return scanPlanBinaryCircleToCircleScanner{} + } + case TextFormatCode: + switch target.(type) { + case CircleScanner: + return scanPlanTextAnyToCircleScanner{} + } + } + + return nil +} + +func (c CircleCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if format == TextFormatCode { + return string(src), nil + } else { + circle, err := c.DecodeValue(ci, oid, format, src) + if err != nil { + return nil, err + } + buf, err := c.Encode(ci, oid, TextFormatCode, circle, nil) + if err != nil { + return nil, err + } + return string(buf), nil + } +} + +func (c CircleCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var circle Circle + scanPlan := c.PlanScan(ci, oid, format, &circle, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") + } + err := scanPlan.Scan(ci, oid, format, src, &circle) + if err != nil { + return nil, err + } + return circle, nil +} + +type scanPlanBinaryCircleToCircleScanner struct{} + +func (scanPlanBinaryCircleToCircleScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(CircleScanner) + + if src == nil { + return scanner.ScanCircle(Circle{}) + } + + if len(src) != 24 { + return fmt.Errorf("invalid length for Circle: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + r := binary.BigEndian.Uint64(src[16:]) + + return scanner.ScanCircle(Circle{ + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + R: math.Float64frombits(r), + Valid: true, + }) +} + +type scanPlanTextAnyToCircleScanner struct{} + +func (scanPlanTextAnyToCircleScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(CircleScanner) + + if src == nil { + return scanner.ScanCircle(Circle{}) + } + if len(src) < 9 { return fmt.Errorf("invalid length for Circle: %v", len(src)) } @@ -64,77 +215,5 @@ func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Circle{P: Vec2{x, y}, R: r, Valid: true} - return nil -} - -func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Circle{} - return nil - } - - if len(src) != 24 { - return fmt.Errorf("invalid length for Circle: %v", len(src)) - } - - x := binary.BigEndian.Uint64(src) - y := binary.BigEndian.Uint64(src[8:]) - r := binary.BigEndian.Uint64(src[16:]) - - *dst = Circle{ - P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, - R: math.Float64frombits(r), - Valid: true, - } - return nil -} - -func (src Circle) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = append(buf, fmt.Sprintf(`<(%s,%s),%s>`, - strconv.FormatFloat(src.P.X, 'f', -1, 64), - strconv.FormatFloat(src.P.Y, 'f', -1, 64), - strconv.FormatFloat(src.R, 'f', -1, 64), - )...) - - return buf, nil -} - -func (src Circle) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.R)) - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Circle) Scan(src interface{}) error { - if src == nil { - *dst = Circle{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Circle) Value() (driver.Value, error) { - return EncodeValueText(src) + return scanner.ScanCircle(Circle{P: Vec2{x, y}, R: r, Valid: true}) } diff --git a/pgtype/circle_test.go b/pgtype/circle_test.go index 8f39644b..742ac688 100644 --- a/pgtype/circle_test.go +++ b/pgtype/circle_test.go @@ -4,13 +4,20 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestCircleTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "circle", []interface{}{ - &pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, - &pgtype.Circle{P: pgtype.Vec2{-1.234, -5.6789}, R: 12.9, Valid: true}, - &pgtype.Circle{}, + testPgxCodec(t, "circle", []PgxTranscodeTestCase{ + { + pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, + new(pgtype.Circle), + isExpectedEq(pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}), + }, + { + pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, + new(pgtype.Circle), + isExpectedEq(pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}), + }, + {nil, new(pgtype.Circle), isExpectedEq(pgtype.Circle{})}, }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index aedc6dd5..11f7ce0b 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -320,7 +320,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) ci.RegisterDataType(DataType{Value: &CID{}, Name: "cid", OID: CIDOID}) ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID}) - ci.RegisterDataType(DataType{Value: &Circle{}, Name: "circle", OID: CircleOID}) + ci.RegisterDataType(DataType{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) ci.RegisterDataType(DataType{Value: &Date{}, Name: "date", OID: DateOID}) // ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID}) diff --git a/pgtype/zzz.circle.go b/pgtype/zzz.circle.go deleted file mode 100644 index b111c06d..00000000 --- a/pgtype/zzz.circle.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Circle) BinaryFormatSupported() bool { - return true -} - -func (Circle) TextFormatSupported() bool { - return true -} - -func (Circle) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Circle) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Circle) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} From ad6ee2bd56187870b342436c3b33b1cb4755c5fe Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 3 Jan 2022 20:56:12 -0600 Subject: [PATCH 0803/1158] Add circle array --- pgtype/pgtype.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 11f7ce0b..6221227d 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -38,6 +38,7 @@ const ( Float4OID = 700 Float8OID = 701 CircleOID = 718 + CircleArrayOID = 719 UnknownOID = 705 MacaddrOID = 829 InetOID = 869 @@ -305,6 +306,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementCodec: Int4Codec{}, ElementOID: Int4OID}}) ci.RegisterDataType(DataType{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementCodec: Int8Codec{}, ElementOID: Int8OID}}) ci.RegisterDataType(DataType{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementCodec: BoxCodec{}, ElementOID: BoxOID}}) + ci.RegisterDataType(DataType{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementCodec: CircleCodec{}, ElementOID: CircleOID}}) ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) ci.RegisterDataType(DataType{Value: &TextArray{}, Name: "_text", OID: TextArrayOID}) ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID}) From f7c0c31e8785764fc8ed38ddde4d5aa6cc695376 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 3 Jan 2022 21:19:00 -0600 Subject: [PATCH 0804/1158] Extract DecodeValue helper --- pgtype/bool.go | 6 +----- pgtype/box.go | 6 +----- pgtype/circle.go | 6 +----- pgtype/int.go | 18 +++--------------- pgtype/int.go.erb | 6 +----- pgtype/pgtype.go | 8 ++++++++ 6 files changed, 15 insertions(+), 35 deletions(-) diff --git a/pgtype/bool.go b/pgtype/bool.go index 4b6fbaf2..36d29d40 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -169,11 +169,7 @@ func (c BoolCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt } var b bool - scanPlan := c.PlanScan(ci, oid, format, &b, true) - if scanPlan == nil { - return nil, fmt.Errorf("PlanScan did not find a plan") - } - err := scanPlan.Scan(ci, oid, format, src, &b) + err := codecScan(c, ci, oid, format, src, &b) if err != nil { return nil, err } diff --git a/pgtype/box.go b/pgtype/box.go index 438a4f21..67e24237 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -147,11 +147,7 @@ func (c BoxCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte } var box Box - scanPlan := c.PlanScan(ci, oid, format, &box, true) - if scanPlan == nil { - return nil, fmt.Errorf("PlanScan did not find a plan") - } - err := scanPlan.Scan(ci, oid, format, src, &box) + err := codecScan(c, ci, oid, format, src, &box) if err != nil { return nil, err } diff --git a/pgtype/circle.go b/pgtype/circle.go index ec136438..83c97453 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -145,11 +145,7 @@ func (c CircleCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []b } var circle Circle - scanPlan := c.PlanScan(ci, oid, format, &circle, true) - if scanPlan == nil { - return nil, fmt.Errorf("PlanScan did not find a plan") - } - err := scanPlan.Scan(ci, oid, format, src, &circle) + err := codecScan(c, ci, oid, format, src, &circle) if err != nil { return nil, err } diff --git a/pgtype/int.go b/pgtype/int.go index c8ec7509..609a58a3 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -226,11 +226,7 @@ func (c Int2Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt } var n int16 - scanPlan := c.PlanScan(ci, oid, format, &n, true) - if scanPlan == nil { - return nil, fmt.Errorf("PlanScan did not find a plan") - } - err := scanPlan.Scan(ci, oid, format, src, &n) + err := codecScan(c, ci, oid, format, src, &n) if err != nil { return nil, err } @@ -714,11 +710,7 @@ func (c Int4Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt } var n int32 - scanPlan := c.PlanScan(ci, oid, format, &n, true) - if scanPlan == nil { - return nil, fmt.Errorf("PlanScan did not find a plan") - } - err := scanPlan.Scan(ci, oid, format, src, &n) + err := codecScan(c, ci, oid, format, src, &n) if err != nil { return nil, err } @@ -1213,11 +1205,7 @@ func (c Int8Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt } var n int64 - scanPlan := c.PlanScan(ci, oid, format, &n, true) - if scanPlan == nil { - return nil, fmt.Errorf("PlanScan did not find a plan") - } - err := scanPlan.Scan(ci, oid, format, src, &n) + err := codecScan(c, ci, oid, format, src, &n) if err != nil { return nil, err } diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 99659d4c..6803c2ea 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -227,11 +227,7 @@ func (c Int<%= pg_byte_size %>Codec) DecodeValue(ci *ConnInfo, oid uint32, forma } var n int<%= pg_bit_size %> - scanPlan := c.PlanScan(ci, oid, format, &n, true) - if scanPlan == nil { - return nil, fmt.Errorf("PlanScan did not find a plan") - } - err := scanPlan.Scan(ci, oid, format, src, &n) + err := codecScan(c, ci, oid, format, src, &n) if err != nil { return nil, err } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 6221227d..d5e52830 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -960,3 +960,11 @@ func NewValue(v Value) Value { } var ErrScanTargetTypeChanged = errors.New("scan target type changed") + +func codecScan(codec Codec, ci *ConnInfo, oid uint32, format int16, src []byte, dst interface{}) error { + scanPlan := codec.PlanScan(ci, oid, format, dst, true) + if scanPlan == nil { + return fmt.Errorf("PlanScan did not find a plan") + } + return scanPlan.Scan(ci, oid, format, src, dst) +} From 6a32f938f1d733652a0f4cb8a2a6aeb5062b968f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 3 Jan 2022 21:23:29 -0600 Subject: [PATCH 0805/1158] Extract codecDecodeToTextFormat --- pgtype/box.go | 14 +------------- pgtype/circle.go | 14 +------------- pgtype/int.go | 18 +++--------------- pgtype/int.go.erb | 6 +----- pgtype/pgtype.go | 16 ++++++++++++++++ 5 files changed, 22 insertions(+), 46 deletions(-) diff --git a/pgtype/box.go b/pgtype/box.go index 67e24237..7db7d5a2 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -126,19 +126,7 @@ func (BoxCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfac } func (c BoxCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - if format == TextFormatCode { - return string(src), nil - } else { - box, err := c.DecodeValue(ci, oid, format, src) - if err != nil { - return nil, err - } - buf, err := c.Encode(ci, oid, TextFormatCode, box, nil) - if err != nil { - return nil, err - } - return string(buf), nil - } + return codecDecodeToTextFormat(c, ci, oid, format, src) } func (c BoxCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { diff --git a/pgtype/circle.go b/pgtype/circle.go index 83c97453..f1f66175 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -124,19 +124,7 @@ func (CircleCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target inter } func (c CircleCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - if format == TextFormatCode { - return string(src), nil - } else { - circle, err := c.DecodeValue(ci, oid, format, src) - if err != nil { - return nil, err - } - buf, err := c.Encode(ci, oid, TextFormatCode, circle, nil) - if err != nil { - return nil, err - } - return string(buf), nil - } + return codecDecodeToTextFormat(c, ci, oid, format, src) } func (c CircleCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { diff --git a/pgtype/int.go b/pgtype/int.go index 609a58a3..21259beb 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -209,11 +209,7 @@ func (c Int2Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16 } var n int64 - scanPlan := c.PlanScan(ci, oid, format, &n, true) - if scanPlan == nil { - return nil, fmt.Errorf("PlanScan did not find a plan") - } - err := scanPlan.Scan(ci, oid, format, src, &n) + err := codecScan(c, ci, oid, format, src, &n) if err != nil { return nil, err } @@ -693,11 +689,7 @@ func (c Int4Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16 } var n int64 - scanPlan := c.PlanScan(ci, oid, format, &n, true) - if scanPlan == nil { - return nil, fmt.Errorf("PlanScan did not find a plan") - } - err := scanPlan.Scan(ci, oid, format, src, &n) + err := codecScan(c, ci, oid, format, src, &n) if err != nil { return nil, err } @@ -1188,11 +1180,7 @@ func (c Int8Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16 } var n int64 - scanPlan := c.PlanScan(ci, oid, format, &n, true) - if scanPlan == nil { - return nil, fmt.Errorf("PlanScan did not find a plan") - } - err := scanPlan.Scan(ci, oid, format, src, &n) + err := codecScan(c, ci, oid, format, src, &n) if err != nil { return nil, err } diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 6803c2ea..546494d4 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -210,11 +210,7 @@ func (c Int<%= pg_byte_size %>Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid ui } var n int64 - scanPlan := c.PlanScan(ci, oid, format, &n, true) - if scanPlan == nil { - return nil, fmt.Errorf("PlanScan did not find a plan") - } - err := scanPlan.Scan(ci, oid, format, src, &n) + err := codecScan(c, ci, oid, format, src, &n) if err != nil { return nil, err } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index d5e52830..5e924e1b 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -968,3 +968,19 @@ func codecScan(codec Codec, ci *ConnInfo, oid uint32, format int16, src []byte, } return scanPlan.Scan(ci, oid, format, src, dst) } + +func codecDecodeToTextFormat(codec Codec, ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if format == TextFormatCode { + return string(src), nil + } else { + value, err := codec.DecodeValue(ci, oid, format, src) + if err != nil { + return nil, err + } + buf, err := codec.Encode(ci, oid, TextFormatCode, value, nil) + if err != nil { + return nil, err + } + return string(buf), nil + } +} From 80ae29d056efafc02740defe9358c0f8c6b1721b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Jan 2022 19:56:16 -0600 Subject: [PATCH 0806/1158] Inline Encoder interface to Codec --- pgtype/pgtype.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5e924e1b..db706369 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -177,13 +177,6 @@ type ResultDecoder interface { DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error } -type Encoder interface { - // Encode appends the encoded bytes of value to buf. If value is the SQL NULL then append nothing and return - // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data - // written. - Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) -} - type Codec interface { // FormatSupported returns true if the format is supported. FormatSupported(int16) bool @@ -191,7 +184,10 @@ type Codec interface { // PreferredFormat returns the preferred format. PreferredFormat() int16 - Encoder + // Encode appends the encoded bytes of value to buf. If value is the SQL NULL then append nothing and return + // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data + // written. + Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) // PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If // actualTarget is true then the returned ScanPlan may be optimized to directly scan into target. If no plan can be From b90f92d2d2543cd4d6d600546714da85bac7b530 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Jan 2022 19:58:40 -0600 Subject: [PATCH 0807/1158] Remove obsolute ArrayType --- pgtype/array_type.go | 368 -------------------------------------- pgtype/array_type_test.go | 84 --------- pgtype/pgxtype/pgxtype.go | 31 ++-- 3 files changed, 16 insertions(+), 467 deletions(-) delete mode 100644 pgtype/array_type.go delete mode 100644 pgtype/array_type_test.go diff --git a/pgtype/array_type.go b/pgtype/array_type.go deleted file mode 100644 index c4f162af..00000000 --- a/pgtype/array_type.go +++ /dev/null @@ -1,368 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -// ArrayType represents an array type. While it implements Value, this is only in service of its type conversion duties -// when registered as a data type in a ConnType. It should not be used directly as a Value. ArrayType is a convenience -// type for types that do not have an concrete array type. -type ArrayType struct { - elements []ValueTranscoder - dimensions []ArrayDimension - - typeName string - newElement func() ValueTranscoder - - elementOID uint32 - valid bool -} - -func NewArrayType(typeName string, elementOID uint32, newElement func() ValueTranscoder) *ArrayType { - return &ArrayType{typeName: typeName, elementOID: elementOID, newElement: newElement} -} - -func (at *ArrayType) NewTypeValue() Value { - return &ArrayType{ - elements: at.elements, - dimensions: at.dimensions, - valid: at.valid, - - typeName: at.typeName, - elementOID: at.elementOID, - newElement: at.newElement, - } -} - -func (at *ArrayType) TypeName() string { - return at.typeName -} - -func (dst *ArrayType) setNil() { - dst.elements = nil - dst.dimensions = nil - dst.valid = false -} - -func (dst *ArrayType) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - dst.setNil() - return nil - } - - sliceVal := reflect.ValueOf(src) - if sliceVal.Kind() != reflect.Slice { - return fmt.Errorf("cannot set non-slice") - } - - if sliceVal.IsNil() { - dst.setNil() - return nil - } - - dst.elements = make([]ValueTranscoder, sliceVal.Len()) - for i := range dst.elements { - v := dst.newElement() - err := v.Set(sliceVal.Index(i).Interface()) - if err != nil { - return err - } - - dst.elements[i] = v - } - dst.dimensions = []ArrayDimension{{Length: int32(len(dst.elements)), LowerBound: 1}} - dst.valid = true - - return nil -} - -func (src ArrayType) Get() interface{} { - if !src.valid { - return nil - } - - elementValues := make([]interface{}, len(src.elements)) - for i := range src.elements { - elementValues[i] = src.elements[i].Get() - } - return elementValues -} - -func (src *ArrayType) AssignTo(dst interface{}) error { - ptrSlice := reflect.ValueOf(dst) - if ptrSlice.Kind() != reflect.Ptr { - return fmt.Errorf("cannot assign to non-pointer") - } - - sliceVal := ptrSlice.Elem() - sliceType := sliceVal.Type() - - if sliceType.Kind() != reflect.Slice { - return fmt.Errorf("cannot assign to pointer to non-slice") - } - - if src.valid { - slice := reflect.MakeSlice(sliceType, len(src.elements), len(src.elements)) - elemType := sliceType.Elem() - - for i := range src.elements { - ptrElem := reflect.New(elemType) - err := src.elements[i].AssignTo(ptrElem.Interface()) - if err != nil { - return err - } - - slice.Index(i).Set(ptrElem.Elem()) - } - - sliceVal.Set(slice) - return nil - } else { - sliceVal.Set(reflect.Zero(sliceType)) - return nil - } -} - -func (ArrayType) BinaryFormatSupported() bool { - return true -} - -func (ArrayType) TextFormatSupported() bool { - return true -} - -func (ArrayType) PreferredFormat() int16 { - return TextFormatCode -} - -func (dst *ArrayType) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - if src == nil { - dst.setNil() - return nil - } - - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src ArrayType) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} - -func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error { - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []ValueTranscoder - - if len(uta.Elements) > 0 { - elements = make([]ValueTranscoder, len(uta.Elements)) - - for i, s := range uta.Elements { - elem := dst.newElement() - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeResult(ci, dst.elementOID, TextFormatCode, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - dst.elements = elements - dst.dimensions = uta.Dimensions - dst.valid = true - - return nil -} - -func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error { - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - var elements []ValueTranscoder - - if len(arrayHeader.Dimensions) == 0 { - dst.elements = elements - dst.dimensions = arrayHeader.Dimensions - dst.valid = true - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements = make([]ValueTranscoder, elementCount) - - for i := range elements { - elem := dst.newElement() - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elem.DecodeResult(ci, dst.elementOID, BinaryFormatCode, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - - dst.elements = elements - dst.dimensions = arrayHeader.Dimensions - dst.valid = true - - return nil -} - -func (src ArrayType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.valid { - return nil, nil - } - - if len(src.dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.dimensions)) - dimElemCounts[len(src.dimensions)-1] = int(src.dimensions[len(src.dimensions)-1].Length) - for i := len(src.dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeParam(ci, src.elementOID, TextFormatCode, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src ArrayType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.dimensions, - ElementOID: int32(src.elementOID), - } - - for i := range src.elements { - if src.elements[i].Get() == nil { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.elements[i].EncodeParam(ci, src.elementOID, BinaryFormatCode, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *ArrayType) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src ArrayType) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/array_type_test.go b/pgtype/array_type_test.go deleted file mode 100644 index 3ea5bc79..00000000 --- a/pgtype/array_type_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package pgtype_test - -import ( - "context" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" - "github.com/stretchr/testify/require" -) - -func TestArrayTypeValue(t *testing.T) { - arrayType := pgtype.NewArrayType("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }) - - err := arrayType.Set(nil) - require.NoError(t, err) - - gotValue := arrayType.Get() - require.Nil(t, gotValue) - - slice := []string{"foo", "bar"} - err = arrayType.AssignTo(&slice) - require.NoError(t, err) - require.Nil(t, slice) - - err = arrayType.Set([]string{}) - require.NoError(t, err) - - gotValue = arrayType.Get() - require.Len(t, gotValue, 0) - - err = arrayType.AssignTo(&slice) - require.NoError(t, err) - require.EqualValues(t, []string{}, slice) - - err = arrayType.Set([]string{"baz", "quz"}) - require.NoError(t, err) - - gotValue = arrayType.Get() - require.Len(t, gotValue, 2) - - err = arrayType.AssignTo(&slice) - require.NoError(t, err) - require.EqualValues(t, []string{"baz", "quz"}, slice) -} - -func TestArrayTypeTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - conn.ConnInfo().RegisterDataType(pgtype.DataType{ - Value: pgtype.NewArrayType("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }), - Name: "_text", - OID: pgtype.TextArrayOID, - }) - - var dstStrings []string - err := conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings) - require.NoError(t, err) - - require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) -} - -func TestArrayTypeEmptyArrayDoesNotBreakArrayType(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - conn.ConnInfo().RegisterDataType(pgtype.DataType{ - Value: pgtype.NewArrayType("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }), - Name: "_text", - OID: pgtype.TextArrayOID, - }) - - var dstStrings []string - err := conn.QueryRow(context.Background(), "select '{}'::text[]").Scan(&dstStrings) - require.NoError(t, err) - - require.EqualValues(t, []string{}, dstStrings) - - err = conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings) - require.NoError(t, err) - - require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) -} diff --git a/pgtype/pgxtype/pgxtype.go b/pgtype/pgxtype/pgxtype.go index 49df7059..4f2c5796 100644 --- a/pgtype/pgxtype/pgxtype.go +++ b/pgtype/pgxtype/pgxtype.go @@ -34,24 +34,25 @@ func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeNa switch typtype { case "b": // array - elementOID, err := GetArrayElementOID(ctx, conn, oid) - if err != nil { - return pgtype.DataType{}, err - } + panic("TODO - restore array support") + // elementOID, err := GetArrayElementOID(ctx, conn, oid) + // if err != nil { + // return pgtype.DataType{}, err + // } - var element pgtype.ValueTranscoder - if dt, ok := ci.DataTypeForOID(elementOID); ok { - if element, ok = dt.Value.(pgtype.ValueTranscoder); !ok { - return pgtype.DataType{}, errors.New("array element OID not registered as ValueTranscoder") - } - } + // var element pgtype.ValueTranscoder + // if dt, ok := ci.DataTypeForOID(elementOID); ok { + // if element, ok = dt.Value.(pgtype.ValueTranscoder); !ok { + // return pgtype.DataType{}, errors.New("array element OID not registered as ValueTranscoder") + // } + // } - newElement := func() pgtype.ValueTranscoder { - return pgtype.NewValue(element).(pgtype.ValueTranscoder) - } + // newElement := func() pgtype.ValueTranscoder { + // return pgtype.NewValue(element).(pgtype.ValueTranscoder) + // } - at := pgtype.NewArrayType(typeName, elementOID, newElement) - return pgtype.DataType{Value: at, Name: typeName, OID: oid}, nil + // at := pgtype.NewArrayType(typeName, elementOID, newElement) + // return pgtype.DataType{Value: at, Name: typeName, OID: oid}, nil case "c": // composite panic("TODO - restore composite support") // fields, err := GetCompositeFields(ctx, conn, oid) From 1a189db041cc05770fc2f077f6c0195f6b480e4f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Jan 2022 19:59:32 -0600 Subject: [PATCH 0808/1158] Remove ValueTranscoder interface --- pgtype/pgtype.go | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index db706369..31cf6038 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -148,14 +148,6 @@ type TypeValue interface { TypeName() string } -// ValueTranscoder is a value that implements the text and binary encoding and decoding interfaces. -type ValueTranscoder interface { - Value - FormatSupport - ParamEncoder - ResultDecoder -} - type FormatSupport interface { BinaryFormatSupported() bool TextFormatSupported() bool From ac80fa5b33c08f4fce1c89ccfdb06d47ce6dd32f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Jan 2022 20:04:48 -0600 Subject: [PATCH 0809/1158] Remove proposed v5 type system before Codec --- pgtype/pgtype.go | 58 +++++++----------------------------- pgtype/pgtype_test.go | 10 ------- pgtype/zzz.aclitem.go | 35 ---------------------- pgtype/zzz.bit.go | 35 ---------------------- pgtype/zzz.bpchar.go | 35 ---------------------- pgtype/zzz.bytea.go | 35 ---------------------- pgtype/zzz.cid.go | 35 ---------------------- pgtype/zzz.cidr.go | 35 ---------------------- pgtype/zzz.date.go | 35 ---------------------- pgtype/zzz.float4.go | 35 ---------------------- pgtype/zzz.float8.go | 35 ---------------------- pgtype/zzz.generic_binary.go | 35 ---------------------- pgtype/zzz.generic_text.go | 35 ---------------------- pgtype/zzz.hstore.go | 35 ---------------------- pgtype/zzz.inet.go | 35 ---------------------- pgtype/zzz.interval.go | 35 ---------------------- pgtype/zzz.json.go | 35 ---------------------- pgtype/zzz.jsonb.go | 35 ---------------------- pgtype/zzz.line.go | 35 ---------------------- pgtype/zzz.lseg.go | 35 ---------------------- pgtype/zzz.macadder.go | 35 ---------------------- pgtype/zzz.name.go | 35 ---------------------- pgtype/zzz.numeric.go | 35 ---------------------- pgtype/zzz.oid.go | 35 ---------------------- pgtype/zzz.oid_value.go | 35 ---------------------- pgtype/zzz.path.go | 35 ---------------------- pgtype/zzz.pguint32.go | 35 ---------------------- pgtype/zzz.point.go | 35 ---------------------- pgtype/zzz.polygon.go | 35 ---------------------- pgtype/zzz.qchar.go | 35 ---------------------- pgtype/zzz.text.go | 35 ---------------------- pgtype/zzz.tid.go | 35 ---------------------- pgtype/zzz.time.go | 35 ---------------------- pgtype/zzz.timestamp.go | 35 ---------------------- pgtype/zzz.timestamptz.go | 35 ---------------------- pgtype/zzz.uuid.go | 35 ---------------------- pgtype/zzz.varbit.go | 35 ---------------------- pgtype/zzz.varchar.go | 35 ---------------------- pgtype/zzz.xid.go | 35 ---------------------- values.go | 4 +-- 40 files changed, 12 insertions(+), 1355 deletions(-) delete mode 100644 pgtype/zzz.aclitem.go delete mode 100644 pgtype/zzz.bit.go delete mode 100644 pgtype/zzz.bpchar.go delete mode 100644 pgtype/zzz.bytea.go delete mode 100644 pgtype/zzz.cid.go delete mode 100644 pgtype/zzz.cidr.go delete mode 100644 pgtype/zzz.date.go delete mode 100644 pgtype/zzz.float4.go delete mode 100644 pgtype/zzz.float8.go delete mode 100644 pgtype/zzz.generic_binary.go delete mode 100644 pgtype/zzz.generic_text.go delete mode 100644 pgtype/zzz.hstore.go delete mode 100644 pgtype/zzz.inet.go delete mode 100644 pgtype/zzz.interval.go delete mode 100644 pgtype/zzz.json.go delete mode 100644 pgtype/zzz.jsonb.go delete mode 100644 pgtype/zzz.line.go delete mode 100644 pgtype/zzz.lseg.go delete mode 100644 pgtype/zzz.macadder.go delete mode 100644 pgtype/zzz.name.go delete mode 100644 pgtype/zzz.numeric.go delete mode 100644 pgtype/zzz.oid.go delete mode 100644 pgtype/zzz.oid_value.go delete mode 100644 pgtype/zzz.path.go delete mode 100644 pgtype/zzz.pguint32.go delete mode 100644 pgtype/zzz.point.go delete mode 100644 pgtype/zzz.polygon.go delete mode 100644 pgtype/zzz.qchar.go delete mode 100644 pgtype/zzz.text.go delete mode 100644 pgtype/zzz.tid.go delete mode 100644 pgtype/zzz.time.go delete mode 100644 pgtype/zzz.timestamp.go delete mode 100644 pgtype/zzz.timestamptz.go delete mode 100644 pgtype/zzz.uuid.go delete mode 100644 pgtype/zzz.varbit.go delete mode 100644 pgtype/zzz.varchar.go delete mode 100644 pgtype/zzz.xid.go diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 31cf6038..3d863373 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -148,27 +148,6 @@ type TypeValue interface { TypeName() string } -type FormatSupport interface { - BinaryFormatSupported() bool - TextFormatSupported() bool - PreferredFormat() int16 -} - -type ParamEncoder interface { - // EncodeParam should append the encoded value of self to buf. If self is the - // SQL value NULL then append nothing and return (nil, nil). The caller of - // EncodeText is responsible for writing the correct NULL value or the - // length of the data written. - EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) -} - -type ResultDecoder interface { - // DecodeResult decodes src into ResultDecoder. If src is nil then the - // original SQL value is NULL. ResultDecoder takes ownership of src. The - // caller MUST not use it again. - DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error -} - type Codec interface { // FormatSupported returns true if the format is supported. FormatSupported(int16) bool @@ -244,8 +223,6 @@ func (e *nullAssignmentError) Error() string { type DataType struct { Value Value - resultDecoder ResultDecoder - textDecoder TextDecoder binaryDecoder BinaryDecoder @@ -402,18 +379,12 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { var formatCode int16 if t.Codec != nil { formatCode = t.Codec.PreferredFormat() - } else if pfp, ok := t.Value.(FormatSupport); ok { - formatCode = pfp.PreferredFormat() } else if _, ok := t.Value.(BinaryEncoder); ok { formatCode = BinaryFormatCode } ci.oidToFormatCode[t.OID] = formatCode } - if d, ok := t.Value.(ResultDecoder); ok { - t.resultDecoder = d - } - if d, ok := t.Value.(TextDecoder); ok { t.textDecoder = d } @@ -501,10 +472,6 @@ type ScanPlan interface { type scanPlanDstResultDecoder struct{} func (scanPlanDstResultDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if d, ok := (dst).(ResultDecoder); ok { - return d.DecodeResult(ci, oid, formatCode, src) - } - newPlan := ci.PlanScan(oid, formatCode, dst) return newPlan.Scan(ci, oid, formatCode, src, dst) } @@ -571,21 +538,18 @@ type scanPlanDataTypeAssignTo DataType func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { dt := (*DataType)(plan) var err error - if dt.resultDecoder != nil { - err = dt.resultDecoder.DecodeResult(ci, oid, formatCode, src) - } else { - switch formatCode { - case BinaryFormatCode: - if dt.binaryDecoder == nil { - return fmt.Errorf("dt.binaryDecoder is nil") - } - err = dt.binaryDecoder.DecodeBinary(ci, src) - case TextFormatCode: - if dt.textDecoder == nil { - return fmt.Errorf("dt.textDecoder is nil") - } - err = dt.textDecoder.DecodeText(ci, src) + + switch formatCode { + case BinaryFormatCode: + if dt.binaryDecoder == nil { + return fmt.Errorf("dt.binaryDecoder is nil") } + err = dt.binaryDecoder.DecodeBinary(ci, src) + case TextFormatCode: + if dt.textDecoder == nil { + return fmt.Errorf("dt.textDecoder is nil") + } + err = dt.textDecoder.DecodeText(ci, src) } if err != nil { return err diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index b9dbda9d..56064281 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -71,16 +71,6 @@ func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { return addr } -func TestConnInfoFormatCodeForOID(t *testing.T) { - ci := pgtype.NewConnInfo() - - // pgtype.JSONB implements BinaryEncoder but also implements ParamFormatPreferrer to override it to text. - assert.Equal(t, int16(pgtype.TextFormatCode), ci.FormatCodeForOID(pgtype.JSONBOID)) - - // pgtype.Int4 implements BinaryEncoder but does not implement ParamFormatPreferrer so it should be binary. - assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.FormatCodeForOID(pgtype.Int4OID)) -} - func TestConnInfoScanNilIsNoOp(t *testing.T) { ci := pgtype.NewConnInfo() diff --git a/pgtype/zzz.aclitem.go b/pgtype/zzz.aclitem.go deleted file mode 100644 index 6ac1f94a..00000000 --- a/pgtype/zzz.aclitem.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (ACLItem) BinaryFormatSupported() bool { - return true -} - -func (ACLItem) TextFormatSupported() bool { - return true -} - -func (ACLItem) PreferredFormat() int16 { - return TextFormatCode -} - -func (dst *ACLItem) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return fmt.Errorf("binary format not supported for %T", dst) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src ACLItem) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return nil, fmt.Errorf("binary format not supported for %T", src) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.bit.go b/pgtype/zzz.bit.go deleted file mode 100644 index e95df74d..00000000 --- a/pgtype/zzz.bit.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Bit) BinaryFormatSupported() bool { - return true -} - -func (Bit) TextFormatSupported() bool { - return true -} - -func (Bit) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Bit) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Bit) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.bpchar.go b/pgtype/zzz.bpchar.go deleted file mode 100644 index c3178670..00000000 --- a/pgtype/zzz.bpchar.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (BPChar) BinaryFormatSupported() bool { - return true -} - -func (BPChar) TextFormatSupported() bool { - return true -} - -func (BPChar) PreferredFormat() int16 { - return TextFormatCode -} - -func (dst *BPChar) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src BPChar) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.bytea.go b/pgtype/zzz.bytea.go deleted file mode 100644 index 4da5ad4f..00000000 --- a/pgtype/zzz.bytea.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Bytea) BinaryFormatSupported() bool { - return true -} - -func (Bytea) TextFormatSupported() bool { - return true -} - -func (Bytea) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Bytea) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Bytea) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.cid.go b/pgtype/zzz.cid.go deleted file mode 100644 index 4cb9671d..00000000 --- a/pgtype/zzz.cid.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (CID) BinaryFormatSupported() bool { - return true -} - -func (CID) TextFormatSupported() bool { - return true -} - -func (CID) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *CID) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src CID) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.cidr.go b/pgtype/zzz.cidr.go deleted file mode 100644 index 714908e0..00000000 --- a/pgtype/zzz.cidr.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (CIDR) BinaryFormatSupported() bool { - return true -} - -func (CIDR) TextFormatSupported() bool { - return true -} - -func (CIDR) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *CIDR) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src CIDR) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.date.go b/pgtype/zzz.date.go deleted file mode 100644 index 66132082..00000000 --- a/pgtype/zzz.date.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Date) BinaryFormatSupported() bool { - return true -} - -func (Date) TextFormatSupported() bool { - return true -} - -func (Date) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Date) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Date) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.float4.go b/pgtype/zzz.float4.go deleted file mode 100644 index b600805e..00000000 --- a/pgtype/zzz.float4.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Float4) BinaryFormatSupported() bool { - return true -} - -func (Float4) TextFormatSupported() bool { - return true -} - -func (Float4) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Float4) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Float4) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.float8.go b/pgtype/zzz.float8.go deleted file mode 100644 index dd3ba0fa..00000000 --- a/pgtype/zzz.float8.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Float8) BinaryFormatSupported() bool { - return true -} - -func (Float8) TextFormatSupported() bool { - return true -} - -func (Float8) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Float8) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Float8) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.generic_binary.go b/pgtype/zzz.generic_binary.go deleted file mode 100644 index b50f1f45..00000000 --- a/pgtype/zzz.generic_binary.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (GenericBinary) BinaryFormatSupported() bool { - return true -} - -func (GenericBinary) TextFormatSupported() bool { - return true -} - -func (GenericBinary) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *GenericBinary) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return fmt.Errorf("text format not supported for %T", dst) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src GenericBinary) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return nil, fmt.Errorf("text format not supported for %T", src) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.generic_text.go b/pgtype/zzz.generic_text.go deleted file mode 100644 index 5ab771cf..00000000 --- a/pgtype/zzz.generic_text.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (GenericText) BinaryFormatSupported() bool { - return true -} - -func (GenericText) TextFormatSupported() bool { - return true -} - -func (GenericText) PreferredFormat() int16 { - return TextFormatCode -} - -func (dst *GenericText) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return fmt.Errorf("binary format not supported for %T", dst) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src GenericText) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return nil, fmt.Errorf("binary format not supported for %T", src) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.hstore.go b/pgtype/zzz.hstore.go deleted file mode 100644 index ebd7bdee..00000000 --- a/pgtype/zzz.hstore.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Hstore) BinaryFormatSupported() bool { - return true -} - -func (Hstore) TextFormatSupported() bool { - return true -} - -func (Hstore) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Hstore) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Hstore) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.inet.go b/pgtype/zzz.inet.go deleted file mode 100644 index 51daeee6..00000000 --- a/pgtype/zzz.inet.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Inet) BinaryFormatSupported() bool { - return true -} - -func (Inet) TextFormatSupported() bool { - return true -} - -func (Inet) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Inet) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Inet) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.interval.go b/pgtype/zzz.interval.go deleted file mode 100644 index a34f2d59..00000000 --- a/pgtype/zzz.interval.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Interval) BinaryFormatSupported() bool { - return true -} - -func (Interval) TextFormatSupported() bool { - return true -} - -func (Interval) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Interval) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Interval) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.json.go b/pgtype/zzz.json.go deleted file mode 100644 index 40a736c9..00000000 --- a/pgtype/zzz.json.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (JSON) BinaryFormatSupported() bool { - return true -} - -func (JSON) TextFormatSupported() bool { - return true -} - -func (JSON) PreferredFormat() int16 { - return TextFormatCode -} - -func (dst *JSON) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src JSON) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.jsonb.go b/pgtype/zzz.jsonb.go deleted file mode 100644 index a07934b7..00000000 --- a/pgtype/zzz.jsonb.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (JSONB) BinaryFormatSupported() bool { - return true -} - -func (JSONB) TextFormatSupported() bool { - return true -} - -func (JSONB) PreferredFormat() int16 { - return TextFormatCode -} - -func (dst *JSONB) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src JSONB) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.line.go b/pgtype/zzz.line.go deleted file mode 100644 index 7365744b..00000000 --- a/pgtype/zzz.line.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Line) BinaryFormatSupported() bool { - return true -} - -func (Line) TextFormatSupported() bool { - return true -} - -func (Line) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Line) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Line) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.lseg.go b/pgtype/zzz.lseg.go deleted file mode 100644 index 1a95af09..00000000 --- a/pgtype/zzz.lseg.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Lseg) BinaryFormatSupported() bool { - return true -} - -func (Lseg) TextFormatSupported() bool { - return true -} - -func (Lseg) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Lseg) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Lseg) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.macadder.go b/pgtype/zzz.macadder.go deleted file mode 100644 index 5758d68f..00000000 --- a/pgtype/zzz.macadder.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Macaddr) BinaryFormatSupported() bool { - return true -} - -func (Macaddr) TextFormatSupported() bool { - return true -} - -func (Macaddr) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Macaddr) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Macaddr) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.name.go b/pgtype/zzz.name.go deleted file mode 100644 index 6949c337..00000000 --- a/pgtype/zzz.name.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Name) BinaryFormatSupported() bool { - return true -} - -func (Name) TextFormatSupported() bool { - return true -} - -func (Name) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Name) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Name) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.numeric.go b/pgtype/zzz.numeric.go deleted file mode 100644 index 838bed40..00000000 --- a/pgtype/zzz.numeric.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Numeric) BinaryFormatSupported() bool { - return true -} - -func (Numeric) TextFormatSupported() bool { - return true -} - -func (Numeric) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Numeric) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Numeric) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.oid.go b/pgtype/zzz.oid.go deleted file mode 100644 index bc3ba7d2..00000000 --- a/pgtype/zzz.oid.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (OID) BinaryFormatSupported() bool { - return true -} - -func (OID) TextFormatSupported() bool { - return true -} - -func (OID) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *OID) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src OID) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.oid_value.go b/pgtype/zzz.oid_value.go deleted file mode 100644 index 6fba9e44..00000000 --- a/pgtype/zzz.oid_value.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (OIDValue) BinaryFormatSupported() bool { - return true -} - -func (OIDValue) TextFormatSupported() bool { - return true -} - -func (OIDValue) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *OIDValue) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src OIDValue) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.path.go b/pgtype/zzz.path.go deleted file mode 100644 index d761ac40..00000000 --- a/pgtype/zzz.path.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Path) BinaryFormatSupported() bool { - return true -} - -func (Path) TextFormatSupported() bool { - return true -} - -func (Path) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Path) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Path) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.pguint32.go b/pgtype/zzz.pguint32.go deleted file mode 100644 index c869da8f..00000000 --- a/pgtype/zzz.pguint32.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (pguint32) BinaryFormatSupported() bool { - return true -} - -func (pguint32) TextFormatSupported() bool { - return true -} - -func (pguint32) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *pguint32) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src pguint32) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.point.go b/pgtype/zzz.point.go deleted file mode 100644 index 083ded95..00000000 --- a/pgtype/zzz.point.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Point) BinaryFormatSupported() bool { - return true -} - -func (Point) TextFormatSupported() bool { - return true -} - -func (Point) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Point) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Point) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.polygon.go b/pgtype/zzz.polygon.go deleted file mode 100644 index 2bfdbbd4..00000000 --- a/pgtype/zzz.polygon.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Polygon) BinaryFormatSupported() bool { - return true -} - -func (Polygon) TextFormatSupported() bool { - return true -} - -func (Polygon) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Polygon) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Polygon) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.qchar.go b/pgtype/zzz.qchar.go deleted file mode 100644 index adc0f462..00000000 --- a/pgtype/zzz.qchar.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (QChar) BinaryFormatSupported() bool { - return true -} - -func (QChar) TextFormatSupported() bool { - return true -} - -func (QChar) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *QChar) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return fmt.Errorf("text format not supported for %T", dst) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src QChar) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return nil, fmt.Errorf("text format not supported for %T", src) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.text.go b/pgtype/zzz.text.go deleted file mode 100644 index e1a3908f..00000000 --- a/pgtype/zzz.text.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Text) BinaryFormatSupported() bool { - return true -} - -func (Text) TextFormatSupported() bool { - return true -} - -func (Text) PreferredFormat() int16 { - return TextFormatCode -} - -func (dst *Text) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Text) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.tid.go b/pgtype/zzz.tid.go deleted file mode 100644 index 1a705277..00000000 --- a/pgtype/zzz.tid.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (TID) BinaryFormatSupported() bool { - return true -} - -func (TID) TextFormatSupported() bool { - return true -} - -func (TID) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *TID) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src TID) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.time.go b/pgtype/zzz.time.go deleted file mode 100644 index be9a96a7..00000000 --- a/pgtype/zzz.time.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Time) BinaryFormatSupported() bool { - return true -} - -func (Time) TextFormatSupported() bool { - return true -} - -func (Time) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Time) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Time) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.timestamp.go b/pgtype/zzz.timestamp.go deleted file mode 100644 index ce6135c7..00000000 --- a/pgtype/zzz.timestamp.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Timestamp) BinaryFormatSupported() bool { - return true -} - -func (Timestamp) TextFormatSupported() bool { - return true -} - -func (Timestamp) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Timestamp) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Timestamp) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.timestamptz.go b/pgtype/zzz.timestamptz.go deleted file mode 100644 index 1147b257..00000000 --- a/pgtype/zzz.timestamptz.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Timestamptz) BinaryFormatSupported() bool { - return true -} - -func (Timestamptz) TextFormatSupported() bool { - return true -} - -func (Timestamptz) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Timestamptz) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Timestamptz) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.uuid.go b/pgtype/zzz.uuid.go deleted file mode 100644 index a0aefaf6..00000000 --- a/pgtype/zzz.uuid.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (UUID) BinaryFormatSupported() bool { - return true -} - -func (UUID) TextFormatSupported() bool { - return true -} - -func (UUID) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *UUID) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src UUID) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.varbit.go b/pgtype/zzz.varbit.go deleted file mode 100644 index 2b090ebf..00000000 --- a/pgtype/zzz.varbit.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Varbit) BinaryFormatSupported() bool { - return true -} - -func (Varbit) TextFormatSupported() bool { - return true -} - -func (Varbit) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Varbit) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Varbit) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.varchar.go b/pgtype/zzz.varchar.go deleted file mode 100644 index 9771d412..00000000 --- a/pgtype/zzz.varchar.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Varchar) BinaryFormatSupported() bool { - return true -} - -func (Varchar) TextFormatSupported() bool { - return true -} - -func (Varchar) PreferredFormat() int16 { - return TextFormatCode -} - -func (dst *Varchar) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Varchar) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/zzz.xid.go b/pgtype/zzz.xid.go deleted file mode 100644 index 2754d98e..00000000 --- a/pgtype/zzz.xid.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (XID) BinaryFormatSupported() bool { - return true -} - -func (XID) TextFormatSupported() bool { - return true -} - -func (XID) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *XID) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src XID) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/values.go b/values.go index 00606689..e084a69b 100644 --- a/values.go +++ b/values.go @@ -252,9 +252,7 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32 // argument to a prepared statement. It defaults to TextFormatCode if no // determination can be made. func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid uint32, arg interface{}) int16 { - switch arg := arg.(type) { - case pgtype.FormatSupport: - return arg.PreferredFormat() + switch arg.(type) { case pgtype.BinaryEncoder: return BinaryFormatCode case string, *string, pgtype.TextEncoder: From dcaf102f8e211f69b39f90125da25388b394b6de Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 5 Jan 2022 08:59:21 -0600 Subject: [PATCH 0810/1158] Introduce PlanEncode --- extended_query_builder.go | 2 +- pgtype/array_codec.go | 122 ++++++++++++++++++++--------------- pgtype/bool.go | 61 ++++++++++++------ pgtype/box.go | 73 +++++++++++---------- pgtype/circle.go | 69 +++++++++++--------- pgtype/int.go | 129 +++++++++++++++++++++++++++++++------- pgtype/int.go.erb | 43 ++++++++++--- pgtype/pgtype.go | 64 +++++++++++++++---- values.go | 4 +- 9 files changed, 388 insertions(+), 179 deletions(-) diff --git a/extended_query_builder.go b/extended_query_builder.go index 480e35d3..36447c99 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -131,7 +131,7 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o } return eqb.encodeExtendedParamValue(ci, oid, formatCode, value) } else if dt.Codec != nil { - buf, err := dt.Codec.Encode(ci, oid, formatCode, arg, eqb.paramValueBytes) + buf, err := ci.Encode(oid, formatCode, arg, eqb.paramValueBytes) if err != nil { return nil, err } diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 16ce7382..e8c2b2ed 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -42,63 +42,29 @@ func (c *ArrayCodec) PreferredFormat() int16 { return c.ElementCodec.PreferredFormat() } -func (c *ArrayCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { - if value == nil { - return nil, nil +func (c *ArrayCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + return &encodePlanArrayCodecBinary{ac: c, ci: ci, oid: oid} + case TextFormatCode: + return &encodePlanArrayCodecText{ac: c, ci: ci, oid: oid} } + return nil +} + +type encodePlanArrayCodecText struct { + ac *ArrayCodec + ci *ConnInfo + oid uint32 +} + +func (p *encodePlanArrayCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { array, err := makeArrayGetter(value) if err != nil { return nil, err } - switch format { - case BinaryFormatCode: - return c.encodeBinary(ci, oid, array, buf) - case TextFormatCode: - return c.encodeText(ci, oid, array, buf) - default: - return nil, fmt.Errorf("unknown format code: %v", format) - } - -} - -func (c *ArrayCodec) encodeBinary(ci *ConnInfo, oid uint32, array ArrayGetter, buf []byte) (newBuf []byte, err error) { - dimensions := array.Dimensions() - if dimensions == nil { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: dimensions, - ElementOID: int32(c.ElementOID), - } - - containsNullIndex := len(buf) + 4 - - buf = arrayHeader.EncodeBinary(ci, buf) - - elementCount := cardinality(dimensions) - for i := 0; i < elementCount; i++ { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := c.ElementCodec.Encode(ci, c.ElementOID, BinaryFormatCode, array.Index(i), buf) - if err != nil { - return nil, err - } - if elemBuf == nil { - pgio.SetInt32(buf[containsNullIndex:], 1) - } else { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -func (c *ArrayCodec) encodeText(ci *ConnInfo, oid uint32, array ArrayGetter, buf []byte) (newBuf []byte, err error) { dimensions := array.Dimensions() if dimensions == nil { return nil, nil @@ -134,7 +100,11 @@ func (c *ArrayCodec) encodeText(ci *ConnInfo, oid uint32, array ArrayGetter, buf } } - elemBuf, err := c.ElementCodec.Encode(ci, c.ElementOID, TextFormatCode, array.Index(i), inElemBuf) + encodePlan := p.ac.ElementCodec.PlanEncode(p.ci, p.ac.ElementOID, TextFormatCode, array.Index(i)) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", array.Index(i)) + } + elemBuf, err := encodePlan.Encode(array.Index(i), inElemBuf) if err != nil { return nil, err } @@ -154,6 +124,56 @@ func (c *ArrayCodec) encodeText(ci *ConnInfo, oid uint32, array ArrayGetter, buf return buf, nil } +type encodePlanArrayCodecBinary struct { + ac *ArrayCodec + ci *ConnInfo + oid uint32 +} + +func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + array, err := makeArrayGetter(value) + if err != nil { + return nil, err + } + + dimensions := array.Dimensions() + if dimensions == nil { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: dimensions, + ElementOID: int32(p.ac.ElementOID), + } + + containsNullIndex := len(buf) + 4 + + buf = arrayHeader.EncodeBinary(p.ci, buf) + + elementCount := cardinality(dimensions) + for i := 0; i < elementCount; i++ { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + encodePlan := p.ac.ElementCodec.PlanEncode(p.ci, p.ac.ElementOID, BinaryFormatCode, array.Index(i)) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", array.Index(i)) + } + elemBuf, err := encodePlan.Encode(array.Index(i), buf) + if err != nil { + return nil, err + } + if elemBuf == nil { + pgio.SetInt32(buf[containsNullIndex:], 1) + } else { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + func (c *ArrayCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { _, err := makeArraySetter(target) if err != nil { diff --git a/pgtype/bool.go b/pgtype/bool.go index 36d29d40..d2c3cdc3 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -105,7 +105,20 @@ func (BoolCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (BoolCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { +func (BoolCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + return &encodePlanBoolCodecBinary{} + case TextFormatCode: + return &encodePlanBoolCodecText{} + } + + return nil +} + +type encodePlanBoolCodecBinary struct{} + +func (p *encodePlanBoolCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { v, valid, err := convertToBoolForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to bool: %v", value, err) @@ -117,24 +130,36 @@ func (BoolCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{ return nil, nil } - switch format { - case BinaryFormatCode: - if v { - buf = append(buf, 1) - } else { - buf = append(buf, 0) - } - return buf, nil - case TextFormatCode: - if v { - buf = append(buf, 't') - } else { - buf = append(buf, 'f') - } - return buf, nil - default: - return nil, fmt.Errorf("unknown format code: %v", format) + if v { + buf = append(buf, 1) + } else { + buf = append(buf, 0) } + + return buf, nil +} + +type encodePlanBoolCodecText struct{} + +func (p *encodePlanBoolCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v, valid, err := convertToBoolForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to bool: %v", value, err) + } + if !valid { + return nil, nil + } + if value == nil { + return nil, nil + } + + if v { + buf = append(buf, 't') + } else { + buf = append(buf, 'f') + } + + return buf, nil } func (BoolCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { diff --git a/pgtype/box.go b/pgtype/box.go index 7db7d5a2..b5c30ed3 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -50,7 +50,7 @@ func (dst *Box) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Box) Value() (driver.Value, error) { - buf, err := BoxCodec{}.Encode(nil, 0, TextFormatCode, src, nil) + buf, err := BoxCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) if err != nil { return nil, err } @@ -67,44 +67,51 @@ func (BoxCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (BoxCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { - if value == nil { - return nil, nil - } - - var box Box - if v, ok := value.(BoxValuer); ok { - b, err := v.BoxValue() - if err != nil { - return nil, err - } - box = b - } else { - return nil, fmt.Errorf("cannot convert %v to box: %v", value, err) - } - - if !box.Valid { - return nil, nil +func (BoxCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(BoxValuer); !ok { + return nil } switch format { case BinaryFormatCode: - buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].X)) - buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].Y)) - buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].X)) - buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].Y)) - return buf, nil + return &encodePlanBoxCodecBinary{} case TextFormatCode: - buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, - strconv.FormatFloat(box.P[0].X, 'f', -1, 64), - strconv.FormatFloat(box.P[0].Y, 'f', -1, 64), - strconv.FormatFloat(box.P[1].X, 'f', -1, 64), - strconv.FormatFloat(box.P[1].Y, 'f', -1, 64), - )...) - return buf, nil - default: - return nil, fmt.Errorf("unknown format code: %v", format) + return &encodePlanBoxCodecText{} } + + return nil +} + +type encodePlanBoxCodecBinary struct{} + +func (p *encodePlanBoxCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + box, err := value.(BoxValuer).BoxValue() + if err != nil { + return nil, err + } + + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].Y)) + return buf, nil +} + +type encodePlanBoxCodecText struct{} + +func (p *encodePlanBoxCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + box, err := value.(BoxValuer).BoxValue() + if err != nil { + return nil, err + } + + buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, + strconv.FormatFloat(box.P[0].X, 'f', -1, 64), + strconv.FormatFloat(box.P[0].Y, 'f', -1, 64), + strconv.FormatFloat(box.P[1].X, 'f', -1, 64), + strconv.FormatFloat(box.P[1].Y, 'f', -1, 64), + )...) + return buf, nil } func (BoxCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { diff --git a/pgtype/circle.go b/pgtype/circle.go index f1f66175..f214a070 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -51,7 +51,7 @@ func (dst *Circle) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Circle) Value() (driver.Value, error) { - buf, err := CircleCodec{}.Encode(nil, 0, TextFormatCode, src, nil) + buf, err := CircleCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) if err != nil { return nil, err } @@ -68,42 +68,49 @@ func (CircleCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (CircleCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { - if value == nil { - return nil, nil - } - - var circle Circle - if v, ok := value.(CircleValuer); ok { - c, err := v.CircleValue() - if err != nil { - return nil, err - } - circle = c - } else { - return nil, fmt.Errorf("cannot convert %v to circle: %v", value, err) - } - - if !circle.Valid { - return nil, nil +func (CircleCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(CircleValuer); !ok { + return nil } switch format { case BinaryFormatCode: - buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.X)) - buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.Y)) - buf = pgio.AppendUint64(buf, math.Float64bits(circle.R)) - return buf, nil + return &encodePlanCircleCodecBinary{} case TextFormatCode: - buf = append(buf, fmt.Sprintf(`<(%s,%s),%s>`, - strconv.FormatFloat(circle.P.X, 'f', -1, 64), - strconv.FormatFloat(circle.P.Y, 'f', -1, 64), - strconv.FormatFloat(circle.R, 'f', -1, 64), - )...) - return buf, nil - default: - return nil, fmt.Errorf("unknown format code: %v", format) + return &encodePlanCircleCodecText{} } + + return nil +} + +type encodePlanCircleCodecBinary struct{} + +func (p *encodePlanCircleCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + circle, err := value.(CircleValuer).CircleValue() + if err != nil { + return nil, err + } + + buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(circle.R)) + return buf, nil +} + +type encodePlanCircleCodecText struct{} + +func (p *encodePlanCircleCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + circle, err := value.(CircleValuer).CircleValue() + if err != nil { + return nil, err + } + + buf = append(buf, fmt.Sprintf(`<(%s,%s),%s>`, + strconv.FormatFloat(circle.P.X, 'f', -1, 64), + strconv.FormatFloat(circle.P.Y, 'f', -1, 64), + strconv.FormatFloat(circle.R, 'f', -1, 64), + )...) + return buf, nil } func (CircleCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { diff --git a/pgtype/int.go b/pgtype/int.go index 21259beb..18b1ba90 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -119,7 +119,20 @@ func (Int2Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int2Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { +func (Int2Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + return &encodePlanInt2CodecBinary{} + case TextFormatCode: + return &encodePlanInt2CodecText{} + } + + return nil +} + +type encodePlanInt2CodecBinary struct{} + +func (p *encodePlanInt2CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { n, valid, err := convertToInt64ForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to int2: %v", value, err) @@ -135,14 +148,28 @@ func (Int2Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{ return nil, fmt.Errorf("%d is less than minimum value for int2", n) } - switch format { - case BinaryFormatCode: - return pgio.AppendInt16(buf, int16(n)), nil - case TextFormatCode: - return append(buf, strconv.FormatInt(n, 10)...), nil - default: - return nil, fmt.Errorf("unknown format code: %v", format) + return pgio.AppendInt16(buf, int16(n)), nil +} + +type encodePlanInt2CodecText struct{} + +func (p *encodePlanInt2CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, valid, err := convertToInt64ForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to int2: %v", value, err) } + if !valid { + return nil, nil + } + + if n > math.MaxInt16 { + return nil, fmt.Errorf("%d is greater than maximum value for int2", n) + } + if n < math.MinInt16 { + return nil, fmt.Errorf("%d is less than minimum value for int2", n) + } + + return append(buf, strconv.FormatInt(n, 10)...), nil } func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { @@ -599,7 +626,20 @@ func (Int4Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int4Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { +func (Int4Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + return &encodePlanInt4CodecBinary{} + case TextFormatCode: + return &encodePlanInt4CodecText{} + } + + return nil +} + +type encodePlanInt4CodecBinary struct{} + +func (p *encodePlanInt4CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { n, valid, err := convertToInt64ForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to int4: %v", value, err) @@ -615,14 +655,28 @@ func (Int4Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{ return nil, fmt.Errorf("%d is less than minimum value for int4", n) } - switch format { - case BinaryFormatCode: - return pgio.AppendInt32(buf, int32(n)), nil - case TextFormatCode: - return append(buf, strconv.FormatInt(n, 10)...), nil - default: - return nil, fmt.Errorf("unknown format code: %v", format) + return pgio.AppendInt32(buf, int32(n)), nil +} + +type encodePlanInt4CodecText struct{} + +func (p *encodePlanInt4CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, valid, err := convertToInt64ForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to int4: %v", value, err) } + if !valid { + return nil, nil + } + + if n > math.MaxInt32 { + return nil, fmt.Errorf("%d is greater than maximum value for int4", n) + } + if n < math.MinInt32 { + return nil, fmt.Errorf("%d is less than minimum value for int4", n) + } + + return append(buf, strconv.FormatInt(n, 10)...), nil } func (Int4Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { @@ -1090,7 +1144,20 @@ func (Int8Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int8Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { +func (Int8Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + return &encodePlanInt8CodecBinary{} + case TextFormatCode: + return &encodePlanInt8CodecText{} + } + + return nil +} + +type encodePlanInt8CodecBinary struct{} + +func (p *encodePlanInt8CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { n, valid, err := convertToInt64ForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to int8: %v", value, err) @@ -1106,14 +1173,28 @@ func (Int8Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{ return nil, fmt.Errorf("%d is less than minimum value for int8", n) } - switch format { - case BinaryFormatCode: - return pgio.AppendInt64(buf, int64(n)), nil - case TextFormatCode: - return append(buf, strconv.FormatInt(n, 10)...), nil - default: - return nil, fmt.Errorf("unknown format code: %v", format) + return pgio.AppendInt64(buf, int64(n)), nil +} + +type encodePlanInt8CodecText struct{} + +func (p *encodePlanInt8CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, valid, err := convertToInt64ForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to int8: %v", value, err) } + if !valid { + return nil, nil + } + + if n > math.MaxInt64 { + return nil, fmt.Errorf("%d is greater than maximum value for int8", n) + } + if n < math.MinInt64 { + return nil, fmt.Errorf("%d is less than minimum value for int8", n) + } + + return append(buf, strconv.FormatInt(n, 10)...), nil } func (Int8Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 546494d4..3f15dfce 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -120,7 +120,20 @@ func (Int<%= pg_byte_size %>Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int<%= pg_byte_size %>Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { +func (Int<%= pg_byte_size %>Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + return &encodePlanInt<%= pg_byte_size %>CodecBinary{} + case TextFormatCode: + return &encodePlanInt<%= pg_byte_size %>CodecText{} + } + + return nil +} + +type encodePlanInt<%= pg_byte_size %>CodecBinary struct{} + +func (p *encodePlanInt<%= pg_byte_size %>CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { n, valid, err := convertToInt64ForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to int<%= pg_byte_size %>: %v", value, err) @@ -136,14 +149,28 @@ func (Int<%= pg_byte_size %>Codec) Encode(ci *ConnInfo, oid uint32, format int16 return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n) } - switch format { - case BinaryFormatCode: - return pgio.AppendInt<%= pg_bit_size %>(buf, int<%= pg_bit_size %>(n)), nil - case TextFormatCode: - return append(buf, strconv.FormatInt(n, 10)...), nil - default: - return nil, fmt.Errorf("unknown format code: %v", format) + return pgio.AppendInt<%= pg_bit_size %>(buf, int<%= pg_bit_size %>(n)), nil +} + +type encodePlanInt<%= pg_byte_size %>CodecText struct{} + +func (p *encodePlanInt<%= pg_byte_size %>CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, valid, err := convertToInt64ForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to int<%= pg_byte_size %>: %v", value, err) } + if !valid { + return nil, nil + } + + if n > math.MaxInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n) + } + if n < math.MinInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n) + } + + return append(buf, strconv.FormatInt(n, 10)...), nil } func (Int<%= pg_byte_size %>Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 3d863373..a49d29c9 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -155,10 +155,9 @@ type Codec interface { // PreferredFormat returns the preferred format. PreferredFormat() int16 - // Encode appends the encoded bytes of value to buf. If value is the SQL NULL then append nothing and return - // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data - // written. - Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) + // PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be + // found then nil is returned. + PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan // PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If // actualTarget is true then the returned ScanPlan may be optimized to directly scan into target. If no plan can be @@ -172,12 +171,6 @@ type Codec interface { DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) } -// ResultFormatPreferrer allows a type to specify its preferred result format instead of it being inferred from -// whether it is also a BinaryDecoder. -type ResultFormatPreferrer interface { - PreferredResultFormat() int16 -} - type BinaryDecoder interface { // DecodeBinary decodes src into BinaryDecoder. If src is nil then the // original SQL value is NULL. BinaryDecoder takes ownership of src. The @@ -462,6 +455,14 @@ func (ci *ConnInfo) PreferAssignToOverSQLScannerForType(value interface{}) { ci.preferAssignToOverSQLScannerTypes[reflect.TypeOf(value)] = struct{}{} } +// EncodePlan is a precompiled plan to encode a particular type into a particular OID and format. +type EncodePlan interface { + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return + // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data + // written. + Encode(value interface{}, buf []byte) (newBuf []byte, err error) +} + // ScanPlan is a precompiled plan to scan into a type of destination. type ScanPlan interface { // Scan scans src into dst. If the dst type has changed in an incompatible way a ScanPlan should automatically @@ -929,10 +930,51 @@ func codecDecodeToTextFormat(codec Codec, ci *ConnInfo, oid uint32, format int16 if err != nil { return nil, err } - buf, err := codec.Encode(ci, oid, TextFormatCode, value, nil) + buf, err := ci.Encode(oid, TextFormatCode, value, nil) if err != nil { return nil, err } return string(buf), nil } } + +// PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be +// found then nil is returned. +func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan { + + var dt *DataType + + if oid == 0 { + if dataType, ok := ci.DataTypeForValue(value); ok { + dt = dataType + } + } else { + if dataType, ok := ci.DataTypeForOID(oid); ok { + dt = dataType + } + } + + if dt != nil && dt.Codec != nil { + if plan := dt.Codec.PlanEncode(ci, oid, format, value); plan != nil { + return plan + } + + } + + return nil +} + +// Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return +// (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data +// written. +func (ci *ConnInfo) Encode(oid uint32, formatCode int16, value interface{}, buf []byte) (newBuf []byte, err error) { + if value == nil { + return nil, nil + } + + plan := ci.PlanEncode(oid, formatCode, value) + if plan == nil { + return nil, fmt.Errorf("unable to encode %v", value) + } + return plan.Encode(value, buf) +} diff --git a/values.go b/values.go index e084a69b..a60d4129 100644 --- a/values.go +++ b/values.go @@ -130,7 +130,7 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e } return string(buf), nil } else if dt.Codec != nil { - buf, err := dt.Codec.Encode(ci, 0, TextFormatCode, arg, nil) + buf, err := ci.Encode(0, TextFormatCode, arg, nil) if err != nil { return nil, err } @@ -230,7 +230,7 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32 } else if dt.Codec != nil { sp := len(buf) buf = pgio.AppendInt32(buf, -1) - argBuf, err := dt.Codec.Encode(ci, oid, BinaryFormatCode, arg, buf) + argBuf, err := ci.Encode(oid, BinaryFormatCode, arg, buf) if err != nil { return nil, err } From 2b0afbb408347335ffe6dd5f24bffb16c1f49c70 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 09:33:08 -0600 Subject: [PATCH 0811/1158] Convert point to Codec --- pgtype/pgtype.go | 2 +- pgtype/point.go | 266 ++++++++++++++++++++++++++----------------- pgtype/point_test.go | 64 +++-------- 3 files changed, 174 insertions(+), 158 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index a49d29c9..5591d623 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -303,7 +303,7 @@ func NewConnInfo() *ConnInfo { // ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) ci.RegisterDataType(DataType{Value: &OIDValue{}, Name: "oid", OID: OIDOID}) ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID}) - ci.RegisterDataType(DataType{Value: &Point{}, Name: "point", OID: PointOID}) + ci.RegisterDataType(DataType{Name: "point", OID: PointOID, Codec: PointCodec{}}) ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID}) // ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) ci.RegisterDataType(DataType{Value: &Text{}, Name: "text", OID: TextOID}) diff --git a/pgtype/point.go b/pgtype/point.go index d35dbf03..b4236c8f 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -17,33 +17,28 @@ type Vec2 struct { Y float64 } +type PointScanner interface { + ScanPoint(v Point) error +} + +type PointValuer interface { + PointValue() (Point, error) +} + type Point struct { P Vec2 Valid bool } -func (dst *Point) Set(src interface{}) error { - if src == nil { - dst.Valid = false - return nil - } - err := fmt.Errorf("cannot convert %v to Point", src) - var p *Point - switch value := src.(type) { - case string: - p, err = parsePoint([]byte(value)) - case []byte: - p, err = parsePoint(value) - default: - return err - } - if err != nil { - return err - } - *dst = *p +func (p *Point) ScanPoint(v Point) error { + *p = v return nil } +func (p Point) PointValue() (Point, error) { + return p, nil +} + func parsePoint(src []byte) (*Point, error) { if src == nil || bytes.Compare(src, []byte("null")) == 0 { return &Point{}, nil @@ -73,87 +68,6 @@ func parsePoint(src []byte) (*Point, error) { return &Point{P: Vec2{x, y}, Valid: true}, nil } -func (dst Point) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *Point) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Point{} - return nil - } - - if len(src) < 5 { - return fmt.Errorf("invalid length for point: %v", len(src)) - } - - parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) - if len(parts) < 2 { - return fmt.Errorf("invalid format for point") - } - - x, err := strconv.ParseFloat(parts[0], 64) - if err != nil { - return err - } - - y, err := strconv.ParseFloat(parts[1], 64) - if err != nil { - return err - } - - *dst = Point{P: Vec2{x, y}, Valid: true} - return nil -} - -func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Point{} - return nil - } - - if len(src) != 16 { - return fmt.Errorf("invalid length for point: %v", len(src)) - } - - x := binary.BigEndian.Uint64(src) - y := binary.BigEndian.Uint64(src[8:]) - - *dst = Point{ - P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, - Valid: true, - } - return nil -} - -func (src Point) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return append(buf, fmt.Sprintf(`(%s,%s)`, - strconv.FormatFloat(src.P.X, 'f', -1, 64), - strconv.FormatFloat(src.P.Y, 'f', -1, 64), - )...), nil -} - -func (src Point) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y)) - return buf, nil -} - // Scan implements the database/sql Scanner interface. func (dst *Point) Scan(src interface{}) error { if src == nil { @@ -163,11 +77,7 @@ func (dst *Point) Scan(src interface{}) error { switch src := src.(type) { case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + return scanPlanTextAnyToPointScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) } return fmt.Errorf("cannot scan %T", src) @@ -175,7 +85,11 @@ func (dst *Point) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Point) Value() (driver.Value, error) { - return EncodeValueText(src) + buf, err := PointCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) + if err != nil { + return nil, err + } + return string(buf), err } func (src Point) MarshalJSON() ([]byte, error) { @@ -198,3 +112,143 @@ func (dst *Point) UnmarshalJSON(point []byte) error { *dst = *p return nil } + +type PointCodec struct{} + +func (PointCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (PointCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (PointCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(PointValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return &encodePlanPointCodecBinary{} + case TextFormatCode: + return &encodePlanPointCodecText{} + } + + return nil +} + +type encodePlanPointCodecBinary struct{} + +func (p *encodePlanPointCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + point, err := value.(PointValuer).PointValue() + if err != nil { + return nil, err + } + + buf = pgio.AppendUint64(buf, math.Float64bits(point.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(point.P.Y)) + return buf, nil +} + +type encodePlanPointCodecText struct{} + +func (p *encodePlanPointCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + point, err := value.(PointValuer).PointValue() + if err != nil { + return nil, err + } + + return append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(point.P.X, 'f', -1, 64), + strconv.FormatFloat(point.P.Y, 'f', -1, 64), + )...), nil +} + +func (PointCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case PointScanner: + return scanPlanBinaryPointToPointScanner{} + } + case TextFormatCode: + switch target.(type) { + case PointScanner: + return scanPlanTextAnyToPointScanner{} + } + } + + return nil +} + +func (c PointCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c PointCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var point Point + err := codecScan(c, ci, oid, format, src, &point) + if err != nil { + return nil, err + } + return point, nil +} + +type scanPlanBinaryPointToPointScanner struct{} + +func (scanPlanBinaryPointToPointScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(PointScanner) + + if src == nil { + return scanner.ScanPoint(Point{}) + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + + return scanner.ScanPoint(Point{ + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + Valid: true, + }) +} + +type scanPlanTextAnyToPointScanner struct{} + +func (scanPlanTextAnyToPointScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(PointScanner) + + if src == nil { + return scanner.ScanPoint(Point{}) + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return fmt.Errorf("invalid format for point") + } + + x, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return err + } + + y, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return err + } + + return scanner.ScanPoint(Point{P: Vec2{x, y}, Valid: true}) +} diff --git a/pgtype/point_test.go b/pgtype/point_test.go index bb4c2126..718e203a 100644 --- a/pgtype/point_test.go +++ b/pgtype/point_test.go @@ -5,63 +5,25 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/require" ) -func TestPointTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "point", []interface{}{ - &pgtype.Point{P: pgtype.Vec2{1.234, 5.6789012345}, Valid: true}, - &pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Valid: true}, - &pgtype.Point{}, +func TestPointCodec(t *testing.T) { + testPgxCodec(t, "point", []PgxTranscodeTestCase{ + { + pgtype.Point{P: pgtype.Vec2{1.234, 5.6789012345}, Valid: true}, + new(pgtype.Point), + isExpectedEq(pgtype.Point{P: pgtype.Vec2{1.234, 5.6789012345}, Valid: true}), + }, + { + pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Valid: true}, + new(pgtype.Point), + isExpectedEq(pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Valid: true}), + }, + {nil, new(pgtype.Point), isExpectedEq(pgtype.Point{})}, }) } -func TestPoint_Set(t *testing.T) { - tests := []struct { - name string - arg interface{} - valid bool - wantErr bool - }{ - { - name: "first", - arg: "(12312.123123,123123.123123)", - valid: true, - wantErr: false, - }, - { - name: "second", - arg: "(1231s2.123123,123123.123123)", - valid: false, - wantErr: true, - }, - { - name: "third", - arg: []byte("(122.123123,123.123123)"), - valid: true, - wantErr: false, - }, - { - name: "third", - arg: nil, - valid: false, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - dst := &pgtype.Point{} - if err := dst.Set(tt.arg); (err != nil) != tt.wantErr { - t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr) - } - if dst.Valid != tt.valid { - t.Errorf("Expected status: %v; got: %v", tt.valid, dst.Valid) - } - }) - } -} - func TestPoint_MarshalJSON(t *testing.T) { tests := []struct { name string From a7d4a22001977f7041038ef1d288c3dc6b0bad5a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 09:37:40 -0600 Subject: [PATCH 0812/1158] Add point array support --- pgtype/pgtype.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5591d623..d14087cd 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -50,6 +50,7 @@ const ( BPCharArrayOID = 1014 VarcharArrayOID = 1015 Int8ArrayOID = 1016 + PointArrayOID = 1017 BoxArrayOID = 1020 Float4ArrayOID = 1021 Float8ArrayOID = 1022 @@ -265,6 +266,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementCodec: Int8Codec{}, ElementOID: Int8OID}}) ci.RegisterDataType(DataType{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementCodec: BoxCodec{}, ElementOID: BoxOID}}) ci.RegisterDataType(DataType{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementCodec: CircleCodec{}, ElementOID: CircleOID}}) + ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementCodec: PointCodec{}, ElementOID: PointOID}}) ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) ci.RegisterDataType(DataType{Value: &TextArray{}, Name: "_text", OID: TextArrayOID}) ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID}) From fcc9dcc960ce72b86c050008c1be9d8c99a66adb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 12:59:04 -0600 Subject: [PATCH 0813/1158] Convert text to Codec This also entailed updating and deleting types that depended on Text. --- conn_test.go | 9 +- pgtype/bpchar.go | 92 ------- pgtype/bpchar_array.go | 504 ----------------------------------- pgtype/bpchar_array_test.go | 55 ---- pgtype/bpchar_test.go | 51 ---- pgtype/enum_array.go | 418 ----------------------------- pgtype/enum_array_test.go | 281 ------------------- pgtype/generic_text.go | 39 --- pgtype/hstore.go | 6 +- pgtype/name.go | 58 ---- pgtype/name_test.go | 98 ------- pgtype/pgtype.go | 110 +++++++- pgtype/text.go | 312 +++++++++++++--------- pgtype/text_array.go | 504 ----------------------------------- pgtype/text_array_test.go | 294 -------------------- pgtype/text_test.go | 160 ++++------- pgtype/unknown.go | 44 --- pgtype/varchar.go | 66 ----- pgtype/varchar_array.go | 504 ----------------------------------- pgtype/varchar_array_test.go | 282 -------------------- pgtype/zeronull/text.go | 65 +---- query_test.go | 4 +- rows.go | 23 +- 23 files changed, 366 insertions(+), 3613 deletions(-) delete mode 100644 pgtype/bpchar.go delete mode 100644 pgtype/bpchar_array.go delete mode 100644 pgtype/bpchar_array_test.go delete mode 100644 pgtype/bpchar_test.go delete mode 100644 pgtype/enum_array.go delete mode 100644 pgtype/enum_array_test.go delete mode 100644 pgtype/generic_text.go delete mode 100644 pgtype/name.go delete mode 100644 pgtype/name_test.go delete mode 100644 pgtype/text_array.go delete mode 100644 pgtype/text_array_test.go delete mode 100644 pgtype/unknown.go delete mode 100644 pgtype/varchar.go delete mode 100644 pgtype/varchar_array.go delete mode 100644 pgtype/varchar_array_test.go diff --git a/conn_test.go b/conn_test.go index 55297e26..0d7bcb31 100644 --- a/conn_test.go +++ b/conn_test.go @@ -91,13 +91,8 @@ func TestConnectWithPreferSimpleProtocol(t *testing.T) { var s pgtype.Text err := conn.QueryRow(context.Background(), "select $1::int4", 42).Scan(&s) - if err != nil { - t.Fatal(err) - } - - if s.Get() != "42" { - t.Fatalf(`expected "42", got %v`, s) - } + require.NoError(t, err) + require.Equal(t, pgtype.Text{String: "42", Valid: true}, s) ensureConnValid(t, conn) } diff --git a/pgtype/bpchar.go b/pgtype/bpchar.go deleted file mode 100644 index 2e899ea8..00000000 --- a/pgtype/bpchar.go +++ /dev/null @@ -1,92 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "fmt" -) - -// BPChar is fixed-length, blank padded char type -// character(n), char(n) -type BPChar Text - -// Set converts from src to dst. -func (dst *BPChar) Set(src interface{}) error { - return (*Text)(dst).Set(src) -} - -// Get returns underlying value -func (dst BPChar) Get() interface{} { - return (Text)(dst).Get() -} - -// AssignTo assigns from src to dst. -func (src *BPChar) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case *rune: - runes := []rune(src.String) - if len(runes) == 1 { - *v = runes[0] - return nil - } - case *string: - *v = src.String - return nil - case *[]byte: - *v = make([]byte, len(src.String)) - copy(*v, src.String) - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - - return fmt.Errorf("cannot decode %#v into %T", src, dst) -} - -func (BPChar) PreferredResultFormat() int16 { - return TextFormatCode -} - -func (dst *BPChar) DecodeText(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeText(ci, src) -} - -func (dst *BPChar) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeBinary(ci, src) -} - -func (BPChar) PreferredParamFormat() int16 { - return TextFormatCode -} - -func (src BPChar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (Text)(src).EncodeText(ci, buf) -} - -func (src BPChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (Text)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *BPChar) Scan(src interface{}) error { - return (*Text)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src BPChar) Value() (driver.Value, error) { - return (Text)(src).Value() -} - -func (src BPChar) MarshalJSON() ([]byte, error) { - return (Text)(src).MarshalJSON() -} - -func (dst *BPChar) UnmarshalJSON(b []byte) error { - return (*Text)(dst).UnmarshalJSON(b) -} diff --git a/pgtype/bpchar_array.go b/pgtype/bpchar_array.go deleted file mode 100644 index c73c78a3..00000000 --- a/pgtype/bpchar_array.go +++ /dev/null @@ -1,504 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type BPCharArray struct { - Elements []BPChar - Dimensions []ArrayDimension - Valid bool -} - -func (dst *BPCharArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = BPCharArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []string: - if value == nil { - *dst = BPCharArray{} - } else if len(value) == 0 { - *dst = BPCharArray{Valid: true} - } else { - elements := make([]BPChar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BPCharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*string: - if value == nil { - *dst = BPCharArray{} - } else if len(value) == 0 { - *dst = BPCharArray{Valid: true} - } else { - elements := make([]BPChar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BPCharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []BPChar: - if value == nil { - *dst = BPCharArray{} - } else if len(value) == 0 { - *dst = BPCharArray{Valid: true} - } else { - *dst = BPCharArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = BPCharArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for BPCharArray", src) - } - if elementsLength == 0 { - *dst = BPCharArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to BPCharArray", src) - } - - *dst = BPCharArray{ - Elements: make([]BPChar, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]BPChar, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to BPCharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *BPCharArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to BPCharArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in BPCharArray", err) - } - index++ - - return index, nil -} - -func (dst BPCharArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *BPCharArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *BPCharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from BPCharArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from BPCharArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *BPCharArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = BPCharArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []BPChar - - if len(uta.Elements) > 0 { - elements = make([]BPChar, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem BPChar - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = BPCharArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *BPCharArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = BPCharArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = BPCharArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]BPChar, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = BPCharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src BPCharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("bpchar"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "bpchar") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *BPCharArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src BPCharArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/bpchar_array_test.go b/pgtype/bpchar_array_test.go deleted file mode 100644 index 0118ad7d..00000000 --- a/pgtype/bpchar_array_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestBPCharArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "char(8)[]", []interface{}{ - &pgtype.BPCharArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.BPCharArray{ - Elements: []pgtype.BPChar{ - pgtype.BPChar{String: "foo ", Valid: true}, - pgtype.BPChar{}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.BPCharArray{}, - &pgtype.BPCharArray{ - Elements: []pgtype.BPChar{ - pgtype.BPChar{String: "bar ", Valid: true}, - pgtype.BPChar{String: "NuLL ", Valid: true}, - pgtype.BPChar{String: `wow"quz\`, Valid: true}, - pgtype.BPChar{String: "1 ", Valid: true}, - pgtype.BPChar{String: "1 ", Valid: true}, - pgtype.BPChar{String: "null ", Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 3, LowerBound: 1}, - {Length: 2, LowerBound: 1}, - }, - Valid: true, - }, - &pgtype.BPCharArray{ - Elements: []pgtype.BPChar{ - pgtype.BPChar{String: " bar ", Valid: true}, - pgtype.BPChar{String: " baz ", Valid: true}, - pgtype.BPChar{String: " quz ", Valid: true}, - pgtype.BPChar{String: "foo ", Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} diff --git a/pgtype/bpchar_test.go b/pgtype/bpchar_test.go deleted file mode 100644 index ead26220..00000000 --- a/pgtype/bpchar_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestChar3Transcode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "char(3)", []interface{}{ - &pgtype.BPChar{String: "a ", Valid: true}, - &pgtype.BPChar{String: " a ", Valid: true}, - &pgtype.BPChar{String: "å—¨ ", Valid: true}, - &pgtype.BPChar{String: " ", Valid: true}, - &pgtype.BPChar{}, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.BPChar) - b := bb.(pgtype.BPChar) - - return a.Valid == b.Valid && a.String == b.String - }) -} - -func TestBPCharAssignTo(t *testing.T) { - var ( - str string - run rune - ) - simpleTests := []struct { - src pgtype.BPChar - dst interface{} - expected interface{} - }{ - {src: pgtype.BPChar{String: "simple", Valid: true}, dst: &str, expected: "simple"}, - {src: pgtype.BPChar{String: "å—¨", Valid: true}, dst: &run, expected: 'å—¨'}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - -} diff --git a/pgtype/enum_array.go b/pgtype/enum_array.go deleted file mode 100644 index dbfb211d..00000000 --- a/pgtype/enum_array.go +++ /dev/null @@ -1,418 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "fmt" - "reflect" -) - -type EnumArray struct { - Elements []GenericText - Dimensions []ArrayDimension - Valid bool -} - -func (dst *EnumArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = EnumArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []string: - if value == nil { - *dst = EnumArray{} - } else if len(value) == 0 { - *dst = EnumArray{Valid: true} - } else { - elements := make([]GenericText, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = EnumArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*string: - if value == nil { - *dst = EnumArray{} - } else if len(value) == 0 { - *dst = EnumArray{Valid: true} - } else { - elements := make([]GenericText, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = EnumArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []GenericText: - if value == nil { - *dst = EnumArray{} - } else if len(value) == 0 { - *dst = EnumArray{Valid: true} - } else { - *dst = EnumArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = EnumArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for EnumArray", src) - } - if elementsLength == 0 { - *dst = EnumArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to EnumArray", src) - } - - *dst = EnumArray{ - Elements: make([]GenericText, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]GenericText, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to EnumArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *EnumArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to EnumArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in EnumArray", err) - } - index++ - - return index, nil -} - -func (dst EnumArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *EnumArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *EnumArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from EnumArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from EnumArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = EnumArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []GenericText - - if len(uta.Elements) > 0 { - elements = make([]GenericText, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem GenericText - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = EnumArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (src EnumArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *EnumArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src EnumArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/enum_array_test.go b/pgtype/enum_array_test.go deleted file mode 100644 index 6e49aaaf..00000000 --- a/pgtype/enum_array_test.go +++ /dev/null @@ -1,281 +0,0 @@ -package pgtype_test - -import ( - "context" - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestEnumArrayTranscode(t *testing.T) { - setupConn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, setupConn) - - if _, err := setupConn.Exec(context.Background(), "drop type if exists color"); err != nil { - t.Fatal(err) - } - if _, err := setupConn.Exec(context.Background(), "create type color as enum ('red', 'green', 'blue')"); err != nil { - t.Fatal(err) - } - - testutil.TestSuccessfulTranscode(t, "color[]", []interface{}{ - &pgtype.EnumArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.EnumArray{ - Elements: []pgtype.GenericText{ - {String: "red", Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.EnumArray{}, - &pgtype.EnumArray{ - Elements: []pgtype.GenericText{ - {String: "red", Valid: true}, - {String: "green", Valid: true}, - {String: "blue", Valid: true}, - {String: "red", Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestEnumArrayArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.EnumArray - }{ - { - source: []string{"foo"}, - result: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]string)(nil)), - result: pgtype.EnumArray{}, - }, - { - source: [][]string{{"foo"}, {"bar"}}, - result: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - result: pgtype.EnumArray{ - Elements: []pgtype.GenericText{ - {String: "foo", Valid: true}, - {String: "bar", Valid: true}, - {String: "baz", Valid: true}, - {String: "wibble", Valid: true}, - {String: "wobble", Valid: true}, - {String: "wubble", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]string{{"foo"}, {"bar"}}, - result: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - result: pgtype.EnumArray{ - Elements: []pgtype.GenericText{ - {String: "foo", Valid: true}, - {String: "bar", Valid: true}, - {String: "baz", Valid: true}, - {String: "wibble", Valid: true}, - {String: "wobble", Valid: true}, - {String: "wubble", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.EnumArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestEnumArrayArrayAssignTo(t *testing.T) { - var stringSlice []string - type _stringSlice []string - var namedStringSlice _stringSlice - var stringSliceDim2 [][]string - var stringSliceDim4 [][][][]string - var stringArrayDim2 [2][1]string - var stringArrayDim4 [2][1][1][3]string - - simpleTests := []struct { - src pgtype.EnumArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &stringSlice, - expected: []string{"foo"}, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &namedStringSlice, - expected: _stringSlice{"bar"}, - }, - { - src: pgtype.EnumArray{}, - dst: &stringSlice, - expected: (([]string)(nil)), - }, - { - src: pgtype.EnumArray{Valid: true}, - dst: &stringSlice, - expected: []string{}, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &stringSliceDim2, - expected: [][]string{{"foo"}, {"bar"}}, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{ - {String: "foo", Valid: true}, - {String: "bar", Valid: true}, - {String: "baz", Valid: true}, - {String: "wibble", Valid: true}, - {String: "wobble", Valid: true}, - {String: "wubble", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &stringSliceDim4, - expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &stringArrayDim2, - expected: [2][1]string{{"foo"}, {"bar"}}, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{ - {String: "foo", Valid: true}, - {String: "bar", Valid: true}, - {String: "baz", Valid: true}, - {String: "wibble", Valid: true}, - {String: "wobble", Valid: true}, - {String: "wubble", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &stringArrayDim4, - expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.EnumArray - dst interface{} - }{ - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &stringSlice, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &stringArrayDim2, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &stringSlice, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &stringArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/generic_text.go b/pgtype/generic_text.go deleted file mode 100644 index dbf5b47e..00000000 --- a/pgtype/generic_text.go +++ /dev/null @@ -1,39 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -// GenericText is a placeholder for text format values that no other type exists -// to handle. -type GenericText Text - -func (dst *GenericText) Set(src interface{}) error { - return (*Text)(dst).Set(src) -} - -func (dst GenericText) Get() interface{} { - return (Text)(dst).Get() -} - -func (src *GenericText) AssignTo(dst interface{}) error { - return (*Text)(src).AssignTo(dst) -} - -func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeText(ci, src) -} - -func (src GenericText) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (Text)(src).EncodeText(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *GenericText) Scan(src interface{}) error { - return (*Text)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src GenericText) Value() (driver.Value, error) { - return (Text)(src).Value() -} diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 25406a74..69c8a07b 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -159,7 +159,7 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { } var value Text - err := value.DecodeBinary(ci, valueBuf) + err := scanPlanTextAnyToTextScanner{}.Scan(ci, TextOID, TextFormatCode, valueBuf, &value) if err != nil { return err } @@ -189,7 +189,7 @@ func (src Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { buf = append(buf, quoteHstoreElementIfNeeded(k)...) buf = append(buf, "=>"...) - elemBuf, err := v.EncodeText(ci, inElemBuf) + elemBuf, err := ci.Encode(TextOID, TextFormatCode, v, inElemBuf) if err != nil { return nil, err } @@ -219,7 +219,7 @@ func (src Hstore) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { sp := len(buf) buf = pgio.AppendInt32(buf, -1) - elemBuf, err := v.EncodeText(ci, buf) + elemBuf, err := ci.Encode(TextOID, BinaryFormatCode, v, buf) if err != nil { return nil, err } diff --git a/pgtype/name.go b/pgtype/name.go deleted file mode 100644 index 7ce8d25e..00000000 --- a/pgtype/name.go +++ /dev/null @@ -1,58 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -// Name is a type used for PostgreSQL's special 63-byte -// name data type, used for identifiers like table names. -// The pg_class.relname column is a good example of where the -// name data type is used. -// -// Note that the underlying Go data type of pgx.Name is string, -// so there is no way to enforce the 63-byte length. Inputting -// a longer name into PostgreSQL will result in silent truncation -// to 63 bytes. -// -// Also, if you have custom-compiled PostgreSQL and set -// NAMEDATALEN to a different value, obviously that number of -// bytes applies, rather than the default 63. -type Name Text - -func (dst *Name) Set(src interface{}) error { - return (*Text)(dst).Set(src) -} - -func (dst Name) Get() interface{} { - return (Text)(dst).Get() -} - -func (src *Name) AssignTo(dst interface{}) error { - return (*Text)(src).AssignTo(dst) -} - -func (dst *Name) DecodeText(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeText(ci, src) -} - -func (dst *Name) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeBinary(ci, src) -} - -func (src Name) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (Text)(src).EncodeText(ci, buf) -} - -func (src Name) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (Text)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *Name) Scan(src interface{}) error { - return (*Text)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Name) Value() (driver.Value, error) { - return (Text)(src).Value() -} diff --git a/pgtype/name_test.go b/pgtype/name_test.go deleted file mode 100644 index 89b16579..00000000 --- a/pgtype/name_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestNameTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "name", []interface{}{ - &pgtype.Name{String: "", Valid: true}, - &pgtype.Name{String: "foo", Valid: true}, - &pgtype.Name{}, - }) -} - -func TestNameSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Name - }{ - {source: "foo", result: pgtype.Name{String: "foo", Valid: true}}, - {source: _string("bar"), result: pgtype.Name{String: "bar", Valid: true}}, - {source: (*string)(nil), result: pgtype.Name{}}, - } - - for i, tt := range successfulTests { - var d pgtype.Name - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if d != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } - } -} - -func TestNameAssignTo(t *testing.T) { - var s string - var ps *string - - simpleTests := []struct { - src pgtype.Name - dst interface{} - expected interface{} - }{ - {src: pgtype.Name{String: "foo", Valid: true}, dst: &s, expected: "foo"}, - {src: pgtype.Name{}, dst: &ps, expected: ((*string)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Name - dst interface{} - expected interface{} - }{ - {src: pgtype.Name{String: "foo", Valid: true}, dst: &ps, expected: "foo"}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Name - dst interface{} - }{ - {src: pgtype.Name{}, dst: &s}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index d14087cd..d6bca76d 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -254,7 +254,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID}) ci.RegisterDataType(DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementCodec: BoolCodec{}, ElementOID: BoolOID}}) - ci.RegisterDataType(DataType{Value: &BPCharArray{}, Name: "_bpchar", OID: BPCharArrayOID}) + ci.RegisterDataType(DataType{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: BPCharOID}}) ci.RegisterDataType(DataType{Value: &ByteaArray{}, Name: "_bytea", OID: ByteaArrayOID}) ci.RegisterDataType(DataType{Value: &CIDRArray{}, Name: "_cidr", OID: CIDRArrayOID}) ci.RegisterDataType(DataType{Value: &DateArray{}, Name: "_date", OID: DateArrayOID}) @@ -268,16 +268,16 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementCodec: CircleCodec{}, ElementOID: CircleOID}}) ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementCodec: PointCodec{}, ElementOID: PointOID}}) ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) - ci.RegisterDataType(DataType{Value: &TextArray{}, Name: "_text", OID: TextArrayOID}) + ci.RegisterDataType(DataType{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: TextOID}}) ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID}) ci.RegisterDataType(DataType{Value: &TimestamptzArray{}, Name: "_timestamptz", OID: TimestamptzArrayOID}) ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID}) - ci.RegisterDataType(DataType{Value: &VarcharArray{}, Name: "_varchar", OID: VarcharArrayOID}) + ci.RegisterDataType(DataType{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: VarcharOID}}) ci.RegisterDataType(DataType{Value: &ACLItem{}, Name: "aclitem", OID: ACLItemOID}) ci.RegisterDataType(DataType{Value: &Bit{}, Name: "bit", OID: BitOID}) ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) ci.RegisterDataType(DataType{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) - ci.RegisterDataType(DataType{Value: &BPChar{}, Name: "bpchar", OID: BPCharOID}) + ci.RegisterDataType(DataType{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Value: &Bytea{}, Name: "bytea", OID: ByteaOID}) ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) ci.RegisterDataType(DataType{Value: &CID{}, Name: "cid", OID: CIDOID}) @@ -300,7 +300,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Line{}, Name: "line", OID: LineOID}) ci.RegisterDataType(DataType{Value: &Lseg{}, Name: "lseg", OID: LsegOID}) ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID}) - ci.RegisterDataType(DataType{Value: &Name{}, Name: "name", OID: NameOID}) + ci.RegisterDataType(DataType{Name: "name", OID: NameOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID}) // ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) ci.RegisterDataType(DataType{Value: &OIDValue{}, Name: "oid", OID: OIDOID}) @@ -308,7 +308,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "point", OID: PointOID, Codec: PointCodec{}}) ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID}) // ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) - ci.RegisterDataType(DataType{Value: &Text{}, Name: "text", OID: TextOID}) + ci.RegisterDataType(DataType{Name: "text", OID: TextOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID}) ci.RegisterDataType(DataType{Value: &Time{}, Name: "time", OID: TimeOID}) ci.RegisterDataType(DataType{Value: &Timestamp{}, Name: "timestamp", OID: TimestampOID}) @@ -317,10 +317,10 @@ func NewConnInfo() *ConnInfo { // ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID}) // ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) // ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID}) - ci.RegisterDataType(DataType{Value: &Unknown{}, Name: "unknown", OID: UnknownOID}) + ci.RegisterDataType(DataType{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Value: &UUID{}, Name: "uuid", OID: UUIDOID}) ci.RegisterDataType(DataType{Value: &Varbit{}, Name: "varbit", OID: VarbitOID}) - ci.RegisterDataType(DataType{Value: &Varchar{}, Name: "varchar", OID: VarcharOID}) + ci.RegisterDataType(DataType{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID}) registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) { @@ -786,6 +786,22 @@ func tryBaseTypeScanPlan(dst interface{}) (plan *baseTypeScanPlan, nextDst inter return nil, nil, false } +type pointerEmptyInterfaceScanPlan struct { + codec Codec +} + +func (plan *pointerEmptyInterfaceScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + value, err := plan.codec.DecodeValue(ci, oid, formatCode, src) + if err != nil { + return err + } + + ptrAny := dst.(*interface{}) + *ptrAny = value + + return nil +} + // PlanScan prepares a plan to scan a value into dst. func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan { switch formatCode { @@ -826,6 +842,8 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan } case TextDecoder: return scanPlanDstTextDecoder{} + case TextScanner: + return scanPlanTextAnyToTextScanner{} } } @@ -859,6 +877,10 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan return baseTypePlan } } + + if _, ok := dst.(*interface{}); ok { + return &pointerEmptyInterfaceScanPlan{codec: dt.Codec} + } } if dt != nil { @@ -961,11 +983,83 @@ func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) Enco return plan } + if derefPointerPlan, nextValue, ok := tryDerefPointerEncodePlan(value); ok { + if nextPlan := ci.PlanEncode(oid, format, nextValue); nextPlan != nil { + derefPointerPlan.next = nextPlan + return derefPointerPlan + } + } + + if baseTypePlan, nextValue, ok := tryBaseTypeEncodePlan(value); ok { + if nextPlan := ci.PlanEncode(oid, format, nextValue); nextPlan != nil { + baseTypePlan.next = nextPlan + return baseTypePlan + } + } + } return nil } +type derefPointerEncodePlan struct { + next EncodePlan +} + +func (plan *derefPointerEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + ptr := reflect.ValueOf(value) + + if ptr.IsNil() { + return nil, nil + } + + return plan.next.Encode(ptr.Elem().Interface(), buf) +} + +func tryDerefPointerEncodePlan(value interface{}) (plan *derefPointerEncodePlan, nextValue interface{}, ok bool) { + if valueType := reflect.TypeOf(value); valueType.Kind() == reflect.Ptr { + return &derefPointerEncodePlan{}, reflect.New(valueType.Elem()).Elem().Interface(), true + } + + return nil, nil, false +} + +var kindToBaseTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ + reflect.Int: reflect.TypeOf(int(0)), + reflect.Int8: reflect.TypeOf(int8(0)), + reflect.Int16: reflect.TypeOf(int16(0)), + reflect.Int32: reflect.TypeOf(int32(0)), + reflect.Int64: reflect.TypeOf(int64(0)), + reflect.Uint: reflect.TypeOf(uint(0)), + reflect.Uint8: reflect.TypeOf(uint8(0)), + reflect.Uint16: reflect.TypeOf(uint16(0)), + reflect.Uint32: reflect.TypeOf(uint32(0)), + reflect.Uint64: reflect.TypeOf(uint64(0)), + reflect.Float32: reflect.TypeOf(float32(0)), + reflect.Float64: reflect.TypeOf(float64(0)), + reflect.String: reflect.TypeOf(""), +} + +type baseTypeEncodePlan struct { + nextValueType reflect.Type + next EncodePlan +} + +func (plan *baseTypeEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(reflect.ValueOf(value).Convert(plan.nextValueType).Interface(), buf) +} + +func tryBaseTypeEncodePlan(value interface{}) (plan *baseTypeEncodePlan, nextValue interface{}, ok bool) { + refValue := reflect.ValueOf(value) + + nextValueType := kindToBaseTypes[refValue.Kind()] + if nextValueType != nil && refValue.Type() != nextValueType { + return &baseTypeEncodePlan{nextValueType: nextValueType}, refValue.Convert(nextValueType).Interface(), true + } + + return nil, nil, false +} + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written. diff --git a/pgtype/text.go b/pgtype/text.go index 5d27c44f..3cb1cfa3 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -4,141 +4,29 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "unicode/utf8" ) +type TextScanner interface { + ScanText(v Text) error +} + +type TextValuer interface { + TextValue() (Text, error) +} + type Text struct { String string Valid bool } -func (dst *Text) Set(src interface{}) error { - if src == nil { - *dst = Text{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case string: - *dst = Text{String: value, Valid: true} - case *string: - if value == nil { - *dst = Text{} - } else { - *dst = Text{String: *value, Valid: true} - } - case []byte: - if value == nil { - *dst = Text{} - } else { - *dst = Text{String: string(value), Valid: true} - } - case fmt.Stringer: - if value == fmt.Stringer(nil) { - *dst = Text{} - } else { - *dst = Text{String: value.String(), Valid: true} - } - default: - // Cannot be part of the switch: If Value() returns nil on - // non-string, we should still try to checks the underlying type - // using reflection. - // - // For example the struct might implement driver.Valuer with - // pointer receiver and fmt.Stringer with value receiver. - if value, ok := src.(driver.Valuer); ok { - if value == driver.Valuer(nil) { - *dst = Text{} - return nil - } else { - v, err := value.Value() - if err != nil { - return fmt.Errorf("driver.Valuer Value() method failed: %w", err) - } - - // Handles also v == nil case. - if s, ok := v.(string); ok { - *dst = Text{String: s, Valid: true} - return nil - } - } - } - - if originalSrc, ok := underlyingStringType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Text", value) - } - +func (t *Text) ScanText(v Text) error { + *t = v return nil } -func (dst Text) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.String -} - -func (src *Text) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case *string: - *v = src.String - return nil - case *[]byte: - *v = make([]byte, len(src.String)) - copy(*v, src.String) - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } -} - -func (Text) PreferredResultFormat() int16 { - return TextFormatCode -} - -func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Text{} - return nil - } - - *dst = Text{String: string(src), Valid: true} - return nil -} - -func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { - return dst.DecodeText(ci, src) -} - -func (Text) PreferredParamFormat() int16 { - return TextFormatCode -} - -func (src Text) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return append(buf, src.String...), nil -} - -func (src Text) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return src.EncodeText(ci, buf) +func (t Text) TextValue() (Text, error) { + return t, nil } // Scan implements the database/sql Scanner interface. @@ -150,11 +38,11 @@ func (dst *Text) Scan(src interface{}) error { switch src := src.(type) { case string: - return dst.DecodeText(nil, []byte(src)) + *dst = Text{String: src, Valid: true} + return nil case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + *dst = Text{String: string(src), Valid: true} + return nil } return fmt.Errorf("cannot scan %T", src) @@ -191,3 +79,169 @@ func (dst *Text) UnmarshalJSON(b []byte) error { return nil } + +type TextCodec struct{} + +func (TextCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (TextCodec) PreferredFormat() int16 { + return TextFormatCode +} + +func (TextCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch value.(type) { + case string: + return encodePlanTextCodecString{} + case []byte: + return encodePlanTextCodecByteSlice{} + case rune: + return encodePlanTextCodecRune{} + case fmt.Stringer: + return encodePlanTextCodecStringer{} + case TextValuer: + return encodePlanTextCodecTextValuer{} + } + } + + return nil +} + +type encodePlanTextCodecString struct{} + +func (encodePlanTextCodecString) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + s := value.(string) + buf = append(buf, s...) + return buf, nil +} + +type encodePlanTextCodecByteSlice struct{} + +func (encodePlanTextCodecByteSlice) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + s := value.([]byte) + buf = append(buf, s...) + return buf, nil +} + +type encodePlanTextCodecRune struct{} + +func (encodePlanTextCodecRune) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + r := value.(rune) + buf = append(buf, string(r)...) + return buf, nil +} + +type encodePlanTextCodecStringer struct{} + +func (encodePlanTextCodecStringer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + s := value.(fmt.Stringer) + buf = append(buf, s.String()...) + return buf, nil +} + +type encodePlanTextCodecTextValuer struct{} + +func (encodePlanTextCodecTextValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + text, err := value.(TextValuer).TextValue() + if err != nil { + return nil, err + } + + if !text.Valid { + return nil, nil + } + + buf = append(buf, text.String...) + return buf, nil +} + +func (TextCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case TextFormatCode, BinaryFormatCode: + switch target.(type) { + case *string: + return scanPlanTextAnyToString{} + case *[]byte: + return scanPlanAnyToNewByteSlice{} + case TextScanner: + return scanPlanTextAnyToTextScanner{} + case *rune: + return scanPlanTextAnyToRune{} + } + } + + return nil +} + +func (c TextCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(ci, oid, format, src) +} + +func (c TextCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + return string(src), nil +} + +type scanPlanTextAnyToString struct{} + +func (scanPlanTextAnyToString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p := (dst).(*string) + *p = string(src) + + return nil +} + +type scanPlanAnyToNewByteSlice struct{} + +func (scanPlanAnyToNewByteSlice) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + p := (dst).(*[]byte) + if src == nil { + *p = nil + } else { + *p = make([]byte, len(src)) + copy(*p, src) + } + + return nil +} + +type scanPlanTextAnyToRune struct{} + +func (scanPlanTextAnyToRune) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + r, size := utf8.DecodeRune(src) + if size != len(src) { + return fmt.Errorf("cannot scan %v into %T: more than one rune received", src, dst) + } + + p := (dst).(*rune) + *p = r + + return nil +} + +type scanPlanTextAnyToTextScanner struct{} + +func (scanPlanTextAnyToTextScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(TextScanner) + + if src == nil { + return scanner.ScanText(Text{}) + } + + return scanner.ScanText(Text{String: string(src), Valid: true}) +} diff --git a/pgtype/text_array.go b/pgtype/text_array.go deleted file mode 100644 index 7fcc1c4d..00000000 --- a/pgtype/text_array.go +++ /dev/null @@ -1,504 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type TextArray struct { - Elements []Text - Dimensions []ArrayDimension - Valid bool -} - -func (dst *TextArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = TextArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []string: - if value == nil { - *dst = TextArray{} - } else if len(value) == 0 { - *dst = TextArray{Valid: true} - } else { - elements := make([]Text, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TextArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*string: - if value == nil { - *dst = TextArray{} - } else if len(value) == 0 { - *dst = TextArray{Valid: true} - } else { - elements := make([]Text, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TextArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Text: - if value == nil { - *dst = TextArray{} - } else if len(value) == 0 { - *dst = TextArray{Valid: true} - } else { - *dst = TextArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = TextArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for TextArray", src) - } - if elementsLength == 0 { - *dst = TextArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to TextArray", src) - } - - *dst = TextArray{ - Elements: make([]Text, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Text, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to TextArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *TextArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to TextArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in TextArray", err) - } - index++ - - return index, nil -} - -func (dst TextArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *TextArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *TextArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from TextArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from TextArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TextArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Text - - if len(uta.Elements) > 0 { - elements = make([]Text, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Text - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = TextArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TextArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = TextArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Text, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = TextArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("text"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "text") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *TextArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src TextArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/text_array_test.go b/pgtype/text_array_test.go deleted file mode 100644 index 22e2ca27..00000000 --- a/pgtype/text_array_test.go +++ /dev/null @@ -1,294 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// https://github.com/jackc/pgtype/issues/78 -func TestTextArrayDecodeTextNull(t *testing.T) { - textArray := &pgtype.TextArray{} - err := textArray.DecodeText(nil, []byte(`{abc,"NULL",NULL,def}`)) - require.NoError(t, err) - require.Len(t, textArray.Elements, 4) - assert.Equal(t, true, textArray.Elements[1].Valid) - assert.Equal(t, false, textArray.Elements[2].Valid) -} - -func TestTextArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "text[]", []interface{}{ - &pgtype.TextArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "foo", Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.TextArray{}, - &pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "bar ", Valid: true}, - {String: "NuLL", Valid: true}, - {String: `wow"quz\`, Valid: true}, - {String: "", Valid: true}, - {}, - {String: "null", Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "bar", Valid: true}, - {String: "baz", Valid: true}, - {String: "quz", Valid: true}, - {String: "foo", Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestTextArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.TextArray - }{ - { - source: []string{"foo"}, - result: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]string)(nil)), - result: pgtype.TextArray{}, - }, - { - source: [][]string{{"foo"}, {"bar"}}, - result: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - result: pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "foo", Valid: true}, - {String: "bar", Valid: true}, - {String: "baz", Valid: true}, - {String: "wibble", Valid: true}, - {String: "wobble", Valid: true}, - {String: "wubble", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]string{{"foo"}, {"bar"}}, - result: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - result: pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "foo", Valid: true}, - {String: "bar", Valid: true}, - {String: "baz", Valid: true}, - {String: "wibble", Valid: true}, - {String: "wobble", Valid: true}, - {String: "wubble", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.TextArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestTextArrayAssignTo(t *testing.T) { - var stringSlice []string - type _stringSlice []string - var namedStringSlice _stringSlice - var stringSliceDim2 [][]string - var stringSliceDim4 [][][][]string - var stringArrayDim2 [2][1]string - var stringArrayDim4 [2][1][1][3]string - - simpleTests := []struct { - src pgtype.TextArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &stringSlice, - expected: []string{"foo"}, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &namedStringSlice, - expected: _stringSlice{"bar"}, - }, - { - src: pgtype.TextArray{}, - dst: &stringSlice, - expected: (([]string)(nil)), - }, - { - src: pgtype.TextArray{Valid: true}, - dst: &stringSlice, - expected: []string{}, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &stringSliceDim2, - expected: [][]string{{"foo"}, {"bar"}}, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "foo", Valid: true}, - {String: "bar", Valid: true}, - {String: "baz", Valid: true}, - {String: "wibble", Valid: true}, - {String: "wobble", Valid: true}, - {String: "wubble", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &stringSliceDim4, - expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &stringArrayDim2, - expected: [2][1]string{{"foo"}, {"bar"}}, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "foo", Valid: true}, - {String: "bar", Valid: true}, - {String: "baz", Valid: true}, - {String: "wibble", Valid: true}, - {String: "wobble", Valid: true}, - {String: "wubble", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &stringArrayDim4, - expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.TextArray - dst interface{} - }{ - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &stringSlice, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &stringArrayDim2, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &stringSlice, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &stringArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/text_test.go b/pgtype/text_test.go index dca6af4a..148aa97b 100644 --- a/pgtype/text_test.go +++ b/pgtype/text_test.go @@ -1,125 +1,71 @@ package pgtype_test import ( - "bytes" - "reflect" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestTextTranscode(t *testing.T) { +func TestTextCodec(t *testing.T) { for _, pgTypeName := range []string{"text", "varchar"} { - testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ - &pgtype.Text{String: "", Valid: true}, - &pgtype.Text{String: "foo", Valid: true}, - &pgtype.Text{}, + testPgxCodec(t, pgTypeName, []PgxTranscodeTestCase{ + { + pgtype.Text{String: "", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "", Valid: true}), + }, + { + pgtype.Text{String: "foo", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "foo", Valid: true}), + }, + {nil, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, + {"foo", new(string), isExpectedEq("foo")}, + {rune('R'), new(rune), isExpectedEq(rune('R'))}, }) } } -func TestTextSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Text - }{ - {source: "foo", result: pgtype.Text{String: "foo", Valid: true}}, - {source: _string("bar"), result: pgtype.Text{String: "bar", Valid: true}}, - {source: (*string)(nil), result: pgtype.Text{}}, - } - - for i, tt := range successfulTests { - var d pgtype.Text - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if d != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } - } +// name is PostgreSQL's special 63-byte data type, used for identifiers like table names. The pg_class.relname column +// is a good example of where the name data type is used. +// +// TextCodec does not do length checking. Inputting a longer name into PostgreSQL will result in silent truncation to +// 63 bytes. +// +// Length checking would be possible with a Codec specialized for "name" but it would be perfect because a +// custom-compiled PostgreSQL could have set NAMEDATALEN to a different value rather than the default 63. +// +// So this is simply a smoke test of the name type. +func TestTextCodecName(t *testing.T) { + testPgxCodec(t, "name", []PgxTranscodeTestCase{ + { + pgtype.Text{String: "", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "", Valid: true}), + }, + { + pgtype.Text{String: "foo", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "foo", Valid: true}), + }, + {nil, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, + {"foo", new(string), isExpectedEq("foo")}, + }) } -func TestTextAssignTo(t *testing.T) { - var s string - var ps *string - - stringTests := []struct { - src pgtype.Text - dst interface{} - expected interface{} - }{ - {src: pgtype.Text{String: "foo", Valid: true}, dst: &s, expected: "foo"}, - {src: pgtype.Text{}, dst: &ps, expected: ((*string)(nil))}, - } - - for i, tt := range stringTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - var buf []byte - - bytesTests := []struct { - src pgtype.Text - dst *[]byte - expected []byte - }{ - {src: pgtype.Text{String: "foo", Valid: true}, dst: &buf, expected: []byte("foo")}, - {src: pgtype.Text{}, dst: &buf, expected: nil}, - } - - for i, tt := range bytesTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if bytes.Compare(*tt.dst, tt.expected) != 0 { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Text - dst interface{} - expected interface{} - }{ - {src: pgtype.Text{String: "foo", Valid: true}, dst: &ps, expected: "foo"}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Text - dst interface{} - }{ - {src: pgtype.Text{}, dst: &s}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } +// Test fixed length char types like char(3) +func TestTextCodecBPChar(t *testing.T) { + testPgxCodec(t, "char(3)", []PgxTranscodeTestCase{ + { + pgtype.Text{String: "a ", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "a ", Valid: true}), + }, + {nil, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, + {" ", new(string), isExpectedEq(" ")}, + {"", new(string), isExpectedEq(" ")}, + {" å—¨ ", new(string), isExpectedEq(" å—¨ ")}, + }) } func TestTextMarshalJSON(t *testing.T) { diff --git a/pgtype/unknown.go b/pgtype/unknown.go deleted file mode 100644 index 0e576ee9..00000000 --- a/pgtype/unknown.go +++ /dev/null @@ -1,44 +0,0 @@ -package pgtype - -import "database/sql/driver" - -// Unknown represents the PostgreSQL unknown type. It is either a string literal -// or NULL. It is used when PostgreSQL does not know the type of a value. In -// general, this will only be used in pgx when selecting a null value without -// type information. e.g. SELECT NULL; -type Unknown struct { - String string - Valid bool -} - -func (dst *Unknown) Set(src interface{}) error { - return (*Text)(dst).Set(src) -} - -func (dst Unknown) Get() interface{} { - return (Text)(dst).Get() -} - -// AssignTo assigns from src to dst. Note that as Unknown is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *Unknown) AssignTo(dst interface{}) error { - return (*Text)(src).AssignTo(dst) -} - -func (dst *Unknown) DecodeText(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeText(ci, src) -} - -func (dst *Unknown) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeBinary(ci, src) -} - -// Scan implements the database/sql Scanner interface. -func (dst *Unknown) Scan(src interface{}) error { - return (*Text)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Unknown) Value() (driver.Value, error) { - return (Text)(src).Value() -} diff --git a/pgtype/varchar.go b/pgtype/varchar.go deleted file mode 100644 index fea31d18..00000000 --- a/pgtype/varchar.go +++ /dev/null @@ -1,66 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -type Varchar Text - -// Set converts from src to dst. Note that as Varchar is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *Varchar) Set(src interface{}) error { - return (*Text)(dst).Set(src) -} - -func (dst Varchar) Get() interface{} { - return (Text)(dst).Get() -} - -// AssignTo assigns from src to dst. Note that as Varchar is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *Varchar) AssignTo(dst interface{}) error { - return (*Text)(src).AssignTo(dst) -} - -func (Varchar) PreferredResultFormat() int16 { - return TextFormatCode -} - -func (dst *Varchar) DecodeText(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeText(ci, src) -} - -func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeBinary(ci, src) -} - -func (Varchar) PreferredParamFormat() int16 { - return TextFormatCode -} - -func (src Varchar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (Text)(src).EncodeText(ci, buf) -} - -func (src Varchar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (Text)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *Varchar) Scan(src interface{}) error { - return (*Text)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Varchar) Value() (driver.Value, error) { - return (Text)(src).Value() -} - -func (src Varchar) MarshalJSON() ([]byte, error) { - return (Text)(src).MarshalJSON() -} - -func (dst *Varchar) UnmarshalJSON(b []byte) error { - return (*Text)(dst).UnmarshalJSON(b) -} diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go deleted file mode 100644 index 3e0913dc..00000000 --- a/pgtype/varchar_array.go +++ /dev/null @@ -1,504 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type VarcharArray struct { - Elements []Varchar - Dimensions []ArrayDimension - Valid bool -} - -func (dst *VarcharArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = VarcharArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []string: - if value == nil { - *dst = VarcharArray{} - } else if len(value) == 0 { - *dst = VarcharArray{Valid: true} - } else { - elements := make([]Varchar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = VarcharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*string: - if value == nil { - *dst = VarcharArray{} - } else if len(value) == 0 { - *dst = VarcharArray{Valid: true} - } else { - elements := make([]Varchar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = VarcharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Varchar: - if value == nil { - *dst = VarcharArray{} - } else if len(value) == 0 { - *dst = VarcharArray{Valid: true} - } else { - *dst = VarcharArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = VarcharArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for VarcharArray", src) - } - if elementsLength == 0 { - *dst = VarcharArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to VarcharArray", src) - } - - *dst = VarcharArray{ - Elements: make([]Varchar, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Varchar, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to VarcharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *VarcharArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to VarcharArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in VarcharArray", err) - } - index++ - - return index, nil -} - -func (dst VarcharArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *VarcharArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *VarcharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from VarcharArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from VarcharArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = VarcharArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Varchar - - if len(uta.Elements) > 0 { - elements = make([]Varchar, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Varchar - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = VarcharArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = VarcharArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = VarcharArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Varchar, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = VarcharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("varchar"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "varchar") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *VarcharArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src VarcharArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/varchar_array_test.go b/pgtype/varchar_array_test.go deleted file mode 100644 index 2d437274..00000000 --- a/pgtype/varchar_array_test.go +++ /dev/null @@ -1,282 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestVarcharArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "varchar[]", []interface{}{ - &pgtype.VarcharArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "foo", Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.VarcharArray{}, - &pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "bar ", Valid: true}, - {String: "NuLL", Valid: true}, - {String: `wow"quz\`, Valid: true}, - {String: "", Valid: true}, - {}, - {String: "null", Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "bar", Valid: true}, - {String: "baz", Valid: true}, - {String: "quz", Valid: true}, - {String: "foo", Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestVarcharArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.VarcharArray - }{ - { - source: []string{"foo"}, - result: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]string)(nil)), - result: pgtype.VarcharArray{}, - }, - { - source: [][]string{{"foo"}, {"bar"}}, - result: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - result: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "foo", Valid: true}, - {String: "bar", Valid: true}, - {String: "baz", Valid: true}, - {String: "wibble", Valid: true}, - {String: "wobble", Valid: true}, - {String: "wubble", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]string{{"foo"}, {"bar"}}, - result: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - result: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "foo", Valid: true}, - {String: "bar", Valid: true}, - {String: "baz", Valid: true}, - {String: "wibble", Valid: true}, - {String: "wobble", Valid: true}, - {String: "wubble", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.VarcharArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestVarcharArrayAssignTo(t *testing.T) { - var stringSlice []string - type _stringSlice []string - var namedStringSlice _stringSlice - var stringSliceDim2 [][]string - var stringSliceDim4 [][][][]string - var stringArrayDim2 [2][1]string - var stringArrayDim4 [2][1][1][3]string - - simpleTests := []struct { - src pgtype.VarcharArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &stringSlice, - expected: []string{"foo"}, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &namedStringSlice, - expected: _stringSlice{"bar"}, - }, - { - src: pgtype.VarcharArray{}, - dst: &stringSlice, - expected: (([]string)(nil)), - }, - { - src: pgtype.VarcharArray{Valid: true}, - dst: &stringSlice, - expected: []string{}, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &stringSliceDim2, - expected: [][]string{{"foo"}, {"bar"}}, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "foo", Valid: true}, - {String: "bar", Valid: true}, - {String: "baz", Valid: true}, - {String: "wibble", Valid: true}, - {String: "wobble", Valid: true}, - {String: "wubble", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &stringSliceDim4, - expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &stringArrayDim2, - expected: [2][1]string{{"foo"}, {"bar"}}, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "foo", Valid: true}, - {String: "bar", Valid: true}, - {String: "baz", Valid: true}, - {String: "wibble", Valid: true}, - {String: "wobble", Valid: true}, - {String: "wubble", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &stringArrayDim4, - expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.VarcharArray - dst interface{} - }{ - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &stringSlice, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &stringArrayDim2, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &stringSlice, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &stringArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/zeronull/text.go b/pgtype/zeronull/text.go index 33ce367f..fcbc16d7 100644 --- a/pgtype/zeronull/text.go +++ b/pgtype/zeronull/text.go @@ -8,68 +8,22 @@ import ( type Text string -func (dst *Text) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - var nullable pgtype.Text - err := nullable.DecodeText(ci, src) - if err != nil { - return err +// ScanText implements the TextScanner interface. +func (dst *Text) ScanText(v pgtype.Text) error { + if !v.Valid { + *dst = "" + return nil } - if nullable.Valid { - *dst = Text(nullable.String) - } else { - *dst = Text("") - } + *dst = Text(v.String) return nil } -func (dst *Text) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - var nullable pgtype.Text - err := nullable.DecodeBinary(ci, src) - if err != nil { - return err - } - - if nullable.Valid { - *dst = Text(nullable.String) - } else { - *dst = Text("") - } - - return nil -} - -func (src Text) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if src == Text("") { - return nil, nil - } - - nullable := pgtype.Text{ - String: string(src), - Valid: true, - } - - return nullable.EncodeText(ci, buf) -} - -func (src Text) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if src == Text("") { - return nil, nil - } - - nullable := pgtype.Text{ - String: string(src), - Valid: true, - } - - return nullable.EncodeBinary(ci, buf) -} - // Scan implements the database/sql Scanner interface. func (dst *Text) Scan(src interface{}) error { if src == nil { - *dst = Text("") + *dst = "" return nil } @@ -86,5 +40,8 @@ func (dst *Text) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Text) Value() (driver.Value, error) { - return pgtype.EncodeValueText(src) + if src == "" { + return nil, nil + } + return string(src), nil } diff --git a/query_test.go b/query_test.go index c0fbebaf..c22c2795 100644 --- a/query_test.go +++ b/query_test.go @@ -245,7 +245,7 @@ func TestConnQueryReadRowMultipleTimes(t *testing.T) { var a, b string var c int32 - var d pgtype.Unknown + var d pgtype.Text var e int32 err = rows.Scan(&a, &b, &c, &d, &e) @@ -958,6 +958,7 @@ func TestQueryRowCoreByteSlice(t *testing.T) { } func TestQueryRowErrors(t *testing.T) { + t.Skip("TODO - unskip later in v5") t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -1939,6 +1940,7 @@ func TestConnQueryFunc(t *testing.T) { } func TestConnQueryFuncScanError(t *testing.T) { + t.Skip("TODO - unskip later in v5") t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { diff --git a/rows.go b/rows.go index 0cc09ad9..62a19016 100644 --- a/rows.go +++ b/rows.go @@ -252,15 +252,15 @@ func (rows *connRows) Values() ([]interface{}, error) { switch fd.Format { case TextFormatCode: - decoder, ok := value.(pgtype.TextDecoder) - if !ok { - decoder = &pgtype.GenericText{} + if decoder, ok := value.(pgtype.TextDecoder); ok { + err := decoder.DecodeText(rows.connInfo, buf) + if err != nil { + rows.fatal(err) + } + values = append(values, decoder.(pgtype.Value).Get()) + } else { + values = append(values, string(buf)) } - err := decoder.DecodeText(rows.connInfo, buf) - if err != nil { - rows.fatal(err) - } - values = append(values, decoder.(pgtype.Value).Get()) case BinaryFormatCode: decoder, ok := value.(pgtype.BinaryDecoder) if !ok { @@ -284,12 +284,7 @@ func (rows *connRows) Values() ([]interface{}, error) { } else { switch fd.Format { case TextFormatCode: - decoder := &pgtype.GenericText{} - err := decoder.DecodeText(rows.connInfo, buf) - if err != nil { - rows.fatal(err) - } - values = append(values, decoder.Get()) + values = append(values, string(buf)) case BinaryFormatCode: decoder := &pgtype.GenericBinary{} err := decoder.DecodeBinary(rows.connInfo, buf) From 58d2d8e453569e78ce2163e522b30c4e555d3fdf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 13:16:09 -0600 Subject: [PATCH 0814/1158] Add name array --- pgtype/pgtype.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index d6bca76d..6c26db16 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -43,6 +43,7 @@ const ( MacaddrOID = 829 InetOID = 869 BoolArrayOID = 1000 + NameArrayOID = 1003 Int2ArrayOID = 1005 Int4ArrayOID = 1007 TextArrayOID = 1009 @@ -267,6 +268,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementCodec: BoxCodec{}, ElementOID: BoxOID}}) ci.RegisterDataType(DataType{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementCodec: CircleCodec{}, ElementOID: CircleOID}}) ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementCodec: PointCodec{}, ElementOID: PointOID}}) + ci.RegisterDataType(DataType{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: NameOID}}) ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) ci.RegisterDataType(DataType{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: TextOID}}) ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID}) From 6a6878bafd9977f74a568b56cbc4766d59a83619 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 13:29:47 -0600 Subject: [PATCH 0815/1158] Fix Box, Circle, and Point NULL --- pgtype/box.go | 12 ++++++++++++ pgtype/box_test.go | 1 + pgtype/circle.go | 12 ++++++++++++ pgtype/circle_test.go | 1 + pgtype/pgtype.go | 4 ++++ pgtype/point.go | 12 ++++++++++++ pgtype/point_test.go | 1 + 7 files changed, 43 insertions(+) diff --git a/pgtype/box.go b/pgtype/box.go index b5c30ed3..5e841ae5 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -50,6 +50,10 @@ func (dst *Box) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Box) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + buf, err := BoxCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) if err != nil { return nil, err @@ -90,6 +94,10 @@ func (p *encodePlanBoxCodecBinary) Encode(value interface{}, buf []byte) (newBuf return nil, err } + if !box.Valid { + return nil, nil + } + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].X)) buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].Y)) buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].X)) @@ -105,6 +113,10 @@ func (p *encodePlanBoxCodecText) Encode(value interface{}, buf []byte) (newBuf [ return nil, err } + if !box.Valid { + return nil, nil + } + buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, strconv.FormatFloat(box.P[0].X, 'f', -1, 64), strconv.FormatFloat(box.P[0].Y, 'f', -1, 64), diff --git a/pgtype/box_test.go b/pgtype/box_test.go index f4e26370..8056e819 100644 --- a/pgtype/box_test.go +++ b/pgtype/box_test.go @@ -31,6 +31,7 @@ func TestBoxCodec(t *testing.T) { Valid: true, }), }, + {pgtype.Box{}, new(pgtype.Box), isExpectedEq(pgtype.Box{})}, {nil, new(pgtype.Box), isExpectedEq(pgtype.Box{})}, }) } diff --git a/pgtype/circle.go b/pgtype/circle.go index f214a070..5d9055e5 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -51,6 +51,10 @@ func (dst *Circle) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Circle) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + buf, err := CircleCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) if err != nil { return nil, err @@ -91,6 +95,10 @@ func (p *encodePlanCircleCodecBinary) Encode(value interface{}, buf []byte) (new return nil, err } + if !circle.Valid { + return nil, nil + } + buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.X)) buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.Y)) buf = pgio.AppendUint64(buf, math.Float64bits(circle.R)) @@ -105,6 +113,10 @@ func (p *encodePlanCircleCodecText) Encode(value interface{}, buf []byte) (newBu return nil, err } + if !circle.Valid { + return nil, nil + } + buf = append(buf, fmt.Sprintf(`<(%s,%s),%s>`, strconv.FormatFloat(circle.P.X, 'f', -1, 64), strconv.FormatFloat(circle.P.Y, 'f', -1, 64), diff --git a/pgtype/circle_test.go b/pgtype/circle_test.go index 742ac688..6fbf4c31 100644 --- a/pgtype/circle_test.go +++ b/pgtype/circle_test.go @@ -18,6 +18,7 @@ func TestCircleTranscode(t *testing.T) { new(pgtype.Circle), isExpectedEq(pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}), }, + {pgtype.Circle{}, new(pgtype.Circle), isExpectedEq(pgtype.Circle{})}, {nil, new(pgtype.Circle), isExpectedEq(pgtype.Circle{})}, }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 6c26db16..6ead989e 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -949,6 +949,10 @@ func codecScan(codec Codec, ci *ConnInfo, oid uint32, format int16, src []byte, } func codecDecodeToTextFormat(codec Codec, ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + if format == TextFormatCode { return string(src), nil } else { diff --git a/pgtype/point.go b/pgtype/point.go index b4236c8f..256bedc0 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -85,6 +85,10 @@ func (dst *Point) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Point) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + buf, err := PointCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) if err != nil { return nil, err @@ -146,6 +150,10 @@ func (p *encodePlanPointCodecBinary) Encode(value interface{}, buf []byte) (newB return nil, err } + if !point.Valid { + return nil, nil + } + buf = pgio.AppendUint64(buf, math.Float64bits(point.P.X)) buf = pgio.AppendUint64(buf, math.Float64bits(point.P.Y)) return buf, nil @@ -159,6 +167,10 @@ func (p *encodePlanPointCodecText) Encode(value interface{}, buf []byte) (newBuf return nil, err } + if !point.Valid { + return nil, nil + } + return append(buf, fmt.Sprintf(`(%s,%s)`, strconv.FormatFloat(point.P.X, 'f', -1, 64), strconv.FormatFloat(point.P.Y, 'f', -1, 64), diff --git a/pgtype/point_test.go b/pgtype/point_test.go index 718e203a..8046da92 100644 --- a/pgtype/point_test.go +++ b/pgtype/point_test.go @@ -20,6 +20,7 @@ func TestPointCodec(t *testing.T) { new(pgtype.Point), isExpectedEq(pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Valid: true}), }, + {pgtype.Point{}, new(pgtype.Point), isExpectedEq(pgtype.Point{})}, {nil, new(pgtype.Point), isExpectedEq(pgtype.Point{})}, }) } From 4aff33603dc1731ca0c448f2154f7ccf550d74c9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 13:37:23 -0600 Subject: [PATCH 0816/1158] Remove useless receivers --- pgtype/bool.go | 4 ++-- pgtype/box.go | 4 ++-- pgtype/circle.go | 4 ++-- pgtype/int.go | 12 ++++++------ pgtype/int.go.erb | 4 ++-- pgtype/point.go | 4 ++-- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/pgtype/bool.go b/pgtype/bool.go index d2c3cdc3..60bec0f3 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -118,7 +118,7 @@ func (BoolCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interf type encodePlanBoolCodecBinary struct{} -func (p *encodePlanBoolCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBoolCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { v, valid, err := convertToBoolForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to bool: %v", value, err) @@ -141,7 +141,7 @@ func (p *encodePlanBoolCodecBinary) Encode(value interface{}, buf []byte) (newBu type encodePlanBoolCodecText struct{} -func (p *encodePlanBoolCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBoolCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { v, valid, err := convertToBoolForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to bool: %v", value, err) diff --git a/pgtype/box.go b/pgtype/box.go index 5e841ae5..8f96a016 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -88,7 +88,7 @@ func (BoxCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interfa type encodePlanBoxCodecBinary struct{} -func (p *encodePlanBoxCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBoxCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { box, err := value.(BoxValuer).BoxValue() if err != nil { return nil, err @@ -107,7 +107,7 @@ func (p *encodePlanBoxCodecBinary) Encode(value interface{}, buf []byte) (newBuf type encodePlanBoxCodecText struct{} -func (p *encodePlanBoxCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBoxCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { box, err := value.(BoxValuer).BoxValue() if err != nil { return nil, err diff --git a/pgtype/circle.go b/pgtype/circle.go index 5d9055e5..f2c591b4 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -89,7 +89,7 @@ func (CircleCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value inte type encodePlanCircleCodecBinary struct{} -func (p *encodePlanCircleCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanCircleCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { circle, err := value.(CircleValuer).CircleValue() if err != nil { return nil, err @@ -107,7 +107,7 @@ func (p *encodePlanCircleCodecBinary) Encode(value interface{}, buf []byte) (new type encodePlanCircleCodecText struct{} -func (p *encodePlanCircleCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanCircleCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { circle, err := value.(CircleValuer).CircleValue() if err != nil { return nil, err diff --git a/pgtype/int.go b/pgtype/int.go index 18b1ba90..785b481b 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -132,7 +132,7 @@ func (Int2Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interf type encodePlanInt2CodecBinary struct{} -func (p *encodePlanInt2CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt2CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { n, valid, err := convertToInt64ForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to int2: %v", value, err) @@ -153,7 +153,7 @@ func (p *encodePlanInt2CodecBinary) Encode(value interface{}, buf []byte) (newBu type encodePlanInt2CodecText struct{} -func (p *encodePlanInt2CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt2CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { n, valid, err := convertToInt64ForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to int2: %v", value, err) @@ -639,7 +639,7 @@ func (Int4Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interf type encodePlanInt4CodecBinary struct{} -func (p *encodePlanInt4CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt4CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { n, valid, err := convertToInt64ForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to int4: %v", value, err) @@ -660,7 +660,7 @@ func (p *encodePlanInt4CodecBinary) Encode(value interface{}, buf []byte) (newBu type encodePlanInt4CodecText struct{} -func (p *encodePlanInt4CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt4CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { n, valid, err := convertToInt64ForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to int4: %v", value, err) @@ -1157,7 +1157,7 @@ func (Int8Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interf type encodePlanInt8CodecBinary struct{} -func (p *encodePlanInt8CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt8CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { n, valid, err := convertToInt64ForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to int8: %v", value, err) @@ -1178,7 +1178,7 @@ func (p *encodePlanInt8CodecBinary) Encode(value interface{}, buf []byte) (newBu type encodePlanInt8CodecText struct{} -func (p *encodePlanInt8CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt8CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { n, valid, err := convertToInt64ForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to int8: %v", value, err) diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 3f15dfce..ba93c8e1 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -133,7 +133,7 @@ func (Int<%= pg_byte_size %>Codec) PlanEncode(ci *ConnInfo, oid uint32, format i type encodePlanInt<%= pg_byte_size %>CodecBinary struct{} -func (p *encodePlanInt<%= pg_byte_size %>CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt<%= pg_byte_size %>CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { n, valid, err := convertToInt64ForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to int<%= pg_byte_size %>: %v", value, err) @@ -154,7 +154,7 @@ func (p *encodePlanInt<%= pg_byte_size %>CodecBinary) Encode(value interface{}, type encodePlanInt<%= pg_byte_size %>CodecText struct{} -func (p *encodePlanInt<%= pg_byte_size %>CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt<%= pg_byte_size %>CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { n, valid, err := convertToInt64ForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to int<%= pg_byte_size %>: %v", value, err) diff --git a/pgtype/point.go b/pgtype/point.go index 256bedc0..2d3380fd 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -144,7 +144,7 @@ func (PointCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value inter type encodePlanPointCodecBinary struct{} -func (p *encodePlanPointCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanPointCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { point, err := value.(PointValuer).PointValue() if err != nil { return nil, err @@ -161,7 +161,7 @@ func (p *encodePlanPointCodecBinary) Encode(value interface{}, buf []byte) (newB type encodePlanPointCodecText struct{} -func (p *encodePlanPointCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanPointCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { point, err := value.(PointValuer).PointValue() if err != nil { return nil, err From 313569db5659e68f826bac9771fff6f01867b4f1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 13:38:56 -0600 Subject: [PATCH 0817/1158] Remove useless allocations --- pgtype/bool.go | 4 ++-- pgtype/box.go | 4 ++-- pgtype/circle.go | 4 ++-- pgtype/int.go | 12 ++++++------ pgtype/int.go.erb | 4 ++-- pgtype/point.go | 4 ++-- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/pgtype/bool.go b/pgtype/bool.go index 60bec0f3..21cd9889 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -108,9 +108,9 @@ func (BoolCodec) PreferredFormat() int16 { func (BoolCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: - return &encodePlanBoolCodecBinary{} + return encodePlanBoolCodecBinary{} case TextFormatCode: - return &encodePlanBoolCodecText{} + return encodePlanBoolCodecText{} } return nil diff --git a/pgtype/box.go b/pgtype/box.go index 8f96a016..677d4bd2 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -78,9 +78,9 @@ func (BoxCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interfa switch format { case BinaryFormatCode: - return &encodePlanBoxCodecBinary{} + return encodePlanBoxCodecBinary{} case TextFormatCode: - return &encodePlanBoxCodecText{} + return encodePlanBoxCodecText{} } return nil diff --git a/pgtype/circle.go b/pgtype/circle.go index f2c591b4..ae8aa352 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -79,9 +79,9 @@ func (CircleCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value inte switch format { case BinaryFormatCode: - return &encodePlanCircleCodecBinary{} + return encodePlanCircleCodecBinary{} case TextFormatCode: - return &encodePlanCircleCodecText{} + return encodePlanCircleCodecText{} } return nil diff --git a/pgtype/int.go b/pgtype/int.go index 785b481b..5fee64a6 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -122,9 +122,9 @@ func (Int2Codec) PreferredFormat() int16 { func (Int2Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: - return &encodePlanInt2CodecBinary{} + return encodePlanInt2CodecBinary{} case TextFormatCode: - return &encodePlanInt2CodecText{} + return encodePlanInt2CodecText{} } return nil @@ -629,9 +629,9 @@ func (Int4Codec) PreferredFormat() int16 { func (Int4Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: - return &encodePlanInt4CodecBinary{} + return encodePlanInt4CodecBinary{} case TextFormatCode: - return &encodePlanInt4CodecText{} + return encodePlanInt4CodecText{} } return nil @@ -1147,9 +1147,9 @@ func (Int8Codec) PreferredFormat() int16 { func (Int8Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: - return &encodePlanInt8CodecBinary{} + return encodePlanInt8CodecBinary{} case TextFormatCode: - return &encodePlanInt8CodecText{} + return encodePlanInt8CodecText{} } return nil diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index ba93c8e1..419dddd2 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -123,9 +123,9 @@ func (Int<%= pg_byte_size %>Codec) PreferredFormat() int16 { func (Int<%= pg_byte_size %>Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: - return &encodePlanInt<%= pg_byte_size %>CodecBinary{} + return encodePlanInt<%= pg_byte_size %>CodecBinary{} case TextFormatCode: - return &encodePlanInt<%= pg_byte_size %>CodecText{} + return encodePlanInt<%= pg_byte_size %>CodecText{} } return nil diff --git a/pgtype/point.go b/pgtype/point.go index 2d3380fd..a9be4fdc 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -134,9 +134,9 @@ func (PointCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value inter switch format { case BinaryFormatCode: - return &encodePlanPointCodecBinary{} + return encodePlanPointCodecBinary{} case TextFormatCode: - return &encodePlanPointCodecText{} + return encodePlanPointCodecText{} } return nil From 1eee7987e123fc1c37e0730e1650df813c5ad549 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 16:24:05 -0600 Subject: [PATCH 0818/1158] Use TextCodec for aclitem type --- pgtype/aclitem.go | 127 ---------- pgtype/aclitem_array.go | 418 ------------------------------- pgtype/aclitem_array_test.go | 329 ------------------------ pgtype/aclitem_test.go | 97 ------- pgtype/pgtype.go | 4 +- pgtype/pgtype_test.go | 38 +-- pgtype/text_format_only_codec.go | 13 + pgtype/text_test.go | 57 +++++ 8 files changed, 93 insertions(+), 990 deletions(-) delete mode 100644 pgtype/aclitem.go delete mode 100644 pgtype/aclitem_array.go delete mode 100644 pgtype/aclitem_array_test.go delete mode 100644 pgtype/aclitem_test.go create mode 100644 pgtype/text_format_only_codec.go diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go deleted file mode 100644 index 0c1f23b5..00000000 --- a/pgtype/aclitem.go +++ /dev/null @@ -1,127 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "fmt" -) - -// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem -// might look like this: -// -// postgres=arwdDxt/postgres -// -// Note, however, that because the user/role name part of an aclitem is -// an identifier, it follows all the usual formatting rules for SQL -// identifiers: if it contains spaces and other special characters, -// it should appear in double-quotes: -// -// postgres=arwdDxt/"role with spaces" -// -type ACLItem struct { - String string - Valid bool -} - -func (dst *ACLItem) Set(src interface{}) error { - if src == nil { - *dst = ACLItem{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case string: - *dst = ACLItem{String: value, Valid: true} - case *string: - if value == nil { - *dst = ACLItem{} - } else { - *dst = ACLItem{String: *value, Valid: true} - } - default: - if originalSrc, ok := underlyingStringType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to ACLItem", value) - } - - return nil -} - -func (dst ACLItem) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.String -} - -func (src *ACLItem) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case *string: - *v = src.String - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - - return fmt.Errorf("cannot decode %#v into %T", src, dst) -} - -func (dst *ACLItem) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = ACLItem{} - return nil - } - - *dst = ACLItem{String: string(src), Valid: true} - return nil -} - -func (src ACLItem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return append(buf, src.String...), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *ACLItem) Scan(src interface{}) error { - if src == nil { - *dst = ACLItem{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src ACLItem) Value() (driver.Value, error) { - if !src.Valid { - return nil, nil - } - - return src.String, nil -} diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go deleted file mode 100644 index fc1128b7..00000000 --- a/pgtype/aclitem_array.go +++ /dev/null @@ -1,418 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "fmt" - "reflect" -) - -type ACLItemArray struct { - Elements []ACLItem - Dimensions []ArrayDimension - Valid bool -} - -func (dst *ACLItemArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = ACLItemArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []string: - if value == nil { - *dst = ACLItemArray{} - } else if len(value) == 0 { - *dst = ACLItemArray{Valid: true} - } else { - elements := make([]ACLItem, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = ACLItemArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*string: - if value == nil { - *dst = ACLItemArray{} - } else if len(value) == 0 { - *dst = ACLItemArray{Valid: true} - } else { - elements := make([]ACLItem, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = ACLItemArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []ACLItem: - if value == nil { - *dst = ACLItemArray{} - } else if len(value) == 0 { - *dst = ACLItemArray{Valid: true} - } else { - *dst = ACLItemArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = ACLItemArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for ACLItemArray", src) - } - if elementsLength == 0 { - *dst = ACLItemArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to ACLItemArray", src) - } - - *dst = ACLItemArray{ - Elements: make([]ACLItem, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]ACLItem, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to ACLItemArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *ACLItemArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to ACLItemArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in ACLItemArray", err) - } - index++ - - return index, nil -} - -func (dst ACLItemArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *ACLItemArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *ACLItemArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from ACLItemArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from ACLItemArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = ACLItemArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []ACLItem - - if len(uta.Elements) > 0 { - elements = make([]ACLItem, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem ACLItem - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (src ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *ACLItemArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src ACLItemArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/aclitem_array_test.go b/pgtype/aclitem_array_test.go deleted file mode 100644 index bdae9b56..00000000 --- a/pgtype/aclitem_array_test.go +++ /dev/null @@ -1,329 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestACLItemArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "aclitem[]", []interface{}{ - &pgtype.ACLItemArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.ACLItemArray{}, - &pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}, - //{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}, - {String: `postgres=arwdDxt/postgres`, Valid: true}, // todo: remove after fixing above case - {String: "=r/postgres", Valid: true}, - {}, - {String: "=r/postgres", Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}, - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestACLItemArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.ACLItemArray - }{ - { - source: []string{"=r/postgres"}, - result: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]string)(nil)), - result: pgtype.ACLItemArray{}, - }, - { - source: [][]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, - result: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]string{ - {{{ - "=r/postgres", - "postgres=arwdDxt/postgres", - "=r/postgres"}}}, - {{{ - "postgres=arwdDxt/postgres", - "=r/postgres", - "postgres=arwdDxt/postgres"}}}}, - result: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}, - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}, - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, - result: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]string{ - {{{ - "=r/postgres", - "postgres=arwdDxt/postgres", - "=r/postgres"}}}, - {{{ - "postgres=arwdDxt/postgres", - "=r/postgres", - "postgres=arwdDxt/postgres"}}}}, - result: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}, - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}, - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.ACLItemArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestACLItemArrayAssignTo(t *testing.T) { - var stringSlice []string - type _stringSlice []string - var namedStringSlice _stringSlice - var stringSliceDim2 [][]string - var stringSliceDim4 [][][][]string - var stringArrayDim2 [2][1]string - var stringArrayDim4 [2][1][1][3]string - - simpleTests := []struct { - src pgtype.ACLItemArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &stringSlice, - expected: []string{"=r/postgres"}, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &namedStringSlice, - expected: _stringSlice{"=r/postgres"}, - }, - { - src: pgtype.ACLItemArray{}, - dst: &stringSlice, - expected: (([]string)(nil)), - }, - { - src: pgtype.ACLItemArray{Valid: true}, - dst: &stringSlice, - expected: []string{}, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &stringSliceDim2, - expected: [][]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}, - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}, - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &stringSliceDim4, - expected: [][][][]string{ - {{{ - "=r/postgres", - "postgres=arwdDxt/postgres", - "=r/postgres"}}}, - {{{ - "postgres=arwdDxt/postgres", - "=r/postgres", - "postgres=arwdDxt/postgres"}}}}, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &stringArrayDim2, - expected: [2][1]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}, - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}, - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &stringArrayDim4, - expected: [2][1][1][3]string{ - {{{ - "=r/postgres", - "postgres=arwdDxt/postgres", - "=r/postgres"}}}, - {{{ - "postgres=arwdDxt/postgres", - "=r/postgres", - "postgres=arwdDxt/postgres"}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.ACLItemArray - dst interface{} - }{ - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &stringSlice, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &stringArrayDim2, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &stringSlice, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Valid: true}, - {String: "postgres=arwdDxt/postgres", Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &stringArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/aclitem_test.go b/pgtype/aclitem_test.go deleted file mode 100644 index 84388142..00000000 --- a/pgtype/aclitem_test.go +++ /dev/null @@ -1,97 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestACLItemTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "aclitem", []interface{}{ - &pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Valid: true}, - //&pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}, - &pgtype.ACLItem{}, - }) -} - -func TestACLItemSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.ACLItem - }{ - {source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Valid: true}}, - {source: (*string)(nil), result: pgtype.ACLItem{}}, - } - - for i, tt := range successfulTests { - var d pgtype.ACLItem - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if d != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } - } -} - -func TestACLItemAssignTo(t *testing.T) { - var s string - var ps *string - - simpleTests := []struct { - src pgtype.ACLItem - dst interface{} - expected interface{} - }{ - {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Valid: true}, dst: &s, expected: "postgres=arwdDxt/postgres"}, - {src: pgtype.ACLItem{}, dst: &ps, expected: ((*string)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.ACLItem - dst interface{} - expected interface{} - }{ - {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Valid: true}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.ACLItem - dst interface{} - }{ - {src: pgtype.ACLItem{}, dst: &s}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 6ead989e..d4001392 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -253,7 +253,7 @@ func newConnInfo() *ConnInfo { func NewConnInfo() *ConnInfo { ci := newConnInfo() - ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID}) + ci.RegisterDataType(DataType{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementCodec: &TextFormatOnlyCodec{TextCodec{}}, ElementOID: ACLItemOID}}) ci.RegisterDataType(DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementCodec: BoolCodec{}, ElementOID: BoolOID}}) ci.RegisterDataType(DataType{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: BPCharOID}}) ci.RegisterDataType(DataType{Value: &ByteaArray{}, Name: "_bytea", OID: ByteaArrayOID}) @@ -275,7 +275,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &TimestamptzArray{}, Name: "_timestamptz", OID: TimestamptzArrayOID}) ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID}) ci.RegisterDataType(DataType{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: VarcharOID}}) - ci.RegisterDataType(DataType{Value: &ACLItem{}, Name: "aclitem", OID: ACLItemOID}) + ci.RegisterDataType(DataType{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) ci.RegisterDataType(DataType{Value: &Bit{}, Name: "bit", OID: BitOID}) ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) ci.RegisterDataType(DataType{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 56064281..703e1843 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -308,11 +308,6 @@ func testPgxCodec(t testing.TB, pgTypeName string, tests []PgxTranscodeTestCase) conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) - _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - formats := []struct { name string code int16 @@ -321,21 +316,30 @@ func testPgxCodec(t testing.TB, pgTypeName string, tests []PgxTranscodeTestCase) {name: "BinaryFormat", code: pgx.BinaryFormatCode}, } + for _, format := range formats { + testPgxCodecFormat(t, pgTypeName, tests, conn, format.name, format.code) + } +} + +func testPgxCodecFormat(t testing.TB, pgTypeName string, tests []PgxTranscodeTestCase, conn *pgx.Conn, formatName string, formatCode int16) { + _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + for i, tt := range tests { - for _, format := range formats { - err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{format.code}, tt.src).Scan(tt.dst) - if err != nil { - t.Errorf("%s %d: %v", format.name, i, err) - } + err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{formatCode}, tt.src).Scan(tt.dst) + if err != nil { + t.Errorf("%s %d: %v", formatName, i, err) + } - dst := reflect.ValueOf(tt.dst) - if dst.Kind() == reflect.Ptr { - dst = dst.Elem() - } + dst := reflect.ValueOf(tt.dst) + if dst.Kind() == reflect.Ptr { + dst = dst.Elem() + } - if !tt.test(dst.Interface()) { - t.Errorf("%s %d: unexpected result for %v: %v", format.name, i, tt.src, dst.Interface()) - } + if !tt.test(dst.Interface()) { + t.Errorf("%s %d: unexpected result for %v: %v", formatName, i, tt.src, dst.Interface()) } } } diff --git a/pgtype/text_format_only_codec.go b/pgtype/text_format_only_codec.go new file mode 100644 index 00000000..d5e4cdb3 --- /dev/null +++ b/pgtype/text_format_only_codec.go @@ -0,0 +1,13 @@ +package pgtype + +type TextFormatOnlyCodec struct { + Codec +} + +func (c *TextFormatOnlyCodec) FormatSupported(format int16) bool { + return format == TextFormatCode && c.Codec.FormatSupported(format) +} + +func (TextFormatOnlyCodec) PreferredFormat() int16 { + return TextFormatCode +} diff --git a/pgtype/text_test.go b/pgtype/text_test.go index 148aa97b..27b01c15 100644 --- a/pgtype/text_test.go +++ b/pgtype/text_test.go @@ -1,9 +1,12 @@ package pgtype_test import ( + "context" "testing" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/stretchr/testify/require" ) func TestTextCodec(t *testing.T) { @@ -68,6 +71,60 @@ func TestTextCodecBPChar(t *testing.T) { }) } +// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem +// might look like this: +// +// postgres=arwdDxt/postgres +// +// Note, however, that because the user/role name part of an aclitem is +// an identifier, it follows all the usual formatting rules for SQL +// identifiers: if it contains spaces and other special characters, +// it should appear in double-quotes: +// +// postgres=arwdDxt/"role with spaces" +// +// It only supports the text format. +func TestTextCodecACLItem(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + testPgxCodecFormat(t, "aclitem", []PgxTranscodeTestCase{ + { + pgtype.Text{String: "postgres=arwdDxt/postgres", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "postgres=arwdDxt/postgres", Valid: true}), + }, + {pgtype.Text{}, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, + {nil, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, + }, conn, "Text", pgtype.TextFormatCode) +} + +func TestTextCodecACLItemRoleWithSpecialCharacters(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + ctx := context.Background() + + // The tricky test user, below, has to actually exist so that it can be used in a test + // of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. + roleWithSpecialCharacters := ` tricky, ' } " \ test user ` + + commandTag, err := conn.Exec(ctx, `select * from pg_roles where rolname = $1`, roleWithSpecialCharacters) + require.NoError(t, err) + + if commandTag.RowsAffected() == 0 { + t.Skipf("Role with special characters does not exist.") + } + + testPgxCodecFormat(t, "aclitem", []PgxTranscodeTestCase{ + { + pgtype.Text{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}), + }, + }, conn, "Text", pgtype.TextFormatCode) +} + func TestTextMarshalJSON(t *testing.T) { successfulTests := []struct { source pgtype.Text From 17513d175aa26c6b27966a4a36975b307a9f8dc3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 16:49:58 -0600 Subject: [PATCH 0819/1158] Convert bit and varbit to Codec --- pgtype/bit.go | 45 --------- pgtype/bit_test.go | 25 ----- pgtype/bits.go | 208 ++++++++++++++++++++++++++++++++++++++++++ pgtype/bits_test.go | 65 +++++++++++++ pgtype/pgtype.go | 4 +- pgtype/varbit.go | 123 ------------------------- pgtype/varbit_test.go | 26 ------ 7 files changed, 275 insertions(+), 221 deletions(-) delete mode 100644 pgtype/bit.go delete mode 100644 pgtype/bit_test.go create mode 100644 pgtype/bits.go create mode 100644 pgtype/bits_test.go delete mode 100644 pgtype/varbit.go delete mode 100644 pgtype/varbit_test.go diff --git a/pgtype/bit.go b/pgtype/bit.go deleted file mode 100644 index c1709e6b..00000000 --- a/pgtype/bit.go +++ /dev/null @@ -1,45 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -type Bit Varbit - -func (dst *Bit) Set(src interface{}) error { - return (*Varbit)(dst).Set(src) -} - -func (dst Bit) Get() interface{} { - return (Varbit)(dst).Get() -} - -func (src *Bit) AssignTo(dst interface{}) error { - return (*Varbit)(src).AssignTo(dst) -} - -func (dst *Bit) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Varbit)(dst).DecodeBinary(ci, src) -} - -func (src Bit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (Varbit)(src).EncodeBinary(ci, buf) -} - -func (dst *Bit) DecodeText(ci *ConnInfo, src []byte) error { - return (*Varbit)(dst).DecodeText(ci, src) -} - -func (src Bit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (Varbit)(src).EncodeText(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *Bit) Scan(src interface{}) error { - return (*Varbit)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Bit) Value() (driver.Value, error) { - return (Varbit)(src).Value() -} diff --git a/pgtype/bit_test.go b/pgtype/bit_test.go deleted file mode 100644 index 2f07c3c9..00000000 --- a/pgtype/bit_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestBitTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bit(40)", []interface{}{ - &pgtype.Varbit{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Valid: true}, - &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, - &pgtype.Varbit{}, - }) -} - -func TestBitNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select B'111111111'", - Value: &pgtype.Bit{Bytes: []byte{255, 128}, Len: 9, Valid: true}, - }, - }) -} diff --git a/pgtype/bits.go b/pgtype/bits.go new file mode 100644 index 00000000..9b499c35 --- /dev/null +++ b/pgtype/bits.go @@ -0,0 +1,208 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + + "github.com/jackc/pgio" +) + +type BitsScanner interface { + ScanBits(v Bits) error +} + +type BitsValuer interface { + BitsValue() (Bits, error) +} + +// Bits represents the PostgreSQL bit and varbit types. +type Bits struct { + Bytes []byte + Len int32 // Number of bits + Valid bool +} + +func (b *Bits) ScanBits(v Bits) error { + *b = v + return nil +} + +func (b Bits) BitsValue() (Bits, error) { + return b, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Bits) Scan(src interface{}) error { + if src == nil { + *dst = Bits{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToBitsScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bits) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + buf, err := BitsCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type BitsCodec struct{} + +func (BitsCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (BitsCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (BitsCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(BitsValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanBitsCodecBinary{} + case TextFormatCode: + return encodePlanBitsCodecText{} + } + + return nil +} + +type encodePlanBitsCodecBinary struct{} + +func (encodePlanBitsCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + bits, err := value.(BitsValuer).BitsValue() + if err != nil { + return nil, err + } + + if !bits.Valid { + return nil, nil + } + + buf = pgio.AppendInt32(buf, bits.Len) + return append(buf, bits.Bytes...), nil +} + +type encodePlanBitsCodecText struct{} + +func (encodePlanBitsCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + bits, err := value.(BitsValuer).BitsValue() + if err != nil { + return nil, err + } + + if !bits.Valid { + return nil, nil + } + + for i := int32(0); i < bits.Len; i++ { + byteIdx := i / 8 + bitMask := byte(128 >> byte(i%8)) + char := byte('0') + if bits.Bytes[byteIdx]&bitMask > 0 { + char = '1' + } + buf = append(buf, char) + } + + return buf, nil +} + +func (BitsCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case BitsScanner: + return scanPlanBinaryBitsToBitsScanner{} + } + case TextFormatCode: + switch target.(type) { + case BitsScanner: + return scanPlanTextAnyToBitsScanner{} + } + } + + return nil +} + +func (c BitsCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c BitsCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var box Bits + err := codecScan(c, ci, oid, format, src, &box) + if err != nil { + return nil, err + } + return box, nil +} + +type scanPlanBinaryBitsToBitsScanner struct{} + +func (scanPlanBinaryBitsToBitsScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(BitsScanner) + + if src == nil { + return scanner.ScanBits(Bits{}) + } + + if len(src) < 4 { + return fmt.Errorf("invalid length for bit/varbit: %v", len(src)) + } + + bitLen := int32(binary.BigEndian.Uint32(src)) + rp := 4 + + return scanner.ScanBits(Bits{Bytes: src[rp:], Len: bitLen, Valid: true}) +} + +type scanPlanTextAnyToBitsScanner struct{} + +func (scanPlanTextAnyToBitsScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(BitsScanner) + + if src == nil { + return scanner.ScanBits(Bits{}) + } + + bitLen := len(src) + byteLen := bitLen / 8 + if bitLen%8 > 0 { + byteLen++ + } + buf := make([]byte, byteLen) + + for i, b := range src { + if b == '1' { + byteIdx := i / 8 + bitIdx := uint(i % 8) + buf[byteIdx] = buf[byteIdx] | (128 >> bitIdx) + } + } + + return scanner.ScanBits(Bits{Bytes: buf, Len: int32(bitLen), Valid: true}) +} diff --git a/pgtype/bits_test.go b/pgtype/bits_test.go new file mode 100644 index 00000000..a585ef8b --- /dev/null +++ b/pgtype/bits_test.go @@ -0,0 +1,65 @@ +package pgtype_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" +) + +func isExpectedEqBits(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + ab := a.(pgtype.Bits) + vb := v.(pgtype.Bits) + return bytes.Compare(ab.Bytes, vb.Bytes) == 0 && ab.Len == vb.Len && ab.Valid == vb.Valid + } +} + +func TestBitsCodecBit(t *testing.T) { + testPgxCodec(t, "bit(40)", []PgxTranscodeTestCase{ + { + pgtype.Bits{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Valid: true}), + }, + { + pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}), + }, + {pgtype.Bits{}, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, + {nil, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, + }) +} + +func TestBitsCodecVarbit(t *testing.T) { + testPgxCodec(t, "varbit", []PgxTranscodeTestCase{ + { + pgtype.Bits{Bytes: []byte{}, Len: 0, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{}, Len: 0, Valid: true}), + }, + { + pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}), + }, + { + pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Valid: true}), + }, + {pgtype.Bits{}, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, + {nil, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, + }) +} + +func TestBitsNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select B'111111111'", + Value: &pgtype.Bits{Bytes: []byte{255, 128}, Len: 9, Valid: true}, + }, + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index d4001392..dc3fbedd 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -276,7 +276,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID}) ci.RegisterDataType(DataType{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: VarcharOID}}) ci.RegisterDataType(DataType{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) - ci.RegisterDataType(DataType{Value: &Bit{}, Name: "bit", OID: BitOID}) + ci.RegisterDataType(DataType{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) ci.RegisterDataType(DataType{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) ci.RegisterDataType(DataType{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) @@ -321,7 +321,7 @@ func NewConnInfo() *ConnInfo { // ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID}) ci.RegisterDataType(DataType{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Value: &UUID{}, Name: "uuid", OID: UUIDOID}) - ci.RegisterDataType(DataType{Value: &Varbit{}, Name: "varbit", OID: VarbitOID}) + ci.RegisterDataType(DataType{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) ci.RegisterDataType(DataType{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID}) diff --git a/pgtype/varbit.go b/pgtype/varbit.go deleted file mode 100644 index bc6fdac4..00000000 --- a/pgtype/varbit.go +++ /dev/null @@ -1,123 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - - "github.com/jackc/pgio" -) - -type Varbit struct { - Bytes []byte - Len int32 // Number of bits - Valid bool -} - -func (dst *Varbit) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Varbit", src) -} - -func (dst Varbit) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *Varbit) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Varbit) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Varbit{} - return nil - } - - bitLen := len(src) - byteLen := bitLen / 8 - if bitLen%8 > 0 { - byteLen++ - } - buf := make([]byte, byteLen) - - for i, b := range src { - if b == '1' { - byteIdx := i / 8 - bitIdx := uint(i % 8) - buf[byteIdx] = buf[byteIdx] | (128 >> bitIdx) - } - } - - *dst = Varbit{Bytes: buf, Len: int32(bitLen), Valid: true} - return nil -} - -func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Varbit{} - return nil - } - - if len(src) < 4 { - return fmt.Errorf("invalid length for varbit: %v", len(src)) - } - - bitLen := int32(binary.BigEndian.Uint32(src)) - rp := 4 - - *dst = Varbit{Bytes: src[rp:], Len: bitLen, Valid: true} - return nil -} - -func (src Varbit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - for i := int32(0); i < src.Len; i++ { - byteIdx := i / 8 - bitMask := byte(128 >> byte(i%8)) - char := byte('0') - if src.Bytes[byteIdx]&bitMask > 0 { - char = '1' - } - buf = append(buf, char) - } - - return buf, nil -} - -func (src Varbit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = pgio.AppendInt32(buf, src.Len) - return append(buf, src.Bytes...), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Varbit) Scan(src interface{}) error { - if src == nil { - *dst = Varbit{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Varbit) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/varbit_test.go b/pgtype/varbit_test.go deleted file mode 100644 index 031d5fa8..00000000 --- a/pgtype/varbit_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestVarbitTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "varbit", []interface{}{ - &pgtype.Varbit{Bytes: []byte{}, Len: 0, Valid: true}, - &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, - &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Valid: true}, - &pgtype.Varbit{}, - }) -} - -func TestVarbitNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select B'111111111'", - Value: &pgtype.Varbit{Bytes: []byte{255, 128}, Len: 9, Valid: true}, - }, - }) -} From f5347987a6f433271ad54330ea9a27611a8357c2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 16:53:15 -0600 Subject: [PATCH 0820/1158] Add bit and varbit array support --- pgtype/pgtype.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index dc3fbedd..47e3518e 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -70,7 +70,9 @@ const ( IntervalOID = 1186 NumericArrayOID = 1231 BitOID = 1560 + BitArrayOID = 1561 VarbitOID = 1562 + VarbitArrayOID = 1563 NumericOID = 1700 RecordOID = 2249 UUIDOID = 2950 @@ -275,6 +277,8 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &TimestamptzArray{}, Name: "_timestamptz", OID: TimestamptzArrayOID}) ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID}) ci.RegisterDataType(DataType{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: VarcharOID}}) + ci.RegisterDataType(DataType{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementCodec: BitsCodec{}, ElementOID: BitOID}}) + ci.RegisterDataType(DataType{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementCodec: BitsCodec{}, ElementOID: VarbitOID}}) ci.RegisterDataType(DataType{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) ci.RegisterDataType(DataType{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) From c6f3e03a6103ca6c4a524ad33b26e69585f89db2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 17:01:32 -0600 Subject: [PATCH 0821/1158] BoolCodec EncodePlan actually plans --- pgtype/bool.go | 103 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 69 insertions(+), 34 deletions(-) diff --git a/pgtype/bool.go b/pgtype/bool.go index 21cd9889..3dd7efd3 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -8,7 +8,11 @@ import ( ) type BoolScanner interface { - ScanBool(v bool, valid bool) error + ScanBool(v Bool) error +} + +type BoolValuer interface { + BoolValue() (Bool, error) } type Bool struct { @@ -16,18 +20,15 @@ type Bool struct { Valid bool } -// ScanBool implements the BoolScanner interface. -func (dst *Bool) ScanBool(v bool, valid bool) error { - if !valid { - *dst = Bool{} - return nil - } - - *dst = Bool{Bool: v, Valid: true} - +func (b *Bool) ScanBool(v Bool) error { + *b = v return nil } +func (b Bool) BoolValue() (Bool, error) { + return b, nil +} + // Scan implements the database/sql Scanner interface. func (dst *Bool) Scan(src interface{}) error { if src == nil { @@ -108,27 +109,28 @@ func (BoolCodec) PreferredFormat() int16 { func (BoolCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: - return encodePlanBoolCodecBinary{} + switch value.(type) { + case bool: + return encodePlanBoolCodecBinaryBool{} + case BoolValuer: + return encodePlanBoolCodecBinaryBoolScanner{} + } case TextFormatCode: - return encodePlanBoolCodecText{} + switch value.(type) { + case bool: + return encodePlanBoolCodecTextBool{} + case BoolValuer: + return encodePlanBoolCodecTextBoolScanner{} + } } return nil } -type encodePlanBoolCodecBinary struct{} +type encodePlanBoolCodecBinaryBool struct{} -func (encodePlanBoolCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - v, valid, err := convertToBoolForEncode(value) - if err != nil { - return nil, fmt.Errorf("cannot convert %v to bool: %v", value, err) - } - if !valid { - return nil, nil - } - if value == nil { - return nil, nil - } +func (encodePlanBoolCodecBinaryBool) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v := value.(bool) if v { buf = append(buf, 1) @@ -139,20 +141,53 @@ func (encodePlanBoolCodecBinary) Encode(value interface{}, buf []byte) (newBuf [ return buf, nil } -type encodePlanBoolCodecText struct{} +type encodePlanBoolCodecTextBoolScanner struct{} -func (encodePlanBoolCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - v, valid, err := convertToBoolForEncode(value) +func (encodePlanBoolCodecTextBoolScanner) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + b, err := value.(BoolValuer).BoolValue() if err != nil { - return nil, fmt.Errorf("cannot convert %v to bool: %v", value, err) + return nil, err } - if !valid { + + if !b.Valid { return nil, nil } - if value == nil { + + if b.Bool { + buf = append(buf, 't') + } else { + buf = append(buf, 'f') + } + + return buf, nil +} + +type encodePlanBoolCodecBinaryBoolScanner struct{} + +func (encodePlanBoolCodecBinaryBoolScanner) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + b, err := value.(BoolValuer).BoolValue() + if err != nil { + return nil, err + } + + if !b.Valid { return nil, nil } + if b.Bool { + buf = append(buf, 1) + } else { + buf = append(buf, 0) + } + + return buf, nil +} + +type encodePlanBoolCodecTextBool struct{} + +func (encodePlanBoolCodecTextBool) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v := value.(bool) + if v { buf = append(buf, 't') } else { @@ -288,14 +323,14 @@ func (scanPlanBinaryBoolToBoolScanner) Scan(ci *ConnInfo, oid uint32, formatCode } if src == nil { - return s.ScanBool(false, false) + return s.ScanBool(Bool{}) } if len(src) != 1 { return fmt.Errorf("invalid length for bool: %v", len(src)) } - return s.ScanBool(src[0] == 1, true) + return s.ScanBool(Bool{Bool: src[0] == 1, Valid: true}) } type scanPlanTextAnyToBoolScanner struct{} @@ -307,12 +342,12 @@ func (scanPlanTextAnyToBoolScanner) Scan(ci *ConnInfo, oid uint32, formatCode in } if src == nil { - return s.ScanBool(false, false) + return s.ScanBool(Bool{}) } if len(src) != 1 { return fmt.Errorf("invalid length for bool: %v", len(src)) } - return s.ScanBool(src[0] == 't', true) + return s.ScanBool(Bool{Bool: src[0] == 't', Valid: true}) } From f573cde09c56ecd161b98b6f568539c4015c69d3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 18:33:08 -0600 Subject: [PATCH 0822/1158] Convert bytea to Codec --- pgtype/bytea.go | 296 +++++++++++++++-------- pgtype/bytea_array.go | 476 ------------------------------------- pgtype/bytea_array_test.go | 229 ------------------ pgtype/bytea_test.go | 131 +++++----- pgtype/generic_binary.go | 39 --- pgtype/pgtype.go | 8 +- rows.go | 27 +-- 7 files changed, 297 insertions(+), 909 deletions(-) delete mode 100644 pgtype/bytea_array.go delete mode 100644 pgtype/bytea_array_test.go delete mode 100644 pgtype/generic_binary.go diff --git a/pgtype/bytea.go b/pgtype/bytea.go index d4c4e436..2eb50610 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -6,141 +6,249 @@ import ( "fmt" ) -type Bytea struct { - Bytes []byte - Valid bool +type BytesScanner interface { + // ScanBytes receives a byte slice of driver memory that is only valid until the next database method call. + ScanBytes(v []byte) error } -func (dst *Bytea) Set(src interface{}) error { - if src == nil { - *dst = Bytea{} +type BytesValuer interface { + // BytesValue returns a byte slice of the byte data. The caller must not change the returned slice. + BytesValue() ([]byte, error) +} + +// DriverBytes is a byte slice that holds a reference to memory owned by the driver. It is only valid until the next +// database method call. e.g. Any call to a Rows or Conn method invalidates the slice. +type DriverBytes []byte + +func (b *DriverBytes) ScanBytes(v []byte) error { + *b = v + return nil +} + +// PreallocBytes is a byte slice of preallocated memory that scanned bytes will be copied to. If it is too small a new +// slice will be allocated. +type PreallocBytes []byte + +func (b *PreallocBytes) ScanBytes(v []byte) error { + if v == nil { + *b = nil return nil } - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } + if len(v) <= len(*b) { + *b = (*b)[:len(v)] + } else { + *b = make(PreallocBytes, len(v)) + } + copy(*b, v) + return nil +} + +// UndecodedBytes can be used as a scan target to get the raw bytes from PostgreSQL without any decoding. +type UndecodedBytes []byte + +type scanPlanAnyToUndecodedBytes struct{} + +func (scanPlanAnyToUndecodedBytes) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + dstBuf := dst.(*UndecodedBytes) + if src == nil { + *dstBuf = nil + return nil } - switch value := src.(type) { - case []byte: - if value != nil { - *dst = Bytea{Bytes: value, Valid: true} - } else { - *dst = Bytea{} + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type ByteaCodec struct{} + +func (ByteaCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (ByteaCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (ByteaCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case []byte: + return encodePlanBytesCodecBinaryBytes{} + case BytesValuer: + return encodePlanBytesCodecBinaryBytesValuer{} } - default: - if originalSrc, ok := underlyingBytesType(src); ok { - return dst.Set(originalSrc) + case TextFormatCode: + switch value.(type) { + case []byte: + return encodePlanBytesCodecTextBytes{} + case BytesValuer: + return encodePlanBytesCodecTextBytesValuer{} } - return fmt.Errorf("cannot convert %v to Bytea", value) } return nil } -func (dst Bytea) Get() interface{} { - if !dst.Valid { - return nil +type encodePlanBytesCodecBinaryBytes struct{} + +func (encodePlanBytesCodecBinaryBytes) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + b := value.([]byte) + if b == nil { + return nil, nil } - return dst.Bytes + + return append(buf, b...), nil } -func (src *Bytea) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) +type encodePlanBytesCodecBinaryBytesValuer struct{} + +func (encodePlanBytesCodecBinaryBytesValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + b, err := value.(BytesValuer).BytesValue() + if err != nil { + return nil, err + } + if b == nil { + return nil, nil } - switch v := dst.(type) { - case *[]byte: - buf := make([]byte, len(src.Bytes)) - copy(buf, src.Bytes) - *v = buf - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) + return append(buf, b...), nil +} + +type encodePlanBytesCodecTextBytes struct{} + +func (encodePlanBytesCodecTextBytes) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + b := value.([]byte) + if b == nil { + return nil, nil + } + + buf = append(buf, `\x`...) + buf = append(buf, hex.EncodeToString(b)...) + return buf, nil +} + +type encodePlanBytesCodecTextBytesValuer struct{} + +func (encodePlanBytesCodecTextBytesValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + b, err := value.(BytesValuer).BytesValue() + if err != nil { + return nil, err + } + if b == nil { + return nil, nil + } + + buf = append(buf, `\x`...) + buf = append(buf, hex.EncodeToString(b)...) + return buf, nil +} + +func (ByteaCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *[]byte: + return scanPlanBinaryBytesToBytes{} + case BytesScanner: + return scanPlanBinaryBytesToBytesScanner{} + } + case TextFormatCode: + switch target.(type) { + case *[]byte: + return scanPlanTextByteaToBytes{} + case BytesScanner: + return scanPlanTextByteaToBytesScanner{} } - return fmt.Errorf("unable to assign to %T", dst) } + + return nil } -// DecodeText only supports the hex format. This has been the default since -// PostgreSQL 9.0. -func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error { +type scanPlanBinaryBytesToBytes struct{} + +func (scanPlanBinaryBytesToBytes) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + dstBuf := dst.(*[]byte) if src == nil { - *dst = Bytea{} + *dstBuf = nil return nil } + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type scanPlanBinaryBytesToBytesScanner struct{} + +func (scanPlanBinaryBytesToBytesScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(BytesScanner) + return scanner.ScanBytes(src) +} + +type scanPlanTextByteaToBytes struct{} + +func (scanPlanTextByteaToBytes) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + dstBuf := dst.(*[]byte) + if src == nil { + *dstBuf = nil + return nil + } + + buf, err := decodeHexBytea(src) + if err != nil { + return err + } + *dstBuf = buf + + return nil +} + +type scanPlanTextByteaToBytesScanner struct{} + +func (scanPlanTextByteaToBytesScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(BytesScanner) + buf, err := decodeHexBytea(src) + if err != nil { + return err + } + return scanner.ScanBytes(buf) +} + +func decodeHexBytea(src []byte) ([]byte, error) { + if src == nil { + return nil, nil + } + if len(src) < 2 || src[0] != '\\' || src[1] != 'x' { - return fmt.Errorf("invalid hex format") + return nil, fmt.Errorf("invalid hex format") } buf := make([]byte, (len(src)-2)/2) _, err := hex.Decode(buf, src[2:]) if err != nil { - return err + return nil, err } - *dst = Bytea{Bytes: buf, Valid: true} - return nil -} - -func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Bytea{} - return nil - } - - *dst = Bytea{Bytes: src, Valid: true} - return nil -} - -func (src Bytea) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = append(buf, `\x`...) - buf = append(buf, hex.EncodeToString(src.Bytes)...) return buf, nil } -func (src Bytea) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return append(buf, src.Bytes...), nil +func (c ByteaCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) } -// Scan implements the database/sql Scanner interface. -func (dst *Bytea) Scan(src interface{}) error { +func (c ByteaCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = Bytea{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - buf := make([]byte, len(src)) - copy(buf, src) - *dst = Bytea{Bytes: buf, Valid: true} - return nil - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Bytea) Value() (driver.Value, error) { - if !src.Valid { return nil, nil } - return src.Bytes, nil + + var buf []byte + err := codecScan(c, ci, oid, format, src, &buf) + if err != nil { + return nil, err + } + return buf, nil } diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go deleted file mode 100644 index 7c539e21..00000000 --- a/pgtype/bytea_array.go +++ /dev/null @@ -1,476 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type ByteaArray struct { - Elements []Bytea - Dimensions []ArrayDimension - Valid bool -} - -func (dst *ByteaArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = ByteaArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case [][]byte: - if value == nil { - *dst = ByteaArray{} - } else if len(value) == 0 { - *dst = ByteaArray{Valid: true} - } else { - elements := make([]Bytea, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = ByteaArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Bytea: - if value == nil { - *dst = ByteaArray{} - } else if len(value) == 0 { - *dst = ByteaArray{Valid: true} - } else { - *dst = ByteaArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = ByteaArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for ByteaArray", src) - } - if elementsLength == 0 { - *dst = ByteaArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to ByteaArray", src) - } - - *dst = ByteaArray{ - Elements: make([]Bytea, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Bytea, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to ByteaArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *ByteaArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to ByteaArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in ByteaArray", err) - } - index++ - - return index, nil -} - -func (dst ByteaArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *ByteaArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[][]byte: - *v = make([][]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *ByteaArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from ByteaArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from ByteaArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = ByteaArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Bytea - - if len(uta.Elements) > 0 { - elements = make([]Bytea, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Bytea - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = ByteaArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = ByteaArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = ByteaArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Bytea, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = ByteaArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("bytea"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "bytea") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *ByteaArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src ByteaArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/bytea_array_test.go b/pgtype/bytea_array_test.go deleted file mode 100644 index 08b69c26..00000000 --- a/pgtype/bytea_array_test.go +++ /dev/null @@ -1,229 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestByteaArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bytea[]", []interface{}{ - &pgtype.ByteaArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.ByteaArray{}, - &pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Valid: true}, - {Bytes: []byte{1, 2, 3}, Valid: true}, - {Bytes: []byte{}, Valid: true}, - {Bytes: []byte{1, 2, 3}, Valid: true}, - {}, - {Bytes: []byte{1}, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Valid: true}, - {Bytes: []byte{}, Valid: true}, - {Bytes: []byte{1, 2, 3}, Valid: true}, - {Bytes: []byte{1}, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestByteaArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.ByteaArray - }{ - { - source: [][]byte{{1, 2, 3}}, - result: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([][]byte)(nil)), - result: pgtype.ByteaArray{}, - }, - { - source: [][][]byte{{{1}}, {{2}}}, - result: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1}, Valid: true}, {Bytes: []byte{2}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, - result: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Valid: true}, - {Bytes: []byte{4, 5, 6}, Valid: true}, - {Bytes: []byte{7, 8, 9}, Valid: true}, - {Bytes: []byte{10, 11, 12}, Valid: true}, - {Bytes: []byte{13, 14, 15}, Valid: true}, - {Bytes: []byte{16, 17, 18}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1][]byte{{{1}}, {{2}}}, - result: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1}, Valid: true}, {Bytes: []byte{2}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, - result: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Valid: true}, - {Bytes: []byte{4, 5, 6}, Valid: true}, - {Bytes: []byte{7, 8, 9}, Valid: true}, - {Bytes: []byte{10, 11, 12}, Valid: true}, - {Bytes: []byte{13, 14, 15}, Valid: true}, - {Bytes: []byte{16, 17, 18}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.ByteaArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestByteaArrayAssignTo(t *testing.T) { - var byteByteSlice [][]byte - var byteByteSliceDim2 [][][]byte - var byteByteSliceDim4 [][][][][]byte - var byteByteArraySliceDim2 [2][1][]byte - var byteByteArraySliceDim4 [2][1][1][3][]byte - - simpleTests := []struct { - src pgtype.ByteaArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &byteByteSlice, - expected: [][]byte{{1, 2, 3}}, - }, - { - src: pgtype.ByteaArray{}, - dst: &byteByteSlice, - expected: (([][]byte)(nil)), - }, - { - src: pgtype.ByteaArray{Valid: true}, - dst: &byteByteSlice, - expected: [][]byte{}, - }, - { - src: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1}, Valid: true}, {Bytes: []byte{2}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &byteByteSliceDim2, - expected: [][][]byte{{{1}}, {{2}}}, - }, - { - src: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Valid: true}, - {Bytes: []byte{4, 5, 6}, Valid: true}, - {Bytes: []byte{7, 8, 9}, Valid: true}, - {Bytes: []byte{10, 11, 12}, Valid: true}, - {Bytes: []byte{13, 14, 15}, Valid: true}, - {Bytes: []byte{16, 17, 18}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &byteByteSliceDim4, - expected: [][][][][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, - }, - { - src: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1}, Valid: true}, {Bytes: []byte{2}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &byteByteArraySliceDim2, - expected: [2][1][]byte{{{1}}, {{2}}}, - }, - { - src: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Valid: true}, - {Bytes: []byte{4, 5, 6}, Valid: true}, - {Bytes: []byte{7, 8, 9}, Valid: true}, - {Bytes: []byte{10, 11, 12}, Valid: true}, - {Bytes: []byte{13, 14, 15}, Valid: true}, - {Bytes: []byte{16, 17, 18}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &byteByteArraySliceDim4, - expected: [2][1][1][3][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go index 21751e24..d99d28b6 100644 --- a/pgtype/bytea_test.go +++ b/pgtype/bytea_test.go @@ -1,73 +1,94 @@ package pgtype_test import ( - "reflect" + "bytes" + "context" "testing" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/stretchr/testify/require" ) -func TestByteaTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bytea", []interface{}{ - &pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}, - &pgtype.Bytea{Bytes: []byte{}, Valid: true}, - &pgtype.Bytea{Bytes: nil}, +func isExpectedEqBytes(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + ab := a.([]byte) + vb := v.([]byte) + + if (ab == nil) != (vb == nil) { + return false + } + + if ab == nil { + return true + } + + return bytes.Compare(ab, vb) == 0 + } +} + +func TestByteaCodec(t *testing.T) { + testPgxCodec(t, "bytea", []PgxTranscodeTestCase{ + {[]byte{1, 2, 3}, new([]byte), isExpectedEqBytes([]byte{1, 2, 3})}, + {[]byte{}, new([]byte), isExpectedEqBytes([]byte{})}, + {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, }) } -func TestByteaSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Bytea - }{ - {source: []byte{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}}, - {source: []byte{}, result: pgtype.Bytea{Bytes: []byte{}, Valid: true}}, - {source: []byte(nil), result: pgtype.Bytea{}}, - {source: _byteSlice{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}}, - {source: _byteSlice(nil), result: pgtype.Bytea{}}, - } +func TestDriverBytes(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) - for i, tt := range successfulTests { - var r pgtype.Bytea - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } + ctx := context.Background() - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestByteaAssignTo(t *testing.T) { var buf []byte - var _buf _byteSlice - var pbuf *[]byte - var _pbuf *_byteSlice + err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.DriverBytes)(&buf)) + require.NoError(t, err) - simpleTests := []struct { - src pgtype.Bytea - dst interface{} - expected interface{} - }{ - {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}, dst: &buf, expected: []byte{1, 2, 3}}, - {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}, dst: &_buf, expected: _byteSlice{1, 2, 3}}, - {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}, dst: &pbuf, expected: &[]byte{1, 2, 3}}, - {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}, dst: &_pbuf, expected: &_byteSlice{1, 2, 3}}, - {src: pgtype.Bytea{}, dst: &pbuf, expected: ((*[]byte)(nil))}, - {src: pgtype.Bytea{}, dst: &_pbuf, expected: ((*_byteSlice)(nil))}, - } + require.Len(t, buf, 2) + require.Equal(t, buf, []byte{1, 2}) + require.Equalf(t, cap(buf), len(buf), "cap(buf) is larger than len(buf)") - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } + // Don't actually have any way to be sure that the bytes are from the driver at the moment as underlying driver + // doesn't reuse buffers at the present. +} + +func TestPreallocBytes(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + ctx := context.Background() + + origBuf := []byte{5, 6, 7, 8} + buf := origBuf + err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.PreallocBytes)(&buf)) + require.NoError(t, err) + + require.Len(t, buf, 2) + require.Equal(t, 4, cap(buf)) + require.Equal(t, buf, []byte{1, 2}) + + require.Equal(t, []byte{1, 2, 7, 8}, origBuf) + + err = conn.QueryRow(ctx, `select $1::bytea`, []byte{3, 4, 5, 6, 7}).Scan((*pgtype.PreallocBytes)(&buf)) + require.NoError(t, err) + require.Len(t, buf, 5) + require.Equal(t, 5, cap(buf)) + + require.Equal(t, []byte{1, 2, 7, 8}, origBuf) +} + +func TestUndecodedBytes(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + ctx := context.Background() + + var buf []byte + err := conn.QueryRow(ctx, `select 1`).Scan((*pgtype.UndecodedBytes)(&buf)) + require.NoError(t, err) + + require.Len(t, buf, 4) + require.Equal(t, buf, []byte{0, 0, 0, 1}) } diff --git a/pgtype/generic_binary.go b/pgtype/generic_binary.go deleted file mode 100644 index 76a1d351..00000000 --- a/pgtype/generic_binary.go +++ /dev/null @@ -1,39 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -// GenericBinary is a placeholder for binary format values that no other type exists -// to handle. -type GenericBinary Bytea - -func (dst *GenericBinary) Set(src interface{}) error { - return (*Bytea)(dst).Set(src) -} - -func (dst GenericBinary) Get() interface{} { - return (Bytea)(dst).Get() -} - -func (src *GenericBinary) AssignTo(dst interface{}) error { - return (*Bytea)(src).AssignTo(dst) -} - -func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Bytea)(dst).DecodeBinary(ci, src) -} - -func (src GenericBinary) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (Bytea)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *GenericBinary) Scan(src interface{}) error { - return (*Bytea)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src GenericBinary) Value() (driver.Value, error) { - return (Bytea)(src).Value() -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 47e3518e..7fff7dd5 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -258,7 +258,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementCodec: &TextFormatOnlyCodec{TextCodec{}}, ElementOID: ACLItemOID}}) ci.RegisterDataType(DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementCodec: BoolCodec{}, ElementOID: BoolOID}}) ci.RegisterDataType(DataType{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: BPCharOID}}) - ci.RegisterDataType(DataType{Value: &ByteaArray{}, Name: "_bytea", OID: ByteaArrayOID}) + ci.RegisterDataType(DataType{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementCodec: ByteaCodec{}, ElementOID: ByteaOID}}) ci.RegisterDataType(DataType{Value: &CIDRArray{}, Name: "_cidr", OID: CIDRArrayOID}) ci.RegisterDataType(DataType{Value: &DateArray{}, Name: "_date", OID: DateArrayOID}) ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID}) @@ -284,7 +284,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) ci.RegisterDataType(DataType{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) ci.RegisterDataType(DataType{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) - ci.RegisterDataType(DataType{Value: &Bytea{}, Name: "bytea", OID: ByteaOID}) + ci.RegisterDataType(DataType{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) ci.RegisterDataType(DataType{Value: &CID{}, Name: "cid", OID: CIDOID}) ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID}) @@ -810,6 +810,10 @@ func (plan *pointerEmptyInterfaceScanPlan) Scan(ci *ConnInfo, oid uint32, format // PlanScan prepares a plan to scan a value into dst. func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan { + if _, ok := dst.(*UndecodedBytes); ok { + return scanPlanAnyToUndecodedBytes{} + } + switch formatCode { case BinaryFormatCode: switch dst.(type) { diff --git a/rows.go b/rows.go index 62a19016..8e9fdc70 100644 --- a/rows.go +++ b/rows.go @@ -262,15 +262,17 @@ func (rows *connRows) Values() ([]interface{}, error) { values = append(values, string(buf)) } case BinaryFormatCode: - decoder, ok := value.(pgtype.BinaryDecoder) - if !ok { - decoder = &pgtype.GenericBinary{} + if decoder, ok := value.(pgtype.BinaryDecoder); ok { + err := decoder.DecodeBinary(rows.connInfo, buf) + if err != nil { + rows.fatal(err) + } + values = append(values, value.Get()) + } else { + newBuf := make([]byte, len(buf)) + copy(newBuf, buf) + values = append(values, newBuf) } - err := decoder.DecodeBinary(rows.connInfo, buf) - if err != nil { - rows.fatal(err) - } - values = append(values, value.Get()) default: rows.fatal(errors.New("Unknown format code")) } @@ -286,12 +288,9 @@ func (rows *connRows) Values() ([]interface{}, error) { case TextFormatCode: values = append(values, string(buf)) case BinaryFormatCode: - decoder := &pgtype.GenericBinary{} - err := decoder.DecodeBinary(rows.connInfo, buf) - if err != nil { - rows.fatal(err) - } - values = append(values, decoder.Get()) + newBuf := make([]byte, len(buf)) + copy(newBuf, buf) + values = append(values, newBuf) default: rows.fatal(errors.New("Unknown format code")) } From 6cb3439492c1c0d3ec4106d67bdd602536a01655 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 18:35:54 -0600 Subject: [PATCH 0823/1158] Fix encode plan names --- pgtype/bool.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pgtype/bool.go b/pgtype/bool.go index 3dd7efd3..d81c2417 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -113,14 +113,14 @@ func (BoolCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interf case bool: return encodePlanBoolCodecBinaryBool{} case BoolValuer: - return encodePlanBoolCodecBinaryBoolScanner{} + return encodePlanBoolCodecBinaryBoolValuer{} } case TextFormatCode: switch value.(type) { case bool: return encodePlanBoolCodecTextBool{} case BoolValuer: - return encodePlanBoolCodecTextBoolScanner{} + return encodePlanBoolCodecTextBoolValuer{} } } @@ -141,9 +141,9 @@ func (encodePlanBoolCodecBinaryBool) Encode(value interface{}, buf []byte) (newB return buf, nil } -type encodePlanBoolCodecTextBoolScanner struct{} +type encodePlanBoolCodecTextBoolValuer struct{} -func (encodePlanBoolCodecTextBoolScanner) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBoolCodecTextBoolValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { b, err := value.(BoolValuer).BoolValue() if err != nil { return nil, err @@ -162,9 +162,9 @@ func (encodePlanBoolCodecTextBoolScanner) Encode(value interface{}, buf []byte) return buf, nil } -type encodePlanBoolCodecBinaryBoolScanner struct{} +type encodePlanBoolCodecBinaryBoolValuer struct{} -func (encodePlanBoolCodecBinaryBoolScanner) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBoolCodecBinaryBoolValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { b, err := value.(BoolValuer).BoolValue() if err != nil { return nil, err From 6be0c3f6b224b099aa6f55cb0284c6f6ac753f46 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 20:51:28 -0600 Subject: [PATCH 0824/1158] Remove convertToBoolForEncode --- pgtype/bool.go | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/pgtype/bool.go b/pgtype/bool.go index d81c2417..71ce09b6 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -236,42 +236,6 @@ func (c BoolCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt return b, nil } -func convertToBoolForEncode(v interface{}) (b bool, valid bool, err error) { - if v == nil { - return false, false, nil - } - - switch v := v.(type) { - case bool: - return v, true, nil - case *bool: - if v == nil { - return false, false, nil - } - return *v, true, nil - case string: - bb, err := strconv.ParseBool(v) - if err != nil { - return false, false, err - } - return bb, true, nil - case *string: - if v == nil { - return false, false, nil - } - bb, err := strconv.ParseBool(*v) - if err != nil { - return false, false, err - } - return bb, true, nil - default: - if originalvalue, ok := underlyingBoolType(v); ok { - return convertToBoolForEncode(originalvalue) - } - return false, false, fmt.Errorf("cannot convert %v to bool", v) - } -} - type scanPlanBinaryBoolToBool struct{} func (scanPlanBinaryBoolToBool) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { From dc05bd9feef192632a6cc1b7be7b393a611ae384 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 20:51:44 -0600 Subject: [PATCH 0825/1158] Remove old code gen --- pgtype/new_pg_value.erb | 37 ------------------------------- pgtype/new_pg_value_gen.sh | 45 -------------------------------------- 2 files changed, 82 deletions(-) delete mode 100644 pgtype/new_pg_value.erb delete mode 100644 pgtype/new_pg_value_gen.sh diff --git a/pgtype/new_pg_value.erb b/pgtype/new_pg_value.erb deleted file mode 100644 index 71a0da7f..00000000 --- a/pgtype/new_pg_value.erb +++ /dev/null @@ -1,37 +0,0 @@ -package pgtype - -<% skip_binary ||= false %> -<% skip_text ||= false %> -<% prefer_text_format ||= false %> - -func (<%= go_type %>) BinaryFormatSupported() bool { - return true -} - -func (<%= go_type %>) TextFormatSupported() bool { - return true -} - -func (<%= go_type %>) PreferredFormat() int16 { - return <%= prefer_text_format ? "Text" : "Binary" %>FormatCode -} - -func (dst *<%= go_type %>) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - <% if skip_binary %> return fmt.Errorf("binary format not supported for %T", dst) <% else %> return dst.DecodeBinary(ci, src) <% end %> - case TextFormatCode: - <% if skip_text %> return fmt.Errorf("text format not supported for %T", dst) <% else %> return dst.DecodeText(ci, src) <% end %> - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src <%= go_type %>) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - <% if skip_binary %>return nil, fmt.Errorf("binary format not supported for %T", src)<% else %>return src.EncodeBinary(ci, buf)<% end %> - case TextFormatCode: - <% if skip_text %>return nil, fmt.Errorf("text format not supported for %T", src)<% else %>return src.EncodeText(ci, buf)<% end %> - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/pgtype/new_pg_value_gen.sh b/pgtype/new_pg_value_gen.sh deleted file mode 100644 index 3dad08de..00000000 --- a/pgtype/new_pg_value_gen.sh +++ /dev/null @@ -1,45 +0,0 @@ -erb go_type=ACLItem skip_binary=true prefer_text_format=true new_pg_value.erb > zzz.aclitem.go -erb go_type=Bit new_pg_value.erb > zzz.bit.go -erb go_type=Bool new_pg_value.erb > zzz.bool.go -erb go_type=Box new_pg_value.erb > zzz.box.go -erb go_type=BPChar prefer_text_format=true new_pg_value.erb > zzz.bpchar.go -erb go_type=Bytea new_pg_value.erb > zzz.bytea.go -erb go_type=CID new_pg_value.erb > zzz.cid.go -erb go_type=CIDR new_pg_value.erb > zzz.cidr.go -erb go_type=Circle new_pg_value.erb > zzz.circle.go -erb go_type=Date new_pg_value.erb > zzz.date.go -erb go_type=Float4 new_pg_value.erb > zzz.float4.go -erb go_type=Float8 new_pg_value.erb > zzz.float8.go -erb go_type=GenericBinary skip_text=true new_pg_value.erb > zzz.generic_binary.go -erb go_type=GenericText skip_binary=true prefer_text_format=true new_pg_value.erb > zzz.generic_text.go -erb go_type=Hstore new_pg_value.erb > zzz.hstore.go -erb go_type=Inet new_pg_value.erb > zzz.inet.go -erb go_type=Int2 new_pg_value.erb > zzz.int2.go -erb go_type=Int4 new_pg_value.erb > zzz.int4.go -erb go_type=Int8 new_pg_value.erb > zzz.int8.go -erb go_type=Interval new_pg_value.erb > zzz.interval.go -erb go_type=JSON prefer_text_format=true new_pg_value.erb > zzz.json.go -erb go_type=JSONB prefer_text_format=true new_pg_value.erb > zzz.jsonb.go -erb go_type=Line new_pg_value.erb > zzz.line.go -erb go_type=Lseg new_pg_value.erb > zzz.lseg.go -erb go_type=Macaddr new_pg_value.erb > zzz.macadder.go -erb go_type=Name new_pg_value.erb > zzz.name.go -erb go_type=Numeric new_pg_value.erb > zzz.numeric.go -erb go_type=OIDValue new_pg_value.erb > zzz.oid_value.go -erb go_type=OID new_pg_value.erb > zzz.oid.go -erb go_type=Path new_pg_value.erb > zzz.path.go -erb go_type=pguint32 new_pg_value.erb > zzz.pguint32.go -erb go_type=Point new_pg_value.erb > zzz.point.go -erb go_type=Polygon new_pg_value.erb > zzz.polygon.go -erb go_type=QChar skip_text=true new_pg_value.erb > zzz.qchar.go -erb go_type=Text prefer_text_format=true new_pg_value.erb > zzz.text.go -erb go_type=TID new_pg_value.erb > zzz.tid.go -erb go_type=Time new_pg_value.erb > zzz.time.go -erb go_type=Timestamp new_pg_value.erb > zzz.timestamp.go -erb go_type=Timestamptz new_pg_value.erb > zzz.timestamptz.go -# erb go_type=Unknown new_pg_value.erb > zzz.unknown.go -erb go_type=UUID new_pg_value.erb > zzz.uuid.go -erb go_type=Varbit new_pg_value.erb > zzz.varbit.go -erb go_type=Varchar prefer_text_format=true new_pg_value.erb > zzz.varchar.go -erb go_type=XID new_pg_value.erb > zzz.xid.go -goimports -w zzz.* From 8aaf235595222e7aa973a1b166e2b92df409fef4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 21:41:08 -0600 Subject: [PATCH 0826/1158] Standardize scanner and valuer for int types --- pgtype/int.go | 76 ++++++++++++++++++++++++++++------------------- pgtype/int.go.erb | 32 ++++++++++++-------- 2 files changed, 66 insertions(+), 42 deletions(-) diff --git a/pgtype/int.go b/pgtype/int.go index 5fee64a6..54898420 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -13,7 +13,11 @@ import ( ) type Int64Scanner interface { - ScanInt64(v int64, valid bool) error + ScanInt64(Int8) error +} + +type Int64Valuer interface { + Int64Value() (Int8, error) } type Int2 struct { @@ -22,23 +26,27 @@ type Int2 struct { } // ScanInt64 implements the Int64Scanner interface. -func (dst *Int2) ScanInt64(n int64, valid bool) error { - if !valid { +func (dst *Int2) ScanInt64(n Int8) error { + if !n.Valid { *dst = Int2{} return nil } - if n < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n) + if n.Int < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n.Int) } - if n > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n) + if n.Int > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n.Int) } - *dst = Int2{Int: int16(n), Valid: true} + *dst = Int2{Int: int16(n.Int), Valid: true} return nil } +func (n Int2) Int64Value() (Int8, error) { + return Int8{Int: int64(n.Int), Valid: n.Valid}, nil +} + // Scan implements the database/sql Scanner interface. func (dst *Int2) Scan(src interface{}) error { if src == nil { @@ -511,7 +519,7 @@ func (scanPlanBinaryInt2ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCod } if src == nil { - return s.ScanInt64(0, false) + return s.ScanInt64(Int8{}) } if len(src) != 2 { @@ -520,7 +528,7 @@ func (scanPlanBinaryInt2ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCod n := int64(binary.BigEndian.Uint16(src)) - return s.ScanInt64(n, true) + return s.ScanInt64(Int8{Int: n, Valid: true}) } type Int4 struct { @@ -529,23 +537,27 @@ type Int4 struct { } // ScanInt64 implements the Int64Scanner interface. -func (dst *Int4) ScanInt64(n int64, valid bool) error { - if !valid { +func (dst *Int4) ScanInt64(n Int8) error { + if !n.Valid { *dst = Int4{} return nil } - if n < math.MinInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", n) + if n.Int < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n.Int) } - if n > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", n) + if n.Int > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n.Int) } - *dst = Int4{Int: int32(n), Valid: true} + *dst = Int4{Int: int32(n.Int), Valid: true} return nil } +func (n Int4) Int64Value() (Int8, error) { + return Int8{Int: int64(n.Int), Valid: n.Valid}, nil +} + // Scan implements the database/sql Scanner interface. func (dst *Int4) Scan(src interface{}) error { if src == nil { @@ -1029,7 +1041,7 @@ func (scanPlanBinaryInt4ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCod } if src == nil { - return s.ScanInt64(0, false) + return s.ScanInt64(Int8{}) } if len(src) != 4 { @@ -1038,7 +1050,7 @@ func (scanPlanBinaryInt4ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCod n := int64(binary.BigEndian.Uint32(src)) - return s.ScanInt64(n, true) + return s.ScanInt64(Int8{Int: n, Valid: true}) } type Int8 struct { @@ -1047,23 +1059,27 @@ type Int8 struct { } // ScanInt64 implements the Int64Scanner interface. -func (dst *Int8) ScanInt64(n int64, valid bool) error { - if !valid { +func (dst *Int8) ScanInt64(n Int8) error { + if !n.Valid { *dst = Int8{} return nil } - if n < math.MinInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", n) + if n.Int < math.MinInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n.Int) } - if n > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", n) + if n.Int > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n.Int) } - *dst = Int8{Int: int64(n), Valid: true} + *dst = Int8{Int: int64(n.Int), Valid: true} return nil } +func (n Int8) Int64Value() (Int8, error) { + return Int8{Int: int64(n.Int), Valid: n.Valid}, nil +} + // Scan implements the database/sql Scanner interface. func (dst *Int8) Scan(src interface{}) error { if src == nil { @@ -1569,7 +1585,7 @@ func (scanPlanBinaryInt8ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCod } if src == nil { - return s.ScanInt64(0, false) + return s.ScanInt64(Int8{}) } if len(src) != 8 { @@ -1578,7 +1594,7 @@ func (scanPlanBinaryInt8ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCod n := int64(binary.BigEndian.Uint64(src)) - return s.ScanInt64(n, true) + return s.ScanInt64(Int8{Int: n, Valid: true}) } type scanPlanTextAnyToInt8 struct{} @@ -1800,7 +1816,7 @@ func (scanPlanTextAnyToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode i } if src == nil { - return s.ScanInt64(0, false) + return s.ScanInt64(Int8{}) } n, err := strconv.ParseInt(string(src), 10, 64) @@ -1808,7 +1824,7 @@ func (scanPlanTextAnyToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode i return err } - err = s.ScanInt64(n, true) + err = s.ScanInt64(Int8{Int: n, Valid: true}) if err != nil { return err } diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 419dddd2..0d88dd42 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -11,7 +11,11 @@ import ( ) type Int64Scanner interface { - ScanInt64(v int64, valid bool) error + ScanInt64(Int8) error +} + +type Int64Valuer interface { + Int64Value() (Int8, error) } @@ -23,23 +27,27 @@ type Int<%= pg_byte_size %> struct { } // ScanInt64 implements the Int64Scanner interface. -func (dst *Int<%= pg_byte_size %>) ScanInt64(n int64, valid bool) error { - if !valid { +func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error { + if !n.Valid { *dst = Int<%= pg_byte_size %>{} return nil } - if n < math.MinInt<%= pg_bit_size %> { - return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + if n.Int < math.MinInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int) } - if n > math.MaxInt<%= pg_bit_size %> { - return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + if n.Int > math.MaxInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int) } - *dst = Int<%= pg_byte_size %>{Int: int<%= pg_bit_size %>(n), Valid: true} + *dst = Int<%= pg_byte_size %>{Int: int<%= pg_bit_size %>(n.Int), Valid: true} return nil } +func (n Int<%= pg_byte_size %>) Int64Value() (Int8, error) { + return Int8{Int: int64(n.Int), Valid: n.Valid}, nil +} + // Scan implements the database/sql Scanner interface. func (dst *Int<%= pg_byte_size %>) Scan(src interface{}) error { if src == nil { @@ -397,7 +405,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner) Scan(ci *ConnInfo, oid } if src == nil { - return s.ScanInt64(0, false) + return s.ScanInt64(Int8{}) } if len(src) != <%= pg_byte_size %> { @@ -407,7 +415,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner) Scan(ci *ConnInfo, oid n := int64(binary.BigEndian.Uint<%= pg_bit_size %>(src)) - return s.ScanInt64(n, true) + return s.ScanInt64(Int8{Int: n, Valid: true}) } <% end %> @@ -471,7 +479,7 @@ func (scanPlanTextAnyToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode i } if src == nil { - return s.ScanInt64(0, false) + return s.ScanInt64(Int8{}) } n, err := strconv.ParseInt(string(src), 10, 64) @@ -479,7 +487,7 @@ func (scanPlanTextAnyToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode i return err } - err = s.ScanInt64(n, true) + err = s.ScanInt64(Int8{Int: n, Valid: true}) if err != nil { return err } From ad79dccd99a2506e049372314bc5b47f4a2bbc6c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 23:44:53 -0600 Subject: [PATCH 0827/1158] Builtin types are automatically wrapped if necessary --- pgtype/array_codec.go | 52 +++++-- pgtype/builtin_wrappers.go | 291 +++++++++++++++++++++++++++++++++++++ pgtype/convert.go | 136 ----------------- pgtype/int.go | 252 +++++++++++++++++++++----------- pgtype/int.go.erb | 84 +++++++---- pgtype/pgtype.go | 176 ++++++++++++++++++++++ 6 files changed, 725 insertions(+), 266 deletions(-) create mode 100644 pgtype/builtin_wrappers.go diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index e8c2b2ed..1e506a43 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "encoding/binary" "fmt" + "reflect" "github.com/jackc/pgio" ) @@ -88,6 +89,8 @@ func (p *encodePlanArrayCodecText) Encode(value interface{}, buf []byte) (newBuf dimElemCounts[i] = int(dimensions[i].Length) * dimElemCounts[i+1] } + var encodePlan EncodePlan + var lastElemType reflect.Type inElemBuf := make([]byte, 0, 32) for i := 0; i < elementCount; i++ { if i > 0 { @@ -100,14 +103,23 @@ func (p *encodePlanArrayCodecText) Encode(value interface{}, buf []byte) (newBuf } } - encodePlan := p.ac.ElementCodec.PlanEncode(p.ci, p.ac.ElementOID, TextFormatCode, array.Index(i)) - if encodePlan == nil { - return nil, fmt.Errorf("unable to encode %v", array.Index(i)) - } - elemBuf, err := encodePlan.Encode(array.Index(i), inElemBuf) - if err != nil { - return nil, err + elem := array.Index(i) + var elemBuf []byte + if elem != nil { + elemType := reflect.TypeOf(elem) + if lastElemType != elemType { + lastElemType = elemType + encodePlan = p.ci.PlanEncode(p.ac.ElementOID, TextFormatCode, elem) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", array.Index(i)) + } + } + elemBuf, err = encodePlan.Encode(elem, inElemBuf) + if err != nil { + return nil, err + } } + if elemBuf == nil { buf = append(buf, `NULL`...) } else { @@ -151,18 +163,30 @@ func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newB buf = arrayHeader.EncodeBinary(p.ci, buf) elementCount := cardinality(dimensions) + + var encodePlan EncodePlan + var lastElemType reflect.Type for i := 0; i < elementCount; i++ { sp := len(buf) buf = pgio.AppendInt32(buf, -1) - encodePlan := p.ac.ElementCodec.PlanEncode(p.ci, p.ac.ElementOID, BinaryFormatCode, array.Index(i)) - if encodePlan == nil { - return nil, fmt.Errorf("unable to encode %v", array.Index(i)) - } - elemBuf, err := encodePlan.Encode(array.Index(i), buf) - if err != nil { - return nil, err + elem := array.Index(i) + var elemBuf []byte + if elem != nil { + elemType := reflect.TypeOf(elem) + if lastElemType != elemType { + lastElemType = elemType + encodePlan = p.ci.PlanEncode(p.ac.ElementOID, BinaryFormatCode, elem) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", array.Index(i)) + } + } + elemBuf, err = encodePlan.Encode(elem, buf) + if err != nil { + return nil, err + } } + if elemBuf == nil { pgio.SetInt32(buf[containsNullIndex:], 1) } else { diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go new file mode 100644 index 00000000..17fe4535 --- /dev/null +++ b/pgtype/builtin_wrappers.go @@ -0,0 +1,291 @@ +package pgtype + +import ( + "fmt" + "math" + "strconv" +) + +type int8Wrapper int8 + +func (n *int8Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int8") + } + + if v.Int < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", v.Int) + } + if v.Int > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", v.Int) + } + *n = int8Wrapper(v.Int) + + return nil +} + +func (n int8Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(n), Valid: true}, nil +} + +type int16Wrapper int16 + +func (n *int16Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int16") + } + + if v.Int < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", v.Int) + } + if v.Int > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", v.Int) + } + *n = int16Wrapper(v.Int) + + return nil +} + +func (n int16Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(n), Valid: true}, nil +} + +type int32Wrapper int32 + +func (n *int32Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int32") + } + + if v.Int < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", v.Int) + } + if v.Int > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", v.Int) + } + *n = int32Wrapper(v.Int) + + return nil +} + +func (n int32Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(n), Valid: true}, nil +} + +type int64Wrapper int64 + +func (n *int64Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int64") + } + + *n = int64Wrapper(v.Int) + + return nil +} + +func (n int64Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(n), Valid: true}, nil +} + +type intWrapper int + +func (n *intWrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int") + } + + if v.Int < math.MinInt { + return fmt.Errorf("%d is less than minimum value for int", v.Int) + } + if v.Int > math.MaxInt { + return fmt.Errorf("%d is greater than maximum value for int", v.Int) + } + + *n = intWrapper(v.Int) + + return nil +} + +func (n intWrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(n), Valid: true}, nil +} + +type uint8Wrapper uint8 + +func (n *uint8Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint8") + } + + if v.Int < 0 { + return fmt.Errorf("%d is less than minimum value for uint8", v.Int) + } + if v.Int > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", v.Int) + } + *n = uint8Wrapper(v.Int) + + return nil +} + +func (n uint8Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(n), Valid: true}, nil +} + +type uint16Wrapper uint16 + +func (n *uint16Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint16") + } + + if v.Int < 0 { + return fmt.Errorf("%d is less than minimum value for uint16", v.Int) + } + if v.Int > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", v.Int) + } + *n = uint16Wrapper(v.Int) + + return nil +} + +func (n uint16Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(n), Valid: true}, nil +} + +type uint32Wrapper uint32 + +func (n *uint32Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint32") + } + + if v.Int < 0 { + return fmt.Errorf("%d is less than minimum value for uint32", v.Int) + } + if v.Int > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", v.Int) + } + *n = uint32Wrapper(v.Int) + + return nil +} + +func (n uint32Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(n), Valid: true}, nil +} + +type uint64Wrapper uint64 + +func (n *uint64Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint64") + } + + if v.Int < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", v.Int) + } + + *n = uint64Wrapper(v.Int) + + return nil +} + +func (n uint64Wrapper) Int64Value() (Int8, error) { + if uint64(n) > uint64(math.MaxInt64) { + return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", n) + } + + return Int8{Int: int64(n), Valid: true}, nil +} + +type uintWrapper uint + +func (n *uintWrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint64") + } + + if v.Int < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", v.Int) + } + + if uint64(v.Int) > math.MaxUint { + return fmt.Errorf("%d is greater than maximum value for uint", v.Int) + } + + *n = uintWrapper(v.Int) + + return nil +} + +func (n uintWrapper) Int64Value() (Int8, error) { + if uint64(n) > uint64(math.MaxInt64) { + return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", n) + } + + return Int8{Int: int64(n), Valid: true}, nil +} + +type float32Wrapper float32 + +func (n *float32Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *float32") + } + + *n = float32Wrapper(v.Int) + + return nil +} + +func (n float32Wrapper) Int64Value() (Int8, error) { + if n > math.MaxInt64 { + return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", n) + } + + return Int8{Int: int64(n), Valid: true}, nil +} + +type float64Wrapper float64 + +func (n *float64Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *float64") + } + + *n = float64Wrapper(v.Int) + + return nil +} + +func (n float64Wrapper) Int64Value() (Int8, error) { + if n > math.MaxInt64 { + return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", n) + } + + return Int8{Int: int64(n), Valid: true}, nil +} + +type stringWrapper string + +func (s *stringWrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *string") + } + + *s = stringWrapper(strconv.FormatInt(v.Int, 10)) + + return nil +} + +func (s stringWrapper) Int64Value() (Int8, error) { + num, err := strconv.ParseInt(string(s), 10, 64) + if err != nil { + return Int8{}, err + } + + return Int8{Int: int64(num), Valid: true}, nil +} diff --git a/pgtype/convert.go b/pgtype/convert.go index ee5ba393..21e208f5 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -5,7 +5,6 @@ import ( "fmt" "math" "reflect" - "strconv" "time" ) @@ -453,141 +452,6 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { return nil, false } -func convertToInt64ForEncode(v interface{}) (n int64, valid bool, err error) { - if v == nil { - return 0, false, nil - } - - switch v := v.(type) { - case int8: - return int64(v), true, nil - case uint8: - return int64(v), true, nil - case int16: - return int64(v), true, nil - case uint16: - return int64(v), true, nil - case int32: - return int64(v), true, nil - case uint32: - return int64(v), true, nil - case int64: - return int64(v), true, nil - case uint64: - if v > math.MaxInt64 { - return 0, false, fmt.Errorf("%d is greater than maximum value for int64", v) - } - return int64(v), true, nil - case int: - return int64(v), true, nil - case uint: - if v > math.MaxInt64 { - return 0, false, fmt.Errorf("%d is greater than maximum value for int64", v) - } - return int64(v), true, nil - case string: - num, err := strconv.ParseInt(v, 10, 64) - if err != nil { - return 0, false, err - } - return num, true, nil - case float32: - if v > math.MaxInt64 { - return 0, false, fmt.Errorf("%f is greater than maximum value for int64", v) - } - return int64(v), true, nil - case float64: - if v > math.MaxInt64 { - return 0, false, fmt.Errorf("%f is greater than maximum value for int64", v) - } - return int64(v), true, nil - case *int8: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *uint8: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *int16: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *uint16: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *int32: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *uint32: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *int64: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *uint64: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *int: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *uint: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *string: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *float32: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *float64: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - - default: - if originalvalue, ok := underlyingNumberType(v); ok { - return convertToInt64ForEncode(originalvalue) - } - return 0, false, fmt.Errorf("cannot convert %v to int64", v) - } -} - func init() { kindTypes = map[reflect.Kind]reflect.Type{ reflect.Bool: reflect.TypeOf(false), diff --git a/pgtype/int.go b/pgtype/int.go index 54898420..553d4dd0 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -130,54 +130,80 @@ func (Int2Codec) PreferredFormat() int16 { func (Int2Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: - return encodePlanInt2CodecBinary{} + switch value.(type) { + case int16: + return encodePlanInt2CodecBinaryInt16{} + case Int64Valuer: + return encodePlanInt2CodecBinaryInt64Valuer{} + } case TextFormatCode: - return encodePlanInt2CodecText{} + switch value.(type) { + case int16: + return encodePlanInt2CodecTextInt16{} + case Int64Valuer: + return encodePlanInt2CodecTextInt64Valuer{} + } } return nil } -type encodePlanInt2CodecBinary struct{} - -func (encodePlanInt2CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) - if err != nil { - return nil, fmt.Errorf("cannot convert %v to int2: %v", value, err) - } - if !valid { - return nil, nil - } - - if n > math.MaxInt16 { - return nil, fmt.Errorf("%d is greater than maximum value for int2", n) - } - if n < math.MinInt16 { - return nil, fmt.Errorf("%d is less than minimum value for int2", n) - } +type encodePlanInt2CodecBinaryInt16 struct{} +func (encodePlanInt2CodecBinaryInt16) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(int16) return pgio.AppendInt16(buf, int16(n)), nil } -type encodePlanInt2CodecText struct{} +type encodePlanInt2CodecTextInt16 struct{} -func (encodePlanInt2CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) +func (encodePlanInt2CodecTextInt16) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(int16) + return append(buf, strconv.FormatInt(int64(n), 10)...), nil +} + +type encodePlanInt2CodecBinaryInt64Valuer struct{} + +func (encodePlanInt2CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() if err != nil { - return nil, fmt.Errorf("cannot convert %v to int2: %v", value, err) + return nil, err } - if !valid { + + if !n.Valid { return nil, nil } - if n > math.MaxInt16 { - return nil, fmt.Errorf("%d is greater than maximum value for int2", n) + if n.Int > math.MaxInt16 { + return nil, fmt.Errorf("%d is greater than maximum value for int2", n.Int) } - if n < math.MinInt16 { - return nil, fmt.Errorf("%d is less than minimum value for int2", n) + if n.Int < math.MinInt16 { + return nil, fmt.Errorf("%d is less than minimum value for int2", n.Int) } - return append(buf, strconv.FormatInt(n, 10)...), nil + return pgio.AppendInt16(buf, int16(n.Int)), nil +} + +type encodePlanInt2CodecTextInt64Valuer struct{} + +func (encodePlanInt2CodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int > math.MaxInt16 { + return nil, fmt.Errorf("%d is greater than maximum value for int2", n.Int) + } + if n.Int < math.MinInt16 { + return nil, fmt.Errorf("%d is less than minimum value for int2", n.Int) + } + + return append(buf, strconv.FormatInt(n.Int, 10)...), nil } func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { @@ -641,54 +667,80 @@ func (Int4Codec) PreferredFormat() int16 { func (Int4Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: - return encodePlanInt4CodecBinary{} + switch value.(type) { + case int32: + return encodePlanInt4CodecBinaryInt32{} + case Int64Valuer: + return encodePlanInt4CodecBinaryInt64Valuer{} + } case TextFormatCode: - return encodePlanInt4CodecText{} + switch value.(type) { + case int32: + return encodePlanInt4CodecTextInt32{} + case Int64Valuer: + return encodePlanInt4CodecTextInt64Valuer{} + } } return nil } -type encodePlanInt4CodecBinary struct{} - -func (encodePlanInt4CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) - if err != nil { - return nil, fmt.Errorf("cannot convert %v to int4: %v", value, err) - } - if !valid { - return nil, nil - } - - if n > math.MaxInt32 { - return nil, fmt.Errorf("%d is greater than maximum value for int4", n) - } - if n < math.MinInt32 { - return nil, fmt.Errorf("%d is less than minimum value for int4", n) - } +type encodePlanInt4CodecBinaryInt32 struct{} +func (encodePlanInt4CodecBinaryInt32) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(int32) return pgio.AppendInt32(buf, int32(n)), nil } -type encodePlanInt4CodecText struct{} +type encodePlanInt4CodecTextInt32 struct{} -func (encodePlanInt4CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) +func (encodePlanInt4CodecTextInt32) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(int32) + return append(buf, strconv.FormatInt(int64(n), 10)...), nil +} + +type encodePlanInt4CodecBinaryInt64Valuer struct{} + +func (encodePlanInt4CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() if err != nil { - return nil, fmt.Errorf("cannot convert %v to int4: %v", value, err) + return nil, err } - if !valid { + + if !n.Valid { return nil, nil } - if n > math.MaxInt32 { - return nil, fmt.Errorf("%d is greater than maximum value for int4", n) + if n.Int > math.MaxInt32 { + return nil, fmt.Errorf("%d is greater than maximum value for int4", n.Int) } - if n < math.MinInt32 { - return nil, fmt.Errorf("%d is less than minimum value for int4", n) + if n.Int < math.MinInt32 { + return nil, fmt.Errorf("%d is less than minimum value for int4", n.Int) } - return append(buf, strconv.FormatInt(n, 10)...), nil + return pgio.AppendInt32(buf, int32(n.Int)), nil +} + +type encodePlanInt4CodecTextInt64Valuer struct{} + +func (encodePlanInt4CodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int > math.MaxInt32 { + return nil, fmt.Errorf("%d is greater than maximum value for int4", n.Int) + } + if n.Int < math.MinInt32 { + return nil, fmt.Errorf("%d is less than minimum value for int4", n.Int) + } + + return append(buf, strconv.FormatInt(n.Int, 10)...), nil } func (Int4Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { @@ -1163,54 +1215,80 @@ func (Int8Codec) PreferredFormat() int16 { func (Int8Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: - return encodePlanInt8CodecBinary{} + switch value.(type) { + case int64: + return encodePlanInt8CodecBinaryInt64{} + case Int64Valuer: + return encodePlanInt8CodecBinaryInt64Valuer{} + } case TextFormatCode: - return encodePlanInt8CodecText{} + switch value.(type) { + case int64: + return encodePlanInt8CodecTextInt64{} + case Int64Valuer: + return encodePlanInt8CodecTextInt64Valuer{} + } } return nil } -type encodePlanInt8CodecBinary struct{} - -func (encodePlanInt8CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) - if err != nil { - return nil, fmt.Errorf("cannot convert %v to int8: %v", value, err) - } - if !valid { - return nil, nil - } - - if n > math.MaxInt64 { - return nil, fmt.Errorf("%d is greater than maximum value for int8", n) - } - if n < math.MinInt64 { - return nil, fmt.Errorf("%d is less than minimum value for int8", n) - } +type encodePlanInt8CodecBinaryInt64 struct{} +func (encodePlanInt8CodecBinaryInt64) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(int64) return pgio.AppendInt64(buf, int64(n)), nil } -type encodePlanInt8CodecText struct{} +type encodePlanInt8CodecTextInt64 struct{} -func (encodePlanInt8CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) +func (encodePlanInt8CodecTextInt64) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(int64) + return append(buf, strconv.FormatInt(int64(n), 10)...), nil +} + +type encodePlanInt8CodecBinaryInt64Valuer struct{} + +func (encodePlanInt8CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() if err != nil { - return nil, fmt.Errorf("cannot convert %v to int8: %v", value, err) + return nil, err } - if !valid { + + if !n.Valid { return nil, nil } - if n > math.MaxInt64 { - return nil, fmt.Errorf("%d is greater than maximum value for int8", n) + if n.Int > math.MaxInt64 { + return nil, fmt.Errorf("%d is greater than maximum value for int8", n.Int) } - if n < math.MinInt64 { - return nil, fmt.Errorf("%d is less than minimum value for int8", n) + if n.Int < math.MinInt64 { + return nil, fmt.Errorf("%d is less than minimum value for int8", n.Int) } - return append(buf, strconv.FormatInt(n, 10)...), nil + return pgio.AppendInt64(buf, int64(n.Int)), nil +} + +type encodePlanInt8CodecTextInt64Valuer struct{} + +func (encodePlanInt8CodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int > math.MaxInt64 { + return nil, fmt.Errorf("%d is greater than maximum value for int8", n.Int) + } + if n.Int < math.MinInt64 { + return nil, fmt.Errorf("%d is less than minimum value for int8", n.Int) + } + + return append(buf, strconv.FormatInt(n.Int, 10)...), nil } func (Int8Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 0d88dd42..6aecb761 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -131,54 +131,80 @@ func (Int<%= pg_byte_size %>Codec) PreferredFormat() int16 { func (Int<%= pg_byte_size %>Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: - return encodePlanInt<%= pg_byte_size %>CodecBinary{} + switch value.(type) { + case int<%= pg_bit_size %>: + return encodePlanInt<%= pg_byte_size %>CodecBinaryInt<%= pg_bit_size %>{} + case Int64Valuer: + return encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer{} + } case TextFormatCode: - return encodePlanInt<%= pg_byte_size %>CodecText{} + switch value.(type) { + case int<%= pg_bit_size %>: + return encodePlanInt<%= pg_byte_size %>CodecTextInt<%= pg_bit_size %>{} + case Int64Valuer: + return encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer{} + } } return nil } -type encodePlanInt<%= pg_byte_size %>CodecBinary struct{} - -func (encodePlanInt<%= pg_byte_size %>CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) - if err != nil { - return nil, fmt.Errorf("cannot convert %v to int<%= pg_byte_size %>: %v", value, err) - } - if !valid { - return nil, nil - } - - if n > math.MaxInt<%= pg_bit_size %> { - return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n) - } - if n < math.MinInt<%= pg_bit_size %> { - return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n) - } +type encodePlanInt<%= pg_byte_size %>CodecBinaryInt<%= pg_bit_size %> struct{} +func (encodePlanInt<%= pg_byte_size %>CodecBinaryInt<%= pg_bit_size %>) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(int<%= pg_bit_size %>) return pgio.AppendInt<%= pg_bit_size %>(buf, int<%= pg_bit_size %>(n)), nil } -type encodePlanInt<%= pg_byte_size %>CodecText struct{} +type encodePlanInt<%= pg_byte_size %>CodecTextInt<%= pg_bit_size %> struct{} -func (encodePlanInt<%= pg_byte_size %>CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) +func (encodePlanInt<%= pg_byte_size %>CodecTextInt<%= pg_bit_size %>) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(int<%= pg_bit_size %>) + return append(buf, strconv.FormatInt(int64(n), 10)...), nil +} + +type encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer struct{} + +func (encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() if err != nil { - return nil, fmt.Errorf("cannot convert %v to int<%= pg_byte_size %>: %v", value, err) + return nil, err } - if !valid { + + if !n.Valid { return nil, nil } - if n > math.MaxInt<%= pg_bit_size %> { - return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n) + if n.Int > math.MaxInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n.Int) } - if n < math.MinInt<%= pg_bit_size %> { - return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n) + if n.Int < math.MinInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n.Int) } - return append(buf, strconv.FormatInt(n, 10)...), nil + return pgio.AppendInt<%= pg_bit_size %>(buf, int<%= pg_bit_size %>(n.Int)), nil +} + +type encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer struct{} + +func (encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int > math.MaxInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n.Int) + } + if n.Int < math.MinInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n.Int) + } + + return append(buf, strconv.FormatInt(n.Int, 10)...), nil } func (Int<%= pg_byte_size %>Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 7fff7dd5..bb0d2a9d 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1011,6 +1011,14 @@ func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) Enco } } + if wrapperPlan, nextValue, ok := tryWrapBuiltinTypeEncodePlan(value); ok { + if nextPlan := ci.PlanEncode(oid, format, nextValue); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + + } + } return nil @@ -1074,6 +1082,174 @@ func tryBaseTypeEncodePlan(value interface{}) (plan *baseTypeEncodePlan, nextVal return nil, nil, false } +type WrappedEncodePlanNextSetter interface { + SetNext(EncodePlan) + EncodePlan +} + +func tryWrapBuiltinTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { + switch value.(type) { + case int8: + return &wrapInt8EncodePlan{}, int8Wrapper(value.(int8)), true + case int16: + return &wrapInt16EncodePlan{}, int16Wrapper(value.(int16)), true + case int32: + return &wrapInt32EncodePlan{}, int32Wrapper(value.(int32)), true + case int64: + return &wrapInt64EncodePlan{}, int64Wrapper(value.(int64)), true + case int: + return &wrapIntEncodePlan{}, intWrapper(value.(int)), true + case uint8: + return &wrapUint8EncodePlan{}, uint8Wrapper(value.(uint8)), true + case uint16: + return &wrapUint16EncodePlan{}, uint16Wrapper(value.(uint16)), true + case uint32: + return &wrapUint32EncodePlan{}, uint32Wrapper(value.(uint32)), true + case uint64: + return &wrapUint64EncodePlan{}, uint64Wrapper(value.(uint64)), true + case uint: + return &wrapUintEncodePlan{}, uintWrapper(value.(uint)), true + case float32: + return &wrapFloat32EncodePlan{}, float32Wrapper(value.(float32)), true + case float64: + return &wrapFloat64EncodePlan{}, float64Wrapper(value.(float64)), true + case string: + return &wrapStringEncodePlan{}, stringWrapper(value.(string)), true + } + + return nil, nil, false +} + +type wrapInt8EncodePlan struct { + next EncodePlan +} + +func (plan *wrapInt8EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapInt8EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(int8Wrapper(value.(int8)), buf) +} + +type wrapInt16EncodePlan struct { + next EncodePlan +} + +func (plan *wrapInt16EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapInt16EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(int16Wrapper(value.(int16)), buf) +} + +type wrapInt32EncodePlan struct { + next EncodePlan +} + +func (plan *wrapInt32EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapInt32EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(int32Wrapper(value.(int32)), buf) +} + +type wrapInt64EncodePlan struct { + next EncodePlan +} + +func (plan *wrapInt64EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapInt64EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(int64Wrapper(value.(int64)), buf) +} + +type wrapIntEncodePlan struct { + next EncodePlan +} + +func (plan *wrapIntEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapIntEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(intWrapper(value.(int)), buf) +} + +type wrapUint8EncodePlan struct { + next EncodePlan +} + +func (plan *wrapUint8EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUint8EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uint8Wrapper(value.(uint8)), buf) +} + +type wrapUint16EncodePlan struct { + next EncodePlan +} + +func (plan *wrapUint16EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUint16EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uint16Wrapper(value.(uint16)), buf) +} + +type wrapUint32EncodePlan struct { + next EncodePlan +} + +func (plan *wrapUint32EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUint32EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uint32Wrapper(value.(uint32)), buf) +} + +type wrapUint64EncodePlan struct { + next EncodePlan +} + +func (plan *wrapUint64EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUint64EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uint64Wrapper(value.(uint64)), buf) +} + +type wrapUintEncodePlan struct { + next EncodePlan +} + +func (plan *wrapUintEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUintEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uintWrapper(value.(uint)), buf) +} + +type wrapFloat32EncodePlan struct { + next EncodePlan +} + +func (plan *wrapFloat32EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapFloat32EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(float32Wrapper(value.(float32)), buf) +} + +type wrapFloat64EncodePlan struct { + next EncodePlan +} + +func (plan *wrapFloat64EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapFloat64EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(float64Wrapper(value.(float64)), buf) +} + +type wrapStringEncodePlan struct { + next EncodePlan +} + +func (plan *wrapStringEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapStringEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(stringWrapper(value.(string)), buf) +} + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written. From b26618ac95e01a1fc4c16c1e019eb3b70262c745 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 9 Jan 2022 00:25:20 -0600 Subject: [PATCH 0828/1158] Prevent try underlying type from acting on a value This is necessary to prevent infinite recursion where a base type is wrapped and then unwrapped. --- pgtype/builtin_wrappers.go | 146 ++++++++++++++++++++++--------------- pgtype/pgtype.go | 41 +++++++---- 2 files changed, 113 insertions(+), 74 deletions(-) diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 17fe4535..5453bf18 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -8,7 +8,9 @@ import ( type int8Wrapper int8 -func (n *int8Wrapper) ScanInt64(v Int8) error { +func (w int8Wrapper) SkipUnderlyingTypePlan() {} + +func (w *int8Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *int8") } @@ -19,18 +21,20 @@ func (n *int8Wrapper) ScanInt64(v Int8) error { if v.Int > math.MaxInt8 { return fmt.Errorf("%d is greater than maximum value for int8", v.Int) } - *n = int8Wrapper(v.Int) + *w = int8Wrapper(v.Int) return nil } -func (n int8Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(n), Valid: true}, nil +func (w int8Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(w), Valid: true}, nil } type int16Wrapper int16 -func (n *int16Wrapper) ScanInt64(v Int8) error { +func (w int16Wrapper) SkipUnderlyingTypePlan() {} + +func (w *int16Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *int16") } @@ -41,18 +45,20 @@ func (n *int16Wrapper) ScanInt64(v Int8) error { if v.Int > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for int16", v.Int) } - *n = int16Wrapper(v.Int) + *w = int16Wrapper(v.Int) return nil } -func (n int16Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(n), Valid: true}, nil +func (w int16Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(w), Valid: true}, nil } type int32Wrapper int32 -func (n *int32Wrapper) ScanInt64(v Int8) error { +func (w int32Wrapper) SkipUnderlyingTypePlan() {} + +func (w *int32Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *int32") } @@ -63,34 +69,38 @@ func (n *int32Wrapper) ScanInt64(v Int8) error { if v.Int > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for int32", v.Int) } - *n = int32Wrapper(v.Int) + *w = int32Wrapper(v.Int) return nil } -func (n int32Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(n), Valid: true}, nil +func (w int32Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(w), Valid: true}, nil } type int64Wrapper int64 -func (n *int64Wrapper) ScanInt64(v Int8) error { +func (w int64Wrapper) SkipUnderlyingTypePlan() {} + +func (w *int64Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *int64") } - *n = int64Wrapper(v.Int) + *w = int64Wrapper(v.Int) return nil } -func (n int64Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(n), Valid: true}, nil +func (w int64Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(w), Valid: true}, nil } type intWrapper int -func (n *intWrapper) ScanInt64(v Int8) error { +func (w intWrapper) SkipUnderlyingTypePlan() {} + +func (w *intWrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *int") } @@ -102,18 +112,20 @@ func (n *intWrapper) ScanInt64(v Int8) error { return fmt.Errorf("%d is greater than maximum value for int", v.Int) } - *n = intWrapper(v.Int) + *w = intWrapper(v.Int) return nil } -func (n intWrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(n), Valid: true}, nil +func (w intWrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(w), Valid: true}, nil } type uint8Wrapper uint8 -func (n *uint8Wrapper) ScanInt64(v Int8) error { +func (w uint8Wrapper) SkipUnderlyingTypePlan() {} + +func (w *uint8Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *uint8") } @@ -124,18 +136,20 @@ func (n *uint8Wrapper) ScanInt64(v Int8) error { if v.Int > math.MaxUint8 { return fmt.Errorf("%d is greater than maximum value for uint8", v.Int) } - *n = uint8Wrapper(v.Int) + *w = uint8Wrapper(v.Int) return nil } -func (n uint8Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(n), Valid: true}, nil +func (w uint8Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(w), Valid: true}, nil } type uint16Wrapper uint16 -func (n *uint16Wrapper) ScanInt64(v Int8) error { +func (w uint16Wrapper) SkipUnderlyingTypePlan() {} + +func (w *uint16Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *uint16") } @@ -146,18 +160,20 @@ func (n *uint16Wrapper) ScanInt64(v Int8) error { if v.Int > math.MaxUint16 { return fmt.Errorf("%d is greater than maximum value for uint16", v.Int) } - *n = uint16Wrapper(v.Int) + *w = uint16Wrapper(v.Int) return nil } -func (n uint16Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(n), Valid: true}, nil +func (w uint16Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(w), Valid: true}, nil } type uint32Wrapper uint32 -func (n *uint32Wrapper) ScanInt64(v Int8) error { +func (w uint32Wrapper) SkipUnderlyingTypePlan() {} + +func (w *uint32Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *uint32") } @@ -168,18 +184,20 @@ func (n *uint32Wrapper) ScanInt64(v Int8) error { if v.Int > math.MaxUint32 { return fmt.Errorf("%d is greater than maximum value for uint32", v.Int) } - *n = uint32Wrapper(v.Int) + *w = uint32Wrapper(v.Int) return nil } -func (n uint32Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(n), Valid: true}, nil +func (w uint32Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(w), Valid: true}, nil } type uint64Wrapper uint64 -func (n *uint64Wrapper) ScanInt64(v Int8) error { +func (w uint64Wrapper) SkipUnderlyingTypePlan() {} + +func (w *uint64Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *uint64") } @@ -188,22 +206,24 @@ func (n *uint64Wrapper) ScanInt64(v Int8) error { return fmt.Errorf("%d is less than minimum value for uint64", v.Int) } - *n = uint64Wrapper(v.Int) + *w = uint64Wrapper(v.Int) return nil } -func (n uint64Wrapper) Int64Value() (Int8, error) { - if uint64(n) > uint64(math.MaxInt64) { - return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", n) +func (w uint64Wrapper) Int64Value() (Int8, error) { + if uint64(w) > uint64(math.MaxInt64) { + return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", w) } - return Int8{Int: int64(n), Valid: true}, nil + return Int8{Int: int64(w), Valid: true}, nil } type uintWrapper uint -func (n *uintWrapper) ScanInt64(v Int8) error { +func (w uintWrapper) SkipUnderlyingTypePlan() {} + +func (w *uintWrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *uint64") } @@ -216,73 +236,79 @@ func (n *uintWrapper) ScanInt64(v Int8) error { return fmt.Errorf("%d is greater than maximum value for uint", v.Int) } - *n = uintWrapper(v.Int) + *w = uintWrapper(v.Int) return nil } -func (n uintWrapper) Int64Value() (Int8, error) { - if uint64(n) > uint64(math.MaxInt64) { - return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", n) +func (w uintWrapper) Int64Value() (Int8, error) { + if uint64(w) > uint64(math.MaxInt64) { + return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", w) } - return Int8{Int: int64(n), Valid: true}, nil + return Int8{Int: int64(w), Valid: true}, nil } type float32Wrapper float32 -func (n *float32Wrapper) ScanInt64(v Int8) error { +func (w float32Wrapper) SkipUnderlyingTypePlan() {} + +func (w *float32Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *float32") } - *n = float32Wrapper(v.Int) + *w = float32Wrapper(v.Int) return nil } -func (n float32Wrapper) Int64Value() (Int8, error) { - if n > math.MaxInt64 { - return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", n) +func (w float32Wrapper) Int64Value() (Int8, error) { + if w > math.MaxInt64 { + return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", w) } - return Int8{Int: int64(n), Valid: true}, nil + return Int8{Int: int64(w), Valid: true}, nil } type float64Wrapper float64 -func (n *float64Wrapper) ScanInt64(v Int8) error { +func (w float64Wrapper) SkipUnderlyingTypePlan() {} + +func (w *float64Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *float64") } - *n = float64Wrapper(v.Int) + *w = float64Wrapper(v.Int) return nil } -func (n float64Wrapper) Int64Value() (Int8, error) { - if n > math.MaxInt64 { - return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", n) +func (w float64Wrapper) Int64Value() (Int8, error) { + if w > math.MaxInt64 { + return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", w) } - return Int8{Int: int64(n), Valid: true}, nil + return Int8{Int: int64(w), Valid: true}, nil } type stringWrapper string -func (s *stringWrapper) ScanInt64(v Int8) error { +func (w stringWrapper) SkipUnderlyingTypePlan() {} + +func (w *stringWrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *string") } - *s = stringWrapper(strconv.FormatInt(v.Int, 10)) + *w = stringWrapper(strconv.FormatInt(v.Int, 10)) return nil } -func (s stringWrapper) Int64Value() (Int8, error) { - num, err := strconv.ParseInt(string(s), 10, 64) +func (w stringWrapper) Int64Value() (Int8, error) { + num, err := strconv.ParseInt(string(w), 10, 64) if err != nil { return Int8{}, err } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index bb0d2a9d..c9da1322 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -747,7 +747,12 @@ func tryPointerPointerScanPlan(dst interface{}) (plan *pointerPointerScanPlan, n return nil, nil, false } -var elemKindToBasePointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ +// SkipUnderlyingTypePlanner prevents PlanScan and PlanDecode from trying to use the underlying type. +type SkipUnderlyingTypePlanner interface { + SkipUnderlyingTypePlan() +} + +var elemKindToPointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ reflect.Int: reflect.TypeOf(new(int)), reflect.Int8: reflect.TypeOf(new(int8)), reflect.Int16: reflect.TypeOf(new(int16)), @@ -763,13 +768,13 @@ var elemKindToBasePointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind] reflect.String: reflect.TypeOf(new(string)), } -type baseTypeScanPlan struct { +type underlyingTypeScanPlan struct { dstType reflect.Type nextDstType reflect.Type next ScanPlan } -func (plan *baseTypeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (plan *underlyingTypeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if plan.dstType != reflect.TypeOf(dst) { newPlan := ci.PlanScan(oid, formatCode, dst) return newPlan.Scan(ci, oid, formatCode, src, dst) @@ -778,14 +783,18 @@ func (plan *baseTypeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, s return plan.next.Scan(ci, oid, formatCode, src, reflect.ValueOf(dst).Convert(plan.nextDstType).Interface()) } -func tryBaseTypeScanPlan(dst interface{}) (plan *baseTypeScanPlan, nextDst interface{}, ok bool) { +func tryUnderlyingTypeScanPlan(dst interface{}) (plan *underlyingTypeScanPlan, nextDst interface{}, ok bool) { + if _, ok := dst.(SkipUnderlyingTypePlanner); ok { + return nil, nil, false + } + dstValue := reflect.ValueOf(dst) if dstValue.Kind() == reflect.Ptr { elemValue := dstValue.Elem() - nextDstType := elemKindToBasePointerTypes[elemValue.Kind()] + nextDstType := elemKindToPointerTypes[elemValue.Kind()] if nextDstType != nil && dstValue.Type() != nextDstType { - return &baseTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true + return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true } } @@ -881,7 +890,7 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan } } - if baseTypePlan, nextDst, ok := tryBaseTypeScanPlan(dst); ok { + if baseTypePlan, nextDst, ok := tryUnderlyingTypeScanPlan(dst); ok { if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { baseTypePlan.next = nextPlan return baseTypePlan @@ -1004,7 +1013,7 @@ func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) Enco } } - if baseTypePlan, nextValue, ok := tryBaseTypeEncodePlan(value); ok { + if baseTypePlan, nextValue, ok := tryUnderlyingTypeEncodePlan(value); ok { if nextPlan := ci.PlanEncode(oid, format, nextValue); nextPlan != nil { baseTypePlan.next = nextPlan return baseTypePlan @@ -1046,7 +1055,7 @@ func tryDerefPointerEncodePlan(value interface{}) (plan *derefPointerEncodePlan, return nil, nil, false } -var kindToBaseTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ +var kindToTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ reflect.Int: reflect.TypeOf(int(0)), reflect.Int8: reflect.TypeOf(int8(0)), reflect.Int16: reflect.TypeOf(int16(0)), @@ -1062,21 +1071,25 @@ var kindToBaseTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Typ reflect.String: reflect.TypeOf(""), } -type baseTypeEncodePlan struct { +type underlyingTypeEncodePlan struct { nextValueType reflect.Type next EncodePlan } -func (plan *baseTypeEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *underlyingTypeEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(reflect.ValueOf(value).Convert(plan.nextValueType).Interface(), buf) } -func tryBaseTypeEncodePlan(value interface{}) (plan *baseTypeEncodePlan, nextValue interface{}, ok bool) { +func tryUnderlyingTypeEncodePlan(value interface{}) (plan *underlyingTypeEncodePlan, nextValue interface{}, ok bool) { + if _, ok := value.(SkipUnderlyingTypePlanner); ok { + return nil, nil, false + } + refValue := reflect.ValueOf(value) - nextValueType := kindToBaseTypes[refValue.Kind()] + nextValueType := kindToTypes[refValue.Kind()] if nextValueType != nil && refValue.Type() != nextValueType { - return &baseTypeEncodePlan{nextValueType: nextValueType}, refValue.Convert(nextValueType).Interface(), true + return &underlyingTypeEncodePlan{nextValueType: nextValueType}, refValue.Convert(nextValueType).Interface(), true } return nil, nil, false From eec82c9433368eaf9855ad6ee0a4f36418c6318a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 9 Jan 2022 00:35:49 -0600 Subject: [PATCH 0829/1158] Replace CID, OID, OIDValue, and XID with Uint32 --- pgtype/cid.go | 61 -------- pgtype/cid_test.go | 102 ------------- pgtype/oid.go | 81 ----------- pgtype/oid_value.go | 55 ------- pgtype/oid_value_test.go | 95 ------------ pgtype/pgtype.go | 6 +- pgtype/pguint32.go | 148 ------------------- pgtype/uint32.go | 303 +++++++++++++++++++++++++++++++++++++++ pgtype/uint32_test.go | 19 +++ pgtype/xid.go | 64 --------- pgtype/xid_test.go | 102 ------------- stdlib/sql.go | 24 +--- 12 files changed, 327 insertions(+), 733 deletions(-) delete mode 100644 pgtype/cid.go delete mode 100644 pgtype/cid_test.go delete mode 100644 pgtype/oid.go delete mode 100644 pgtype/oid_value.go delete mode 100644 pgtype/oid_value_test.go delete mode 100644 pgtype/pguint32.go create mode 100644 pgtype/uint32.go create mode 100644 pgtype/uint32_test.go delete mode 100644 pgtype/xid.go delete mode 100644 pgtype/xid_test.go diff --git a/pgtype/cid.go b/pgtype/cid.go deleted file mode 100644 index b944748c..00000000 --- a/pgtype/cid.go +++ /dev/null @@ -1,61 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -// CID is PostgreSQL's Command Identifier type. -// -// When one does -// -// select cmin, cmax, * from some_table; -// -// it is the data type of the cmin and cmax hidden system columns. -// -// It is currently implemented as an unsigned four byte integer. -// Its definition can be found in src/include/c.h as CommandId -// in the PostgreSQL sources. -type CID pguint32 - -// Set converts from src to dst. Note that as CID is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *CID) Set(src interface{}) error { - return (*pguint32)(dst).Set(src) -} - -func (dst CID) Get() interface{} { - return (pguint32)(dst).Get() -} - -// AssignTo assigns from src to dst. Note that as CID is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *CID) AssignTo(dst interface{}) error { - return (*pguint32)(src).AssignTo(dst) -} - -func (dst *CID) DecodeText(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeText(ci, src) -} - -func (dst *CID) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeBinary(ci, src) -} - -func (src CID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (pguint32)(src).EncodeText(ci, buf) -} - -func (src CID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (pguint32)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *CID) Scan(src interface{}) error { - return (*pguint32)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src CID) Value() (driver.Value, error) { - return (pguint32)(src).Value() -} diff --git a/pgtype/cid_test.go b/pgtype/cid_test.go deleted file mode 100644 index 3d3ad2a5..00000000 --- a/pgtype/cid_test.go +++ /dev/null @@ -1,102 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestCIDTranscode(t *testing.T) { - pgTypeName := "cid" - values := []interface{}{ - &pgtype.CID{Uint: 42, Valid: true}, - &pgtype.CID{}, - } - eqFunc := func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - } - - testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, "github.com/jackc/pgx/stdlib", pgTypeName, values, eqFunc) -} - -func TestCIDSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.CID - }{ - {source: uint32(1), result: pgtype.CID{Uint: 1, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.CID - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestCIDAssignTo(t *testing.T) { - var ui32 uint32 - var pui32 *uint32 - - simpleTests := []struct { - src pgtype.CID - dst interface{} - expected interface{} - }{ - {src: pgtype.CID{Uint: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.CID{}, dst: &pui32, expected: ((*uint32)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.CID - dst interface{} - expected interface{} - }{ - {src: pgtype.CID{Uint: 42, Valid: true}, dst: &pui32, expected: uint32(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.CID - dst interface{} - }{ - {src: pgtype.CID{}, dst: &ui32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/oid.go b/pgtype/oid.go deleted file mode 100644 index 31677e89..00000000 --- a/pgtype/oid.go +++ /dev/null @@ -1,81 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "strconv" - - "github.com/jackc/pgio" -) - -// OID (Object Identifier Type) is, according to -// https://www.postgresql.org/docs/current/static/datatype-oid.html, used -// internally by PostgreSQL as a primary key for various system tables. It is -// currently implemented as an unsigned four-byte integer. Its definition can be -// found in src/include/postgres_ext.h in the PostgreSQL sources. Because it is -// so frequently required to be in a NOT NULL condition OID cannot be NULL. To -// allow for NULL OIDs use OIDValue. -type OID uint32 - -func (dst *OID) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - return fmt.Errorf("cannot decode nil into OID") - } - - n, err := strconv.ParseUint(string(src), 10, 32) - if err != nil { - return err - } - - *dst = OID(n) - return nil -} - -func (dst *OID) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - return fmt.Errorf("cannot decode nil into OID") - } - - if len(src) != 4 { - return fmt.Errorf("invalid length: %v", len(src)) - } - - n := binary.BigEndian.Uint32(src) - *dst = OID(n) - return nil -} - -func (src OID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return append(buf, strconv.FormatUint(uint64(src), 10)...), nil -} - -func (src OID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return pgio.AppendUint32(buf, uint32(src)), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *OID) Scan(src interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan NULL into %T", src) - } - - switch src := src.(type) { - case int64: - *dst = OID(src) - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src OID) Value() (driver.Value, error) { - return int64(src), nil -} diff --git a/pgtype/oid_value.go b/pgtype/oid_value.go deleted file mode 100644 index 5dc9136c..00000000 --- a/pgtype/oid_value.go +++ /dev/null @@ -1,55 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -// OIDValue (Object Identifier Type) is, according to -// https://www.postgresql.org/docs/current/static/datatype-OIDValue.html, used -// internally by PostgreSQL as a primary key for various system tables. It is -// currently implemented as an unsigned four-byte integer. Its definition can be -// found in src/include/postgres_ext.h in the PostgreSQL sources. -type OIDValue pguint32 - -// Set converts from src to dst. Note that as OIDValue is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *OIDValue) Set(src interface{}) error { - return (*pguint32)(dst).Set(src) -} - -func (dst OIDValue) Get() interface{} { - return (pguint32)(dst).Get() -} - -// AssignTo assigns from src to dst. Note that as OIDValue is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *OIDValue) AssignTo(dst interface{}) error { - return (*pguint32)(src).AssignTo(dst) -} - -func (dst *OIDValue) DecodeText(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeText(ci, src) -} - -func (dst *OIDValue) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeBinary(ci, src) -} - -func (src OIDValue) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (pguint32)(src).EncodeText(ci, buf) -} - -func (src OIDValue) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (pguint32)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *OIDValue) Scan(src interface{}) error { - return (*pguint32)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src OIDValue) Value() (driver.Value, error) { - return (pguint32)(src).Value() -} diff --git a/pgtype/oid_value_test.go b/pgtype/oid_value_test.go deleted file mode 100644 index aecfc149..00000000 --- a/pgtype/oid_value_test.go +++ /dev/null @@ -1,95 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestOIDValueTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "oid", []interface{}{ - &pgtype.OIDValue{Uint: 42, Valid: true}, - &pgtype.OIDValue{}, - }) -} - -func TestOIDValueSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.OIDValue - }{ - {source: uint32(1), result: pgtype.OIDValue{Uint: 1, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.OIDValue - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestOIDValueAssignTo(t *testing.T) { - var ui32 uint32 - var pui32 *uint32 - - simpleTests := []struct { - src pgtype.OIDValue - dst interface{} - expected interface{} - }{ - {src: pgtype.OIDValue{Uint: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.OIDValue{}, dst: &pui32, expected: ((*uint32)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.OIDValue - dst interface{} - expected interface{} - }{ - {src: pgtype.OIDValue{Uint: 42, Valid: true}, dst: &pui32, expected: uint32(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.OIDValue - dst interface{} - }{ - {src: pgtype.OIDValue{}, dst: &ui32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index c9da1322..ef269365 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -286,7 +286,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) - ci.RegisterDataType(DataType{Value: &CID{}, Name: "cid", OID: CIDOID}) + ci.RegisterDataType(DataType{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID}) ci.RegisterDataType(DataType{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) ci.RegisterDataType(DataType{Value: &Date{}, Name: "date", OID: DateOID}) @@ -309,7 +309,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "name", OID: NameOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID}) // ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) - ci.RegisterDataType(DataType{Value: &OIDValue{}, Name: "oid", OID: OIDOID}) + ci.RegisterDataType(DataType{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID}) ci.RegisterDataType(DataType{Name: "point", OID: PointOID, Codec: PointCodec{}}) ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID}) @@ -327,7 +327,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &UUID{}, Name: "uuid", OID: UUIDOID}) ci.RegisterDataType(DataType{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) ci.RegisterDataType(DataType{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) - ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID}) + ci.RegisterDataType(DataType{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) { ci.RegisterDefaultPgType(value, name) diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go deleted file mode 100644 index e36ebb1f..00000000 --- a/pgtype/pguint32.go +++ /dev/null @@ -1,148 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "math" - "strconv" - - "github.com/jackc/pgio" -) - -// pguint32 is the core type that is used to implement PostgreSQL types such as -// CID and XID. -type pguint32 struct { - Uint uint32 - Valid bool -} - -// Set converts from src to dst. Note that as pguint32 is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *pguint32) Set(src interface{}) error { - switch value := src.(type) { - case int64: - if value < 0 { - return fmt.Errorf("%d is less than minimum value for pguint32", value) - } - if value > math.MaxUint32 { - return fmt.Errorf("%d is greater than maximum value for pguint32", value) - } - *dst = pguint32{Uint: uint32(value), Valid: true} - case uint32: - *dst = pguint32{Uint: value, Valid: true} - default: - return fmt.Errorf("cannot convert %v to pguint32", value) - } - - return nil -} - -func (dst pguint32) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.Uint -} - -// AssignTo assigns from src to dst. Note that as pguint32 is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *pguint32) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *uint32: - if src.Valid { - *v = src.Uint - } else { - return fmt.Errorf("cannot assign %v into %T", src, dst) - } - case **uint32: - if src.Valid { - n := src.Uint - *v = &n - } else { - *v = nil - } - } - - return nil -} - -func (dst *pguint32) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = pguint32{} - return nil - } - - n, err := strconv.ParseUint(string(src), 10, 32) - if err != nil { - return err - } - - *dst = pguint32{Uint: uint32(n), Valid: true} - return nil -} - -func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = pguint32{} - return nil - } - - if len(src) != 4 { - return fmt.Errorf("invalid length: %v", len(src)) - } - - n := binary.BigEndian.Uint32(src) - *dst = pguint32{Uint: n, Valid: true} - return nil -} - -func (src pguint32) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return append(buf, strconv.FormatUint(uint64(src.Uint), 10)...), nil -} - -func (src pguint32) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return pgio.AppendUint32(buf, src.Uint), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *pguint32) Scan(src interface{}) error { - if src == nil { - *dst = pguint32{} - return nil - } - - switch src := src.(type) { - case uint32: - *dst = pguint32{Uint: src, Valid: true} - return nil - case int64: - *dst = pguint32{Uint: uint32(src), Valid: true} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src pguint32) Value() (driver.Value, error) { - if !src.Valid { - return nil, nil - } - return int64(src.Uint), nil -} diff --git a/pgtype/uint32.go b/pgtype/uint32.go new file mode 100644 index 00000000..ccf39471 --- /dev/null +++ b/pgtype/uint32.go @@ -0,0 +1,303 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +type Uint32Scanner interface { + ScanUint32(v Uint32) error +} + +type Uint32Valuer interface { + Uint32Value() (Uint32, error) +} + +// Uint32 is the core type that is used to represent PostgreSQL types such as OID, CID, and XID. +type Uint32 struct { + Uint uint32 + Valid bool +} + +func (n *Uint32) ScanUint32(v Uint32) error { + *n = v + return nil +} + +func (n Uint32) Uint32Value() (Uint32, error) { + return n, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Uint32) Scan(src interface{}) error { + if src == nil { + *dst = Uint32{} + return nil + } + + var n int64 + + switch src := src.(type) { + case int64: + n = src + case string: + un, err := strconv.ParseUint(src, 10, 32) + if err != nil { + return err + } + n = int64(un) + default: + return fmt.Errorf("cannot scan %T", src) + } + + if n < 0 { + return fmt.Errorf("%d is less than the minimum value for Uint32", n) + } + if n > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for Uint32", n) + } + + *dst = Uint32{Uint: uint32(n), Valid: true} + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Uint32) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Uint), nil +} + +type Uint32Codec struct{} + +func (Uint32Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Uint32Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Uint32Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case uint32: + return encodePlanUint32CodecBinaryUint32{} + case Uint32Valuer: + return encodePlanUint32CodecBinaryUint32Valuer{} + case Int64Valuer: + return encodePlanUint32CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case uint32: + return encodePlanUint32CodecTextUint32{} + case Int64Valuer: + return encodePlanUint32CodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanUint32CodecBinaryUint32 struct{} + +func (encodePlanUint32CodecBinaryUint32) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v := value.(uint32) + return pgio.AppendUint32(buf, v), nil +} + +type encodePlanUint32CodecBinaryUint32Valuer struct{} + +func (encodePlanUint32CodecBinaryUint32Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v, err := value.(Uint32Valuer).Uint32Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + return pgio.AppendUint32(buf, v.Uint), nil +} + +type encodePlanUint32CodecBinaryInt64Valuer struct{} + +func (encodePlanUint32CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + if v.Int < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint32", v.Int) + } + if v.Int > math.MaxUint32 { + return nil, fmt.Errorf("%d is greater than maximum value for uint32", v.Int) + } + + return pgio.AppendUint32(buf, uint32(v.Int)), nil +} + +type encodePlanUint32CodecTextUint32 struct{} + +func (encodePlanUint32CodecTextUint32) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v := value.(uint32) + return append(buf, strconv.FormatUint(uint64(v), 10)...), nil +} + +type encodePlanUint32CodecTextUint32Valuer struct{} + +func (encodePlanUint32CodecTextUint32Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v, err := value.(Uint32Valuer).Uint32Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + return append(buf, strconv.FormatUint(uint64(v.Uint), 10)...), nil +} + +type encodePlanUint32CodecTextInt64Valuer struct{} + +func (encodePlanUint32CodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + if v.Int < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint32", v.Int) + } + if v.Int > math.MaxUint32 { + return nil, fmt.Errorf("%d is greater than maximum value for uint32", v.Int) + } + + return append(buf, strconv.FormatInt(v.Int, 10)...), nil +} + +func (Uint32Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *uint32: + return scanPlanBinaryUint32ToUint32{} + case Uint32Scanner: + return scanPlanBinaryUint32ToUint32Scanner{} + } + case TextFormatCode: + switch target.(type) { + case *uint32: + return scanPlanTextAnyToUint32{} + case Uint32Scanner: + return scanPlanTextAnyToUint32Scanner{} + } + } + + return nil +} + +func (c Uint32Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n uint32 + err := codecScan(c, ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return int64(n), nil +} + +func (c Uint32Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var n uint32 + err := codecScan(c, ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +type scanPlanBinaryUint32ToUint32 struct{} + +func (scanPlanBinaryUint32ToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint32: %v", len(src)) + } + + p := (dst).(*uint32) + *p = binary.BigEndian.Uint32(src) + + return nil +} + +type scanPlanBinaryUint32ToUint32Scanner struct{} + +func (scanPlanBinaryUint32ToUint32Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s, ok := (dst).(Uint32Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanUint32(Uint32{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint32: %v", len(src)) + } + + n := binary.BigEndian.Uint32(src) + + return s.ScanUint32(Uint32{Uint: n, Valid: true}) +} + +type scanPlanTextAnyToUint32Scanner struct{} + +func (scanPlanTextAnyToUint32Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s, ok := (dst).(Uint32Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanUint32(Uint32{}) + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + return s.ScanUint32(Uint32{Uint: uint32(n), Valid: true}) +} diff --git a/pgtype/uint32_test.go b/pgtype/uint32_test.go new file mode 100644 index 00000000..8e58605d --- /dev/null +++ b/pgtype/uint32_test.go @@ -0,0 +1,19 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgtype" +) + +func TestUint32Codec(t *testing.T) { + testPgxCodec(t, "oid", []PgxTranscodeTestCase{ + { + pgtype.Uint32{Uint: pgtype.TextOID, Valid: true}, + new(pgtype.Uint32), + isExpectedEq(pgtype.Uint32{Uint: pgtype.TextOID, Valid: true}), + }, + {pgtype.Uint32{}, new(pgtype.Uint32), isExpectedEq(pgtype.Uint32{})}, + {nil, new(pgtype.Uint32), isExpectedEq(pgtype.Uint32{})}, + }) +} diff --git a/pgtype/xid.go b/pgtype/xid.go deleted file mode 100644 index f6d6b22d..00000000 --- a/pgtype/xid.go +++ /dev/null @@ -1,64 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -// XID is PostgreSQL's Transaction ID type. -// -// In later versions of PostgreSQL, it is the type used for the backend_xid -// and backend_xmin columns of the pg_stat_activity system view. -// -// Also, when one does -// -// select xmin, xmax, * from some_table; -// -// it is the data type of the xmin and xmax hidden system columns. -// -// It is currently implemented as an unsigned four byte integer. -// Its definition can be found in src/include/postgres_ext.h as TransactionId -// in the PostgreSQL sources. -type XID pguint32 - -// Set converts from src to dst. Note that as XID is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *XID) Set(src interface{}) error { - return (*pguint32)(dst).Set(src) -} - -func (dst XID) Get() interface{} { - return (pguint32)(dst).Get() -} - -// AssignTo assigns from src to dst. Note that as XID is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *XID) AssignTo(dst interface{}) error { - return (*pguint32)(src).AssignTo(dst) -} - -func (dst *XID) DecodeText(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeText(ci, src) -} - -func (dst *XID) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeBinary(ci, src) -} - -func (src XID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (pguint32)(src).EncodeText(ci, buf) -} - -func (src XID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (pguint32)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *XID) Scan(src interface{}) error { - return (*pguint32)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src XID) Value() (driver.Value, error) { - return (pguint32)(src).Value() -} diff --git a/pgtype/xid_test.go b/pgtype/xid_test.go deleted file mode 100644 index ee11fa41..00000000 --- a/pgtype/xid_test.go +++ /dev/null @@ -1,102 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestXIDTranscode(t *testing.T) { - pgTypeName := "xid" - values := []interface{}{ - &pgtype.XID{Uint: 42, Valid: true}, - &pgtype.XID{}, - } - eqFunc := func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - } - - testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, "github.com/jackc/pgx/stdlib", pgTypeName, values, eqFunc) -} - -func TestXIDSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.XID - }{ - {source: uint32(1), result: pgtype.XID{Uint: 1, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.XID - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestXIDAssignTo(t *testing.T) { - var ui32 uint32 - var pui32 *uint32 - - simpleTests := []struct { - src pgtype.XID - dst interface{} - expected interface{} - }{ - {src: pgtype.XID{Uint: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.XID{}, dst: &pui32, expected: ((*uint32)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.XID - dst interface{} - expected interface{} - }{ - {src: pgtype.XID{Uint: 42, Valid: true}, dst: &pui32, expected: uint32(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.XID - dst interface{} - }{ - {src: pgtype.XID{}, dst: &ui32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/stdlib/sql.go b/stdlib/sql.go index 39b7524a..cbb8544e 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -615,8 +615,8 @@ func (r *Rows) Next(dest []driver.Value) error { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return d, err } - case pgtype.CIDOID: - var d pgtype.CID + case pgtype.CIDOID, pgtype.OIDOID, pgtype.XIDOID: + var d pgtype.Uint32 scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) @@ -690,16 +690,6 @@ func (r *Rows) Next(dest []driver.Value) error { } return d.Value() } - case pgtype.OIDOID: - var d pgtype.OIDValue - scanPlan := ci.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) - if err != nil { - return nil, err - } - return d.Value() - } case pgtype.TimestampOID: var d pgtype.Timestamp scanPlan := ci.PlanScan(dataTypeOID, format, &d) @@ -720,16 +710,6 @@ func (r *Rows) Next(dest []driver.Value) error { } return d.Value() } - case pgtype.XIDOID: - var d pgtype.XID - scanPlan := ci.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) - if err != nil { - return nil, err - } - return d.Value() - } default: var d string scanPlan := ci.PlanScan(dataTypeOID, format, &d) From f4a9d84e32e18702b22815bf6e9da54e6e5ce111 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 9 Jan 2022 00:41:25 -0600 Subject: [PATCH 0830/1158] Add CID, OID, and XID arrays --- pgtype/pgtype.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index ef269365..bcc8828d 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -48,6 +48,8 @@ const ( Int4ArrayOID = 1007 TextArrayOID = 1009 ByteaArrayOID = 1001 + XIDArrayOID = 1011 + CIDArrayOID = 1012 BPCharArrayOID = 1014 VarcharArrayOID = 1015 Int8ArrayOID = 1016 @@ -55,6 +57,7 @@ const ( BoxArrayOID = 1020 Float4ArrayOID = 1021 Float8ArrayOID = 1022 + OIDArrayOID = 1028 ACLItemOID = 1033 ACLItemArrayOID = 1034 InetArrayOID = 1041 @@ -279,6 +282,9 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: VarcharOID}}) ci.RegisterDataType(DataType{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementCodec: BitsCodec{}, ElementOID: BitOID}}) ci.RegisterDataType(DataType{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementCodec: BitsCodec{}, ElementOID: VarbitOID}}) + ci.RegisterDataType(DataType{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementCodec: Uint32Codec{}, ElementOID: CIDOID}}) + ci.RegisterDataType(DataType{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementCodec: Uint32Codec{}, ElementOID: OIDOID}}) + ci.RegisterDataType(DataType{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementCodec: Uint32Codec{}, ElementOID: XIDOID}}) ci.RegisterDataType(DataType{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) ci.RegisterDataType(DataType{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) From b57e0c419b8e5e273a4b1cbb8318995cb0da2e76 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 10 Jan 2022 21:02:20 -0600 Subject: [PATCH 0831/1158] Convert Date to Codec --- pgtype/builtin_wrappers.go | 25 ++ pgtype/date.go | 363 ++++++++++++++------------ pgtype/date_array.go | 505 ------------------------------------- pgtype/date_array_test.go | 327 ------------------------ pgtype/date_test.go | 127 ++-------- pgtype/pgtype.go | 76 ++++-- 6 files changed, 301 insertions(+), 1122 deletions(-) delete mode 100644 pgtype/date_array.go delete mode 100644 pgtype/date_array_test.go diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 5453bf18..b653260e 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -4,6 +4,7 @@ import ( "fmt" "math" "strconv" + "time" ) type int8Wrapper int8 @@ -315,3 +316,27 @@ func (w stringWrapper) Int64Value() (Int8, error) { return Int8{Int: int64(num), Valid: true}, nil } + +type timeWrapper time.Time + +func (w *timeWrapper) ScanDate(v Date) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Time") + } + + switch v.InfinityModifier { + case None: + *w = timeWrapper(v.Time) + return nil + case Infinity: + return fmt.Errorf("cannot scan Infinity into *time.Time") + case NegativeInfinity: + return fmt.Errorf("cannot scan -Infinity into *time.Time") + default: + return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) + } +} + +func (w timeWrapper) DateValue() (Date, error) { + return Date{Time: time.Time(w), Valid: true}, nil +} diff --git a/pgtype/date.go b/pgtype/date.go index 5b7f47e6..fde66745 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -10,10 +10,27 @@ import ( "github.com/jackc/pgio" ) +type DateScanner interface { + ScanDate(v Date) error +} + +type DateValuer interface { + DateValue() (Date, error) +} + type Date struct { Time time.Time - Valid bool InfinityModifier InfinityModifier + Valid bool +} + +func (d *Date) ScanDate(v Date) error { + *d = v + return nil +} + +func (d Date) DateValue() (Date, error) { + return d, nil } const ( @@ -21,166 +38,6 @@ const ( infinityDayOffset = 2147483647 ) -func (dst *Date) Set(src interface{}) error { - if src == nil { - *dst = Date{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case time.Time: - *dst = Date{Time: value, Valid: true} - case string: - return dst.DecodeText(nil, []byte(value)) - case *time.Time: - if value == nil { - *dst = Date{} - } else { - return dst.Set(*value) - } - case *string: - if value == nil { - *dst = Date{} - } else { - return dst.Set(*value) - } - default: - if originalSrc, ok := underlyingTimeType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Date", value) - } - - return nil -} - -func (dst Date) Get() interface{} { - if !dst.Valid { - return nil - } - if dst.InfinityModifier != None { - return dst.InfinityModifier - } - return dst.Time -} - -func (src *Date) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case *time.Time: - if src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } -} - -func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Date{} - return nil - } - - sbuf := string(src) - switch sbuf { - case "infinity": - *dst = Date{Valid: true, InfinityModifier: Infinity} - case "-infinity": - *dst = Date{Valid: true, InfinityModifier: -Infinity} - default: - t, err := time.ParseInLocation("2006-01-02", sbuf, time.UTC) - if err != nil { - return err - } - - *dst = Date{Time: t, Valid: true} - } - - return nil -} - -func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Date{} - return nil - } - - if len(src) != 4 { - return fmt.Errorf("invalid length for date: %v", len(src)) - } - - dayOffset := int32(binary.BigEndian.Uint32(src)) - - switch dayOffset { - case infinityDayOffset: - *dst = Date{Valid: true, InfinityModifier: Infinity} - case negativeInfinityDayOffset: - *dst = Date{Valid: true, InfinityModifier: -Infinity} - default: - t := time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.UTC) - *dst = Date{Time: t, Valid: true} - } - - return nil -} - -func (src Date) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - var s string - - switch src.InfinityModifier { - case None: - s = src.Time.Format("2006-01-02") - case Infinity: - s = "infinity" - case NegativeInfinity: - s = "-infinity" - } - - return append(buf, s...), nil -} - -func (src Date) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - var daysSinceDateEpoch int32 - switch src.InfinityModifier { - case None: - tUnix := time.Date(src.Time.Year(), src.Time.Month(), src.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() - dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() - - secSinceDateEpoch := tUnix - dateEpoch - daysSinceDateEpoch = int32(secSinceDateEpoch / 86400) - case Infinity: - daysSinceDateEpoch = infinityDayOffset - case NegativeInfinity: - daysSinceDateEpoch = negativeInfinityDayOffset - } - - return pgio.AppendInt32(buf, daysSinceDateEpoch), nil -} - // Scan implements the database/sql Scanner interface. func (dst *Date) Scan(src interface{}) error { if src == nil { @@ -190,11 +47,7 @@ func (dst *Date) Scan(src interface{}) error { switch src := src.(type) { case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + return scanPlanTextAnyToDateScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) case time.Time: *dst = Date{Time: src, Valid: true} return nil @@ -262,3 +115,181 @@ func (dst *Date) UnmarshalJSON(b []byte) error { return nil } + +type DateCodec struct{} + +func (DateCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (DateCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (DateCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(DateValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanDateCodecBinary{} + case TextFormatCode: + return encodePlanDateCodecText{} + } + + return nil +} + +type encodePlanDateCodecBinary struct{} + +func (encodePlanDateCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + date, err := value.(DateValuer).DateValue() + if err != nil { + return nil, err + } + + if !date.Valid { + return nil, nil + } + + var daysSinceDateEpoch int32 + switch date.InfinityModifier { + case None: + tUnix := time.Date(date.Time.Year(), date.Time.Month(), date.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() + dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() + + secSinceDateEpoch := tUnix - dateEpoch + daysSinceDateEpoch = int32(secSinceDateEpoch / 86400) + case Infinity: + daysSinceDateEpoch = infinityDayOffset + case NegativeInfinity: + daysSinceDateEpoch = negativeInfinityDayOffset + } + + return pgio.AppendInt32(buf, daysSinceDateEpoch), nil +} + +type encodePlanDateCodecText struct{} + +func (encodePlanDateCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + date, err := value.(DateValuer).DateValue() + if err != nil { + return nil, err + } + + if !date.Valid { + return nil, nil + } + + var s string + + switch date.InfinityModifier { + case None: + s = date.Time.Format("2006-01-02") + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return append(buf, s...), nil +} + +func (DateCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case DateScanner: + return scanPlanBinaryDateToDateScanner{} + } + case TextFormatCode: + switch target.(type) { + case DateScanner: + return scanPlanTextAnyToDateScanner{} + } + } + + return nil +} + +type scanPlanBinaryDateToDateScanner struct{} + +func (scanPlanBinaryDateToDateScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(DateScanner) + + if src == nil { + return scanner.ScanDate(Date{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for date: %v", len(src)) + } + + dayOffset := int32(binary.BigEndian.Uint32(src)) + + switch dayOffset { + case infinityDayOffset: + return scanner.ScanDate(Date{InfinityModifier: Infinity, Valid: true}) + case negativeInfinityDayOffset: + return scanner.ScanDate(Date{InfinityModifier: -Infinity, Valid: true}) + default: + t := time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.UTC) + return scanner.ScanDate(Date{Time: t, Valid: true}) + } +} + +type scanPlanTextAnyToDateScanner struct{} + +func (scanPlanTextAnyToDateScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(DateScanner) + + if src == nil { + return scanner.ScanDate(Date{}) + } + + sbuf := string(src) + switch sbuf { + case "infinity": + return scanner.ScanDate(Date{InfinityModifier: Infinity, Valid: true}) + case "-infinity": + return scanner.ScanDate(Date{InfinityModifier: -Infinity, Valid: true}) + default: + t, err := time.ParseInLocation("2006-01-02", sbuf, time.UTC) + if err != nil { + return err + } + + return scanner.ScanDate(Date{Time: t, Valid: true}) + } +} + +func (c DateCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c DateCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var date Date + err := codecScan(c, ci, oid, format, src, &date) + if err != nil { + return nil, err + } + + if date.Valid { + switch date.InfinityModifier { + case None: + return date.Time, nil + case Infinity: + return "infinity", nil + case NegativeInfinity: + return "-infinity", nil + } + } + + return nil, nil +} diff --git a/pgtype/date_array.go b/pgtype/date_array.go deleted file mode 100644 index 9d3b32e2..00000000 --- a/pgtype/date_array.go +++ /dev/null @@ -1,505 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - "time" - - "github.com/jackc/pgio" -) - -type DateArray struct { - Elements []Date - Dimensions []ArrayDimension - Valid bool -} - -func (dst *DateArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = DateArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []time.Time: - if value == nil { - *dst = DateArray{} - } else if len(value) == 0 { - *dst = DateArray{Valid: true} - } else { - elements := make([]Date, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = DateArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*time.Time: - if value == nil { - *dst = DateArray{} - } else if len(value) == 0 { - *dst = DateArray{Valid: true} - } else { - elements := make([]Date, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = DateArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Date: - if value == nil { - *dst = DateArray{} - } else if len(value) == 0 { - *dst = DateArray{Valid: true} - } else { - *dst = DateArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = DateArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for DateArray", src) - } - if elementsLength == 0 { - *dst = DateArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to DateArray", src) - } - - *dst = DateArray{ - Elements: make([]Date, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Date, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to DateArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *DateArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to DateArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in DateArray", err) - } - index++ - - return index, nil -} - -func (dst DateArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *DateArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*time.Time: - *v = make([]*time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *DateArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from DateArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from DateArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = DateArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Date - - if len(uta.Elements) > 0 { - elements = make([]Date, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Date - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = DateArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = DateArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = DateArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Date, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = DateArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("date"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "date") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *DateArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src DateArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/date_array_test.go b/pgtype/date_array_test.go deleted file mode 100644 index 39ee32ce..00000000 --- a/pgtype/date_array_test.go +++ /dev/null @@ -1,327 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - "time" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestDateArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "date[]", []interface{}{ - &pgtype.DateArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.DateArray{}, - &pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {}, - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestDateArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.DateArray - }{ - { - source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - result: pgtype.DateArray{ - Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]time.Time)(nil)), - result: pgtype.DateArray{}, - }, - { - source: [][]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - result: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - result: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - result: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - result: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.DateArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestDateArrayAssignTo(t *testing.T) { - var timeSlice []time.Time - var timeSliceDim2 [][]time.Time - var timeSliceDim4 [][][][]time.Time - var timeArrayDim2 [2][1]time.Time - var timeArrayDim4 [2][1][1][3]time.Time - - simpleTests := []struct { - src pgtype.DateArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &timeSlice, - expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - }, - { - src: pgtype.DateArray{}, - dst: &timeSlice, - expected: (([]time.Time)(nil)), - }, - { - src: pgtype.DateArray{Valid: true}, - dst: &timeSlice, - expected: []time.Time{}, - }, - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &timeSliceDim2, - expected: [][]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - }, - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &timeSliceDim4, - expected: [][][][]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - }, - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &timeArrayDim2, - expected: [2][1]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - }, - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &timeArrayDim4, - expected: [2][1][1][3]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.DateArray - dst interface{} - }{ - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &timeSlice, - }, - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &timeArrayDim2, - }, - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &timeSlice, - }, - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &timeArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/date_test.go b/pgtype/date_test.go index caccfc47..268759c1 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -1,123 +1,36 @@ package pgtype_test import ( - "reflect" "testing" "time" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestDateTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "date", []interface{}{ - &pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - &pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - &pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, - &pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - &pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true}, - &pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - &pgtype.Date{}, - &pgtype.Date{Valid: true, InfinityModifier: pgtype.Infinity}, - &pgtype.Date{Valid: true, InfinityModifier: -pgtype.Infinity}, - }, func(a, b interface{}) bool { - at := a.(pgtype.Date) - bt := b.(pgtype.Date) +func isExpectedEqTime(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + at := a.(time.Time) + vt := v.(time.Time) - return at.Time.Equal(bt.Time) && at.Valid == bt.Valid && at.InfinityModifier == bt.InfinityModifier + return at.Equal(vt) + } +} + +func TestDateCodec(t *testing.T) { + testPgxCodec(t, "date", []PgxTranscodeTestCase{ + {time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC))}, + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC))}, + {time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC))}, + {pgtype.Date{InfinityModifier: pgtype.Infinity, Valid: true}, new(pgtype.Date), isExpectedEq(pgtype.Date{InfinityModifier: pgtype.Infinity, Valid: true})}, + {pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(pgtype.Date), isExpectedEq(pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Valid: true})}, + {pgtype.Date{}, new(pgtype.Date), isExpectedEq(pgtype.Date{})}, + {nil, new(*time.Time), isExpectedEq((*time.Time)(nil))}, }) } -func TestDateSet(t *testing.T) { - type _time time.Time - - successfulTests := []struct { - source interface{} - result pgtype.Date - }{ - {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - {source: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}}, - {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - {source: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true}}, - {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - {source: "1999-12-31", result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}}, - } - - for i, tt := range successfulTests { - var d pgtype.Date - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if d != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } - } -} - -func TestDateAssignTo(t *testing.T) { - var tim time.Time - var ptim *time.Time - - simpleTests := []struct { - src pgtype.Date - dst interface{} - expected interface{} - }{ - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, - {src: pgtype.Date{Time: time.Time{}}, dst: &ptim, expected: ((*time.Time)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Date - dst interface{} - expected interface{} - }{ - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Date - dst interface{} - }{ - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Valid: true}, dst: &tim}, - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Valid: true}, dst: &tim}, - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, dst: &tim}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} - func TestDateMarshalJSON(t *testing.T) { successfulTests := []struct { source pgtype.Date diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index bcc8828d..e48bec51 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -263,7 +263,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: BPCharOID}}) ci.RegisterDataType(DataType{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementCodec: ByteaCodec{}, ElementOID: ByteaOID}}) ci.RegisterDataType(DataType{Value: &CIDRArray{}, Name: "_cidr", OID: CIDRArrayOID}) - ci.RegisterDataType(DataType{Value: &DateArray{}, Name: "_date", OID: DateArrayOID}) + ci.RegisterDataType(DataType{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementCodec: DateCodec{}, ElementOID: DateOID}}) ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID}) ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID}) ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID}) @@ -295,7 +295,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID}) ci.RegisterDataType(DataType{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) - ci.RegisterDataType(DataType{Value: &Date{}, Name: "date", OID: DateOID}) + ci.RegisterDataType(DataType{Name: "date", OID: DateOID, Codec: DateCodec{}}) // ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID}) ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID}) @@ -807,6 +807,30 @@ func tryUnderlyingTypeScanPlan(dst interface{}) (plan *underlyingTypeScanPlan, n return nil, nil, false } +type WrappedScanPlanNextSetter interface { + SetNext(ScanPlan) + ScanPlan +} + +func tryWrapBuiltinTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter, nextDst interface{}, ok bool) { + switch dst := dst.(type) { + case *time.Time: + return &wrapTimeScanPlan{}, (*timeWrapper)(dst), true + } + + return nil, nil, false +} + +type wrapTimeScanPlan struct { + next ScanPlan +} + +func (plan *wrapTimeScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapTimeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*timeWrapper)(dst.(*time.Time))) +} + type pointerEmptyInterfaceScanPlan struct { codec Codec } @@ -903,6 +927,13 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan } } + if wrapperPlan, nextValue, ok := tryWrapBuiltinTypeScanPlan(dst); ok { + if nextPlan := ci.PlanScan(oid, formatCode, nextValue); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + if _, ok := dst.(*interface{}); ok { return &pointerEmptyInterfaceScanPlan{codec: dt.Codec} } @@ -1031,7 +1062,6 @@ func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) Enco wrapperPlan.SetNext(nextPlan) return wrapperPlan } - } } @@ -1107,33 +1137,35 @@ type WrappedEncodePlanNextSetter interface { } func tryWrapBuiltinTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { - switch value.(type) { + switch value := value.(type) { case int8: - return &wrapInt8EncodePlan{}, int8Wrapper(value.(int8)), true + return &wrapInt8EncodePlan{}, int8Wrapper(value), true case int16: - return &wrapInt16EncodePlan{}, int16Wrapper(value.(int16)), true + return &wrapInt16EncodePlan{}, int16Wrapper(value), true case int32: - return &wrapInt32EncodePlan{}, int32Wrapper(value.(int32)), true + return &wrapInt32EncodePlan{}, int32Wrapper(value), true case int64: - return &wrapInt64EncodePlan{}, int64Wrapper(value.(int64)), true + return &wrapInt64EncodePlan{}, int64Wrapper(value), true case int: - return &wrapIntEncodePlan{}, intWrapper(value.(int)), true + return &wrapIntEncodePlan{}, intWrapper(value), true case uint8: - return &wrapUint8EncodePlan{}, uint8Wrapper(value.(uint8)), true + return &wrapUint8EncodePlan{}, uint8Wrapper(value), true case uint16: - return &wrapUint16EncodePlan{}, uint16Wrapper(value.(uint16)), true + return &wrapUint16EncodePlan{}, uint16Wrapper(value), true case uint32: - return &wrapUint32EncodePlan{}, uint32Wrapper(value.(uint32)), true + return &wrapUint32EncodePlan{}, uint32Wrapper(value), true case uint64: - return &wrapUint64EncodePlan{}, uint64Wrapper(value.(uint64)), true + return &wrapUint64EncodePlan{}, uint64Wrapper(value), true case uint: - return &wrapUintEncodePlan{}, uintWrapper(value.(uint)), true + return &wrapUintEncodePlan{}, uintWrapper(value), true case float32: - return &wrapFloat32EncodePlan{}, float32Wrapper(value.(float32)), true + return &wrapFloat32EncodePlan{}, float32Wrapper(value), true case float64: - return &wrapFloat64EncodePlan{}, float64Wrapper(value.(float64)), true + return &wrapFloat64EncodePlan{}, float64Wrapper(value), true case string: - return &wrapStringEncodePlan{}, stringWrapper(value.(string)), true + return &wrapStringEncodePlan{}, stringWrapper(value), true + case time.Time: + return &wrapTimeEncodePlan{}, timeWrapper(value), true } return nil, nil, false @@ -1269,6 +1301,16 @@ func (plan *wrapStringEncodePlan) Encode(value interface{}, buf []byte) (newBuf return plan.next.Encode(stringWrapper(value.(string)), buf) } +type wrapTimeEncodePlan struct { + next EncodePlan +} + +func (plan *wrapTimeEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapTimeEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(timeWrapper(value.(time.Time)), buf) +} + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written. From ae9be0b99ed53d3db94444a994b954fdf3f52d26 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 11 Jan 2022 20:46:10 -0600 Subject: [PATCH 0832/1158] Replace EnumType with EnumCodec --- bench_test.go | 3 +- pgtype/enum_codec.go | 114 +++++++++++++++++++++++++++ pgtype/enum_type.go | 158 -------------------------------------- pgtype/enum_type_test.go | 148 ----------------------------------- pgtype/pgxtype/pgxtype.go | 31 +------- 5 files changed, 116 insertions(+), 338 deletions(-) create mode 100644 pgtype/enum_codec.go delete mode 100644 pgtype/enum_type.go delete mode 100644 pgtype/enum_type_test.go diff --git a/bench_test.go b/bench_test.go index 9b14b7d3..c49c87f6 100644 --- a/bench_test.go +++ b/bench_test.go @@ -918,8 +918,7 @@ func BenchmarkSelectManyRegisteredEnum(b *testing.B) { err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "color").Scan(&oid) require.NoError(b, err) - et := pgtype.NewEnumType("color", []string{"blue", "green", "orange"}) - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "color", OID: oid}) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Name: "color", OID: oid, Codec: &pgtype.EnumCodec{}}) b.ResetTimer() var x, y, z string diff --git a/pgtype/enum_codec.go b/pgtype/enum_codec.go new file mode 100644 index 00000000..9a37f1dd --- /dev/null +++ b/pgtype/enum_codec.go @@ -0,0 +1,114 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" +) + +// EnumCodec is a codec that caches the strings it decodes. If the same string is read multiple times only one copy is +// allocated. These strings are only garbage collected when the EnumCodec is garbage collected. EnumCodec can be used +// for any text type not only enums, but it should only be used when there are a small number of possible values. +type EnumCodec struct { + membersMap map[string]string // map to quickly lookup member and reuse string instead of allocating +} + +func (EnumCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (EnumCodec) PreferredFormat() int16 { + return TextFormatCode +} + +func (EnumCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch value.(type) { + case string: + return encodePlanTextCodecString{} + case []byte: + return encodePlanTextCodecByteSlice{} + case rune: + return encodePlanTextCodecRune{} + case fmt.Stringer: + return encodePlanTextCodecStringer{} + case TextValuer: + return encodePlanTextCodecTextValuer{} + } + } + + return nil +} + +func (c *EnumCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch target.(type) { + case *string: + return &scanPlanTextAnyToEnumString{codec: c} + case *[]byte: + return scanPlanAnyToNewByteSlice{} + case TextScanner: + return &scanPlanTextAnyToEnumTextScanner{codec: c} + case *rune: + return scanPlanTextAnyToRune{} + } + } + + return nil +} + +func (c *EnumCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(ci, oid, format, src) +} + +func (c *EnumCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + return c.lookupAndCacheString(src), nil +} + +// lookupAndCacheString looks for src in the members map. If it is not found it is added to the map. +func (c *EnumCodec) lookupAndCacheString(src []byte) string { + if c.membersMap == nil { + c.membersMap = make(map[string]string) + } + + if s, found := c.membersMap[string(src)]; found { + return s + } else { + c.membersMap[s] = s + return s + } +} + +type scanPlanTextAnyToEnumString struct { + codec *EnumCodec +} + +func (plan *scanPlanTextAnyToEnumString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p := (dst).(*string) + *p = plan.codec.lookupAndCacheString(src) + + return nil +} + +type scanPlanTextAnyToEnumTextScanner struct { + codec *EnumCodec +} + +func (plan *scanPlanTextAnyToEnumTextScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(TextScanner) + + if src == nil { + return scanner.ScanText(Text{}) + } + + return scanner.ScanText(Text{String: plan.codec.lookupAndCacheString(src), Valid: true}) +} diff --git a/pgtype/enum_type.go b/pgtype/enum_type.go deleted file mode 100644 index 73ee3823..00000000 --- a/pgtype/enum_type.go +++ /dev/null @@ -1,158 +0,0 @@ -package pgtype - -import "fmt" - -// EnumType represents a enum type. While it implements Value, this is only in service of its type conversion duties -// when registered as a data type in a ConnType. It should not be used directly as a Value. -type EnumType struct { - value string - valid bool - - typeName string // PostgreSQL type name - members []string // enum members - membersMap map[string]string // map to quickly lookup member and reuse string instead of allocating -} - -// NewEnumType initializes a new EnumType. It retains a read-only reference to members. members must not be changed. -func NewEnumType(typeName string, members []string) *EnumType { - et := &EnumType{typeName: typeName, members: members} - et.membersMap = make(map[string]string, len(members)) - for _, m := range members { - et.membersMap[m] = m - } - return et -} - -func (et *EnumType) NewTypeValue() Value { - return &EnumType{ - value: et.value, - valid: et.valid, - - typeName: et.typeName, - members: et.members, - membersMap: et.membersMap, - } -} - -func (et *EnumType) TypeName() string { - return et.typeName -} - -func (et *EnumType) Members() []string { - return et.members -} - -// Set assigns src to dst. Set purposely does not check that src is a member. This allows continued error free -// operation in the event the PostgreSQL enum type is modified during a connection. -func (dst *EnumType) Set(src interface{}) error { - if src == nil { - dst.valid = false - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case string: - dst.value = value - dst.valid = true - case *string: - if value == nil { - dst.valid = false - } else { - dst.value = *value - dst.valid = true - } - case []byte: - if value == nil { - dst.valid = false - } else { - dst.value = string(value) - dst.valid = true - } - default: - if originalSrc, ok := underlyingStringType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to enum %s", value, dst.typeName) - } - - return nil -} - -func (dst EnumType) Get() interface{} { - if !dst.valid { - return nil - } - return dst.value -} - -func (src *EnumType) AssignTo(dst interface{}) error { - if !src.valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case *string: - *v = src.value - return nil - case *[]byte: - *v = make([]byte, len(src.value)) - copy(*v, src.value) - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } -} - -func (EnumType) PreferredResultFormat() int16 { - return TextFormatCode -} - -func (dst *EnumType) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - dst.valid = false - return nil - } - - // Lookup the string in membersMap to avoid an allocation. - if s, found := dst.membersMap[string(src)]; found { - dst.value = s - } else { - // If an enum type is modified after the initial connection it is possible to receive an unexpected value. - // Gracefully handle this situation. Purposely NOT modifying members and membersMap to allow for sharing members - // and membersMap between connections. - dst.value = string(src) - } - dst.valid = true - - return nil -} - -func (dst *EnumType) DecodeBinary(ci *ConnInfo, src []byte) error { - return dst.DecodeText(ci, src) -} - -func (EnumType) PreferredParamFormat() int16 { - return TextFormatCode -} - -func (src EnumType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.valid { - return nil, nil - } - - return append(buf, src.value...), nil -} - -func (src EnumType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return src.EncodeText(ci, buf) -} diff --git a/pgtype/enum_type_test.go b/pgtype/enum_type_test.go deleted file mode 100644 index 903b742f..00000000 --- a/pgtype/enum_type_test.go +++ /dev/null @@ -1,148 +0,0 @@ -package pgtype_test - -import ( - "bytes" - "context" - "testing" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func setupEnum(t *testing.T, conn *pgx.Conn) *pgtype.EnumType { - _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;") - require.NoError(t, err) - - _, err = conn.Exec(context.Background(), "create type pgtype_enum_color as enum ('blue', 'green', 'purple');") - require.NoError(t, err) - - var oid uint32 - err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "pgtype_enum_color").Scan(&oid) - require.NoError(t, err) - - et := pgtype.NewEnumType("pgtype_enum_color", []string{"blue", "green", "purple"}) - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "pgtype_enum_color", OID: oid}) - - return et -} - -func cleanupEnum(t *testing.T, conn *pgx.Conn) { - _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;") - require.NoError(t, err) -} - -func TestEnumTypeTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - setupEnum(t, conn) - defer cleanupEnum(t, conn) - - var dst string - err := conn.QueryRow(context.Background(), "select $1::pgtype_enum_color", "blue").Scan(&dst) - require.NoError(t, err) - require.EqualValues(t, "blue", dst) -} - -func TestEnumTypeSet(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - enumType := setupEnum(t, conn) - defer cleanupEnum(t, conn) - - successfulTests := []struct { - source interface{} - result interface{} - }{ - {source: "blue", result: "blue"}, - {source: _string("green"), result: "green"}, - {source: (*string)(nil), result: nil}, - } - - for i, tt := range successfulTests { - err := enumType.Set(tt.source) - assert.NoErrorf(t, err, "%d", i) - assert.Equalf(t, tt.result, enumType.Get(), "%d", i) - } -} - -func TestEnumTypeAssignTo(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - enumType := setupEnum(t, conn) - defer cleanupEnum(t, conn) - - { - var s string - - err := enumType.Set("blue") - require.NoError(t, err) - - err = enumType.AssignTo(&s) - require.NoError(t, err) - - assert.EqualValues(t, "blue", s) - } - - { - var ps *string - - err := enumType.Set("blue") - require.NoError(t, err) - - err = enumType.AssignTo(&ps) - require.NoError(t, err) - - assert.EqualValues(t, "blue", *ps) - } - - { - var ps *string - - err := enumType.Set(nil) - require.NoError(t, err) - - err = enumType.AssignTo(&ps) - require.NoError(t, err) - - assert.EqualValues(t, (*string)(nil), ps) - } - - var buf []byte - bytesTests := []struct { - src interface{} - dst *[]byte - expected []byte - }{ - {src: "blue", dst: &buf, expected: []byte("blue")}, - {src: nil, dst: &buf, expected: nil}, - } - - for i, tt := range bytesTests { - err := enumType.Set(tt.src) - require.NoError(t, err, "%d", i) - - err = enumType.AssignTo(tt.dst) - require.NoError(t, err, "%d", i) - - if bytes.Compare(*tt.dst, tt.expected) != 0 { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst) - } - } - - { - var s string - - err := enumType.Set(nil) - require.NoError(t, err) - - err = enumType.AssignTo(&s) - require.Error(t, err) - } - -} diff --git a/pgtype/pgxtype/pgxtype.go b/pgtype/pgxtype/pgxtype.go index 4f2c5796..6b5068e2 100644 --- a/pgtype/pgxtype/pgxtype.go +++ b/pgtype/pgxtype/pgxtype.go @@ -65,11 +65,7 @@ func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeNa // } // return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil case "e": // enum - members, err := GetEnumMembers(ctx, conn, oid) - if err != nil { - return pgtype.DataType{}, err - } - return pgtype.DataType{Value: pgtype.NewEnumType(typeName, members), Name: typeName, OID: oid}, nil + return pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil default: return pgtype.DataType{}, errors.New("unknown typtype") } @@ -121,28 +117,3 @@ func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32, // return fields, nil // } - -// GetEnumMembers gets the possible values of the enum by oid. -func GetEnumMembers(ctx context.Context, conn Querier, oid uint32) ([]string, error) { - members := []string{} - - rows, err := conn.Query(ctx, "select enumlabel from pg_enum where enumtypid=$1 order by enumsortorder", oid) - if err != nil { - return nil, err - } - - for rows.Next() { - var m string - err := rows.Scan(&m) - if err != nil { - return nil, err - } - members = append(members, m) - } - - if rows.Err() != nil { - return nil, rows.Err() - } - - return members, nil -} From f743007fb48218dacc967c194410a25bbdb246ce Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 11 Jan 2022 20:49:20 -0600 Subject: [PATCH 0833/1158] Restore array support to pgxtype.LoadDataType --- pgtype/pgxtype/pgxtype.go | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/pgtype/pgxtype/pgxtype.go b/pgtype/pgxtype/pgxtype.go index 6b5068e2..6436f01b 100644 --- a/pgtype/pgxtype/pgxtype.go +++ b/pgtype/pgxtype/pgxtype.go @@ -34,25 +34,20 @@ func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeNa switch typtype { case "b": // array - panic("TODO - restore array support") - // elementOID, err := GetArrayElementOID(ctx, conn, oid) - // if err != nil { - // return pgtype.DataType{}, err - // } + elementOID, err := GetArrayElementOID(ctx, conn, oid) + if err != nil { + return pgtype.DataType{}, err + } - // var element pgtype.ValueTranscoder - // if dt, ok := ci.DataTypeForOID(elementOID); ok { - // if element, ok = dt.Value.(pgtype.ValueTranscoder); !ok { - // return pgtype.DataType{}, errors.New("array element OID not registered as ValueTranscoder") - // } - // } + var elementCodec pgtype.Codec + if dt, ok := ci.DataTypeForOID(elementOID); ok { + if dt.Codec == nil { + return pgtype.DataType{}, errors.New("array element OID not registered with Codec") + } + elementCodec = dt.Codec + } - // newElement := func() pgtype.ValueTranscoder { - // return pgtype.NewValue(element).(pgtype.ValueTranscoder) - // } - - // at := pgtype.NewArrayType(typeName, elementOID, newElement) - // return pgtype.DataType{Value: at, Name: typeName, OID: oid}, nil + return pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementOID: elementOID, ElementCodec: elementCodec}}, nil case "c": // composite panic("TODO - restore composite support") // fields, err := GetCompositeFields(ctx, conn, oid) From ccc7cc2931b8f0b3df41e1f5c1e7322d708b6623 Mon Sep 17 00:00:00 2001 From: Oleg Lomaka Date: Tue, 4 Jan 2022 16:25:19 +0200 Subject: [PATCH 0834/1158] Assign Numeric to *big.Rat --- numeric.go | 26 ++++++++++++++++++++++++++ numeric_test.go | 13 +++++++++++++ 2 files changed, 39 insertions(+) diff --git a/numeric.go b/numeric.go index a939625b..cd057749 100644 --- a/numeric.go +++ b/numeric.go @@ -369,6 +369,12 @@ func (src *Numeric) AssignTo(dst interface{}) error { return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) } *v = normalizedInt.Uint64() + case *big.Rat: + rat, err := src.toBigRat() + if err != nil { + return err + } + v.Set(rat) default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) @@ -406,6 +412,26 @@ func (dst *Numeric) toBigInt() (*big.Int, error) { return num, nil } +func (dst *Numeric) toBigRat() (*big.Rat, error) { + if dst.NaN { + return nil, fmt.Errorf("%v is not a number", dst) + } else if dst.InfinityModifier == Infinity { + return nil, fmt.Errorf("%v is infinity", dst) + } else if dst.InfinityModifier == NegativeInfinity { + return nil, fmt.Errorf("%v is -infinity", dst) + } + + num := new(big.Rat).SetInt(dst.Int) + if dst.Exp > 0 { + mul := new(big.Int).Exp(big10, big.NewInt(int64(dst.Exp)), nil) + num.Mul(num, new(big.Rat).SetInt(mul)) + } else if dst.Exp < 0 { + mul := new(big.Int).Exp(big10, big.NewInt(int64(-dst.Exp)), nil) + num.Quo(num, new(big.Rat).SetInt(mul)) + } + return num, nil +} + func (src *Numeric) toFloat64() (float64, error) { if src.NaN { return math.NaN(), nil diff --git a/numeric_test.go b/numeric_test.go index 455c3ac3..83334a04 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -263,6 +263,7 @@ func TestNumericAssignTo(t *testing.T) { var f64 float64 var pf32 *float32 var pf64 *float64 + var br = new(big.Rat) simpleTests := []struct { src *pgtype.Numeric @@ -293,6 +294,9 @@ func TestNumericAssignTo(t *testing.T) { {src: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, dst: &f32, expected: float32(math.Inf(1))}, {src: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.NegativeInfinity}, dst: &f64, expected: math.Inf(-1)}, {src: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.NegativeInfinity}, dst: &f32, expected: float32(math.Inf(-1))}, + {src: &pgtype.Numeric{Int: big.NewInt(-1023), Exp: -2, Status: pgtype.Present}, dst: br, expected: big.NewRat(-1023, 100)}, + {src: &pgtype.Numeric{Int: big.NewInt(-1023), Exp: 2, Status: pgtype.Present}, dst: br, expected: big.NewRat(-102300, 1)}, + {src: &pgtype.Numeric{Int: big.NewInt(23), Exp: 0, Status: pgtype.Present}, dst: br, expected: big.NewRat(23, 1)}, } for i, tt := range simpleTests { @@ -317,6 +321,11 @@ func TestNumericAssignTo(t *testing.T) { } else if !nanExpected && dst != tt.expected { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } + case big.Rat: + if (&dstTyped).Cmp(tt.expected.(*big.Rat)) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", + i, tt.src, tt.expected, dst) + } default: if dst != tt.expected { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) @@ -356,6 +365,10 @@ func TestNumericAssignTo(t *testing.T) { {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui64}, {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui}, {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &i32}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: br}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Present, NaN: true}, dst: br}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, dst: br}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Present, InfinityModifier: pgtype.NegativeInfinity}, dst: br}, } for i, tt := range errorTests { From 05598d4ca64c9cf6ca3e8ca2e13ec3ab28bcdd9f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 15 Jan 2022 09:48:21 -0600 Subject: [PATCH 0835/1158] Convert inet and cidr to codec --- pgtype/builtin_wrappers.go | 41 +++ pgtype/cidr.go | 31 --- pgtype/cidr_array.go | 533 ------------------------------------- pgtype/cidr_array_test.go | 319 ---------------------- pgtype/inet.go | 355 ++++++++++++------------ pgtype/inet_array.go | 533 ------------------------------------- pgtype/inet_array_test.go | 319 ---------------------- pgtype/inet_test.go | 156 +++-------- pgtype/pgtype.go | 74 ++++- 9 files changed, 318 insertions(+), 2043 deletions(-) delete mode 100644 pgtype/cidr.go delete mode 100644 pgtype/cidr_array.go delete mode 100644 pgtype/cidr_array_test.go delete mode 100644 pgtype/inet_array.go delete mode 100644 pgtype/inet_array_test.go diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index b653260e..df955f18 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -3,6 +3,7 @@ package pgtype import ( "fmt" "math" + "net" "strconv" "time" ) @@ -340,3 +341,43 @@ func (w *timeWrapper) ScanDate(v Date) error { func (w timeWrapper) DateValue() (Date, error) { return Date{Time: time.Time(w), Valid: true}, nil } + +type netIPNetWrapper net.IPNet + +func (w *netIPNetWrapper) ScanInet(v Inet) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *net.IPNet") + } + + *w = (netIPNetWrapper)(*v.IPNet) + return nil +} + +func (w netIPNetWrapper) InetValue() (Inet, error) { + return Inet{IPNet: (*net.IPNet)(&w), Valid: true}, nil +} + +type netIPWrapper net.IP + +func (w *netIPWrapper) ScanInet(v Inet) error { + if !v.Valid { + *w = nil + return nil + } + + if oneCount, bitCount := v.IPNet.Mask.Size(); oneCount != bitCount { + return fmt.Errorf("cannot scan %v to *net.IP", v) + } + *w = netIPWrapper(v.IPNet.IP) + return nil +} + +func (w netIPWrapper) InetValue() (Inet, error) { + if w == nil { + return Inet{}, nil + } + + bitCount := len(w) * 8 + mask := net.CIDRMask(bitCount, bitCount) + return Inet{IPNet: &net.IPNet{Mask: mask, IP: net.IP(w)}, Valid: true}, nil +} diff --git a/pgtype/cidr.go b/pgtype/cidr.go deleted file mode 100644 index 2241ca1c..00000000 --- a/pgtype/cidr.go +++ /dev/null @@ -1,31 +0,0 @@ -package pgtype - -type CIDR Inet - -func (dst *CIDR) Set(src interface{}) error { - return (*Inet)(dst).Set(src) -} - -func (dst CIDR) Get() interface{} { - return (Inet)(dst).Get() -} - -func (src *CIDR) AssignTo(dst interface{}) error { - return (*Inet)(src).AssignTo(dst) -} - -func (dst *CIDR) DecodeText(ci *ConnInfo, src []byte) error { - return (*Inet)(dst).DecodeText(ci, src) -} - -func (dst *CIDR) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Inet)(dst).DecodeBinary(ci, src) -} - -func (src CIDR) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (Inet)(src).EncodeText(ci, buf) -} - -func (src CIDR) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (Inet)(src).EncodeBinary(ci, buf) -} diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go deleted file mode 100644 index 48a6a4c1..00000000 --- a/pgtype/cidr_array.go +++ /dev/null @@ -1,533 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "net" - "reflect" - - "github.com/jackc/pgio" -) - -type CIDRArray struct { - Elements []CIDR - Dimensions []ArrayDimension - Valid bool -} - -func (dst *CIDRArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = CIDRArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []*net.IPNet: - if value == nil { - *dst = CIDRArray{} - } else if len(value) == 0 { - *dst = CIDRArray{Valid: true} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []net.IP: - if value == nil { - *dst = CIDRArray{} - } else if len(value) == 0 { - *dst = CIDRArray{Valid: true} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*net.IP: - if value == nil { - *dst = CIDRArray{} - } else if len(value) == 0 { - *dst = CIDRArray{Valid: true} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []CIDR: - if value == nil { - *dst = CIDRArray{} - } else if len(value) == 0 { - *dst = CIDRArray{Valid: true} - } else { - *dst = CIDRArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = CIDRArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for CIDRArray", src) - } - if elementsLength == 0 { - *dst = CIDRArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to CIDRArray", src) - } - - *dst = CIDRArray{ - Elements: make([]CIDR, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]CIDR, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to CIDRArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *CIDRArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to CIDRArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in CIDRArray", err) - } - index++ - - return index, nil -} - -func (dst CIDRArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *CIDRArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]*net.IPNet: - *v = make([]*net.IPNet, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]net.IP: - *v = make([]net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.IP: - *v = make([]*net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *CIDRArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from CIDRArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from CIDRArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = CIDRArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []CIDR - - if len(uta.Elements) > 0 { - elements = make([]CIDR, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem CIDR - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = CIDRArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = CIDRArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = CIDRArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]CIDR, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = CIDRArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("cidr"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "cidr") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *CIDRArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src CIDRArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/cidr_array_test.go b/pgtype/cidr_array_test.go deleted file mode 100644 index 550bf9d1..00000000 --- a/pgtype/cidr_array_test.go +++ /dev/null @@ -1,319 +0,0 @@ -package pgtype_test - -import ( - "net" - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestCIDRArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "cidr[]", []interface{}{ - &pgtype.CIDRArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.CIDRArray{}, - &pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, - {}, - {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestCIDRArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.CIDRArray - }{ - { - source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]*net.IPNet)(nil)), - result: pgtype.CIDRArray{}, - }, - { - source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]net.IP)(nil)), - result: pgtype.CIDRArray{}, - }, - { - source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.CIDRArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestCIDRArrayAssignTo(t *testing.T) { - var ipnetSlice []*net.IPNet - var ipSlice []net.IP - var ipSliceDim2 [][]net.IP - var ipnetSliceDim4 [][][][]*net.IPNet - var ipArrayDim2 [2][1]net.IP - var ipnetArrayDim4 [2][1][1][3]*net.IPNet - - simpleTests := []struct { - src pgtype.CIDRArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &ipnetSlice, - expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &ipnetSlice, - expected: []*net.IPNet{nil}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &ipSlice, - expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &ipSlice, - expected: []net.IP{nil}, - }, - { - src: pgtype.CIDRArray{}, - dst: &ipnetSlice, - expected: (([]*net.IPNet)(nil)), - }, - { - src: pgtype.CIDRArray{Valid: true}, - dst: &ipnetSlice, - expected: []*net.IPNet{}, - }, - { - src: pgtype.CIDRArray{}, - dst: &ipSlice, - expected: (([]net.IP)(nil)), - }, - { - src: pgtype.CIDRArray{Valid: true}, - dst: &ipSlice, - expected: []net.IP{}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &ipSliceDim2, - expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &ipnetSliceDim4, - expected: [][][][]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &ipArrayDim2, - expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &ipnetArrayDim4, - expected: [2][1][1][3]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/pgtype/inet.go b/pgtype/inet.go index 4b3217a9..f88d1712 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -14,119 +14,208 @@ const ( defaultAFInet6 = 3 ) +type InetScanner interface { + ScanInet(v Inet) error +} + +type InetValuer interface { + InetValue() (Inet, error) +} + // Inet represents both inet and cidr PostgreSQL types. type Inet struct { IPNet *net.IPNet Valid bool } -func (dst *Inet) Set(src interface{}) error { +func (inet *Inet) ScanInet(v Inet) error { + *inet = v + return nil +} + +func (inet Inet) InetValue() (Inet, error) { + return inet, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Inet) Scan(src interface{}) error { if src == nil { *dst = Inet{} return nil } - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } + switch src := src.(type) { + case string: + return scanPlanTextAnyToInetScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) } - switch value := src.(type) { - case net.IPNet: - *dst = Inet{IPNet: &value, Valid: true} - case net.IP: - if len(value) == 0 { - *dst = Inet{} - } else { - bitCount := len(value) * 8 - mask := net.CIDRMask(bitCount, bitCount) - *dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Valid: true} - } - case string: - ip, ipnet, err := net.ParseCIDR(value) - if err != nil { - ip = net.ParseIP(value) - if ip == nil { - return fmt.Errorf("unable to parse inet address: %s", value) - } - ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} - if ipv4 := ip.To4(); ipv4 != nil { - ip = ipv4 - ipnet.Mask = net.CIDRMask(32, 32) - } - } - ipnet.IP = ip - *dst = Inet{IPNet: ipnet, Valid: true} - case *net.IPNet: - if value == nil { - *dst = Inet{} - } else { - return dst.Set(*value) - } - case *net.IP: - if value == nil { - *dst = Inet{} - } else { - return dst.Set(*value) - } - case *string: - if value == nil { - *dst = Inet{} - } else { - return dst.Set(*value) - } - default: - if originalSrc, ok := underlyingPtrType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Inet", value) + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Inet) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + buf, err := InetCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type InetCodec struct{} + +func (InetCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (InetCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (InetCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(InetValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanInetCodecBinary{} + case TextFormatCode: + return encodePlanInetCodecText{} } return nil } -func (dst Inet) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.IPNet -} +type encodePlanInetCodecBinary struct{} -func (src *Inet) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) +func (encodePlanInetCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + inet, err := value.(InetValuer).InetValue() + if err != nil { + return nil, err } - switch v := dst.(type) { - case *net.IPNet: - *v = net.IPNet{ - IP: make(net.IP, len(src.IPNet.IP)), - Mask: make(net.IPMask, len(src.IPNet.Mask)), - } - copy(v.IP, src.IPNet.IP) - copy(v.Mask, src.IPNet.Mask) - return nil - case *net.IP: - if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = make(net.IP, len(src.IPNet.IP)) - copy(*v, src.IPNet.IP) - return nil + if !inet.Valid { + return nil, nil + } + + var family byte + switch len(inet.IPNet.IP) { + case net.IPv4len: + family = defaultAFInet + case net.IPv6len: + family = defaultAFInet6 default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) + return nil, fmt.Errorf("Unexpected IP length: %v", len(inet.IPNet.IP)) } + + buf = append(buf, family) + + ones, _ := inet.IPNet.Mask.Size() + buf = append(buf, byte(ones)) + + // is_cidr is ignored on server + buf = append(buf, 0) + + buf = append(buf, byte(len(inet.IPNet.IP))) + + return append(buf, inet.IPNet.IP...), nil } -func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { +type encodePlanInetCodecText struct{} + +func (encodePlanInetCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + inet, err := value.(InetValuer).InetValue() + if err != nil { + return nil, err + } + + if !inet.Valid { + return nil, nil + } + + return append(buf, inet.IPNet.String()...), nil +} + +func (InetCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case InetScanner: + return scanPlanBinaryInetToInetScanner{} + } + case TextFormatCode: + switch target.(type) { + case InetScanner: + return scanPlanTextAnyToInetScanner{} + } + } + + return nil +} + +func (c InetCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c InetCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = Inet{} - return nil + return nil, nil + } + + var inet Inet + err := codecScan(c, ci, oid, format, src, &inet) + if err != nil { + return nil, err + } + + if !inet.Valid { + return nil, nil + } + + return inet.IPNet, nil +} + +type scanPlanBinaryInetToInetScanner struct{} + +func (scanPlanBinaryInetToInetScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(InetScanner) + + if src == nil { + return scanner.ScanInet(Inet{}) + } + + if len(src) != 8 && len(src) != 20 { + return fmt.Errorf("Received an invalid size for a inet: %d", len(src)) + } + + // ignore family + bits := src[1] + // ignore is_cidr + addressLength := src[3] + + var ipnet net.IPNet + ipnet.IP = make(net.IP, int(addressLength)) + copy(ipnet.IP, src[4:]) + if ipv4 := ipnet.IP.To4(); ipv4 != nil { + ipnet.IP = ipv4 + } + ipnet.Mask = net.CIDRMask(int(bits), len(ipnet.IP)*8) + + return scanner.ScanInet(Inet{IPNet: &ipnet, Valid: true}) +} + +type scanPlanTextAnyToInetScanner struct{} + +func (scanPlanTextAnyToInetScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(InetScanner) + + if src == nil { + return scanner.ScanInet(Inet{}) } var ipnet *net.IPNet @@ -151,95 +240,5 @@ func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { *ipnet = net.IPNet{IP: ip, Mask: net.CIDRMask(ones, len(ip)*8)} } - *dst = Inet{IPNet: ipnet, Valid: true} - return nil -} - -func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Inet{} - return nil - } - - if len(src) != 8 && len(src) != 20 { - return fmt.Errorf("Received an invalid size for a inet: %d", len(src)) - } - - // ignore family - bits := src[1] - // ignore is_cidr - addressLength := src[3] - - var ipnet net.IPNet - ipnet.IP = make(net.IP, int(addressLength)) - copy(ipnet.IP, src[4:]) - if ipv4 := ipnet.IP.To4(); ipv4 != nil { - ipnet.IP = ipv4 - } - ipnet.Mask = net.CIDRMask(int(bits), len(ipnet.IP)*8) - - *dst = Inet{IPNet: &ipnet, Valid: true} - - return nil -} - -func (src Inet) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return append(buf, src.IPNet.String()...), nil -} - -// EncodeBinary encodes src into w. -func (src Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - var family byte - switch len(src.IPNet.IP) { - case net.IPv4len: - family = defaultAFInet - case net.IPv6len: - family = defaultAFInet6 - default: - return nil, fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) - } - - buf = append(buf, family) - - ones, _ := src.IPNet.Mask.Size() - buf = append(buf, byte(ones)) - - // is_cidr is ignored on server - buf = append(buf, 0) - - buf = append(buf, byte(len(src.IPNet.IP))) - - return append(buf, src.IPNet.IP...), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Inet) Scan(src interface{}) error { - if src == nil { - *dst = Inet{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Inet) Value() (driver.Value, error) { - return EncodeValueText(src) + return scanner.ScanInet(Inet{IPNet: ipnet, Valid: true}) } diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go deleted file mode 100644 index 7f41c4e5..00000000 --- a/pgtype/inet_array.go +++ /dev/null @@ -1,533 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "net" - "reflect" - - "github.com/jackc/pgio" -) - -type InetArray struct { - Elements []Inet - Dimensions []ArrayDimension - Valid bool -} - -func (dst *InetArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = InetArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []*net.IPNet: - if value == nil { - *dst = InetArray{} - } else if len(value) == 0 { - *dst = InetArray{Valid: true} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []net.IP: - if value == nil { - *dst = InetArray{} - } else if len(value) == 0 { - *dst = InetArray{Valid: true} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*net.IP: - if value == nil { - *dst = InetArray{} - } else if len(value) == 0 { - *dst = InetArray{Valid: true} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Inet: - if value == nil { - *dst = InetArray{} - } else if len(value) == 0 { - *dst = InetArray{Valid: true} - } else { - *dst = InetArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = InetArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for InetArray", src) - } - if elementsLength == 0 { - *dst = InetArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to InetArray", src) - } - - *dst = InetArray{ - Elements: make([]Inet, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Inet, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to InetArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *InetArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to InetArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in InetArray", err) - } - index++ - - return index, nil -} - -func (dst InetArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *InetArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]*net.IPNet: - *v = make([]*net.IPNet, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]net.IP: - *v = make([]net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.IP: - *v = make([]*net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *InetArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from InetArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from InetArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = InetArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Inet - - if len(uta.Elements) > 0 { - elements = make([]Inet, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Inet - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = InetArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = InetArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = InetArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Inet, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = InetArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("inet"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "inet") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *InetArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src InetArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/inet_array_test.go b/pgtype/inet_array_test.go deleted file mode 100644 index da7ee975..00000000 --- a/pgtype/inet_array_test.go +++ /dev/null @@ -1,319 +0,0 @@ -package pgtype_test - -import ( - "net" - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestInetArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "inet[]", []interface{}{ - &pgtype.InetArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.InetArray{}, - &pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, - {}, - {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestInetArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.InetArray - }{ - { - source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]*net.IPNet)(nil)), - result: pgtype.InetArray{}, - }, - { - source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]net.IP)(nil)), - result: pgtype.InetArray{}, - }, - { - source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.InetArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInetArrayAssignTo(t *testing.T) { - var ipnetSlice []*net.IPNet - var ipSlice []net.IP - var ipSliceDim2 [][]net.IP - var ipnetSliceDim4 [][][][]*net.IPNet - var ipArrayDim2 [2][1]net.IP - var ipnetArrayDim4 [2][1][1][3]*net.IPNet - - simpleTests := []struct { - src pgtype.InetArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &ipnetSlice, - expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &ipnetSlice, - expected: []*net.IPNet{nil}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &ipSlice, - expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &ipSlice, - expected: []net.IP{nil}, - }, - { - src: pgtype.InetArray{}, - dst: &ipnetSlice, - expected: (([]*net.IPNet)(nil)), - }, - { - src: pgtype.InetArray{Valid: true}, - dst: &ipnetSlice, - expected: []*net.IPNet{}, - }, - { - src: pgtype.InetArray{}, - dst: &ipSlice, - expected: (([]net.IP)(nil)), - }, - { - src: pgtype.InetArray{Valid: true}, - dst: &ipSlice, - expected: []net.IP{}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &ipSliceDim2, - expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &ipnetSliceDim4, - expected: [][][][]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &ipArrayDim2, - expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &ipnetArrayDim4, - expected: [2][1][1][3]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go index d4716479..4ead4672 100644 --- a/pgtype/inet_test.go +++ b/pgtype/inet_test.go @@ -2,138 +2,48 @@ package pgtype_test import ( "net" - "reflect" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" - "github.com/stretchr/testify/assert" ) +func isExpectedEqIPNet(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + ap := a.(*net.IPNet) + vp := v.(net.IPNet) + + return ap.IP.Equal(vp.IP) && ap.Mask.String() == vp.Mask.String() + } +} + func TestInetTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "inet", []interface{}{ - &pgtype.Inet{IPNet: mustParseInet(t, "0.0.0.0/32"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "127.0.0.1/8"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "12.34.56.65/32"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "192.168.1.16/24"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "255.0.0.0/8"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "255.255.255.255/32"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "::1/64"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "::/0"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "::1/128"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), Valid: true}, - &pgtype.Inet{}, + testPgxCodec(t, "inet", []PgxTranscodeTestCase{ + {mustParseInet(t, "0.0.0.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "0.0.0.0/32"))}, + {mustParseInet(t, "127.0.0.1/8"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "127.0.0.1/8"))}, + {mustParseInet(t, "12.34.56.65/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "12.34.56.65/32"))}, + {mustParseInet(t, "192.168.1.16/24"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "192.168.1.16/24"))}, + {mustParseInet(t, "255.0.0.0/8"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "255.0.0.0/8"))}, + {mustParseInet(t, "255.255.255.255/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "255.255.255.255/32"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "2607:f8b0:4009:80b::200e"))}, + {mustParseInet(t, "::1/64"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::1/64"))}, + {mustParseInet(t, "::/0"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::/0"))}, + {mustParseInet(t, "::1/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::1/128"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "2607:f8b0:4009:80b::200e/64"))}, + {nil, new(pgtype.Inet), isExpectedEq(pgtype.Inet{})}, }) } func TestCidrTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "cidr", []interface{}{ - &pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, - &pgtype.Inet{}, + testPgxCodec(t, "cidr", []PgxTranscodeTestCase{ + {mustParseInet(t, "0.0.0.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "0.0.0.0/32"))}, + {mustParseInet(t, "127.0.0.1/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "127.0.0.1/32"))}, + {mustParseInet(t, "12.34.56.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "12.34.56.0/32"))}, + {mustParseInet(t, "192.168.1.0/24"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "192.168.1.0/24"))}, + {mustParseInet(t, "255.0.0.0/8"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "255.0.0.0/8"))}, + {mustParseInet(t, "::/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::/128"))}, + {mustParseInet(t, "::/0"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::/0"))}, + {mustParseInet(t, "::1/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::1/128"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "2607:f8b0:4009:80b::200e/128"))}, + {nil, new(pgtype.Inet), isExpectedEq(pgtype.Inet{})}, }) } - -func TestInetSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Inet - }{ - {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - {source: "1.2.3.4/24", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("1.2.3.4"), Mask: net.CIDRMask(24, 32)}, Valid: true}}, - {source: "10.0.0.1", result: pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Valid: true}}, - {source: "2607:f8b0:4009:80b::200e", result: pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Valid: true}}, - {source: net.ParseIP(""), result: pgtype.Inet{}}, - } - - for i, tt := range successfulTests { - var r pgtype.Inet - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - continue - } - - assert.Equalf(t, tt.result.Valid, r.Valid, "%d: Status", i) - if tt.result.Valid { - assert.Equalf(t, tt.result.IPNet.Mask, r.IPNet.Mask, "%d: IP", i) - assert.Truef(t, tt.result.IPNet.IP.Equal(r.IPNet.IP), "%d: Mask", i) - } - } -} - -func TestInetAssignTo(t *testing.T) { - var ipnet net.IPNet - var pipnet *net.IPNet - var ip net.IP - var pip *net.IP - - simpleTests := []struct { - src pgtype.Inet - dst interface{} - expected interface{} - }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, - {src: pgtype.Inet{}, dst: &pipnet, expected: ((*net.IPNet)(nil))}, - {src: pgtype.Inet{}, dst: &pip, expected: ((*net.IP)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %#v, but result was %#v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Inet - dst interface{} - expected interface{} - }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Inet - dst interface{} - }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Valid: true}, dst: &ip}, - {src: pgtype.Inet{}, dst: &ipnet}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index e48bec51..89c7b348 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -262,11 +262,11 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementCodec: BoolCodec{}, ElementOID: BoolOID}}) ci.RegisterDataType(DataType{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: BPCharOID}}) ci.RegisterDataType(DataType{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementCodec: ByteaCodec{}, ElementOID: ByteaOID}}) - ci.RegisterDataType(DataType{Value: &CIDRArray{}, Name: "_cidr", OID: CIDRArrayOID}) + ci.RegisterDataType(DataType{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementCodec: InetCodec{}, ElementOID: CIDROID}}) ci.RegisterDataType(DataType{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementCodec: DateCodec{}, ElementOID: DateOID}}) ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID}) ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID}) - ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID}) + ci.RegisterDataType(DataType{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementCodec: InetCodec{}, ElementOID: InetOID}}) ci.RegisterDataType(DataType{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementCodec: Int2Codec{}, ElementOID: Int2OID}}) ci.RegisterDataType(DataType{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementCodec: Int4Codec{}, ElementOID: Int4OID}}) ci.RegisterDataType(DataType{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementCodec: Int8Codec{}, ElementOID: Int8OID}}) @@ -293,13 +293,13 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) ci.RegisterDataType(DataType{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) - ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID}) + ci.RegisterDataType(DataType{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) ci.RegisterDataType(DataType{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) ci.RegisterDataType(DataType{Name: "date", OID: DateOID, Codec: DateCodec{}}) // ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID}) ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID}) - ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID}) + ci.RegisterDataType(DataType{Name: "inet", OID: InetOID, Codec: InetCodec{}}) ci.RegisterDataType(DataType{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) ci.RegisterDataType(DataType{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) // ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) @@ -336,15 +336,26 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) { + // T ci.RegisterDefaultPgType(value, name) - valueType := reflect.TypeOf(value) + // *T + valueType := reflect.TypeOf(value) ci.RegisterDefaultPgType(reflect.New(valueType).Interface(), name) + // []T sliceType := reflect.SliceOf(valueType) ci.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName) + // *[]T ci.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName) + + // []*T + sliceOfPointerType := reflect.SliceOf(reflect.TypeOf(reflect.New(valueType).Interface())) + ci.RegisterDefaultPgType(reflect.MakeSlice(sliceOfPointerType, 0, 0).Interface(), arrayName) + + // *[]*T + ci.RegisterDefaultPgType(reflect.New(sliceOfPointerType).Interface(), arrayName) } // Integer types that directly map to a PostgreSQL type @@ -368,8 +379,7 @@ func NewConnInfo() *ConnInfo { registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil)) registerDefaultPgTypeVariants("inet", "_inet", net.IP{}) - ci.RegisterDefaultPgType((*net.IPNet)(nil), "cidr") - ci.RegisterDefaultPgType([]*net.IPNet(nil), "_cidr") + registerDefaultPgTypeVariants("cidr", "_cidr", net.IPNet{}) return ci } @@ -816,6 +826,10 @@ func tryWrapBuiltinTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter switch dst := dst.(type) { case *time.Time: return &wrapTimeScanPlan{}, (*timeWrapper)(dst), true + case *net.IPNet: + return &wrapNetIPNetScanPlan{}, (*netIPNetWrapper)(dst), true + case *net.IP: + return &wrapNetIPScanPlan{}, (*netIPWrapper)(dst), true } return nil, nil, false @@ -831,6 +845,26 @@ func (plan *wrapTimeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, s return plan.next.Scan(ci, oid, formatCode, src, (*timeWrapper)(dst.(*time.Time))) } +type wrapNetIPNetScanPlan struct { + next ScanPlan +} + +func (plan *wrapNetIPNetScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapNetIPNetScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*netIPNetWrapper)(dst.(*net.IPNet))) +} + +type wrapNetIPScanPlan struct { + next ScanPlan +} + +func (plan *wrapNetIPScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapNetIPScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*netIPWrapper)(dst.(*net.IP))) +} + type pointerEmptyInterfaceScanPlan struct { codec Codec } @@ -901,6 +935,7 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan if oid == 0 { if dataType, ok := ci.DataTypeForValue(dst); ok { dt = dataType + oid = dt.OID // Preserve assumed OID in case we are recursively called below. } } else { if dataType, ok := ci.DataTypeForOID(oid); ok { @@ -1031,6 +1066,7 @@ func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) Enco if oid == 0 { if dataType, ok := ci.DataTypeForValue(value); ok { dt = dataType + oid = dt.OID // Preserve assumed OID in case we are recursively called below. } } else { if dataType, ok := ci.DataTypeForOID(oid); ok { @@ -1166,6 +1202,10 @@ func tryWrapBuiltinTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNext return &wrapStringEncodePlan{}, stringWrapper(value), true case time.Time: return &wrapTimeEncodePlan{}, timeWrapper(value), true + case net.IPNet: + return &wrapNetIPNetEncodePlan{}, netIPNetWrapper(value), true + case net.IP: + return &wrapNetIPEncodePlan{}, netIPWrapper(value), true } return nil, nil, false @@ -1311,6 +1351,26 @@ func (plan *wrapTimeEncodePlan) Encode(value interface{}, buf []byte) (newBuf [] return plan.next.Encode(timeWrapper(value.(time.Time)), buf) } +type wrapNetIPNetEncodePlan struct { + next EncodePlan +} + +func (plan *wrapNetIPNetEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapNetIPNetEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(netIPNetWrapper(value.(net.IPNet)), buf) +} + +type wrapNetIPEncodePlan struct { + next EncodePlan +} + +func (plan *wrapNetIPEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapNetIPEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(netIPWrapper(value.(net.IP)), buf) +} + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written. From 313254c75d8aaf2abdac0e8520797fd22b36c690 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 15 Jan 2022 11:12:06 -0600 Subject: [PATCH 0836/1158] Convert float4 and float8 to Codec --- pgtype/float4.go | 397 ++++++++++++++-------------- pgtype/float4_array.go | 504 ------------------------------------ pgtype/float4_array_test.go | 282 -------------------- pgtype/float4_test.go | 149 +---------- pgtype/float8.go | 445 +++++++++++++++++-------------- pgtype/float8_array.go | 504 ------------------------------------ pgtype/float8_array_test.go | 258 ------------------ pgtype/float8_test.go | 149 +---------- pgtype/pgtype.go | 164 +++++++++++- pgtype/zeronull/float8.go | 72 ++---- 10 files changed, 644 insertions(+), 2280 deletions(-) delete mode 100644 pgtype/float4_array.go delete mode 100644 pgtype/float4_array_test.go delete mode 100644 pgtype/float8_array.go delete mode 100644 pgtype/float8_array_test.go diff --git a/pgtype/float4.go b/pgtype/float4.go index 36c46346..7699f656 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -15,198 +15,158 @@ type Float4 struct { Valid bool } -func (dst *Float4) Set(src interface{}) error { +// ScanFloat64 implements the Float64Scanner interface. +func (f *Float4) ScanFloat64(n Float8) error { + *f = Float4{Float: float32(n.Float), Valid: n.Valid} + return nil +} + +func (f Float4) Float64Value() (Float8, error) { + return Float8{Float: float64(f.Float), Valid: f.Valid}, nil +} + +// Scan implements the database/sql Scanner interface. +func (f *Float4) Scan(src interface{}) error { if src == nil { - *dst = Float4{} + *f = Float4{} return nil } - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case float32: - *dst = Float4{Float: value, Valid: true} + switch src := src.(type) { case float64: - *dst = Float4{Float: float32(value), Valid: true} - case int8: - *dst = Float4{Float: float32(value), Valid: true} - case uint8: - *dst = Float4{Float: float32(value), Valid: true} - case int16: - *dst = Float4{Float: float32(value), Valid: true} - case uint16: - *dst = Float4{Float: float32(value), Valid: true} - case int32: - f32 := float32(value) - if int32(f32) == value { - *dst = Float4{Float: f32, Valid: true} - } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) - } - case uint32: - f32 := float32(value) - if uint32(f32) == value { - *dst = Float4{Float: f32, Valid: true} - } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) - } - case int64: - f32 := float32(value) - if int64(f32) == value { - *dst = Float4{Float: f32, Valid: true} - } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) - } - case uint64: - f32 := float32(value) - if uint64(f32) == value { - *dst = Float4{Float: f32, Valid: true} - } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) - } - case int: - f32 := float32(value) - if int(f32) == value { - *dst = Float4{Float: f32, Valid: true} - } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) - } - case uint: - f32 := float32(value) - if uint(f32) == value { - *dst = Float4{Float: f32, Valid: true} - } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) - } + *f = Float4{Float: float32(src), Valid: true} + return nil case string: - num, err := strconv.ParseFloat(value, 32) + n, err := strconv.ParseFloat(string(src), 32) if err != nil { return err } - *dst = Float4{Float: float32(num), Valid: true} - case *float64: - if value == nil { - *dst = Float4{} - } else { - return dst.Set(*value) + *f = Float4{Float: float32(n), Valid: true} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (f Float4) Value() (driver.Value, error) { + if !f.Valid { + return nil, nil + } + return float64(f.Float), nil +} + +type Float4Codec struct{} + +func (Float4Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Float4Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Float4Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case float32: + return encodePlanFloat4CodecBinaryFloat32{} + case Float64Valuer: + return encodePlanFloat4CodecBinaryFloat64Valuer{} + case Int64Valuer: + return encodePlanFloat4CodecBinaryInt64Valuer{} } - case *float32: - if value == nil { - *dst = Float4{} - } else { - return dst.Set(*value) + case TextFormatCode: + switch value.(type) { + case float32: + return encodePlanTextFloat32{} + case Float64Valuer: + return encodePlanTextFloat64Valuer{} + case Int64Valuer: + return encodePlanTextInt64Valuer{} } - case *int8: - if value == nil { - *dst = Float4{} - } else { - return dst.Set(*value) - } - case *uint8: - if value == nil { - *dst = Float4{} - } else { - return dst.Set(*value) - } - case *int16: - if value == nil { - *dst = Float4{} - } else { - return dst.Set(*value) - } - case *uint16: - if value == nil { - *dst = Float4{} - } else { - return dst.Set(*value) - } - case *int32: - if value == nil { - *dst = Float4{} - } else { - return dst.Set(*value) - } - case *uint32: - if value == nil { - *dst = Float4{} - } else { - return dst.Set(*value) - } - case *int64: - if value == nil { - *dst = Float4{} - } else { - return dst.Set(*value) - } - case *uint64: - if value == nil { - *dst = Float4{} - } else { - return dst.Set(*value) - } - case *int: - if value == nil { - *dst = Float4{} - } else { - return dst.Set(*value) - } - case *uint: - if value == nil { - *dst = Float4{} - } else { - return dst.Set(*value) - } - case *string: - if value == nil { - *dst = Float4{} - } else { - return dst.Set(*value) - } - default: - if originalSrc, ok := underlyingNumberType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Float8", value) } return nil } -func (dst Float4) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.Float +type encodePlanFloat4CodecBinaryFloat32 struct{} + +func (encodePlanFloat4CodecBinaryFloat32) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(float32) + return pgio.AppendUint32(buf, math.Float32bits(n)), nil } -func (src *Float4) AssignTo(dst interface{}) error { - return float64AssignTo(float64(src.Float), src.Valid, dst) +type encodePlanTextFloat32 struct{} + +func (encodePlanTextFloat32) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(float32) + return append(buf, strconv.FormatFloat(float64(n), 'f', -1, 32)...), nil } -func (dst *Float4) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Float4{} - return nil - } +type encodePlanFloat4CodecBinaryFloat64Valuer struct{} - n, err := strconv.ParseFloat(string(src), 32) +func (encodePlanFloat4CodecBinaryFloat64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Float64Valuer).Float64Value() if err != nil { - return err + return nil, err + } + + if !n.Valid { + return nil, nil + } + + return pgio.AppendUint32(buf, math.Float32bits(float32(n.Float))), nil +} + +type encodePlanFloat4CodecBinaryInt64Valuer struct{} + +func (encodePlanFloat4CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + f := float32(n.Int) + return pgio.AppendUint32(buf, math.Float32bits(f)), nil +} + +func (Float4Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *float32: + return scanPlanBinaryFloat4ToFloat32{} + case Float64Scanner: + return scanPlanBinaryFloat4ToFloat64Scanner{} + case Int64Scanner: + return scanPlanBinaryFloat4ToInt64Scanner{} + } + case TextFormatCode: + switch target.(type) { + case *float32: + return scanPlanTextAnyToFloat32{} + case Float64Scanner: + return scanPlanTextAnyToFloat64Scanner{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } } - *dst = Float4{Float: float32(n), Valid: true} return nil } -func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { +type scanPlanBinaryFloat4ToFloat32 struct{} + +func (scanPlanBinaryFloat4ToFloat32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { - *dst = Float4{} - return nil + return fmt.Errorf("cannot scan null into %T", dst) } if len(src) != 4 { @@ -214,55 +174,92 @@ func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { } n := int32(binary.BigEndian.Uint32(src)) + f := (dst).(*float32) + *f = math.Float32frombits(uint32(n)) - *dst = Float4{Float: math.Float32frombits(uint32(n)), Valid: true} return nil } -func (src Float4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } +type scanPlanBinaryFloat4ToFloat64Scanner struct{} - buf = append(buf, strconv.FormatFloat(float64(src.Float), 'f', -1, 32)...) - return buf, nil -} +func (scanPlanBinaryFloat4ToFloat64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s := (dst).(Float64Scanner) -func (src Float4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = pgio.AppendUint32(buf, math.Float32bits(src.Float)) - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Float4) Scan(src interface{}) error { if src == nil { - *dst = Float4{} - return nil + return s.ScanFloat64(Float8{}) } - switch src := src.(type) { - case float64: - *dst = Float4{Float: float32(src), Valid: true} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + if len(src) != 4 { + return fmt.Errorf("invalid length for float4: %v", len(src)) } - return fmt.Errorf("cannot scan %T", src) + n := int32(binary.BigEndian.Uint32(src)) + return s.ScanFloat64(Float8{Float: float64(math.Float32frombits(uint32(n))), Valid: true}) } -// Value implements the database/sql/driver Valuer interface. -func (src Float4) Value() (driver.Value, error) { - if !src.Valid { +type scanPlanBinaryFloat4ToInt64Scanner struct{} + +func (scanPlanBinaryFloat4ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s := (dst).(Int64Scanner) + + if src == nil { + return s.ScanInt64(Int8{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for float4: %v", len(src)) + } + + ui32 := int32(binary.BigEndian.Uint32(src)) + f32 := math.Float32frombits(uint32(ui32)) + i64 := int64(f32) + if f32 != float32(i64) { + return fmt.Errorf("cannot losslessly convert %v to int64", f32) + } + + return s.ScanInt64(Int8{Int: i64, Valid: true}) +} + +type scanPlanTextAnyToFloat32 struct{} + +func (scanPlanTextAnyToFloat32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + n, err := strconv.ParseFloat(string(src), 32) + if err != nil { + return err + } + + f := (dst).(*float32) + *f = float32(n) + + return nil +} + +func (c Float4Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { return nil, nil } - return float64(src.Float), nil + + var n float64 + err := codecScan(c, ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +func (c Float4Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var n float32 + err := codecScan(c, ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil } diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go deleted file mode 100644 index dcf6c1f7..00000000 --- a/pgtype/float4_array.go +++ /dev/null @@ -1,504 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type Float4Array struct { - Elements []Float4 - Dimensions []ArrayDimension - Valid bool -} - -func (dst *Float4Array) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Float4Array{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []float32: - if value == nil { - *dst = Float4Array{} - } else if len(value) == 0 { - *dst = Float4Array{Valid: true} - } else { - elements := make([]Float4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*float32: - if value == nil { - *dst = Float4Array{} - } else if len(value) == 0 { - *dst = Float4Array{Valid: true} - } else { - elements := make([]Float4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Float4: - if value == nil { - *dst = Float4Array{} - } else if len(value) == 0 { - *dst = Float4Array{Valid: true} - } else { - *dst = Float4Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = Float4Array{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for Float4Array", src) - } - if elementsLength == 0 { - *dst = Float4Array{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Float4Array", src) - } - - *dst = Float4Array{ - Elements: make([]Float4, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Float4, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to Float4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *Float4Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to Float4Array") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in Float4Array", err) - } - index++ - - return index, nil -} - -func (dst Float4Array) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *Float4Array) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]float32: - *v = make([]float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float32: - *v = make([]*float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *Float4Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from Float4Array") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from Float4Array") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Float4Array{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Float4 - - if len(uta.Elements) > 0 { - elements = make([]Float4, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Float4 - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = Float4Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Float4Array{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = Float4Array{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Float4, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = Float4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("float4"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "float4") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Float4Array) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Float4Array) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/float4_array_test.go b/pgtype/float4_array_test.go deleted file mode 100644 index ecd65206..00000000 --- a/pgtype/float4_array_test.go +++ /dev/null @@ -1,282 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestFloat4ArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "float4[]", []interface{}{ - &pgtype.Float4Array{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.Float4Array{}, - &pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Valid: true}, - {Float: 2, Valid: true}, - {Float: 3, Valid: true}, - {Float: 4, Valid: true}, - {}, - {Float: 6, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Valid: true}, - {Float: 2, Valid: true}, - {Float: 3, Valid: true}, - {Float: 4, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestFloat4ArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Float4Array - }{ - { - source: []float32{1}, - result: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]float32)(nil)), - result: pgtype.Float4Array{}, - }, - { - source: [][]float32{{1}, {2}}, - result: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Valid: true}, - {Float: 2, Valid: true}, - {Float: 3, Valid: true}, - {Float: 4, Valid: true}, - {Float: 5, Valid: true}, - {Float: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]float32{{1}, {2}}, - result: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Valid: true}, - {Float: 2, Valid: true}, - {Float: 3, Valid: true}, - {Float: 4, Valid: true}, - {Float: 5, Valid: true}, - {Float: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Float4Array - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestFloat4ArrayAssignTo(t *testing.T) { - var float32Slice []float32 - var namedFloat32Slice _float32Slice - var float32SliceDim2 [][]float32 - var float32SliceDim4 [][][][]float32 - var float32ArrayDim2 [2][1]float32 - var float32ArrayDim4 [2][1][1][3]float32 - - simpleTests := []struct { - src pgtype.Float4Array - dst interface{} - expected interface{} - }{ - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1.23, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &float32Slice, - expected: []float32{1.23}, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1.23, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &namedFloat32Slice, - expected: _float32Slice{1.23}, - }, - { - src: pgtype.Float4Array{}, - dst: &float32Slice, - expected: (([]float32)(nil)), - }, - { - src: pgtype.Float4Array{Valid: true}, - dst: &float32Slice, - expected: []float32{}, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - expected: [][]float32{{1}, {2}}, - dst: &float32SliceDim2, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Valid: true}, - {Float: 2, Valid: true}, - {Float: 3, Valid: true}, - {Float: 4, Valid: true}, - {Float: 5, Valid: true}, - {Float: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - expected: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &float32SliceDim4, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - expected: [2][1]float32{{1}, {2}}, - dst: &float32ArrayDim2, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Valid: true}, - {Float: 2, Valid: true}, - {Float: 3, Valid: true}, - {Float: 4, Valid: true}, - {Float: 5, Valid: true}, - {Float: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - expected: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &float32ArrayDim4, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Float4Array - dst interface{} - }{ - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &float32Slice, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &float32ArrayDim2, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &float32Slice, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &float32ArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/float4_test.go b/pgtype/float4_test.go index aab6e980..85b3b21d 100644 --- a/pgtype/float4_test.go +++ b/pgtype/float4_test.go @@ -1,149 +1,20 @@ package pgtype_test import ( - "reflect" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestFloat4Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "float4", []interface{}{ - &pgtype.Float4{Float: -1, Valid: true}, - &pgtype.Float4{Float: 0, Valid: true}, - &pgtype.Float4{Float: 0.00001, Valid: true}, - &pgtype.Float4{Float: 1, Valid: true}, - &pgtype.Float4{Float: 9999.99, Valid: true}, - &pgtype.Float4{Float: 0}, +func TestFloat4Codec(t *testing.T) { + testPgxCodec(t, "float4", []PgxTranscodeTestCase{ + {pgtype.Float4{Float: -1, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float: -1, Valid: true})}, + {pgtype.Float4{Float: 0, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float: 0, Valid: true})}, + {pgtype.Float4{Float: 1, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float: 1, Valid: true})}, + {float32(0.00001), new(float32), isExpectedEq(float32(0.00001))}, + {float32(9999.99), new(float32), isExpectedEq(float32(9999.99))}, + {pgtype.Float4{}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{})}, + {int64(1), new(int64), isExpectedEq(int64(1))}, + {nil, new(*float32), isExpectedEq((*float32)(nil))}, }) } - -func TestFloat4Set(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Float4 - }{ - {source: float32(1), result: pgtype.Float4{Float: 1, Valid: true}}, - {source: float64(1), result: pgtype.Float4{Float: 1, Valid: true}}, - {source: int8(1), result: pgtype.Float4{Float: 1, Valid: true}}, - {source: int16(1), result: pgtype.Float4{Float: 1, Valid: true}}, - {source: int32(1), result: pgtype.Float4{Float: 1, Valid: true}}, - {source: int64(1), result: pgtype.Float4{Float: 1, Valid: true}}, - {source: int8(-1), result: pgtype.Float4{Float: -1, Valid: true}}, - {source: int16(-1), result: pgtype.Float4{Float: -1, Valid: true}}, - {source: int32(-1), result: pgtype.Float4{Float: -1, Valid: true}}, - {source: int64(-1), result: pgtype.Float4{Float: -1, Valid: true}}, - {source: uint8(1), result: pgtype.Float4{Float: 1, Valid: true}}, - {source: uint16(1), result: pgtype.Float4{Float: 1, Valid: true}}, - {source: uint32(1), result: pgtype.Float4{Float: 1, Valid: true}}, - {source: uint64(1), result: pgtype.Float4{Float: 1, Valid: true}}, - {source: "1", result: pgtype.Float4{Float: 1, Valid: true}}, - {source: _int8(1), result: pgtype.Float4{Float: 1, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.Float4 - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestFloat4AssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - var f32 float32 - var f64 float64 - var pf32 *float32 - var pf64 *float64 - - simpleTests := []struct { - src pgtype.Float4 - dst interface{} - expected interface{} - }{ - {src: pgtype.Float4{Float: 42, Valid: true}, dst: &f32, expected: float32(42)}, - {src: pgtype.Float4{Float: 42, Valid: true}, dst: &f64, expected: float64(42)}, - {src: pgtype.Float4{Float: 42, Valid: true}, dst: &i16, expected: int16(42)}, - {src: pgtype.Float4{Float: 42, Valid: true}, dst: &i32, expected: int32(42)}, - {src: pgtype.Float4{Float: 42, Valid: true}, dst: &i64, expected: int64(42)}, - {src: pgtype.Float4{Float: 42, Valid: true}, dst: &i, expected: int(42)}, - {src: pgtype.Float4{Float: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Float4{Float: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Float4{Float: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Float4{Float: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Float4{Float: 42, Valid: true}, dst: &ui, expected: uint(42)}, - {src: pgtype.Float4{Float: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Float4{Float: 0}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Float4{Float: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Float4 - dst interface{} - expected interface{} - }{ - {src: pgtype.Float4{Float: 42, Valid: true}, dst: &pf32, expected: float32(42)}, - {src: pgtype.Float4{Float: 42, Valid: true}, dst: &pf64, expected: float64(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Float4 - dst interface{} - }{ - {src: pgtype.Float4{Float: 150, Valid: true}, dst: &i8}, - {src: pgtype.Float4{Float: 40000, Valid: true}, dst: &i16}, - {src: pgtype.Float4{Float: -1, Valid: true}, dst: &ui8}, - {src: pgtype.Float4{Float: -1, Valid: true}, dst: &ui16}, - {src: pgtype.Float4{Float: -1, Valid: true}, dst: &ui32}, - {src: pgtype.Float4{Float: -1, Valid: true}, dst: &ui64}, - {src: pgtype.Float4{Float: -1, Valid: true}, dst: &ui}, - {src: pgtype.Float4{Float: 0}, dst: &i32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/float8.go b/pgtype/float8.go index 1038d283..86638ab1 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -10,178 +10,259 @@ import ( "github.com/jackc/pgio" ) +type Float64Scanner interface { + ScanFloat64(Float8) error +} + +type Float64Valuer interface { + Float64Value() (Float8, error) +} + type Float8 struct { Float float64 Valid bool } -func (dst *Float8) Set(src interface{}) error { +// ScanFloat64 implements the Float64Scanner interface. +func (f *Float8) ScanFloat64(n Float8) error { + *f = n + return nil +} + +func (f Float8) Float64Value() (Float8, error) { + return f, nil +} + +// Scan implements the database/sql Scanner interface. +func (f *Float8) Scan(src interface{}) error { if src == nil { - *dst = Float8{} + *f = Float8{} return nil } - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case float32: - *dst = Float8{Float: float64(value), Valid: true} + switch src := src.(type) { case float64: - *dst = Float8{Float: value, Valid: true} - case int8: - *dst = Float8{Float: float64(value), Valid: true} - case uint8: - *dst = Float8{Float: float64(value), Valid: true} - case int16: - *dst = Float8{Float: float64(value), Valid: true} - case uint16: - *dst = Float8{Float: float64(value), Valid: true} - case int32: - *dst = Float8{Float: float64(value), Valid: true} - case uint32: - *dst = Float8{Float: float64(value), Valid: true} - case int64: - f64 := float64(value) - if int64(f64) == value { - *dst = Float8{Float: f64, Valid: true} - } else { - return fmt.Errorf("%v cannot be exactly represented as float64", value) - } - case uint64: - f64 := float64(value) - if uint64(f64) == value { - *dst = Float8{Float: f64, Valid: true} - } else { - return fmt.Errorf("%v cannot be exactly represented as float64", value) - } - case int: - f64 := float64(value) - if int(f64) == value { - *dst = Float8{Float: f64, Valid: true} - } else { - return fmt.Errorf("%v cannot be exactly represented as float64", value) - } - case uint: - f64 := float64(value) - if uint(f64) == value { - *dst = Float8{Float: f64, Valid: true} - } else { - return fmt.Errorf("%v cannot be exactly represented as float64", value) - } + *f = Float8{Float: src, Valid: true} + return nil case string: - num, err := strconv.ParseFloat(value, 64) + n, err := strconv.ParseFloat(string(src), 64) if err != nil { return err } - *dst = Float8{Float: float64(num), Valid: true} - case *float64: - if value == nil { - *dst = Float8{} - } else { - return dst.Set(*value) + *f = Float8{Float: n, Valid: true} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (f Float8) Value() (driver.Value, error) { + if !f.Valid { + return nil, nil + } + return f.Float, nil +} + +type Float8Codec struct{} + +func (Float8Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Float8Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Float8Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case float64: + return encodePlanFloat8CodecBinaryFloat64{} + case Float64Valuer: + return encodePlanFloat8CodecBinaryFloat64Valuer{} + case Int64Valuer: + return encodePlanFloat8CodecBinaryInt64Valuer{} } - case *float32: - if value == nil { - *dst = Float8{} - } else { - return dst.Set(*value) + case TextFormatCode: + switch value.(type) { + case float64: + return encodePlanTextFloat64{} + case Float64Valuer: + return encodePlanTextFloat64Valuer{} + case Int64Valuer: + return encodePlanTextInt64Valuer{} } - case *int8: - if value == nil { - *dst = Float8{} - } else { - return dst.Set(*value) - } - case *uint8: - if value == nil { - *dst = Float8{} - } else { - return dst.Set(*value) - } - case *int16: - if value == nil { - *dst = Float8{} - } else { - return dst.Set(*value) - } - case *uint16: - if value == nil { - *dst = Float8{} - } else { - return dst.Set(*value) - } - case *int32: - if value == nil { - *dst = Float8{} - } else { - return dst.Set(*value) - } - case *uint32: - if value == nil { - *dst = Float8{} - } else { - return dst.Set(*value) - } - case *int64: - if value == nil { - *dst = Float8{} - } else { - return dst.Set(*value) - } - case *uint64: - if value == nil { - *dst = Float8{} - } else { - return dst.Set(*value) - } - case *int: - if value == nil { - *dst = Float8{} - } else { - return dst.Set(*value) - } - case *uint: - if value == nil { - *dst = Float8{} - } else { - return dst.Set(*value) - } - case *string: - if value == nil { - *dst = Float8{} - } else { - return dst.Set(*value) - } - default: - if originalSrc, ok := underlyingNumberType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Float8", value) } return nil } -func (dst Float8) Get() interface{} { - if !dst.Valid { - return nil +type encodePlanFloat8CodecBinaryFloat64 struct{} + +func (encodePlanFloat8CodecBinaryFloat64) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(float64) + return pgio.AppendUint64(buf, math.Float64bits(n)), nil +} + +type encodePlanTextFloat64 struct{} + +func (encodePlanTextFloat64) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(float64) + return append(buf, strconv.FormatFloat(n, 'f', -1, 64)...), nil +} + +type encodePlanFloat8CodecBinaryFloat64Valuer struct{} + +func (encodePlanFloat8CodecBinaryFloat64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Float64Valuer).Float64Value() + if err != nil { + return nil, err } - return dst.Float + + if !n.Valid { + return nil, nil + } + + return pgio.AppendUint64(buf, math.Float64bits(n.Float)), nil } -func (src *Float8) AssignTo(dst interface{}) error { - return float64AssignTo(src.Float, src.Valid, dst) +type encodePlanTextFloat64Valuer struct{} + +func (encodePlanTextFloat64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Float64Valuer).Float64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + return append(buf, strconv.FormatFloat(n.Float, 'f', -1, 64)...), nil } -func (dst *Float8) DecodeText(ci *ConnInfo, src []byte) error { +type encodePlanFloat8CodecBinaryInt64Valuer struct{} + +func (encodePlanFloat8CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + f := float64(n.Int) + return pgio.AppendUint64(buf, math.Float64bits(f)), nil +} + +type encodePlanTextInt64Valuer struct{} + +func (encodePlanTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + return append(buf, strconv.FormatInt(n.Int, 10)...), nil +} + +func (Float8Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *float64: + return scanPlanBinaryFloat8ToFloat64{} + case Float64Scanner: + return scanPlanBinaryFloat8ToFloat64Scanner{} + case Int64Scanner: + return scanPlanBinaryFloat8ToInt64Scanner{} + } + case TextFormatCode: + switch target.(type) { + case *float64: + return scanPlanTextAnyToFloat64{} + case Float64Scanner: + return scanPlanTextAnyToFloat64Scanner{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +type scanPlanBinaryFloat8ToFloat64 struct{} + +func (scanPlanBinaryFloat8ToFloat64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { - *dst = Float8{} - return nil + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for float8: %v", len(src)) + } + + n := int64(binary.BigEndian.Uint64(src)) + f := (dst).(*float64) + *f = math.Float64frombits(uint64(n)) + + return nil +} + +type scanPlanBinaryFloat8ToFloat64Scanner struct{} + +func (scanPlanBinaryFloat8ToFloat64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s := (dst).(Float64Scanner) + + if src == nil { + return s.ScanFloat64(Float8{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for float8: %v", len(src)) + } + + n := int64(binary.BigEndian.Uint64(src)) + return s.ScanFloat64(Float8{Float: math.Float64frombits(uint64(n)), Valid: true}) +} + +type scanPlanBinaryFloat8ToInt64Scanner struct{} + +func (scanPlanBinaryFloat8ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s := (dst).(Int64Scanner) + + if src == nil { + return s.ScanInt64(Int8{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for float8: %v", len(src)) + } + + ui64 := int64(binary.BigEndian.Uint64(src)) + f64 := math.Float64frombits(uint64(ui64)) + i64 := int64(f64) + if f64 != float64(i64) { + return fmt.Errorf("cannot losslessly convert %v to int64", f64) + } + + return s.ScanInt64(Int8{Int: i64, Valid: true}) +} + +type scanPlanTextAnyToFloat64 struct{} + +func (scanPlanTextAnyToFloat64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) } n, err := strconv.ParseFloat(string(src), 64) @@ -189,70 +270,42 @@ func (dst *Float8) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Float8{Float: n, Valid: true} + f := (dst).(*float64) + *f = n + return nil } -func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { +type scanPlanTextAnyToFloat64Scanner struct{} + +func (scanPlanTextAnyToFloat64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s := (dst).(Float64Scanner) + if src == nil { - *dst = Float8{} - return nil + return s.ScanFloat64(Float8{}) } - if len(src) != 8 { - return fmt.Errorf("invalid length for float4: %v", len(src)) + n, err := strconv.ParseFloat(string(src), 64) + if err != nil { + return err } - n := int64(binary.BigEndian.Uint64(src)) - - *dst = Float8{Float: math.Float64frombits(uint64(n)), Valid: true} - return nil + return s.ScanFloat64(Float8{Float: n, Valid: true}) } -func (src Float8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = append(buf, strconv.FormatFloat(float64(src.Float), 'f', -1, 64)...) - return buf, nil +func (c Float8Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(ci, oid, format, src) } -func (src Float8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = pgio.AppendUint64(buf, math.Float64bits(src.Float)) - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Float8) Scan(src interface{}) error { +func (c Float8Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = Float8{} - return nil - } - - switch src := src.(type) { - case float64: - *dst = Float8{Float: src, Valid: true} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Float8) Value() (driver.Value, error) { - if !src.Valid { return nil, nil } - return src.Float, nil + + var n float64 + err := codecScan(c, ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil } diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go deleted file mode 100644 index 5e85e236..00000000 --- a/pgtype/float8_array.go +++ /dev/null @@ -1,504 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type Float8Array struct { - Elements []Float8 - Dimensions []ArrayDimension - Valid bool -} - -func (dst *Float8Array) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Float8Array{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []float64: - if value == nil { - *dst = Float8Array{} - } else if len(value) == 0 { - *dst = Float8Array{Valid: true} - } else { - elements := make([]Float8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*float64: - if value == nil { - *dst = Float8Array{} - } else if len(value) == 0 { - *dst = Float8Array{Valid: true} - } else { - elements := make([]Float8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Float8: - if value == nil { - *dst = Float8Array{} - } else if len(value) == 0 { - *dst = Float8Array{Valid: true} - } else { - *dst = Float8Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = Float8Array{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for Float8Array", src) - } - if elementsLength == 0 { - *dst = Float8Array{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Float8Array", src) - } - - *dst = Float8Array{ - Elements: make([]Float8, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Float8, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to Float8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *Float8Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to Float8Array") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in Float8Array", err) - } - index++ - - return index, nil -} - -func (dst Float8Array) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *Float8Array) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]float64: - *v = make([]float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float64: - *v = make([]*float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *Float8Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from Float8Array") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from Float8Array") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Float8Array{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Float8 - - if len(uta.Elements) > 0 { - elements = make([]Float8, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Float8 - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = Float8Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Float8Array{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = Float8Array{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Float8, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = Float8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("float8"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "float8") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Float8Array) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Float8Array) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/float8_array_test.go b/pgtype/float8_array_test.go deleted file mode 100644 index 66a10784..00000000 --- a/pgtype/float8_array_test.go +++ /dev/null @@ -1,258 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestFloat8ArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "float8[]", []interface{}{ - &pgtype.Float8Array{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.Float8Array{ - Elements: []pgtype.Float8{ - {Float: 1, Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.Float8Array{}, - &pgtype.Float8Array{ - Elements: []pgtype.Float8{ - {Float: 1, Valid: true}, - {Float: 2, Valid: true}, - {Float: 3, Valid: true}, - {Float: 4, Valid: true}, - {}, - {Float: 6, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.Float8Array{ - Elements: []pgtype.Float8{ - {Float: 1, Valid: true}, - {Float: 2, Valid: true}, - {Float: 3, Valid: true}, - {Float: 4, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestFloat8ArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Float8Array - }{ - { - source: []float64{1}, - result: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]float64)(nil)), - result: pgtype.Float8Array{}, - }, - { - source: [][]float64{{1}, {2}}, - result: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.Float8Array{ - Elements: []pgtype.Float8{ - {Float: 1, Valid: true}, - {Float: 2, Valid: true}, - {Float: 3, Valid: true}, - {Float: 4, Valid: true}, - {Float: 5, Valid: true}, - {Float: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Float8Array - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestFloat8ArrayAssignTo(t *testing.T) { - var float64Slice []float64 - var namedFloat64Slice _float64Slice - var float64SliceDim2 [][]float64 - var float64SliceDim4 [][][][]float64 - var float64ArrayDim2 [2][1]float64 - var float64ArrayDim4 [2][1][1][3]float64 - - simpleTests := []struct { - src pgtype.Float8Array - dst interface{} - expected interface{} - }{ - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1.23, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &float64Slice, - expected: []float64{1.23}, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1.23, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &namedFloat64Slice, - expected: _float64Slice{1.23}, - }, - { - src: pgtype.Float8Array{}, - dst: &float64Slice, - expected: (([]float64)(nil)), - }, - { - src: pgtype.Float8Array{Valid: true}, - dst: &float64Slice, - expected: []float64{}, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - expected: [][]float64{{1}, {2}}, - dst: &float64SliceDim2, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{ - {Float: 1, Valid: true}, - {Float: 2, Valid: true}, - {Float: 3, Valid: true}, - {Float: 4, Valid: true}, - {Float: 5, Valid: true}, - {Float: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - expected: [][][][]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &float64SliceDim4, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - expected: [2][1]float64{{1}, {2}}, - dst: &float64ArrayDim2, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{ - {Float: 1, Valid: true}, - {Float: 2, Valid: true}, - {Float: 3, Valid: true}, - {Float: 4, Valid: true}, - {Float: 5, Valid: true}, - {Float: 6, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - expected: [2][1][1][3]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - dst: &float64ArrayDim4, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Float8Array - dst interface{} - }{ - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &float64Slice, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &float64ArrayDim2, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &float64Slice, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &float64ArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/float8_test.go b/pgtype/float8_test.go index e7bd4444..3c7660b8 100644 --- a/pgtype/float8_test.go +++ b/pgtype/float8_test.go @@ -1,149 +1,20 @@ package pgtype_test import ( - "reflect" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestFloat8Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "float8", []interface{}{ - &pgtype.Float8{Float: -1, Valid: true}, - &pgtype.Float8{Float: 0, Valid: true}, - &pgtype.Float8{Float: 0.00001, Valid: true}, - &pgtype.Float8{Float: 1, Valid: true}, - &pgtype.Float8{Float: 9999.99, Valid: true}, - &pgtype.Float8{Float: 0}, +func TestFloat8Codec(t *testing.T) { + testPgxCodec(t, "float8", []PgxTranscodeTestCase{ + {pgtype.Float8{Float: -1, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float: -1, Valid: true})}, + {pgtype.Float8{Float: 0, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float: 0, Valid: true})}, + {pgtype.Float8{Float: 1, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float: 1, Valid: true})}, + {float64(0.00001), new(float64), isExpectedEq(float64(0.00001))}, + {float64(9999.99), new(float64), isExpectedEq(float64(9999.99))}, + {pgtype.Float8{}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{})}, + {int64(1), new(int64), isExpectedEq(int64(1))}, + {nil, new(*float64), isExpectedEq((*float64)(nil))}, }) } - -func TestFloat8Set(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Float8 - }{ - {source: float32(1), result: pgtype.Float8{Float: 1, Valid: true}}, - {source: float64(1), result: pgtype.Float8{Float: 1, Valid: true}}, - {source: int8(1), result: pgtype.Float8{Float: 1, Valid: true}}, - {source: int16(1), result: pgtype.Float8{Float: 1, Valid: true}}, - {source: int32(1), result: pgtype.Float8{Float: 1, Valid: true}}, - {source: int64(1), result: pgtype.Float8{Float: 1, Valid: true}}, - {source: int8(-1), result: pgtype.Float8{Float: -1, Valid: true}}, - {source: int16(-1), result: pgtype.Float8{Float: -1, Valid: true}}, - {source: int32(-1), result: pgtype.Float8{Float: -1, Valid: true}}, - {source: int64(-1), result: pgtype.Float8{Float: -1, Valid: true}}, - {source: uint8(1), result: pgtype.Float8{Float: 1, Valid: true}}, - {source: uint16(1), result: pgtype.Float8{Float: 1, Valid: true}}, - {source: uint32(1), result: pgtype.Float8{Float: 1, Valid: true}}, - {source: uint64(1), result: pgtype.Float8{Float: 1, Valid: true}}, - {source: "1", result: pgtype.Float8{Float: 1, Valid: true}}, - {source: _int8(1), result: pgtype.Float8{Float: 1, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.Float8 - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestFloat8AssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - var f32 float32 - var f64 float64 - var pf32 *float32 - var pf64 *float64 - - simpleTests := []struct { - src pgtype.Float8 - dst interface{} - expected interface{} - }{ - {src: pgtype.Float8{Float: 42, Valid: true}, dst: &f32, expected: float32(42)}, - {src: pgtype.Float8{Float: 42, Valid: true}, dst: &f64, expected: float64(42)}, - {src: pgtype.Float8{Float: 42, Valid: true}, dst: &i16, expected: int16(42)}, - {src: pgtype.Float8{Float: 42, Valid: true}, dst: &i32, expected: int32(42)}, - {src: pgtype.Float8{Float: 42, Valid: true}, dst: &i64, expected: int64(42)}, - {src: pgtype.Float8{Float: 42, Valid: true}, dst: &i, expected: int(42)}, - {src: pgtype.Float8{Float: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Float8{Float: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Float8{Float: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Float8{Float: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Float8{Float: 42, Valid: true}, dst: &ui, expected: uint(42)}, - {src: pgtype.Float8{Float: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Float8{Float: 0}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Float8{Float: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Float8 - dst interface{} - expected interface{} - }{ - {src: pgtype.Float8{Float: 42, Valid: true}, dst: &pf32, expected: float32(42)}, - {src: pgtype.Float8{Float: 42, Valid: true}, dst: &pf64, expected: float64(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Float8 - dst interface{} - }{ - {src: pgtype.Float8{Float: 150, Valid: true}, dst: &i8}, - {src: pgtype.Float8{Float: 40000, Valid: true}, dst: &i16}, - {src: pgtype.Float8{Float: -1, Valid: true}, dst: &ui8}, - {src: pgtype.Float8{Float: -1, Valid: true}, dst: &ui16}, - {src: pgtype.Float8{Float: -1, Valid: true}, dst: &ui32}, - {src: pgtype.Float8{Float: -1, Valid: true}, dst: &ui64}, - {src: pgtype.Float8{Float: -1, Valid: true}, dst: &ui}, - {src: pgtype.Float8{Float: 0}, dst: &i32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 89c7b348..cb351fbd 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -264,8 +264,8 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementCodec: ByteaCodec{}, ElementOID: ByteaOID}}) ci.RegisterDataType(DataType{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementCodec: InetCodec{}, ElementOID: CIDROID}}) ci.RegisterDataType(DataType{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementCodec: DateCodec{}, ElementOID: DateOID}}) - ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID}) - ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID}) + ci.RegisterDataType(DataType{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementCodec: Float4Codec{}, ElementOID: Float4OID}}) + ci.RegisterDataType(DataType{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementCodec: Float8Codec{}, ElementOID: Float8OID}}) ci.RegisterDataType(DataType{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementCodec: InetCodec{}, ElementOID: InetOID}}) ci.RegisterDataType(DataType{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementCodec: Int2Codec{}, ElementOID: Int2OID}}) ci.RegisterDataType(DataType{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementCodec: Int4Codec{}, ElementOID: Int4OID}}) @@ -297,8 +297,8 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) ci.RegisterDataType(DataType{Name: "date", OID: DateOID, Codec: DateCodec{}}) // ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) - ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID}) - ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID}) + ci.RegisterDataType(DataType{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) + ci.RegisterDataType(DataType{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) ci.RegisterDataType(DataType{Name: "inet", OID: InetOID, Codec: InetCodec{}}) ci.RegisterDataType(DataType{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) ci.RegisterDataType(DataType{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) @@ -824,6 +824,32 @@ type WrappedScanPlanNextSetter interface { func tryWrapBuiltinTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter, nextDst interface{}, ok bool) { switch dst := dst.(type) { + case *int8: + return &wrapInt8ScanPlan{}, (*int8Wrapper)(dst), true + case *int16: + return &wrapInt16ScanPlan{}, (*int16Wrapper)(dst), true + case *int32: + return &wrapInt32ScanPlan{}, (*int32Wrapper)(dst), true + case *int64: + return &wrapInt64ScanPlan{}, (*int64Wrapper)(dst), true + case *int: + return &wrapIntScanPlan{}, (*intWrapper)(dst), true + case *uint8: + return &wrapUint8ScanPlan{}, (*uint8Wrapper)(dst), true + case *uint16: + return &wrapUint16ScanPlan{}, (*uint16Wrapper)(dst), true + case *uint32: + return &wrapUint32ScanPlan{}, (*uint32Wrapper)(dst), true + case *uint64: + return &wrapUint64ScanPlan{}, (*uint64Wrapper)(dst), true + case *uint: + return &wrapUintScanPlan{}, (*uintWrapper)(dst), true + case *float32: + return &wrapFloat32ScanPlan{}, (*float32Wrapper)(dst), true + case *float64: + return &wrapFloat64ScanPlan{}, (*float64Wrapper)(dst), true + case *string: + return &wrapStringScanPlan{}, (*stringWrapper)(dst), true case *time.Time: return &wrapTimeScanPlan{}, (*timeWrapper)(dst), true case *net.IPNet: @@ -835,6 +861,136 @@ func tryWrapBuiltinTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter return nil, nil, false } +type wrapInt8ScanPlan struct { + next ScanPlan +} + +func (plan *wrapInt8ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapInt8ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*int8Wrapper)(dst.(*int8))) +} + +type wrapInt16ScanPlan struct { + next ScanPlan +} + +func (plan *wrapInt16ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapInt16ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*int16Wrapper)(dst.(*int16))) +} + +type wrapInt32ScanPlan struct { + next ScanPlan +} + +func (plan *wrapInt32ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapInt32ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*int32Wrapper)(dst.(*int32))) +} + +type wrapInt64ScanPlan struct { + next ScanPlan +} + +func (plan *wrapInt64ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapInt64ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*int64Wrapper)(dst.(*int64))) +} + +type wrapIntScanPlan struct { + next ScanPlan +} + +func (plan *wrapIntScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapIntScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*intWrapper)(dst.(*int))) +} + +type wrapUint8ScanPlan struct { + next ScanPlan +} + +func (plan *wrapUint8ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapUint8ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*uint8Wrapper)(dst.(*uint8))) +} + +type wrapUint16ScanPlan struct { + next ScanPlan +} + +func (plan *wrapUint16ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapUint16ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*uint16Wrapper)(dst.(*uint16))) +} + +type wrapUint32ScanPlan struct { + next ScanPlan +} + +func (plan *wrapUint32ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapUint32ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*uint32Wrapper)(dst.(*uint32))) +} + +type wrapUint64ScanPlan struct { + next ScanPlan +} + +func (plan *wrapUint64ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapUint64ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*uint64Wrapper)(dst.(*uint64))) +} + +type wrapUintScanPlan struct { + next ScanPlan +} + +func (plan *wrapUintScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapUintScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*uintWrapper)(dst.(*uint))) +} + +type wrapFloat32ScanPlan struct { + next ScanPlan +} + +func (plan *wrapFloat32ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapFloat32ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*float32Wrapper)(dst.(*float32))) +} + +type wrapFloat64ScanPlan struct { + next ScanPlan +} + +func (plan *wrapFloat64ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapFloat64ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*float64Wrapper)(dst.(*float64))) +} + +type wrapStringScanPlan struct { + next ScanPlan +} + +func (plan *wrapStringScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapStringScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*stringWrapper)(dst.(*string))) +} + type wrapTimeScanPlan struct { next ScanPlan } diff --git a/pgtype/zeronull/float8.go b/pgtype/zeronull/float8.go index a4efb1ed..1c29d331 100644 --- a/pgtype/zeronull/float8.go +++ b/pgtype/zeronull/float8.go @@ -8,68 +8,29 @@ import ( type Float8 float64 -func (dst *Float8) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - var nullable pgtype.Float8 - err := nullable.DecodeText(ci, src) - if err != nil { - return err +// ScanFloat64 implements the Float64Scanner interface. +func (f *Float8) ScanFloat64(n pgtype.Float8) error { + if !n.Valid { + *f = 0 + return nil } - if nullable.Valid { - *dst = Float8(nullable.Float) - } else { - *dst = 0 - } + *f = Float8(n.Float) return nil } -func (dst *Float8) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - var nullable pgtype.Float8 - err := nullable.DecodeBinary(ci, src) - if err != nil { - return err +func (f Float8) Float64Value() (pgtype.Float8, error) { + if f == 0 { + return pgtype.Float8{}, nil } - - if nullable.Valid { - *dst = Float8(nullable.Float) - } else { - *dst = 0 - } - - return nil -} - -func (src Float8) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if src == 0 { - return nil, nil - } - - nullable := pgtype.Float8{ - Float: float64(src), - Valid: true, - } - - return nullable.EncodeText(ci, buf) -} - -func (src Float8) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if src == 0 { - return nil, nil - } - - nullable := pgtype.Float8{ - Float: float64(src), - Valid: true, - } - - return nullable.EncodeBinary(ci, buf) + return pgtype.Float8{Float: float64(f), Valid: true}, nil } // Scan implements the database/sql Scanner interface. -func (dst *Float8) Scan(src interface{}) error { +func (f *Float8) Scan(src interface{}) error { if src == nil { - *dst = 0 + *f = 0 return nil } @@ -79,12 +40,15 @@ func (dst *Float8) Scan(src interface{}) error { return err } - *dst = Float8(nullable.Float) + *f = Float8(nullable.Float) return nil } // Value implements the database/sql/driver Valuer interface. -func (src Float8) Value() (driver.Value, error) { - return pgtype.EncodeValueText(src) +func (f Float8) Value() (driver.Value, error) { + if f == 0 { + return nil, nil + } + return float64(f), nil } From a6863a7dd2f9efa3dc2a13f396c49cc7da6f7c13 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 15 Jan 2022 17:47:37 -0600 Subject: [PATCH 0837/1158] Convert Hstore to Codec --- pgtype/builtin_wrappers.go | 37 +++ pgtype/hstore.go | 357 ++++++++++++++------------- pgtype/hstore_array.go | 476 ------------------------------------ pgtype/hstore_array_test.go | 436 --------------------------------- pgtype/hstore_test.go | 334 ++++++++++++------------- pgtype/pgtype.go | 48 ++++ 6 files changed, 432 insertions(+), 1256 deletions(-) delete mode 100644 pgtype/hstore_array.go delete mode 100644 pgtype/hstore_array_test.go diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index df955f18..15d4e083 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -381,3 +381,40 @@ func (w netIPWrapper) InetValue() (Inet, error) { mask := net.CIDRMask(bitCount, bitCount) return Inet{IPNet: &net.IPNet{Mask: mask, IP: net.IP(w)}, Valid: true}, nil } + +type mapStringToPointerStringWrapper map[string]*string + +func (w *mapStringToPointerStringWrapper) ScanHstore(v Hstore) error { + *w = mapStringToPointerStringWrapper(v) + return nil +} + +func (w mapStringToPointerStringWrapper) HstoreValue() (Hstore, error) { + return Hstore(w), nil +} + +type mapStringToStringWrapper map[string]string + +func (w *mapStringToStringWrapper) ScanHstore(v Hstore) error { + *w = make(mapStringToStringWrapper, len(v)) + for k, v := range v { + if v == nil { + return fmt.Errorf("cannot scan NULL to string") + } + (*w)[k] = *v + } + return nil +} + +func (w mapStringToStringWrapper) HstoreValue() (Hstore, error) { + if w == nil { + return nil, nil + } + + hstore := make(Hstore, len(w)) + for k, v := range w { + s := v + hstore[k] = &s + } + return hstore, nil +} diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 69c8a07b..6ff8164c 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -13,114 +13,168 @@ import ( "github.com/jackc/pgio" ) +type HstoreScanner interface { + ScanHstore(v Hstore) error +} + +type HstoreValuer interface { + HstoreValue() (Hstore, error) +} + // Hstore represents an hstore column that can be null or have null values // associated with its keys. -type Hstore struct { - Map map[string]Text - Valid bool -} - -func (dst *Hstore) Set(src interface{}) error { - if src == nil { - *dst = Hstore{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case map[string]string: - m := make(map[string]Text, len(value)) - for k, v := range value { - m[k] = Text{String: v, Valid: true} - } - *dst = Hstore{Map: m, Valid: true} - case map[string]*string: - m := make(map[string]Text, len(value)) - for k, v := range value { - if v == nil { - m[k] = Text{} - } else { - m[k] = Text{String: *v, Valid: true} - } - } - *dst = Hstore{Map: m, Valid: true} - default: - return fmt.Errorf("cannot convert %v to Hstore", src) - } +type Hstore map[string]*string +func (h *Hstore) ScanHstore(v Hstore) error { + *h = v return nil } -func (dst Hstore) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.Map +func (h Hstore) HstoreValue() (Hstore, error) { + return h, nil } -func (src *Hstore) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case *map[string]string: - *v = make(map[string]string, len(src.Map)) - for k, val := range src.Map { - if !val.Valid { - return fmt.Errorf("cannot decode %#v into %T", src, dst) - } - (*v)[k] = val.String - } - return nil - case *map[string]*string: - *v = make(map[string]*string, len(src.Map)) - for k, val := range src.Map { - if val.Valid { - (*v)[k] = &val.String - } else { - (*v)[k] = nil - } - } - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } -} - -func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { +// Scan implements the database/sql Scanner interface. +func (h *Hstore) Scan(src interface{}) error { if src == nil { - *dst = Hstore{} + *h = nil return nil } - keys, values, err := parseHstore(string(src)) + switch src := src.(type) { + case string: + return scanPlanTextAnyToHstoreScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), h) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (h Hstore) Value() (driver.Value, error) { + if h == nil { + return nil, nil + } + + buf, err := HstoreCodec{}.PlanEncode(nil, 0, TextFormatCode, h).Encode(h, nil) if err != nil { - return err + return nil, err + } + return string(buf), err +} + +type HstoreCodec struct{} + +func (HstoreCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (HstoreCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (HstoreCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(HstoreValuer); !ok { + return nil } - m := make(map[string]Text, len(keys)) - for i := range keys { - m[keys[i]] = values[i] + switch format { + case BinaryFormatCode: + return encodePlanHstoreCodecBinary{} + case TextFormatCode: + return encodePlanHstoreCodecText{} } - *dst = Hstore{Map: m, Valid: true} return nil } -func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { +type encodePlanHstoreCodecBinary struct{} + +func (encodePlanHstoreCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + hstore, err := value.(HstoreValuer).HstoreValue() + if err != nil { + return nil, err + } + + if hstore == nil { + return nil, nil + } + + buf = pgio.AppendInt32(buf, int32(len(hstore))) + + for k, v := range hstore { + buf = pgio.AppendInt32(buf, int32(len(k))) + buf = append(buf, k...) + + if v == nil { + buf = pgio.AppendInt32(buf, -1) + } else { + buf = pgio.AppendInt32(buf, int32(len(*v))) + buf = append(buf, (*v)...) + } + } + + return buf, nil +} + +type encodePlanHstoreCodecText struct{} + +func (encodePlanHstoreCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + hstore, err := value.(HstoreValuer).HstoreValue() + if err != nil { + return nil, err + } + + if hstore == nil { + return nil, nil + } + + firstPair := true + + for k, v := range hstore { + if firstPair { + firstPair = false + } else { + buf = append(buf, ',') + } + + buf = append(buf, quoteHstoreElementIfNeeded(k)...) + buf = append(buf, "=>"...) + + if v == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, quoteHstoreElementIfNeeded(*v)...) + } + } + + return buf, nil +} + +func (HstoreCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case HstoreScanner: + return scanPlanBinaryHstoreToHstoreScanner{} + } + case TextFormatCode: + switch target.(type) { + case HstoreScanner: + return scanPlanTextAnyToHstoreScanner{} + } + } + + return nil +} + +type scanPlanBinaryHstoreToHstoreScanner struct{} + +func (scanPlanBinaryHstoreToHstoreScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(HstoreScanner) + if src == nil { - *dst = Hstore{} - return nil + return scanner.ScanHstore(Hstore{}) } rp := 0 @@ -131,7 +185,7 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 - m := make(map[string]Text, pairCount) + hstore := make(Hstore, pairCount) for i := 0; i < pairCount; i++ { if len(src[rp:]) < 4 { @@ -163,73 +217,58 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { if err != nil { return err } - m[key] = value + + if value.Valid { + hstore[key] = &value.String + } else { + hstore[key] = nil + } } - *dst = Hstore{Map: m, Valid: true} - - return nil + return scanner.ScanHstore(hstore) } -func (src Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { +type scanPlanTextAnyToHstoreScanner struct{} + +func (scanPlanTextAnyToHstoreScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(HstoreScanner) + + if src == nil { + return scanner.ScanHstore(Hstore{}) + } + + keys, values, err := parseHstore(string(src)) + if err != nil { + return err + } + + m := make(Hstore, len(keys)) + for i := range keys { + if values[i].Valid { + m[keys[i]] = &values[i].String + } else { + m[keys[i]] = nil + } + } + + return scanner.ScanHstore(m) +} + +func (c HstoreCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c HstoreCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { return nil, nil } - firstPair := true - - inElemBuf := make([]byte, 0, 32) - for k, v := range src.Map { - if firstPair { - firstPair = false - } else { - buf = append(buf, ',') - } - - buf = append(buf, quoteHstoreElementIfNeeded(k)...) - buf = append(buf, "=>"...) - - elemBuf, err := ci.Encode(TextOID, TextFormatCode, v, inElemBuf) - if err != nil { - return nil, err - } - - if elemBuf == nil { - buf = append(buf, "NULL"...) - } else { - buf = append(buf, quoteHstoreElementIfNeeded(string(elemBuf))...) - } + var hstore Hstore + err := codecScan(c, ci, oid, format, src, &hstore) + if err != nil { + return nil, err } - - return buf, nil -} - -func (src Hstore) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = pgio.AppendInt32(buf, int32(len(src.Map))) - - var err error - for k, v := range src.Map { - buf = pgio.AppendInt32(buf, int32(len(k))) - buf = append(buf, k...) - - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := ci.Encode(TextOID, BinaryFormatCode, v, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, err + return hstore, nil } var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) @@ -420,27 +459,3 @@ func parseHstore(s string) (k []string, v []Text, err error) { v = values return } - -// Scan implements the database/sql Scanner interface. -func (dst *Hstore) Scan(src interface{}) error { - if src == nil { - *dst = Hstore{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Hstore) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go deleted file mode 100644 index 0ca5d4fb..00000000 --- a/pgtype/hstore_array.go +++ /dev/null @@ -1,476 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type HstoreArray struct { - Elements []Hstore - Dimensions []ArrayDimension - Valid bool -} - -func (dst *HstoreArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = HstoreArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []map[string]string: - if value == nil { - *dst = HstoreArray{} - } else if len(value) == 0 { - *dst = HstoreArray{Valid: true} - } else { - elements := make([]Hstore, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = HstoreArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Hstore: - if value == nil { - *dst = HstoreArray{} - } else if len(value) == 0 { - *dst = HstoreArray{Valid: true} - } else { - *dst = HstoreArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = HstoreArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for HstoreArray", src) - } - if elementsLength == 0 { - *dst = HstoreArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to HstoreArray", src) - } - - *dst = HstoreArray{ - Elements: make([]Hstore, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Hstore, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to HstoreArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *HstoreArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to HstoreArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in HstoreArray", err) - } - index++ - - return index, nil -} - -func (dst HstoreArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *HstoreArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]map[string]string: - *v = make([]map[string]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *HstoreArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from HstoreArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from HstoreArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = HstoreArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Hstore - - if len(uta.Elements) > 0 { - elements = make([]Hstore, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Hstore - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = HstoreArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = HstoreArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = HstoreArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Hstore, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = HstoreArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("hstore"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "hstore") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *HstoreArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src HstoreArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/hstore_array_test.go b/pgtype/hstore_array_test.go deleted file mode 100644 index 7912b626..00000000 --- a/pgtype/hstore_array_test.go +++ /dev/null @@ -1,436 +0,0 @@ -package pgtype_test - -import ( - "context" - "reflect" - "testing" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestHstoreArrayTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - var hstoreOID uint32 - err := conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='hstore';").Scan(&hstoreOID) - if err != nil { - t.Fatalf("did not find hstore OID, %v", err) - } - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.Hstore{}, Name: "hstore", OID: hstoreOID}) - - var hstoreArrayOID uint32 - err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='_hstore';").Scan(&hstoreArrayOID) - if err != nil { - t.Fatalf("did not find _hstore OID, %v", err) - } - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.HstoreArray{}, Name: "_hstore", OID: hstoreArrayOID}) - - text := func(s string) pgtype.Text { - return pgtype.Text{String: s, Valid: true} - } - - values := []pgtype.Hstore{ - {Map: map[string]pgtype.Text{}, Valid: true}, - {Map: map[string]pgtype.Text{"foo": text("bar")}, Valid: true}, - {Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Valid: true}, - {Map: map[string]pgtype.Text{"NULL": text("bar")}, Valid: true}, - {Map: map[string]pgtype.Text{"foo": text("NULL")}, Valid: true}, - {}, - } - - specialStrings := []string{ - `"`, - `'`, - `\`, - `\\`, - `=>`, - ` `, - `\ / / \\ => " ' " '`, - } - for _, s := range specialStrings { - // Special key values - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Valid: true}) // at beginning - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Valid: true}) // in middle - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Valid: true}) // at end - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Valid: true}) // is key - - // Special value values - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Valid: true}) // at beginning - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Valid: true}) // in middle - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Valid: true}) // at end - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Valid: true}) // is key - } - - src := &pgtype.HstoreArray{ - Elements: values, - Dimensions: []pgtype.ArrayDimension{{Length: int32(len(values)), LowerBound: 1}}, - Valid: true, - } - - _, err = conn.Prepare(context.Background(), "test", "select $1::hstore[]") - if err != nil { - t.Fatal(err) - } - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for _, fc := range formats { - queryResultFormats := pgx.QueryResultFormats{fc.formatCode} - vEncoder := testutil.ForceEncoder(src, fc.formatCode) - if vEncoder == nil { - t.Logf("%#v does not implement %v", src, fc.name) - continue - } - - var result pgtype.HstoreArray - err := conn.QueryRow(context.Background(), "test", queryResultFormats, vEncoder).Scan(&result) - if err != nil { - t.Errorf("%v: %v", fc.name, err) - continue - } - - if result.Valid != src.Valid { - t.Errorf("%v: expected Valid %v, got %v", fc.formatCode, src.Valid, result.Valid) - continue - } - - if len(result.Elements) != len(src.Elements) { - t.Errorf("%v: expected %v elements, got %v", fc.formatCode, len(src.Elements), len(result.Elements)) - continue - } - - for i := range result.Elements { - a := src.Elements[i] - b := result.Elements[i] - - if a.Valid != b.Valid { - t.Errorf("%v element idx %d: expected Valid %v, got %v", fc.formatCode, i, a.Valid, b.Valid) - } - - if len(a.Map) != len(b.Map) { - t.Errorf("%v element idx %d: expected %v pairs, got %v", fc.formatCode, i, len(a.Map), len(b.Map)) - } - - for k := range a.Map { - if a.Map[k] != b.Map[k] { - t.Errorf("%v element idx %d: expected key %v to be %v, got %v", fc.formatCode, i, k, a.Map[k], b.Map[k]) - } - } - } - } -} - -func TestHstoreArraySet(t *testing.T) { - successfulTests := []struct { - src interface{} - result pgtype.HstoreArray - }{ - { - src: []map[string]string{{"foo": "bar"}}, - result: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - }, - { - src: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, - result: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true, - }, - }, - { - src: [][][][]map[string]string{ - {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, - {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, - result: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"bar": {String: "baz", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wibble": {String: "wobble", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wubble": {String: "wabble", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wabble": {String: "wobble", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true, - }, - }, - { - src: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, - result: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true, - }, - }, - { - src: [2][1][1][3]map[string]string{ - {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, - {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, - result: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"bar": {String: "baz", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wibble": {String: "wobble", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wubble": {String: "wabble", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wabble": {String: "wobble", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true, - }, - }, - } - - for i, tt := range successfulTests { - var dst pgtype.HstoreArray - err := dst.Set(tt.src) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(dst, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) - } - } -} - -func TestHstoreArrayAssignTo(t *testing.T) { - var hstoreSlice []map[string]string - var hstoreSliceDim2 [][]map[string]string - var hstoreSliceDim4 [][][][]map[string]string - var hstoreArrayDim2 [2][1]map[string]string - var hstoreArrayDim4 [2][1][1][3]map[string]string - - simpleTests := []struct { - src pgtype.HstoreArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &hstoreSlice, - expected: []map[string]string{{"foo": "bar"}}}, - { - src: pgtype.HstoreArray{}, dst: &hstoreSlice, expected: (([]map[string]string)(nil)), - }, - { - src: pgtype.HstoreArray{Valid: true}, dst: &hstoreSlice, expected: []map[string]string{}, - }, - { - src: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &hstoreSliceDim2, - expected: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, - }, - { - src: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"bar": {String: "baz", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wibble": {String: "wobble", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wubble": {String: "wabble", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wabble": {String: "wobble", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true, - }, - dst: &hstoreSliceDim4, - expected: [][][][]map[string]string{ - {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, - {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, - }, - { - src: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &hstoreArrayDim2, - expected: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, - }, - { - src: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"bar": {String: "baz", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wibble": {String: "wobble", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wubble": {String: "wabble", Valid: true}}, - Valid: true, - }, - { - Map: map[string]pgtype.Text{"wabble": {String: "wobble", Valid: true}}, - Valid: true, - }, - }, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true, - }, - dst: &hstoreArrayDim4, - expected: [2][1][1][3]map[string]string{ - {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, - {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go index dd80f0c5..edd94db7 100644 --- a/pgtype/hstore_test.go +++ b/pgtype/hstore_test.go @@ -1,31 +1,124 @@ package pgtype_test import ( - "reflect" + "context" "testing" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestHstoreTranscode(t *testing.T) { - text := func(s string) pgtype.Text { - return pgtype.Text{String: s, Valid: true} +func isExpectedEqMapStringString(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + am := a.(map[string]string) + vm := v.(map[string]string) + + if len(am) != len(vm) { + return false + } + + for k, v := range am { + if vm[k] != v { + return false + } + } + + return true + } +} + +func isExpectedEqMapStringPointerString(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + am := a.(map[string]*string) + vm := v.(map[string]*string) + + if len(am) != len(vm) { + return false + } + + for k, v := range am { + if (vm[k] == nil) != (v == nil) { + return false + } + + if v != nil && *vm[k] != *v { + return false + } + } + + return true + } +} + +func TestHstoreCodec(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + var hstoreOID uint32 + err := conn.QueryRow(context.Background(), `select oid from pg_type where typname = 'hstore'`).Scan(&hstoreOID) + if err != nil { + t.Skipf("Skipping: cannot find hstore OID") } - values := []interface{}{ - &pgtype.Hstore{Map: map[string]pgtype.Text{}, Valid: true}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(""), "bar": text(""), "baz": text("123")}, Valid: true}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Valid: true}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Valid: true}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Valid: true}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Valid: true}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"": text("bar")}, Valid: true}, - &pgtype.Hstore{ - Map: map[string]pgtype.Text{"a": text("a"), "b": {}, "c": text("c"), "d": {}, "e": text("e")}, - Valid: true, + conn.ConnInfo().RegisterDataType(pgtype.DataType{Name: "hstore", OID: hstoreOID, Codec: pgtype.HstoreCodec{}}) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + fs := func(s string) *string { + return &s + } + + tests := []PgxTranscodeTestCase{ + { + map[string]string{}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{}), }, - &pgtype.Hstore{}, + { + map[string]string{"foo": "", "bar": "", "baz": "123"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": "", "bar": "", "baz": "123"}), + }, + { + map[string]string{"NULL": "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"NULL": "bar"}), + }, + { + map[string]string{"bar": "NULL"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"bar": "NULL"}), + }, + { + map[string]string{"": "foo"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"": "foo"}), + }, + { + map[string]*string{}, + new(map[string]*string), + isExpectedEqMapStringPointerString(map[string]*string{}), + }, + { + map[string]*string{"foo": fs("bar"), "baq": fs("quz")}, + new(map[string]*string), + isExpectedEqMapStringPointerString(map[string]*string{"foo": fs("bar"), "baq": fs("quz")}), + }, + { + map[string]*string{"foo": nil, "baq": fs("quz")}, + new(map[string]*string), + isExpectedEqMapStringPointerString(map[string]*string{"foo": nil, "baq": fs("quz")}), + }, + {nil, new(*map[string]string), isExpectedEq((*map[string]string)(nil))}, + {nil, new(*map[string]*string), isExpectedEq((*map[string]*string)(nil))}, + {nil, new(*pgtype.Hstore), isExpectedEq((*pgtype.Hstore)(nil))}, } specialStrings := []string{ @@ -39,166 +132,61 @@ func TestHstoreTranscode(t *testing.T) { } for _, s := range specialStrings { // Special key values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Valid: true}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Valid: true}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Valid: true}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Valid: true}) // is key + + // at beginning + tests = append(tests, PgxTranscodeTestCase{ + map[string]string{s + "foo": "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{s + "foo": "bar"}), + }) + // in middle + tests = append(tests, PgxTranscodeTestCase{ + map[string]string{"foo" + s + "bar": "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo" + s + "bar": "bar"}), + }) + // at end + tests = append(tests, PgxTranscodeTestCase{ + map[string]string{"foo" + s: "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo" + s: "bar"}), + }) + // is key + tests = append(tests, PgxTranscodeTestCase{ + map[string]string{s: "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{s: "bar"}), + }) // Special value values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Valid: true}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Valid: true}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Valid: true}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Valid: true}) // is key + + // at beginning + tests = append(tests, PgxTranscodeTestCase{ + map[string]string{"foo": s + "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": s + "bar"}), + }) + // in middle + tests = append(tests, PgxTranscodeTestCase{ + map[string]string{"foo": "foo" + s + "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": "foo" + s + "bar"}), + }) + // at end + tests = append(tests, PgxTranscodeTestCase{ + map[string]string{"foo": "foo" + s}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": "foo" + s}), + }) + // is key + tests = append(tests, PgxTranscodeTestCase{ + map[string]string{"foo": s}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": s}), + }) } - testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { - a := ai.(pgtype.Hstore) - b := bi.(pgtype.Hstore) - - if len(a.Map) != len(b.Map) || a.Valid != b.Valid { - return false - } - - for k := range a.Map { - if a.Map[k] != b.Map[k] { - return false - } - } - - return true - }) -} - -func TestHstoreTranscodeNullable(t *testing.T) { - text := func(s string, valid bool) pgtype.Text { - return pgtype.Text{String: s, Valid: valid} - } - - values := []interface{}{ - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("", false)}, Valid: true}, - } - - specialStrings := []string{ - `"`, - `'`, - `\`, - `\\`, - `=>`, - ` `, - `\ / / \\ => " ' " '`, - } - for _, s := range specialStrings { - // Special key values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("", false)}, Valid: true}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("", false)}, Valid: true}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("", false)}, Valid: true}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("", false)}, Valid: true}) // is key - } - - testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { - a := ai.(pgtype.Hstore) - b := bi.(pgtype.Hstore) - - if len(a.Map) != len(b.Map) || a.Valid != b.Valid { - return false - } - - for k := range a.Map { - if a.Map[k] != b.Map[k] { - return false - } - } - - return true - }) -} - -func TestHstoreSet(t *testing.T) { - successfulTests := []struct { - src map[string]string - result pgtype.Hstore - }{ - {src: map[string]string{"foo": "bar"}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, Valid: true}}, - } - - for i, tt := range successfulTests { - var dst pgtype.Hstore - err := dst.Set(tt.src) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(dst, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) - } - } -} - -func TestHstoreSetNullable(t *testing.T) { - successfulTests := []struct { - src map[string]*string - result pgtype.Hstore - }{ - {src: map[string]*string{"foo": nil}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {}}, Valid: true}}, - } - - for i, tt := range successfulTests { - var dst pgtype.Hstore - err := dst.Set(tt.src) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(dst, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) - } - } -} - -func TestHstoreAssignTo(t *testing.T) { - var m map[string]string - - simpleTests := []struct { - src pgtype.Hstore - dst *map[string]string - expected map[string]string - }{ - {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, Valid: true}, dst: &m, expected: map[string]string{"foo": "bar"}}, - {src: pgtype.Hstore{}, dst: &m, expected: ((map[string]string)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(*tt.dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } -} - -func TestHstoreAssignToNullable(t *testing.T) { - var m map[string]*string - - simpleTests := []struct { - src pgtype.Hstore - dst *map[string]*string - expected map[string]*string - }{ - {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {}}, Valid: true}, dst: &m, expected: map[string]*string{"foo": nil}}, - {src: pgtype.Hstore{}, dst: &m, expected: ((map[string]*string)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(*tt.dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } + for _, format := range formats { + testPgxCodecFormat(t, "hstore", tests, conn, format.name, format.code) } } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index cb351fbd..5d0ed882 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -856,6 +856,10 @@ func tryWrapBuiltinTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter return &wrapNetIPNetScanPlan{}, (*netIPNetWrapper)(dst), true case *net.IP: return &wrapNetIPScanPlan{}, (*netIPWrapper)(dst), true + case *map[string]*string: + return &wrapMapStringToPointerStringScanPlan{}, (*mapStringToPointerStringWrapper)(dst), true + case *map[string]string: + return &wrapMapStringToStringScanPlan{}, (*mapStringToStringWrapper)(dst), true } return nil, nil, false @@ -1021,6 +1025,26 @@ func (plan *wrapNetIPScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, return plan.next.Scan(ci, oid, formatCode, src, (*netIPWrapper)(dst.(*net.IP))) } +type wrapMapStringToPointerStringScanPlan struct { + next ScanPlan +} + +func (plan *wrapMapStringToPointerStringScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapMapStringToPointerStringScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*mapStringToPointerStringWrapper)(dst.(*map[string]*string))) +} + +type wrapMapStringToStringScanPlan struct { + next ScanPlan +} + +func (plan *wrapMapStringToStringScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapMapStringToStringScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*mapStringToStringWrapper)(dst.(*map[string]string))) +} + type pointerEmptyInterfaceScanPlan struct { codec Codec } @@ -1362,6 +1386,10 @@ func tryWrapBuiltinTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNext return &wrapNetIPNetEncodePlan{}, netIPNetWrapper(value), true case net.IP: return &wrapNetIPEncodePlan{}, netIPWrapper(value), true + case map[string]*string: + return &wrapMapStringToPointerStringEncodePlan{}, mapStringToPointerStringWrapper(value), true + case map[string]string: + return &wrapMapStringToStringEncodePlan{}, mapStringToStringWrapper(value), true } return nil, nil, false @@ -1527,6 +1555,26 @@ func (plan *wrapNetIPEncodePlan) Encode(value interface{}, buf []byte) (newBuf [ return plan.next.Encode(netIPWrapper(value.(net.IP)), buf) } +type wrapMapStringToPointerStringEncodePlan struct { + next EncodePlan +} + +func (plan *wrapMapStringToPointerStringEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapMapStringToPointerStringEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(mapStringToPointerStringWrapper(value.(map[string]*string)), buf) +} + +type wrapMapStringToStringEncodePlan struct { + next EncodePlan +} + +func (plan *wrapMapStringToStringEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapMapStringToStringEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(mapStringToStringWrapper(value.(map[string]string)), buf) +} + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written. From 67720623f8115048369ce6ff3ae77bcb54467a5a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 15 Jan 2022 18:43:52 -0600 Subject: [PATCH 0838/1158] Extract plan wrapper concept --- pgtype/pgtype.go | 73 ++++++++++++++++++++++++------------------------ 1 file changed, 36 insertions(+), 37 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5d0ed882..c0d02197 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -730,11 +730,15 @@ func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byt return newPlan.Scan(ci, oid, formatCode, src, dst) } +type tryWrapScanPlanFunc func(dst interface{}) (plan WrappedScanPlanNextSetter, nextDst interface{}, ok bool) + type pointerPointerScanPlan struct { dstType reflect.Type next ScanPlan } +func (plan *pointerPointerScanPlan) SetNext(next ScanPlan) { plan.next = next } + func (plan *pointerPointerScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if plan.dstType != reflect.TypeOf(dst) { newPlan := ci.PlanScan(oid, formatCode, dst) @@ -751,7 +755,7 @@ func (plan *pointerPointerScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode in return plan.next.Scan(ci, oid, formatCode, src, el.Interface()) } -func tryPointerPointerScanPlan(dst interface{}) (plan *pointerPointerScanPlan, nextDst interface{}, ok bool) { +func tryPointerPointerScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter, nextDst interface{}, ok bool) { if dstValue := reflect.ValueOf(dst); dstValue.Kind() == reflect.Ptr { elemValue := dstValue.Elem() if elemValue.Kind() == reflect.Ptr { @@ -790,6 +794,8 @@ type underlyingTypeScanPlan struct { next ScanPlan } +func (plan *underlyingTypeScanPlan) SetNext(next ScanPlan) { plan.next = next } + func (plan *underlyingTypeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if plan.dstType != reflect.TypeOf(dst) { newPlan := ci.PlanScan(oid, formatCode, dst) @@ -799,7 +805,7 @@ func (plan *underlyingTypeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode in return plan.next.Scan(ci, oid, formatCode, src, reflect.ValueOf(dst).Convert(plan.nextDstType).Interface()) } -func tryUnderlyingTypeScanPlan(dst interface{}) (plan *underlyingTypeScanPlan, nextDst interface{}, ok bool) { +func tryUnderlyingTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter, nextDst interface{}, ok bool) { if _, ok := dst.(SkipUnderlyingTypePlanner); ok { return nil, nil, false } @@ -1128,24 +1134,18 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan return plan } - if pointerPointerPlan, nextDst, ok := tryPointerPointerScanPlan(dst); ok { - if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { - pointerPointerPlan.next = nextPlan - return pointerPointerPlan - } + tryWrappers := []tryWrapScanPlanFunc{ + tryPointerPointerScanPlan, + tryUnderlyingTypeScanPlan, + tryWrapBuiltinTypeScanPlan, } - if baseTypePlan, nextDst, ok := tryUnderlyingTypeScanPlan(dst); ok { - if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { - baseTypePlan.next = nextPlan - return baseTypePlan - } - } - - if wrapperPlan, nextValue, ok := tryWrapBuiltinTypeScanPlan(dst); ok { - if nextPlan := ci.PlanScan(oid, formatCode, nextValue); nextPlan != nil { - wrapperPlan.SetNext(nextPlan) - return wrapperPlan + for _, f := range tryWrappers { + if wrapperPlan, nextDst, ok := f(dst); ok { + if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } } } @@ -1259,36 +1259,33 @@ func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) Enco return plan } - if derefPointerPlan, nextValue, ok := tryDerefPointerEncodePlan(value); ok { - if nextPlan := ci.PlanEncode(oid, format, nextValue); nextPlan != nil { - derefPointerPlan.next = nextPlan - return derefPointerPlan - } + tryWrappers := []tryWrapEncodePlanFunc{ + tryDerefPointerEncodePlan, + tryUnderlyingTypeEncodePlan, + tryWrapBuiltinTypeEncodePlan, } - if baseTypePlan, nextValue, ok := tryUnderlyingTypeEncodePlan(value); ok { - if nextPlan := ci.PlanEncode(oid, format, nextValue); nextPlan != nil { - baseTypePlan.next = nextPlan - return baseTypePlan + for _, f := range tryWrappers { + if wrapperPlan, nextValue, ok := f(value); ok { + if nextPlan := ci.PlanEncode(oid, format, nextValue); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } } } - - if wrapperPlan, nextValue, ok := tryWrapBuiltinTypeEncodePlan(value); ok { - if nextPlan := ci.PlanEncode(oid, format, nextValue); nextPlan != nil { - wrapperPlan.SetNext(nextPlan) - return wrapperPlan - } - } - } return nil } +type tryWrapEncodePlanFunc func(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) + type derefPointerEncodePlan struct { next EncodePlan } +func (plan *derefPointerEncodePlan) SetNext(next EncodePlan) { plan.next = next } + func (plan *derefPointerEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { ptr := reflect.ValueOf(value) @@ -1299,7 +1296,7 @@ func (plan *derefPointerEncodePlan) Encode(value interface{}, buf []byte) (newBu return plan.next.Encode(ptr.Elem().Interface(), buf) } -func tryDerefPointerEncodePlan(value interface{}) (plan *derefPointerEncodePlan, nextValue interface{}, ok bool) { +func tryDerefPointerEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { if valueType := reflect.TypeOf(value); valueType.Kind() == reflect.Ptr { return &derefPointerEncodePlan{}, reflect.New(valueType.Elem()).Elem().Interface(), true } @@ -1328,11 +1325,13 @@ type underlyingTypeEncodePlan struct { next EncodePlan } +func (plan *underlyingTypeEncodePlan) SetNext(next EncodePlan) { plan.next = next } + func (plan *underlyingTypeEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(reflect.ValueOf(value).Convert(plan.nextValueType).Interface(), buf) } -func tryUnderlyingTypeEncodePlan(value interface{}) (plan *underlyingTypeEncodePlan, nextValue interface{}, ok bool) { +func tryUnderlyingTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { if _, ok := value.(SkipUnderlyingTypePlanner); ok { return nil, nil, false } From 5472ce9f10de431102fe221e76384392ea90f771 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 15 Jan 2022 18:45:42 -0600 Subject: [PATCH 0839/1158] Reorder Box functions --- pgtype/box.go | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/pgtype/box.go b/pgtype/box.go index 677d4bd2..80e1bd19 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -144,23 +144,6 @@ func (BoxCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfac return nil } -func (c BoxCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) -} - -func (c BoxCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { - if src == nil { - return nil, nil - } - - var box Box - err := codecScan(c, ci, oid, format, src, &box) - if err != nil { - return nil, err - } - return box, nil -} - type scanPlanBinaryBoxToBoxScanner struct{} func (scanPlanBinaryBoxToBoxScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { @@ -236,3 +219,20 @@ func (scanPlanTextAnyToBoxScanner) Scan(ci *ConnInfo, oid uint32, formatCode int return scanner.ScanBox(Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true}) } + +func (c BoxCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c BoxCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var box Box + err := codecScan(c, ci, oid, format, src, &box) + if err != nil { + return nil, err + } + return box, nil +} From 77e4b01553277c03e56a3303f3f8f935509beccd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 15 Jan 2022 18:46:28 -0600 Subject: [PATCH 0840/1158] Convert Interval to Codec --- pgtype/builtin_wrappers.go | 16 +++ pgtype/interval.go | 274 ++++++++++++++++++++++--------------- pgtype/interval_test.go | 188 ++++++++++++++++--------- pgtype/pgtype.go | 32 ++++- 4 files changed, 331 insertions(+), 179 deletions(-) diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 15d4e083..5689b321 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -342,6 +342,22 @@ func (w timeWrapper) DateValue() (Date, error) { return Date{Time: time.Time(w), Valid: true}, nil } +type durationWrapper time.Duration + +func (w *durationWrapper) ScanInterval(v Interval) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Interval") + } + + us := int64(v.Months)*microsecondsPerMonth + int64(v.Days)*microsecondsPerDay + v.Microseconds + *w = durationWrapper(time.Duration(us) * time.Microsecond) + return nil +} + +func (w durationWrapper) IntervalValue() (Interval, error) { + return Interval{Microseconds: int64(w) / 1000, Valid: true}, nil +} + type netIPNetWrapper net.IPNet func (w *netIPNetWrapper) ScanInet(v Inet) error { diff --git a/pgtype/interval.go b/pgtype/interval.go index a92cd41f..41216f37 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -6,7 +6,6 @@ import ( "fmt" "strconv" "strings" - "time" "github.com/jackc/pgio" ) @@ -19,6 +18,14 @@ const ( microsecondsPerMonth = 30 * microsecondsPerDay ) +type IntervalScanner interface { + ScanInterval(v Interval) error +} + +type IntervalValuer interface { + IntervalValue() (Interval, error) +} + type Interval struct { Microseconds int64 Days int32 @@ -26,61 +33,169 @@ type Interval struct { Valid bool } -func (dst *Interval) Set(src interface{}) error { +func (interval *Interval) ScanInterval(v Interval) error { + *interval = v + return nil +} + +func (interval Interval) IntervalValue() (Interval, error) { + return interval, nil +} + +// Scan implements the database/sql Scanner interface. +func (interval *Interval) Scan(src interface{}) error { if src == nil { - *dst = Interval{} + *interval = Interval{} return nil } - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } + switch src := src.(type) { + case string: + return scanPlanTextAnyToIntervalScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), interval) } - switch value := src.(type) { - case time.Duration: - *dst = Interval{Microseconds: int64(value) / 1000, Valid: true} - default: - if originalSrc, ok := underlyingPtrType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Interval", value) + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (interval Interval) Value() (driver.Value, error) { + if !interval.Valid { + return nil, nil + } + + buf, err := IntervalCodec{}.PlanEncode(nil, 0, TextFormatCode, interval).Encode(interval, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type IntervalCodec struct{} + +func (IntervalCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (IntervalCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (IntervalCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(IntervalValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanIntervalCodecBinary{} + case TextFormatCode: + return encodePlanIntervalCodecText{} } return nil } -func (dst Interval) Get() interface{} { - if !dst.Valid { - return nil +type encodePlanIntervalCodecBinary struct{} + +func (encodePlanIntervalCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + interval, err := value.(IntervalValuer).IntervalValue() + if err != nil { + return nil, err } - return dst + + if !interval.Valid { + return nil, nil + } + + buf = pgio.AppendInt64(buf, interval.Microseconds) + buf = pgio.AppendInt32(buf, interval.Days) + buf = pgio.AppendInt32(buf, interval.Months) + return buf, nil } -func (src *Interval) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) +type encodePlanIntervalCodecText struct{} + +func (encodePlanIntervalCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + interval, err := value.(IntervalValuer).IntervalValue() + if err != nil { + return nil, err } - switch v := dst.(type) { - case *time.Duration: - us := int64(src.Months)*microsecondsPerMonth + int64(src.Days)*microsecondsPerDay + src.Microseconds - *v = time.Duration(us) * time.Microsecond - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) + if !interval.Valid { + return nil, nil + } + + if interval.Months != 0 { + buf = append(buf, strconv.FormatInt(int64(interval.Months), 10)...) + buf = append(buf, " mon "...) + } + + if interval.Days != 0 { + buf = append(buf, strconv.FormatInt(int64(interval.Days), 10)...) + buf = append(buf, " day "...) + } + + absMicroseconds := interval.Microseconds + if absMicroseconds < 0 { + absMicroseconds = -absMicroseconds + buf = append(buf, '-') + } + + hours := absMicroseconds / microsecondsPerHour + minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute + seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond + microseconds := absMicroseconds % microsecondsPerSecond + + timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds) + buf = append(buf, timeStr...) + return buf, nil +} + +func (IntervalCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case IntervalScanner: + return scanPlanBinaryIntervalToIntervalScanner{} + } + case TextFormatCode: + switch target.(type) { + case IntervalScanner: + return scanPlanTextAnyToIntervalScanner{} } - return fmt.Errorf("unable to assign to %T", dst) } + + return nil } -func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { +type scanPlanBinaryIntervalToIntervalScanner struct{} + +func (scanPlanBinaryIntervalToIntervalScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(IntervalScanner) + if src == nil { - *dst = Interval{} - return nil + return scanner.ScanInterval(Interval{}) + } + + if len(src) != 16 { + return fmt.Errorf("Received an invalid size for a interval: %d", len(src)) + } + + microseconds := int64(binary.BigEndian.Uint64(src)) + days := int32(binary.BigEndian.Uint32(src[8:])) + months := int32(binary.BigEndian.Uint32(src[12:])) + + return scanner.ScanInterval(Interval{Microseconds: microseconds, Days: days, Months: months, Valid: true}) +} + +type scanPlanTextAnyToIntervalScanner struct{} + +func (scanPlanTextAnyToIntervalScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(IntervalScanner) + + if src == nil { + return scanner.ScanInterval(Interval{}) } var microseconds int64 @@ -156,89 +271,22 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = Interval{Months: months, Days: days, Microseconds: microseconds, Valid: true} - return nil + return scanner.ScanInterval(Interval{Months: months, Days: days, Microseconds: microseconds, Valid: true}) } -func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { +func (c IntervalCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c IntervalCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = Interval{} - return nil - } - - if len(src) != 16 { - return fmt.Errorf("Received an invalid size for a interval: %d", len(src)) - } - - microseconds := int64(binary.BigEndian.Uint64(src)) - days := int32(binary.BigEndian.Uint32(src[8:])) - months := int32(binary.BigEndian.Uint32(src[12:])) - - *dst = Interval{Microseconds: microseconds, Days: days, Months: months, Valid: true} - return nil -} - -func (src Interval) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { return nil, nil } - if src.Months != 0 { - buf = append(buf, strconv.FormatInt(int64(src.Months), 10)...) - buf = append(buf, " mon "...) + var interval Interval + err := codecScan(c, ci, oid, format, src, &interval) + if err != nil { + return nil, err } - - if src.Days != 0 { - buf = append(buf, strconv.FormatInt(int64(src.Days), 10)...) - buf = append(buf, " day "...) - } - - absMicroseconds := src.Microseconds - if absMicroseconds < 0 { - absMicroseconds = -absMicroseconds - buf = append(buf, '-') - } - - hours := absMicroseconds / microsecondsPerHour - minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute - seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond - microseconds := absMicroseconds % microsecondsPerSecond - - timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds) - return append(buf, timeStr...), nil -} - -// EncodeBinary encodes src into w. -func (src Interval) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = pgio.AppendInt64(buf, src.Microseconds) - buf = pgio.AppendInt32(buf, src.Days) - return pgio.AppendInt32(buf, src.Months), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Interval) Scan(src interface{}) error { - if src == nil { - *dst = Interval{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Interval) Value() (driver.Value, error) { - return EncodeValueText(src) + return interval, nil } diff --git a/pgtype/interval_test.go b/pgtype/interval_test.go index a8241bf6..75733ff1 100644 --- a/pgtype/interval_test.go +++ b/pgtype/interval_test.go @@ -5,70 +5,132 @@ import ( "time" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func TestIntervalTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "interval", []interface{}{ - &pgtype.Interval{Microseconds: 1, Valid: true}, - &pgtype.Interval{Microseconds: 1000000, Valid: true}, - &pgtype.Interval{Microseconds: 1000001, Valid: true}, - &pgtype.Interval{Microseconds: 123202800000000, Valid: true}, - &pgtype.Interval{Days: 1, Valid: true}, - &pgtype.Interval{Months: 1, Valid: true}, - &pgtype.Interval{Months: 12, Valid: true}, - &pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Valid: true}, - &pgtype.Interval{Microseconds: -1, Valid: true}, - &pgtype.Interval{Microseconds: -1000000, Valid: true}, - &pgtype.Interval{Microseconds: -1000001, Valid: true}, - &pgtype.Interval{Microseconds: -123202800000000, Valid: true}, - &pgtype.Interval{Days: -1, Valid: true}, - &pgtype.Interval{Months: -1, Valid: true}, - &pgtype.Interval{Months: -12, Valid: true}, - &pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Valid: true}, - &pgtype.Interval{}, +func TestIntervalCodec(t *testing.T) { + testPgxCodec(t, "interval", []PgxTranscodeTestCase{ + { + pgtype.Interval{Microseconds: 1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1, Valid: true}), + }, + { + pgtype.Interval{Microseconds: 1000000, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1000000, Valid: true}), + }, + { + pgtype.Interval{Microseconds: 1000001, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1000001, Valid: true}), + }, + { + pgtype.Interval{Microseconds: 123202800000000, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 123202800000000, Valid: true}), + }, + { + pgtype.Interval{Days: 1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Days: 1, Valid: true}), + }, + { + pgtype.Interval{Months: 1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 1, Valid: true}), + }, + { + pgtype.Interval{Months: 12, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 12, Valid: true}), + }, + { + pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Valid: true}), + }, + { + pgtype.Interval{Microseconds: -1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: -1, Valid: true}), + }, + { + pgtype.Interval{Microseconds: -1000000, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: -1000000, Valid: true}), + }, + { + pgtype.Interval{Microseconds: -1000001, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: -1000001, Valid: true}), + }, + { + pgtype.Interval{Microseconds: -123202800000000, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: -123202800000000, Valid: true}), + }, + { + pgtype.Interval{Days: -1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Days: -1, Valid: true}), + }, + { + pgtype.Interval{Months: -1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: -1, Valid: true}), + }, + { + pgtype.Interval{Months: -12, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: -12, Valid: true}), + }, + { + pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Valid: true}), + }, + { + "1 second", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1000000, Valid: true}), + }, + { + "1.000001 second", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1000001, Valid: true}), + }, + { + "34223 hours", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 123202800000000, Valid: true}), + }, + { + "1 day", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Days: 1, Valid: true}), + }, + { + "1 month", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 1, Valid: true}), + }, + { + "1 year", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 12, Valid: true}), + }, + { + "-13 mon", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: -13, Valid: true}), + }, + {time.Hour, new(time.Duration), isExpectedEq(time.Hour)}, + { + pgtype.Interval{Months: 1, Days: 1, Valid: true}, + new(time.Duration), + isExpectedEq(time.Duration(2678400000000000)), + }, + {pgtype.Interval{}, new(pgtype.Interval), isExpectedEq(pgtype.Interval{})}, + {nil, new(pgtype.Interval), isExpectedEq(pgtype.Interval{})}, }) } - -func TestIntervalNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select '1 second'::interval", - Value: &pgtype.Interval{Microseconds: 1000000, Valid: true}, - }, - { - SQL: "select '1.000001 second'::interval", - Value: &pgtype.Interval{Microseconds: 1000001, Valid: true}, - }, - { - SQL: "select '34223 hours'::interval", - Value: &pgtype.Interval{Microseconds: 123202800000000, Valid: true}, - }, - { - SQL: "select '1 day'::interval", - Value: &pgtype.Interval{Days: 1, Valid: true}, - }, - { - SQL: "select '1 month'::interval", - Value: &pgtype.Interval{Months: 1, Valid: true}, - }, - { - SQL: "select '1 year'::interval", - Value: &pgtype.Interval{Months: 12, Valid: true}, - }, - { - SQL: "select '-13 mon'::interval", - Value: &pgtype.Interval{Months: -13, Valid: true}, - }, - }) -} - -func TestIntervalLossyConversionToDuration(t *testing.T) { - interval := &pgtype.Interval{Months: 1, Days: 1, Valid: true} - var d time.Duration - err := interval.AssignTo(&d) - require.NoError(t, err) - assert.EqualValues(t, int64(2678400000000000), d.Nanoseconds()) -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index c0d02197..5ac3b50d 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -305,7 +305,7 @@ func NewConnInfo() *ConnInfo { // ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) ci.RegisterDataType(DataType{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) // ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) - ci.RegisterDataType(DataType{Value: &Interval{}, Name: "interval", OID: IntervalOID}) + ci.RegisterDataType(DataType{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID}) ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID}) ci.RegisterDataType(DataType{Value: &JSONBArray{}, Name: "_jsonb", OID: JSONBArrayOID}) @@ -858,6 +858,8 @@ func tryWrapBuiltinTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter return &wrapStringScanPlan{}, (*stringWrapper)(dst), true case *time.Time: return &wrapTimeScanPlan{}, (*timeWrapper)(dst), true + case *time.Duration: + return &wrapDurationScanPlan{}, (*durationWrapper)(dst), true case *net.IPNet: return &wrapNetIPNetScanPlan{}, (*netIPNetWrapper)(dst), true case *net.IP: @@ -1011,6 +1013,16 @@ func (plan *wrapTimeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, s return plan.next.Scan(ci, oid, formatCode, src, (*timeWrapper)(dst.(*time.Time))) } +type wrapDurationScanPlan struct { + next ScanPlan +} + +func (plan *wrapDurationScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapDurationScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*durationWrapper)(dst.(*time.Duration))) +} + type wrapNetIPNetScanPlan struct { next ScanPlan } @@ -1143,8 +1155,10 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan for _, f := range tryWrappers { if wrapperPlan, nextDst, ok := f(dst); ok { if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { - wrapperPlan.SetNext(nextPlan) - return wrapperPlan + if _, ok := nextPlan.(*scanPlanDataTypeAssignTo); !ok { // avoid fallthrough -- this will go away when old system removed. + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } } } } @@ -1381,6 +1395,8 @@ func tryWrapBuiltinTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNext return &wrapStringEncodePlan{}, stringWrapper(value), true case time.Time: return &wrapTimeEncodePlan{}, timeWrapper(value), true + case time.Duration: + return &wrapDurationEncodePlan{}, durationWrapper(value), true case net.IPNet: return &wrapNetIPNetEncodePlan{}, netIPNetWrapper(value), true case net.IP: @@ -1534,6 +1550,16 @@ func (plan *wrapTimeEncodePlan) Encode(value interface{}, buf []byte) (newBuf [] return plan.next.Encode(timeWrapper(value.(time.Time)), buf) } +type wrapDurationEncodePlan struct { + next EncodePlan +} + +func (plan *wrapDurationEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapDurationEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(durationWrapper(value.(time.Duration)), buf) +} + type wrapNetIPNetEncodePlan struct { next EncodePlan } From bff036b366ba0d55116754449c16fe41e85d5329 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 15 Jan 2022 18:48:10 -0600 Subject: [PATCH 0841/1158] Add interval array support --- pgtype/pgtype.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5ac3b50d..300037df 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -71,6 +71,7 @@ const ( TimestamptzOID = 1184 TimestamptzArrayOID = 1185 IntervalOID = 1186 + IntervalArrayOID = 1187 NumericArrayOID = 1231 BitOID = 1560 BitArrayOID = 1561 @@ -270,6 +271,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementCodec: Int2Codec{}, ElementOID: Int2OID}}) ci.RegisterDataType(DataType{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementCodec: Int4Codec{}, ElementOID: Int4OID}}) ci.RegisterDataType(DataType{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementCodec: Int8Codec{}, ElementOID: Int8OID}}) + ci.RegisterDataType(DataType{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementCodec: IntervalCodec{}, ElementOID: IntervalOID}}) ci.RegisterDataType(DataType{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementCodec: BoxCodec{}, ElementOID: BoxOID}}) ci.RegisterDataType(DataType{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementCodec: CircleCodec{}, ElementOID: CircleOID}}) ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementCodec: PointCodec{}, ElementOID: PointOID}}) From 06593ffb10bda20ae035b6c21707a0c2c1eb9b14 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 18 Jan 2022 11:29:19 -0600 Subject: [PATCH 0842/1158] Convert line to Codec --- pgtype/line.go | 241 ++++++++++++++++++++++++++++++-------------- pgtype/line_test.go | 32 ++++-- pgtype/pgtype.go | 2 +- 3 files changed, 189 insertions(+), 86 deletions(-) diff --git a/pgtype/line.go b/pgtype/line.go index c3192b2a..db584862 100644 --- a/pgtype/line.go +++ b/pgtype/line.go @@ -11,34 +11,177 @@ import ( "github.com/jackc/pgio" ) +type LineScanner interface { + ScanLine(v Line) error +} + +type LineValuer interface { + LineValue() (Line, error) +} + type Line struct { A, B, C float64 Valid bool } -func (dst *Line) Set(src interface{}) error { +func (line *Line) ScanLine(v Line) error { + *line = v + return nil +} + +func (line Line) LineValue() (Line, error) { + return line, nil +} + +func (line *Line) Set(src interface{}) error { return fmt.Errorf("cannot convert %v to Line", src) } -func (dst Line) Get() interface{} { - if !dst.Valid { +// Scan implements the database/sql Scanner interface. +func (line *Line) Scan(src interface{}) error { + if src == nil { + *line = Line{} return nil } - return dst + + switch src := src.(type) { + case string: + return scanPlanTextAnyToLineScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), line) + } + + return fmt.Errorf("cannot scan %T", src) } -func (src *Line) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) +// Value implements the database/sql/driver Valuer interface. +func (line Line) Value() (driver.Value, error) { + if !line.Valid { + return nil, nil + } + + buf, err := LineCodec{}.PlanEncode(nil, 0, TextFormatCode, line).Encode(line, nil) + if err != nil { + return nil, err + } + return string(buf), err } -func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Line{} +type LineCodec struct{} + +func (LineCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (LineCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (LineCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(LineValuer); !ok { return nil } + switch format { + case BinaryFormatCode: + return encodePlanLineCodecBinary{} + case TextFormatCode: + return encodePlanLineCodecText{} + } + + return nil +} + +type encodePlanLineCodecBinary struct{} + +func (encodePlanLineCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + line, err := value.(LineValuer).LineValue() + if err != nil { + return nil, err + } + + if !line.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(line.A)) + buf = pgio.AppendUint64(buf, math.Float64bits(line.B)) + buf = pgio.AppendUint64(buf, math.Float64bits(line.C)) + return buf, nil +} + +type encodePlanLineCodecText struct{} + +func (encodePlanLineCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + line, err := value.(LineValuer).LineValue() + if err != nil { + return nil, err + } + + if !line.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`{%s,%s,%s}`, + strconv.FormatFloat(line.A, 'f', -1, 64), + strconv.FormatFloat(line.B, 'f', -1, 64), + strconv.FormatFloat(line.C, 'f', -1, 64), + )...) + return buf, nil +} + +func (LineCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case LineScanner: + return scanPlanBinaryLineToLineScanner{} + } + case TextFormatCode: + switch target.(type) { + case LineScanner: + return scanPlanTextAnyToLineScanner{} + } + } + + return nil +} + +type scanPlanBinaryLineToLineScanner struct{} + +func (scanPlanBinaryLineToLineScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(LineScanner) + + if src == nil { + return scanner.ScanLine(Line{}) + } + + if len(src) != 24 { + return fmt.Errorf("invalid length for line: %v", len(src)) + } + + a := binary.BigEndian.Uint64(src) + b := binary.BigEndian.Uint64(src[8:]) + c := binary.BigEndian.Uint64(src[16:]) + + return scanner.ScanLine(Line{ + A: math.Float64frombits(a), + B: math.Float64frombits(b), + C: math.Float64frombits(c), + Valid: true, + }) +} + +type scanPlanTextAnyToLineScanner struct{} + +func (scanPlanTextAnyToLineScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(LineScanner) + + if src == nil { + return scanner.ScanLine(Line{}) + } + if len(src) < 7 { - return fmt.Errorf("invalid length for Line: %v", len(src)) + return fmt.Errorf("invalid length for line: %v", len(src)) } parts := strings.SplitN(string(src[1:len(src)-1]), ",", 3) @@ -61,78 +204,22 @@ func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Line{A: a, B: b, C: c, Valid: true} - return nil + return scanner.ScanLine(Line{A: a, B: b, C: c, Valid: true}) } -func (dst *Line) DecodeBinary(ci *ConnInfo, src []byte) error { +func (c LineCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c LineCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = Line{} - return nil - } - - if len(src) != 24 { - return fmt.Errorf("invalid length for Line: %v", len(src)) - } - - a := binary.BigEndian.Uint64(src) - b := binary.BigEndian.Uint64(src[8:]) - c := binary.BigEndian.Uint64(src[16:]) - - *dst = Line{ - A: math.Float64frombits(a), - B: math.Float64frombits(b), - C: math.Float64frombits(c), - Valid: true, - } - return nil -} - -func (src Line) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { return nil, nil } - buf = append(buf, fmt.Sprintf(`{%s,%s,%s}`, - strconv.FormatFloat(src.A, 'f', -1, 64), - strconv.FormatFloat(src.B, 'f', -1, 64), - strconv.FormatFloat(src.C, 'f', -1, 64), - )...) - - return buf, nil -} - -func (src Line) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil + var line Line + err := codecScan(c, ci, oid, format, src, &line) + if err != nil { + return nil, err } - - buf = pgio.AppendUint64(buf, math.Float64bits(src.A)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.B)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.C)) - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Line) Scan(src interface{}) error { - if src == nil { - *dst = Line{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Line) Value() (driver.Value, error) { - return EncodeValueText(src) + return line, nil } diff --git a/pgtype/line_test.go b/pgtype/line_test.go index b171a7a5..669d9b8d 100644 --- a/pgtype/line_test.go +++ b/pgtype/line_test.go @@ -10,6 +10,7 @@ import ( func TestLineTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) + defer conn.Close(context.Background()) if _, ok := conn.ConnInfo().DataTypeForName("line"); !ok { t.Skip("Skipping due to no line type") } @@ -24,15 +25,30 @@ func TestLineTranscode(t *testing.T) { t.Skip("Skipping due to unimplemented line type in PG 9.3") } - testutil.TestSuccessfulTranscode(t, "line", []interface{}{ - &pgtype.Line{ - A: 1.23, B: 4.56, C: 7.89012345, - Valid: true, + testPgxCodec(t, "line", []PgxTranscodeTestCase{ + { + pgtype.Line{ + A: 1.23, B: 4.56, C: 7.89012345, + Valid: true, + }, + new(pgtype.Line), + isExpectedEq(pgtype.Line{ + A: 1.23, B: 4.56, C: 7.89012345, + Valid: true, + }), }, - &pgtype.Line{ - A: -1.23, B: -4.56, C: -7.89, - Valid: true, + { + pgtype.Line{ + A: -1.23, B: -4.56, C: -7.89, + Valid: true, + }, + new(pgtype.Line), + isExpectedEq(pgtype.Line{ + A: -1.23, B: -4.56, C: -7.89, + Valid: true, + }), }, - &pgtype.Line{}, + {pgtype.Line{}, new(pgtype.Line), isExpectedEq(pgtype.Line{})}, + {nil, new(pgtype.Line), isExpectedEq(pgtype.Line{})}, }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 300037df..605d9132 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -311,7 +311,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID}) ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID}) ci.RegisterDataType(DataType{Value: &JSONBArray{}, Name: "_jsonb", OID: JSONBArrayOID}) - ci.RegisterDataType(DataType{Value: &Line{}, Name: "line", OID: LineOID}) + ci.RegisterDataType(DataType{Name: "line", OID: LineOID, Codec: LineCodec{}}) ci.RegisterDataType(DataType{Value: &Lseg{}, Name: "lseg", OID: LsegOID}) ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID}) ci.RegisterDataType(DataType{Name: "name", OID: NameOID, Codec: TextCodec{}}) From 97d8a408ea6501877550c2d0a0a96be2b395df22 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 18 Jan 2022 11:30:39 -0600 Subject: [PATCH 0843/1158] Add line array --- pgtype/pgtype.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 605d9132..7c94c809 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -33,6 +33,7 @@ const ( BoxOID = 603 PolygonOID = 604 LineOID = 628 + LineArrayOID = 629 CIDROID = 650 CIDRArrayOID = 651 Float4OID = 700 @@ -273,6 +274,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementCodec: Int8Codec{}, ElementOID: Int8OID}}) ci.RegisterDataType(DataType{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementCodec: IntervalCodec{}, ElementOID: IntervalOID}}) ci.RegisterDataType(DataType{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementCodec: BoxCodec{}, ElementOID: BoxOID}}) + ci.RegisterDataType(DataType{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementCodec: LineCodec{}, ElementOID: LineOID}}) ci.RegisterDataType(DataType{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementCodec: CircleCodec{}, ElementOID: CircleOID}}) ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementCodec: PointCodec{}, ElementOID: PointOID}}) ci.RegisterDataType(DataType{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: NameOID}}) From 869213a315c957e4210efb09b640ae68b5d671ee Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 18 Jan 2022 11:38:35 -0600 Subject: [PATCH 0844/1158] Convert lseg to Codec --- pgtype/lseg.go | 247 +++++++++++++++++++++++++++++--------------- pgtype/lseg_test.go | 32 ++++-- pgtype/pgtype.go | 2 +- 3 files changed, 189 insertions(+), 92 deletions(-) diff --git a/pgtype/lseg.go b/pgtype/lseg.go index 649863ca..26730e85 100644 --- a/pgtype/lseg.go +++ b/pgtype/lseg.go @@ -11,34 +11,177 @@ import ( "github.com/jackc/pgio" ) +type LsegScanner interface { + ScanLseg(v Lseg) error +} + +type LsegValuer interface { + LsegValue() (Lseg, error) +} + type Lseg struct { P [2]Vec2 Valid bool } -func (dst *Lseg) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Lseg", src) +func (lseg *Lseg) ScanLseg(v Lseg) error { + *lseg = v + return nil } -func (dst Lseg) Get() interface{} { - if !dst.Valid { +func (lseg Lseg) LsegValue() (Lseg, error) { + return lseg, nil +} + +// Scan implements the database/sql Scanner interface. +func (lseg *Lseg) Scan(src interface{}) error { + if src == nil { + *lseg = Lseg{} return nil } - return dst + + switch src := src.(type) { + case string: + return scanPlanTextAnyToLsegScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), lseg) + } + + return fmt.Errorf("cannot scan %T", src) } -func (src *Lseg) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) +// Value implements the database/sql/driver Valuer interface. +func (lseg Lseg) Value() (driver.Value, error) { + if !lseg.Valid { + return nil, nil + } + + buf, err := LsegCodec{}.PlanEncode(nil, 0, TextFormatCode, lseg).Encode(lseg, nil) + if err != nil { + return nil, err + } + return string(buf), err } -func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Lseg{} +type LsegCodec struct{} + +func (LsegCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (LsegCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (LsegCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(LsegValuer); !ok { return nil } + switch format { + case BinaryFormatCode: + return encodePlanLsegCodecBinary{} + case TextFormatCode: + return encodePlanLsegCodecText{} + } + + return nil +} + +type encodePlanLsegCodecBinary struct{} + +func (encodePlanLsegCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + lseg, err := value.(LsegValuer).LsegValue() + if err != nil { + return nil, err + } + + if !lseg.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(lseg.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(lseg.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(lseg.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(lseg.P[1].Y)) + return buf, nil +} + +type encodePlanLsegCodecText struct{} + +func (encodePlanLsegCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + lseg, err := value.(LsegValuer).LsegValue() + if err != nil { + return nil, err + } + + if !lseg.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, + strconv.FormatFloat(lseg.P[0].X, 'f', -1, 64), + strconv.FormatFloat(lseg.P[0].Y, 'f', -1, 64), + strconv.FormatFloat(lseg.P[1].X, 'f', -1, 64), + strconv.FormatFloat(lseg.P[1].Y, 'f', -1, 64), + )...) + return buf, nil +} + +func (LsegCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case LsegScanner: + return scanPlanBinaryLsegToLsegScanner{} + } + case TextFormatCode: + switch target.(type) { + case LsegScanner: + return scanPlanTextAnyToLsegScanner{} + } + } + + return nil +} + +type scanPlanBinaryLsegToLsegScanner struct{} + +func (scanPlanBinaryLsegToLsegScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(LsegScanner) + + if src == nil { + return scanner.ScanLseg(Lseg{}) + } + + if len(src) != 32 { + return fmt.Errorf("invalid length for lseg: %v", len(src)) + } + + x1 := binary.BigEndian.Uint64(src) + y1 := binary.BigEndian.Uint64(src[8:]) + x2 := binary.BigEndian.Uint64(src[16:]) + y2 := binary.BigEndian.Uint64(src[24:]) + + return scanner.ScanLseg(Lseg{ + P: [2]Vec2{ + {math.Float64frombits(x1), math.Float64frombits(y1)}, + {math.Float64frombits(x2), math.Float64frombits(y2)}, + }, + Valid: true, + }) +} + +type scanPlanTextAnyToLsegScanner struct{} + +func (scanPlanTextAnyToLsegScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(LsegScanner) + + if src == nil { + return scanner.ScanLseg(Lseg{}) + } + if len(src) < 11 { - return fmt.Errorf("invalid length for Lseg: %v", len(src)) + return fmt.Errorf("invalid length for lseg: %v", len(src)) } str := string(src[2:]) @@ -74,82 +217,22 @@ func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Lseg{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true} - return nil + return scanner.ScanLseg(Lseg{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true}) } -func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { +func (c LsegCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c LsegCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = Lseg{} - return nil - } - - if len(src) != 32 { - return fmt.Errorf("invalid length for Lseg: %v", len(src)) - } - - x1 := binary.BigEndian.Uint64(src) - y1 := binary.BigEndian.Uint64(src[8:]) - x2 := binary.BigEndian.Uint64(src[16:]) - y2 := binary.BigEndian.Uint64(src[24:]) - - *dst = Lseg{ - P: [2]Vec2{ - {math.Float64frombits(x1), math.Float64frombits(y1)}, - {math.Float64frombits(x2), math.Float64frombits(y2)}, - }, - Valid: true, - } - return nil -} - -func (src Lseg) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { return nil, nil } - buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, - strconv.FormatFloat(src.P[0].X, 'f', -1, 64), - strconv.FormatFloat(src.P[0].Y, 'f', -1, 64), - strconv.FormatFloat(src.P[1].X, 'f', -1, 64), - strconv.FormatFloat(src.P[1].Y, 'f', -1, 64), - )...) - - return buf, nil -} - -func (src Lseg) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil + var lseg Lseg + err := codecScan(c, ci, oid, format, src, &lseg) + if err != nil { + return nil, err } - - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y)) - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Lseg) Scan(src interface{}) error { - if src == nil { - *dst = Lseg{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Lseg) Value() (driver.Value, error) { - return EncodeValueText(src) + return lseg, nil } diff --git a/pgtype/lseg_test.go b/pgtype/lseg_test.go index ce128784..1866439f 100644 --- a/pgtype/lseg_test.go +++ b/pgtype/lseg_test.go @@ -4,19 +4,33 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestLsegTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "lseg", []interface{}{ - &pgtype.Lseg{ - P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, - Valid: true, + testPgxCodec(t, "lseg", []PgxTranscodeTestCase{ + { + pgtype.Lseg{ + P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, + Valid: true, + }, + new(pgtype.Lseg), + isExpectedEq(pgtype.Lseg{ + P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, + Valid: true, + }), }, - &pgtype.Lseg{ - P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, - Valid: true, + { + pgtype.Lseg{ + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Valid: true, + }, + new(pgtype.Lseg), + isExpectedEq(pgtype.Lseg{ + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Valid: true, + }), }, - &pgtype.Lseg{}, + {pgtype.Lseg{}, new(pgtype.Lseg), isExpectedEq(pgtype.Lseg{})}, + {nil, new(pgtype.Lseg), isExpectedEq(pgtype.Lseg{})}, }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 7c94c809..9948c87a 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -314,7 +314,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID}) ci.RegisterDataType(DataType{Value: &JSONBArray{}, Name: "_jsonb", OID: JSONBArrayOID}) ci.RegisterDataType(DataType{Name: "line", OID: LineOID, Codec: LineCodec{}}) - ci.RegisterDataType(DataType{Value: &Lseg{}, Name: "lseg", OID: LsegOID}) + ci.RegisterDataType(DataType{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID}) ci.RegisterDataType(DataType{Name: "name", OID: NameOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID}) From 0ae8de35c838e44e93e22ad9ba3ab7440ba41313 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 18 Jan 2022 11:39:58 -0600 Subject: [PATCH 0845/1158] Add lseg array --- pgtype/pgtype.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 9948c87a..af91dcaf 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -55,6 +55,7 @@ const ( VarcharArrayOID = 1015 Int8ArrayOID = 1016 PointArrayOID = 1017 + LsegArrayOID = 1018 BoxArrayOID = 1020 Float4ArrayOID = 1021 Float8ArrayOID = 1022 @@ -275,6 +276,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementCodec: IntervalCodec{}, ElementOID: IntervalOID}}) ci.RegisterDataType(DataType{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementCodec: BoxCodec{}, ElementOID: BoxOID}}) ci.RegisterDataType(DataType{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementCodec: LineCodec{}, ElementOID: LineOID}}) + ci.RegisterDataType(DataType{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementCodec: LsegCodec{}, ElementOID: LsegOID}}) ci.RegisterDataType(DataType{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementCodec: CircleCodec{}, ElementOID: CircleOID}}) ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementCodec: PointCodec{}, ElementOID: PointOID}}) ci.RegisterDataType(DataType{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: NameOID}}) From 5ff0ad548be0d589750f3293b349edf97130dd53 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 18 Jan 2022 11:51:08 -0600 Subject: [PATCH 0846/1158] Convert path to Codec --- pgtype/path.go | 315 ++++++++++++++++++++++++++++---------------- pgtype/path_test.go | 73 +++++++--- pgtype/pgtype.go | 2 +- 3 files changed, 260 insertions(+), 130 deletions(-) diff --git a/pgtype/path.go b/pgtype/path.go index 7ac38c68..be7daaa0 100644 --- a/pgtype/path.go +++ b/pgtype/path.go @@ -11,33 +11,213 @@ import ( "github.com/jackc/pgio" ) +type PathScanner interface { + ScanPath(v Path) error +} + +type PathValuer interface { + PathValue() (Path, error) +} + type Path struct { P []Vec2 Closed bool Valid bool } -func (dst *Path) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Path", src) +func (path *Path) ScanPath(v Path) error { + *path = v + return nil } -func (dst Path) Get() interface{} { - if !dst.Valid { +func (path Path) PathValue() (Path, error) { + return path, nil +} + +// Scan implements the database/sql Scanner interface. +func (path *Path) Scan(src interface{}) error { + if src == nil { + *path = Path{} return nil } - return dst + + switch src := src.(type) { + case string: + return scanPlanTextAnyToPathScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), path) + } + + return fmt.Errorf("cannot scan %T", src) } -func (src *Path) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) +// Value implements the database/sql/driver Valuer interface. +func (path Path) Value() (driver.Value, error) { + if !path.Valid { + return nil, nil + } + + buf, err := PathCodec{}.PlanEncode(nil, 0, TextFormatCode, path).Encode(path, nil) + if err != nil { + return nil, err + } + + return string(buf), err } -func (dst *Path) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Path{} +type PathCodec struct{} + +func (PathCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (PathCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (PathCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(PathValuer); !ok { return nil } + switch format { + case BinaryFormatCode: + return encodePlanPathCodecBinary{} + case TextFormatCode: + return encodePlanPathCodecText{} + } + + return nil +} + +type encodePlanPathCodecBinary struct{} + +func (encodePlanPathCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + path, err := value.(PathValuer).PathValue() + if err != nil { + return nil, err + } + + if !path.Valid { + return nil, nil + } + + var closeByte byte + if path.Closed { + closeByte = 1 + } + buf = append(buf, closeByte) + + buf = pgio.AppendInt32(buf, int32(len(path.P))) + + for _, p := range path.P { + buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) + } + + return buf, nil +} + +type encodePlanPathCodecText struct{} + +func (encodePlanPathCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + path, err := value.(PathValuer).PathValue() + if err != nil { + return nil, err + } + + if !path.Valid { + return nil, nil + } + + var startByte, endByte byte + if path.Closed { + startByte = '(' + endByte = ')' + } else { + startByte = '[' + endByte = ']' + } + buf = append(buf, startByte) + + for i, p := range path.P { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(p.X, 'f', -1, 64), + strconv.FormatFloat(p.Y, 'f', -1, 64), + )...) + } + + buf = append(buf, endByte) + + return buf, nil +} + +func (PathCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case PathScanner: + return scanPlanBinaryPathToPathScanner{} + } + case TextFormatCode: + switch target.(type) { + case PathScanner: + return scanPlanTextAnyToPathScanner{} + } + } + + return nil +} + +type scanPlanBinaryPathToPathScanner struct{} + +func (scanPlanBinaryPathToPathScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(PathScanner) + + if src == nil { + return scanner.ScanPath(Path{}) + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for Path: %v", len(src)) + } + + closed := src[0] == 1 + pointCount := int(binary.BigEndian.Uint32(src[1:])) + + rp := 5 + + if 5+pointCount*16 != len(src) { + return fmt.Errorf("invalid length for Path with %d points: %v", pointCount, len(src)) + } + + points := make([]Vec2, pointCount) + for i := 0; i < len(points); i++ { + x := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + y := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} + } + + return scanner.ScanPath(Path{ + P: points, + Closed: closed, + Valid: true, + }) +} + +type scanPlanTextAnyToPathScanner struct{} + +func (scanPlanTextAnyToPathScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(PathScanner) + + if src == nil { + return scanner.ScanPath(Path{}) + } + if len(src) < 7 { return fmt.Errorf("invalid length for Path: %v", len(src)) } @@ -71,115 +251,22 @@ func (dst *Path) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = Path{P: points, Closed: closed, Valid: true} - return nil + return scanner.ScanPath(Path{P: points, Closed: closed, Valid: true}) } -func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { +func (c PathCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c PathCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = Path{} - return nil - } - - if len(src) < 5 { - return fmt.Errorf("invalid length for Path: %v", len(src)) - } - - closed := src[0] == 1 - pointCount := int(binary.BigEndian.Uint32(src[1:])) - - rp := 5 - - if 5+pointCount*16 != len(src) { - return fmt.Errorf("invalid length for Path with %d points: %v", pointCount, len(src)) - } - - points := make([]Vec2, pointCount) - for i := 0; i < len(points); i++ { - x := binary.BigEndian.Uint64(src[rp:]) - rp += 8 - y := binary.BigEndian.Uint64(src[rp:]) - rp += 8 - points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} - } - - *dst = Path{ - P: points, - Closed: closed, - Valid: true, - } - return nil -} - -func (src Path) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { return nil, nil } - var startByte, endByte byte - if src.Closed { - startByte = '(' - endByte = ')' - } else { - startByte = '[' - endByte = ']' + var path Path + err := codecScan(c, ci, oid, format, src, &path) + if err != nil { + return nil, err } - buf = append(buf, startByte) - - for i, p := range src.P { - if i > 0 { - buf = append(buf, ',') - } - buf = append(buf, fmt.Sprintf(`(%s,%s)`, - strconv.FormatFloat(p.X, 'f', -1, 64), - strconv.FormatFloat(p.Y, 'f', -1, 64), - )...) - } - - return append(buf, endByte), nil -} - -func (src Path) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - var closeByte byte - if src.Closed { - closeByte = 1 - } - buf = append(buf, closeByte) - - buf = pgio.AppendInt32(buf, int32(len(src.P))) - - for _, p := range src.P { - buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) - buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Path) Scan(src interface{}) error { - if src == nil { - *dst = Path{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Path) Value() (driver.Value, error) { - return EncodeValueText(src) + return path, nil } diff --git a/pgtype/path_test.go b/pgtype/path_test.go index 8a218fe1..291fa9d4 100644 --- a/pgtype/path_test.go +++ b/pgtype/path_test.go @@ -4,26 +4,69 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) +func isExpectedEqPath(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + ap := a.(pgtype.Path) + vp := v.(pgtype.Path) + + if !(ap.Valid == vp.Valid && ap.Closed == vp.Closed && len(ap.P) == len(vp.P)) { + return false + } + + for i := range ap.P { + if ap.P[i] != vp.P[i] { + return false + } + } + + return true + } +} + func TestPathTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "path", []interface{}{ - &pgtype.Path{ - P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}}, - Closed: false, - Valid: true, + testPgxCodec(t, "path", []PgxTranscodeTestCase{ + { + pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}}, + Closed: false, + Valid: true, + }, + new(pgtype.Path), + isExpectedEqPath(pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}}, + Closed: false, + Valid: true, + }), }, - &pgtype.Path{ - P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, - Closed: true, - Valid: true, + { + pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, + Closed: true, + Valid: true, + }, + new(pgtype.Path), + isExpectedEqPath(pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, + Closed: true, + Valid: true, + }), }, - &pgtype.Path{ - P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, - Closed: true, - Valid: true, + { + pgtype.Path{ + P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Closed: true, + Valid: true, + }, + new(pgtype.Path), + isExpectedEqPath(pgtype.Path{ + P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Closed: true, + Valid: true, + }), }, - &pgtype.Path{}, + {pgtype.Path{}, new(pgtype.Path), isExpectedEqPath(pgtype.Path{})}, + {nil, new(pgtype.Path), isExpectedEqPath(pgtype.Path{})}, }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index af91dcaf..3b967b82 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -322,7 +322,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID}) // ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) ci.RegisterDataType(DataType{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) - ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID}) + ci.RegisterDataType(DataType{Name: "path", OID: PathOID, Codec: PathCodec{}}) ci.RegisterDataType(DataType{Name: "point", OID: PointOID, Codec: PointCodec{}}) ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID}) // ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) From 11d96fb92867b6bad8ddc5e2857b60655735708c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 18 Jan 2022 11:52:44 -0600 Subject: [PATCH 0847/1158] Add path array --- pgtype/pgtype.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 3b967b82..9b995dd0 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -56,6 +56,7 @@ const ( Int8ArrayOID = 1016 PointArrayOID = 1017 LsegArrayOID = 1018 + PathArrayOID = 1019 BoxArrayOID = 1020 Float4ArrayOID = 1021 Float8ArrayOID = 1022 @@ -277,6 +278,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementCodec: BoxCodec{}, ElementOID: BoxOID}}) ci.RegisterDataType(DataType{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementCodec: LineCodec{}, ElementOID: LineOID}}) ci.RegisterDataType(DataType{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementCodec: LsegCodec{}, ElementOID: LsegOID}}) + ci.RegisterDataType(DataType{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementCodec: PathCodec{}, ElementOID: PathOID}}) ci.RegisterDataType(DataType{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementCodec: CircleCodec{}, ElementOID: CircleOID}}) ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementCodec: PointCodec{}, ElementOID: PointOID}}) ci.RegisterDataType(DataType{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: NameOID}}) From abd7e98f31f6e9420de13a019e9e80e967757d57 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 18 Jan 2022 12:04:17 -0600 Subject: [PATCH 0848/1158] Convert polygon to Codec --- pgtype/pgtype.go | 2 +- pgtype/polygon.go | 322 +++++++++++++++++++++++------------------ pgtype/polygon_test.go | 118 ++++++--------- 3 files changed, 223 insertions(+), 219 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 9b995dd0..f6cc2d43 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -326,7 +326,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) ci.RegisterDataType(DataType{Name: "path", OID: PathOID, Codec: PathCodec{}}) ci.RegisterDataType(DataType{Name: "point", OID: PointOID, Codec: PointCodec{}}) - ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID}) + ci.RegisterDataType(DataType{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) // ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) ci.RegisterDataType(DataType{Name: "text", OID: TextOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID}) diff --git a/pgtype/polygon.go b/pgtype/polygon.go index 956920e6..47dbfed9 100644 --- a/pgtype/polygon.go +++ b/pgtype/polygon.go @@ -11,81 +11,195 @@ import ( "github.com/jackc/pgio" ) +type PolygonScanner interface { + ScanPolygon(v Polygon) error +} + +type PolygonValuer interface { + PolygonValue() (Polygon, error) +} + type Polygon struct { P []Vec2 Valid bool } -// Set converts src to dest. -// -// src can be nil, string, []float64, and []pgtype.Vec2. -// -// If src is string the format must be ((x1,y1),(x2,y2),...,(xn,yn)). -// Important that there are no spaces in it. -func (dst *Polygon) Set(src interface{}) error { - if src == nil { - dst.Valid = false - return nil - } - err := fmt.Errorf("cannot convert %v to Polygon", src) - var p *Polygon - switch value := src.(type) { - case string: - p, err = stringToPolygon(value) - case []Vec2: - p = &Polygon{Valid: true, P: value} - err = nil - case []float64: - p, err = float64ToPolygon(value) - default: - return err - } - if err != nil { - return err - } - *dst = *p +func (p *Polygon) ScanPolygon(v Polygon) error { + *p = v return nil } -func stringToPolygon(src string) (*Polygon, error) { - p := &Polygon{} - err := p.DecodeText(nil, []byte(src)) - return p, err -} - -func float64ToPolygon(src []float64) (*Polygon, error) { - p := &Polygon{} - if len(src) == 0 { - return p, nil - } - if len(src)%2 != 0 { - return p, fmt.Errorf("invalid length for polygon: %v", len(src)) - } - p.Valid = true - p.P = make([]Vec2, 0) - for i := 0; i < len(src); i += 2 { - p.P = append(p.P, Vec2{X: src[i], Y: src[i+1]}) - } +func (p Polygon) PolygonValue() (Polygon, error) { return p, nil } -func (dst Polygon) Get() interface{} { - if !dst.Valid { +// Scan implements the database/sql Scanner interface. +func (p *Polygon) Scan(src interface{}) error { + if src == nil { + *p = Polygon{} return nil } - return dst + + switch src := src.(type) { + case string: + return scanPlanTextAnyToPolygonScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), p) + } + + return fmt.Errorf("cannot scan %T", src) } -func (src *Polygon) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) +// Value implements the database/sql/driver Valuer interface. +func (p Polygon) Value() (driver.Value, error) { + if !p.Valid { + return nil, nil + } + + buf, err := PolygonCodec{}.PlanEncode(nil, 0, TextFormatCode, p).Encode(p, nil) + if err != nil { + return nil, err + } + + return string(buf), err } -func (dst *Polygon) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Polygon{} +type PolygonCodec struct{} + +func (PolygonCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (PolygonCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (PolygonCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(PolygonValuer); !ok { return nil } + switch format { + case BinaryFormatCode: + return encodePlanPolygonCodecBinary{} + case TextFormatCode: + return encodePlanPolygonCodecText{} + } + + return nil +} + +type encodePlanPolygonCodecBinary struct{} + +func (encodePlanPolygonCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + polygon, err := value.(PolygonValuer).PolygonValue() + if err != nil { + return nil, err + } + + if !polygon.Valid { + return nil, nil + } + + buf = pgio.AppendInt32(buf, int32(len(polygon.P))) + + for _, p := range polygon.P { + buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) + } + + return buf, nil +} + +type encodePlanPolygonCodecText struct{} + +func (encodePlanPolygonCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + polygon, err := value.(PolygonValuer).PolygonValue() + if err != nil { + return nil, err + } + + if !polygon.Valid { + return nil, nil + } + + buf = append(buf, '(') + + for i, p := range polygon.P { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(p.X, 'f', -1, 64), + strconv.FormatFloat(p.Y, 'f', -1, 64), + )...) + } + + buf = append(buf, ')') + + return buf, nil +} + +func (PolygonCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case PolygonScanner: + return scanPlanBinaryPolygonToPolygonScanner{} + } + case TextFormatCode: + switch target.(type) { + case PolygonScanner: + return scanPlanTextAnyToPolygonScanner{} + } + } + + return nil +} + +type scanPlanBinaryPolygonToPolygonScanner struct{} + +func (scanPlanBinaryPolygonToPolygonScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(PolygonScanner) + + if src == nil { + return scanner.ScanPolygon(Polygon{}) + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for polygon: %v", len(src)) + } + + pointCount := int(binary.BigEndian.Uint32(src)) + rp := 4 + + if 4+pointCount*16 != len(src) { + return fmt.Errorf("invalid length for Polygon with %d points: %v", pointCount, len(src)) + } + + points := make([]Vec2, pointCount) + for i := 0; i < len(points); i++ { + x := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + y := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} + } + + return scanner.ScanPolygon(Polygon{ + P: points, + Valid: true, + }) +} + +type scanPlanTextAnyToPolygonScanner struct{} + +func (scanPlanTextAnyToPolygonScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(PolygonScanner) + + if src == nil { + return scanner.ScanPolygon(Polygon{}) + } + if len(src) < 7 { return fmt.Errorf("invalid length for Polygon: %v", len(src)) } @@ -118,98 +232,22 @@ func (dst *Polygon) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = Polygon{P: points, Valid: true} - return nil + return scanner.ScanPolygon(Polygon{P: points, Valid: true}) } -func (dst *Polygon) DecodeBinary(ci *ConnInfo, src []byte) error { +func (c PolygonCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c PolygonCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = Polygon{} - return nil - } - - if len(src) < 5 { - return fmt.Errorf("invalid length for Polygon: %v", len(src)) - } - - pointCount := int(binary.BigEndian.Uint32(src)) - rp := 4 - - if 4+pointCount*16 != len(src) { - return fmt.Errorf("invalid length for Polygon with %d points: %v", pointCount, len(src)) - } - - points := make([]Vec2, pointCount) - for i := 0; i < len(points); i++ { - x := binary.BigEndian.Uint64(src[rp:]) - rp += 8 - y := binary.BigEndian.Uint64(src[rp:]) - rp += 8 - points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} - } - - *dst = Polygon{ - P: points, - Valid: true, - } - return nil -} - -func (src Polygon) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { return nil, nil } - buf = append(buf, '(') - - for i, p := range src.P { - if i > 0 { - buf = append(buf, ',') - } - buf = append(buf, fmt.Sprintf(`(%s,%s)`, - strconv.FormatFloat(p.X, 'f', -1, 64), - strconv.FormatFloat(p.Y, 'f', -1, 64), - )...) + var polygon Polygon + err := codecScan(c, ci, oid, format, src, &polygon) + if err != nil { + return nil, err } - - return append(buf, ')'), nil -} - -func (src Polygon) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = pgio.AppendInt32(buf, int32(len(src.P))) - - for _, p := range src.P { - buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) - buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Polygon) Scan(src interface{}) error { - if src == nil { - *dst = Polygon{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Polygon) Value() (driver.Value, error) { - return EncodeValueText(src) + return polygon, nil } diff --git a/pgtype/polygon_test.go b/pgtype/polygon_test.go index 4e7f69ce..c0912b31 100644 --- a/pgtype/polygon_test.go +++ b/pgtype/polygon_test.go @@ -4,86 +4,52 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestPolygonTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "polygon", []interface{}{ - &pgtype.Polygon{ - P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, - Valid: true, - }, - &pgtype.Polygon{ - P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, - Valid: true, - }, - &pgtype.Polygon{}, - }) +func isExpectedEqPolygon(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + ap := a.(pgtype.Polygon) + vp := v.(pgtype.Polygon) + + if !(ap.Valid == vp.Valid && len(ap.P) == len(vp.P)) { + return false + } + + for i := range ap.P { + if ap.P[i] != vp.P[i] { + return false + } + } + + return true + } } -func TestPolygon_Set(t *testing.T) { - tests := []struct { - name string - arg interface{} - valid bool - wantErr bool - }{ +func TestPolygonTranscode(t *testing.T) { + testPgxCodec(t, "polygon", []PgxTranscodeTestCase{ { - name: "string", - arg: "((3.14,1.678901234),(7.1,5.234),(5.0,3.234))", - valid: true, - wantErr: false, - }, { - name: "[]float64", - arg: []float64{1, 2, 3.45, 6.78, 1.23, 4.567, 8.9, 1.0}, - valid: true, - wantErr: false, - }, { - name: "[]Vec2", - arg: []pgtype.Vec2{{1, 2}, {2.3, 4.5}, {6.78, 9.123}}, - valid: true, - wantErr: false, - }, { - name: "null", - arg: nil, - valid: false, - wantErr: false, - }, { - name: "invalid_string_1", - arg: "((3.14,1.678901234),(7.1,5.234),(5.0,3.234x))", - valid: false, - wantErr: true, - }, { - name: "invalid_string_2", - arg: "(3,4)", - valid: false, - wantErr: true, - }, { - name: "invalid_[]float64", - arg: []float64{1, 2, 3.45, 6.78, 1.23, 4.567, 8.9}, - valid: false, - wantErr: true, - }, { - name: "invalid_type", - arg: []int{1, 2, 3, 6}, - valid: false, - wantErr: true, - }, { - name: "empty_[]float64", - arg: []float64{}, - valid: false, - wantErr: false, + pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, + Valid: true, + }, + new(pgtype.Polygon), + isExpectedEqPolygon(pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, + Valid: true, + }), }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - dst := &pgtype.Polygon{} - if err := dst.Set(tt.arg); (err != nil) != tt.wantErr { - t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr) - } - if dst.Valid != tt.valid { - t.Errorf("Expected valid: %v; got: %v", tt.valid, dst.Valid) - } - }) - } + { + pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, + Valid: true, + }, + new(pgtype.Polygon), + isExpectedEqPolygon(pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, + Valid: true, + }), + }, + {pgtype.Polygon{}, new(pgtype.Polygon), isExpectedEqPolygon(pgtype.Polygon{})}, + {nil, new(pgtype.Polygon), isExpectedEqPolygon(pgtype.Polygon{})}, + }) } From 8728acfca60401b5151ce3ac5ce91951caa5546b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 18 Jan 2022 12:05:28 -0600 Subject: [PATCH 0849/1158] Add polygon array --- pgtype/pgtype.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index f6cc2d43..a6d77356 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -60,6 +60,7 @@ const ( BoxArrayOID = 1020 Float4ArrayOID = 1021 Float8ArrayOID = 1022 + PolygonArrayOID = 1027 OIDArrayOID = 1028 ACLItemOID = 1033 ACLItemArrayOID = 1034 @@ -281,6 +282,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementCodec: PathCodec{}, ElementOID: PathOID}}) ci.RegisterDataType(DataType{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementCodec: CircleCodec{}, ElementOID: CircleOID}}) ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementCodec: PointCodec{}, ElementOID: PointOID}}) + ci.RegisterDataType(DataType{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementCodec: PolygonCodec{}, ElementOID: PolygonOID}}) ci.RegisterDataType(DataType{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: NameOID}}) ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) ci.RegisterDataType(DataType{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: TextOID}}) From 8b27725f5ba3f107557ef9ba213ece97f4a82322 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 18 Jan 2022 16:04:25 -0600 Subject: [PATCH 0850/1158] Convert json and jsonb to Codec --- pgtype/json.go | 240 +++++++----------- pgtype/json_test.go | 193 +++----------- pgtype/jsonb.go | 135 ++++++---- pgtype/jsonb_array.go | 504 ------------------------------------- pgtype/jsonb_array_test.go | 36 --- pgtype/jsonb_test.go | 143 +---------- pgtype/pgtype.go | 8 +- stdlib/sql.go | 16 +- values.go | 26 -- 9 files changed, 236 insertions(+), 1065 deletions(-) delete mode 100644 pgtype/jsonb_array.go delete mode 100644 pgtype/jsonb_array_test.go diff --git a/pgtype/json.go b/pgtype/json.go index 580e8505..510b638e 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -3,187 +3,129 @@ package pgtype import ( "database/sql/driver" "encoding/json" - "errors" - "fmt" + "reflect" ) -type JSON struct { - Bytes []byte - Valid bool +type JSONCodec struct{} + +func (JSONCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode } -func (dst *JSON) Set(src interface{}) error { - if src == nil { - *dst = JSON{} - return nil - } +func (JSONCodec) PreferredFormat() int16 { + return TextFormatCode +} - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case string: - *dst = JSON{Bytes: []byte(value), Valid: true} - case *string: - if value == nil { - *dst = JSON{} - } else { - *dst = JSON{Bytes: []byte(*value), Valid: true} - } +func (JSONCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch value.(type) { case []byte: - if value == nil { - *dst = JSON{} - } else { - *dst = JSON{Bytes: value, Valid: true} - } - // Encode* methods are defined on *JSON. If JSON is passed directly then the - // struct itself would be encoded instead of Bytes. This is clearly a footgun - // so detect and return an error. See https://github.com/jackc/pgx/issues/350. - case JSON: - return errors.New("use pointer to pgtype.JSON instead of value") - // Same as above but for JSONB (because they share implementation) - case JSONB: - return errors.New("use pointer to pgtype.JSONB instead of value") - + return encodePlanJSONCodecEitherFormatByteSlice{} default: - buf, err := json.Marshal(value) - if err != nil { - return err - } - *dst = JSON{Bytes: buf, Valid: true} + return encodePlanJSONCodecEitherFormatMarshal{} } - - return nil } -func (dst JSON) Get() interface{} { - if !dst.Valid { - return nil +type encodePlanJSONCodecEitherFormatByteSlice struct{} + +func (encodePlanJSONCodecEitherFormatByteSlice) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + jsonBytes := value.([]byte) + if jsonBytes == nil { + return nil, nil } - var i interface{} - err := json.Unmarshal(dst.Bytes, &i) + buf = append(buf, jsonBytes...) + return buf, nil +} + +type encodePlanJSONCodecEitherFormatMarshal struct{} + +func (encodePlanJSONCodecEitherFormatMarshal) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + jsonBytes, err := json.Marshal(value) if err != nil { - return dst + return nil, err } - return i + + buf = append(buf, jsonBytes...) + return buf, nil } -func (src *JSON) AssignTo(dst interface{}) error { - switch v := dst.(type) { +func (JSONCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch target.(type) { case *string: - if src.Valid { - *v = string(src.Bytes) - } else { - return fmt.Errorf("cannot assign non-valid to %T", dst) - } - case **string: - if src.Valid { - s := string(src.Bytes) - *v = &s - return nil - } else { - *v = nil - return nil - } + return scanPlanAnyToString{} case *[]byte: - if !src.Valid { - *v = nil - } else { - buf := make([]byte, len(src.Bytes)) - copy(buf, src.Bytes) - *v = buf - } + return scanPlanJSONToByteSlice{} + case BytesScanner: + return scanPlanBinaryBytesToBytesScanner{} default: - data := src.Bytes - if data == nil || !src.Valid { - data = []byte("null") + return scanPlanJSONToJSONUnmarshal{} + } + +} + +type scanPlanAnyToString struct{} + +func (scanPlanAnyToString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + p := dst.(*string) + *p = string(src) + return nil +} + +type scanPlanJSONToByteSlice struct{} + +func (scanPlanJSONToByteSlice) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + dstBuf := dst.(*[]byte) + if src == nil { + *dstBuf = nil + return nil + } + + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type scanPlanJSONToBytesScanner struct{} + +func (scanPlanJSONToBytesScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(BytesScanner) + return scanner.ScanBytes(src) +} + +type scanPlanJSONToJSONUnmarshal struct{} + +func (scanPlanJSONToJSONUnmarshal) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() == reflect.Ptr { + el := dstValue.Elem() + switch el.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map: + el.Set(reflect.Zero(el.Type())) + return nil + } } - - return json.Unmarshal(data, dst) } - return nil + return json.Unmarshal(src, dst) } -func (JSON) PreferredResultFormat() int16 { - return TextFormatCode -} - -func (dst *JSON) DecodeText(ci *ConnInfo, src []byte) error { +func (c JSONCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { - *dst = JSON{} - return nil - } - - *dst = JSON{Bytes: src, Valid: true} - return nil -} - -func (dst *JSON) DecodeBinary(ci *ConnInfo, src []byte) error { - return dst.DecodeText(ci, src) -} - -func (JSON) PreferredParamFormat() int16 { - return TextFormatCode -} - -func (src JSON) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { return nil, nil } - return append(buf, src.Bytes...), nil + dstBuf := make([]byte, len(src)) + copy(dstBuf, src) + return dstBuf, nil } -func (src JSON) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return src.EncodeText(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *JSON) Scan(src interface{}) error { +func (c JSONCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = JSON{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src JSON) Value() (driver.Value, error) { - if !src.Valid { return nil, nil } - return src.Bytes, nil -} - -func (src JSON) MarshalJSON() ([]byte, error) { - if !src.Valid { - return []byte("null"), nil - } - return src.Bytes, nil -} - -func (dst *JSON) UnmarshalJSON(b []byte) error { - if b == nil || string(b) == "null" { - *dst = JSON{} - } else { - *dst = JSON{Bytes: b, Valid: true} - } - return nil + var dst interface{} + err := json.Unmarshal(src, &dst) + return dst, err } diff --git a/pgtype/json_test.go b/pgtype/json_test.go index cb5162d3..156217ac 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -1,177 +1,52 @@ package pgtype_test import ( - "bytes" - "reflect" "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestJSONTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "json", []interface{}{ - &pgtype.JSON{Bytes: []byte("{}"), Valid: true}, - &pgtype.JSON{Bytes: []byte("null"), Valid: true}, - &pgtype.JSON{Bytes: []byte("42"), Valid: true}, - &pgtype.JSON{Bytes: []byte(`"hello"`), Valid: true}, - &pgtype.JSON{}, - }) -} +func isExpectedEqMap(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + aa := a.(map[string]interface{}) + bb := v.(map[string]interface{}) -func TestJSONSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.JSON - }{ - {source: "{}", result: pgtype.JSON{Bytes: []byte("{}"), Valid: true}}, - {source: []byte("{}"), result: pgtype.JSON{Bytes: []byte("{}"), Valid: true}}, - {source: ([]byte)(nil), result: pgtype.JSON{}}, - {source: (*string)(nil), result: pgtype.JSON{}}, - {source: []int{1, 2, 3}, result: pgtype.JSON{Bytes: []byte("[1,2,3]"), Valid: true}}, - {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Valid: true}}, - } - - for i, tt := range successfulTests { - var d pgtype.JSON - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) + if (aa == nil) != (bb == nil) { + return false } - if !reflect.DeepEqual(d, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + if aa == nil { + return true } + + if len(aa) != len(bb) { + return false + } + + for k := range aa { + if aa[k] != bb[k] { + return false + } + } + + return true } } -func TestJSONAssignTo(t *testing.T) { - var s string - var ps *string - var b []byte - - rawStringTests := []struct { - src pgtype.JSON - dst *string - expected string - }{ - {src: pgtype.JSON{Bytes: []byte("{}"), Valid: true}, dst: &s, expected: "{}"}, - } - - for i, tt := range rawStringTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if *tt.dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } - - rawBytesTests := []struct { - src pgtype.JSON - dst *[]byte - expected []byte - }{ - {src: pgtype.JSON{Bytes: []byte("{}"), Valid: true}, dst: &b, expected: []byte("{}")}, - {src: pgtype.JSON{}, dst: &b, expected: (([]byte)(nil))}, - } - - for i, tt := range rawBytesTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if bytes.Compare(tt.expected, *tt.dst) != 0 { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } - - var mapDst map[string]interface{} - type structDst struct { +func TestJSONCodec(t *testing.T) { + type jsonStruct struct { Name string `json:"name"` Age int `json:"age"` } - var strDst structDst - unmarshalTests := []struct { - src pgtype.JSON - dst interface{} - expected interface{} - }{ - {src: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Valid: true}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, - {src: pgtype.JSON{Bytes: []byte(`{"name":"John","age":42}`), Valid: true}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, - } - for i, tt := range unmarshalTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.JSON - dst **string - expected *string - }{ - {src: pgtype.JSON{}, dst: &ps, expected: ((*string)(nil))}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if *tt.dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } -} - -func TestJSONMarshalJSON(t *testing.T) { - successfulTests := []struct { - source pgtype.JSON - result string - }{ - {source: pgtype.JSON{}, result: "null"}, - {source: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Valid: true}, result: "{\"a\": 1}"}, - } - for i, tt := range successfulTests { - r, err := tt.source.MarshalJSON() - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if string(r) != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) - } - } -} - -func TestJSONUnmarshalJSON(t *testing.T) { - successfulTests := []struct { - source string - result pgtype.JSON - }{ - {source: "null", result: pgtype.JSON{}}, - {source: "{\"a\": 1}", result: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Valid: true}}, - } - for i, tt := range successfulTests { - var r pgtype.JSON - err := r.UnmarshalJSON([]byte(tt.source)) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if string(r.Bytes) != string(tt.result.Bytes) || r.Valid != tt.result.Valid { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } + testPgxCodec(t, "json", []PgxTranscodeTestCase{ + {[]byte("{}"), new([]byte), isExpectedEqBytes([]byte("{}"))}, + {[]byte("null"), new([]byte), isExpectedEqBytes([]byte("null"))}, + {[]byte("42"), new([]byte), isExpectedEqBytes([]byte("42"))}, + {[]byte(`"hello"`), new([]byte), isExpectedEqBytes([]byte(`"hello"`))}, + {[]byte(`"hello"`), new(string), isExpectedEq(`"hello"`)}, + {map[string]interface{}{"foo": "bar"}, new(map[string]interface{}), isExpectedEqMap(map[string]interface{}{"foo": "bar"})}, + {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, + {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, + {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, + }) } diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index 38d56499..6e329150 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -2,35 +2,64 @@ package pgtype import ( "database/sql/driver" + "encoding/json" "fmt" ) -type JSONB JSON +type JSONBCodec struct{} -func (dst *JSONB) Set(src interface{}) error { - return (*JSON)(dst).Set(src) +func (JSONBCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode } -func (dst JSONB) Get() interface{} { - return (JSON)(dst).Get() -} - -func (src *JSONB) AssignTo(dst interface{}) error { - return (*JSON)(src).AssignTo(dst) -} - -func (JSONB) PreferredResultFormat() int16 { +func (JSONBCodec) PreferredFormat() int16 { return TextFormatCode } -func (dst *JSONB) DecodeText(ci *ConnInfo, src []byte) error { - return (*JSON)(dst).DecodeText(ci, src) +func (JSONBCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + plan := JSONCodec{}.PlanEncode(ci, oid, TextFormatCode, value) + if plan != nil { + return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan} + } + case TextFormatCode: + return JSONCodec{}.PlanEncode(ci, oid, format, value) + } + + return nil } -func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { +type encodePlanJSONBCodecBinaryWrapper struct { + textPlan EncodePlan +} + +func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + buf = append(buf, 1) + return plan.textPlan.Encode(value, buf) +} + +func (JSONBCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { + case BinaryFormatCode: + plan := JSONCodec{}.PlanScan(ci, oid, TextFormatCode, target, actualTarget) + if plan != nil { + return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan} + } + case TextFormatCode: + return JSONCodec{}.PlanScan(ci, oid, format, target, actualTarget) + } + + return nil +} + +type scanPlanJSONBCodecBinaryUnwrapper struct { + textPlan ScanPlan +} + +func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { - *dst = JSONB{} - return nil + return plan.textPlan.Scan(ci, oid, formatCode, src, dst) } if len(src) == 0 { @@ -41,42 +70,58 @@ func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { return fmt.Errorf("unknown jsonb version number %d", src[0]) } - *dst = JSONB{Bytes: src[1:], Valid: true} - return nil - + return plan.textPlan.Scan(ci, oid, formatCode, src[1:], dst) } -func (JSONB) PreferredParamFormat() int16 { - return TextFormatCode -} - -func (src JSONB) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (JSON)(src).EncodeText(ci, buf) -} - -func (src JSONB) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { +func (c JSONBCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { return nil, nil } - buf = append(buf, 1) - return append(buf, src.Bytes...), nil + switch format { + case BinaryFormatCode: + if len(src) == 0 { + return nil, fmt.Errorf("jsonb too short") + } + + if src[0] != 1 { + return nil, fmt.Errorf("unknown jsonb version number %d", src[0]) + } + + dstBuf := make([]byte, len(src)-1) + copy(dstBuf, src[1:]) + return dstBuf, nil + case TextFormatCode: + dstBuf := make([]byte, len(src)) + copy(dstBuf, src) + return dstBuf, nil + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } } -// Scan implements the database/sql Scanner interface. -func (dst *JSONB) Scan(src interface{}) error { - return (*JSON)(dst).Scan(src) -} +func (c JSONBCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } -// Value implements the database/sql/driver Valuer interface. -func (src JSONB) Value() (driver.Value, error) { - return (JSON)(src).Value() -} + switch format { + case BinaryFormatCode: + if len(src) == 0 { + return nil, fmt.Errorf("jsonb too short") + } -func (src JSONB) MarshalJSON() ([]byte, error) { - return (JSON)(src).MarshalJSON() -} + if src[0] != 1 { + return nil, fmt.Errorf("unknown jsonb version number %d", src[0]) + } -func (dst *JSONB) UnmarshalJSON(b []byte) error { - return (*JSON)(dst).UnmarshalJSON(b) + src = src[1:] + case TextFormatCode: + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } + + var dst interface{} + err := json.Unmarshal(src, &dst) + return dst, err } diff --git a/pgtype/jsonb_array.go b/pgtype/jsonb_array.go deleted file mode 100644 index 81ed9f29..00000000 --- a/pgtype/jsonb_array.go +++ /dev/null @@ -1,504 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type JSONBArray struct { - Elements []JSONB - Dimensions []ArrayDimension - Valid bool -} - -func (dst *JSONBArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = JSONBArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []string: - if value == nil { - *dst = JSONBArray{} - } else if len(value) == 0 { - *dst = JSONBArray{Valid: true} - } else { - elements := make([]JSONB, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = JSONBArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case [][]byte: - if value == nil { - *dst = JSONBArray{} - } else if len(value) == 0 { - *dst = JSONBArray{Valid: true} - } else { - elements := make([]JSONB, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = JSONBArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []JSONB: - if value == nil { - *dst = JSONBArray{} - } else if len(value) == 0 { - *dst = JSONBArray{Valid: true} - } else { - *dst = JSONBArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = JSONBArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for JSONBArray", src) - } - if elementsLength == 0 { - *dst = JSONBArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to JSONBArray", src) - } - - *dst = JSONBArray{ - Elements: make([]JSONB, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]JSONB, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to JSONBArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *JSONBArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to JSONBArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in JSONBArray", err) - } - index++ - - return index, nil -} - -func (dst JSONBArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *JSONBArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[][]byte: - *v = make([][]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *JSONBArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from JSONBArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from JSONBArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *JSONBArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = JSONBArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []JSONB - - if len(uta.Elements) > 0 { - elements = make([]JSONB, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem JSONB - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = JSONBArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *JSONBArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = JSONBArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = JSONBArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]JSONB, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = JSONBArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src JSONBArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src JSONBArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("jsonb"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "jsonb") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *JSONBArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src JSONBArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/jsonb_array_test.go b/pgtype/jsonb_array_test.go deleted file mode 100644 index 0fc4d40e..00000000 --- a/pgtype/jsonb_array_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestJSONBArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "jsonb[]", []interface{}{ - &pgtype.JSONBArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.JSONBArray{ - Elements: []pgtype.JSONB{ - {Bytes: []byte(`"foo"`), Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.JSONBArray{}, - &pgtype.JSONBArray{ - Elements: []pgtype.JSONB{ - {Bytes: []byte(`"foo"`), Valid: true}, - {Bytes: []byte("null"), Valid: true}, - {Bytes: []byte("42"), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, - Valid: true, - }, - }) -} diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index 3a0d62c2..282caeb1 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -1,142 +1,25 @@ package pgtype_test import ( - "bytes" - "reflect" "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestJSONBTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - if _, ok := conn.ConnInfo().DataTypeForName("jsonb"); !ok { - t.Skip("Skipping due to no jsonb type") - } - - testutil.TestSuccessfulTranscode(t, "jsonb", []interface{}{ - &pgtype.JSONB{Bytes: []byte("{}"), Valid: true}, - &pgtype.JSONB{Bytes: []byte("null"), Valid: true}, - &pgtype.JSONB{Bytes: []byte("42"), Valid: true}, - &pgtype.JSONB{Bytes: []byte(`"hello"`), Valid: true}, - &pgtype.JSONB{}, - }) -} - -func TestJSONBSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.JSONB - }{ - {source: "{}", result: pgtype.JSONB{Bytes: []byte("{}"), Valid: true}}, - {source: []byte("{}"), result: pgtype.JSONB{Bytes: []byte("{}"), Valid: true}}, - {source: ([]byte)(nil), result: pgtype.JSONB{}}, - {source: (*string)(nil), result: pgtype.JSONB{}}, - {source: []int{1, 2, 3}, result: pgtype.JSONB{Bytes: []byte("[1,2,3]"), Valid: true}}, - {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Valid: true}}, - } - - for i, tt := range successfulTests { - var d pgtype.JSONB - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(d, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } - } -} - -func TestJSONBAssignTo(t *testing.T) { - var s string - var ps *string - var b []byte - - rawStringTests := []struct { - src pgtype.JSONB - dst *string - expected string - }{ - {src: pgtype.JSONB{Bytes: []byte("{}"), Valid: true}, dst: &s, expected: "{}"}, - } - - for i, tt := range rawStringTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if *tt.dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } - - rawBytesTests := []struct { - src pgtype.JSONB - dst *[]byte - expected []byte - }{ - {src: pgtype.JSONB{Bytes: []byte("{}"), Valid: true}, dst: &b, expected: []byte("{}")}, - {src: pgtype.JSONB{}, dst: &b, expected: (([]byte)(nil))}, - } - - for i, tt := range rawBytesTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if bytes.Compare(tt.expected, *tt.dst) != 0 { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } - - var mapDst map[string]interface{} - type structDst struct { + type jsonStruct struct { Name string `json:"name"` Age int `json:"age"` } - var strDst structDst - unmarshalTests := []struct { - src pgtype.JSONB - dst interface{} - expected interface{} - }{ - {src: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Valid: true}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, - {src: pgtype.JSONB{Bytes: []byte(`{"name":"John","age":42}`), Valid: true}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, - } - for i, tt := range unmarshalTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.JSONB - dst **string - expected *string - }{ - {src: pgtype.JSONB{}, dst: &ps, expected: ((*string)(nil))}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if *tt.dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } + testPgxCodec(t, "jsonb", []PgxTranscodeTestCase{ + {[]byte("{}"), new([]byte), isExpectedEqBytes([]byte("{}"))}, + {[]byte("null"), new([]byte), isExpectedEqBytes([]byte("null"))}, + {[]byte("42"), new([]byte), isExpectedEqBytes([]byte("42"))}, + {[]byte(`"hello"`), new([]byte), isExpectedEqBytes([]byte(`"hello"`))}, + {[]byte(`"hello"`), new(string), isExpectedEq(`"hello"`)}, + {map[string]interface{}{"foo": "bar"}, new(map[string]interface{}), isExpectedEqMap(map[string]interface{}{"foo": "bar"})}, + {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, + {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, + {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, + }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index a6d77356..744671ab 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -27,6 +27,7 @@ const ( XIDOID = 28 CIDOID = 29 JSONOID = 114 + JSONArrayOID = 199 PointOID = 600 LsegOID = 601 PathOID = 602 @@ -289,6 +290,8 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID}) ci.RegisterDataType(DataType{Value: &TimestamptzArray{}, Name: "_timestamptz", OID: TimestamptzArrayOID}) ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID}) + ci.RegisterDataType(DataType{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementCodec: JSONBCodec{}, ElementOID: JSONBOID}}) + ci.RegisterDataType(DataType{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementCodec: JSONCodec{}, ElementOID: JSONOID}}) ci.RegisterDataType(DataType{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: VarcharOID}}) ci.RegisterDataType(DataType{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementCodec: BitsCodec{}, ElementOID: BitOID}}) ci.RegisterDataType(DataType{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementCodec: BitsCodec{}, ElementOID: VarbitOID}}) @@ -316,9 +319,8 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) // ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) ci.RegisterDataType(DataType{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) - ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID}) - ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID}) - ci.RegisterDataType(DataType{Value: &JSONBArray{}, Name: "_jsonb", OID: JSONBArrayOID}) + ci.RegisterDataType(DataType{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) + ci.RegisterDataType(DataType{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) ci.RegisterDataType(DataType{Name: "line", OID: LineOID, Codec: LineCodec{}}) ci.RegisterDataType(DataType{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID}) diff --git a/stdlib/sql.go b/stdlib/sql.go index cbb8544e..40693ded 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -670,25 +670,15 @@ func (r *Rows) Next(dest []driver.Value) error { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return d, err } - case pgtype.JSONOID: - var d pgtype.JSON + case pgtype.JSONOID, pgtype.JSONBOID: + var d []byte scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } - return d.Value() - } - case pgtype.JSONBOID: - var d pgtype.JSONB - scanPlan := ci.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) - if err != nil { - return nil, err - } - return d.Value() + return d, nil } case pgtype.TimestampOID: var d pgtype.Timestamp diff --git a/values.go b/values.go index a60d4129..b5ce4f7c 100644 --- a/values.go +++ b/values.go @@ -35,32 +35,6 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e } switch arg := arg.(type) { - - // https://github.com/jackc/pgx/issues/409 Changed JSON and JSONB to surface - // []byte to database/sql instead of string. But that caused problems with the - // simple protocol because the driver.Valuer case got taken before the - // pgtype.TextEncoder case. And driver.Valuer needed to be first in the usual - // case because of https://github.com/jackc/pgx/issues/339. So instead we - // special case JSON and JSONB. - case *pgtype.JSON: - buf, err := arg.EncodeText(ci, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil - case *pgtype.JSONB: - buf, err := arg.EncodeText(ci, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil - case driver.Valuer: return callValuerValue(arg) case pgtype.TextEncoder: From 99fb8cf2f33d61b19dc417af17d02fc82eac306e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 18 Jan 2022 21:49:38 -0600 Subject: [PATCH 0851/1158] Convert timestamp and timestamptz to Codec --- pgtype/builtin_wrappers.go | 44 +++ pgtype/pgtype.go | 8 +- pgtype/timestamp.go | 410 ++++++++++++++----------- pgtype/timestamp_array.go | 505 ------------------------------- pgtype/timestamp_array_test.go | 307 ------------------- pgtype/timestamp_test.go | 186 ++---------- pgtype/timestamptz.go | 437 ++++++++++++++------------ pgtype/timestamptz_array.go | 505 ------------------------------- pgtype/timestamptz_array_test.go | 343 --------------------- pgtype/timestamptz_test.go | 181 ++--------- pgtype/zeronull/timestamp.go | 82 ++--- pgtype/zeronull/timestamptz.go | 106 +++---- 12 files changed, 627 insertions(+), 2487 deletions(-) delete mode 100644 pgtype/timestamp_array.go delete mode 100644 pgtype/timestamp_array_test.go delete mode 100644 pgtype/timestamptz_array.go delete mode 100644 pgtype/timestamptz_array_test.go diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 5689b321..3c8d23fb 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -342,6 +342,50 @@ func (w timeWrapper) DateValue() (Date, error) { return Date{Time: time.Time(w), Valid: true}, nil } +func (w *timeWrapper) ScanTimestamp(v Timestamp) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Time") + } + + switch v.InfinityModifier { + case None: + *w = timeWrapper(v.Time) + return nil + case Infinity: + return fmt.Errorf("cannot scan Infinity into *time.Time") + case NegativeInfinity: + return fmt.Errorf("cannot scan -Infinity into *time.Time") + default: + return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) + } +} + +func (w timeWrapper) TimestampValue() (Timestamp, error) { + return Timestamp{Time: time.Time(w), Valid: true}, nil +} + +func (w *timeWrapper) ScanTimestamptz(v Timestamptz) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Time") + } + + switch v.InfinityModifier { + case None: + *w = timeWrapper(v.Time) + return nil + case Infinity: + return fmt.Errorf("cannot scan Infinity into *time.Time") + case NegativeInfinity: + return fmt.Errorf("cannot scan -Infinity into *time.Time") + default: + return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) + } +} + +func (w timeWrapper) TimestamptzValue() (Timestamptz, error) { + return Timestamptz{Time: time.Time(w), Valid: true}, nil +} + type durationWrapper time.Duration func (w *durationWrapper) ScanInterval(v Interval) error { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 744671ab..5072d061 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -287,8 +287,8 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: NameOID}}) ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) ci.RegisterDataType(DataType{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: TextOID}}) - ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID}) - ci.RegisterDataType(DataType{Value: &TimestamptzArray{}, Name: "_timestamptz", OID: TimestamptzArrayOID}) + ci.RegisterDataType(DataType{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementCodec: TimestampCodec{}, ElementOID: TimestampOID}}) + ci.RegisterDataType(DataType{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementCodec: TimestamptzCodec{}, ElementOID: TimestamptzOID}}) ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID}) ci.RegisterDataType(DataType{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementCodec: JSONBCodec{}, ElementOID: JSONBOID}}) ci.RegisterDataType(DataType{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementCodec: JSONCodec{}, ElementOID: JSONOID}}) @@ -335,8 +335,8 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "text", OID: TextOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID}) ci.RegisterDataType(DataType{Value: &Time{}, Name: "time", OID: TimeOID}) - ci.RegisterDataType(DataType{Value: &Timestamp{}, Name: "timestamp", OID: TimestampOID}) - ci.RegisterDataType(DataType{Value: &Timestamptz{}, Name: "timestamptz", OID: TimestamptzOID}) + ci.RegisterDataType(DataType{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) + ci.RegisterDataType(DataType{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) // ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) // ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID}) // ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 882cd41a..374aafe4 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -11,203 +11,42 @@ import ( const pgTimestampFormat = "2006-01-02 15:04:05.999999999" -// Timestamp represents the PostgreSQL timestamp type. The PostgreSQL -// timestamp does not have a time zone. This presents a problem when -// translating to and from time.Time which requires a time zone. It is highly -// recommended to use timestamptz whenever possible. Timestamp methods either -// convert to UTC or return an error on non-UTC times. +type TimestampScanner interface { + ScanTimestamp(v Timestamp) error +} + +type TimestampValuer interface { + TimestampValue() (Timestamp, error) +} + +// Timestamp represents the PostgreSQL timestamp type. type Timestamp struct { - Time time.Time // Time must always be in UTC. - Valid bool + Time time.Time // Time zone will be ignored when encoding to PostgreSQL. InfinityModifier InfinityModifier + Valid bool } -// Set converts src into a Timestamp and stores in dst. If src is a -// time.Time in a non-UTC time zone, the time zone is discarded. -func (dst *Timestamp) Set(src interface{}) error { - if src == nil { - *dst = Timestamp{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case time.Time: - *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Valid: true} - case *time.Time: - if value == nil { - *dst = Timestamp{} - } else { - return dst.Set(*value) - } - case InfinityModifier: - *dst = Timestamp{InfinityModifier: value, Valid: true} - default: - if originalSrc, ok := underlyingTimeType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Timestamp", value) - } - +func (ts *Timestamp) ScanTimestamp(v Timestamp) error { + *ts = v return nil } -func (dst Timestamp) Get() interface{} { - if !dst.Valid { - return nil - } - if dst.InfinityModifier != None { - return dst.InfinityModifier - } - return dst.Time -} - -func (src *Timestamp) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case *time.Time: - if src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } -} - -// DecodeText decodes from src into dst. The decoded time is considered to -// be in UTC. -func (dst *Timestamp) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Timestamp{} - return nil - } - - sbuf := string(src) - switch sbuf { - case "infinity": - *dst = Timestamp{Valid: true, InfinityModifier: Infinity} - case "-infinity": - *dst = Timestamp{Valid: true, InfinityModifier: -Infinity} - default: - tim, err := time.Parse(pgTimestampFormat, sbuf) - if err != nil { - return err - } - - *dst = Timestamp{Time: tim, Valid: true} - } - - return nil -} - -// DecodeBinary decodes from src into dst. The decoded time is considered to -// be in UTC. -func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Timestamp{} - return nil - } - - if len(src) != 8 { - return fmt.Errorf("invalid length for timestamp: %v", len(src)) - } - - microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) - - switch microsecSinceY2K { - case infinityMicrosecondOffset: - *dst = Timestamp{Valid: true, InfinityModifier: Infinity} - case negativeInfinityMicrosecondOffset: - *dst = Timestamp{Valid: true, InfinityModifier: -Infinity} - default: - tim := time.Unix( - microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, - (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), - ).UTC() - *dst = Timestamp{Time: tim, Valid: true} - } - - return nil -} - -// EncodeText writes the text encoding of src into w. If src.Time is not in -// the UTC time zone it returns an error. -func (src Timestamp) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - if src.Time.Location() != time.UTC { - return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") - } - - var s string - - switch src.InfinityModifier { - case None: - s = src.Time.Truncate(time.Microsecond).Format(pgTimestampFormat) - case Infinity: - s = "infinity" - case NegativeInfinity: - s = "-infinity" - } - - return append(buf, s...), nil -} - -// EncodeBinary writes the binary encoding of src into w. If src.Time is not in -// the UTC time zone it returns an error. -func (src Timestamp) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - if src.Time.Location() != time.UTC { - return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") - } - - var microsecSinceY2K int64 - switch src.InfinityModifier { - case None: - microsecSinceUnixEpoch := src.Time.Unix()*1000000 + int64(src.Time.Nanosecond())/1000 - microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K - case Infinity: - microsecSinceY2K = infinityMicrosecondOffset - case NegativeInfinity: - microsecSinceY2K = negativeInfinityMicrosecondOffset - } - - return pgio.AppendInt64(buf, microsecSinceY2K), nil +func (ts Timestamp) TimestampValue() (Timestamp, error) { + return ts, nil } // Scan implements the database/sql Scanner interface. -func (dst *Timestamp) Scan(src interface{}) error { +func (ts *Timestamp) Scan(src interface{}) error { if src == nil { - *dst = Timestamp{} + *ts = Timestamp{} return nil } switch src := src.(type) { case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + return scanPlanTextTimestampToTimestampScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), ts) case time.Time: - *dst = Timestamp{Time: src, Valid: true} + *ts = Timestamp{Time: src, Valid: true} return nil } @@ -215,13 +54,214 @@ func (dst *Timestamp) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Timestamp) Value() (driver.Value, error) { - if !src.Valid { +func (ts Timestamp) Value() (driver.Value, error) { + if !ts.Valid { return nil, nil } - if src.InfinityModifier != None { - return src.InfinityModifier.String(), nil + if ts.InfinityModifier != None { + return ts.InfinityModifier.String(), nil } - return src.Time, nil + return ts.Time, nil +} + +type TimestampCodec struct{} + +func (TimestampCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (TimestampCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (TimestampCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(TimestampValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanTimestampCodecBinary{} + case TextFormatCode: + return encodePlanTimestampCodecText{} + } + + return nil +} + +type encodePlanTimestampCodecBinary struct{} + +func (encodePlanTimestampCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + ts, err := value.(TimestampValuer).TimestampValue() + if err != nil { + return nil, err + } + + if !ts.Valid { + return nil, nil + } + + var microsecSinceY2K int64 + switch ts.InfinityModifier { + case None: + t := discardTimeZone(ts.Time) + microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + buf = pgio.AppendInt64(buf, microsecSinceY2K) + + return buf, nil +} + +type encodePlanTimestampCodecText struct{} + +func (encodePlanTimestampCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + ts, err := value.(TimestampValuer).TimestampValue() + if err != nil { + return nil, err + } + + var s string + + switch ts.InfinityModifier { + case None: + t := discardTimeZone(ts.Time) + s = t.Truncate(time.Microsecond).Format(pgTimestampFormat) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + buf = append(buf, s...) + + return buf, nil +} + +func discardTimeZone(t time.Time) time.Time { + if t.Location() != time.UTC { + return time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC) + } + + return t +} + +func (TimestampCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case TimestampScanner: + return scanPlanBinaryTimestampToTimestampScanner{} + } + case TextFormatCode: + switch target.(type) { + case TimestampScanner: + return scanPlanTextTimestampToTimestampScanner{} + } + } + + return nil +} + +type scanPlanBinaryTimestampToTimestampScanner struct{} + +func (scanPlanBinaryTimestampToTimestampScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(TimestampScanner) + + if src == nil { + return scanner.ScanTimestamp(Timestamp{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for timestamp: %v", len(src)) + } + + var ts Timestamp + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + ts = Timestamp{Valid: true, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + ts = Timestamp{Valid: true, InfinityModifier: -Infinity} + default: + tim := time.Unix( + microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, + (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), + ).UTC() + ts = Timestamp{Time: tim, Valid: true} + } + + return scanner.ScanTimestamp(ts) +} + +type scanPlanTextTimestampToTimestampScanner struct{} + +func (scanPlanTextTimestampToTimestampScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(TimestampScanner) + + if src == nil { + return scanner.ScanTimestamp(Timestamp{}) + } + + var ts Timestamp + sbuf := string(src) + switch sbuf { + case "infinity": + ts = Timestamp{Valid: true, InfinityModifier: Infinity} + case "-infinity": + ts = Timestamp{Valid: true, InfinityModifier: -Infinity} + default: + tim, err := time.Parse(pgTimestampFormat, sbuf) + if err != nil { + return err + } + + ts = Timestamp{Time: tim, Valid: true} + } + + return scanner.ScanTimestamp(ts) +} + +func (c TimestampCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var ts Timestamp + err := codecScan(c, ci, oid, format, src, &ts) + if err != nil { + return nil, err + } + + if ts.InfinityModifier != None { + return ts.InfinityModifier.String(), nil + } + + return ts.Time, nil +} + +func (c TimestampCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var ts Timestamp + err := codecScan(c, ci, oid, format, src, &ts) + if err != nil { + return nil, err + } + + if ts.InfinityModifier != None { + return ts.InfinityModifier, nil + } + + return ts.Time, nil } diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go deleted file mode 100644 index fbf7c48a..00000000 --- a/pgtype/timestamp_array.go +++ /dev/null @@ -1,505 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - "time" - - "github.com/jackc/pgio" -) - -type TimestampArray struct { - Elements []Timestamp - Dimensions []ArrayDimension - Valid bool -} - -func (dst *TimestampArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = TimestampArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []time.Time: - if value == nil { - *dst = TimestampArray{} - } else if len(value) == 0 { - *dst = TimestampArray{Valid: true} - } else { - elements := make([]Timestamp, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestampArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*time.Time: - if value == nil { - *dst = TimestampArray{} - } else if len(value) == 0 { - *dst = TimestampArray{Valid: true} - } else { - elements := make([]Timestamp, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestampArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Timestamp: - if value == nil { - *dst = TimestampArray{} - } else if len(value) == 0 { - *dst = TimestampArray{Valid: true} - } else { - *dst = TimestampArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = TimestampArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for TimestampArray", src) - } - if elementsLength == 0 { - *dst = TimestampArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to TimestampArray", src) - } - - *dst = TimestampArray{ - Elements: make([]Timestamp, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Timestamp, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to TimestampArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *TimestampArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to TimestampArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in TimestampArray", err) - } - index++ - - return index, nil -} - -func (dst TimestampArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *TimestampArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*time.Time: - *v = make([]*time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *TimestampArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from TimestampArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from TimestampArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TimestampArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Timestamp - - if len(uta.Elements) > 0 { - elements = make([]Timestamp, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Timestamp - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = TimestampArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TimestampArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = TimestampArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Timestamp, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = TimestampArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("timestamp"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "timestamp") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *TimestampArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src TimestampArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/timestamp_array_test.go b/pgtype/timestamp_array_test.go deleted file mode 100644 index f78bf14e..00000000 --- a/pgtype/timestamp_array_test.go +++ /dev/null @@ -1,307 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - "time" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestTimestampArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp[]", []interface{}{ - &pgtype.TimestampArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.TimestampArray{}, - &pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {}, - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }, func(a, b interface{}) bool { - ata := a.(pgtype.TimestampArray) - bta := b.(pgtype.TimestampArray) - - if len(ata.Elements) != len(bta.Elements) || ata.Valid != bta.Valid { - return false - } - - for i := range ata.Elements { - ae, be := ata.Elements[i], bta.Elements[i] - if !(ae.Time.Equal(be.Time) && ae.Valid == be.Valid && ae.InfinityModifier == be.InfinityModifier) { - return false - } - } - - return true - }) -} - -func TestTimestampArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.TimestampArray - }{ - { - source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - result: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]time.Time)(nil)), - result: pgtype.TimestampArray{}, - }, - { - source: [][]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - result: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - result: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.TimestampArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestTimestampArrayAssignTo(t *testing.T) { - var timeSlice []time.Time - var timeSliceDim2 [][]time.Time - var timeSliceDim4 [][][][]time.Time - var timeArrayDim2 [2][1]time.Time - var timeArrayDim4 [2][1][1][3]time.Time - - simpleTests := []struct { - src pgtype.TimestampArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &timeSlice, - expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - }, - { - src: pgtype.TimestampArray{}, - dst: &timeSlice, - expected: (([]time.Time)(nil)), - }, - { - src: pgtype.TimestampArray{Valid: true}, - dst: &timeSlice, - expected: []time.Time{}, - }, - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &timeSliceDim2, - expected: [][]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - }, - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &timeSliceDim4, - expected: [][][][]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - }, - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &timeArrayDim2, - expected: [2][1]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - }, - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &timeArrayDim4, - expected: [2][1][1][3]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.TimestampArray - dst interface{} - }{ - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &timeSlice, - }, - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &timeArrayDim2, - }, - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &timeSlice, - }, - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &timeArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index 85de5c94..1caca58b 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -2,7 +2,6 @@ package pgtype_test import ( "context" - "reflect" "testing" "time" @@ -11,41 +10,35 @@ import ( "github.com/stretchr/testify/require" ) -func TestTimestampTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ - &pgtype.Timestamp{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - &pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - &pgtype.Timestamp{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - &pgtype.Timestamp{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - &pgtype.Timestamp{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - &pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - &pgtype.Timestamp{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, - &pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - &pgtype.Timestamp{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true}, - &pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - &pgtype.Timestamp{}, - &pgtype.Timestamp{Valid: true, InfinityModifier: pgtype.Infinity}, - &pgtype.Timestamp{Valid: true, InfinityModifier: -pgtype.Infinity}, - }, func(a, b interface{}) bool { - at := a.(pgtype.Timestamp) - bt := b.(pgtype.Timestamp) +func TestTimestampCodec(t *testing.T) { + testPgxCodec(t, "timestamp", []PgxTranscodeTestCase{ + {time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC))}, + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC))}, + {time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC))}, - return at.Time.Equal(bt.Time) && at.Valid == bt.Valid && at.InfinityModifier == bt.InfinityModifier + // Nanosecond truncation + {time.Date(2020, 1, 1, 0, 0, 0, 999999999, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.UTC))}, + {time.Date(2020, 1, 1, 0, 0, 0, 999999001, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.UTC))}, + + {pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true}, new(pgtype.Timestamp), isExpectedEq(pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true})}, + {pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(pgtype.Timestamp), isExpectedEq(pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true})}, + {pgtype.Timestamp{}, new(pgtype.Timestamp), isExpectedEq(pgtype.Timestamp{})}, + {nil, new(*time.Time), isExpectedEq((*time.Time)(nil))}, }) } // https://github.com/jackc/pgx/v4/pgtype/pull/128 func TestTimestampTranscodeBigTimeBinary(t *testing.T) { conn := testutil.MustConnectPgx(t) - if _, ok := conn.ConnInfo().DataTypeForName("line"); !ok { - t.Skip("Skipping due to no line type") - } defer testutil.MustCloseContext(t, conn) in := &pgtype.Timestamp{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Valid: true} var out pgtype.Timestamp - err := conn.QueryRow(context.Background(), "select $1::timestamptz", in).Scan(&out) + err := conn.QueryRow(context.Background(), "select $1::timestamp", in).Scan(&out) if err != nil { t.Fatal(err) } @@ -54,146 +47,11 @@ func TestTimestampTranscodeBigTimeBinary(t *testing.T) { require.Truef(t, in.Time.Equal(out.Time), "expected %v got %v", in.Time, out.Time) } -func TestTimestampNanosecondsTruncated(t *testing.T) { - tests := []struct { - input time.Time - expected time.Time - }{ - {time.Date(2020, 1, 1, 0, 0, 0, 999999999, time.UTC), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.UTC)}, - {time.Date(2020, 1, 1, 0, 0, 0, 999999001, time.UTC), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.UTC)}, - } - for i, tt := range tests { - { - ts := pgtype.Timestamp{Time: tt.input, Valid: true} - buf, err := ts.EncodeText(nil, nil) - if err != nil { - t.Errorf("%d. EncodeText failed - %v", i, err) - } - - ts.DecodeText(nil, buf) - if err != nil { - t.Errorf("%d. DecodeText failed - %v", i, err) - } - - if !(ts.Valid && ts.Time.Equal(tt.expected)) { - t.Errorf("%d. EncodeText did not truncate nanoseconds", i) - } - } - - { - ts := pgtype.Timestamp{Time: tt.input, Valid: true} - buf, err := ts.EncodeBinary(nil, nil) - if err != nil { - t.Errorf("%d. EncodeBinary failed - %v", i, err) - } - - ts.DecodeBinary(nil, buf) - if err != nil { - t.Errorf("%d. DecodeBinary failed - %v", i, err) - } - - if !(ts.Valid && ts.Time.Equal(tt.expected)) { - t.Errorf("%d. EncodeBinary did not truncate nanoseconds", i) - } - } - } -} - // https://github.com/jackc/pgtype/issues/74 -func TestTimestampDecodeTextInvalid(t *testing.T) { - tstz := &pgtype.Timestamp{} - err := tstz.DecodeText(nil, []byte(`eeeee`)) +func TestTimestampCodecDecodeTextInvalid(t *testing.T) { + c := &pgtype.TimestampCodec{} + var ts pgtype.Timestamp + plan := c.PlanScan(nil, pgtype.TimestampOID, pgtype.TextFormatCode, &ts, false) + err := plan.Scan(nil, pgtype.TimestampOID, pgtype.TextFormatCode, []byte(`eeeee`), &ts) require.Error(t, err) } - -func TestTimestampSet(t *testing.T) { - type _time time.Time - - successfulTests := []struct { - source interface{} - result pgtype.Timestamp - }{ - {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), Valid: true}}, - {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), Valid: true}}, - {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - {source: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - {source: pgtype.Infinity, result: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true}}, - {source: pgtype.NegativeInfinity, result: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.Timestamp - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestTimestampAssignTo(t *testing.T) { - var tim time.Time - var ptim *time.Time - - simpleTests := []struct { - src pgtype.Timestamp - dst interface{} - expected interface{} - }{ - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)}, - {src: pgtype.Timestamp{Time: time.Time{}}, dst: &ptim, expected: ((*time.Time)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Timestamp - dst interface{} - expected interface{} - }{ - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Timestamp - dst interface{} - }{ - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Valid: true}, dst: &tim}, - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Valid: true}, dst: &tim}, - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, dst: &tim}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 2a711ffa..eec1dca5 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -20,88 +20,252 @@ const ( infinityMicrosecondOffset = 9223372036854775807 ) -type Timestamptz struct { - Time time.Time - Valid bool - InfinityModifier InfinityModifier +type TimestamptzScanner interface { + ScanTimestamptz(v Timestamptz) error } -func (dst *Timestamptz) Set(src interface{}) error { +type TimestamptzValuer interface { + TimestamptzValue() (Timestamptz, error) +} + +// Timestamptz represents the PostgreSQL timestamptz type. +type Timestamptz struct { + Time time.Time + InfinityModifier InfinityModifier + Valid bool +} + +func (tstz *Timestamptz) ScanTimestamptz(v Timestamptz) error { + *tstz = v + return nil +} + +func (tstz Timestamptz) TimestamptzValue() (Timestamptz, error) { + return tstz, nil +} + +// Scan implements the database/sql Scanner interface. +func (tstz *Timestamptz) Scan(src interface{}) error { if src == nil { - *dst = Timestamptz{} + *tstz = Timestamptz{} return nil } - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } + switch src := src.(type) { + case string: + return scanPlanTextTimestamptzToTimestamptzScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), tstz) + case time.Time: + *tstz = Timestamptz{Time: src, Valid: true} + return nil } - switch value := src.(type) { - case time.Time: - *dst = Timestamptz{Time: value, Valid: true} - case *time.Time: - if value == nil { - *dst = Timestamptz{} - } else { - return dst.Set(*value) - } - case InfinityModifier: - *dst = Timestamptz{InfinityModifier: value, Valid: true} + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (tstz Timestamptz) Value() (driver.Value, error) { + if !tstz.Valid { + return nil, nil + } + + if tstz.InfinityModifier != None { + return tstz.InfinityModifier.String(), nil + } + return tstz.Time, nil +} + +func (tstz Timestamptz) MarshalJSON() ([]byte, error) { + if !tstz.Valid { + return []byte("null"), nil + } + + var s string + + switch tstz.InfinityModifier { + case None: + s = tstz.Time.Format(time.RFC3339Nano) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) +} + +func (tstz *Timestamptz) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *tstz = Timestamptz{} + return nil + } + + switch *s { + case "infinity": + *tstz = Timestamptz{Valid: true, InfinityModifier: Infinity} + case "-infinity": + *tstz = Timestamptz{Valid: true, InfinityModifier: -Infinity} default: - if originalSrc, ok := underlyingTimeType(src); ok { - return dst.Set(originalSrc) + // PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz + tim, err := time.Parse(time.RFC3339Nano, *s) + if err != nil { + return err } - return fmt.Errorf("cannot convert %v to Timestamptz", value) + + *tstz = Timestamptz{Time: tim, Valid: true} } return nil } -func (dst Timestamptz) Get() interface{} { - if !dst.Valid { - return nil - } - if dst.InfinityModifier != None { - return dst.InfinityModifier - } - return dst.Time +type TimestamptzCodec struct{} + +func (TimestamptzCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode } -func (src *Timestamptz) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case *time.Time: - if src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } +func (TimestamptzCodec) PreferredFormat() int16 { + return BinaryFormatCode } -func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { +func (TimestamptzCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(TimestamptzValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanTimestamptzCodecBinary{} + case TextFormatCode: + return encodePlanTimestamptzCodecText{} + } + + return nil +} + +type encodePlanTimestamptzCodecBinary struct{} + +func (encodePlanTimestamptzCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + ts, err := value.(TimestamptzValuer).TimestamptzValue() + if err != nil { + return nil, err + } + + if !ts.Valid { + return nil, nil + } + + var microsecSinceY2K int64 + switch ts.InfinityModifier { + case None: + microsecSinceUnixEpoch := ts.Time.Unix()*1000000 + int64(ts.Time.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + buf = pgio.AppendInt64(buf, microsecSinceY2K) + + return buf, nil +} + +type encodePlanTimestamptzCodecText struct{} + +func (encodePlanTimestamptzCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + ts, err := value.(TimestamptzValuer).TimestamptzValue() + if err != nil { + return nil, err + } + + var s string + + switch ts.InfinityModifier { + case None: + s = ts.Time.UTC().Truncate(time.Microsecond).Format(pgTimestamptzSecondFormat) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + buf = append(buf, s...) + + return buf, nil +} + +func (TimestamptzCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case TimestamptzScanner: + return scanPlanBinaryTimestamptzToTimestamptzScanner{} + } + case TextFormatCode: + switch target.(type) { + case TimestamptzScanner: + return scanPlanTextTimestamptzToTimestamptzScanner{} + } + } + + return nil +} + +type scanPlanBinaryTimestamptzToTimestamptzScanner struct{} + +func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(TimestamptzScanner) + if src == nil { - *dst = Timestamptz{} - return nil + return scanner.ScanTimestamptz(Timestamptz{}) } + if len(src) != 8 { + return fmt.Errorf("invalid length for timestamptz: %v", len(src)) + } + + var tstz Timestamptz + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + tstz = Timestamptz{Valid: true, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + tstz = Timestamptz{Valid: true, InfinityModifier: -Infinity} + default: + tim := time.Unix( + microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, + (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), + ) + tstz = Timestamptz{Time: tim, Valid: true} + } + + return scanner.ScanTimestamptz(tstz) +} + +type scanPlanTextTimestamptzToTimestamptzScanner struct{} + +func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(TimestamptzScanner) + + if src == nil { + return scanner.ScanTimestamptz(Timestamptz{}) + } + + var tstz Timestamptz sbuf := string(src) switch sbuf { case "infinity": - *dst = Timestamptz{Valid: true, InfinityModifier: Infinity} + tstz = Timestamptz{Valid: true, InfinityModifier: Infinity} case "-infinity": - *dst = Timestamptz{Valid: true, InfinityModifier: -Infinity} + tstz = Timestamptz{Valid: true, InfinityModifier: -Infinity} default: var format string if len(sbuf) >= 9 && (sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+') { @@ -117,157 +281,44 @@ func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Timestamptz{Time: tim, Valid: true} + tstz = Timestamptz{Time: tim, Valid: true} } - return nil + return scanner.ScanTimestamptz(tstz) } -func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { +func (c TimestamptzCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { - *dst = Timestamptz{} - return nil - } - - if len(src) != 8 { - return fmt.Errorf("invalid length for timestamptz: %v", len(src)) - } - - microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) - - switch microsecSinceY2K { - case infinityMicrosecondOffset: - *dst = Timestamptz{Valid: true, InfinityModifier: Infinity} - case negativeInfinityMicrosecondOffset: - *dst = Timestamptz{Valid: true, InfinityModifier: -Infinity} - default: - tim := time.Unix( - microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, - (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), - ) - *dst = Timestamptz{Time: tim, Valid: true} - } - - return nil -} - -func (src Timestamptz) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { return nil, nil } - var s string - - switch src.InfinityModifier { - case None: - s = src.Time.UTC().Truncate(time.Microsecond).Format(pgTimestamptzSecondFormat) - case Infinity: - s = "infinity" - case NegativeInfinity: - s = "-infinity" - } - - return append(buf, s...), nil -} - -func (src Timestamptz) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - var microsecSinceY2K int64 - switch src.InfinityModifier { - case None: - microsecSinceUnixEpoch := src.Time.Unix()*1000000 + int64(src.Time.Nanosecond())/1000 - microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K - case Infinity: - microsecSinceY2K = infinityMicrosecondOffset - case NegativeInfinity: - microsecSinceY2K = negativeInfinityMicrosecondOffset - } - - return pgio.AppendInt64(buf, microsecSinceY2K), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Timestamptz) Scan(src interface{}) error { - if src == nil { - *dst = Timestamptz{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - case time.Time: - *dst = Timestamptz{Time: src, Valid: true} - return nil - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Timestamptz) Value() (driver.Value, error) { - if !src.Valid { - return nil, nil - } - - if src.InfinityModifier != None { - return src.InfinityModifier.String(), nil - } - return src.Time, nil -} - -func (src Timestamptz) MarshalJSON() ([]byte, error) { - if !src.Valid { - return []byte("null"), nil - } - - var s string - - switch src.InfinityModifier { - case None: - s = src.Time.Format(time.RFC3339Nano) - case Infinity: - s = "infinity" - case NegativeInfinity: - s = "-infinity" - } - - return json.Marshal(s) -} - -func (dst *Timestamptz) UnmarshalJSON(b []byte) error { - var s *string - err := json.Unmarshal(b, &s) + var tstz Timestamptz + err := codecScan(c, ci, oid, format, src, &tstz) if err != nil { - return err + return nil, err } - if s == nil { - *dst = Timestamptz{} - return nil + if tstz.InfinityModifier != None { + return tstz.InfinityModifier.String(), nil } - switch *s { - case "infinity": - *dst = Timestamptz{Valid: true, InfinityModifier: Infinity} - case "-infinity": - *dst = Timestamptz{Valid: true, InfinityModifier: -Infinity} - default: - // PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz - tim, err := time.Parse(time.RFC3339Nano, *s) - if err != nil { - return err - } + return tstz.Time, nil +} - *dst = Timestamptz{Time: tim, Valid: true} +func (c TimestamptzCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil } - return nil + var tstz Timestamptz + err := codecScan(c, ci, oid, format, src, &tstz) + if err != nil { + return nil, err + } + + if tstz.InfinityModifier != None { + return tstz.InfinityModifier, nil + } + + return tstz.Time, nil } diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go deleted file mode 100644 index 4523b251..00000000 --- a/pgtype/timestamptz_array.go +++ /dev/null @@ -1,505 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - "time" - - "github.com/jackc/pgio" -) - -type TimestamptzArray struct { - Elements []Timestamptz - Dimensions []ArrayDimension - Valid bool -} - -func (dst *TimestamptzArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = TimestamptzArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []time.Time: - if value == nil { - *dst = TimestamptzArray{} - } else if len(value) == 0 { - *dst = TimestamptzArray{Valid: true} - } else { - elements := make([]Timestamptz, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestamptzArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*time.Time: - if value == nil { - *dst = TimestamptzArray{} - } else if len(value) == 0 { - *dst = TimestamptzArray{Valid: true} - } else { - elements := make([]Timestamptz, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestamptzArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Timestamptz: - if value == nil { - *dst = TimestamptzArray{} - } else if len(value) == 0 { - *dst = TimestamptzArray{Valid: true} - } else { - *dst = TimestamptzArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = TimestamptzArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for TimestamptzArray", src) - } - if elementsLength == 0 { - *dst = TimestamptzArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to TimestamptzArray", src) - } - - *dst = TimestamptzArray{ - Elements: make([]Timestamptz, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Timestamptz, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to TimestamptzArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *TimestamptzArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to TimestamptzArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in TimestamptzArray", err) - } - index++ - - return index, nil -} - -func (dst TimestamptzArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *TimestamptzArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*time.Time: - *v = make([]*time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *TimestamptzArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from TimestamptzArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from TimestamptzArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TimestamptzArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Timestamptz - - if len(uta.Elements) > 0 { - elements = make([]Timestamptz, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Timestamptz - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = TimestamptzArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TimestamptzArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = TimestamptzArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Timestamptz, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = TimestamptzArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("timestamptz"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "timestamptz") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *TimestamptzArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src TimestamptzArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/timestamptz_array_test.go b/pgtype/timestamptz_array_test.go deleted file mode 100644 index e00b7d7f..00000000 --- a/pgtype/timestamptz_array_test.go +++ /dev/null @@ -1,343 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - "time" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestTimestamptzArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz[]", []interface{}{ - &pgtype.TimestamptzArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.TimestamptzArray{}, - &pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {}, - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }, func(a, b interface{}) bool { - ata := a.(pgtype.TimestamptzArray) - bta := b.(pgtype.TimestamptzArray) - - if len(ata.Elements) != len(bta.Elements) || ata.Valid != bta.Valid { - return false - } - - for i := range ata.Elements { - ae, be := ata.Elements[i], bta.Elements[i] - if !(ae.Time.Equal(be.Time) && ae.Valid == be.Valid && ae.InfinityModifier == be.InfinityModifier) { - return false - } - } - - return true - }) -} - -func TestTimestamptzArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.TimestamptzArray - }{ - { - source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - result: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]time.Time)(nil)), - result: pgtype.TimestamptzArray{}, - }, - { - source: [][]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - result: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - result: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - result: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - result: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.TimestamptzArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestTimestamptzArrayAssignTo(t *testing.T) { - var timeSlice []time.Time - var timeSliceDim2 [][]time.Time - var timeSliceDim4 [][][][]time.Time - var timeArrayDim2 [2][1]time.Time - var timeArrayDim4 [2][1][1][3]time.Time - - simpleTests := []struct { - src pgtype.TimestamptzArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &timeSlice, - expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - }, - { - src: pgtype.TimestamptzArray{}, - dst: &timeSlice, - expected: (([]time.Time)(nil)), - }, - { - src: pgtype.TimestamptzArray{Valid: true}, - dst: &timeSlice, - expected: []time.Time{}, - }, - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &timeSliceDim2, - expected: [][]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - }, - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &timeSliceDim4, - expected: [][][][]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - }, - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &timeArrayDim2, - expected: [2][1]time.Time{ - {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, - }, - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &timeArrayDim4, - expected: [2][1][1][3]time.Time{ - {{{ - time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), - time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), - time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, - {{{ - time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), - time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), - time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.TimestamptzArray - dst interface{} - }{ - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &timeSlice, - }, - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &timeArrayDim2, - }, - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &timeSlice, - }, - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, - {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &timeArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index 332fc8a7..2a45d2cb 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -2,7 +2,6 @@ package pgtype_test import ( "context" - "reflect" "testing" "time" @@ -11,35 +10,29 @@ import ( "github.com/stretchr/testify/require" ) -func TestTimestamptzTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ - &pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, - &pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, - &pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, - &pgtype.Timestamptz{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, - &pgtype.Timestamptz{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, - &pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, - &pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), Valid: true}, - &pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, - &pgtype.Timestamptz{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), Valid: true}, - &pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, - &pgtype.Timestamptz{}, - &pgtype.Timestamptz{Valid: true, InfinityModifier: pgtype.Infinity}, - &pgtype.Timestamptz{Valid: true, InfinityModifier: -pgtype.Infinity}, - }, func(a, b interface{}) bool { - at := a.(pgtype.Timestamptz) - bt := b.(pgtype.Timestamptz) +func TestTimestamptzCodec(t *testing.T) { + testPgxCodec(t, "timestamptz", []PgxTranscodeTestCase{ + {time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local))}, + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local))}, + {time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local))}, - return at.Time.Equal(bt.Time) && at.Valid == bt.Valid && at.InfinityModifier == bt.InfinityModifier + // Nanosecond truncation + {time.Date(2020, 1, 1, 0, 0, 0, 999999999, time.Local), new(time.Time), isExpectedEqTime(time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.Local))}, + {time.Date(2020, 1, 1, 0, 0, 0, 999999001, time.Local), new(time.Time), isExpectedEqTime(time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.Local))}, + + {pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true}, new(pgtype.Timestamptz), isExpectedEq(pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true})}, + {pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(pgtype.Timestamptz), isExpectedEq(pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true})}, + {pgtype.Timestamptz{}, new(pgtype.Timestamptz), isExpectedEq(pgtype.Timestamptz{})}, + {nil, new(*time.Time), isExpectedEq((*time.Time)(nil))}, }) } // https://github.com/jackc/pgx/v4/pgtype/pull/128 func TestTimestamptzTranscodeBigTimeBinary(t *testing.T) { conn := testutil.MustConnectPgx(t) - if _, ok := conn.ConnInfo().DataTypeForName("line"); !ok { - t.Skip("Skipping due to no line type") - } defer testutil.MustCloseContext(t, conn) in := &pgtype.Timestamptz{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Valid: true} @@ -54,149 +47,15 @@ func TestTimestamptzTranscodeBigTimeBinary(t *testing.T) { require.Truef(t, in.Time.Equal(out.Time), "expected %v got %v", in.Time, out.Time) } -func TestTimestamptzNanosecondsTruncated(t *testing.T) { - tests := []struct { - input time.Time - expected time.Time - }{ - {time.Date(2020, 1, 1, 0, 0, 0, 999999999, time.Local), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.Local)}, - {time.Date(2020, 1, 1, 0, 0, 0, 999999001, time.Local), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.Local)}, - } - for i, tt := range tests { - { - tstz := pgtype.Timestamptz{Time: tt.input, Valid: true} - buf, err := tstz.EncodeText(nil, nil) - if err != nil { - t.Errorf("%d. EncodeText failed - %v", i, err) - } - - tstz.DecodeText(nil, buf) - if err != nil { - t.Errorf("%d. DecodeText failed - %v", i, err) - } - - if !(tstz.Valid && tstz.Time.Equal(tt.expected)) { - t.Errorf("%d. EncodeText did not truncate nanoseconds", i) - } - } - - { - tstz := pgtype.Timestamptz{Time: tt.input, Valid: true} - buf, err := tstz.EncodeBinary(nil, nil) - if err != nil { - t.Errorf("%d. EncodeBinary failed - %v", i, err) - } - - tstz.DecodeBinary(nil, buf) - if err != nil { - t.Errorf("%d. DecodeBinary failed - %v", i, err) - } - - if !(tstz.Valid && tstz.Time.Equal(tt.expected)) { - t.Errorf("%d. EncodeBinary did not truncate nanoseconds", i) - } - } - } -} - // https://github.com/jackc/pgtype/issues/74 func TestTimestamptzDecodeTextInvalid(t *testing.T) { - tstz := &pgtype.Timestamptz{} - err := tstz.DecodeText(nil, []byte(`eeeee`)) + c := &pgtype.TimestamptzCodec{} + var tstz pgtype.Timestamptz + plan := c.PlanScan(nil, pgtype.TimestamptzOID, pgtype.TextFormatCode, &tstz, false) + err := plan.Scan(nil, pgtype.TimestamptzOID, pgtype.TextFormatCode, []byte(`eeeee`), &tstz) require.Error(t, err) } -func TestTimestamptzSet(t *testing.T) { - type _time time.Time - - successfulTests := []struct { - source interface{} - result pgtype.Timestamptz - }{ - {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, - {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Valid: true}}, - {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, - {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), Valid: true}}, - {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, - {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, - {source: pgtype.Infinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true}}, - {source: pgtype.NegativeInfinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.Timestamptz - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestTimestamptzAssignTo(t *testing.T) { - var tim time.Time - var ptim *time.Time - - simpleTests := []struct { - src pgtype.Timestamptz - dst interface{} - expected interface{} - }{ - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, - {src: pgtype.Timestamptz{Time: time.Time{}}, dst: &ptim, expected: ((*time.Time)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Timestamptz - dst interface{} - expected interface{} - }{ - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Timestamptz - dst interface{} - }{ - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Valid: true}, dst: &tim}, - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Valid: true}, dst: &tim}, - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, dst: &tim}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} - func TestTimestamptzMarshalJSON(t *testing.T) { successfulTests := []struct { source pgtype.Timestamptz diff --git a/pgtype/zeronull/timestamp.go b/pgtype/zeronull/timestamp.go index d96dbf08..1c2a1a63 100644 --- a/pgtype/zeronull/timestamp.go +++ b/pgtype/zeronull/timestamp.go @@ -2,6 +2,7 @@ package zeronull import ( "database/sql/driver" + "fmt" "time" "github.com/jackc/pgx/v5/pgtype" @@ -9,68 +10,37 @@ import ( type Timestamp time.Time -func (dst *Timestamp) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - var nullable pgtype.Timestamp - err := nullable.DecodeText(ci, src) - if err != nil { - return err +func (ts *Timestamp) ScanTimestamp(v pgtype.Timestamp) error { + if !v.Valid { + *ts = Timestamp{} + return nil } - if nullable.Valid { - *dst = Timestamp(nullable.Time) - } else { - *dst = Timestamp{} + switch v.InfinityModifier { + case pgtype.None: + *ts = Timestamp(v.Time) + return nil + case pgtype.Infinity: + return fmt.Errorf("cannot scan Infinity into *time.Time") + case pgtype.NegativeInfinity: + return fmt.Errorf("cannot scan -Infinity into *time.Time") + default: + return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) } - - return nil } -func (dst *Timestamp) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - var nullable pgtype.Timestamp - err := nullable.DecodeBinary(ci, src) - if err != nil { - return err +func (ts Timestamp) TimestampValue() (pgtype.Timestamp, error) { + if time.Time(ts).IsZero() { + return pgtype.Timestamp{}, nil } - if nullable.Valid { - *dst = Timestamp(nullable.Time) - } else { - *dst = Timestamp{} - } - - return nil -} - -func (src Timestamp) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if (src == Timestamp{}) { - return nil, nil - } - - nullable := pgtype.Timestamp{ - Time: time.Time(src), - Valid: true, - } - - return nullable.EncodeText(ci, buf) -} - -func (src Timestamp) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if (src == Timestamp{}) { - return nil, nil - } - - nullable := pgtype.Timestamp{ - Time: time.Time(src), - Valid: true, - } - - return nullable.EncodeBinary(ci, buf) + return pgtype.Timestamp{Time: time.Time(ts), Valid: true}, nil } // Scan implements the database/sql Scanner interface. -func (dst *Timestamp) Scan(src interface{}) error { +func (ts *Timestamp) Scan(src interface{}) error { if src == nil { - *dst = Timestamp{} + *ts = Timestamp{} return nil } @@ -80,12 +50,16 @@ func (dst *Timestamp) Scan(src interface{}) error { return err } - *dst = Timestamp(nullable.Time) + *ts = Timestamp(nullable.Time) return nil } // Value implements the database/sql/driver Valuer interface. -func (src Timestamp) Value() (driver.Value, error) { - return pgtype.EncodeValueText(src) +func (ts Timestamp) Value() (driver.Value, error) { + if time.Time(ts).IsZero() { + return nil, nil + } + + return time.Time(ts), nil } diff --git a/pgtype/zeronull/timestamptz.go b/pgtype/zeronull/timestamptz.go index 46448607..c5378059 100644 --- a/pgtype/zeronull/timestamptz.go +++ b/pgtype/zeronull/timestamptz.go @@ -2,6 +2,7 @@ package zeronull import ( "database/sql/driver" + "fmt" "time" "github.com/jackc/pgx/v5/pgtype" @@ -9,83 +10,56 @@ import ( type Timestamptz time.Time -func (dst *Timestamptz) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - var nullable pgtype.Timestamptz - err := nullable.DecodeText(ci, src) - if err != nil { - return err - } - - if nullable.Valid { - *dst = Timestamptz(nullable.Time) - } else { - *dst = Timestamptz{} - } - - return nil -} - -func (dst *Timestamptz) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - var nullable pgtype.Timestamptz - err := nullable.DecodeBinary(ci, src) - if err != nil { - return err - } - - if nullable.Valid { - *dst = Timestamptz(nullable.Time) - } else { - *dst = Timestamptz{} - } - - return nil -} - -func (src Timestamptz) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if (src == Timestamptz{}) { - return nil, nil - } - - nullable := pgtype.Timestamptz{ - Time: time.Time(src), - Valid: true, - } - - return nullable.EncodeText(ci, buf) -} - -func (src Timestamptz) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if (src == Timestamptz{}) { - return nil, nil - } - - nullable := pgtype.Timestamptz{ - Time: time.Time(src), - Valid: true, - } - - return nullable.EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *Timestamptz) Scan(src interface{}) error { - if src == nil { - *dst = Timestamptz{} +func (ts *Timestamptz) ScanTimestamptz(v pgtype.Timestamptz) error { + if !v.Valid { + *ts = Timestamptz{} return nil } - var nullable pgtype.Timestamptz + switch v.InfinityModifier { + case pgtype.None: + *ts = Timestamptz(v.Time) + return nil + case pgtype.Infinity: + return fmt.Errorf("cannot scan Infinity into *time.Time") + case pgtype.NegativeInfinity: + return fmt.Errorf("cannot scan -Infinity into *time.Time") + default: + return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) + } +} + +func (ts Timestamptz) TimestamptzValue() (pgtype.Timestamptz, error) { + if time.Time(ts).IsZero() { + return pgtype.Timestamptz{}, nil + } + + return pgtype.Timestamptz{Time: time.Time(ts), Valid: true}, nil +} + +// Scan implements the database/sql Scanner interface. +func (ts *Timestamptz) Scan(src interface{}) error { + if src == nil { + *ts = Timestamptz{} + return nil + } + + var nullable pgtype.Timestamp err := nullable.Scan(src) if err != nil { return err } - *dst = Timestamptz(nullable.Time) + *ts = Timestamptz(nullable.Time) return nil } // Value implements the database/sql/driver Valuer interface. -func (src Timestamptz) Value() (driver.Value, error) { - return pgtype.EncodeValueText(src) +func (ts Timestamptz) Value() (driver.Value, error) { + if time.Time(ts).IsZero() { + return nil, nil + } + + return time.Time(ts), nil } From 05d532b5df9e43f11a521f61529404aedb578212 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 20 Jan 2022 16:40:44 -0600 Subject: [PATCH 0852/1158] Fix connect when receiving NoticeResponse refs #102 --- pgconn.go | 2 +- pgconn_test.go | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pgconn.go b/pgconn.go index f8b8a659..7bf2f20e 100644 --- a/pgconn.go +++ b/pgconn.go @@ -335,7 +335,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } } return pgConn, nil - case *pgproto3.ParameterStatus: + case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse: // handled by ReceiveMessage case *pgproto3.ErrorResponse: pgConn.conn.Close() diff --git a/pgconn_test.go b/pgconn_test.go index b22792fb..32186fc6 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1298,6 +1298,7 @@ func TestConnOnNotice(t *testing.T) { config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { msg = notice.Message } + config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the message we expect. pgConn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) @@ -1954,7 +1955,11 @@ func TestConnSendBytesAndReceiveMessage(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the messages we expect. + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) defer closeConn(t, pgConn) From bcf4931a7e4c54c639f8b3d805d5dffb9aba6df7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 20 Jan 2022 17:56:07 -0600 Subject: [PATCH 0853/1158] Convert "char" to Codec --- pgtype/pgtype.go | 2 +- pgtype/qchar.go | 208 +++++++++++++++++++++---------------------- pgtype/qchar_test.go | 141 ++--------------------------- 3 files changed, 111 insertions(+), 240 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5072d061..476450a2 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -304,7 +304,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) ci.RegisterDataType(DataType{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) - ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) + ci.RegisterDataType(DataType{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) ci.RegisterDataType(DataType{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) ci.RegisterDataType(DataType{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) ci.RegisterDataType(DataType{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) diff --git a/pgtype/qchar.go b/pgtype/qchar.go index e56bf142..28c91110 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -1,145 +1,141 @@ package pgtype import ( + "database/sql/driver" "fmt" "math" - "strconv" ) -// QChar is for PostgreSQL's special 8-bit-only "char" type more akin to the C +// QCharCodec is for PostgreSQL's special 8-bit-only "char" type more akin to the C // language's char type, or Go's byte type. (Note that the name in PostgreSQL // itself is "char", in double-quotes, and not char.) It gets used a lot in // PostgreSQL's system tables to hold a single ASCII character value (eg // pg_class.relkind). It is named Qchar for quoted char to disambiguate from SQL // standard type char. -// -// Not all possible values of QChar are representable in the text format. -// Therefore, QChar does not implement TextEncoder and TextDecoder. In -// addition, database/sql Scanner and database/sql/driver Value are not -// implemented. -type QChar struct { - Int int8 - Valid bool +type QCharCodec struct{} + +func (QCharCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode } -func (dst *QChar) Set(src interface{}) error { - if src == nil { - *dst = QChar{} - return nil - } +func (QCharCodec) PreferredFormat() int16 { + return BinaryFormatCode +} - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) +func (QCharCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch value.(type) { + case byte: + return encodePlanQcharCodecByte{} + case rune: + return encodePlanQcharCodecRune{} } } - switch value := src.(type) { - case int8: - *dst = QChar{Int: value, Valid: true} - case uint8: - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case int16: - if value < math.MinInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case uint16: - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case int32: - if value < math.MinInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case uint32: - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case int64: - if value < math.MinInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case uint64: - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case int: - if value < math.MinInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case uint: - if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Valid: true} - case string: - num, err := strconv.ParseInt(value, 10, 8) - if err != nil { - return err - } - *dst = QChar{Int: int8(num), Valid: true} - default: - if originalSrc, ok := underlyingNumberType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to QChar", value) - } - return nil } -func (dst QChar) Get() interface{} { - if !dst.Valid { - return nil +type encodePlanQcharCodecByte struct{} + +func (encodePlanQcharCodecByte) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + b := value.(byte) + buf = append(buf, b) + return buf, nil +} + +type encodePlanQcharCodecRune struct{} + +func (encodePlanQcharCodecRune) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + r := value.(rune) + if r > math.MaxUint8 { + return nil, fmt.Errorf(`%v cannot be encoded to "char"`, r) } - return dst.Int + b := byte(r) + buf = append(buf, b) + return buf, nil } -func (src *QChar) AssignTo(dst interface{}) error { - return int64AssignTo(int64(src.Int), src.Valid, dst) +func (QCharCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch target.(type) { + case *byte: + return scanPlanQcharCodecByte{} + case *rune: + return scanPlanQcharCodecRune{} + } + } + + return nil } -func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { +type scanPlanQcharCodecByte struct{} + +func (scanPlanQcharCodecByte) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { - *dst = QChar{} - return nil + return fmt.Errorf("cannot scan null into %T", dst) } - if len(src) != 1 { + if len(src) > 1 { return fmt.Errorf(`invalid length for "char": %v`, len(src)) } - *dst = QChar{Int: int8(src[0]), Valid: true} + b := dst.(*byte) + // In the text format the zero value is returned as a zero byte value instead of 0 + if len(src) == 0 { + *b = 0 + } else { + *b = src[0] + } + return nil } -func (src QChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { +type scanPlanQcharCodecRune struct{} + +func (scanPlanQcharCodecRune) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) > 1 { + return fmt.Errorf(`invalid length for "char": %v`, len(src)) + } + + r := dst.(*rune) + // In the text format the zero value is returned as a zero byte value instead of 0 + if len(src) == 0 { + *r = 0 + } else { + *r = rune(src[0]) + } + + return nil +} + +func (c QCharCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { return nil, nil } - return append(buf, byte(src.Int)), nil + var r rune + err := codecScan(c, ci, oid, format, src, &r) + if err != nil { + return nil, err + } + return string(r), nil +} + +func (c QCharCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var r rune + err := codecScan(c, ci, oid, format, src, &r) + if err != nil { + return nil, err + } + return r, nil } diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go index cb9b6786..ec555eb2 100644 --- a/pgtype/qchar_test.go +++ b/pgtype/qchar_test.go @@ -2,142 +2,17 @@ package pgtype_test import ( "math" - "reflect" "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestQCharTranscode(t *testing.T) { - testutil.TestPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ - &pgtype.QChar{Int: math.MinInt8, Valid: true}, - &pgtype.QChar{Int: -1, Valid: true}, - &pgtype.QChar{Int: 0, Valid: true}, - &pgtype.QChar{Int: 1, Valid: true}, - &pgtype.QChar{Int: math.MaxInt8, Valid: true}, - &pgtype.QChar{Int: 0}, - }, func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - }) -} - -func TestQCharSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.QChar - }{ - {source: int8(1), result: pgtype.QChar{Int: 1, Valid: true}}, - {source: int16(1), result: pgtype.QChar{Int: 1, Valid: true}}, - {source: int32(1), result: pgtype.QChar{Int: 1, Valid: true}}, - {source: int64(1), result: pgtype.QChar{Int: 1, Valid: true}}, - {source: int8(-1), result: pgtype.QChar{Int: -1, Valid: true}}, - {source: int16(-1), result: pgtype.QChar{Int: -1, Valid: true}}, - {source: int32(-1), result: pgtype.QChar{Int: -1, Valid: true}}, - {source: int64(-1), result: pgtype.QChar{Int: -1, Valid: true}}, - {source: uint8(1), result: pgtype.QChar{Int: 1, Valid: true}}, - {source: uint16(1), result: pgtype.QChar{Int: 1, Valid: true}}, - {source: uint32(1), result: pgtype.QChar{Int: 1, Valid: true}}, - {source: uint64(1), result: pgtype.QChar{Int: 1, Valid: true}}, - {source: "1", result: pgtype.QChar{Int: 1, Valid: true}}, - {source: _int8(1), result: pgtype.QChar{Int: 1, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.QChar - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestQCharAssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - - simpleTests := []struct { - src pgtype.QChar - dst interface{} - expected interface{} - }{ - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i8, expected: int8(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i16, expected: int16(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i32, expected: int32(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i64, expected: int64(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i, expected: int(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui, expected: uint(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.QChar{Int: 0}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.QChar{Int: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.QChar - dst interface{} - expected interface{} - }{ - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &pi8, expected: int8(42)}, - {src: pgtype.QChar{Int: 42, Valid: true}, dst: &_pi8, expected: _int8(42)}, +func TestQcharTranscode(t *testing.T) { + var tests []PgxTranscodeTestCase + for i := 0; i <= math.MaxUint8; i++ { + tests = append(tests, PgxTranscodeTestCase{rune(i), new(rune), isExpectedEq(rune(i))}) + tests = append(tests, PgxTranscodeTestCase{byte(i), new(byte), isExpectedEq(byte(i))}) } + tests = append(tests, PgxTranscodeTestCase{nil, new(*rune), isExpectedEq((*rune)(nil))}) + tests = append(tests, PgxTranscodeTestCase{nil, new(*byte), isExpectedEq((*byte)(nil))}) - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.QChar - dst interface{} - }{ - {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui8}, - {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui16}, - {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui32}, - {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui64}, - {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui}, - {src: pgtype.QChar{Int: 0}, dst: &i16}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } + testPgxCodec(t, `"char"`, tests) } From b2e5c4ff6e47347a5ccb5adfd470cdb419860c7f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 20 Jan 2022 18:00:43 -0600 Subject: [PATCH 0854/1158] Add "char" array --- pgtype/pgtype.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 476450a2..cb028677 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -45,6 +45,7 @@ const ( MacaddrOID = 829 InetOID = 869 BoolArrayOID = 1000 + QCharArrayOID = 1003 NameArrayOID = 1003 Int2ArrayOID = 1005 Int4ArrayOID = 1007 @@ -285,6 +286,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementCodec: PointCodec{}, ElementOID: PointOID}}) ci.RegisterDataType(DataType{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementCodec: PolygonCodec{}, ElementOID: PolygonOID}}) ci.RegisterDataType(DataType{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: NameOID}}) + ci.RegisterDataType(DataType{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementCodec: QCharCodec{}, ElementOID: QCharOID}}) ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) ci.RegisterDataType(DataType{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: TextOID}}) ci.RegisterDataType(DataType{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementCodec: TimestampCodec{}, ElementOID: TimestampOID}}) From 97443487ce93a473d1ba1569f70ef8797904838e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 20 Jan 2022 20:07:09 -0600 Subject: [PATCH 0855/1158] Convert macaddr to Codec --- pgtype/builtin_wrappers.go | 13 + pgtype/macaddr.go | 228 ++++++++-------- pgtype/macaddr_array.go | 505 ----------------------------------- pgtype/macaddr_array_test.go | 262 ------------------ pgtype/macaddr_test.go | 103 +++---- pgtype/pgtype.go | 2 +- 6 files changed, 164 insertions(+), 949 deletions(-) delete mode 100644 pgtype/macaddr_array.go delete mode 100644 pgtype/macaddr_array_test.go diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 3c8d23fb..873afe53 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -299,6 +299,19 @@ type stringWrapper string func (w stringWrapper) SkipUnderlyingTypePlan() {} +func (w *stringWrapper) ScanText(v Text) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *string") + } + + *w = stringWrapper(v.String) + return nil +} + +func (w stringWrapper) TextValue() (Text, error) { + return Text{String: string(w), Valid: true}, nil +} + func (w *stringWrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *string") diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go index 8d6ab720..0ac003ae 100644 --- a/pgtype/macaddr.go +++ b/pgtype/macaddr.go @@ -2,92 +2,135 @@ package pgtype import ( "database/sql/driver" - "fmt" "net" ) -type Macaddr struct { - Addr net.HardwareAddr - Valid bool +type MacaddrCodec struct{} + +func (MacaddrCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode } -func (dst *Macaddr) Set(src interface{}) error { - if src == nil { - *dst = Macaddr{} - return nil - } +func (MacaddrCodec) PreferredFormat() int16 { + return BinaryFormatCode +} - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } +func (MacaddrCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case net.HardwareAddr: + return encodePlanMacaddrCodecBinaryHardwareAddr{} + case TextValuer: + return encodePlanMacAddrCodecTextValuer{} - switch value := src.(type) { - case net.HardwareAddr: - addr := make(net.HardwareAddr, len(value)) - copy(addr, value) - *dst = Macaddr{Addr: addr, Valid: true} - case string: - addr, err := net.ParseMAC(value) - if err != nil { - return err } - *dst = Macaddr{Addr: addr, Valid: true} - case *net.HardwareAddr: - if value == nil { - *dst = Macaddr{} - } else { - return dst.Set(*value) + case TextFormatCode: + switch value.(type) { + case net.HardwareAddr: + return encodePlanMacaddrCodecTextHardwareAddr{} + case TextValuer: + return encodePlanTextCodecTextValuer{} } - case *string: - if value == nil { - *dst = Macaddr{} - } else { - return dst.Set(*value) - } - default: - if originalSrc, ok := underlyingPtrType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Macaddr", value) } return nil } -func (dst Macaddr) Get() interface{} { - if !dst.Valid { - return nil +type encodePlanMacaddrCodecBinaryHardwareAddr struct{} + +func (encodePlanMacaddrCodecBinaryHardwareAddr) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + addr := value.(net.HardwareAddr) + if addr == nil { + return nil, nil } - return dst.Addr + + return append(buf, addr...), nil } -func (src *Macaddr) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) +type encodePlanMacAddrCodecTextValuer struct{} + +func (encodePlanMacAddrCodecTextValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + t, err := value.(TextValuer).TextValue() + if err != nil { + return nil, err + } + if !t.Valid { + return nil, nil } - switch v := dst.(type) { - case *net.HardwareAddr: - *v = make(net.HardwareAddr, len(src.Addr)) - copy(*v, src.Addr) - return nil - case *string: - *v = src.Addr.String() - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) + addr, err := net.ParseMAC(t.String) + if err != nil { + return nil, err + } + + return append(buf, addr...), nil +} + +type encodePlanMacaddrCodecTextHardwareAddr struct{} + +func (encodePlanMacaddrCodecTextHardwareAddr) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + addr := value.(net.HardwareAddr) + if addr == nil { + return nil, nil + } + + return append(buf, addr.String()...), nil +} + +func (MacaddrCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case *net.HardwareAddr: + return scanPlanBinaryMacaddrToHardwareAddr{} + case TextScanner: + return scanPlanBinaryMacaddrToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *net.HardwareAddr: + return scanPlanTextMacaddrToHardwareAddr{} + case TextScanner: + return scanPlanTextAnyToTextScanner{} } - return fmt.Errorf("unable to assign to %T", dst) } + + return nil } -func (dst *Macaddr) DecodeText(ci *ConnInfo, src []byte) error { +type scanPlanBinaryMacaddrToHardwareAddr struct{} + +func (scanPlanBinaryMacaddrToHardwareAddr) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + dstBuf := dst.(*net.HardwareAddr) if src == nil { - *dst = Macaddr{} + *dstBuf = nil + return nil + } + + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type scanPlanBinaryMacaddrToTextScanner struct{} + +func (scanPlanBinaryMacaddrToTextScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(TextScanner) + if src == nil { + return scanner.ScanText(Text{}) + } + + return scanner.ScanText(Text{String: net.HardwareAddr(src).String(), Valid: true}) +} + +type scanPlanTextMacaddrToHardwareAddr struct{} + +func (scanPlanTextMacaddrToHardwareAddr) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + p := dst.(*net.HardwareAddr) + + if src == nil { + *p = nil return nil } @@ -96,65 +139,24 @@ func (dst *Macaddr) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Macaddr{Addr: addr, Valid: true} - return nil -} - -func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Macaddr{} - return nil - } - - if len(src) != 6 { - return fmt.Errorf("Received an invalid size for a macaddr: %d", len(src)) - } - - addr := make(net.HardwareAddr, 6) - copy(addr, src) - - *dst = Macaddr{Addr: addr, Valid: true} + *p = addr return nil } -func (src Macaddr) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return append(buf, src.Addr.String()...), nil +func (c MacaddrCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) } -// EncodeBinary encodes src into w. -func (src Macaddr) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return append(buf, src.Addr...), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Macaddr) Scan(src interface{}) error { +func (c MacaddrCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = Macaddr{} - return nil + return nil, nil } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + var addr net.HardwareAddr + err := codecScan(c, ci, oid, format, src, &addr) + if err != nil { + return nil, err } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Macaddr) Value() (driver.Value, error) { - return EncodeValueText(src) + return addr, nil } diff --git a/pgtype/macaddr_array.go b/pgtype/macaddr_array.go deleted file mode 100644 index 78a93a2d..00000000 --- a/pgtype/macaddr_array.go +++ /dev/null @@ -1,505 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "net" - "reflect" - - "github.com/jackc/pgio" -) - -type MacaddrArray struct { - Elements []Macaddr - Dimensions []ArrayDimension - Valid bool -} - -func (dst *MacaddrArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = MacaddrArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []net.HardwareAddr: - if value == nil { - *dst = MacaddrArray{} - } else if len(value) == 0 { - *dst = MacaddrArray{Valid: true} - } else { - elements := make([]Macaddr, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = MacaddrArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*net.HardwareAddr: - if value == nil { - *dst = MacaddrArray{} - } else if len(value) == 0 { - *dst = MacaddrArray{Valid: true} - } else { - elements := make([]Macaddr, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = MacaddrArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Macaddr: - if value == nil { - *dst = MacaddrArray{} - } else if len(value) == 0 { - *dst = MacaddrArray{Valid: true} - } else { - *dst = MacaddrArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = MacaddrArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for MacaddrArray", src) - } - if elementsLength == 0 { - *dst = MacaddrArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to MacaddrArray", src) - } - - *dst = MacaddrArray{ - Elements: make([]Macaddr, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Macaddr, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to MacaddrArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *MacaddrArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to MacaddrArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in MacaddrArray", err) - } - index++ - - return index, nil -} - -func (dst MacaddrArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *MacaddrArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]net.HardwareAddr: - *v = make([]net.HardwareAddr, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.HardwareAddr: - *v = make([]*net.HardwareAddr, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *MacaddrArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from MacaddrArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from MacaddrArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *MacaddrArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = MacaddrArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Macaddr - - if len(uta.Elements) > 0 { - elements = make([]Macaddr, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Macaddr - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = MacaddrArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *MacaddrArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = MacaddrArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = MacaddrArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Macaddr, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = MacaddrArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src MacaddrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src MacaddrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("macaddr"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "macaddr") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *MacaddrArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src MacaddrArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/macaddr_array_test.go b/pgtype/macaddr_array_test.go deleted file mode 100644 index ac76a052..00000000 --- a/pgtype/macaddr_array_test.go +++ /dev/null @@ -1,262 +0,0 @@ -package pgtype_test - -import ( - "net" - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestMacaddrArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "macaddr[]", []interface{}{ - &pgtype.MacaddrArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.MacaddrArray{}, - }) -} - -func TestMacaddrArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.MacaddrArray - }{ - { - source: []net.HardwareAddr{mustParseMacaddr(t, "01:23:45:67:89:ab")}, - result: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]net.HardwareAddr)(nil)), - result: pgtype.MacaddrArray{}, - }, - { - source: [][]net.HardwareAddr{ - {mustParseMacaddr(t, "01:23:45:67:89:ab")}, - {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, - result: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]net.HardwareAddr{ - {{{ - mustParseMacaddr(t, "01:23:45:67:89:ab"), - mustParseMacaddr(t, "cd:ef:01:23:45:67"), - mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, - {{{ - mustParseMacaddr(t, "45:67:89:ab:cd:ef"), - mustParseMacaddr(t, "fe:dc:ba:98:76:54"), - mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, - result: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}, - {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Valid: true}, - {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Valid: true}, - {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Valid: true}, - {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]net.HardwareAddr{ - {mustParseMacaddr(t, "01:23:45:67:89:ab")}, - {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, - result: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]net.HardwareAddr{ - {{{ - mustParseMacaddr(t, "01:23:45:67:89:ab"), - mustParseMacaddr(t, "cd:ef:01:23:45:67"), - mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, - {{{ - mustParseMacaddr(t, "45:67:89:ab:cd:ef"), - mustParseMacaddr(t, "fe:dc:ba:98:76:54"), - mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, - result: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}, - {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Valid: true}, - {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Valid: true}, - {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Valid: true}, - {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.MacaddrArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestMacaddrArrayAssignTo(t *testing.T) { - var macaddrSlice []net.HardwareAddr - var macaddrSliceDim2 [][]net.HardwareAddr - var macaddrSliceDim4 [][][][]net.HardwareAddr - var macaddrArrayDim2 [2][1]net.HardwareAddr - var macaddrArrayDim4 [2][1][1][3]net.HardwareAddr - - simpleTests := []struct { - src pgtype.MacaddrArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &macaddrSlice, - expected: []net.HardwareAddr{mustParseMacaddr(t, "01:23:45:67:89:ab")}, - }, - { - src: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &macaddrSlice, - expected: []net.HardwareAddr{nil}, - }, - { - src: pgtype.MacaddrArray{}, - dst: &macaddrSlice, - expected: (([]net.HardwareAddr)(nil)), - }, - { - src: pgtype.MacaddrArray{Valid: true}, - dst: &macaddrSlice, - expected: []net.HardwareAddr{}, - }, - { - src: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &macaddrSliceDim2, - expected: [][]net.HardwareAddr{ - {mustParseMacaddr(t, "01:23:45:67:89:ab")}, - {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, - }, - { - src: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}, - {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Valid: true}, - {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Valid: true}, - {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Valid: true}, - {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &macaddrSliceDim4, - expected: [][][][]net.HardwareAddr{ - {{{ - mustParseMacaddr(t, "01:23:45:67:89:ab"), - mustParseMacaddr(t, "cd:ef:01:23:45:67"), - mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, - {{{ - mustParseMacaddr(t, "45:67:89:ab:cd:ef"), - mustParseMacaddr(t, "fe:dc:ba:98:76:54"), - mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, - }, - { - src: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &macaddrArrayDim2, - expected: [2][1]net.HardwareAddr{ - {mustParseMacaddr(t, "01:23:45:67:89:ab")}, - {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, - }, - { - src: pgtype.MacaddrArray{ - Elements: []pgtype.Macaddr{ - {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, - {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}, - {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Valid: true}, - {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Valid: true}, - {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Valid: true}, - {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &macaddrArrayDim4, - expected: [2][1][1][3]net.HardwareAddr{ - {{{ - mustParseMacaddr(t, "01:23:45:67:89:ab"), - mustParseMacaddr(t, "cd:ef:01:23:45:67"), - mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, - {{{ - mustParseMacaddr(t, "45:67:89:ab:cd:ef"), - mustParseMacaddr(t, "fe:dc:ba:98:76:54"), - mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go index 5b9d8d88..3e59c580 100644 --- a/pgtype/macaddr_test.go +++ b/pgtype/macaddr_test.go @@ -3,76 +3,43 @@ package pgtype_test import ( "bytes" "net" - "reflect" "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestMacaddrTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "macaddr", []interface{}{ - &pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, - &pgtype.Macaddr{}, +func isExpectedEqHardwareAddr(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + aa := a.(net.HardwareAddr) + vv := v.(net.HardwareAddr) + + if (aa == nil) != (vv == nil) { + return false + } + + if aa == nil { + return true + } + + return bytes.Compare(aa, vv) == 0 + } +} + +func TestMacaddrCodec(t *testing.T) { + testPgxCodec(t, "macaddr", []PgxTranscodeTestCase{ + { + mustParseMacaddr(t, "01:23:45:67:89:ab"), + new(net.HardwareAddr), + isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab")), + }, + { + "01:23:45:67:89:ab", + new(net.HardwareAddr), + isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab")), + }, + { + mustParseMacaddr(t, "01:23:45:67:89:ab"), + new(string), + isExpectedEq("01:23:45:67:89:ab"), + }, + {nil, new(*net.HardwareAddr), isExpectedEq((*net.HardwareAddr)(nil))}, }) } - -func TestMacaddrSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Macaddr - }{ - { - source: mustParseMacaddr(t, "01:23:45:67:89:ab"), - result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, - }, - { - source: "01:23:45:67:89:ab", - result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Macaddr - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestMacaddrAssignTo(t *testing.T) { - { - src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true} - var dst net.HardwareAddr - expected := mustParseMacaddr(t, "01:23:45:67:89:ab") - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if bytes.Compare([]byte(dst), []byte(expected)) != 0 { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true} - var dst string - expected := "01:23:45:67:89:ab" - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index cb028677..0cc0c062 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -325,7 +325,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) ci.RegisterDataType(DataType{Name: "line", OID: LineOID, Codec: LineCodec{}}) ci.RegisterDataType(DataType{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) - ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID}) + ci.RegisterDataType(DataType{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) ci.RegisterDataType(DataType{Name: "name", OID: NameOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID}) // ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) From 06f4e47750af9f1c026a1cdc3c0e281fe23d39fa Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 20 Jan 2022 20:10:43 -0600 Subject: [PATCH 0856/1158] Add macaddr array --- pgtype/pgtype.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 0cc0c062..7c1b6c07 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -66,6 +66,7 @@ const ( OIDArrayOID = 1028 ACLItemOID = 1033 ACLItemArrayOID = 1034 + MacaddrArrayOID = 1040 InetArrayOID = 1041 BPCharOID = 1042 VarcharOID = 1043 @@ -291,6 +292,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: TextOID}}) ci.RegisterDataType(DataType{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementCodec: TimestampCodec{}, ElementOID: TimestampOID}}) ci.RegisterDataType(DataType{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementCodec: TimestamptzCodec{}, ElementOID: TimestamptzOID}}) + ci.RegisterDataType(DataType{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementCodec: MacaddrCodec{}, ElementOID: MacaddrOID}}) ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID}) ci.RegisterDataType(DataType{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementCodec: JSONBCodec{}, ElementOID: JSONBOID}}) ci.RegisterDataType(DataType{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementCodec: JSONCodec{}, ElementOID: JSONOID}}) From b10eb89fe4631ac56de742168ffb08be559e2ddd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 20 Jan 2022 20:22:53 -0600 Subject: [PATCH 0857/1158] Use wrapper to treat fmt.String as pgtype.TextValuer --- pgtype/builtin_wrappers.go | 8 ++++++++ pgtype/enum_codec.go | 2 -- pgtype/pgtype.go | 12 ++++++++++++ pgtype/text.go | 2 -- pgtype/text_test.go | 7 +++++++ 5 files changed, 27 insertions(+), 4 deletions(-) diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 873afe53..5b6a032a 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -491,3 +491,11 @@ func (w mapStringToStringWrapper) HstoreValue() (Hstore, error) { } return hstore, nil } + +type fmtStringerWrapper struct { + s fmt.Stringer +} + +func (w fmtStringerWrapper) TextValue() (Text, error) { + return Text{String: w.s.String(), Valid: true}, nil +} diff --git a/pgtype/enum_codec.go b/pgtype/enum_codec.go index 9a37f1dd..d405245f 100644 --- a/pgtype/enum_codec.go +++ b/pgtype/enum_codec.go @@ -30,8 +30,6 @@ func (EnumCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interf return encodePlanTextCodecByteSlice{} case rune: return encodePlanTextCodecRune{} - case fmt.Stringer: - return encodePlanTextCodecStringer{} case TextValuer: return encodePlanTextCodecTextValuer{} } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 7c1b6c07..bcbb9f97 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1421,6 +1421,8 @@ func tryWrapBuiltinTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNext return &wrapMapStringToPointerStringEncodePlan{}, mapStringToPointerStringWrapper(value), true case map[string]string: return &wrapMapStringToStringEncodePlan{}, mapStringToStringWrapper(value), true + case fmt.Stringer: + return &wrapFmtStringerEncodePlan{}, fmtStringerWrapper{value}, true } return nil, nil, false @@ -1616,6 +1618,16 @@ func (plan *wrapMapStringToStringEncodePlan) Encode(value interface{}, buf []byt return plan.next.Encode(mapStringToStringWrapper(value.(map[string]string)), buf) } +type wrapFmtStringerEncodePlan struct { + next EncodePlan +} + +func (plan *wrapFmtStringerEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapFmtStringerEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(fmtStringerWrapper{value.(fmt.Stringer)}, buf) +} + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written. diff --git a/pgtype/text.go b/pgtype/text.go index 3cb1cfa3..3c73cc15 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -100,8 +100,6 @@ func (TextCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interf return encodePlanTextCodecByteSlice{} case rune: return encodePlanTextCodecRune{} - case fmt.Stringer: - return encodePlanTextCodecStringer{} case TextValuer: return encodePlanTextCodecTextValuer{} } diff --git a/pgtype/text_test.go b/pgtype/text_test.go index 27b01c15..f45978a7 100644 --- a/pgtype/text_test.go +++ b/pgtype/text_test.go @@ -9,6 +9,12 @@ import ( "github.com/stretchr/testify/require" ) +type someFmtStringer struct{} + +func (someFmtStringer) String() string { + return "some fmt.Stringer" +} + func TestTextCodec(t *testing.T) { for _, pgTypeName := range []string{"text", "varchar"} { testPgxCodec(t, pgTypeName, []PgxTranscodeTestCase{ @@ -24,6 +30,7 @@ func TestTextCodec(t *testing.T) { }, {nil, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, {"foo", new(string), isExpectedEq("foo")}, + {someFmtStringer{}, new(string), isExpectedEq("some fmt.Stringer")}, {rune('R'), new(rune), isExpectedEq(rune('R'))}, }) } From 7a3bc454e0ad5f080c323507c353fcad132ebd64 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 20 Jan 2022 20:40:37 -0600 Subject: [PATCH 0858/1158] Convert TID to Codec --- pgtype/pgtype.go | 2 +- pgtype/tid.go | 233 +++++++++++++++++++++++++++++---------------- pgtype/tid_test.go | 66 +++---------- 3 files changed, 167 insertions(+), 134 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index bcbb9f97..9b1ef595 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -337,7 +337,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) // ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) ci.RegisterDataType(DataType{Name: "text", OID: TextOID, Codec: TextCodec{}}) - ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID}) + ci.RegisterDataType(DataType{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) ci.RegisterDataType(DataType{Value: &Time{}, Name: "time", OID: TimeOID}) ci.RegisterDataType(DataType{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) ci.RegisterDataType(DataType{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) diff --git a/pgtype/tid.go b/pgtype/tid.go index 0108d219..624b3c2a 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -10,6 +10,14 @@ import ( "github.com/jackc/pgio" ) +type TIDScanner interface { + ScanTID(v TID) error +} + +type TIDValuer interface { + TIDValue() (TID, error) +} + // TID is PostgreSQL's Tuple Identifier type. // // When one does @@ -27,40 +35,148 @@ type TID struct { Valid bool } -func (dst *TID) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to TID", src) +func (b *TID) ScanTID(v TID) error { + *b = v + return nil } -func (dst TID) Get() interface{} { - if !dst.Valid { - return nil - } - return dst +func (b TID) TIDValue() (TID, error) { + return b, nil } -func (src *TID) AssignTo(dst interface{}) error { - if !src.Valid { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - - switch v := dst.(type) { - case *string: - *v = fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber) - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } -} - -func (dst *TID) DecodeText(ci *ConnInfo, src []byte) error { +// Scan implements the database/sql Scanner interface. +func (dst *TID) Scan(src interface{}) error { if src == nil { *dst = TID{} return nil } + switch src := src.(type) { + case string: + return scanPlanTextAnyToTIDScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src TID) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + buf, err := TIDCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type TIDCodec struct{} + +func (TIDCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (TIDCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (TIDCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(TIDValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanTIDCodecBinary{} + case TextFormatCode: + return encodePlanTIDCodecText{} + } + + return nil +} + +type encodePlanTIDCodecBinary struct{} + +func (encodePlanTIDCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + tid, err := value.(TIDValuer).TIDValue() + if err != nil { + return nil, err + } + + if !tid.Valid { + return nil, nil + } + + buf = pgio.AppendUint32(buf, tid.BlockNumber) + buf = pgio.AppendUint16(buf, tid.OffsetNumber) + return buf, nil +} + +type encodePlanTIDCodecText struct{} + +func (encodePlanTIDCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + tid, err := value.(TIDValuer).TIDValue() + if err != nil { + return nil, err + } + + if !tid.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`(%d,%d)`, tid.BlockNumber, tid.OffsetNumber)...) + return buf, nil +} + +func (TIDCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case TIDScanner: + return scanPlanBinaryTIDToTIDScanner{} + } + case TextFormatCode: + switch target.(type) { + case TIDScanner: + return scanPlanTextAnyToTIDScanner{} + } + } + + return nil +} + +type scanPlanBinaryTIDToTIDScanner struct{} + +func (scanPlanBinaryTIDToTIDScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(TIDScanner) + + if src == nil { + return scanner.ScanTID(TID{}) + } + + if len(src) != 6 { + return fmt.Errorf("invalid length for tid: %v", len(src)) + } + + return scanner.ScanTID(TID{ + BlockNumber: binary.BigEndian.Uint32(src), + OffsetNumber: binary.BigEndian.Uint16(src[4:]), + Valid: true, + }) +} + +type scanPlanTextAnyToTIDScanner struct{} + +func (scanPlanTextAnyToTIDScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(TIDScanner) + + if src == nil { + return scanner.ScanTID(TID{}) + } + if len(src) < 5 { return fmt.Errorf("invalid length for tid: %v", len(src)) } @@ -80,67 +196,22 @@ func (dst *TID) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Valid: true} - return nil + return scanner.ScanTID(TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Valid: true}) } -func (dst *TID) DecodeBinary(ci *ConnInfo, src []byte) error { +func (c TIDCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c TIDCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = TID{} - return nil - } - - if len(src) != 6 { - return fmt.Errorf("invalid length for tid: %v", len(src)) - } - - *dst = TID{ - BlockNumber: binary.BigEndian.Uint32(src), - OffsetNumber: binary.BigEndian.Uint16(src[4:]), - Valid: true, - } - return nil -} - -func (src TID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { return nil, nil } - buf = append(buf, fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber)...) - return buf, nil -} - -func (src TID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil + var tid TID + err := codecScan(c, ci, oid, format, src, &tid) + if err != nil { + return nil, err } - - buf = pgio.AppendUint32(buf, src.BlockNumber) - buf = pgio.AppendUint16(buf, src.OffsetNumber) - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *TID) Scan(src interface{}) error { - if src == nil { - *dst = TID{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src TID) Value() (driver.Value, error) { - return EncodeValueText(src) + return tid, nil } diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go index ef24005a..4203a471 100644 --- a/pgtype/tid_test.go +++ b/pgtype/tid_test.go @@ -1,62 +1,24 @@ package pgtype_test import ( - "reflect" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestTIDTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "tid", []interface{}{ - &pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, - &pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, - &pgtype.TID{}, +func TestTIDCodec(t *testing.T) { + testPgxCodec(t, "tid", []PgxTranscodeTestCase{ + { + pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, + new(pgtype.TID), + isExpectedEq(pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}), + }, + { + pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, + new(pgtype.TID), + isExpectedEq(pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}), + }, + {pgtype.TID{}, new(pgtype.TID), isExpectedEq(pgtype.TID{})}, + {nil, new(pgtype.TID), isExpectedEq(pgtype.TID{})}, }) } - -func TestTIDAssignTo(t *testing.T) { - var s string - var sp *string - - simpleTests := []struct { - src pgtype.TID - dst interface{} - expected interface{} - }{ - {src: pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, dst: &s, expected: "(42,43)"}, - {src: pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, dst: &s, expected: "(4294967295,65535)"}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.TID - dst interface{} - expected interface{} - }{ - {src: pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, dst: &sp, expected: "(42,43)"}, - {src: pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, dst: &sp, expected: "(4294967295,65535)"}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} From 5ca29a014e52b891a8832d1a2990228fe679f8f1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 20 Jan 2022 20:41:56 -0600 Subject: [PATCH 0859/1158] Add tid array --- pgtype/pgtype.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 9b1ef595..5e7e2501 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -50,6 +50,7 @@ const ( Int2ArrayOID = 1005 Int4ArrayOID = 1007 TextArrayOID = 1009 + TIDArrayOID = 1010 ByteaArrayOID = 1001 XIDArrayOID = 1011 CIDArrayOID = 1012 @@ -293,6 +294,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementCodec: TimestampCodec{}, ElementOID: TimestampOID}}) ci.RegisterDataType(DataType{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementCodec: TimestamptzCodec{}, ElementOID: TimestamptzOID}}) ci.RegisterDataType(DataType{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementCodec: MacaddrCodec{}, ElementOID: MacaddrOID}}) + ci.RegisterDataType(DataType{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementCodec: TIDCodec{}, ElementOID: TIDOID}}) ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID}) ci.RegisterDataType(DataType{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementCodec: JSONBCodec{}, ElementOID: JSONBOID}}) ci.RegisterDataType(DataType{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementCodec: JSONCodec{}, ElementOID: JSONOID}}) From c8b87644014fb3b6d0ab2d02affcd057f9f15624 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 20 Jan 2022 20:59:36 -0600 Subject: [PATCH 0860/1158] Allow scanning tid to string --- pgtype/tid.go | 24 ++++++++++++++++++++++++ pgtype/tid_test.go | 10 ++++++++++ 2 files changed, 34 insertions(+) diff --git a/pgtype/tid.go b/pgtype/tid.go index 624b3c2a..450cfbc9 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -137,6 +137,8 @@ func (TIDCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfac switch target.(type) { case TIDScanner: return scanPlanBinaryTIDToTIDScanner{} + case TextScanner: + return scanPlanBinaryTIDToTextScanner{} } case TextFormatCode: switch target.(type) { @@ -168,6 +170,28 @@ func (scanPlanBinaryTIDToTIDScanner) Scan(ci *ConnInfo, oid uint32, formatCode i }) } +type scanPlanBinaryTIDToTextScanner struct{} + +func (scanPlanBinaryTIDToTextScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(TextScanner) + + if src == nil { + return scanner.ScanText(Text{}) + } + + if len(src) != 6 { + return fmt.Errorf("invalid length for tid: %v", len(src)) + } + + blockNumber := binary.BigEndian.Uint32(src) + offsetNumber := binary.BigEndian.Uint16(src[4:]) + + return scanner.ScanText(Text{ + String: fmt.Sprintf(`(%d,%d)`, blockNumber, offsetNumber), + Valid: true, + }) +} + type scanPlanTextAnyToTIDScanner struct{} func (scanPlanTextAnyToTIDScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go index 4203a471..0d3afe5e 100644 --- a/pgtype/tid_test.go +++ b/pgtype/tid_test.go @@ -18,6 +18,16 @@ func TestTIDCodec(t *testing.T) { new(pgtype.TID), isExpectedEq(pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}), }, + { + pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, + new(string), + isExpectedEq("(42,43)"), + }, + { + pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, + new(string), + isExpectedEq("(4294967295,65535)"), + }, {pgtype.TID{}, new(pgtype.TID), isExpectedEq(pgtype.TID{})}, {nil, new(pgtype.TID), isExpectedEq(pgtype.TID{})}, }) From 61b4fb76895c2677afba96fa30b9db3c3725da2b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 21 Jan 2022 16:50:30 -0600 Subject: [PATCH 0861/1158] Convert time to Codec --- pgtype/builtin_wrappers.go | 32 +++++ pgtype/pgtype.go | 2 +- pgtype/time.go | 273 +++++++++++++++++++------------------ pgtype/time_test.go | 140 +++++-------------- 4 files changed, 211 insertions(+), 236 deletions(-) diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 5b6a032a..3020f9bb 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -399,6 +399,38 @@ func (w timeWrapper) TimestamptzValue() (Timestamptz, error) { return Timestamptz{Time: time.Time(w), Valid: true}, nil } +func (w *timeWrapper) ScanTime(v Time) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Time") + } + + // 24:00:00 is max allowed time in PostgreSQL, but time.Time will normalize that to 00:00:00 the next day. + var maxRepresentableByTime int64 = 24*60*60*1000000 - 1 + if v.Microseconds > maxRepresentableByTime { + return fmt.Errorf("%d microseconds cannot be represented as time.Time", v.Microseconds) + } + + usec := v.Microseconds + hours := usec / microsecondsPerHour + usec -= hours * microsecondsPerHour + minutes := usec / microsecondsPerMinute + usec -= minutes * microsecondsPerMinute + seconds := usec / microsecondsPerSecond + usec -= seconds * microsecondsPerSecond + ns := usec * 1000 + *w = timeWrapper(time.Date(2000, 1, 1, int(hours), int(minutes), int(seconds), int(ns), time.UTC)) + return nil +} + +func (w timeWrapper) TimeValue() (Time, error) { + t := time.Time(w) + usec := int64(t.Hour())*microsecondsPerHour + + int64(t.Minute())*microsecondsPerMinute + + int64(t.Second())*microsecondsPerSecond + + int64(t.Nanosecond())/1000 + return Time{Microseconds: usec, Valid: true}, nil +} + type durationWrapper time.Duration func (w *durationWrapper) ScanInterval(v Interval) error { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5e7e2501..058abf5e 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -340,7 +340,7 @@ func NewConnInfo() *ConnInfo { // ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) ci.RegisterDataType(DataType{Name: "text", OID: TextOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) - ci.RegisterDataType(DataType{Value: &Time{}, Name: "time", OID: TimeOID}) + ci.RegisterDataType(DataType{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) ci.RegisterDataType(DataType{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) ci.RegisterDataType(DataType{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) // ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) diff --git a/pgtype/time.go b/pgtype/time.go index 3252a633..47dabe99 100644 --- a/pgtype/time.go +++ b/pgtype/time.go @@ -5,11 +5,18 @@ import ( "encoding/binary" "fmt" "strconv" - "time" "github.com/jackc/pgio" ) +type TimeScanner interface { + ScanTime(v Time) error +} + +type TimeValuer interface { + TimeValue() (Time, error) +} + // Time represents the PostgreSQL time type. The PostgreSQL time is a time of day without time zone. // // Time is represented as the number of microseconds since midnight in the same way that PostgreSQL does. Other time @@ -20,86 +27,151 @@ type Time struct { Valid bool } -// Set converts src into a Time and stores in dst. -func (dst *Time) Set(src interface{}) error { +func (t *Time) ScanTime(v Time) error { + *t = v + return nil +} + +func (t Time) TimeValue() (Time, error) { + return t, nil +} + +// Scan implements the database/sql Scanner interface. +func (t *Time) Scan(src interface{}) error { if src == nil { - *dst = Time{} + *t = Time{} return nil } - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } + switch src := src.(type) { + case string: + return scanPlanTextAnyToTimeScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), t) } - switch value := src.(type) { - case time.Time: - usec := int64(value.Hour())*microsecondsPerHour + - int64(value.Minute())*microsecondsPerMinute + - int64(value.Second())*microsecondsPerSecond + - int64(value.Nanosecond())/1000 - *dst = Time{Microseconds: usec, Valid: true} - case *time.Time: - if value == nil { - *dst = Time{} - } else { - return dst.Set(*value) - } - default: - if originalSrc, ok := underlyingTimeType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Time", value) + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (t Time) Value() (driver.Value, error) { + if !t.Valid { + return nil, nil + } + + buf, err := TimeCodec{}.PlanEncode(nil, 0, TextFormatCode, t).Encode(t, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type TimeCodec struct{} + +func (TimeCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (TimeCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (TimeCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(TimeValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanTimeCodecBinary{} + case TextFormatCode: + return encodePlanTimeCodecText{} } return nil } -func (dst Time) Get() interface{} { - if !dst.Valid { - return nil +type encodePlanTimeCodecBinary struct{} + +func (encodePlanTimeCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + t, err := value.(TimeValuer).TimeValue() + if err != nil { + return nil, err } - return dst.Microseconds + + if !t.Valid { + return nil, nil + } + + return pgio.AppendInt64(buf, t.Microseconds), nil } -func (src *Time) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) +type encodePlanTimeCodecText struct{} + +func (encodePlanTimeCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + t, err := value.(TimeValuer).TimeValue() + if err != nil { + return nil, err } - switch v := dst.(type) { - case *time.Time: - // 24:00:00 is max allowed time in PostgreSQL, but time.Time will normalize that to 00:00:00 the next day. - var maxRepresentableByTime int64 = 24*60*60*1000000 - 1 - if src.Microseconds > maxRepresentableByTime { - return fmt.Errorf("%d microseconds cannot be represented as time.Time", src.Microseconds) - } - - usec := src.Microseconds - hours := usec / microsecondsPerHour - usec -= hours * microsecondsPerHour - minutes := usec / microsecondsPerMinute - usec -= minutes * microsecondsPerMinute - seconds := usec / microsecondsPerSecond - usec -= seconds * microsecondsPerSecond - ns := usec * 1000 - *v = time.Date(2000, 1, 1, int(hours), int(minutes), int(seconds), int(ns), time.UTC) - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) + if !t.Valid { + return nil, nil } + + usec := t.Microseconds + hours := usec / microsecondsPerHour + usec -= hours * microsecondsPerHour + minutes := usec / microsecondsPerMinute + usec -= minutes * microsecondsPerMinute + seconds := usec / microsecondsPerSecond + usec -= seconds * microsecondsPerSecond + + s := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, usec) + + return append(buf, s...), nil } -// DecodeText decodes from src into dst. -func (dst *Time) DecodeText(ci *ConnInfo, src []byte) error { +func (TimeCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case TimeScanner: + return scanPlanBinaryTimeToTimeScanner{} + } + case TextFormatCode: + switch target.(type) { + case TimeScanner: + return scanPlanTextAnyToTimeScanner{} + } + } + + return nil +} + +type scanPlanBinaryTimeToTimeScanner struct{} + +func (scanPlanBinaryTimeToTimeScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(TimeScanner) + if src == nil { - *dst = Time{} - return nil + return scanner.ScanTime(Time{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for time: %v", len(src)) + } + + usec := int64(binary.BigEndian.Uint64(src)) + + return scanner.ScanTime(Time{Microseconds: usec, Valid: true}) +} + +type scanPlanTextAnyToTimeScanner struct{} + +func (scanPlanTextAnyToTimeScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(TimeScanner) + + if src == nil { + return scanner.ScanTime(Time{}) } s := string(src) @@ -140,79 +212,22 @@ func (dst *Time) DecodeText(ci *ConnInfo, src []byte) error { usec += n } - *dst = Time{Microseconds: usec, Valid: true} - - return nil + return scanner.ScanTime(Time{Microseconds: usec, Valid: true}) } -// DecodeBinary decodes from src into dst. -func (dst *Time) DecodeBinary(ci *ConnInfo, src []byte) error { +func (c TimeCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c TimeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = Time{} - return nil - } - - if len(src) != 8 { - return fmt.Errorf("invalid length for time: %v", len(src)) - } - - usec := int64(binary.BigEndian.Uint64(src)) - *dst = Time{Microseconds: usec, Valid: true} - - return nil -} - -// EncodeText writes the text encoding of src into w. -func (src Time) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { return nil, nil } - usec := src.Microseconds - hours := usec / microsecondsPerHour - usec -= hours * microsecondsPerHour - minutes := usec / microsecondsPerMinute - usec -= minutes * microsecondsPerMinute - seconds := usec / microsecondsPerSecond - usec -= seconds * microsecondsPerSecond - - s := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, usec) - - return append(buf, s...), nil -} - -// EncodeBinary writes the binary encoding of src into w. If src.Time is not in -// the UTC time zone it returns an error. -func (src Time) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil + var t Time + err := codecScan(c, ci, oid, format, src, &t) + if err != nil { + return nil, err } - - return pgio.AppendInt64(buf, src.Microseconds), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Time) Scan(src interface{}) error { - if src == nil { - *dst = Time{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - case time.Time: - return dst.Set(src) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Time) Value() (driver.Value, error) { - return EncodeValueText(src) + return t, nil } diff --git a/pgtype/time_test.go b/pgtype/time_test.go index f710ed03..8394a951 100644 --- a/pgtype/time_test.go +++ b/pgtype/time_test.go @@ -1,117 +1,45 @@ package pgtype_test import ( - "reflect" "testing" "time" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestTimeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "time", []interface{}{ - &pgtype.Time{Microseconds: 0, Valid: true}, - &pgtype.Time{Microseconds: 1, Valid: true}, - &pgtype.Time{Microseconds: 86399999999, Valid: true}, - &pgtype.Time{Microseconds: 86400000000, Valid: true}, - &pgtype.Time{}, +func TestTimeCodec(t *testing.T) { + testPgxCodec(t, "time", []PgxTranscodeTestCase{ + { + pgtype.Time{Microseconds: 0, Valid: true}, + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 0, Valid: true}), + }, + { + pgtype.Time{Microseconds: 1, Valid: true}, + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 1, Valid: true}), + }, + { + pgtype.Time{Microseconds: 86399999999, Valid: true}, + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 86399999999, Valid: true}), + }, + { + pgtype.Time{Microseconds: 86400000000, Valid: true}, + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 86400000000, Valid: true}), + }, + { + time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 0, Valid: true}), + }, + { + pgtype.Time{Microseconds: 0, Valid: true}, + new(time.Time), + isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)), + }, + {pgtype.Time{}, new(pgtype.Time), isExpectedEq(pgtype.Time{})}, + {nil, new(pgtype.Time), isExpectedEq(pgtype.Time{})}, }) } - -func TestTimeSet(t *testing.T) { - type _time time.Time - - successfulTests := []struct { - source interface{} - result pgtype.Time - }{ - {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 0, Valid: true}}, - {source: time.Date(1900, 1, 1, 1, 0, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 3600000000, Valid: true}}, - {source: time.Date(1900, 1, 1, 0, 1, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 60000000, Valid: true}}, - {source: time.Date(1900, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Time{Microseconds: 1000000, Valid: true}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 1, time.UTC), result: pgtype.Time{Microseconds: 0, Valid: true}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 1000, time.UTC), result: pgtype.Time{Microseconds: 1, Valid: true}}, - {source: time.Date(1999, 12, 31, 23, 59, 59, 999999999, time.UTC), result: pgtype.Time{Microseconds: 86399999999, Valid: true}}, - {source: time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local), result: pgtype.Time{Microseconds: 2, Valid: true}}, - {source: func(t time.Time) *time.Time { return &t }(time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local)), result: pgtype.Time{Microseconds: 2, Valid: true}}, - {source: nil, result: pgtype.Time{}}, - {source: (*time.Time)(nil), result: pgtype.Time{}}, - {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 3000, time.UTC)), result: pgtype.Time{Microseconds: 3, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.Time - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestTimeAssignTo(t *testing.T) { - var tim time.Time - var ptim *time.Time - - simpleTests := []struct { - src pgtype.Time - dst interface{} - expected interface{} - }{ - {src: pgtype.Time{Microseconds: 0, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}, - {src: pgtype.Time{Microseconds: 3600000000, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 1, 0, 0, 0, time.UTC)}, - {src: pgtype.Time{Microseconds: 60000000, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 1, 0, 0, time.UTC)}, - {src: pgtype.Time{Microseconds: 1000000, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC)}, - {src: pgtype.Time{Microseconds: 1, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 0, 1000, time.UTC)}, - {src: pgtype.Time{Microseconds: 86399999999, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 23, 59, 59, 999999000, time.UTC)}, - {src: pgtype.Time{Microseconds: 0}, dst: &ptim, expected: ((*time.Time)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Time - dst interface{} - expected interface{} - }{ - {src: pgtype.Time{Microseconds: 0, Valid: true}, dst: &ptim, expected: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Time - dst interface{} - }{ - {src: pgtype.Time{Microseconds: 86400000000, Valid: true}, dst: &tim}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} From 0056156904301c6cf65de73354034ee0968233d1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 21 Jan 2022 16:51:53 -0600 Subject: [PATCH 0862/1158] Add time array --- pgtype/pgtype.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 058abf5e..a7c97a73 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -76,6 +76,7 @@ const ( TimestampOID = 1114 TimestampArrayOID = 1115 DateArrayOID = 1182 + TimeArrayOID = 1183 TimestamptzOID = 1184 TimestamptzArrayOID = 1185 IntervalOID = 1186 @@ -304,6 +305,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementCodec: Uint32Codec{}, ElementOID: CIDOID}}) ci.RegisterDataType(DataType{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementCodec: Uint32Codec{}, ElementOID: OIDOID}}) ci.RegisterDataType(DataType{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementCodec: Uint32Codec{}, ElementOID: XIDOID}}) + ci.RegisterDataType(DataType{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementCodec: TimeCodec{}, ElementOID: TimeOID}}) ci.RegisterDataType(DataType{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) ci.RegisterDataType(DataType{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) From b9b5e35d0fd1309bcdf3b5b2611f73bd8fa9f2d8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jan 2022 09:31:59 -0600 Subject: [PATCH 0863/1158] Convert numeric to Codec --- conn_test.go | 2 +- pgtype/builtin_wrappers.go | 28 + pgtype/numeric.go | 1103 +++++++++++++++------------------- pgtype/numeric_array.go | 672 --------------------- pgtype/numeric_array_test.go | 305 ---------- pgtype/numeric_test.go | 468 +++------------ pgtype/pgtype.go | 4 +- 7 files changed, 625 insertions(+), 1957 deletions(-) delete mode 100644 pgtype/numeric_array.go delete mode 100644 pgtype/numeric_array_test.go diff --git a/conn_test.go b/conn_test.go index 0d7bcb31..cedd6a7a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -901,7 +901,7 @@ func TestDomainType(t *testing.T) { if err != nil { t.Fatalf("did not find uint64 OID, %v", err) } - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.Numeric{}, Name: "uint64", OID: uint64OID}) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}}) // String is still an acceptable argument after registration err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n) diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 3020f9bb..3becf06b 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -273,6 +273,20 @@ func (w float32Wrapper) Int64Value() (Int8, error) { return Int8{Int: int64(w), Valid: true}, nil } +func (w *float32Wrapper) ScanFloat64(v Float8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *float32") + } + + *w = float32Wrapper(v.Float) + + return nil +} + +func (w float32Wrapper) Float64Value() (Float8, error) { + return Float8{Float: float64(w), Valid: true}, nil +} + type float64Wrapper float64 func (w float64Wrapper) SkipUnderlyingTypePlan() {} @@ -295,6 +309,20 @@ func (w float64Wrapper) Int64Value() (Int8, error) { return Int8{Int: int64(w), Valid: true}, nil } +func (w *float64Wrapper) ScanFloat64(v Float8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *float64") + } + + *w = float64Wrapper(v.Float) + + return nil +} + +func (w float64Wrapper) Float64Value() (Float8, error) { + return Float8{Float: float64(w), Valid: true}, nil +} + type stringWrapper string func (w stringWrapper) SkipUnderlyingTypePlan() {} diff --git a/pgtype/numeric.go b/pgtype/numeric.go index b24f433c..435c9618 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -55,422 +55,69 @@ var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) +type NumericScanner interface { + ScanNumeric(v Numeric) error +} + +type NumericValuer interface { + NumericValue() (Numeric, error) +} + type Numeric struct { Int *big.Int Exp int32 NaN bool InfinityModifier InfinityModifier Valid bool - - NumericDecoderWrapper func(interface{}) NumericDecoder - Getter func(Numeric) interface{} } -func (n *Numeric) NewTypeValue() Value { - return &Numeric{ - NumericDecoderWrapper: n.NumericDecoderWrapper, - Getter: n.Getter, - } -} - -func (n *Numeric) TypeName() string { - return "numeric" -} - -func (dst *Numeric) setNil() { - dst.Int = nil - dst.Exp = 0 - dst.NaN = false - dst.Valid = false -} - -func (dst *Numeric) setNaN() { - dst.Int = nil - dst.Exp = 0 - dst.NaN = true - dst.Valid = true -} - -func (dst *Numeric) setNumber(i *big.Int, exp int32) { - dst.Int = i - dst.Exp = exp - dst.NaN = false - dst.Valid = true -} - -func (dst *Numeric) Set(src interface{}) error { - if src == nil { - dst.setNil() - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case float32: - if math.IsNaN(float64(value)) { - dst.setNaN() - return nil - } else if math.IsInf(float64(value), 1) { - *dst = Numeric{Valid: true, InfinityModifier: Infinity} - return nil - } else if math.IsInf(float64(value), -1) { - *dst = Numeric{Valid: true, InfinityModifier: NegativeInfinity} - return nil - } - num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) - if err != nil { - return err - } - dst.setNumber(num, exp) - case float64: - if math.IsNaN(value) { - dst.setNaN() - return nil - } else if math.IsInf(value, 1) { - *dst = Numeric{Valid: true, InfinityModifier: Infinity} - return nil - } else if math.IsInf(value, -1) { - *dst = Numeric{Valid: true, InfinityModifier: NegativeInfinity} - return nil - } - num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) - if err != nil { - return err - } - dst.setNumber(num, exp) - case int8: - dst.setNumber(big.NewInt(int64(value)), 0) - case uint8: - dst.setNumber(big.NewInt(int64(value)), 0) - case int16: - dst.setNumber(big.NewInt(int64(value)), 0) - case uint16: - dst.setNumber(big.NewInt(int64(value)), 0) - case int32: - dst.setNumber(big.NewInt(int64(value)), 0) - case uint32: - dst.setNumber(big.NewInt(int64(value)), 0) - case int64: - dst.setNumber(big.NewInt(value), 0) - case uint64: - dst.setNumber((&big.Int{}).SetUint64(value), 0) - case int: - dst.setNumber(big.NewInt(int64(value)), 0) - case uint: - dst.setNumber((&big.Int{}).SetUint64(uint64(value)), 0) - case string: - num, exp, err := parseNumericString(value) - if err != nil { - return err - } - dst.setNumber(num, exp) - case *float64: - if value == nil { - dst.setNil() - } else { - return dst.Set(*value) - } - case *float32: - if value == nil { - dst.setNil() - } else { - return dst.Set(*value) - } - case *int8: - if value == nil { - dst.setNil() - } else { - return dst.Set(*value) - } - case *uint8: - if value == nil { - dst.setNil() - } else { - return dst.Set(*value) - } - case *int16: - if value == nil { - dst.setNil() - } else { - return dst.Set(*value) - } - case *uint16: - if value == nil { - dst.setNil() - } else { - return dst.Set(*value) - } - case *int32: - if value == nil { - dst.setNil() - } else { - return dst.Set(*value) - } - case *uint32: - if value == nil { - dst.setNil() - } else { - return dst.Set(*value) - } - case *int64: - if value == nil { - dst.setNil() - } else { - return dst.Set(*value) - } - case *uint64: - if value == nil { - dst.setNil() - } else { - return dst.Set(*value) - } - case *int: - if value == nil { - dst.setNil() - } else { - return dst.Set(*value) - } - case *uint: - if value == nil { - dst.setNil() - } else { - return dst.Set(*value) - } - case *string: - if value == nil { - dst.setNil() - } else { - return dst.Set(*value) - } - case InfinityModifier: - *dst = Numeric{InfinityModifier: value, Valid: true} - default: - if originalSrc, ok := underlyingNumberType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Numeric", value) - } - +func (n *Numeric) ScanNumeric(v Numeric) error { + *n = v return nil } -func (dst Numeric) Get() interface{} { - if dst.Getter != nil { - return dst.Getter(dst) - } - - if !dst.Valid { - return nil - } - - if dst.InfinityModifier != None { - return dst.InfinityModifier - } - return dst +func (n Numeric) NumericValue() (Numeric, error) { + return n, nil } -type NumericDecoder interface { - DecodeNumeric(*Numeric) error -} - -func (src *Numeric) AssignTo(dst interface{}) error { - if d, ok := dst.(NumericDecoder); ok { - return d.DecodeNumeric(src) - } else { - if src.NumericDecoderWrapper != nil { - d = src.NumericDecoderWrapper(dst) - if d != nil { - return d.DecodeNumeric(src) - } - } - } - - if !src.Valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case *float32: - f, err := src.toFloat64() - if err != nil { - return err - } - return float64AssignTo(f, src.Valid, dst) - case *float64: - f, err := src.toFloat64() - if err != nil { - return err - } - return float64AssignTo(f, src.Valid, dst) - case *int: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(bigMaxInt) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) - } - if normalizedInt.Cmp(bigMinInt) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) - } - *v = int(normalizedInt.Int64()) - case *int8: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(bigMaxInt8) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) - } - if normalizedInt.Cmp(bigMinInt8) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) - } - *v = int8(normalizedInt.Int64()) - case *int16: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(bigMaxInt16) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) - } - if normalizedInt.Cmp(bigMinInt16) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) - } - *v = int16(normalizedInt.Int64()) - case *int32: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(bigMaxInt32) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) - } - if normalizedInt.Cmp(bigMinInt32) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) - } - *v = int32(normalizedInt.Int64()) - case *int64: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(bigMaxInt64) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) - } - if normalizedInt.Cmp(bigMinInt64) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) - } - *v = normalizedInt.Int64() - case *uint: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) - } else if normalizedInt.Cmp(bigMaxUint) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) - } - *v = uint(normalizedInt.Uint64()) - case *uint8: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) - } else if normalizedInt.Cmp(bigMaxUint8) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) - } - *v = uint8(normalizedInt.Uint64()) - case *uint16: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) - } else if normalizedInt.Cmp(bigMaxUint16) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) - } - *v = uint16(normalizedInt.Uint64()) - case *uint32: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) - } else if normalizedInt.Cmp(bigMaxUint32) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) - } - *v = uint32(normalizedInt.Uint64()) - case *uint64: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) - } else if normalizedInt.Cmp(bigMaxUint64) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) - } - *v = normalizedInt.Uint64() - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - - return nil -} - -func (dst *Numeric) toBigInt() (*big.Int, error) { - if dst.Exp == 0 { - return dst.Int, nil +func (n *Numeric) toBigInt() (*big.Int, error) { + if n.Exp == 0 { + return n.Int, nil } num := &big.Int{} - num.Set(dst.Int) - if dst.Exp > 0 { + num.Set(n.Int) + if n.Exp > 0 { mul := &big.Int{} - mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil) + mul.Exp(big10, big.NewInt(int64(n.Exp)), nil) num.Mul(num, mul) return num, nil } div := &big.Int{} - div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil) + div.Exp(big10, big.NewInt(int64(-n.Exp)), nil) remainder := &big.Int{} num.DivMod(num, div, remainder) if remainder.Cmp(big0) != 0 { - return nil, fmt.Errorf("cannot convert %v to integer", dst) + return nil, fmt.Errorf("cannot convert %v to integer", n) } return num, nil } -func (src *Numeric) toFloat64() (float64, error) { - if src.NaN { +func (n *Numeric) toFloat64() (float64, error) { + if n.NaN { return math.NaN(), nil - } else if src.InfinityModifier == Infinity { + } else if n.InfinityModifier == Infinity { return math.Inf(1), nil - } else if src.InfinityModifier == NegativeInfinity { + } else if n.InfinityModifier == NegativeInfinity { return math.Inf(-1), nil } buf := make([]byte, 0, 32) - buf = append(buf, src.Int.String()...) + buf = append(buf, n.Int.String()...) buf = append(buf, 'e') - buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...) + buf = append(buf, strconv.FormatInt(int64(n.Exp), 10)...) f, err := strconv.ParseFloat(string(buf), 64) if err != nil { @@ -479,32 +126,6 @@ func (src *Numeric) toFloat64() (float64, error) { return f, nil } -func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - dst.setNil() - return nil - } - - if string(src) == "NaN" { - dst.setNaN() - return nil - } else if string(src) == "Infinity" { - *dst = Numeric{Valid: true, InfinityModifier: Infinity} - return nil - } else if string(src) == "-Infinity" { - *dst = Numeric{Valid: true, InfinityModifier: NegativeInfinity} - return nil - } - - num, exp, err := parseNumericString(string(src)) - if err != nil { - return err - } - - dst.setNumber(num, exp) - return nil -} - func parseNumericString(str string) (n *big.Int, exp int32, err error) { parts := strings.SplitN(str, ".", 2) digits := strings.Join(parts, "") @@ -526,12 +147,388 @@ func parseNumericString(str string) (n *big.Int, exp int32, err error) { return accum, exp, nil } -func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { +func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { + digits := len(src) / 2 + if digits > 4 { + digits = 4 + } + + rp := 0 + + for i := 0; i < digits; i++ { + if i > 0 { + accum *= nbase + } + accum += int64(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + return accum, rp, digits +} + +// Scan implements the database/sql Scanner interface. +func (n *Numeric) Scan(src interface{}) error { if src == nil { - dst.setNil() + *n = Numeric{} return nil } + switch src := src.(type) { + case string: + return scanPlanTextAnyToNumericScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), n) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (n Numeric) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + + buf, err := NumericCodec{}.PlanEncode(nil, 0, TextFormatCode, n).Encode(n, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +func (n Numeric) MarshalJSON() ([]byte, error) { + if !n.Valid { + return []byte("null"), nil + } + + if n.NaN { + return []byte(`"NaN"`), nil + } + + intStr := n.Int.String() + buf := &bytes.Buffer{} + exp := int(n.Exp) + if exp > 0 { + buf.WriteString(intStr) + for i := 0; i < exp; i++ { + buf.WriteByte('0') + } + } else if exp < 0 { + if len(intStr) <= -exp { + buf.WriteString("0.") + leadingZeros := -exp - len(intStr) + for i := 0; i < leadingZeros; i++ { + buf.WriteByte('0') + } + buf.WriteString(intStr) + } else if len(intStr) > -exp { + dpPos := len(intStr) + exp + buf.WriteString(intStr[:dpPos]) + buf.WriteByte('.') + buf.WriteString(intStr[dpPos:]) + } + } else { + buf.WriteString(intStr) + } + + return buf.Bytes(), nil +} + +type NumericCodec struct{} + +func (NumericCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (NumericCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (NumericCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case NumericValuer: + return encodePlanNumericCodecBinaryNumericValuer{} + case Float64Valuer: + return encodePlanNumericCodecBinaryFloat64Valuer{} + case Int64Valuer: + return encodePlanNumericCodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case NumericValuer: + return encodePlanNumericCodecTextNumericValuer{} + case Float64Valuer: + return encodePlanNumericCodecTextFloat64Valuer{} + case Int64Valuer: + return encodePlanNumericCodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanNumericCodecBinaryNumericValuer struct{} + +func (encodePlanNumericCodecBinaryNumericValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(NumericValuer).NumericValue() + if err != nil { + return nil, err + } + + return encodeNumericBinary(n, buf) +} + +type encodePlanNumericCodecBinaryFloat64Valuer struct{} + +func (encodePlanNumericCodecBinaryFloat64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Float64Valuer).Float64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if math.IsNaN(n.Float) { + return encodeNumericBinary(Numeric{NaN: true, Valid: true}, buf) + } else if math.IsInf(n.Float, 1) { + return encodeNumericBinary(Numeric{InfinityModifier: Infinity, Valid: true}, buf) + } else if math.IsInf(n.Float, -1) { + return encodeNumericBinary(Numeric{InfinityModifier: NegativeInfinity, Valid: true}, buf) + } + num, exp, err := parseNumericString(strconv.FormatFloat(n.Float, 'f', -1, 64)) + if err != nil { + return nil, err + } + + return encodeNumericBinary(Numeric{Int: num, Exp: exp, Valid: true}, buf) +} + +type encodePlanNumericCodecBinaryInt64Valuer struct{} + +func (encodePlanNumericCodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + return encodeNumericBinary(Numeric{Int: big.NewInt(n.Int), Valid: true}, buf) +} + +func encodeNumericBinary(n Numeric, buf []byte) (newBuf []byte, err error) { + if !n.Valid { + return nil, nil + } + + if n.NaN { + buf = pgio.AppendUint64(buf, pgNumericNaN) + return buf, nil + } else if n.InfinityModifier == Infinity { + buf = pgio.AppendUint64(buf, pgNumericPosInf) + return buf, nil + } else if n.InfinityModifier == NegativeInfinity { + buf = pgio.AppendUint64(buf, pgNumericNegInf) + return buf, nil + } + + var sign int16 + if n.Int.Cmp(big0) < 0 { + sign = 16384 + } + + absInt := &big.Int{} + wholePart := &big.Int{} + fracPart := &big.Int{} + remainder := &big.Int{} + absInt.Abs(n.Int) + + // Normalize absInt and exp to where exp is always a multiple of 4. This makes + // converting to 16-bit base 10,000 digits easier. + var exp int32 + switch n.Exp % 4 { + case 1, -3: + exp = n.Exp - 1 + absInt.Mul(absInt, big10) + case 2, -2: + exp = n.Exp - 2 + absInt.Mul(absInt, big100) + case 3, -1: + exp = n.Exp - 3 + absInt.Mul(absInt, big1000) + default: + exp = n.Exp + } + + if exp < 0 { + divisor := &big.Int{} + divisor.Exp(big10, big.NewInt(int64(-exp)), nil) + wholePart.DivMod(absInt, divisor, fracPart) + fracPart.Add(fracPart, divisor) + } else { + wholePart = absInt + } + + var wholeDigits, fracDigits []int16 + + for wholePart.Cmp(big0) != 0 { + wholePart.DivMod(wholePart, bigNBase, remainder) + wholeDigits = append(wholeDigits, int16(remainder.Int64())) + } + + if fracPart.Cmp(big0) != 0 { + for fracPart.Cmp(big1) != 0 { + fracPart.DivMod(fracPart, bigNBase, remainder) + fracDigits = append(fracDigits, int16(remainder.Int64())) + } + } + + buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits))) + + var weight int16 + if len(wholeDigits) > 0 { + weight = int16(len(wholeDigits) - 1) + if exp > 0 { + weight += int16(exp / 4) + } + } else { + weight = int16(exp/4) - 1 + int16(len(fracDigits)) + } + buf = pgio.AppendInt16(buf, weight) + + buf = pgio.AppendInt16(buf, sign) + + var dscale int16 + if n.Exp < 0 { + dscale = int16(-n.Exp) + } + buf = pgio.AppendInt16(buf, dscale) + + for i := len(wholeDigits) - 1; i >= 0; i-- { + buf = pgio.AppendInt16(buf, wholeDigits[i]) + } + + for i := len(fracDigits) - 1; i >= 0; i-- { + buf = pgio.AppendInt16(buf, fracDigits[i]) + } + + return buf, nil +} + +type encodePlanNumericCodecTextNumericValuer struct{} + +func (encodePlanNumericCodecTextNumericValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(NumericValuer).NumericValue() + if err != nil { + return nil, err + } + + return encodeNumericText(n, buf) +} + +type encodePlanNumericCodecTextFloat64Valuer struct{} + +func (encodePlanNumericCodecTextFloat64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Float64Valuer).Float64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if math.IsNaN(n.Float) { + return encodeNumericBinary(Numeric{NaN: true, Valid: true}, buf) + } else if math.IsInf(n.Float, 1) { + return encodeNumericBinary(Numeric{InfinityModifier: Infinity, Valid: true}, buf) + } else if math.IsInf(n.Float, -1) { + return encodeNumericBinary(Numeric{InfinityModifier: NegativeInfinity, Valid: true}, buf) + } + num, exp, err := parseNumericString(strconv.FormatFloat(n.Float, 'f', -1, 64)) + if err != nil { + return nil, err + } + + return encodeNumericText(Numeric{Int: num, Exp: exp, Valid: true}, buf) +} + +type encodePlanNumericCodecTextInt64Valuer struct{} + +func (encodePlanNumericCodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + return encodeNumericText(Numeric{Int: big.NewInt(n.Int), Valid: true}, buf) +} + +func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { + if !n.Valid { + return nil, nil + } + + if n.NaN { + buf = append(buf, "NaN"...) + return buf, nil + } else if n.InfinityModifier == Infinity { + buf = append(buf, "Infinity"...) + return buf, nil + } else if n.InfinityModifier == NegativeInfinity { + buf = append(buf, "-Infinity"...) + return buf, nil + } + + buf = append(buf, n.Int.String()...) + buf = append(buf, 'e') + buf = append(buf, strconv.FormatInt(int64(n.Exp), 10)...) + return buf, nil +} + +func (NumericCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case NumericScanner: + return scanPlanBinaryNumericToNumericScanner{} + case Float64Scanner: + return scanPlanBinaryNumericToFloat64Scanner{} + case Int64Scanner: + return scanPlanBinaryNumericToInt64Scanner{} + } + case TextFormatCode: + switch target.(type) { + case NumericScanner: + return scanPlanTextAnyToNumericScanner{} + case Float64Scanner: + return scanPlanTextAnyToFloat64Scanner{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +type scanPlanBinaryNumericToNumericScanner struct{} + +func (scanPlanBinaryNumericToNumericScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(NumericScanner) + + if src == nil { + return scanner.ScanNumeric(Numeric{}) + } + if len(src) < 8 { return fmt.Errorf("numeric incomplete %v", src) } @@ -547,19 +544,15 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { rp += 2 if sign == pgNumericNaNSign { - dst.setNaN() - return nil + return scanner.ScanNumeric(Numeric{NaN: true, Valid: true}) } else if sign == pgNumericPosInfSign { - *dst = Numeric{Valid: true, InfinityModifier: Infinity} - return nil + return scanner.ScanNumeric(Numeric{InfinityModifier: Infinity, Valid: true}) } else if sign == pgNumericNegInfSign { - *dst = Numeric{Valid: true, InfinityModifier: NegativeInfinity} - return nil + return scanner.ScanNumeric(Numeric{InfinityModifier: NegativeInfinity, Valid: true}) } if ndigits == 0 { - dst.setNumber(big.NewInt(0), 0) - return nil + return scanner.ScanNumeric(Numeric{Int: big.NewInt(0), Valid: true}) } if len(src[rp:]) < int(ndigits)*2 { @@ -630,219 +623,117 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { accum.Neg(accum) } - dst.setNumber(accum, exp) - - return nil - + return scanner.ScanNumeric(Numeric{Int: accum, Exp: exp, Valid: true}) } -func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { - digits := len(src) / 2 - if digits > 4 { - digits = 4 - } +type scanPlanBinaryNumericToFloat64Scanner struct{} - rp := 0 +func (scanPlanBinaryNumericToFloat64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(Float64Scanner) - for i := 0; i < digits; i++ { - if i > 0 { - accum *= nbase - } - accum += int64(binary.BigEndian.Uint16(src[rp:])) - rp += 2 - } - - return accum, rp, digits -} - -func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if src.NaN { - buf = append(buf, "NaN"...) - return buf, nil - } else if src.InfinityModifier == Infinity { - buf = append(buf, "Infinity"...) - return buf, nil - } else if src.InfinityModifier == NegativeInfinity { - buf = append(buf, "-Infinity"...) - return buf, nil - } - - buf = append(buf, src.Int.String()...) - buf = append(buf, 'e') - buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...) - return buf, nil -} - -func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if src.NaN { - buf = pgio.AppendUint64(buf, pgNumericNaN) - return buf, nil - } else if src.InfinityModifier == Infinity { - buf = pgio.AppendUint64(buf, pgNumericPosInf) - return buf, nil - } else if src.InfinityModifier == NegativeInfinity { - buf = pgio.AppendUint64(buf, pgNumericNegInf) - return buf, nil - } - - var sign int16 - if src.Int.Cmp(big0) < 0 { - sign = 16384 - } - - absInt := &big.Int{} - wholePart := &big.Int{} - fracPart := &big.Int{} - remainder := &big.Int{} - absInt.Abs(src.Int) - - // Normalize absInt and exp to where exp is always a multiple of 4. This makes - // converting to 16-bit base 10,000 digits easier. - var exp int32 - switch src.Exp % 4 { - case 1, -3: - exp = src.Exp - 1 - absInt.Mul(absInt, big10) - case 2, -2: - exp = src.Exp - 2 - absInt.Mul(absInt, big100) - case 3, -1: - exp = src.Exp - 3 - absInt.Mul(absInt, big1000) - default: - exp = src.Exp - } - - if exp < 0 { - divisor := &big.Int{} - divisor.Exp(big10, big.NewInt(int64(-exp)), nil) - wholePart.DivMod(absInt, divisor, fracPart) - fracPart.Add(fracPart, divisor) - } else { - wholePart = absInt - } - - var wholeDigits, fracDigits []int16 - - for wholePart.Cmp(big0) != 0 { - wholePart.DivMod(wholePart, bigNBase, remainder) - wholeDigits = append(wholeDigits, int16(remainder.Int64())) - } - - if fracPart.Cmp(big0) != 0 { - for fracPart.Cmp(big1) != 0 { - fracPart.DivMod(fracPart, bigNBase, remainder) - fracDigits = append(fracDigits, int16(remainder.Int64())) - } - } - - buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits))) - - var weight int16 - if len(wholeDigits) > 0 { - weight = int16(len(wholeDigits) - 1) - if exp > 0 { - weight += int16(exp / 4) - } - } else { - weight = int16(exp/4) - 1 + int16(len(fracDigits)) - } - buf = pgio.AppendInt16(buf, weight) - - buf = pgio.AppendInt16(buf, sign) - - var dscale int16 - if src.Exp < 0 { - dscale = int16(-src.Exp) - } - buf = pgio.AppendInt16(buf, dscale) - - for i := len(wholeDigits) - 1; i >= 0; i-- { - buf = pgio.AppendInt16(buf, wholeDigits[i]) - } - - for i := len(fracDigits) - 1; i >= 0; i-- { - buf = pgio.AppendInt16(buf, fracDigits[i]) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Numeric) Scan(src interface{}) error { if src == nil { - dst.setNil() - return nil + return scanner.ScanFloat64(Float8{}) } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + var n Numeric + + err := scanPlanBinaryNumericToNumericScanner{}.Scan(ci, oid, formatCode, src, &n) + if err != nil { + return err } - return fmt.Errorf("cannot scan %T", src) + f64, err := n.toFloat64() + if err != nil { + return err + } + + return scanner.ScanFloat64(Float8{Float: f64, Valid: true}) } -// Value implements the database/sql/driver Valuer interface. -func (src Numeric) Value() (driver.Value, error) { - if !src.Valid { +type scanPlanBinaryNumericToInt64Scanner struct{} + +func (scanPlanBinaryNumericToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(Int64Scanner) + + if src == nil { + return scanner.ScanInt64(Int8{}) + } + + var n Numeric + + err := scanPlanBinaryNumericToNumericScanner{}.Scan(ci, oid, formatCode, src, &n) + if err != nil { + return err + } + + bigInt, err := n.toBigInt() + if err != nil { + return err + } + + if !bigInt.IsInt64() { + return fmt.Errorf("%v is out of range for int64", bigInt) + } + + return scanner.ScanInt64(Int8{Int: bigInt.Int64(), Valid: true}) +} + +type scanPlanTextAnyToNumericScanner struct{} + +func (scanPlanTextAnyToNumericScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(NumericScanner) + + if src == nil { + return scanner.ScanNumeric(Numeric{}) + } + + if string(src) == "NaN" { + return scanner.ScanNumeric(Numeric{NaN: true, Valid: true}) + } else if string(src) == "Infinity" { + return scanner.ScanNumeric(Numeric{InfinityModifier: Infinity, Valid: true}) + } else if string(src) == "-Infinity" { + return scanner.ScanNumeric(Numeric{InfinityModifier: NegativeInfinity, Valid: true}) + } + + num, exp, err := parseNumericString(string(src)) + if err != nil { + return err + } + + return scanner.ScanNumeric(Numeric{Int: num, Exp: exp, Valid: true}) +} + +func (c NumericCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { return nil, nil } - buf, err := src.EncodeText(nil, nil) + if format == TextFormatCode { + return string(src), nil + } + + var n Numeric + err := codecScan(c, ci, oid, format, src, &n) if err != nil { return nil, err } + buf, err := ci.Encode(oid, TextFormatCode, n, nil) + if err != nil { + return nil, err + } return string(buf), nil } -func (src Numeric) MarshalJSON() ([]byte, error) { - if !src.Valid { - return []byte("null"), nil +func (c NumericCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil } - if src.NaN { - return []byte(`"NaN"`), nil + var n Numeric + err := codecScan(c, ci, oid, format, src, &n) + if err != nil { + return nil, err } - - intStr := src.Int.String() - buf := &bytes.Buffer{} - exp := int(src.Exp) - if exp > 0 { - buf.WriteString(intStr) - for i := 0; i < exp; i++ { - buf.WriteByte('0') - } - } else if exp < 0 { - if len(intStr) <= -exp { - buf.WriteString("0.") - leadingZeros := -exp - len(intStr) - for i := 0; i < leadingZeros; i++ { - buf.WriteByte('0') - } - buf.WriteString(intStr) - } else if len(intStr) > -exp { - dpPos := len(intStr) + exp - buf.WriteString(intStr[:dpPos]) - buf.WriteByte('.') - buf.WriteString(intStr[dpPos:]) - } - } else { - buf.WriteString(intStr) - } - - return buf.Bytes(), nil + return n, nil } diff --git a/pgtype/numeric_array.go b/pgtype/numeric_array.go deleted file mode 100644 index 3e9298b6..00000000 --- a/pgtype/numeric_array.go +++ /dev/null @@ -1,672 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type NumericArray struct { - Elements []Numeric - Dimensions []ArrayDimension - Valid bool -} - -func (dst *NumericArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = NumericArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []float32: - if value == nil { - *dst = NumericArray{} - } else if len(value) == 0 { - *dst = NumericArray{Valid: true} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*float32: - if value == nil { - *dst = NumericArray{} - } else if len(value) == 0 { - *dst = NumericArray{Valid: true} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []float64: - if value == nil { - *dst = NumericArray{} - } else if len(value) == 0 { - *dst = NumericArray{Valid: true} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*float64: - if value == nil { - *dst = NumericArray{} - } else if len(value) == 0 { - *dst = NumericArray{Valid: true} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []int64: - if value == nil { - *dst = NumericArray{} - } else if len(value) == 0 { - *dst = NumericArray{Valid: true} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int64: - if value == nil { - *dst = NumericArray{} - } else if len(value) == 0 { - *dst = NumericArray{Valid: true} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint64: - if value == nil { - *dst = NumericArray{} - } else if len(value) == 0 { - *dst = NumericArray{Valid: true} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint64: - if value == nil { - *dst = NumericArray{} - } else if len(value) == 0 { - *dst = NumericArray{Valid: true} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Numeric: - if value == nil { - *dst = NumericArray{} - } else if len(value) == 0 { - *dst = NumericArray{Valid: true} - } else { - *dst = NumericArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = NumericArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for NumericArray", src) - } - if elementsLength == 0 { - *dst = NumericArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to NumericArray", src) - } - - *dst = NumericArray{ - Elements: make([]Numeric, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Numeric, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to NumericArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *NumericArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to NumericArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in NumericArray", err) - } - index++ - - return index, nil -} - -func (dst NumericArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *NumericArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]float32: - *v = make([]float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float32: - *v = make([]*float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]float64: - *v = make([]float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float64: - *v = make([]*float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *NumericArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from NumericArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from NumericArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = NumericArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Numeric - - if len(uta.Elements) > 0 { - elements = make([]Numeric, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Numeric - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = NumericArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *NumericArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = NumericArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = NumericArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Numeric, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = NumericArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("numeric"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "numeric") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *NumericArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src NumericArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/numeric_array_test.go b/pgtype/numeric_array_test.go deleted file mode 100644 index 4542ed3e..00000000 --- a/pgtype/numeric_array_test.go +++ /dev/null @@ -1,305 +0,0 @@ -package pgtype_test - -import ( - "math" - "math/big" - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestNumericArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "numeric[]", []interface{}{ - &pgtype.NumericArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.NumericArray{}, - &pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Valid: true}, - {Int: big.NewInt(2), Valid: true}, - {Int: big.NewInt(3), Valid: true}, - {Int: big.NewInt(4), Valid: true}, - {}, - {Int: big.NewInt(6), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Valid: true}, - {Int: big.NewInt(2), Valid: true}, - {Int: big.NewInt(3), Valid: true}, - {Int: big.NewInt(4), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestNumericArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.NumericArray - }{ - { - source: []float32{1}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []float32{float32(math.Copysign(0, -1))}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(0), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []float64{1}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []float64{math.Copysign(0, -1)}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(0), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]float32)(nil)), - result: pgtype.NumericArray{}, - }, - { - source: [][]float32{{1}, {2}}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Valid: true}, - {Int: big.NewInt(2), Valid: true}, - {Int: big.NewInt(3), Valid: true}, - {Int: big.NewInt(4), Valid: true}, - {Int: big.NewInt(5), Valid: true}, - {Int: big.NewInt(6), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]float32{{1}, {2}}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Valid: true}, - {Int: big.NewInt(2), Valid: true}, - {Int: big.NewInt(3), Valid: true}, - {Int: big.NewInt(4), Valid: true}, - {Int: big.NewInt(5), Valid: true}, - {Int: big.NewInt(6), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.NumericArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestNumericArrayAssignTo(t *testing.T) { - var float32Slice []float32 - var float64Slice []float64 - var float32SliceDim2 [][]float32 - var float32SliceDim4 [][][][]float32 - var float32ArrayDim2 [2][1]float32 - var float32ArrayDim4 [2][1][1][3]float32 - - simpleTests := []struct { - src pgtype.NumericArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &float32Slice, - expected: []float32{1}, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &float64Slice, - expected: []float64{1}, - }, - { - src: pgtype.NumericArray{}, - dst: &float32Slice, - expected: (([]float32)(nil)), - }, - { - src: pgtype.NumericArray{Valid: true}, - dst: &float32Slice, - expected: []float32{}, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &float32SliceDim2, - expected: [][]float32{{1}, {2}}, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Valid: true}, - {Int: big.NewInt(2), Valid: true}, - {Int: big.NewInt(3), Valid: true}, - {Int: big.NewInt(4), Valid: true}, - {Int: big.NewInt(5), Valid: true}, - {Int: big.NewInt(6), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &float32SliceDim4, - expected: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &float32ArrayDim2, - expected: [2][1]float32{{1}, {2}}, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Valid: true}, - {Int: big.NewInt(2), Valid: true}, - {Int: big.NewInt(3), Valid: true}, - {Int: big.NewInt(4), Valid: true}, - {Int: big.NewInt(5), Valid: true}, - {Int: big.NewInt(6), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &float32ArrayDim4, - expected: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.NumericArray - dst interface{} - }{ - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &float32Slice, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &float32ArrayDim2, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &float32Slice, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &float32ArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index ff53d92b..06d81ffe 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -6,7 +6,7 @@ import ( "math" "math/big" "math/rand" - "reflect" + "strconv" "testing" "github.com/jackc/pgx/v5/pgtype" @@ -14,34 +14,6 @@ import ( "github.com/stretchr/testify/require" ) -// For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0) -func numericEqual(left, right *pgtype.Numeric) bool { - return left.Valid == right.Valid && - left.Exp == right.Exp && - ((left.Int == nil && right.Int == nil) || (left.Int != nil && right.Int != nil && left.Int.Cmp(right.Int) == 0)) && - left.NaN == right.NaN -} - -// For test purposes only. -func numericNormalizedEqual(left, right *pgtype.Numeric) bool { - if left.Valid != right.Valid { - return false - } - - normLeft := &pgtype.Numeric{Int: (&big.Int{}).Set(left.Int), Valid: left.Valid} - normRight := &pgtype.Numeric{Int: (&big.Int{}).Set(right.Int), Valid: right.Valid} - - if left.Exp < right.Exp { - mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(right.Exp-left.Exp)), nil) - normRight.Int.Mul(normRight.Int, mul) - } else if left.Exp > right.Exp { - mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(left.Exp-right.Exp)), nil) - normLeft.Int.Mul(normLeft.Int, mul) - } - - return normLeft.Int.Cmp(normRight.Int) == 0 -} - func mustParseBigInt(t *testing.T, src string) *big.Int { i := &big.Int{} if _, ok := i.SetString(src, 10); !ok { @@ -50,368 +22,122 @@ func mustParseBigInt(t *testing.T, src string) *big.Int { return i } -func TestNumericNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select '0'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Valid: true}, - }, - { - SQL: "select '1'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Valid: true}, - }, - { - SQL: "select '10.00'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Valid: true}, - }, - { - SQL: "select '1e-3'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Valid: true}, - }, - { - SQL: "select '-1'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Valid: true}, - }, - { - SQL: "select '10000'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Valid: true}, - }, - { - SQL: "select '3.14'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Valid: true}, - }, - { - SQL: "select '1.1'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Valid: true}, - }, - { - SQL: "select '100010001'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Valid: true}, - }, - { - SQL: "select '100010001.0001'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Valid: true}, - }, - { - SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", - Value: &pgtype.Numeric{ - Int: mustParseBigInt(t, "423723478923478928934789237432487213832189417894318904389012483210893443219085471578891547854892438945012347981"), - Exp: -41, - Valid: true, - }, - }, - { - SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", - Value: &pgtype.Numeric{ - Int: mustParseBigInt(t, "8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), - Exp: -196, - Valid: true, - }, - }, - { - SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", - Value: &pgtype.Numeric{ - Int: mustParseBigInt(t, "123"), - Exp: -186, - Valid: true, - }, - }, - }) +func isExpectedEqNumeric(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + aa := a.(pgtype.Numeric) + vv := v.(pgtype.Numeric) + + if aa.Valid != vv.Valid { + return false + } + + // If NULL doesn't matter what the rest of the values are. + if !aa.Valid { + return true + } + + if !(aa.NaN == vv.NaN && aa.InfinityModifier == vv.InfinityModifier) { + return false + } + + // If NaN or InfinityModifier are set then Int and Exp don't matter. + if aa.NaN || aa.InfinityModifier != pgtype.None { + return true + } + + aaInt := (&big.Int{}).Set(aa.Int) + vvInt := (&big.Int{}).Set(vv.Int) + + if aa.Exp < vv.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(vv.Exp-aa.Exp)), nil) + vvInt.Mul(vvInt, mul) + } else if aa.Exp > vv.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(aa.Exp-vv.Exp)), nil) + aaInt.Mul(aaInt, mul) + } + + return aaInt.Cmp(vvInt) == 0 + } } -func TestNumericTranscode(t *testing.T) { +func mustParseNumeric(t *testing.T, src string) pgtype.Numeric { + var n pgtype.Numeric + plan := pgtype.NumericCodec{}.PlanScan(nil, pgtype.NumericOID, pgtype.TextFormatCode, &n, false) + require.NotNil(t, plan) + err := plan.Scan(nil, pgtype.NumericOID, pgtype.TextFormatCode, []byte(src), &n) + require.NoError(t, err) + return n +} + +func TestNumericCodec(t *testing.T) { max := new(big.Int).Exp(big.NewInt(10), big.NewInt(147454), nil) max.Add(max, big.NewInt(1)) - longestNumeric := &pgtype.Numeric{Int: max, Exp: -16383, Valid: true} + longestNumeric := pgtype.Numeric{Int: max, Exp: -16383, Valid: true} - testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ - &pgtype.Numeric{NaN: true, Valid: true}, - &pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, - &pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, - - &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Valid: true}, - &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Valid: true}, - &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Valid: true}, - &pgtype.Numeric{Int: big.NewInt(1), Exp: 6, Valid: true}, - - // preserves significant zeroes - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -1, Valid: true}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -2, Valid: true}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -3, Valid: true}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -4, Valid: true}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -5, Valid: true}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -6, Valid: true}, - - &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Valid: true}, - &pgtype.Numeric{Int: big.NewInt(123), Exp: -7, Valid: true}, - &pgtype.Numeric{Int: big.NewInt(123), Exp: -8, Valid: true}, - &pgtype.Numeric{Int: big.NewInt(123), Exp: -9, Valid: true}, - &pgtype.Numeric{Int: big.NewInt(123), Exp: -1500, Valid: true}, - &pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Valid: true}, - &pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Valid: true}, - &pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Valid: true}, - &pgtype.Numeric{Int: mustParseBigInt(t, "3723409723490243842378942378901237502734019231380123"), Exp: 81, Valid: true}, - &pgtype.Numeric{Int: mustParseBigInt(t, "723409723490243842378942378901237502734019231380123"), Exp: 82, Valid: true}, - &pgtype.Numeric{Int: mustParseBigInt(t, "23409723490243842378942378901237502734019231380123"), Exp: 83, Valid: true}, - &pgtype.Numeric{Int: mustParseBigInt(t, "3409723490243842378942378901237502734019231380123"), Exp: 84, Valid: true}, - &pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Valid: true}, - &pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Valid: true}, - &pgtype.Numeric{Int: mustParseBigInt(t, "3423409823409243892349028349023482934092340892390101"), Exp: -91, Valid: true}, - &pgtype.Numeric{Int: mustParseBigInt(t, "423409823409243892349028349023482934092340892390101"), Exp: -92, Valid: true}, - &pgtype.Numeric{Int: mustParseBigInt(t, "23409823409243892349028349023482934092340892390101"), Exp: -93, Valid: true}, - &pgtype.Numeric{Int: mustParseBigInt(t, "3409823409243892349028349023482934092340892390101"), Exp: -94, Valid: true}, - - longestNumeric, - - &pgtype.Numeric{}, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.Numeric) - b := bb.(pgtype.Numeric) - - return numericEqual(&a, &b) + testPgxCodec(t, "numeric", []PgxTranscodeTestCase{ + {mustParseNumeric(t, "1"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))}, + {mustParseNumeric(t, "3.14159"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "3.14159"))}, + {mustParseNumeric(t, "100010001"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "100010001"))}, + {mustParseNumeric(t, "100010001.0001"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "100010001.0001"))}, + {mustParseNumeric(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"))}, + {mustParseNumeric(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"))}, + {mustParseNumeric(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"))}, + {pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 81, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 81, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 82, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 82, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 83, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 83, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 84, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 84, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -91, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -91, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -92, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -92, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -93, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -93, Valid: true})}, + {pgtype.Numeric{NaN: true, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{NaN: true, Valid: true})}, + {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true})}, + {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true})}, + {longestNumeric, new(pgtype.Numeric), isExpectedEqNumeric(longestNumeric)}, + {mustParseNumeric(t, "1"), new(int64), isExpectedEq(int64(1))}, + {math.NaN(), new(float64), func(a interface{}) bool { return math.IsNaN(a.(float64)) }}, + {float32(math.NaN()), new(float32), func(a interface{}) bool { return math.IsNaN(float64(a.(float32))) }}, + {math.Inf(1), new(float64), isExpectedEq(math.Inf(1))}, + {float32(math.Inf(1)), new(float32), isExpectedEq(float32(math.Inf(1)))}, + {math.Inf(-1), new(float64), isExpectedEq(math.Inf(-1))}, + {float32(math.Inf(-1)), new(float32), isExpectedEq(float32(math.Inf(-1)))}, + {int64(-1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "-1"))}, + {int64(0), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0"))}, + {int64(1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))}, + {int64(math.MinInt64), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MinInt64, 10)))}, + {int64(math.MinInt64 + 1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MinInt64+1, 10)))}, + {int64(math.MaxInt64), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MaxInt64, 10)))}, + {int64(math.MaxInt64 - 1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MaxInt64-1, 10)))}, + {pgtype.Numeric{}, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, + {nil, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, }) - } -func TestNumericTranscodeFuzz(t *testing.T) { +func TestNumericCodecFuzz(t *testing.T) { r := rand.New(rand.NewSource(0)) max := &big.Int{} max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) - values := make([]interface{}, 0, 2000) + tests := make([]PgxTranscodeTestCase, 0, 2000) for i := 0; i < 10; i++ { for j := -50; j < 50; j++ { num := (&big.Int{}).Rand(r, max) + + n := pgtype.Numeric{Int: num, Exp: int32(j), Valid: true} + tests = append(tests, PgxTranscodeTestCase{n, new(pgtype.Numeric), isExpectedEqNumeric(n)}) + negNum := &big.Int{} negNum.Neg(num) - values = append(values, &pgtype.Numeric{Int: num, Exp: int32(j), Valid: true}) - values = append(values, &pgtype.Numeric{Int: negNum, Exp: int32(j), Valid: true}) + n = pgtype.Numeric{Int: negNum, Exp: int32(j), Valid: true} + tests = append(tests, PgxTranscodeTestCase{n, new(pgtype.Numeric), isExpectedEqNumeric(n)}) } } - testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, - func(aa, bb interface{}) bool { - a := aa.(pgtype.Numeric) - b := bb.(pgtype.Numeric) - - return numericNormalizedEqual(&a, &b) - }) -} - -func TestNumericSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result *pgtype.Numeric - }{ - {source: float32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, - {source: float32(math.Copysign(0, -1)), result: &pgtype.Numeric{Int: big.NewInt(0), Valid: true}}, - {source: float64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, - {source: float64(math.Copysign(0, -1)), result: &pgtype.Numeric{Int: big.NewInt(0), Valid: true}}, - {source: int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, - {source: int16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, - {source: int32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, - {source: int64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, - {source: int8(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}}, - {source: int16(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}}, - {source: int32(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}}, - {source: int64(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}}, - {source: uint8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, - {source: uint16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, - {source: uint32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, - {source: uint64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, - {source: "1", result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, - {source: _int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, - {source: float64(1000), result: &pgtype.Numeric{Int: big.NewInt(1), Exp: 3, Valid: true}}, - {source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Valid: true}}, - {source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Valid: true}}, - {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Valid: true}}, - {source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Valid: true, NaN: true}}, - {source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Valid: true, NaN: true}}, - {source: pgtype.Infinity, result: &pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}}, - {source: math.Inf(1), result: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}}, - {source: float32(math.Inf(1)), result: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}}, - {source: pgtype.NegativeInfinity, result: &pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, - {source: math.Inf(-1), result: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.NegativeInfinity}}, - {source: float32(math.Inf(1)), result: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}}, - } - - for i, tt := range successfulTests { - r := &pgtype.Numeric{} - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !numericEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestNumericAssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - var f32 float32 - var f64 float64 - var pf32 *float32 - var pf64 *float64 - - simpleTests := []struct { - src *pgtype.Numeric - dst interface{} - expected interface{} - }{ - {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &f32, expected: float32(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &f64, expected: float64(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Valid: true}, dst: &f32, expected: float32(4.2)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Valid: true}, dst: &f64, expected: float64(4.2)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &i16, expected: int16(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &i32, expected: int32(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &i64, expected: int64(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: 3, Valid: true}, dst: &i64, expected: int64(42000)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &i, expected: int(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &ui8, expected: uint8(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &ui16, expected: uint16(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &ui64, expected: uint64(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &ui, expected: uint(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &_i8, expected: _int8(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(0)}, dst: &pi8, expected: ((*int8)(nil))}, - {src: &pgtype.Numeric{Int: big.NewInt(0)}, dst: &_pi8, expected: ((*_int8)(nil))}, - {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Valid: true}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgx/v5/pgtype/issues/27 - {src: &pgtype.Numeric{Valid: true, NaN: true}, dst: &f64, expected: math.NaN()}, - {src: &pgtype.Numeric{Valid: true, NaN: true}, dst: &f32, expected: float32(math.NaN())}, - {src: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}, dst: &f64, expected: math.Inf(1)}, - {src: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}, dst: &f32, expected: float32(math.Inf(1))}, - {src: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.NegativeInfinity}, dst: &f64, expected: math.Inf(-1)}, - {src: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.NegativeInfinity}, dst: &f32, expected: float32(math.Inf(-1))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - dst := reflect.ValueOf(tt.dst).Elem().Interface() - switch dstTyped := dst.(type) { - case float32: - nanExpected := math.IsNaN(float64(tt.expected.(float32))) - if nanExpected && !math.IsNaN(float64(dstTyped)) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } else if !nanExpected && dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - case float64: - nanExpected := math.IsNaN(tt.expected.(float64)) - if nanExpected && !math.IsNaN(dstTyped) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } else if !nanExpected && dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - default: - if dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - } - - pointerAllocTests := []struct { - src *pgtype.Numeric - dst interface{} - expected interface{} - }{ - {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &pf32, expected: float32(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &pf64, expected: float64(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src *pgtype.Numeric - dst interface{} - }{ - {src: &pgtype.Numeric{Int: big.NewInt(150), Valid: true}, dst: &i8}, - {src: &pgtype.Numeric{Int: big.NewInt(40000), Valid: true}, dst: &i16}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}, dst: &ui8}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}, dst: &ui16}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}, dst: &ui32}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}, dst: &ui64}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}, dst: &ui}, - {src: &pgtype.Numeric{Int: big.NewInt(0)}, dst: &i32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} - -func TestNumericEncodeDecodeBinary(t *testing.T) { - ci := pgtype.NewConnInfo() - tests := []interface{}{ - 123, - 0.000012345, - 1.00002345, - math.NaN(), - float32(math.NaN()), - math.Inf(1), - float32(math.Inf(1)), - math.Inf(-1), - float32(math.Inf(-1)), - } - - for i, tt := range tests { - toString := func(n *pgtype.Numeric) string { - ci := pgtype.NewConnInfo() - text, err := n.EncodeText(ci, nil) - if err != nil { - t.Errorf("%d (EncodeText): %v", i, err) - } - return string(text) - } - numeric := &pgtype.Numeric{} - numeric.Set(tt) - - encoded, err := numeric.EncodeBinary(ci, nil) - if err != nil { - t.Errorf("%d (EncodeBinary): %v", i, err) - } - decoded := &pgtype.Numeric{} - err = decoded.DecodeBinary(ci, encoded) - if err != nil { - t.Errorf("%d (DecodeBinary): %v", i, err) - } - - text0 := toString(numeric) - text1 := toString(decoded) - - if text0 != text1 { - t.Errorf("%d: expected %v to equal to %v, but doesn't", i, text0, text1) - } - } + testPgxCodec(t, "numeric", tests) } func TestNumericMarshalJSON(t *testing.T) { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index a7c97a73..b1a9201f 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -290,7 +290,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementCodec: PolygonCodec{}, ElementOID: PolygonOID}}) ci.RegisterDataType(DataType{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: NameOID}}) ci.RegisterDataType(DataType{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementCodec: QCharCodec{}, ElementOID: QCharOID}}) - ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) + ci.RegisterDataType(DataType{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementCodec: NumericCodec{}, ElementOID: NumericOID}}) ci.RegisterDataType(DataType{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: TextOID}}) ci.RegisterDataType(DataType{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementCodec: TimestampCodec{}, ElementOID: TimestampOID}}) ci.RegisterDataType(DataType{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementCodec: TimestamptzCodec{}, ElementOID: TimestamptzOID}}) @@ -333,7 +333,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) ci.RegisterDataType(DataType{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) ci.RegisterDataType(DataType{Name: "name", OID: NameOID, Codec: TextCodec{}}) - ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID}) + ci.RegisterDataType(DataType{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) // ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) ci.RegisterDataType(DataType{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) ci.RegisterDataType(DataType{Name: "path", OID: PathOID, Codec: PathCodec{}}) From 740263c0d4ee9d17bb5ed56550839925413fc98b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jan 2022 10:53:47 -0600 Subject: [PATCH 0864/1158] Convert UUID to Codec --- pgtype/builtin_wrappers.go | 36 +++ pgtype/pgtype.go | 52 +++- pgtype/uuid.go | 354 +++++++++++------------ pgtype/uuid_array.go | 560 ------------------------------------- pgtype/uuid_array_test.go | 368 ------------------------ pgtype/uuid_test.go | 167 ++--------- pgtype/zeronull/uuid.go | 78 ++---- values_test.go | 2 + 8 files changed, 309 insertions(+), 1308 deletions(-) delete mode 100644 pgtype/uuid_array.go delete mode 100644 pgtype/uuid_array_test.go diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 3becf06b..abf21b82 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -559,3 +559,39 @@ type fmtStringerWrapper struct { func (w fmtStringerWrapper) TextValue() (Text, error) { return Text{String: w.s.String(), Valid: true}, nil } + +type byte16Wrapper [16]byte + +func (w *byte16Wrapper) ScanUUID(v UUID) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *[16]byte") + } + *w = byte16Wrapper(v.Bytes) + return nil +} + +func (w byte16Wrapper) UUIDValue() (UUID, error) { + return UUID{Bytes: [16]byte(w), Valid: true}, nil +} + +type byteSliceWrapper []byte + +func (w *byteSliceWrapper) ScanUUID(v UUID) error { + if !v.Valid { + *w = nil + return nil + } + *w = make(byteSliceWrapper, 16) + copy(*w, v.Bytes[:]) + return nil +} + +func (w byteSliceWrapper) UUIDValue() (UUID, error) { + if w == nil { + return UUID{}, nil + } + + uuid := UUID{Valid: true} + copy(uuid.Bytes[:], w) + return uuid, nil +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index b1a9201f..b5e1181d 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -296,7 +296,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementCodec: TimestamptzCodec{}, ElementOID: TimestamptzOID}}) ci.RegisterDataType(DataType{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementCodec: MacaddrCodec{}, ElementOID: MacaddrOID}}) ci.RegisterDataType(DataType{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementCodec: TIDCodec{}, ElementOID: TIDOID}}) - ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID}) + ci.RegisterDataType(DataType{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementCodec: UUIDCodec{}, ElementOID: UUIDOID}}) ci.RegisterDataType(DataType{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementCodec: JSONBCodec{}, ElementOID: JSONBOID}}) ci.RegisterDataType(DataType{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementCodec: JSONCodec{}, ElementOID: JSONOID}}) ci.RegisterDataType(DataType{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: VarcharOID}}) @@ -350,7 +350,7 @@ func NewConnInfo() *ConnInfo { // ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) // ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID}) ci.RegisterDataType(DataType{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) - ci.RegisterDataType(DataType{Value: &UUID{}, Name: "uuid", OID: UUIDOID}) + ci.RegisterDataType(DataType{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) ci.RegisterDataType(DataType{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) ci.RegisterDataType(DataType{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) @@ -888,6 +888,10 @@ func tryWrapBuiltinTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter return &wrapMapStringToPointerStringScanPlan{}, (*mapStringToPointerStringWrapper)(dst), true case *map[string]string: return &wrapMapStringToStringScanPlan{}, (*mapStringToStringWrapper)(dst), true + case *[16]byte: + return &wrapByte16ScanPlan{}, (*byte16Wrapper)(dst), true + case *[]byte: + return &wrapByteSliceScanPlan{}, (*byteSliceWrapper)(dst), true } return nil, nil, false @@ -1083,6 +1087,26 @@ func (plan *wrapMapStringToStringScanPlan) Scan(ci *ConnInfo, oid uint32, format return plan.next.Scan(ci, oid, formatCode, src, (*mapStringToStringWrapper)(dst.(*map[string]string))) } +type wrapByte16ScanPlan struct { + next ScanPlan +} + +func (plan *wrapByte16ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapByte16ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*byte16Wrapper)(dst.(*[16]byte))) +} + +type wrapByteSliceScanPlan struct { + next ScanPlan +} + +func (plan *wrapByteSliceScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapByteSliceScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*byteSliceWrapper)(dst.(*[]byte))) +} + type pointerEmptyInterfaceScanPlan struct { codec Codec } @@ -1425,6 +1449,10 @@ func tryWrapBuiltinTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNext return &wrapMapStringToPointerStringEncodePlan{}, mapStringToPointerStringWrapper(value), true case map[string]string: return &wrapMapStringToStringEncodePlan{}, mapStringToStringWrapper(value), true + case [16]byte: + return &wrapByte16EncodePlan{}, byte16Wrapper(value), true + case []byte: + return &wrapByteSliceEncodePlan{}, byteSliceWrapper(value), true case fmt.Stringer: return &wrapFmtStringerEncodePlan{}, fmtStringerWrapper{value}, true } @@ -1622,6 +1650,26 @@ func (plan *wrapMapStringToStringEncodePlan) Encode(value interface{}, buf []byt return plan.next.Encode(mapStringToStringWrapper(value.(map[string]string)), buf) } +type wrapByte16EncodePlan struct { + next EncodePlan +} + +func (plan *wrapByte16EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapByte16EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(byte16Wrapper(value.([16]byte)), buf) +} + +type wrapByteSliceEncodePlan struct { + next EncodePlan +} + +func (plan *wrapByteSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapByteSliceEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(byteSliceWrapper(value.([]byte)), buf) +} + type wrapFmtStringerEncodePlan struct { next EncodePlan } diff --git a/pgtype/uuid.go b/pgtype/uuid.go index 4533aa06..288fc454 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -7,144 +7,26 @@ import ( "fmt" ) +type UUIDScanner interface { + ScanUUID(v UUID) error +} + +type UUIDValuer interface { + UUIDValue() (UUID, error) +} + type UUID struct { Bytes [16]byte Valid bool - - UUIDDecoderWrapper func(interface{}) UUIDDecoder - Getter func(UUID) interface{} } -func (n *UUID) NewTypeValue() Value { - return &UUID{ - UUIDDecoderWrapper: n.UUIDDecoderWrapper, - Getter: n.Getter, - } -} - -func (n *UUID) TypeName() string { - return "uuid" -} - -func (dst *UUID) setNil() { - dst.Bytes = [16]byte{} - dst.Valid = false -} - -func (dst *UUID) setByteArray(value [16]byte) { - dst.Bytes = value - dst.Valid = true -} - -func (dst *UUID) setByteSlice(value []byte) error { - if value != nil { - if len(value) != 16 { - return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) - } - copy(dst.Bytes[:], value) - dst.Valid = true - } else { - dst.setNil() - } - +func (b *UUID) ScanUUID(v UUID) error { + *b = v return nil } -func (dst *UUID) setString(value string) error { - uuid, err := parseUUID(value) - if err != nil { - return err - } - dst.setByteArray(uuid) - return nil -} - -func (dst *UUID) Set(src interface{}) error { - if src == nil { - dst.setNil() - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case [16]byte: - dst.setByteArray(value) - case []byte: - return dst.setByteSlice(value) - case string: - return dst.setString(value) - case *string: - if value == nil { - dst.setNil() - } else { - return dst.setString(*value) - } - default: - if originalSrc, ok := underlyingUUIDType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to UUID", value) - } - - return nil -} - -func (dst UUID) Get() interface{} { - if dst.Getter != nil { - return dst.Getter(dst) - } - - if !dst.Valid { - return nil - } - - return dst.Bytes -} - -type UUIDDecoder interface { - DecodeUUID(*UUID) error -} - -func (src *UUID) AssignTo(dst interface{}) error { - if d, ok := dst.(UUIDDecoder); ok { - return d.DecodeUUID(src) - } else { - if src.UUIDDecoderWrapper != nil { - d = src.UUIDDecoderWrapper(dst) - if d != nil { - return d.DecodeUUID(src) - } - } - } - - if !src.Valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case *[16]byte: - *v = src.Bytes - return nil - case *[]byte: - *v = make([]byte, 16) - copy(*v, src.Bytes[:]) - return nil - case *string: - *v = encodeUUID(src.Bytes) - return nil - default: - if nextDst, retry := GetAssignToDstType(v); retry { - return src.AssignTo(nextDst) - } - } - - return nil +func (b UUID) UUIDValue() (UUID, error) { + return b, nil } // parseUUID converts a string UUID in standard form to a byte array. @@ -173,68 +55,21 @@ func encodeUUID(src [16]byte) string { return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16]) } -func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - dst.setNil() - return nil - } - - if len(src) != 36 { - return fmt.Errorf("invalid length for UUID: %v", len(src)) - } - - buf, err := parseUUID(string(src)) - if err != nil { - return err - } - - dst.setByteArray(buf) - return nil -} - -func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - dst.setNil() - return nil - } - - if len(src) != 16 { - return fmt.Errorf("invalid length for UUID: %v", len(src)) - } - - return dst.setByteSlice(src) -} - -func (src UUID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return append(buf, encodeUUID(src.Bytes)...), nil -} - -func (src UUID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return append(buf, src.Bytes[:]...), nil -} - // Scan implements the database/sql Scanner interface. func (dst *UUID) Scan(src interface{}) error { if src == nil { - dst.setNil() + *dst = UUID{} return nil } switch src := src.(type) { case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + buf, err := parseUUID(src) + if err != nil { + return err + } + *dst = UUID{Bytes: buf, Valid: true} + return nil } return fmt.Errorf("cannot scan %T", src) @@ -242,7 +77,11 @@ func (dst *UUID) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src UUID) Value() (driver.Value, error) { - return EncodeValueText(src) + if !src.Valid { + return nil, nil + } + + return encodeUUID(src.Bytes), nil } func (src UUID) MarshalJSON() ([]byte, error) { @@ -259,10 +98,151 @@ func (src UUID) MarshalJSON() ([]byte, error) { func (dst *UUID) UnmarshalJSON(src []byte) error { if bytes.Compare(src, []byte("null")) == 0 { - return dst.Set(nil) + *dst = UUID{} + return nil } if len(src) != 38 { return fmt.Errorf("invalid length for UUID: %v", len(src)) } - return dst.Set(string(src[1 : len(src)-1])) + buf, err := parseUUID(string(src[1 : len(src)-1])) + if err != nil { + return err + } + *dst = UUID{Bytes: buf, Valid: true} + return nil +} + +type UUIDCodec struct{} + +func (UUIDCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (UUIDCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (UUIDCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(UUIDValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanUUIDCodecBinaryUUIDValuer{} + case TextFormatCode: + return encodePlanUUIDCodecTextUUIDValuer{} + } + + return nil +} + +type encodePlanUUIDCodecBinaryUUIDValuer struct{} + +func (encodePlanUUIDCodecBinaryUUIDValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + uuid, err := value.(UUIDValuer).UUIDValue() + if err != nil { + return nil, err + } + + if !uuid.Valid { + return nil, nil + } + + return append(buf, uuid.Bytes[:]...), nil +} + +type encodePlanUUIDCodecTextUUIDValuer struct{} + +func (encodePlanUUIDCodecTextUUIDValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + uuid, err := value.(UUIDValuer).UUIDValue() + if err != nil { + return nil, err + } + + if !uuid.Valid { + return nil, nil + } + + return append(buf, encodeUUID(uuid.Bytes)...), nil +} + +func (UUIDCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case UUIDScanner: + return scanPlanBinaryUUIDToUUIDScanner{} + } + case TextFormatCode: + switch target.(type) { + case UUIDScanner: + return scanPlanTextAnyToUUIDScanner{} + } + } + + return nil +} + +type scanPlanBinaryUUIDToUUIDScanner struct{} + +func (scanPlanBinaryUUIDToUUIDScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(UUIDScanner) + + if src == nil { + return scanner.ScanUUID(UUID{}) + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for UUID: %v", len(src)) + } + + uuid := UUID{Valid: true} + copy(uuid.Bytes[:], src) + + return scanner.ScanUUID(uuid) +} + +type scanPlanTextAnyToUUIDScanner struct{} + +func (scanPlanTextAnyToUUIDScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(UUIDScanner) + + if src == nil { + return scanner.ScanUUID(UUID{}) + } + + buf, err := parseUUID(string(src)) + if err != nil { + return err + } + + return scanner.ScanUUID(UUID{Bytes: buf, Valid: true}) +} + +func (c UUIDCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var uuid UUID + err := codecScan(c, ci, oid, format, src, &uuid) + if err != nil { + return nil, err + } + + return encodeUUID(uuid.Bytes), nil +} + +func (c UUIDCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var uuid UUID + err := codecScan(c, ci, oid, format, src, &uuid) + if err != nil { + return nil, err + } + return uuid.Bytes, nil } diff --git a/pgtype/uuid_array.go b/pgtype/uuid_array.go deleted file mode 100644 index 98904f9f..00000000 --- a/pgtype/uuid_array.go +++ /dev/null @@ -1,560 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type UUIDArray struct { - Elements []UUID - Dimensions []ArrayDimension - Valid bool -} - -func (dst *UUIDArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = UUIDArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case [][16]byte: - if value == nil { - *dst = UUIDArray{} - } else if len(value) == 0 { - *dst = UUIDArray{Valid: true} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case [][]byte: - if value == nil { - *dst = UUIDArray{} - } else if len(value) == 0 { - *dst = UUIDArray{Valid: true} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []string: - if value == nil { - *dst = UUIDArray{} - } else if len(value) == 0 { - *dst = UUIDArray{Valid: true} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*string: - if value == nil { - *dst = UUIDArray{} - } else if len(value) == 0 { - *dst = UUIDArray{Valid: true} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []UUID: - if value == nil { - *dst = UUIDArray{} - } else if len(value) == 0 { - *dst = UUIDArray{Valid: true} - } else { - *dst = UUIDArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = UUIDArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for UUIDArray", src) - } - if elementsLength == 0 { - *dst = UUIDArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to UUIDArray", src) - } - - *dst = UUIDArray{ - Elements: make([]UUID, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]UUID, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to UUIDArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *UUIDArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to UUIDArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in UUIDArray", err) - } - index++ - - return index, nil -} - -func (dst UUIDArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *UUIDArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[][16]byte: - *v = make([][16]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[][]byte: - *v = make([][]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *UUIDArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from UUIDArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from UUIDArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *UUIDArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = UUIDArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []UUID - - if len(uta.Elements) > 0 { - elements = make([]UUID, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem UUID - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = UUIDArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *UUIDArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = UUIDArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = UUIDArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]UUID, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = UUIDArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src UUIDArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("uuid"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "uuid") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *UUIDArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src UUIDArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/uuid_array_test.go b/pgtype/uuid_array_test.go deleted file mode 100644 index b432d0f8..00000000 --- a/pgtype/uuid_array_test.go +++ /dev/null @@ -1,368 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestUUIDArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "uuid[]", []interface{}{ - &pgtype.UUIDArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.UUIDArray{}, - &pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, - {}, - {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestUUIDArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.UUIDArray - }{ - { - source: nil, - result: pgtype.UUIDArray{}, - }, - { - source: [][16]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][16]byte{}, - result: pgtype.UUIDArray{Valid: true}, - }, - { - source: ([][16]byte)(nil), - result: pgtype.UUIDArray{}, - }, - { - source: [][]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][]byte{ - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, - nil, - {32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, - }, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, - {}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 4}}, - Valid: true}, - }, - { - source: [][]byte{}, - result: pgtype.UUIDArray{Valid: true}, - }, - { - source: ([][]byte)(nil), - result: pgtype.UUIDArray{}, - }, - { - source: []string{"00010203-0405-0607-0809-0a0b0c0d0e0f"}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: []string{}, - result: pgtype.UUIDArray{Valid: true}, - }, - { - source: ([]string)(nil), - result: pgtype.UUIDArray{}, - }, - { - source: [][][16]byte{{ - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]string{ - {{{ - "00010203-0405-0607-0809-0a0b0c0d0e0f", - "10111213-1415-1617-1819-1a1b1c1d1e1f", - "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, - {{{ - "30313233-3435-3637-3839-3a3b3c3d3e3f", - "40414243-4445-4647-4849-4a4b4c4d4e4f", - "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, - {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Valid: true}, - {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1][16]byte{{ - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]string{ - {{{ - "00010203-0405-0607-0809-0a0b0c0d0e0f", - "10111213-1415-1617-1819-1a1b1c1d1e1f", - "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, - {{{ - "30313233-3435-3637-3839-3a3b3c3d3e3f", - "40414243-4445-4647-4849-4a4b4c4d4e4f", - "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, - {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Valid: true}, - {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.UUIDArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestUUIDArrayAssignTo(t *testing.T) { - var byteArraySlice [][16]byte - var byteSliceSlice [][]byte - var stringSlice []string - var byteSlice []byte - var byteArraySliceDim2 [][][16]byte - var stringSliceDim4 [][][][]string - var byteArrayDim2 [2][1][16]byte - var stringArrayDim4 [2][1][1][3]string - - simpleTests := []struct { - src pgtype.UUIDArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &byteArraySlice, - expected: [][16]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - }, - { - src: pgtype.UUIDArray{}, - dst: &byteArraySlice, - expected: ([][16]byte)(nil), - }, - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &byteSliceSlice, - expected: [][]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - }, - { - src: pgtype.UUIDArray{}, - dst: &byteSliceSlice, - expected: ([][]byte)(nil), - }, - { - src: pgtype.UUIDArray{Valid: true}, - dst: &byteSlice, - expected: []byte{}, - }, - { - src: pgtype.UUIDArray{Valid: true}, - dst: &stringSlice, - expected: []string{}, - }, - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &stringSlice, - expected: []string{"00010203-0405-0607-0809-0a0b0c0d0e0f"}, - }, - { - src: pgtype.UUIDArray{}, - dst: &stringSlice, - expected: ([]string)(nil), - }, - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &byteArraySliceDim2, - expected: [][][16]byte{{ - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, - }, - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, - {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Valid: true}, - {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &stringSliceDim4, - expected: [][][][]string{ - {{{ - "00010203-0405-0607-0809-0a0b0c0d0e0f", - "10111213-1415-1617-1819-1a1b1c1d1e1f", - "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, - {{{ - "30313233-3435-3637-3839-3a3b3c3d3e3f", - "40414243-4445-4647-4849-4a4b4c4d4e4f", - "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, - }, - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &byteArrayDim2, - expected: [2][1][16]byte{{ - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, - }, - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, - {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Valid: true}, - {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &stringArrayDim4, - expected: [2][1][1][3]string{ - {{{ - "00010203-0405-0607-0809-0a0b0c0d0e0f", - "10111213-1415-1617-1819-1a1b1c1d1e1f", - "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, - {{{ - "30313233-3435-3637-3839-3a3b3c3d3e3f", - "40414243-4445-4647-4849-4a4b4c4d4e4f", - "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index 9701db74..71de8d67 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -1,153 +1,46 @@ package pgtype_test import ( - "bytes" "reflect" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/require" ) -func TestUUIDTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - &pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - &pgtype.UUID{}, +func TestUUIDCodec(t *testing.T) { + testPgxCodec(t, "uuid", []PgxTranscodeTestCase{ + { + pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + { + [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + { + []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + { + "00010203-0405-0607-0809-0a0b0c0d0e0f", + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + { + "000102030405060708090a0b0c0d0e0f", + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + {pgtype.UUID{}, new([]byte), isExpectedEqBytes([]byte(nil))}, + {pgtype.UUID{}, new(pgtype.UUID), isExpectedEq(pgtype.UUID{})}, + {nil, new(pgtype.UUID), isExpectedEq(pgtype.UUID{})}, }) } -type SomeUUIDWrapper struct { - SomeUUIDType -} - -type SomeUUIDType [16]byte - -func TestUUIDSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.UUID - }{ - { - source: nil, - result: pgtype.UUID{}, - }, - { - source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - }, - { - source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - }, - { - source: SomeUUIDType{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - }, - { - source: ([]byte)(nil), - result: pgtype.UUID{}, - }, - { - source: "00010203-0405-0607-0809-0a0b0c0d0e0f", - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - }, - { - source: "000102030405060708090a0b0c0d0e0f", - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.UUID - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r.Bytes != tt.result.Bytes || r.Valid != tt.result.Valid { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestUUIDAssignTo(t *testing.T) { - { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} - var dst [16]byte - expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} - var dst []byte - expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if bytes.Compare(dst, expected) != 0 { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} - var dst SomeUUIDType - expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} - var dst string - expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} - var dst SomeUUIDWrapper - expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst.SomeUUIDType != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } -} - func TestUUID_MarshalJSON(t *testing.T) { tests := []struct { name string diff --git a/pgtype/zeronull/uuid.go b/pgtype/zeronull/uuid.go index 8c0978b3..a87003e1 100644 --- a/pgtype/zeronull/uuid.go +++ b/pgtype/zeronull/uuid.go @@ -8,68 +8,29 @@ import ( type UUID [16]byte -func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - var nullable pgtype.UUID - err := nullable.DecodeText(ci, src) - if err != nil { - return err +// ScanUUID implements the UUIDScanner interface. +func (u *UUID) ScanUUID(v pgtype.UUID) error { + if !v.Valid { + *u = UUID{} + return nil } - if nullable.Valid { - *dst = UUID(nullable.Bytes) - } else { - *dst = UUID{} - } + *u = UUID(v.Bytes) return nil } -func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - var nullable pgtype.UUID - err := nullable.DecodeBinary(ci, src) - if err != nil { - return err +func (u UUID) UUIDValue() (pgtype.UUID, error) { + if u == (UUID{}) { + return pgtype.UUID{}, nil } - - if nullable.Valid { - *dst = UUID(nullable.Bytes) - } else { - *dst = UUID{} - } - - return nil -} - -func (src UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if (src == UUID{}) { - return nil, nil - } - - nullable := pgtype.UUID{ - Bytes: [16]byte(src), - Valid: true, - } - - return nullable.EncodeText(ci, buf) -} - -func (src UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if (src == UUID{}) { - return nil, nil - } - - nullable := pgtype.UUID{ - Bytes: [16]byte(src), - Valid: true, - } - - return nullable.EncodeBinary(ci, buf) + return pgtype.UUID{Bytes: u, Valid: true}, nil } // Scan implements the database/sql Scanner interface. -func (dst *UUID) Scan(src interface{}) error { +func (u *UUID) Scan(src interface{}) error { if src == nil { - *dst = UUID{} + *u = UUID{} return nil } @@ -79,12 +40,21 @@ func (dst *UUID) Scan(src interface{}) error { return err } - *dst = UUID(nullable.Bytes) + *u = UUID(nullable.Bytes) return nil } // Value implements the database/sql/driver Valuer interface. -func (src UUID) Value() (driver.Value, error) { - return pgtype.EncodeValueText(src) +func (u UUID) Value() (driver.Value, error) { + if u == (UUID{}) { + return nil, nil + } + + buf, err := pgtype.UUIDCodec{}.PlanEncode(nil, pgtype.UUIDOID, pgtype.TextFormatCode, u).Encode(u, nil) + if err != nil { + return nil, err + } + + return string(buf), nil } diff --git a/values_test.go b/values_test.go index 82a6496a..080bb305 100644 --- a/values_test.go +++ b/values_test.go @@ -243,6 +243,8 @@ func mustParseCIDR(t *testing.T, s string) *net.IPNet { } func TestStringToNotTextTypeTranscode(t *testing.T) { + t.Skip("TODO - unskip later in v5") // Should this even be a thing... i.e. anything is scanable to a string to a string + t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { From 8d2c87b5e5864b84cada19ab6833a2a3abf04f59 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jan 2022 10:54:54 -0600 Subject: [PATCH 0865/1158] Remove old typed array code gen --- pgtype/typed_array.go.erb | 481 -------------------------------------- pgtype/typed_array_gen.sh | 26 --- 2 files changed, 507 deletions(-) delete mode 100644 pgtype/typed_array.go.erb delete mode 100755 pgtype/typed_array_gen.sh diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb deleted file mode 100644 index e1ead59c..00000000 --- a/pgtype/typed_array.go.erb +++ /dev/null @@ -1,481 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "bytes" - "fmt" - "io" - - "github.com/jackc/pgio" -) - -type <%= pgtype_array_type %> struct { - Elements []<%= pgtype_element_type %> - Dimensions []ArrayDimension - Valid bool -} - -func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = <%= pgtype_array_type %>{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - <% go_array_types.split(",").each do |t| %> - <% if t != "[]#{pgtype_element_type}" %> - case <%= t %>: - if value == nil { - *dst = <%= pgtype_array_type %>{} - } else if len(value) == 0 { - *dst = <%= pgtype_array_type %>{Valid: true} - } else { - elements := make([]<%= pgtype_element_type %>, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = <%= pgtype_array_type %>{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - <% end %> - <% end %> - case []<%= pgtype_element_type %>: - if value == nil { - *dst = <%= pgtype_array_type %>{} - } else if len(value) == 0 { - *dst = <%= pgtype_array_type %>{Valid: true} - } else { - *dst = <%= pgtype_array_type %>{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = <%= pgtype_array_type %>{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for <%= pgtype_array_type %>", src) - } - if elementsLength == 0 { - *dst = <%= pgtype_array_type %>{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to <%= pgtype_array_type %>", src) - } - - *dst = <%= pgtype_array_type %> { - Elements: make([]<%= pgtype_element_type %>, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]<%= pgtype_element_type %>, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to <%= pgtype_array_type %>, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *<%= pgtype_array_type %>) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to <%= pgtype_array_type %>") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in <%= pgtype_array_type %>", err) - } - index++ - - return index, nil -} - -func (dst <%= pgtype_array_type %>) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1{ - // Attempt to match to select common types: - switch v := dst.(type) { - <% go_array_types.split(",").each do |t| %> - case *<%= t %>: - *v = make(<%= t %>, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - <% end %> - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *<%= pgtype_array_type %>) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr(){ - return 0, fmt.Errorf("cannot assign all values from <%= pgtype_array_type %>") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from <%= pgtype_array_type %>") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = <%= pgtype_array_type %>{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []<%= pgtype_element_type %> - - if len(uta.Elements) > 0 { - elements = make([]<%= pgtype_element_type %>, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem <%= pgtype_element_type %> - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -<% if binary_format == "true" %> -func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = <%= pgtype_array_type %>{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = <%= pgtype_array_type %>{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]<%= pgtype_element_type %>, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp:rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} -<% end %> - -func (src <%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `<%= text_null %>`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -<% if binary_format == "true" %> - func (src <%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil - } -<% end %> - -// Scan implements the database/sql Scanner interface. -func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src <%= pgtype_array_type %>) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh deleted file mode 100755 index 3766c3f8..00000000 --- a/pgtype/typed_array_gen.sh +++ /dev/null @@ -1,26 +0,0 @@ -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[]*bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time,[]*time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time,[]*time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go -erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange go_array_types=[]Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstzrange_array.go -erb pgtype_array_type=TsrangeArray pgtype_element_type=Tsrange go_array_types=[]Tsrange element_type_name=tsrange text_null=NULL binary_format=true typed_array.go.erb > tsrange_array.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time,[]*time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32,[]*float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64,[]*float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go -erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr go_array_types=[]net.HardwareAddr,[]*net.HardwareAddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go -erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string,[]*string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > text_array.go -erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string,[]*string element_type_name=varchar text_null=NULL binary_format=true typed_array.go.erb > varchar_array.go -erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string,[]*string element_type_name=bpchar text_null=NULL binary_format=true typed_array.go.erb > bpchar_array.go -erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go -erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string,[]*string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go -erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go -erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]*float32,[]float64,[]*float64,[]int64,[]*int64,[]uint64,[]*uint64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go -erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string,[]*string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go -erb pgtype_array_type=JSONBArray pgtype_element_type=JSONB go_array_types=[]string,[][]byte element_type_name=jsonb text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go - -# While the binary format is theoretically possible it is only practical to use the text format. -erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string,[]*string text_null=NULL binary_format=false typed_array.go.erb > enum_array.go - -goimports -w *_array.go From ad785d813409568d1bf57d67706f45494ca2924f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jan 2022 10:56:56 -0600 Subject: [PATCH 0866/1158] Remove TypeValue interface --- pgtype/pgtype.go | 31 ++----------------------------- 1 file changed, 2 insertions(+), 29 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index b5e1181d..481c58be 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -149,22 +149,6 @@ type Value interface { AssignTo(dst interface{}) error } -// TypeValue is a Value where instances can represent different PostgreSQL types. This can be useful for -// representing types such as enums, composites, and arrays. -// -// In general, instances of TypeValue should not be used to directly represent a value. It should only be used as an -// encoder and decoder internal to ConnInfo. -type TypeValue interface { - Value - - // NewTypeValue creates a TypeValue including references to internal type information. e.g. the list of members - // in an EnumType. - NewTypeValue() Value - - // TypeName returns the PostgreSQL name of this type. - TypeName() string -} - type Codec interface { // FormatSupported returns true if the format is supported. FormatSupported(int16) bool @@ -456,9 +440,7 @@ func (ci *ConnInfo) buildReflectTypeToDataType() { for _, dt := range ci.oidToDataType { if dt.Value != nil { - if _, is := dt.Value.(TypeValue); !is { - ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt - } + ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt } } @@ -476,11 +458,6 @@ func (ci *ConnInfo) DataTypeForValue(v interface{}) (*DataType, bool) { ci.buildReflectTypeToDataType() } - if tv, ok := v.(TypeValue); ok { - dt, ok := ci.nameToDataType[tv.TypeName()] - return dt, ok - } - dt, ok := ci.reflectTypeToDataType[reflect.TypeOf(v)] return dt, ok } @@ -1258,11 +1235,7 @@ func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) // NewValue returns a new instance of the same type as v. func NewValue(v Value) Value { - if tv, ok := v.(TypeValue); ok { - return tv.NewTypeValue() - } else { - return reflect.New(reflect.ValueOf(v).Elem().Type()).Interface().(Value) - } + return reflect.New(reflect.ValueOf(v).Elem().Type()).Interface().(Value) } var ErrScanTargetTypeChanged = errors.New("scan target type changed") From eb0a4c96264854d3e603eb90536c25104244a6ad Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jan 2022 11:21:12 -0600 Subject: [PATCH 0867/1158] Replace some old database/sql compatibility --- pgtype/builtin_wrappers.go | 2 ++ pgtype/database_sql.go | 41 ---------------------------- pgtype/pgtype.go | 56 ++++++++++---------------------------- 3 files changed, 16 insertions(+), 83 deletions(-) delete mode 100644 pgtype/database_sql.go diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index abf21b82..9df28f55 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -461,6 +461,8 @@ func (w timeWrapper) TimeValue() (Time, error) { type durationWrapper time.Duration +func (w durationWrapper) SkipUnderlyingTypePlan() {} + func (w *durationWrapper) ScanInterval(v Interval) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *time.Interval") diff --git a/pgtype/database_sql.go b/pgtype/database_sql.go deleted file mode 100644 index 9d1cf822..00000000 --- a/pgtype/database_sql.go +++ /dev/null @@ -1,41 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "errors" -) - -func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { - if valuer, ok := src.(driver.Valuer); ok { - return valuer.Value() - } - - if textEncoder, ok := src.(TextEncoder); ok { - buf, err := textEncoder.EncodeText(ci, nil) - if err != nil { - return nil, err - } - return string(buf), nil - } - - if binaryEncoder, ok := src.(BinaryEncoder); ok { - buf, err := binaryEncoder.EncodeBinary(ci, nil) - if err != nil { - return nil, err - } - return buf, nil - } - - return nil, errors.New("cannot convert to database/sql compatible value") -} - -func EncodeValueText(src TextEncoder) (interface{}, error) { - buf, err := src.EncodeText(nil, make([]byte, 0, 32)) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), err -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 481c58be..ee2730ee 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -520,41 +520,6 @@ func (plan scanPlanDstTextDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int return newPlan.Scan(ci, oid, formatCode, src, dst) } -type scanPlanDataTypeSQLScanner DataType - -func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - scanner, ok := dst.(sql.Scanner) - if !ok { - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) - } - - dt := (*DataType)(plan) - if dt.Codec != nil { - sqlValue, err := dt.Codec.DecodeDatabaseSQLValue(ci, oid, formatCode, src) - if err != nil { - return err - } - return scanner.Scan(sqlValue) - } - var err error - switch formatCode { - case BinaryFormatCode: - err = dt.binaryDecoder.DecodeBinary(ci, src) - case TextFormatCode: - err = dt.textDecoder.DecodeText(ci, src) - } - if err != nil { - return err - } - - sqlSrc, err := DatabaseSQLValue(ci, dt.Value) - if err != nil { - return err - } - return scanner.Scan(sqlSrc) -} - type scanPlanDataTypeAssignTo DataType func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { @@ -596,6 +561,18 @@ func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode return assignToErr } +type scanPlanCodecSQLScanner struct{ c Codec } + +func (plan *scanPlanCodecSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + value, err := plan.c.DecodeDatabaseSQLValue(ci, oid, formatCode, src) + if err != nil { + return err + } + + scanner := dst.(sql.Scanner) + return scanner.Scan(value) +} + type scanPlanSQLScanner struct{} func (scanPlanSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { @@ -1176,7 +1153,7 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan for _, f := range tryWrappers { if wrapperPlan, nextDst, ok := f(dst); ok { if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { - if _, ok := nextPlan.(*scanPlanDataTypeAssignTo); !ok { // avoid fallthrough -- this will go away when old system removed. + if _, ok := nextPlan.(scanPlanReflection); !ok { // avoid fallthrough -- this will go away when old system removed. wrapperPlan.SetNext(nextPlan) return wrapperPlan } @@ -1187,15 +1164,10 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan if _, ok := dst.(*interface{}); ok { return &pointerEmptyInterfaceScanPlan{codec: dt.Codec} } - } - if dt != nil { if _, ok := dst.(sql.Scanner); ok { - if _, found := ci.preferAssignToOverSQLScannerTypes[reflect.TypeOf(dst)]; !found { - return (*scanPlanDataTypeSQLScanner)(dt) - } + return &scanPlanCodecSQLScanner{c: dt.Codec} } - return (*scanPlanDataTypeAssignTo)(dt) } if _, ok := dst.(sql.Scanner); ok { From 3a90c6c8795d5866ae403e11ff92ec03edd2540a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jan 2022 12:07:35 -0600 Subject: [PATCH 0868/1158] Removed TextEncoder and BinaryEncoder Restructured / fixed a lot of tests along the way. --- extended_query_builder.go | 50 ---- messages.go | 4 - pgtype/bits_test.go | 13 +- pgtype/bool_test.go | 3 +- pgtype/box_test.go | 14 +- pgtype/bytea_test.go | 2 +- pgtype/circle_test.go | 3 +- pgtype/date_test.go | 3 +- pgtype/float4_test.go | 3 +- pgtype/float8_test.go | 3 +- pgtype/hstore_test.go | 20 +- pgtype/inet_test.go | 5 +- pgtype/int_test.go | 7 +- pgtype/int_test.go.erb | 2 +- pgtype/interval_test.go | 3 +- pgtype/json_test.go | 4 +- pgtype/jsonb_test.go | 4 +- pgtype/line_test.go | 2 +- pgtype/lseg_test.go | 3 +- pgtype/macaddr_test.go | 4 +- pgtype/numeric_test.go | 10 +- pgtype/path_test.go | 3 +- pgtype/pgtype.go | 22 -- pgtype/pgtype_test.go | 50 ---- pgtype/point_test.go | 3 +- pgtype/polygon_test.go | 3 +- pgtype/qchar_test.go | 14 +- pgtype/testutil/testutil.go | 392 ++-------------------------- pgtype/text_test.go | 10 +- pgtype/tid_test.go | 3 +- pgtype/time_test.go | 3 +- pgtype/timestamp_test.go | 2 +- pgtype/timestamptz_test.go | 2 +- pgtype/uint32_test.go | 3 +- pgtype/uuid_test.go | 3 +- pgtype/zeronull/float8.go | 2 + pgtype/zeronull/float8_test.go | 33 ++- pgtype/zeronull/int.go | 6 + pgtype/zeronull/int.go.erb | 2 + pgtype/zeronull/int_test.go | 81 +++--- pgtype/zeronull/int_test.go.erb | 27 +- pgtype/zeronull/text.go | 2 + pgtype/zeronull/text_test.go | 27 +- pgtype/zeronull/timestamp.go | 2 + pgtype/zeronull/timestamp_test.go | 41 +-- pgtype/zeronull/timestamptz.go | 2 + pgtype/zeronull/timestamptz_test.go | 41 +-- pgtype/zeronull/uuid.go | 2 + pgtype/zeronull/uuid_test.go | 27 +- values.go | 83 +----- 50 files changed, 295 insertions(+), 758 deletions(-) diff --git a/extended_query_builder.go b/extended_query_builder.go index 36447c99..34662eb1 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -71,40 +71,12 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o eqb.paramValueBytes = make([]byte, 0, 128) } - var err error - var buf []byte pos := len(eqb.paramValueBytes) if arg, ok := arg.(string); ok { return []byte(arg), nil } - if formatCode == TextFormatCode { - if arg, ok := arg.(pgtype.TextEncoder); ok { - buf, err = arg.EncodeText(ci, eqb.paramValueBytes) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - eqb.paramValueBytes = buf - return eqb.paramValueBytes[pos:], nil - } - } else if formatCode == BinaryFormatCode { - if arg, ok := arg.(pgtype.BinaryEncoder); ok { - buf, err = arg.EncodeBinary(ci, eqb.paramValueBytes) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - eqb.paramValueBytes = buf - return eqb.paramValueBytes[pos:], nil - } - } - if argIsPtr { // We have already checked that arg is not pointing to nil, // so it is safe to dereference here. @@ -143,28 +115,6 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o } } - // There is no data type registered for the destination OID, but maybe there is data type registered for the arg - // type. If so use it's text encoder (if available). - if dt, ok := ci.DataTypeForValue(arg); ok { - value := dt.Value - if textEncoder, ok := value.(pgtype.TextEncoder); ok { - err := value.Set(arg) - if err != nil { - return nil, err - } - - buf, err = textEncoder.EncodeText(ci, eqb.paramValueBytes) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - eqb.paramValueBytes = buf - return eqb.paramValueBytes[pos:], nil - } - } - if strippedArg, ok := stripNamedType(&refVal); ok { return eqb.encodeExtendedParamValue(ci, oid, formatCode, strippedArg) } diff --git a/messages.go b/messages.go index 87c0aa22..01ece44e 100644 --- a/messages.go +++ b/messages.go @@ -2,15 +2,11 @@ package pgx import ( "database/sql/driver" - - "github.com/jackc/pgx/v5/pgtype" ) func convertDriverValuers(args []interface{}) ([]interface{}, error) { for i, arg := range args { switch arg := arg.(type) { - case pgtype.BinaryEncoder: - case pgtype.TextEncoder: case driver.Valuer: v, err := callValuerValue(arg) if err != nil { diff --git a/pgtype/bits_test.go b/pgtype/bits_test.go index a585ef8b..5a82743c 100644 --- a/pgtype/bits_test.go +++ b/pgtype/bits_test.go @@ -17,7 +17,7 @@ func isExpectedEqBits(a interface{}) func(interface{}) bool { } func TestBitsCodecBit(t *testing.T) { - testPgxCodec(t, "bit(40)", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "bit(40)", []testutil.TranscodeTestCase{ { pgtype.Bits{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Valid: true}, new(pgtype.Bits), @@ -34,7 +34,7 @@ func TestBitsCodecBit(t *testing.T) { } func TestBitsCodecVarbit(t *testing.T) { - testPgxCodec(t, "varbit", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "varbit", []testutil.TranscodeTestCase{ { pgtype.Bits{Bytes: []byte{}, Len: 0, Valid: true}, new(pgtype.Bits), @@ -54,12 +54,3 @@ func TestBitsCodecVarbit(t *testing.T) { {nil, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, }) } - -func TestBitsNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select B'111111111'", - Value: &pgtype.Bits{Bytes: []byte{255, 128}, Len: 9, Valid: true}, - }, - }) -} diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go index ec8c31d9..57094144 100644 --- a/pgtype/bool_test.go +++ b/pgtype/bool_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestBoolCodec(t *testing.T) { - testPgxCodec(t, "bool", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "bool", []testutil.TranscodeTestCase{ {true, new(bool), isExpectedEq(true)}, {false, new(bool), isExpectedEq(false)}, {true, new(pgtype.Bool), isExpectedEq(pgtype.Bool{Bool: true, Valid: true})}, diff --git a/pgtype/box_test.go b/pgtype/box_test.go index 8056e819..72e37b76 100644 --- a/pgtype/box_test.go +++ b/pgtype/box_test.go @@ -8,7 +8,7 @@ import ( ) func TestBoxCodec(t *testing.T) { - testPgxCodec(t, "box", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "box", []testutil.TranscodeTestCase{ { pgtype.Box{ P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, @@ -35,15 +35,3 @@ func TestBoxCodec(t *testing.T) { {nil, new(pgtype.Box), isExpectedEq(pgtype.Box{})}, }) } - -func TestBoxNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select '3.14, 1.678, 7.1, 5.234'::box", - Value: &pgtype.Box{ - P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, - Valid: true, - }, - }, - }) -} diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go index d99d28b6..41c3482e 100644 --- a/pgtype/bytea_test.go +++ b/pgtype/bytea_test.go @@ -28,7 +28,7 @@ func isExpectedEqBytes(a interface{}) func(interface{}) bool { } func TestByteaCodec(t *testing.T) { - testPgxCodec(t, "bytea", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "bytea", []testutil.TranscodeTestCase{ {[]byte{1, 2, 3}, new([]byte), isExpectedEqBytes([]byte{1, 2, 3})}, {[]byte{}, new([]byte), isExpectedEqBytes([]byte{})}, {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, diff --git a/pgtype/circle_test.go b/pgtype/circle_test.go index 6fbf4c31..f38d8194 100644 --- a/pgtype/circle_test.go +++ b/pgtype/circle_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestCircleTranscode(t *testing.T) { - testPgxCodec(t, "circle", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "circle", []testutil.TranscodeTestCase{ { pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, new(pgtype.Circle), diff --git a/pgtype/date_test.go b/pgtype/date_test.go index 268759c1..d57b9115 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func isExpectedEqTime(a interface{}) func(interface{}) bool { @@ -17,7 +18,7 @@ func isExpectedEqTime(a interface{}) func(interface{}) bool { } func TestDateCodec(t *testing.T) { - testPgxCodec(t, "date", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "date", []testutil.TranscodeTestCase{ {time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC))}, {time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC))}, {time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC))}, diff --git a/pgtype/float4_test.go b/pgtype/float4_test.go index 85b3b21d..a0069836 100644 --- a/pgtype/float4_test.go +++ b/pgtype/float4_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestFloat4Codec(t *testing.T) { - testPgxCodec(t, "float4", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "float4", []testutil.TranscodeTestCase{ {pgtype.Float4{Float: -1, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float: -1, Valid: true})}, {pgtype.Float4{Float: 0, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float: 0, Valid: true})}, {pgtype.Float4{Float: 1, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float: 1, Valid: true})}, diff --git a/pgtype/float8_test.go b/pgtype/float8_test.go index 3c7660b8..e69174eb 100644 --- a/pgtype/float8_test.go +++ b/pgtype/float8_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestFloat8Codec(t *testing.T) { - testPgxCodec(t, "float8", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "float8", []testutil.TranscodeTestCase{ {pgtype.Float8{Float: -1, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float: -1, Valid: true})}, {pgtype.Float8{Float: 0, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float: 0, Valid: true})}, {pgtype.Float8{Float: 1, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float: 1, Valid: true})}, diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go index edd94db7..8d2b6971 100644 --- a/pgtype/hstore_test.go +++ b/pgtype/hstore_test.go @@ -75,7 +75,7 @@ func TestHstoreCodec(t *testing.T) { return &s } - tests := []PgxTranscodeTestCase{ + tests := []testutil.TranscodeTestCase{ { map[string]string{}, new(map[string]string), @@ -134,25 +134,25 @@ func TestHstoreCodec(t *testing.T) { // Special key values // at beginning - tests = append(tests, PgxTranscodeTestCase{ + tests = append(tests, testutil.TranscodeTestCase{ map[string]string{s + "foo": "bar"}, new(map[string]string), isExpectedEqMapStringString(map[string]string{s + "foo": "bar"}), }) // in middle - tests = append(tests, PgxTranscodeTestCase{ + tests = append(tests, testutil.TranscodeTestCase{ map[string]string{"foo" + s + "bar": "bar"}, new(map[string]string), isExpectedEqMapStringString(map[string]string{"foo" + s + "bar": "bar"}), }) // at end - tests = append(tests, PgxTranscodeTestCase{ + tests = append(tests, testutil.TranscodeTestCase{ map[string]string{"foo" + s: "bar"}, new(map[string]string), isExpectedEqMapStringString(map[string]string{"foo" + s: "bar"}), }) // is key - tests = append(tests, PgxTranscodeTestCase{ + tests = append(tests, testutil.TranscodeTestCase{ map[string]string{s: "bar"}, new(map[string]string), isExpectedEqMapStringString(map[string]string{s: "bar"}), @@ -161,25 +161,25 @@ func TestHstoreCodec(t *testing.T) { // Special value values // at beginning - tests = append(tests, PgxTranscodeTestCase{ + tests = append(tests, testutil.TranscodeTestCase{ map[string]string{"foo": s + "bar"}, new(map[string]string), isExpectedEqMapStringString(map[string]string{"foo": s + "bar"}), }) // in middle - tests = append(tests, PgxTranscodeTestCase{ + tests = append(tests, testutil.TranscodeTestCase{ map[string]string{"foo": "foo" + s + "bar"}, new(map[string]string), isExpectedEqMapStringString(map[string]string{"foo": "foo" + s + "bar"}), }) // at end - tests = append(tests, PgxTranscodeTestCase{ + tests = append(tests, testutil.TranscodeTestCase{ map[string]string{"foo": "foo" + s}, new(map[string]string), isExpectedEqMapStringString(map[string]string{"foo": "foo" + s}), }) // is key - tests = append(tests, PgxTranscodeTestCase{ + tests = append(tests, testutil.TranscodeTestCase{ map[string]string{"foo": s}, new(map[string]string), isExpectedEqMapStringString(map[string]string{"foo": s}), @@ -187,6 +187,6 @@ func TestHstoreCodec(t *testing.T) { } for _, format := range formats { - testPgxCodecFormat(t, "hstore", tests, conn, format.name, format.code) + testutil.RunTranscodeTestsFormat(t, "hstore", tests, conn, format.name, format.code) } } diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go index 4ead4672..c3f66755 100644 --- a/pgtype/inet_test.go +++ b/pgtype/inet_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func isExpectedEqIPNet(a interface{}) func(interface{}) bool { @@ -17,7 +18,7 @@ func isExpectedEqIPNet(a interface{}) func(interface{}) bool { } func TestInetTranscode(t *testing.T) { - testPgxCodec(t, "inet", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "inet", []testutil.TranscodeTestCase{ {mustParseInet(t, "0.0.0.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "0.0.0.0/32"))}, {mustParseInet(t, "127.0.0.1/8"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "127.0.0.1/8"))}, {mustParseInet(t, "12.34.56.65/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "12.34.56.65/32"))}, @@ -34,7 +35,7 @@ func TestInetTranscode(t *testing.T) { } func TestCidrTranscode(t *testing.T) { - testPgxCodec(t, "cidr", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "cidr", []testutil.TranscodeTestCase{ {mustParseInet(t, "0.0.0.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "0.0.0.0/32"))}, {mustParseInet(t, "127.0.0.1/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "127.0.0.1/32"))}, {mustParseInet(t, "12.34.56.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "12.34.56.0/32"))}, diff --git a/pgtype/int_test.go b/pgtype/int_test.go index 77aa0589..1dbf32d5 100644 --- a/pgtype/int_test.go +++ b/pgtype/int_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestInt2Codec(t *testing.T) { - testPgxCodec(t, "int2", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "int2", []testutil.TranscodeTestCase{ {int8(1), new(int16), isExpectedEq(int16(1))}, {int16(1), new(int16), isExpectedEq(int16(1))}, {int32(1), new(int16), isExpectedEq(int16(1))}, @@ -89,7 +90,7 @@ func TestInt2UnmarshalJSON(t *testing.T) { } func TestInt4Codec(t *testing.T) { - testPgxCodec(t, "int4", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "int4", []testutil.TranscodeTestCase{ {int8(1), new(int32), isExpectedEq(int32(1))}, {int16(1), new(int32), isExpectedEq(int32(1))}, {int32(1), new(int32), isExpectedEq(int32(1))}, @@ -169,7 +170,7 @@ func TestInt4UnmarshalJSON(t *testing.T) { } func TestInt8Codec(t *testing.T) { - testPgxCodec(t, "int8", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "int8", []testutil.TranscodeTestCase{ {int8(1), new(int64), isExpectedEq(int64(1))}, {int16(1), new(int64), isExpectedEq(int64(1))}, {int32(1), new(int64), isExpectedEq(int64(1))}, diff --git a/pgtype/int_test.go.erb b/pgtype/int_test.go.erb index afcc8b9c..c98f6488 100644 --- a/pgtype/int_test.go.erb +++ b/pgtype/int_test.go.erb @@ -10,7 +10,7 @@ import ( <% [2, 4, 8].each do |pg_byte_size| %> <% pg_bit_size = pg_byte_size * 8 %> func TestInt<%= pg_byte_size %>Codec(t *testing.T) { - testPgxCodec(t, "int<%= pg_byte_size %>", []PgxTranscodeTestCase{ + testPgxCodec(t, "int<%= pg_byte_size %>", []testutil.TranscodeTestCase{ {int8(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, {int16(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, {int32(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, diff --git a/pgtype/interval_test.go b/pgtype/interval_test.go index 75733ff1..310ea6bc 100644 --- a/pgtype/interval_test.go +++ b/pgtype/interval_test.go @@ -5,10 +5,11 @@ import ( "time" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestIntervalCodec(t *testing.T) { - testPgxCodec(t, "interval", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "interval", []testutil.TranscodeTestCase{ { pgtype.Interval{Microseconds: 1, Valid: true}, new(pgtype.Interval), diff --git a/pgtype/json_test.go b/pgtype/json_test.go index 156217ac..a255c45a 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -2,6 +2,8 @@ package pgtype_test import ( "testing" + + "github.com/jackc/pgx/v5/pgtype/testutil" ) func isExpectedEqMap(a interface{}) func(interface{}) bool { @@ -37,7 +39,7 @@ func TestJSONCodec(t *testing.T) { Age int `json:"age"` } - testPgxCodec(t, "json", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "json", []testutil.TranscodeTestCase{ {[]byte("{}"), new([]byte), isExpectedEqBytes([]byte("{}"))}, {[]byte("null"), new([]byte), isExpectedEqBytes([]byte("null"))}, {[]byte("42"), new([]byte), isExpectedEqBytes([]byte("42"))}, diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index 282caeb1..981ec28e 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -2,6 +2,8 @@ package pgtype_test import ( "testing" + + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestJSONBTranscode(t *testing.T) { @@ -10,7 +12,7 @@ func TestJSONBTranscode(t *testing.T) { Age int `json:"age"` } - testPgxCodec(t, "jsonb", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "jsonb", []testutil.TranscodeTestCase{ {[]byte("{}"), new([]byte), isExpectedEqBytes([]byte("{}"))}, {[]byte("null"), new([]byte), isExpectedEqBytes([]byte("null"))}, {[]byte("42"), new([]byte), isExpectedEqBytes([]byte("42"))}, diff --git a/pgtype/line_test.go b/pgtype/line_test.go index 669d9b8d..b7c82e35 100644 --- a/pgtype/line_test.go +++ b/pgtype/line_test.go @@ -25,7 +25,7 @@ func TestLineTranscode(t *testing.T) { t.Skip("Skipping due to unimplemented line type in PG 9.3") } - testPgxCodec(t, "line", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "line", []testutil.TranscodeTestCase{ { pgtype.Line{ A: 1.23, B: 4.56, C: 7.89012345, diff --git a/pgtype/lseg_test.go b/pgtype/lseg_test.go index 1866439f..51fe2adb 100644 --- a/pgtype/lseg_test.go +++ b/pgtype/lseg_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestLsegTranscode(t *testing.T) { - testPgxCodec(t, "lseg", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "lseg", []testutil.TranscodeTestCase{ { pgtype.Lseg{ P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go index 3e59c580..2ce7b007 100644 --- a/pgtype/macaddr_test.go +++ b/pgtype/macaddr_test.go @@ -4,6 +4,8 @@ import ( "bytes" "net" "testing" + + "github.com/jackc/pgx/v5/pgtype/testutil" ) func isExpectedEqHardwareAddr(a interface{}) func(interface{}) bool { @@ -24,7 +26,7 @@ func isExpectedEqHardwareAddr(a interface{}) func(interface{}) bool { } func TestMacaddrCodec(t *testing.T) { - testPgxCodec(t, "macaddr", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "macaddr", []testutil.TranscodeTestCase{ { mustParseMacaddr(t, "01:23:45:67:89:ab"), new(net.HardwareAddr), diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 06d81ffe..0d89dc2d 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -74,7 +74,7 @@ func TestNumericCodec(t *testing.T) { max.Add(max, big.NewInt(1)) longestNumeric := pgtype.Numeric{Int: max, Exp: -16383, Valid: true} - testPgxCodec(t, "numeric", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "numeric", []testutil.TranscodeTestCase{ {mustParseNumeric(t, "1"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))}, {mustParseNumeric(t, "3.14159"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "3.14159"))}, {mustParseNumeric(t, "100010001"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "100010001"))}, @@ -122,22 +122,22 @@ func TestNumericCodecFuzz(t *testing.T) { max := &big.Int{} max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) - tests := make([]PgxTranscodeTestCase, 0, 2000) + tests := make([]testutil.TranscodeTestCase, 0, 2000) for i := 0; i < 10; i++ { for j := -50; j < 50; j++ { num := (&big.Int{}).Rand(r, max) n := pgtype.Numeric{Int: num, Exp: int32(j), Valid: true} - tests = append(tests, PgxTranscodeTestCase{n, new(pgtype.Numeric), isExpectedEqNumeric(n)}) + tests = append(tests, testutil.TranscodeTestCase{n, new(pgtype.Numeric), isExpectedEqNumeric(n)}) negNum := &big.Int{} negNum.Neg(num) n = pgtype.Numeric{Int: negNum, Exp: int32(j), Valid: true} - tests = append(tests, PgxTranscodeTestCase{n, new(pgtype.Numeric), isExpectedEqNumeric(n)}) + tests = append(tests, testutil.TranscodeTestCase{n, new(pgtype.Numeric), isExpectedEqNumeric(n)}) } } - testPgxCodec(t, "numeric", tests) + testutil.RunTranscodeTests(t, "numeric", tests) } func TestNumericMarshalJSON(t *testing.T) { diff --git a/pgtype/path_test.go b/pgtype/path_test.go index 291fa9d4..546f4d36 100644 --- a/pgtype/path_test.go +++ b/pgtype/path_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func isExpectedEqPath(a interface{}) func(interface{}) bool { @@ -26,7 +27,7 @@ func isExpectedEqPath(a interface{}) func(interface{}) bool { } func TestPathTranscode(t *testing.T) { - testPgxCodec(t, "path", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "path", []testutil.TranscodeTestCase{ { pgtype.Path{ P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}}, diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index ee2730ee..5c818ec7 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -186,26 +186,6 @@ type TextDecoder interface { DecodeText(ci *ConnInfo, src []byte) error } -// BinaryEncoder is implemented by types that can encode themselves into the -// PostgreSQL binary wire format. -type BinaryEncoder interface { - // EncodeBinary should append the binary format of self to buf. If self is the - // SQL value NULL then append nothing and return (nil, nil). The caller of - // EncodeBinary is responsible for writing the correct NULL value or the - // length of the data written. - EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) -} - -// TextEncoder is implemented by types that can encode themselves into the -// PostgreSQL text wire format. -type TextEncoder interface { - // EncodeText should append the text format of self to buf. If self is the - // SQL value NULL then append nothing and return (nil, nil). The caller of - // EncodeText is responsible for writing the correct NULL value or the - // length of the data written. - EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) -} - type nullAssignmentError struct { dst interface{} } @@ -400,8 +380,6 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { var formatCode int16 if t.Codec != nil { formatCode = t.Codec.PreferredFormat() - } else if _, ok := t.Value.(BinaryEncoder); ok { - formatCode = BinaryFormatCode } ci.oidToFormatCode[t.OID] = formatCode } diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 703e1843..3c6b138a 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -2,17 +2,13 @@ package pgtype_test import ( "bytes" - "context" "database/sql" "errors" - "fmt" "net" - "reflect" "testing" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" _ "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -292,54 +288,8 @@ func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { } } -type PgxTranscodeTestCase struct { - src interface{} - dst interface{} - test func(interface{}) bool -} - func isExpectedEq(a interface{}) func(interface{}) bool { return func(v interface{}) bool { return a == v } } - -func testPgxCodec(t testing.TB, pgTypeName string, tests []PgxTranscodeTestCase) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - formats := []struct { - name string - code int16 - }{ - {name: "TextFormat", code: pgx.TextFormatCode}, - {name: "BinaryFormat", code: pgx.BinaryFormatCode}, - } - - for _, format := range formats { - testPgxCodecFormat(t, pgTypeName, tests, conn, format.name, format.code) - } -} - -func testPgxCodecFormat(t testing.TB, pgTypeName string, tests []PgxTranscodeTestCase, conn *pgx.Conn, formatName string, formatCode int16) { - _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - for i, tt := range tests { - err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{formatCode}, tt.src).Scan(tt.dst) - if err != nil { - t.Errorf("%s %d: %v", formatName, i, err) - } - - dst := reflect.ValueOf(tt.dst) - if dst.Kind() == reflect.Ptr { - dst = dst.Elem() - } - - if !tt.test(dst.Interface()) { - t.Errorf("%s %d: unexpected result for %v: %v", formatName, i, tt.src, dst.Interface()) - } - } -} diff --git a/pgtype/point_test.go b/pgtype/point_test.go index 8046da92..03d948b7 100644 --- a/pgtype/point_test.go +++ b/pgtype/point_test.go @@ -5,11 +5,12 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/require" ) func TestPointCodec(t *testing.T) { - testPgxCodec(t, "point", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "point", []testutil.TranscodeTestCase{ { pgtype.Point{P: pgtype.Vec2{1.234, 5.6789012345}, Valid: true}, new(pgtype.Point), diff --git a/pgtype/polygon_test.go b/pgtype/polygon_test.go index c0912b31..9c7c0182 100644 --- a/pgtype/polygon_test.go +++ b/pgtype/polygon_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func isExpectedEqPolygon(a interface{}) func(interface{}) bool { @@ -26,7 +27,7 @@ func isExpectedEqPolygon(a interface{}) func(interface{}) bool { } func TestPolygonTranscode(t *testing.T) { - testPgxCodec(t, "polygon", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "polygon", []testutil.TranscodeTestCase{ { pgtype.Polygon{ P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go index ec555eb2..36742f75 100644 --- a/pgtype/qchar_test.go +++ b/pgtype/qchar_test.go @@ -3,16 +3,18 @@ package pgtype_test import ( "math" "testing" + + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestQcharTranscode(t *testing.T) { - var tests []PgxTranscodeTestCase + var tests []testutil.TranscodeTestCase for i := 0; i <= math.MaxUint8; i++ { - tests = append(tests, PgxTranscodeTestCase{rune(i), new(rune), isExpectedEq(rune(i))}) - tests = append(tests, PgxTranscodeTestCase{byte(i), new(byte), isExpectedEq(byte(i))}) + tests = append(tests, testutil.TranscodeTestCase{rune(i), new(rune), isExpectedEq(rune(i))}) + tests = append(tests, testutil.TranscodeTestCase{byte(i), new(byte), isExpectedEq(byte(i))}) } - tests = append(tests, PgxTranscodeTestCase{nil, new(*rune), isExpectedEq((*rune)(nil))}) - tests = append(tests, PgxTranscodeTestCase{nil, new(*byte), isExpectedEq((*byte)(nil))}) + tests = append(tests, testutil.TranscodeTestCase{nil, new(*rune), isExpectedEq((*rune)(nil))}) + tests = append(tests, testutil.TranscodeTestCase{nil, new(*byte), isExpectedEq((*byte)(nil))}) - testPgxCodec(t, `"char"`, tests) + testutil.RunTranscodeTests(t, `"char"`, tests) } diff --git a/pgtype/testutil/testutil.go b/pgtype/testutil/testutil.go index bfe9b01f..19ed4412 100644 --- a/pgtype/testutil/testutil.go +++ b/pgtype/testutil/testutil.go @@ -2,34 +2,15 @@ package testutil import ( "context" - "database/sql" "fmt" "os" "reflect" "testing" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" _ "github.com/jackc/pgx/v5/stdlib" ) -func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { - var sqlDriverName string - switch driverName { - case "github.com/jackc/pgx/stdlib": - sqlDriverName = "pgx" - default: - t.Fatalf("Unknown driver %v", driverName) - } - - db, err := sql.Open(sqlDriverName, os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - t.Fatal(err) - } - - return db -} - func MustConnectPgx(t testing.TB) *pgx.Conn { conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { @@ -57,369 +38,48 @@ func MustCloseContext(t testing.TB, conn interface { } } -type forceTextEncoder struct { - e pgtype.TextEncoder +type TranscodeTestCase struct { + Src interface{} + Dst interface{} + Test func(interface{}) bool } -func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - return f.e.EncodeText(ci, buf) -} - -type forceBinaryEncoder struct { - e pgtype.BinaryEncoder -} - -func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - return f.e.EncodeBinary(ci, buf) -} - -func ForceEncoder(e interface{}, formatCode int16) interface{} { - switch formatCode { - case pgx.TextFormatCode: - if e, ok := e.(pgtype.TextEncoder); ok { - return forceTextEncoder{e: e} - } - case pgx.BinaryFormatCode: - if e, ok := e.(pgtype.BinaryEncoder); ok { - return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} - } - } - return nil -} - -func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { - TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - }) -} - -func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - TestDatabaseSQLSuccessfulTranscodeEqFunc(t, "github.com/jackc/pgx/stdlib", pgTypeName, values, eqFunc) -} - -func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { +func RunTranscodeTests(t testing.TB, pgTypeName string, tests []TranscodeTestCase) { conn := MustConnectPgx(t) defer MustCloseContext(t, conn) + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + RunTranscodeTestsFormat(t, pgTypeName, tests, conn, format.name, format.code) + } +} + +func RunTranscodeTestsFormat(t testing.TB, pgTypeName string, tests []TranscodeTestCase, conn *pgx.Conn, formatName string, formatCode int16) { _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName)) if err != nil { t.Fatal(err) } - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for i, v := range values { - for _, paramFormat := range formats { - for _, resultFormat := range formats { - vEncoder := ForceEncoder(v, paramFormat.formatCode) - if vEncoder == nil { - t.Logf("Skipping Param %s Result %s: %#v does not implement %v for encoding", paramFormat.name, resultFormat.name, v, paramFormat.name) - continue - } - switch resultFormat.formatCode { - case pgx.TextFormatCode: - if _, ok := v.(pgtype.TextEncoder); !ok { - t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name) - continue - } - case pgx.BinaryFormatCode: - if _, ok := v.(pgtype.BinaryEncoder); !ok { - t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name) - continue - } - } - - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - - err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{resultFormat.formatCode}, vEncoder).Scan(result.Interface()) - if err != nil { - t.Errorf("Param %s Result %s %d: %v", paramFormat.name, resultFormat.name, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("Param %s Result %s %d: expected %v, got %v", paramFormat.name, resultFormat.name, i, derefV, result.Elem().Interface()) - } - } - } - } -} - -func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := MustConnectDatabaseSQL(t, driverName) - defer MustClose(t, conn) - - ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - for i, v := range values { - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err := ps.QueryRow(v).Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", driverName, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) - } - } -} - -type NormalizeTest struct { - SQL string - Value interface{} -} - -func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) { - TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - }) -} - -func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { - TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc) - TestDatabaseSQLSuccessfulNormalizeEqFunc(t, "github.com/jackc/pgx/stdlib", tests, eqFunc) -} - -func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { - conn := MustConnectPgx(t) - defer MustCloseContext(t, conn) - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - for i, tt := range tests { - for _, fc := range formats { - psName := fmt.Sprintf("test%d", i) - _, err := conn.Prepare(context.Background(), psName, tt.SQL) - if err != nil { - t.Fatal(err) - } - - queryResultFormats := pgx.QueryResultFormats{fc.formatCode} - if ForceEncoder(tt.Value, fc.formatCode) == nil { - t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name) - continue - } - // Derefence value if it is a pointer - derefV := tt.Value - refVal := reflect.ValueOf(tt.Value) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err = conn.QueryRow(context.Background(), psName, queryResultFormats).Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", fc.name, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) - } - } - } -} - -func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { - conn := MustConnectDatabaseSQL(t, driverName) - defer MustClose(t, conn) - - for i, tt := range tests { - ps, err := conn.Prepare(tt.SQL) + err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{formatCode}, tt.Src).Scan(tt.Dst) if err != nil { - t.Errorf("%d. %v", i, err) - continue + t.Errorf("%s %d: %v", formatName, i, err) } - // Derefence value if it is a pointer - derefV := tt.Value - refVal := reflect.ValueOf(tt.Value) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() + dst := reflect.ValueOf(tt.Dst) + if dst.Kind() == reflect.Ptr { + dst = dst.Elem() } - result := reflect.New(reflect.TypeOf(derefV)) - err = ps.QueryRow().Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", driverName, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + if !tt.Test(dst.Interface()) { + t.Errorf("%s %d: unexpected result for %v: %v", formatName, i, tt.Src, dst.Interface()) } } } - -func TestGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) { - TestPgxGoZeroToNullConversion(t, pgTypeName, zero) - TestDatabaseSQLGoZeroToNullConversion(t, "github.com/jackc/pgx/stdlib", pgTypeName, zero) -} - -func TestNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) { - TestPgxNullToGoZeroConversion(t, pgTypeName, zero) - TestDatabaseSQLNullToGoZeroConversion(t, "github.com/jackc/pgx/stdlib", pgTypeName, zero) -} - -func TestPgxGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) { - conn := MustConnectPgx(t) - defer MustCloseContext(t, conn) - - _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s is null", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for _, paramFormat := range formats { - vEncoder := ForceEncoder(zero, paramFormat.formatCode) - if vEncoder == nil { - t.Logf("Skipping Param %s: %#v does not implement %v for encoding", paramFormat.name, zero, paramFormat.name) - continue - } - - var result bool - err := conn.QueryRow(context.Background(), "test", vEncoder).Scan(&result) - if err != nil { - t.Errorf("Param %s: %v", paramFormat.name, err) - } - - if !result { - t.Errorf("Param %s: did not convert zero to null", paramFormat.name) - } - } -} - -func TestPgxNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) { - conn := MustConnectPgx(t) - defer MustCloseContext(t, conn) - - _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select null::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for _, resultFormat := range formats { - - switch resultFormat.formatCode { - case pgx.TextFormatCode: - if _, ok := zero.(pgtype.TextEncoder); !ok { - t.Logf("Skipping Result %s: %#v does not implement %v for decoding", resultFormat.name, zero, resultFormat.name) - continue - } - case pgx.BinaryFormatCode: - if _, ok := zero.(pgtype.BinaryEncoder); !ok { - t.Logf("Skipping Result %s: %#v does not implement %v for decoding", resultFormat.name, zero, resultFormat.name) - continue - } - } - - // Derefence value if it is a pointer - derefZero := zero - refVal := reflect.ValueOf(zero) - if refVal.Kind() == reflect.Ptr { - derefZero = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefZero)) - - err := conn.QueryRow(context.Background(), "test").Scan(result.Interface()) - if err != nil { - t.Errorf("Result %s: %v", resultFormat.name, err) - } - - if !reflect.DeepEqual(result.Elem().Interface(), derefZero) { - t.Errorf("Result %s: did not convert null to zero", resultFormat.name) - } - } -} - -func TestDatabaseSQLGoZeroToNullConversion(t testing.TB, driverName, pgTypeName string, zero interface{}) { - conn := MustConnectDatabaseSQL(t, driverName) - defer MustClose(t, conn) - - ps, err := conn.Prepare(fmt.Sprintf("select $1::%s is null", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - var result bool - err = ps.QueryRow(zero).Scan(&result) - if err != nil { - t.Errorf("%v %v", driverName, err) - } - - if !result { - t.Errorf("%v: did not convert zero to null", driverName) - } -} - -func TestDatabaseSQLNullToGoZeroConversion(t testing.TB, driverName, pgTypeName string, zero interface{}) { - conn := MustConnectDatabaseSQL(t, driverName) - defer MustClose(t, conn) - - ps, err := conn.Prepare(fmt.Sprintf("select null::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - // Derefence value if it is a pointer - derefZero := zero - refVal := reflect.ValueOf(zero) - if refVal.Kind() == reflect.Ptr { - derefZero = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefZero)) - - err = ps.QueryRow().Scan(result.Interface()) - if err != nil { - t.Errorf("%v %v", driverName, err) - } - - if !reflect.DeepEqual(result.Elem().Interface(), derefZero) { - t.Errorf("%s: did not convert null to zero", driverName) - } -} diff --git a/pgtype/text_test.go b/pgtype/text_test.go index f45978a7..7a188a67 100644 --- a/pgtype/text_test.go +++ b/pgtype/text_test.go @@ -17,7 +17,7 @@ func (someFmtStringer) String() string { func TestTextCodec(t *testing.T) { for _, pgTypeName := range []string{"text", "varchar"} { - testPgxCodec(t, pgTypeName, []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, pgTypeName, []testutil.TranscodeTestCase{ { pgtype.Text{String: "", Valid: true}, new(pgtype.Text), @@ -47,7 +47,7 @@ func TestTextCodec(t *testing.T) { // // So this is simply a smoke test of the name type. func TestTextCodecName(t *testing.T) { - testPgxCodec(t, "name", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "name", []testutil.TranscodeTestCase{ { pgtype.Text{String: "", Valid: true}, new(pgtype.Text), @@ -65,7 +65,7 @@ func TestTextCodecName(t *testing.T) { // Test fixed length char types like char(3) func TestTextCodecBPChar(t *testing.T) { - testPgxCodec(t, "char(3)", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "char(3)", []testutil.TranscodeTestCase{ { pgtype.Text{String: "a ", Valid: true}, new(pgtype.Text), @@ -95,7 +95,7 @@ func TestTextCodecACLItem(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) - testPgxCodecFormat(t, "aclitem", []PgxTranscodeTestCase{ + testutil.RunTranscodeTestsFormat(t, "aclitem", []testutil.TranscodeTestCase{ { pgtype.Text{String: "postgres=arwdDxt/postgres", Valid: true}, new(pgtype.Text), @@ -123,7 +123,7 @@ func TestTextCodecACLItemRoleWithSpecialCharacters(t *testing.T) { t.Skipf("Role with special characters does not exist.") } - testPgxCodecFormat(t, "aclitem", []PgxTranscodeTestCase{ + testutil.RunTranscodeTestsFormat(t, "aclitem", []testutil.TranscodeTestCase{ { pgtype.Text{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}, new(pgtype.Text), diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go index 0d3afe5e..4ff53151 100644 --- a/pgtype/tid_test.go +++ b/pgtype/tid_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestTIDCodec(t *testing.T) { - testPgxCodec(t, "tid", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "tid", []testutil.TranscodeTestCase{ { pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, new(pgtype.TID), diff --git a/pgtype/time_test.go b/pgtype/time_test.go index 8394a951..61f9ef0e 100644 --- a/pgtype/time_test.go +++ b/pgtype/time_test.go @@ -5,10 +5,11 @@ import ( "time" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestTimeCodec(t *testing.T) { - testPgxCodec(t, "time", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "time", []testutil.TranscodeTestCase{ { pgtype.Time{Microseconds: 0, Valid: true}, new(pgtype.Time), diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index 1caca58b..1bbceaf5 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -11,7 +11,7 @@ import ( ) func TestTimestampCodec(t *testing.T) { - testPgxCodec(t, "timestamp", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "timestamp", []testutil.TranscodeTestCase{ {time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC))}, {time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC))}, {time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC))}, diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index 2a45d2cb..42439b7b 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -11,7 +11,7 @@ import ( ) func TestTimestamptzCodec(t *testing.T) { - testPgxCodec(t, "timestamptz", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "timestamptz", []testutil.TranscodeTestCase{ {time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local))}, {time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local))}, {time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local))}, diff --git a/pgtype/uint32_test.go b/pgtype/uint32_test.go index 8e58605d..98adbee4 100644 --- a/pgtype/uint32_test.go +++ b/pgtype/uint32_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestUint32Codec(t *testing.T) { - testPgxCodec(t, "oid", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "oid", []testutil.TranscodeTestCase{ { pgtype.Uint32{Uint: pgtype.TextOID, Valid: true}, new(pgtype.Uint32), diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index 71de8d67..870e7ae1 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -5,11 +5,12 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/require" ) func TestUUIDCodec(t *testing.T) { - testPgxCodec(t, "uuid", []PgxTranscodeTestCase{ + testutil.RunTranscodeTests(t, "uuid", []testutil.TranscodeTestCase{ { pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, new(pgtype.UUID), diff --git a/pgtype/zeronull/float8.go b/pgtype/zeronull/float8.go index 1c29d331..d1c053c5 100644 --- a/pgtype/zeronull/float8.go +++ b/pgtype/zeronull/float8.go @@ -8,6 +8,8 @@ import ( type Float8 float64 +func (Float8) SkipUnderlyingTypePlan() {} + // ScanFloat64 implements the Float64Scanner interface. func (f *Float8) ScanFloat64(n pgtype.Float8) error { if !n.Valid { diff --git a/pgtype/zeronull/float8_test.go b/pgtype/zeronull/float8_test.go index b0331faa..fd683657 100644 --- a/pgtype/zeronull/float8_test.go +++ b/pgtype/zeronull/float8_test.go @@ -7,17 +7,28 @@ import ( "github.com/jackc/pgx/v5/pgtype/zeronull" ) +func isExpectedEq(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + return a == v + } +} + func TestFloat8Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "float8", []interface{}{ - (zeronull.Float8)(1), - (zeronull.Float8)(0), + testutil.RunTranscodeTests(t, "float8", []testutil.TranscodeTestCase{ + { + (zeronull.Float8)(1), + new(zeronull.Float8), + isExpectedEq((zeronull.Float8)(1)), + }, + { + nil, + new(zeronull.Float8), + isExpectedEq((zeronull.Float8)(0)), + }, + { + (zeronull.Float8)(0), + new(interface{}), + isExpectedEq(nil), + }, }) } - -func TestFloat8ConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "float8", (zeronull.Float8)(0)) -} - -func TestFloat8ConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "float8", (zeronull.Float8)(0)) -} diff --git a/pgtype/zeronull/int.go b/pgtype/zeronull/int.go index 0149834b..1d479307 100644 --- a/pgtype/zeronull/int.go +++ b/pgtype/zeronull/int.go @@ -11,6 +11,8 @@ import ( type Int2 int16 +func (Int2) SkipUnderlyingTypePlan() {} + // ScanInt64 implements the Int64Scanner interface. func (dst *Int2) ScanInt64(n int64, valid bool) error { if !valid { @@ -57,6 +59,8 @@ func (src Int2) Value() (driver.Value, error) { type Int4 int32 +func (Int4) SkipUnderlyingTypePlan() {} + // ScanInt64 implements the Int64Scanner interface. func (dst *Int4) ScanInt64(n int64, valid bool) error { if !valid { @@ -103,6 +107,8 @@ func (src Int4) Value() (driver.Value, error) { type Int8 int64 +func (Int8) SkipUnderlyingTypePlan() {} + // ScanInt64 implements the Int64Scanner interface. func (dst *Int8) ScanInt64(n int64, valid bool) error { if !valid { diff --git a/pgtype/zeronull/int.go.erb b/pgtype/zeronull/int.go.erb index 935b56a9..9e3b5ef0 100644 --- a/pgtype/zeronull/int.go.erb +++ b/pgtype/zeronull/int.go.erb @@ -12,6 +12,8 @@ import ( <% pg_bit_size = pg_byte_size * 8 %> type Int<%= pg_byte_size %> int<%= pg_bit_size %> +func (Int<%= pg_byte_size %>) SkipUnderlyingTypePlan() {} + // ScanInt64 implements the Int64Scanner interface. func (dst *Int<%= pg_byte_size %>) ScanInt64(n int64, valid bool) error { if !valid { diff --git a/pgtype/zeronull/int_test.go b/pgtype/zeronull/int_test.go index bd2ef0b2..d687a733 100644 --- a/pgtype/zeronull/int_test.go +++ b/pgtype/zeronull/int_test.go @@ -9,46 +9,61 @@ import ( ) func TestInt2Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ - (zeronull.Int2)(1), - (zeronull.Int2)(0), + testutil.RunTranscodeTests(t, "int2", []testutil.TranscodeTestCase{ + { + (zeronull.Int2)(1), + new(zeronull.Int2), + isExpectedEq((zeronull.Int2)(1)), + }, + { + nil, + new(zeronull.Int2), + isExpectedEq((zeronull.Int2)(0)), + }, + { + (zeronull.Int2)(0), + new(interface{}), + isExpectedEq(nil), + }, }) } -func TestInt2ConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "int2", (zeronull.Int2)(0)) -} - -func TestInt2ConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "int2", (zeronull.Int2)(0)) -} - func TestInt4Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ - (zeronull.Int4)(1), - (zeronull.Int4)(0), + testutil.RunTranscodeTests(t, "int4", []testutil.TranscodeTestCase{ + { + (zeronull.Int4)(1), + new(zeronull.Int4), + isExpectedEq((zeronull.Int4)(1)), + }, + { + nil, + new(zeronull.Int4), + isExpectedEq((zeronull.Int4)(0)), + }, + { + (zeronull.Int4)(0), + new(interface{}), + isExpectedEq(nil), + }, }) } -func TestInt4ConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "int4", (zeronull.Int4)(0)) -} - -func TestInt4ConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "int4", (zeronull.Int4)(0)) -} - func TestInt8Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ - (zeronull.Int8)(1), - (zeronull.Int8)(0), + testutil.RunTranscodeTests(t, "int8", []testutil.TranscodeTestCase{ + { + (zeronull.Int8)(1), + new(zeronull.Int8), + isExpectedEq((zeronull.Int8)(1)), + }, + { + nil, + new(zeronull.Int8), + isExpectedEq((zeronull.Int8)(0)), + }, + { + (zeronull.Int8)(0), + new(interface{}), + isExpectedEq(nil), + }, }) } - -func TestInt8ConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "int8", (zeronull.Int8)(0)) -} - -func TestInt8ConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "int8", (zeronull.Int8)(0)) -} diff --git a/pgtype/zeronull/int_test.go.erb b/pgtype/zeronull/int_test.go.erb index 51273710..b33cfa4a 100644 --- a/pgtype/zeronull/int_test.go.erb +++ b/pgtype/zeronull/int_test.go.erb @@ -10,17 +10,22 @@ import ( <% [2, 4, 8].each do |pg_byte_size| %> <% pg_bit_size = pg_byte_size * 8 %> func TestInt<%= pg_byte_size %>Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int<%= pg_byte_size %>", []interface{}{ - (zeronull.Int<%= pg_byte_size %>)(1), - (zeronull.Int<%= pg_byte_size %>)(0), + testutil.RunTranscodeTests(t, "int<%= pg_byte_size %>", []testutil.TranscodeTestCase{ + { + (zeronull.Int<%= pg_byte_size %>)(1), + new(zeronull.Int<%= pg_byte_size %>), + isExpectedEq((zeronull.Int<%= pg_byte_size %>)(1)), + }, + { + nil, + new(zeronull.Int<%= pg_byte_size %>), + isExpectedEq((zeronull.Int<%= pg_byte_size %>)(0)), + }, + { + (zeronull.Int<%= pg_byte_size %>)(0), + new(interface{}), + isExpectedEq(nil), + }, }) } - -func TestInt<%= pg_byte_size %>ConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "int<%= pg_byte_size %>", (zeronull.Int<%= pg_byte_size %>)(0)) -} - -func TestInt<%= pg_byte_size %>ConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "int<%= pg_byte_size %>", (zeronull.Int<%= pg_byte_size %>)(0)) -} <% end %> diff --git a/pgtype/zeronull/text.go b/pgtype/zeronull/text.go index fcbc16d7..b768e308 100644 --- a/pgtype/zeronull/text.go +++ b/pgtype/zeronull/text.go @@ -8,6 +8,8 @@ import ( type Text string +func (Text) SkipUnderlyingTypePlan() {} + // ScanText implements the TextScanner interface. func (dst *Text) ScanText(v pgtype.Text) error { if !v.Valid { diff --git a/pgtype/zeronull/text_test.go b/pgtype/zeronull/text_test.go index e4293024..e20ab868 100644 --- a/pgtype/zeronull/text_test.go +++ b/pgtype/zeronull/text_test.go @@ -8,16 +8,21 @@ import ( ) func TestTextTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "text", []interface{}{ - (zeronull.Text)("foo"), - (zeronull.Text)(""), + testutil.RunTranscodeTests(t, "text", []testutil.TranscodeTestCase{ + { + (zeronull.Text)("foo"), + new(zeronull.Text), + isExpectedEq((zeronull.Text)("foo")), + }, + { + nil, + new(zeronull.Text), + isExpectedEq((zeronull.Text)("")), + }, + { + (zeronull.Text)(""), + new(interface{}), + isExpectedEq(nil), + }, }) } - -func TestTextConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "text", (zeronull.Text)("")) -} - -func TestTextConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "text", (zeronull.Text)("")) -} diff --git a/pgtype/zeronull/timestamp.go b/pgtype/zeronull/timestamp.go index 1c2a1a63..6e2c3d1e 100644 --- a/pgtype/zeronull/timestamp.go +++ b/pgtype/zeronull/timestamp.go @@ -10,6 +10,8 @@ import ( type Timestamp time.Time +func (Timestamp) SkipUnderlyingTypePlan() {} + func (ts *Timestamp) ScanTimestamp(v pgtype.Timestamp) error { if !v.Valid { *ts = Timestamp{} diff --git a/pgtype/zeronull/timestamp_test.go b/pgtype/zeronull/timestamp_test.go index 2eb072c6..9d8ee7ae 100644 --- a/pgtype/zeronull/timestamp_test.go +++ b/pgtype/zeronull/timestamp_test.go @@ -8,22 +8,31 @@ import ( "github.com/jackc/pgx/v5/pgtype/zeronull" ) -func TestTimestampTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ - (zeronull.Timestamp)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), - (zeronull.Timestamp)(time.Time{}), - }, func(a, b interface{}) bool { - at := a.(zeronull.Timestamp) - bt := b.(zeronull.Timestamp) +func isExpectedEqTimestamp(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + at := time.Time(a.(zeronull.Timestamp)) + vt := time.Time(v.(zeronull.Timestamp)) - return time.Time(at).Equal(time.Time(bt)) + return at.Equal(vt) + } +} + +func TestTimestampTranscode(t *testing.T) { + testutil.RunTranscodeTests(t, "timestamp", []testutil.TranscodeTestCase{ + { + (zeronull.Timestamp)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + new(zeronull.Timestamp), + isExpectedEqTimestamp((zeronull.Timestamp)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC))), + }, + { + nil, + new(zeronull.Timestamp), + isExpectedEqTimestamp((zeronull.Timestamp)(time.Time{})), + }, + { + (zeronull.Timestamp)(time.Time{}), + new(interface{}), + isExpectedEq(nil), + }, }) } - -func TestTimestampConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "timestamp", (zeronull.Timestamp)(time.Time{})) -} - -func TestTimestampConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "timestamp", (zeronull.Timestamp)(time.Time{})) -} diff --git a/pgtype/zeronull/timestamptz.go b/pgtype/zeronull/timestamptz.go index c5378059..79fcb563 100644 --- a/pgtype/zeronull/timestamptz.go +++ b/pgtype/zeronull/timestamptz.go @@ -10,6 +10,8 @@ import ( type Timestamptz time.Time +func (Timestamptz) SkipUnderlyingTypePlan() {} + func (ts *Timestamptz) ScanTimestamptz(v pgtype.Timestamptz) error { if !v.Valid { *ts = Timestamptz{} diff --git a/pgtype/zeronull/timestamptz_test.go b/pgtype/zeronull/timestamptz_test.go index e288b9e8..15ac66da 100644 --- a/pgtype/zeronull/timestamptz_test.go +++ b/pgtype/zeronull/timestamptz_test.go @@ -8,22 +8,31 @@ import ( "github.com/jackc/pgx/v5/pgtype/zeronull" ) -func TestTimestamptzTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ - (zeronull.Timestamptz)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), - (zeronull.Timestamptz)(time.Time{}), - }, func(a, b interface{}) bool { - at := a.(zeronull.Timestamptz) - bt := b.(zeronull.Timestamptz) +func isExpectedEqTimestamptz(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + at := time.Time(a.(zeronull.Timestamptz)) + vt := time.Time(v.(zeronull.Timestamptz)) - return time.Time(at).Equal(time.Time(bt)) + return at.Equal(vt) + } +} + +func TestTimestamptzTranscode(t *testing.T) { + testutil.RunTranscodeTests(t, "timestamptz", []testutil.TranscodeTestCase{ + { + (zeronull.Timestamptz)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + new(zeronull.Timestamptz), + isExpectedEqTimestamptz((zeronull.Timestamptz)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC))), + }, + { + nil, + new(zeronull.Timestamptz), + isExpectedEqTimestamptz((zeronull.Timestamptz)(time.Time{})), + }, + { + (zeronull.Timestamptz)(time.Time{}), + new(interface{}), + isExpectedEq(nil), + }, }) } - -func TestTimestamptzConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "timestamptz", (zeronull.Timestamptz)(time.Time{})) -} - -func TestTimestamptzConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "timestamptz", (zeronull.Timestamptz)(time.Time{})) -} diff --git a/pgtype/zeronull/uuid.go b/pgtype/zeronull/uuid.go index a87003e1..abe5049e 100644 --- a/pgtype/zeronull/uuid.go +++ b/pgtype/zeronull/uuid.go @@ -8,6 +8,8 @@ import ( type UUID [16]byte +func (UUID) SkipUnderlyingTypePlan() {} + // ScanUUID implements the UUIDScanner interface. func (u *UUID) ScanUUID(v pgtype.UUID) error { if !v.Valid { diff --git a/pgtype/zeronull/uuid_test.go b/pgtype/zeronull/uuid_test.go index 913698d9..5be1d22e 100644 --- a/pgtype/zeronull/uuid_test.go +++ b/pgtype/zeronull/uuid_test.go @@ -8,16 +8,21 @@ import ( ) func TestUUIDTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - (*zeronull.UUID)(&[16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), - (*zeronull.UUID)(&[16]byte{}), + testutil.RunTranscodeTests(t, "uuid", []testutil.TranscodeTestCase{ + { + (zeronull.UUID)([16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), + new(zeronull.UUID), + isExpectedEq((zeronull.UUID)([16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15})), + }, + { + nil, + new(zeronull.UUID), + isExpectedEq((zeronull.UUID)([16]byte{})), + }, + { + (zeronull.UUID)([16]byte{}), + new(interface{}), + isExpectedEq(nil), + }, }) } - -func TestUUIDConvertsGoZeroToNull(t *testing.T) { - testutil.TestGoZeroToNullConversion(t, "uuid", (*zeronull.UUID)(&[16]byte{})) -} - -func TestUUIDConvertsNullToGoZero(t *testing.T) { - testutil.TestNullToGoZeroConversion(t, "uuid", (*zeronull.UUID)(&[16]byte{})) -} diff --git a/values.go b/values.go index b5ce4f7c..3cb27bbf 100644 --- a/values.go +++ b/values.go @@ -37,15 +37,6 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e switch arg := arg.(type) { case driver.Valuer: return callValuerValue(arg) - case pgtype.TextEncoder: - buf, err := arg.EncodeText(ci, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil case float32: return float64(arg), nil case float64: @@ -89,21 +80,7 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e } if dt, found := ci.DataTypeForValue(arg); found { - if dt.Value != nil { - v := dt.Value - err := v.Set(arg) - if err != nil { - return nil, err - } - buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil - } else if dt.Codec != nil { + if dt.Codec != nil { buf, err := ci.Encode(0, TextFormatCode, arg, nil) if err != nil { return nil, err @@ -132,30 +109,6 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32 } switch arg := arg.(type) { - case pgtype.BinaryEncoder: - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - argBuf, err := arg.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if argBuf != nil { - buf = argBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - return buf, nil - case pgtype.TextEncoder: - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - argBuf, err := arg.EncodeText(ci, buf) - if err != nil { - return nil, err - } - if argBuf != nil { - buf = argBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - return buf, nil case string: buf = pgio.AppendInt32(buf, int32(len(arg))) buf = append(buf, arg...) @@ -173,35 +126,7 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32 } if dt, ok := ci.DataTypeForOID(oid); ok { - if dt.Value != nil { - value := dt.Value - err := value.Set(arg) - if err != nil { - { - if arg, ok := arg.(driver.Valuer); ok { - v, err := callValuerValue(arg) - if err != nil { - return nil, err - } - return encodePreparedStatementArgument(ci, buf, oid, v) - } - } - - return nil, err - } - - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if argBuf != nil { - buf = argBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - return buf, nil - } else if dt.Codec != nil { + if dt.Codec != nil { sp := len(buf) buf = pgio.AppendInt32(buf, -1) argBuf, err := ci.Encode(oid, BinaryFormatCode, arg, buf) @@ -227,9 +152,7 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32 // determination can be made. func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid uint32, arg interface{}) int16 { switch arg.(type) { - case pgtype.BinaryEncoder: - return BinaryFormatCode - case string, *string, pgtype.TextEncoder: + case string, *string: return TextFormatCode } From 4cf6dc94477921d444fc6822ae2ad0eb59ab0aae Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jan 2022 12:16:02 -0600 Subject: [PATCH 0869/1158] Remove BinaryEncoder and TextEncoder --- example_custom_type_test.go | 60 ++++++------------------ pgtype/pgtype.go | 92 ------------------------------------- pgtype/pgtype_test.go | 1 + rows.go | 27 ++--------- 4 files changed, 18 insertions(+), 162 deletions(-) diff --git a/example_custom_type_test.go b/example_custom_type_test.go index 10014278..fc0d4b78 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -5,7 +5,6 @@ import ( "fmt" "os" "regexp" - "strconv" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" @@ -15,52 +14,26 @@ var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) // Point represents a point that may be null. type Point struct { - X, Y float64 // Coordinates of point + X, Y float32 // Coordinates of point Valid bool } -func (dst *Point) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Point", src) -} - -func (dst *Point) Get() interface{} { - if !dst.Valid { - return nil +func (p *Point) ScanPoint(v pgtype.Point) error { + *p = Point{ + X: float32(v.P.X), + Y: float32(v.P.Y), + Valid: v.Valid, } - - return dst -} - -func (src *Point) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - *dst = Point{} - return nil - } - - s := string(src) - match := pointRegexp.FindStringSubmatch(s) - if match == nil { - return fmt.Errorf("Received invalid point: %v", s) - } - - x, err := strconv.ParseFloat(match[1], 64) - if err != nil { - return fmt.Errorf("Received invalid point: %v", s) - } - y, err := strconv.ParseFloat(match[2], 64) - if err != nil { - return fmt.Errorf("Received invalid point: %v", s) - } - - *dst = Point{X: x, Y: y, Valid: true} - return nil } +func (p Point) PointValue() (pgtype.Point, error) { + return pgtype.Point{ + P: pgtype.Vec2{X: float64(p.X), Y: float64(p.Y)}, + Valid: true, + }, nil +} + func (src *Point) String() string { if !src.Valid { return "null point" @@ -85,13 +58,6 @@ func Example_CustomType() { return } - // Override registered handler for point - conn.ConnInfo().RegisterDataType(pgtype.DataType{ - Value: &Point{}, - Name: "point", - OID: 600, - }) - p := &Point{} err = conn.QueryRow(context.Background(), "select null::point").Scan(p) if err != nil { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5c818ec7..8725aaa0 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -172,20 +172,6 @@ type Codec interface { DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) } -type BinaryDecoder interface { - // DecodeBinary decodes src into BinaryDecoder. If src is nil then the - // original SQL value is NULL. BinaryDecoder takes ownership of src. The - // caller MUST not use it again. - DecodeBinary(ci *ConnInfo, src []byte) error -} - -type TextDecoder interface { - // DecodeText decodes src into TextDecoder. If src is nil then the original - // SQL value is NULL. TextDecoder takes ownership of src. The caller MUST not - // use it again. - DecodeText(ci *ConnInfo, src []byte) error -} - type nullAssignmentError struct { dst interface{} } @@ -197,9 +183,6 @@ func (e *nullAssignmentError) Error() string { type DataType struct { Value Value - textDecoder TextDecoder - binaryDecoder BinaryDecoder - Codec Codec Name string @@ -384,14 +367,6 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { ci.oidToFormatCode[t.OID] = formatCode } - if d, ok := t.Value.(TextDecoder); ok { - t.textDecoder = d - } - - if d, ok := t.Value.(BinaryDecoder); ok { - t.binaryDecoder = d - } - ci.reflectTypeToDataType = nil // Invalidated by type registration } @@ -476,69 +451,6 @@ func (scanPlanDstResultDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, return newPlan.Scan(ci, oid, formatCode, src, dst) } -type scanPlanDstBinaryDecoder struct{} - -func (scanPlanDstBinaryDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if d, ok := (dst).(BinaryDecoder); ok { - return d.DecodeBinary(ci, src) - } - - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) -} - -type scanPlanDstTextDecoder struct{} - -func (plan scanPlanDstTextDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if d, ok := (dst).(TextDecoder); ok { - return d.DecodeText(ci, src) - } - - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) -} - -type scanPlanDataTypeAssignTo DataType - -func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - dt := (*DataType)(plan) - var err error - - switch formatCode { - case BinaryFormatCode: - if dt.binaryDecoder == nil { - return fmt.Errorf("dt.binaryDecoder is nil") - } - err = dt.binaryDecoder.DecodeBinary(ci, src) - case TextFormatCode: - if dt.textDecoder == nil { - return fmt.Errorf("dt.textDecoder is nil") - } - err = dt.textDecoder.DecodeText(ci, src) - } - if err != nil { - return err - } - - assignToErr := dt.Value.AssignTo(dst) - if assignToErr == nil { - return nil - } - - if dstPtr, ok := dst.(*interface{}); ok { - *dstPtr = dt.Value.Get() - return nil - } - - // assignToErr might have failed because the type of destination has changed - newPlan := ci.PlanScan(oid, formatCode, dst) - if newPlan, sameType := newPlan.(*scanPlanDataTypeAssignTo); !sameType { - return newPlan.Scan(ci, oid, formatCode, src, dst) - } - - return assignToErr -} - type scanPlanCodecSQLScanner struct{ c Codec } func (plan *scanPlanCodecSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { @@ -1086,8 +998,6 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan case ByteaOID, TextOID, VarcharOID, JSONOID: return scanPlanBinaryBytes{} } - case BinaryDecoder: - return scanPlanDstBinaryDecoder{} } case TextFormatCode: switch dst.(type) { @@ -1097,8 +1007,6 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan if oid != ByteaOID { return scanPlanBinaryBytes{} } - case TextDecoder: - return scanPlanDstTextDecoder{} case TextScanner: return scanPlanTextAnyToTextScanner{} } diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 3c6b138a..9bd665c5 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -159,6 +159,7 @@ func (ct *pgCustomType) DecodeText(ci *pgtype.ConnInfo, buf []byte) error { } func TestConnInfoScanUnregisteredOIDToCustomType(t *testing.T) { + t.Skip("TODO - unskip later in v5") // may no longer be relevent unregisteredOID := uint32(999999) ci := pgtype.NewConnInfo() diff --git a/rows.go b/rows.go index 8e9fdc70..620ce5d6 100644 --- a/rows.go +++ b/rows.go @@ -247,32 +247,13 @@ func (rows *connRows) Values() ([]interface{}, error) { if dt, ok := rows.connInfo.DataTypeForOID(fd.DataTypeOID); ok { if dt.Value != nil { - - value := dt.Value - switch fd.Format { case TextFormatCode: - if decoder, ok := value.(pgtype.TextDecoder); ok { - err := decoder.DecodeText(rows.connInfo, buf) - if err != nil { - rows.fatal(err) - } - values = append(values, decoder.(pgtype.Value).Get()) - } else { - values = append(values, string(buf)) - } + values = append(values, string(buf)) case BinaryFormatCode: - if decoder, ok := value.(pgtype.BinaryDecoder); ok { - err := decoder.DecodeBinary(rows.connInfo, buf) - if err != nil { - rows.fatal(err) - } - values = append(values, value.Get()) - } else { - newBuf := make([]byte, len(buf)) - copy(newBuf, buf) - values = append(values, newBuf) - } + newBuf := make([]byte, len(buf)) + copy(newBuf, buf) + values = append(values, newBuf) default: rows.fatal(errors.New("Unknown format code")) } From db95cee40c8dd7f629ac7abd0e1f6e6ca0a21e6e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jan 2022 12:18:40 -0600 Subject: [PATCH 0870/1158] Remove pgtype.Value interface --- extended_query_builder.go | 20 +------------------ pgtype/pgtype.go | 42 ++------------------------------------- rows.go | 13 +----------- 3 files changed, 4 insertions(+), 71 deletions(-) diff --git a/extended_query_builder.go b/extended_query_builder.go index 34662eb1..e87b5edd 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -1,7 +1,6 @@ package pgx import ( - "database/sql/driver" "fmt" "reflect" @@ -85,24 +84,7 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o } if dt, ok := ci.DataTypeForOID(oid); ok { - if dt.Value != nil { - value := dt.Value - err := value.Set(arg) - if err != nil { - { - if arg, ok := arg.(driver.Valuer); ok { - v, err := callValuerValue(arg) - if err != nil { - return nil, err - } - return eqb.encodeExtendedParamValue(ci, oid, formatCode, v) - } - } - - return nil, err - } - return eqb.encodeExtendedParamValue(ci, oid, formatCode, value) - } else if dt.Codec != nil { + if dt.Codec != nil { buf, err := ci.Encode(oid, formatCode, arg, eqb.paramValueBytes) if err != nil { return nil, err diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 8725aaa0..9a39e3ff 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -129,26 +129,6 @@ const ( BinaryFormatCode = 1 ) -// Value translates values to and from an internal canonical representation for the type. To actually be usable a type -// that implements Value should also implement some combination of BinaryDecoder, BinaryEncoder, TextDecoder, -// and TextEncoder. -// -// Operations that update a Value (e.g. Set, DecodeText, DecodeBinary) should entirely replace the value. e.g. Internal -// slices should be replaced not resized and reused. This allows Get and AssignTo to return a slice directly rather -// than incur a usually unnecessary copy. -type Value interface { - // Set converts and assigns src to itself. Value takes ownership of src. - Set(src interface{}) error - - // Get returns the simplest representation of Value. Get may return a pointer to an internal value but it must never - // mutate that value. e.g. If Get returns a []byte Value must never change the contents of the []byte. - Get() interface{} - - // AssignTo converts and assigns the Value to dst. AssignTo may a pointer to an internal value but it must never - // mutate that value. e.g. If Get returns a []byte Value must never change the contents of the []byte. - AssignTo(dst interface{}) error -} - type Codec interface { // FormatSupported returns true if the format is supported. FormatSupported(int16) bool @@ -181,12 +161,9 @@ func (e *nullAssignmentError) Error() string { } type DataType struct { - Value Value - Codec Codec - - Name string - OID uint32 + Name string + OID uint32 } type ConnInfo struct { @@ -352,10 +329,6 @@ func NewConnInfo() *ConnInfo { } func (ci *ConnInfo) RegisterDataType(t DataType) { - if t.Value != nil { - t.Value = NewValue(t.Value) - } - ci.oidToDataType[t.OID] = &t ci.nameToDataType[t.Name] = &t @@ -391,12 +364,6 @@ func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) { func (ci *ConnInfo) buildReflectTypeToDataType() { ci.reflectTypeToDataType = make(map[reflect.Type]*DataType) - for _, dt := range ci.oidToDataType { - if dt.Value != nil { - ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt - } - } - for reflectType, name := range ci.reflectTypeToName { if dt, ok := ci.nameToDataType[name]; ok { ci.reflectTypeToDataType[reflectType] = dt @@ -1091,11 +1058,6 @@ func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) } } -// NewValue returns a new instance of the same type as v. -func NewValue(v Value) Value { - return reflect.New(reflect.ValueOf(v).Elem().Type()).Interface().(Value) -} - var ErrScanTargetTypeChanged = errors.New("scan target type changed") func codecScan(codec Codec, ci *ConnInfo, oid uint32, format int16, src []byte, dst interface{}) error { diff --git a/rows.go b/rows.go index 620ce5d6..e076ce43 100644 --- a/rows.go +++ b/rows.go @@ -246,18 +246,7 @@ func (rows *connRows) Values() ([]interface{}, error) { } if dt, ok := rows.connInfo.DataTypeForOID(fd.DataTypeOID); ok { - if dt.Value != nil { - switch fd.Format { - case TextFormatCode: - values = append(values, string(buf)) - case BinaryFormatCode: - newBuf := make([]byte, len(buf)) - copy(newBuf, buf) - values = append(values, newBuf) - default: - rows.fatal(errors.New("Unknown format code")) - } - } else if dt.Codec != nil { + if dt.Codec != nil { value, err := dt.Codec.DecodeValue(rows.connInfo, fd.DataTypeOID, fd.Format, buf) if err != nil { rows.fatal(err) From 2b395f3730e54d7121c22188d23de9580d18658b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jan 2022 12:21:16 -0600 Subject: [PATCH 0871/1158] pgtype.DataType.Codec can never be nil --- extended_query_builder.go | 20 +++++++++---------- pgtype/pgtype.go | 14 +++---------- rows.go | 10 ++++------ values.go | 42 ++++++++++++++++++--------------------- 4 files changed, 35 insertions(+), 51 deletions(-) diff --git a/extended_query_builder.go b/extended_query_builder.go index e87b5edd..8b250685 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -83,18 +83,16 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg) } - if dt, ok := ci.DataTypeForOID(oid); ok { - if dt.Codec != nil { - buf, err := ci.Encode(oid, formatCode, arg, eqb.paramValueBytes) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - eqb.paramValueBytes = buf - return eqb.paramValueBytes[pos:], nil + if _, ok := ci.DataTypeForOID(oid); ok { + buf, err := ci.Encode(oid, formatCode, arg, eqb.paramValueBytes) + if err != nil { + return nil, err } + if buf == nil { + return nil, nil + } + eqb.paramValueBytes = buf + return eqb.paramValueBytes[pos:], nil } if strippedArg, ok := stripNamedType(&refVal); ok { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 9a39e3ff..5bb04d71 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -331,15 +331,7 @@ func NewConnInfo() *ConnInfo { func (ci *ConnInfo) RegisterDataType(t DataType) { ci.oidToDataType[t.OID] = &t ci.nameToDataType[t.Name] = &t - - { - var formatCode int16 - if t.Codec != nil { - formatCode = t.Codec.PreferredFormat() - } - ci.oidToFormatCode[t.OID] = formatCode - } - + ci.oidToFormatCode[t.OID] = t.Codec.PreferredFormat() ci.reflectTypeToDataType = nil // Invalidated by type registration } @@ -992,7 +984,7 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan } } - if dt != nil && dt.Codec != nil { + if dt != nil { if plan := dt.Codec.PlanScan(ci, oid, formatCode, dst, false); plan != nil { return plan } @@ -1105,7 +1097,7 @@ func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) Enco } } - if dt != nil && dt.Codec != nil { + if dt != nil { if plan := dt.Codec.PlanEncode(ci, oid, format, value); plan != nil { return plan } diff --git a/rows.go b/rows.go index e076ce43..5a4bc9a9 100644 --- a/rows.go +++ b/rows.go @@ -246,13 +246,11 @@ func (rows *connRows) Values() ([]interface{}, error) { } if dt, ok := rows.connInfo.DataTypeForOID(fd.DataTypeOID); ok { - if dt.Codec != nil { - value, err := dt.Codec.DecodeValue(rows.connInfo, fd.DataTypeOID, fd.Format, buf) - if err != nil { - rows.fatal(err) - } - values = append(values, value) + value, err := dt.Codec.DecodeValue(rows.connInfo, fd.DataTypeOID, fd.Format, buf) + if err != nil { + rows.fatal(err) } + values = append(values, value) } else { switch fd.Format { case TextFormatCode: diff --git a/values.go b/values.go index 3cb27bbf..a6fdcc86 100644 --- a/values.go +++ b/values.go @@ -79,17 +79,15 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e return int64(arg), nil } - if dt, found := ci.DataTypeForValue(arg); found { - if dt.Codec != nil { - buf, err := ci.Encode(0, TextFormatCode, arg, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil + if _, found := ci.DataTypeForValue(arg); found { + buf, err := ci.Encode(0, TextFormatCode, arg, nil) + if err != nil { + return nil, err } + if buf == nil { + return nil, nil + } + return string(buf), nil } if refVal.Kind() == reflect.Ptr { @@ -125,20 +123,18 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32 return encodePreparedStatementArgument(ci, buf, oid, arg) } - if dt, ok := ci.DataTypeForOID(oid); ok { - if dt.Codec != nil { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - argBuf, err := ci.Encode(oid, BinaryFormatCode, arg, buf) - if err != nil { - return nil, err - } - if argBuf != nil { - buf = argBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - return buf, nil + if _, ok := ci.DataTypeForOID(oid); ok { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := ci.Encode(oid, BinaryFormatCode, arg, buf) + if err != nil { + return nil, err } + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + return buf, nil } if strippedArg, ok := stripNamedType(&refVal); ok { From aedf7d63e5b7f676856d33d8f7a35583cfb55989 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jan 2022 16:19:32 -0600 Subject: [PATCH 0872/1158] Expose try wrap functions in ConnInfo --- pgtype/pgtype.go | 132 +++++++++++++++++++++++++++++------------------ 1 file changed, 81 insertions(+), 51 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5bb04d71..8aad4fb7 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -175,22 +175,42 @@ type ConnInfo struct { reflectTypeToDataType map[reflect.Type]*DataType + // TryWrapEncodePlanFuncs is a slice of functions that will wrap a value that cannot be encoded by the Codec. Every + // time a wrapper is found the PlanEncode method will be recursively called with the new value. This allows several layers of wrappers + // to be built up. There are default functions placed in this slice by NewConnInfo(). In most cases these functions + // should run last. i.e. Additional functions should typically be prepended not appended. + TryWrapEncodePlanFuncs []TryWrapEncodePlanFunc + + // TryWrapScanPlanFuncs is a slice of functions that will wrap a target that cannot be scanned into by the Codec. Every + // time a wrapper is found the PlanScan method will be recursively called with the new target. This allows several layers of wrappers + // to be built up. There are default functions placed in this slice by NewConnInfo(). In most cases these functions + // should run last. i.e. Additional functions should typically be prepended not appended. + TryWrapScanPlanFuncs []TryWrapScanPlanFunc + preferAssignToOverSQLScannerTypes map[reflect.Type]struct{} } -func newConnInfo() *ConnInfo { - return &ConnInfo{ +func NewConnInfo() *ConnInfo { + ci := &ConnInfo{ oidToDataType: make(map[uint32]*DataType), nameToDataType: make(map[string]*DataType), reflectTypeToName: make(map[reflect.Type]string), oidToFormatCode: make(map[uint32]int16), oidToResultFormatCode: make(map[uint32]int16), preferAssignToOverSQLScannerTypes: make(map[reflect.Type]struct{}), - } -} -func NewConnInfo() *ConnInfo { - ci := newConnInfo() + TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ + TryWrapDerefPointerEncodePlan, + TryWrapFindUnderlyingTypeEncodePlan, + TryWrapBuiltinTypeEncodePlan, + }, + + TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{ + TryPointerPointerScanPlan, + TryFindUnderlyingTypeScanPlan, + TryWrapBuiltinTypeScanPlan, + }, + } ci.RegisterDataType(DataType{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementCodec: &TextFormatOnlyCodec{TextCodec{}}, ElementOID: ACLItemOID}}) ci.RegisterDataType(DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementCodec: BoolCodec{}, ElementOID: BoolOID}}) @@ -553,7 +573,11 @@ func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byt return newPlan.Scan(ci, oid, formatCode, src, dst) } -type tryWrapScanPlanFunc func(dst interface{}) (plan WrappedScanPlanNextSetter, nextDst interface{}, ok bool) +// TryWrapScanPlanFunc is a function that tries to create a wrapper plan for target. If successful it returns a plan +// that will convert the target passed to Scan and then call the next plan. nextTarget is target as it will be converted +// by plan. It must be used to find another suitable ScanPlan. When it is found SetNext must be called on plan for it +// to be usabled. ok indicates if a suitable wrapper was found. +type TryWrapScanPlanFunc func(target interface{}) (plan WrappedScanPlanNextSetter, nextTarget interface{}, ok bool) type pointerPointerScanPlan struct { dstType reflect.Type @@ -578,8 +602,10 @@ func (plan *pointerPointerScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode in return plan.next.Scan(ci, oid, formatCode, src, el.Interface()) } -func tryPointerPointerScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter, nextDst interface{}, ok bool) { - if dstValue := reflect.ValueOf(dst); dstValue.Kind() == reflect.Ptr { +// TryPointerPointerScanPlan handles a pointer to a pointer by setting the target to nil for SQL NULL and allocating and +// scanning for non-NULL. +func TryPointerPointerScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextTarget interface{}, ok bool) { + if dstValue := reflect.ValueOf(target); dstValue.Kind() == reflect.Ptr { elemValue := dstValue.Elem() if elemValue.Kind() == reflect.Ptr { plan = &pointerPointerScanPlan{dstType: dstValue.Type()} @@ -628,7 +654,9 @@ func (plan *underlyingTypeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode in return plan.next.Scan(ci, oid, formatCode, src, reflect.ValueOf(dst).Convert(plan.nextDstType).Interface()) } -func tryUnderlyingTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter, nextDst interface{}, ok bool) { +// TryFindUnderlyingTypeScanPlan tries to convert to a Go builtin type. e.g. If value was of type MyString and +// MyString was defined as a string then a wrapper plan would be returned that converts MyString to string. +func TryFindUnderlyingTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter, nextDst interface{}, ok bool) { if _, ok := dst.(SkipUnderlyingTypePlanner); ok { return nil, nil, false } @@ -651,50 +679,53 @@ type WrappedScanPlanNextSetter interface { ScanPlan } -func tryWrapBuiltinTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter, nextDst interface{}, ok bool) { - switch dst := dst.(type) { +// TryWrapBuiltinTypeScanPlan tries to wrap a builtin type with a wrapper that provides additional methods. e.g. If +// value was of type int32 then a wrapper plan would be returned that converts target to a value that implements +// Int64Scanner. +func TryWrapBuiltinTypeScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextDst interface{}, ok bool) { + switch target := target.(type) { case *int8: - return &wrapInt8ScanPlan{}, (*int8Wrapper)(dst), true + return &wrapInt8ScanPlan{}, (*int8Wrapper)(target), true case *int16: - return &wrapInt16ScanPlan{}, (*int16Wrapper)(dst), true + return &wrapInt16ScanPlan{}, (*int16Wrapper)(target), true case *int32: - return &wrapInt32ScanPlan{}, (*int32Wrapper)(dst), true + return &wrapInt32ScanPlan{}, (*int32Wrapper)(target), true case *int64: - return &wrapInt64ScanPlan{}, (*int64Wrapper)(dst), true + return &wrapInt64ScanPlan{}, (*int64Wrapper)(target), true case *int: - return &wrapIntScanPlan{}, (*intWrapper)(dst), true + return &wrapIntScanPlan{}, (*intWrapper)(target), true case *uint8: - return &wrapUint8ScanPlan{}, (*uint8Wrapper)(dst), true + return &wrapUint8ScanPlan{}, (*uint8Wrapper)(target), true case *uint16: - return &wrapUint16ScanPlan{}, (*uint16Wrapper)(dst), true + return &wrapUint16ScanPlan{}, (*uint16Wrapper)(target), true case *uint32: - return &wrapUint32ScanPlan{}, (*uint32Wrapper)(dst), true + return &wrapUint32ScanPlan{}, (*uint32Wrapper)(target), true case *uint64: - return &wrapUint64ScanPlan{}, (*uint64Wrapper)(dst), true + return &wrapUint64ScanPlan{}, (*uint64Wrapper)(target), true case *uint: - return &wrapUintScanPlan{}, (*uintWrapper)(dst), true + return &wrapUintScanPlan{}, (*uintWrapper)(target), true case *float32: - return &wrapFloat32ScanPlan{}, (*float32Wrapper)(dst), true + return &wrapFloat32ScanPlan{}, (*float32Wrapper)(target), true case *float64: - return &wrapFloat64ScanPlan{}, (*float64Wrapper)(dst), true + return &wrapFloat64ScanPlan{}, (*float64Wrapper)(target), true case *string: - return &wrapStringScanPlan{}, (*stringWrapper)(dst), true + return &wrapStringScanPlan{}, (*stringWrapper)(target), true case *time.Time: - return &wrapTimeScanPlan{}, (*timeWrapper)(dst), true + return &wrapTimeScanPlan{}, (*timeWrapper)(target), true case *time.Duration: - return &wrapDurationScanPlan{}, (*durationWrapper)(dst), true + return &wrapDurationScanPlan{}, (*durationWrapper)(target), true case *net.IPNet: - return &wrapNetIPNetScanPlan{}, (*netIPNetWrapper)(dst), true + return &wrapNetIPNetScanPlan{}, (*netIPNetWrapper)(target), true case *net.IP: - return &wrapNetIPScanPlan{}, (*netIPWrapper)(dst), true + return &wrapNetIPScanPlan{}, (*netIPWrapper)(target), true case *map[string]*string: - return &wrapMapStringToPointerStringScanPlan{}, (*mapStringToPointerStringWrapper)(dst), true + return &wrapMapStringToPointerStringScanPlan{}, (*mapStringToPointerStringWrapper)(target), true case *map[string]string: - return &wrapMapStringToStringScanPlan{}, (*mapStringToStringWrapper)(dst), true + return &wrapMapStringToStringScanPlan{}, (*mapStringToStringWrapper)(target), true case *[16]byte: - return &wrapByte16ScanPlan{}, (*byte16Wrapper)(dst), true + return &wrapByte16ScanPlan{}, (*byte16Wrapper)(target), true case *[]byte: - return &wrapByteSliceScanPlan{}, (*byteSliceWrapper)(dst), true + return &wrapByteSliceScanPlan{}, (*byteSliceWrapper)(target), true } return nil, nil, false @@ -989,13 +1020,7 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan return plan } - tryWrappers := []tryWrapScanPlanFunc{ - tryPointerPointerScanPlan, - tryUnderlyingTypeScanPlan, - tryWrapBuiltinTypeScanPlan, - } - - for _, f := range tryWrappers { + for _, f := range ci.TryWrapScanPlanFuncs { if wrapperPlan, nextDst, ok := f(dst); ok { if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { if _, ok := nextPlan.(scanPlanReflection); !ok { // avoid fallthrough -- this will go away when old system removed. @@ -1102,13 +1127,7 @@ func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) Enco return plan } - tryWrappers := []tryWrapEncodePlanFunc{ - tryDerefPointerEncodePlan, - tryUnderlyingTypeEncodePlan, - tryWrapBuiltinTypeEncodePlan, - } - - for _, f := range tryWrappers { + for _, f := range ci.TryWrapEncodePlanFuncs { if wrapperPlan, nextValue, ok := f(value); ok { if nextPlan := ci.PlanEncode(oid, format, nextValue); nextPlan != nil { wrapperPlan.SetNext(nextPlan) @@ -1121,7 +1140,11 @@ func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) Enco return nil } -type tryWrapEncodePlanFunc func(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) +// TryWrapEncodePlanFunc is a function that tries to create a wrapper plan for value. If successful it returns a plan +// that will convert the value passed to Encode and then call the next plan. nextValue is value as it will be converted +// by plan. It must be used to find another suitable EncodePlan. When it is found SetNext must be called on plan for it +// to be usabled. ok indicates if a suitable wrapper was found. +type TryWrapEncodePlanFunc func(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) type derefPointerEncodePlan struct { next EncodePlan @@ -1139,7 +1162,9 @@ func (plan *derefPointerEncodePlan) Encode(value interface{}, buf []byte) (newBu return plan.next.Encode(ptr.Elem().Interface(), buf) } -func tryDerefPointerEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { +// TryWrapDerefPointerEncodePlan tries to dereference a pointer. e.g. If value was of type *string then a wrapper plan +// would be returned that derefences the value. +func TryWrapDerefPointerEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { if valueType := reflect.TypeOf(value); valueType.Kind() == reflect.Ptr { return &derefPointerEncodePlan{}, reflect.New(valueType.Elem()).Elem().Interface(), true } @@ -1174,7 +1199,9 @@ func (plan *underlyingTypeEncodePlan) Encode(value interface{}, buf []byte) (new return plan.next.Encode(reflect.ValueOf(value).Convert(plan.nextValueType).Interface(), buf) } -func tryUnderlyingTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { +// TryWrapFindUnderlyingTypeEncodePlan tries to convert to a Go builtin type. e.g. If value was of type MyString and +// MyString was defined as a string then a wrapper plan would be returned that converts MyString to string. +func TryWrapFindUnderlyingTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { if _, ok := value.(SkipUnderlyingTypePlanner); ok { return nil, nil, false } @@ -1194,7 +1221,10 @@ type WrappedEncodePlanNextSetter interface { EncodePlan } -func tryWrapBuiltinTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { +// TryWrapBuiltinTypeEncodePlan tries to wrap a builtin type with a wrapper that provides additional methods. e.g. If +// value was of type int32 then a wrapper plan would be returned that converts value to a type that implements +// Int64Valuer. +func TryWrapBuiltinTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { switch value := value.(type) { case int8: return &wrapInt8EncodePlan{}, int8Wrapper(value), true From 322bfedc600692de7e6deb257bed0767c210fe59 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jan 2022 16:20:37 -0600 Subject: [PATCH 0873/1158] Remove old SQL scanner integration --- pgtype/pgtype.go | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 8aad4fb7..73bd249c 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -186,18 +186,15 @@ type ConnInfo struct { // to be built up. There are default functions placed in this slice by NewConnInfo(). In most cases these functions // should run last. i.e. Additional functions should typically be prepended not appended. TryWrapScanPlanFuncs []TryWrapScanPlanFunc - - preferAssignToOverSQLScannerTypes map[reflect.Type]struct{} } func NewConnInfo() *ConnInfo { ci := &ConnInfo{ - oidToDataType: make(map[uint32]*DataType), - nameToDataType: make(map[string]*DataType), - reflectTypeToName: make(map[reflect.Type]string), - oidToFormatCode: make(map[uint32]int16), - oidToResultFormatCode: make(map[uint32]int16), - preferAssignToOverSQLScannerTypes: make(map[reflect.Type]struct{}), + oidToDataType: make(map[uint32]*DataType), + nameToDataType: make(map[string]*DataType), + reflectTypeToName: make(map[reflect.Type]string), + oidToFormatCode: make(map[uint32]int16), + oidToResultFormatCode: make(map[uint32]int16), TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ TryWrapDerefPointerEncodePlan, @@ -402,12 +399,6 @@ func (ci *ConnInfo) FormatCodeForOID(oid uint32) int16 { return TextFormatCode } -// PreferAssignToOverSQLScannerForType makes a sql.Scanner type use the AssignTo scan path instead of sql.Scanner. -// This is primarily for efficient integration with 3rd party numeric and UUID types. -func (ci *ConnInfo) PreferAssignToOverSQLScannerForType(value interface{}) { - ci.preferAssignToOverSQLScannerTypes[reflect.TypeOf(value)] = struct{}{} -} - // EncodePlan is a precompiled plan to encode a particular type into a particular OID and format. type EncodePlan interface { // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return From 5ed95dcd1c05ea64d3c82fedd79caaa28006325a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jan 2022 17:48:31 -0600 Subject: [PATCH 0874/1158] Expose wrap functions on ConnInfo - Remove rarely used ScanPlan.Scan arguments - Plus other refactorings and fixes that fell out of this change. - Plus rows Scan now handles checking for changed type. --- pgtype/array_codec.go | 27 ++- pgtype/bits.go | 6 +- pgtype/bool.go | 8 +- pgtype/box.go | 6 +- pgtype/builtin_wrappers.go | 22 +++ pgtype/bytea.go | 10 +- pgtype/circle.go | 6 +- pgtype/date.go | 6 +- pgtype/enum_codec.go | 4 +- pgtype/float4.go | 8 +- pgtype/float8.go | 10 +- pgtype/hstore.go | 8 +- pgtype/inet.go | 6 +- pgtype/int.go | 88 ++++----- pgtype/int.go.erb | 16 +- pgtype/interval.go | 6 +- pgtype/json.go | 8 +- pgtype/jsonb.go | 6 +- pgtype/line.go | 6 +- pgtype/lseg.go | 6 +- pgtype/macaddr.go | 6 +- pgtype/numeric.go | 14 +- pgtype/numeric_test.go | 2 +- pgtype/path.go | 6 +- pgtype/pgtype.go | 368 +++++++++++++------------------------ pgtype/pgtype_test.go | 27 +-- pgtype/point.go | 6 +- pgtype/polygon.go | 6 +- pgtype/qchar.go | 4 +- pgtype/text.go | 8 +- pgtype/tid.go | 8 +- pgtype/time.go | 6 +- pgtype/timestamp.go | 6 +- pgtype/timestamp_test.go | 2 +- pgtype/timestamptz.go | 6 +- pgtype/timestamptz_test.go | 2 +- pgtype/uint32.go | 6 +- pgtype/uuid.go | 4 +- rows.go | 11 +- stdlib/sql.go | 26 +-- 40 files changed, 352 insertions(+), 435 deletions(-) diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 1e506a43..4cc7e84c 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -204,7 +204,12 @@ func (c *ArrayCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target int return nil } - return (*scanPlanArrayCodec)(c) + return &scanPlanArrayCodec{ + arrayCodec: c, + ci: ci, + oid: oid, + formatCode: format, + } } func (c *ArrayCodec) decodeBinary(ci *ConnInfo, arrayOID uint32, src []byte, array ArraySetter) error { @@ -244,7 +249,7 @@ func (c *ArrayCodec) decodeBinary(ci *ConnInfo, arrayOID uint32, src []byte, arr elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elementScanPlan.Scan(ci, c.ElementOID, BinaryFormatCode, elemSrc, elem) + err = elementScanPlan.Scan(elemSrc, elem) if err != nil { return err } @@ -286,7 +291,7 @@ func (c *ArrayCodec) decodeText(ci *ConnInfo, arrayOID uint32, src []byte, array elemSrc = []byte(s) } - err = elementScanPlan.Scan(ci, c.ElementOID, TextFormatCode, elemSrc, elem) + err = elementScanPlan.Scan(elemSrc, elem) if err != nil { return err } @@ -295,15 +300,23 @@ func (c *ArrayCodec) decodeText(ci *ConnInfo, arrayOID uint32, src []byte, array return nil } -type scanPlanArrayCodec ArrayCodec +type scanPlanArrayCodec struct { + arrayCodec *ArrayCodec + ci *ConnInfo + oid uint32 + formatCode int16 +} -func (spac *scanPlanArrayCodec) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - c := (*ArrayCodec)(spac) +func (spac *scanPlanArrayCodec) Scan(src []byte, dst interface{}) error { + c := spac.arrayCodec + ci := spac.ci + oid := spac.oid + formatCode := spac.formatCode array, err := makeArraySetter(dst) if err != nil { newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) + return newPlan.Scan(src, dst) } if src == nil { diff --git a/pgtype/bits.go b/pgtype/bits.go index 9b499c35..541a3a6b 100644 --- a/pgtype/bits.go +++ b/pgtype/bits.go @@ -41,7 +41,7 @@ func (dst *Bits) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextAnyToBitsScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) + return scanPlanTextAnyToBitsScanner{}.Scan([]byte(src), dst) } return fmt.Errorf("cannot scan %T", src) @@ -163,7 +163,7 @@ func (c BitsCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt type scanPlanBinaryBitsToBitsScanner struct{} -func (scanPlanBinaryBitsToBitsScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryBitsToBitsScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(BitsScanner) if src == nil { @@ -182,7 +182,7 @@ func (scanPlanBinaryBitsToBitsScanner) Scan(ci *ConnInfo, oid uint32, formatCode type scanPlanTextAnyToBitsScanner struct{} -func (scanPlanTextAnyToBitsScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToBitsScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(BitsScanner) if src == nil { diff --git a/pgtype/bool.go b/pgtype/bool.go index 71ce09b6..5aa06870 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -238,7 +238,7 @@ func (c BoolCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt type scanPlanBinaryBoolToBool struct{} -func (scanPlanBinaryBoolToBool) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryBoolToBool) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -259,7 +259,7 @@ func (scanPlanBinaryBoolToBool) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanTextAnyToBool struct{} -func (scanPlanTextAnyToBool) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToBool) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -280,7 +280,7 @@ func (scanPlanTextAnyToBool) Scan(ci *ConnInfo, oid uint32, formatCode int16, sr type scanPlanBinaryBoolToBoolScanner struct{} -func (scanPlanBinaryBoolToBoolScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryBoolToBoolScanner) Scan(src []byte, dst interface{}) error { s, ok := (dst).(BoolScanner) if !ok { return ErrScanTargetTypeChanged @@ -299,7 +299,7 @@ func (scanPlanBinaryBoolToBoolScanner) Scan(ci *ConnInfo, oid uint32, formatCode type scanPlanTextAnyToBoolScanner struct{} -func (scanPlanTextAnyToBoolScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst interface{}) error { s, ok := (dst).(BoolScanner) if !ok { return ErrScanTargetTypeChanged diff --git a/pgtype/box.go b/pgtype/box.go index 80e1bd19..6c637308 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -42,7 +42,7 @@ func (dst *Box) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextAnyToBoxScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) + return scanPlanTextAnyToBoxScanner{}.Scan([]byte(src), dst) } return fmt.Errorf("cannot scan %T", src) @@ -146,7 +146,7 @@ func (BoxCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfac type scanPlanBinaryBoxToBoxScanner struct{} -func (scanPlanBinaryBoxToBoxScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryBoxToBoxScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(BoxScanner) if src == nil { @@ -173,7 +173,7 @@ func (scanPlanBinaryBoxToBoxScanner) Scan(ci *ConnInfo, oid uint32, formatCode i type scanPlanTextAnyToBoxScanner struct{} -func (scanPlanTextAnyToBoxScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToBoxScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(BoxScanner) if src == nil { diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 9df28f55..fe58eee0 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -494,6 +494,8 @@ func (w netIPNetWrapper) InetValue() (Inet, error) { type netIPWrapper net.IP +func (w netIPWrapper) SkipUnderlyingTypePlan() {} + func (w *netIPWrapper) ScanInet(v Inet) error { if !v.Valid { *w = nil @@ -578,6 +580,26 @@ func (w byte16Wrapper) UUIDValue() (UUID, error) { type byteSliceWrapper []byte +func (w byteSliceWrapper) SkipUnderlyingTypePlan() {} + +func (w *byteSliceWrapper) ScanText(v Text) error { + if !v.Valid { + *w = nil + return nil + } + + *w = byteSliceWrapper(v.String) + return nil +} + +func (w byteSliceWrapper) TextValue() (Text, error) { + if w == nil { + return Text{}, nil + } + + return Text{String: string(w), Valid: true}, nil +} + func (w *byteSliceWrapper) ScanUUID(v UUID) error { if !v.Valid { *w = nil diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 2eb50610..501e0c59 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -49,7 +49,7 @@ type UndecodedBytes []byte type scanPlanAnyToUndecodedBytes struct{} -func (scanPlanAnyToUndecodedBytes) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanAnyToUndecodedBytes) Scan(src []byte, dst interface{}) error { dstBuf := dst.(*UndecodedBytes) if src == nil { *dstBuf = nil @@ -170,7 +170,7 @@ func (ByteaCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interf type scanPlanBinaryBytesToBytes struct{} -func (scanPlanBinaryBytesToBytes) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryBytesToBytes) Scan(src []byte, dst interface{}) error { dstBuf := dst.(*[]byte) if src == nil { *dstBuf = nil @@ -184,14 +184,14 @@ func (scanPlanBinaryBytesToBytes) Scan(ci *ConnInfo, oid uint32, formatCode int1 type scanPlanBinaryBytesToBytesScanner struct{} -func (scanPlanBinaryBytesToBytesScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryBytesToBytesScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(BytesScanner) return scanner.ScanBytes(src) } type scanPlanTextByteaToBytes struct{} -func (scanPlanTextByteaToBytes) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextByteaToBytes) Scan(src []byte, dst interface{}) error { dstBuf := dst.(*[]byte) if src == nil { *dstBuf = nil @@ -209,7 +209,7 @@ func (scanPlanTextByteaToBytes) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanTextByteaToBytesScanner struct{} -func (scanPlanTextByteaToBytesScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextByteaToBytesScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(BytesScanner) buf, err := decodeHexBytea(src) if err != nil { diff --git a/pgtype/circle.go b/pgtype/circle.go index ae8aa352..8e06de88 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -43,7 +43,7 @@ func (dst *Circle) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextAnyToCircleScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) + return scanPlanTextAnyToCircleScanner{}.Scan([]byte(src), dst) } return fmt.Errorf("cannot scan %T", src) @@ -161,7 +161,7 @@ func (c CircleCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []b type scanPlanBinaryCircleToCircleScanner struct{} -func (scanPlanBinaryCircleToCircleScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryCircleToCircleScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(CircleScanner) if src == nil { @@ -185,7 +185,7 @@ func (scanPlanBinaryCircleToCircleScanner) Scan(ci *ConnInfo, oid uint32, format type scanPlanTextAnyToCircleScanner struct{} -func (scanPlanTextAnyToCircleScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToCircleScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(CircleScanner) if src == nil { diff --git a/pgtype/date.go b/pgtype/date.go index fde66745..adfa0999 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -47,7 +47,7 @@ func (dst *Date) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextAnyToDateScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) + return scanPlanTextAnyToDateScanner{}.Scan([]byte(src), dst) case time.Time: *dst = Date{Time: src, Valid: true} return nil @@ -216,7 +216,7 @@ func (DateCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa type scanPlanBinaryDateToDateScanner struct{} -func (scanPlanBinaryDateToDateScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryDateToDateScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(DateScanner) if src == nil { @@ -242,7 +242,7 @@ func (scanPlanBinaryDateToDateScanner) Scan(ci *ConnInfo, oid uint32, formatCode type scanPlanTextAnyToDateScanner struct{} -func (scanPlanTextAnyToDateScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToDateScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(DateScanner) if src == nil { diff --git a/pgtype/enum_codec.go b/pgtype/enum_codec.go index d405245f..3bf29f4a 100644 --- a/pgtype/enum_codec.go +++ b/pgtype/enum_codec.go @@ -86,7 +86,7 @@ type scanPlanTextAnyToEnumString struct { codec *EnumCodec } -func (plan *scanPlanTextAnyToEnumString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (plan *scanPlanTextAnyToEnumString) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -101,7 +101,7 @@ type scanPlanTextAnyToEnumTextScanner struct { codec *EnumCodec } -func (plan *scanPlanTextAnyToEnumTextScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (plan *scanPlanTextAnyToEnumTextScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(TextScanner) if src == nil { diff --git a/pgtype/float4.go b/pgtype/float4.go index 7699f656..db9b2215 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -164,7 +164,7 @@ func (Float4Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target inter type scanPlanBinaryFloat4ToFloat32 struct{} -func (scanPlanBinaryFloat4ToFloat32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryFloat4ToFloat32) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -182,7 +182,7 @@ func (scanPlanBinaryFloat4ToFloat32) Scan(ci *ConnInfo, oid uint32, formatCode i type scanPlanBinaryFloat4ToFloat64Scanner struct{} -func (scanPlanBinaryFloat4ToFloat64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryFloat4ToFloat64Scanner) Scan(src []byte, dst interface{}) error { s := (dst).(Float64Scanner) if src == nil { @@ -199,7 +199,7 @@ func (scanPlanBinaryFloat4ToFloat64Scanner) Scan(ci *ConnInfo, oid uint32, forma type scanPlanBinaryFloat4ToInt64Scanner struct{} -func (scanPlanBinaryFloat4ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryFloat4ToInt64Scanner) Scan(src []byte, dst interface{}) error { s := (dst).(Int64Scanner) if src == nil { @@ -222,7 +222,7 @@ func (scanPlanBinaryFloat4ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatC type scanPlanTextAnyToFloat32 struct{} -func (scanPlanTextAnyToFloat32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToFloat32) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } diff --git a/pgtype/float8.go b/pgtype/float8.go index 86638ab1..96dcb0f3 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -202,7 +202,7 @@ func (Float8Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target inter type scanPlanBinaryFloat8ToFloat64 struct{} -func (scanPlanBinaryFloat8ToFloat64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryFloat8ToFloat64) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -220,7 +220,7 @@ func (scanPlanBinaryFloat8ToFloat64) Scan(ci *ConnInfo, oid uint32, formatCode i type scanPlanBinaryFloat8ToFloat64Scanner struct{} -func (scanPlanBinaryFloat8ToFloat64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryFloat8ToFloat64Scanner) Scan(src []byte, dst interface{}) error { s := (dst).(Float64Scanner) if src == nil { @@ -237,7 +237,7 @@ func (scanPlanBinaryFloat8ToFloat64Scanner) Scan(ci *ConnInfo, oid uint32, forma type scanPlanBinaryFloat8ToInt64Scanner struct{} -func (scanPlanBinaryFloat8ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryFloat8ToInt64Scanner) Scan(src []byte, dst interface{}) error { s := (dst).(Int64Scanner) if src == nil { @@ -260,7 +260,7 @@ func (scanPlanBinaryFloat8ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatC type scanPlanTextAnyToFloat64 struct{} -func (scanPlanTextAnyToFloat64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToFloat64) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -278,7 +278,7 @@ func (scanPlanTextAnyToFloat64) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanTextAnyToFloat64Scanner struct{} -func (scanPlanTextAnyToFloat64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToFloat64Scanner) Scan(src []byte, dst interface{}) error { s := (dst).(Float64Scanner) if src == nil { diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 6ff8164c..dc5caa84 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -43,7 +43,7 @@ func (h *Hstore) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextAnyToHstoreScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), h) + return scanPlanTextAnyToHstoreScanner{}.Scan([]byte(src), h) } return fmt.Errorf("cannot scan %T", src) @@ -170,7 +170,7 @@ func (HstoreCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target inter type scanPlanBinaryHstoreToHstoreScanner struct{} -func (scanPlanBinaryHstoreToHstoreScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(HstoreScanner) if src == nil { @@ -213,7 +213,7 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(ci *ConnInfo, oid uint32, format } var value Text - err := scanPlanTextAnyToTextScanner{}.Scan(ci, TextOID, TextFormatCode, valueBuf, &value) + err := scanPlanTextAnyToTextScanner{}.Scan(valueBuf, &value) if err != nil { return err } @@ -230,7 +230,7 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(ci *ConnInfo, oid uint32, format type scanPlanTextAnyToHstoreScanner struct{} -func (scanPlanTextAnyToHstoreScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(HstoreScanner) if src == nil { diff --git a/pgtype/inet.go b/pgtype/inet.go index f88d1712..9530d1a2 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -46,7 +46,7 @@ func (dst *Inet) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextAnyToInetScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) + return scanPlanTextAnyToInetScanner{}.Scan([]byte(src), dst) } return fmt.Errorf("cannot scan %T", src) @@ -182,7 +182,7 @@ func (c InetCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt type scanPlanBinaryInetToInetScanner struct{} -func (scanPlanBinaryInetToInetScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInetToInetScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(InetScanner) if src == nil { @@ -211,7 +211,7 @@ func (scanPlanBinaryInetToInetScanner) Scan(ci *ConnInfo, oid uint32, formatCode type scanPlanTextAnyToInetScanner struct{} -func (scanPlanTextAnyToInetScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToInetScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(InetScanner) if src == nil { diff --git a/pgtype/int.go b/pgtype/int.go index 553d4dd0..a5b1c0a5 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -292,7 +292,7 @@ func (c Int2Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt type scanPlanBinaryInt2ToInt8 struct{} -func (scanPlanBinaryInt2ToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToInt8) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -320,7 +320,7 @@ func (scanPlanBinaryInt2ToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanBinaryInt2ToUint8 struct{} -func (scanPlanBinaryInt2ToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToUint8) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -350,7 +350,7 @@ func (scanPlanBinaryInt2ToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16 type scanPlanBinaryInt2ToInt16 struct{} -func (scanPlanBinaryInt2ToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToInt16) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -371,7 +371,7 @@ func (scanPlanBinaryInt2ToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16 type scanPlanBinaryInt2ToUint16 struct{} -func (scanPlanBinaryInt2ToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToUint16) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -397,7 +397,7 @@ func (scanPlanBinaryInt2ToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int1 type scanPlanBinaryInt2ToInt32 struct{} -func (scanPlanBinaryInt2ToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToInt32) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -418,7 +418,7 @@ func (scanPlanBinaryInt2ToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16 type scanPlanBinaryInt2ToUint32 struct{} -func (scanPlanBinaryInt2ToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToUint32) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -444,7 +444,7 @@ func (scanPlanBinaryInt2ToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int1 type scanPlanBinaryInt2ToInt64 struct{} -func (scanPlanBinaryInt2ToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToInt64) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -465,7 +465,7 @@ func (scanPlanBinaryInt2ToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16 type scanPlanBinaryInt2ToUint64 struct{} -func (scanPlanBinaryInt2ToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToUint64) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -491,7 +491,7 @@ func (scanPlanBinaryInt2ToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int1 type scanPlanBinaryInt2ToInt struct{} -func (scanPlanBinaryInt2ToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToInt) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -512,7 +512,7 @@ func (scanPlanBinaryInt2ToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanBinaryInt2ToUint struct{} -func (scanPlanBinaryInt2ToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToUint) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -538,7 +538,7 @@ func (scanPlanBinaryInt2ToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanBinaryInt2ToInt64Scanner struct{} -func (scanPlanBinaryInt2ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToInt64Scanner) Scan(src []byte, dst interface{}) error { s, ok := (dst).(Int64Scanner) if !ok { return ErrScanTargetTypeChanged @@ -829,7 +829,7 @@ func (c Int4Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt type scanPlanBinaryInt4ToInt8 struct{} -func (scanPlanBinaryInt4ToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToInt8) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -857,7 +857,7 @@ func (scanPlanBinaryInt4ToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanBinaryInt4ToUint8 struct{} -func (scanPlanBinaryInt4ToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToUint8) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -887,7 +887,7 @@ func (scanPlanBinaryInt4ToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16 type scanPlanBinaryInt4ToInt16 struct{} -func (scanPlanBinaryInt4ToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToInt16) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -915,7 +915,7 @@ func (scanPlanBinaryInt4ToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16 type scanPlanBinaryInt4ToUint16 struct{} -func (scanPlanBinaryInt4ToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToUint16) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -945,7 +945,7 @@ func (scanPlanBinaryInt4ToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int1 type scanPlanBinaryInt4ToInt32 struct{} -func (scanPlanBinaryInt4ToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToInt32) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -966,7 +966,7 @@ func (scanPlanBinaryInt4ToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16 type scanPlanBinaryInt4ToUint32 struct{} -func (scanPlanBinaryInt4ToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToUint32) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -992,7 +992,7 @@ func (scanPlanBinaryInt4ToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int1 type scanPlanBinaryInt4ToInt64 struct{} -func (scanPlanBinaryInt4ToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToInt64) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1013,7 +1013,7 @@ func (scanPlanBinaryInt4ToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16 type scanPlanBinaryInt4ToUint64 struct{} -func (scanPlanBinaryInt4ToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToUint64) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1039,7 +1039,7 @@ func (scanPlanBinaryInt4ToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int1 type scanPlanBinaryInt4ToInt struct{} -func (scanPlanBinaryInt4ToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToInt) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1060,7 +1060,7 @@ func (scanPlanBinaryInt4ToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanBinaryInt4ToUint struct{} -func (scanPlanBinaryInt4ToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToUint) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1086,7 +1086,7 @@ func (scanPlanBinaryInt4ToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanBinaryInt4ToInt64Scanner struct{} -func (scanPlanBinaryInt4ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToInt64Scanner) Scan(src []byte, dst interface{}) error { s, ok := (dst).(Int64Scanner) if !ok { return ErrScanTargetTypeChanged @@ -1377,7 +1377,7 @@ func (c Int8Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt type scanPlanBinaryInt8ToInt8 struct{} -func (scanPlanBinaryInt8ToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToInt8) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1405,7 +1405,7 @@ func (scanPlanBinaryInt8ToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanBinaryInt8ToUint8 struct{} -func (scanPlanBinaryInt8ToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToUint8) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1435,7 +1435,7 @@ func (scanPlanBinaryInt8ToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16 type scanPlanBinaryInt8ToInt16 struct{} -func (scanPlanBinaryInt8ToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToInt16) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1463,7 +1463,7 @@ func (scanPlanBinaryInt8ToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16 type scanPlanBinaryInt8ToUint16 struct{} -func (scanPlanBinaryInt8ToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToUint16) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1493,7 +1493,7 @@ func (scanPlanBinaryInt8ToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int1 type scanPlanBinaryInt8ToInt32 struct{} -func (scanPlanBinaryInt8ToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToInt32) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1521,7 +1521,7 @@ func (scanPlanBinaryInt8ToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16 type scanPlanBinaryInt8ToUint32 struct{} -func (scanPlanBinaryInt8ToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToUint32) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1551,7 +1551,7 @@ func (scanPlanBinaryInt8ToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int1 type scanPlanBinaryInt8ToInt64 struct{} -func (scanPlanBinaryInt8ToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToInt64) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1572,7 +1572,7 @@ func (scanPlanBinaryInt8ToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16 type scanPlanBinaryInt8ToUint64 struct{} -func (scanPlanBinaryInt8ToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToUint64) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1598,7 +1598,7 @@ func (scanPlanBinaryInt8ToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int1 type scanPlanBinaryInt8ToInt struct{} -func (scanPlanBinaryInt8ToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToInt) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1626,7 +1626,7 @@ func (scanPlanBinaryInt8ToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanBinaryInt8ToUint struct{} -func (scanPlanBinaryInt8ToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToUint) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1656,7 +1656,7 @@ func (scanPlanBinaryInt8ToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanBinaryInt8ToInt64Scanner struct{} -func (scanPlanBinaryInt8ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToInt64Scanner) Scan(src []byte, dst interface{}) error { s, ok := (dst).(Int64Scanner) if !ok { return ErrScanTargetTypeChanged @@ -1677,7 +1677,7 @@ func (scanPlanBinaryInt8ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCod type scanPlanTextAnyToInt8 struct{} -func (scanPlanTextAnyToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt8) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1698,7 +1698,7 @@ func (scanPlanTextAnyToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, sr type scanPlanTextAnyToUint8 struct{} -func (scanPlanTextAnyToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToUint8) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1719,7 +1719,7 @@ func (scanPlanTextAnyToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16, s type scanPlanTextAnyToInt16 struct{} -func (scanPlanTextAnyToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt16) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1740,7 +1740,7 @@ func (scanPlanTextAnyToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, s type scanPlanTextAnyToUint16 struct{} -func (scanPlanTextAnyToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToUint16) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1761,7 +1761,7 @@ func (scanPlanTextAnyToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanTextAnyToInt32 struct{} -func (scanPlanTextAnyToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt32) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1782,7 +1782,7 @@ func (scanPlanTextAnyToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, s type scanPlanTextAnyToUint32 struct{} -func (scanPlanTextAnyToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToUint32) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1803,7 +1803,7 @@ func (scanPlanTextAnyToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanTextAnyToInt64 struct{} -func (scanPlanTextAnyToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt64) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1824,7 +1824,7 @@ func (scanPlanTextAnyToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, s type scanPlanTextAnyToUint64 struct{} -func (scanPlanTextAnyToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToUint64) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1845,7 +1845,7 @@ func (scanPlanTextAnyToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanTextAnyToInt struct{} -func (scanPlanTextAnyToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1866,7 +1866,7 @@ func (scanPlanTextAnyToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src type scanPlanTextAnyToUint struct{} -func (scanPlanTextAnyToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToUint) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1887,7 +1887,7 @@ func (scanPlanTextAnyToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, sr type scanPlanTextAnyToInt64Scanner struct{} -func (scanPlanTextAnyToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt64Scanner) Scan(src []byte, dst interface{}) error { s, ok := (dst).(Int64Scanner) if !ok { return ErrScanTargetTypeChanged diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 6aecb761..8524136f 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -295,7 +295,7 @@ func (c Int<%= pg_byte_size %>Codec) DecodeValue(ci *ConnInfo, oid uint32, forma <% [8, 16, 32, 64].each do |dst_bit_size| %> type scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %> struct{} -func (scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %>) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %>) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -329,7 +329,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %>) Scan(ci *Con type scanPlanBinaryInt<%= pg_byte_size %>ToUint<%= dst_bit_size %> struct{} -func (scanPlanBinaryInt<%= pg_byte_size %>ToUint<%= dst_bit_size %>) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt<%= pg_byte_size %>ToUint<%= dst_bit_size %>) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -361,7 +361,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToUint<%= dst_bit_size %>) Scan(ci *Co <%# PostgreSQL binary format integer to Go machine integers %> type scanPlanBinaryInt<%= pg_byte_size %>ToInt struct{} -func (scanPlanBinaryInt<%= pg_byte_size %>ToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt<%= pg_byte_size %>ToInt) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -393,7 +393,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToInt) Scan(ci *ConnInfo, oid uint32, type scanPlanBinaryInt<%= pg_byte_size %>ToUint struct{} -func (scanPlanBinaryInt<%= pg_byte_size %>ToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt<%= pg_byte_size %>ToUint) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -424,7 +424,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToUint) Scan(ci *ConnInfo, oid uint32, <%# PostgreSQL binary format integer to Go Int64Scanner %> type scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner struct{} -func (scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner) Scan(src []byte, dst interface{}) error { s, ok := (dst).(Int64Scanner) if !ok { return ErrScanTargetTypeChanged @@ -455,7 +455,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner) Scan(ci *ConnInfo, oid ].each do |type_suffix, bit_size| %> type scanPlanTextAnyToInt<%= type_suffix %> struct{} -func (scanPlanTextAnyToInt<%= type_suffix %>) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt<%= type_suffix %>) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -476,7 +476,7 @@ func (scanPlanTextAnyToInt<%= type_suffix %>) Scan(ci *ConnInfo, oid uint32, for type scanPlanTextAnyToUint<%= type_suffix %> struct{} -func (scanPlanTextAnyToUint<%= type_suffix %>) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToUint<%= type_suffix %>) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -498,7 +498,7 @@ func (scanPlanTextAnyToUint<%= type_suffix %>) Scan(ci *ConnInfo, oid uint32, fo type scanPlanTextAnyToInt64Scanner struct{} -func (scanPlanTextAnyToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt64Scanner) Scan(src []byte, dst interface{}) error { s, ok := (dst).(Int64Scanner) if !ok { return ErrScanTargetTypeChanged diff --git a/pgtype/interval.go b/pgtype/interval.go index 41216f37..a20266eb 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -51,7 +51,7 @@ func (interval *Interval) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextAnyToIntervalScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), interval) + return scanPlanTextAnyToIntervalScanner{}.Scan([]byte(src), interval) } return fmt.Errorf("cannot scan %T", src) @@ -171,7 +171,7 @@ func (IntervalCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target int type scanPlanBinaryIntervalToIntervalScanner struct{} -func (scanPlanBinaryIntervalToIntervalScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryIntervalToIntervalScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(IntervalScanner) if src == nil { @@ -191,7 +191,7 @@ func (scanPlanBinaryIntervalToIntervalScanner) Scan(ci *ConnInfo, oid uint32, fo type scanPlanTextAnyToIntervalScanner struct{} -func (scanPlanTextAnyToIntervalScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToIntervalScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(IntervalScanner) if src == nil { diff --git a/pgtype/json.go b/pgtype/json.go index 510b638e..cd8b8ec9 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -65,7 +65,7 @@ func (JSONCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa type scanPlanAnyToString struct{} -func (scanPlanAnyToString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanAnyToString) Scan(src []byte, dst interface{}) error { p := dst.(*string) *p = string(src) return nil @@ -73,7 +73,7 @@ func (scanPlanAnyToString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src type scanPlanJSONToByteSlice struct{} -func (scanPlanJSONToByteSlice) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanJSONToByteSlice) Scan(src []byte, dst interface{}) error { dstBuf := dst.(*[]byte) if src == nil { *dstBuf = nil @@ -87,14 +87,14 @@ func (scanPlanJSONToByteSlice) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanJSONToBytesScanner struct{} -func (scanPlanJSONToBytesScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanJSONToBytesScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(BytesScanner) return scanner.ScanBytes(src) } type scanPlanJSONToJSONUnmarshal struct{} -func (scanPlanJSONToJSONUnmarshal) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst interface{}) error { if src == nil { dstValue := reflect.ValueOf(dst) if dstValue.Kind() == reflect.Ptr { diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index 6e329150..07ea58bc 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -57,9 +57,9 @@ type scanPlanJSONBCodecBinaryUnwrapper struct { textPlan ScanPlan } -func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst interface{}) error { if src == nil { - return plan.textPlan.Scan(ci, oid, formatCode, src, dst) + return plan.textPlan.Scan(src, dst) } if len(src) == 0 { @@ -70,7 +70,7 @@ func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(ci *ConnInfo, oid uint32, fo return fmt.Errorf("unknown jsonb version number %d", src[0]) } - return plan.textPlan.Scan(ci, oid, formatCode, src[1:], dst) + return plan.textPlan.Scan(src[1:], dst) } func (c JSONBCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { diff --git a/pgtype/line.go b/pgtype/line.go index db584862..acae903b 100644 --- a/pgtype/line.go +++ b/pgtype/line.go @@ -46,7 +46,7 @@ func (line *Line) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextAnyToLineScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), line) + return scanPlanTextAnyToLineScanner{}.Scan([]byte(src), line) } return fmt.Errorf("cannot scan %T", src) @@ -148,7 +148,7 @@ func (LineCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa type scanPlanBinaryLineToLineScanner struct{} -func (scanPlanBinaryLineToLineScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryLineToLineScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(LineScanner) if src == nil { @@ -173,7 +173,7 @@ func (scanPlanBinaryLineToLineScanner) Scan(ci *ConnInfo, oid uint32, formatCode type scanPlanTextAnyToLineScanner struct{} -func (scanPlanTextAnyToLineScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToLineScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(LineScanner) if src == nil { diff --git a/pgtype/lseg.go b/pgtype/lseg.go index 26730e85..471b36b2 100644 --- a/pgtype/lseg.go +++ b/pgtype/lseg.go @@ -42,7 +42,7 @@ func (lseg *Lseg) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextAnyToLsegScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), lseg) + return scanPlanTextAnyToLsegScanner{}.Scan([]byte(src), lseg) } return fmt.Errorf("cannot scan %T", src) @@ -146,7 +146,7 @@ func (LsegCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa type scanPlanBinaryLsegToLsegScanner struct{} -func (scanPlanBinaryLsegToLsegScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryLsegToLsegScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(LsegScanner) if src == nil { @@ -173,7 +173,7 @@ func (scanPlanBinaryLsegToLsegScanner) Scan(ci *ConnInfo, oid uint32, formatCode type scanPlanTextAnyToLsegScanner struct{} -func (scanPlanTextAnyToLsegScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToLsegScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(LsegScanner) if src == nil { diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go index 0ac003ae..5b42811a 100644 --- a/pgtype/macaddr.go +++ b/pgtype/macaddr.go @@ -101,7 +101,7 @@ func (MacaddrCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target inte type scanPlanBinaryMacaddrToHardwareAddr struct{} -func (scanPlanBinaryMacaddrToHardwareAddr) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryMacaddrToHardwareAddr) Scan(src []byte, dst interface{}) error { dstBuf := dst.(*net.HardwareAddr) if src == nil { *dstBuf = nil @@ -115,7 +115,7 @@ func (scanPlanBinaryMacaddrToHardwareAddr) Scan(ci *ConnInfo, oid uint32, format type scanPlanBinaryMacaddrToTextScanner struct{} -func (scanPlanBinaryMacaddrToTextScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryMacaddrToTextScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(TextScanner) if src == nil { return scanner.ScanText(Text{}) @@ -126,7 +126,7 @@ func (scanPlanBinaryMacaddrToTextScanner) Scan(ci *ConnInfo, oid uint32, formatC type scanPlanTextMacaddrToHardwareAddr struct{} -func (scanPlanTextMacaddrToHardwareAddr) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextMacaddrToHardwareAddr) Scan(src []byte, dst interface{}) error { p := dst.(*net.HardwareAddr) if src == nil { diff --git a/pgtype/numeric.go b/pgtype/numeric.go index 435c9618..5bdbd4d5 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -175,7 +175,7 @@ func (n *Numeric) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextAnyToNumericScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), n) + return scanPlanTextAnyToNumericScanner{}.Scan([]byte(src), n) } return fmt.Errorf("cannot scan %T", src) @@ -522,7 +522,7 @@ func (NumericCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target inte type scanPlanBinaryNumericToNumericScanner struct{} -func (scanPlanBinaryNumericToNumericScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryNumericToNumericScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(NumericScanner) if src == nil { @@ -628,7 +628,7 @@ func (scanPlanBinaryNumericToNumericScanner) Scan(ci *ConnInfo, oid uint32, form type scanPlanBinaryNumericToFloat64Scanner struct{} -func (scanPlanBinaryNumericToFloat64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryNumericToFloat64Scanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(Float64Scanner) if src == nil { @@ -637,7 +637,7 @@ func (scanPlanBinaryNumericToFloat64Scanner) Scan(ci *ConnInfo, oid uint32, form var n Numeric - err := scanPlanBinaryNumericToNumericScanner{}.Scan(ci, oid, formatCode, src, &n) + err := scanPlanBinaryNumericToNumericScanner{}.Scan(src, &n) if err != nil { return err } @@ -652,7 +652,7 @@ func (scanPlanBinaryNumericToFloat64Scanner) Scan(ci *ConnInfo, oid uint32, form type scanPlanBinaryNumericToInt64Scanner struct{} -func (scanPlanBinaryNumericToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryNumericToInt64Scanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(Int64Scanner) if src == nil { @@ -661,7 +661,7 @@ func (scanPlanBinaryNumericToInt64Scanner) Scan(ci *ConnInfo, oid uint32, format var n Numeric - err := scanPlanBinaryNumericToNumericScanner{}.Scan(ci, oid, formatCode, src, &n) + err := scanPlanBinaryNumericToNumericScanner{}.Scan(src, &n) if err != nil { return err } @@ -680,7 +680,7 @@ func (scanPlanBinaryNumericToInt64Scanner) Scan(ci *ConnInfo, oid uint32, format type scanPlanTextAnyToNumericScanner struct{} -func (scanPlanTextAnyToNumericScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToNumericScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(NumericScanner) if src == nil { diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 0d89dc2d..c74fb9a3 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -64,7 +64,7 @@ func mustParseNumeric(t *testing.T, src string) pgtype.Numeric { var n pgtype.Numeric plan := pgtype.NumericCodec{}.PlanScan(nil, pgtype.NumericOID, pgtype.TextFormatCode, &n, false) require.NotNil(t, plan) - err := plan.Scan(nil, pgtype.NumericOID, pgtype.TextFormatCode, []byte(src), &n) + err := plan.Scan([]byte(src), &n) require.NoError(t, err) return n } diff --git a/pgtype/path.go b/pgtype/path.go index be7daaa0..62a23219 100644 --- a/pgtype/path.go +++ b/pgtype/path.go @@ -43,7 +43,7 @@ func (path *Path) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextAnyToPathScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), path) + return scanPlanTextAnyToPathScanner{}.Scan([]byte(src), path) } return fmt.Errorf("cannot scan %T", src) @@ -173,7 +173,7 @@ func (PathCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa type scanPlanBinaryPathToPathScanner struct{} -func (scanPlanBinaryPathToPathScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryPathToPathScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(PathScanner) if src == nil { @@ -211,7 +211,7 @@ func (scanPlanBinaryPathToPathScanner) Scan(ci *ConnInfo, oid uint32, formatCode type scanPlanTextAnyToPathScanner struct{} -func (scanPlanTextAnyToPathScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToPathScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(PathScanner) if src == nil { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 73bd249c..150f1a23 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -3,10 +3,8 @@ package pgtype import ( "database/sql" "database/sql/driver" - "encoding/binary" "errors" "fmt" - "math" "net" "reflect" "time" @@ -198,14 +196,14 @@ func NewConnInfo() *ConnInfo { TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ TryWrapDerefPointerEncodePlan, - TryWrapFindUnderlyingTypeEncodePlan, TryWrapBuiltinTypeEncodePlan, + TryWrapFindUnderlyingTypeEncodePlan, }, TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{ TryPointerPointerScanPlan, - TryFindUnderlyingTypeScanPlan, TryWrapBuiltinTypeScanPlan, + TryFindUnderlyingTypeScanPlan, }, } @@ -409,22 +407,19 @@ type EncodePlan interface { // ScanPlan is a precompiled plan to scan into a type of destination. type ScanPlan interface { - // Scan scans src into dst. If the dst type has changed in an incompatible way a ScanPlan should automatically - // replan and scan. - Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error + // Scan scans src into target. + Scan(src []byte, target interface{}) error } -type scanPlanDstResultDecoder struct{} - -func (scanPlanDstResultDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) +type scanPlanCodecSQLScanner struct { + c Codec + ci *ConnInfo + oid uint32 + formatCode int16 } -type scanPlanCodecSQLScanner struct{ c Codec } - -func (plan *scanPlanCodecSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - value, err := plan.c.DecodeDatabaseSQLValue(ci, oid, formatCode, src) +func (plan *scanPlanCodecSQLScanner) Scan(src []byte, dst interface{}) error { + value, err := plan.c.DecodeDatabaseSQLValue(plan.ci, plan.oid, plan.formatCode, src) if err != nil { return err } @@ -433,135 +428,56 @@ func (plan *scanPlanCodecSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode i return scanner.Scan(value) } -type scanPlanSQLScanner struct{} +type scanPlanSQLScanner struct { + formatCode int16 +} -func (scanPlanSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (plan *scanPlanSQLScanner) Scan(src []byte, dst interface{}) error { scanner := dst.(sql.Scanner) if src == nil { // This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the // text format path would be converted to empty string. return scanner.Scan(nil) - } else if formatCode == BinaryFormatCode { + } else if plan.formatCode == BinaryFormatCode { return scanner.Scan(src) } else { return scanner.Scan(string(src)) } } -type scanPlanReflection struct{} - -func (scanPlanReflection) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - // We might be given a pointer to something that implements the decoder interface(s), - // even though the pointer itself doesn't. - refVal := reflect.ValueOf(dst) - if refVal.Kind() == reflect.Ptr && refVal.Type().Elem().Kind() == reflect.Ptr { - // If the database returned NULL, then we set dest as nil to indicate that. - if src == nil { - nilPtr := reflect.Zero(refVal.Type().Elem()) - refVal.Elem().Set(nilPtr) - return nil - } - - // We need to allocate an element, and set the destination to it - // Then we can retry as that element. - elemPtr := reflect.New(refVal.Type().Elem().Elem()) - refVal.Elem().Set(elemPtr) - - plan := ci.PlanScan(oid, formatCode, elemPtr.Interface()) - return plan.Scan(ci, oid, formatCode, src, elemPtr.Interface()) - } - - return scanUnknownType(oid, formatCode, src, dst) -} - -type scanPlanBinaryInt64 struct{} - -func (scanPlanBinaryInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - if len(src) != 8 { - return fmt.Errorf("invalid length for int8: %v", len(src)) - } - - if p, ok := (dst).(*int64); ok { - *p = int64(binary.BigEndian.Uint64(src)) - return nil - } - - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) -} - -type scanPlanBinaryFloat32 struct{} - -func (scanPlanBinaryFloat32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - if len(src) != 4 { - return fmt.Errorf("invalid length for int4: %v", len(src)) - } - - if p, ok := (dst).(*float32); ok { - n := int32(binary.BigEndian.Uint32(src)) - *p = float32(math.Float32frombits(uint32(n))) - return nil - } - - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) -} - -type scanPlanBinaryFloat64 struct{} - -func (scanPlanBinaryFloat64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - if len(src) != 8 { - return fmt.Errorf("invalid length for int8: %v", len(src)) - } - - if p, ok := (dst).(*float64); ok { - n := int64(binary.BigEndian.Uint64(src)) - *p = float64(math.Float64frombits(uint64(n))) - return nil - } - - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) -} - -type scanPlanBinaryBytes struct{} - -func (scanPlanBinaryBytes) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if p, ok := (dst).(*[]byte); ok { - *p = src - return nil - } - - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) -} - type scanPlanString struct{} -func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanString) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } - if p, ok := (dst).(*string); ok { - *p = string(src) + p := (dst).(*string) + *p = string(src) + return nil +} + +type scanPlanAnyTextToBytes struct{} + +func (scanPlanAnyTextToBytes) Scan(src []byte, dst interface{}) error { + dstBuf := dst.(*[]byte) + if src == nil { + *dstBuf = nil return nil } - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type scanPlanFail struct { + oid uint32 + formatCode int16 +} + +func (plan *scanPlanFail) Scan(src []byte, dst interface{}) error { + return fmt.Errorf("cannot scan OID %v in format %v into %T", plan.oid, plan.formatCode, dst) } // TryWrapScanPlanFunc is a function that tries to create a wrapper plan for target. If successful it returns a plan @@ -577,12 +493,7 @@ type pointerPointerScanPlan struct { func (plan *pointerPointerScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *pointerPointerScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if plan.dstType != reflect.TypeOf(dst) { - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) - } - +func (plan *pointerPointerScanPlan) Scan(src []byte, dst interface{}) error { el := reflect.ValueOf(dst).Elem() if src == nil { el.Set(reflect.Zero(el.Type())) @@ -590,7 +501,7 @@ func (plan *pointerPointerScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode in } el.Set(reflect.New(el.Type().Elem())) - return plan.next.Scan(ci, oid, formatCode, src, el.Interface()) + return plan.next.Scan(src, el.Interface()) } // TryPointerPointerScanPlan handles a pointer to a pointer by setting the target to nil for SQL NULL and allocating and @@ -636,13 +547,8 @@ type underlyingTypeScanPlan struct { func (plan *underlyingTypeScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *underlyingTypeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - if plan.dstType != reflect.TypeOf(dst) { - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) - } - - return plan.next.Scan(ci, oid, formatCode, src, reflect.ValueOf(dst).Convert(plan.nextDstType).Interface()) +func (plan *underlyingTypeScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, reflect.ValueOf(dst).Convert(plan.nextDstType).Interface()) } // TryFindUnderlyingTypeScanPlan tries to convert to a Go builtin type. e.g. If value was of type MyString and @@ -657,9 +563,17 @@ func TryFindUnderlyingTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSet if dstValue.Kind() == reflect.Ptr { elemValue := dstValue.Elem() nextDstType := elemKindToPointerTypes[elemValue.Kind()] + if nextDstType == nil && elemValue.Kind() == reflect.Slice { + if elemValue.Type().Elem().Kind() == reflect.Uint8 { + var v *[]byte + nextDstType = reflect.TypeOf(v) + } + } + if nextDstType != nil && dstValue.Type() != nextDstType { return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true } + } return nil, nil, false @@ -728,8 +642,8 @@ type wrapInt8ScanPlan struct { func (plan *wrapInt8ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapInt8ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*int8Wrapper)(dst.(*int8))) +func (plan *wrapInt8ScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*int8Wrapper)(dst.(*int8))) } type wrapInt16ScanPlan struct { @@ -738,8 +652,8 @@ type wrapInt16ScanPlan struct { func (plan *wrapInt16ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapInt16ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*int16Wrapper)(dst.(*int16))) +func (plan *wrapInt16ScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*int16Wrapper)(dst.(*int16))) } type wrapInt32ScanPlan struct { @@ -748,8 +662,8 @@ type wrapInt32ScanPlan struct { func (plan *wrapInt32ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapInt32ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*int32Wrapper)(dst.(*int32))) +func (plan *wrapInt32ScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*int32Wrapper)(dst.(*int32))) } type wrapInt64ScanPlan struct { @@ -758,8 +672,8 @@ type wrapInt64ScanPlan struct { func (plan *wrapInt64ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapInt64ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*int64Wrapper)(dst.(*int64))) +func (plan *wrapInt64ScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*int64Wrapper)(dst.(*int64))) } type wrapIntScanPlan struct { @@ -768,8 +682,8 @@ type wrapIntScanPlan struct { func (plan *wrapIntScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapIntScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*intWrapper)(dst.(*int))) +func (plan *wrapIntScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*intWrapper)(dst.(*int))) } type wrapUint8ScanPlan struct { @@ -778,8 +692,8 @@ type wrapUint8ScanPlan struct { func (plan *wrapUint8ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapUint8ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*uint8Wrapper)(dst.(*uint8))) +func (plan *wrapUint8ScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*uint8Wrapper)(dst.(*uint8))) } type wrapUint16ScanPlan struct { @@ -788,8 +702,8 @@ type wrapUint16ScanPlan struct { func (plan *wrapUint16ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapUint16ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*uint16Wrapper)(dst.(*uint16))) +func (plan *wrapUint16ScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*uint16Wrapper)(dst.(*uint16))) } type wrapUint32ScanPlan struct { @@ -798,8 +712,8 @@ type wrapUint32ScanPlan struct { func (plan *wrapUint32ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapUint32ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*uint32Wrapper)(dst.(*uint32))) +func (plan *wrapUint32ScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*uint32Wrapper)(dst.(*uint32))) } type wrapUint64ScanPlan struct { @@ -808,8 +722,8 @@ type wrapUint64ScanPlan struct { func (plan *wrapUint64ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapUint64ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*uint64Wrapper)(dst.(*uint64))) +func (plan *wrapUint64ScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*uint64Wrapper)(dst.(*uint64))) } type wrapUintScanPlan struct { @@ -818,8 +732,8 @@ type wrapUintScanPlan struct { func (plan *wrapUintScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapUintScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*uintWrapper)(dst.(*uint))) +func (plan *wrapUintScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*uintWrapper)(dst.(*uint))) } type wrapFloat32ScanPlan struct { @@ -828,8 +742,8 @@ type wrapFloat32ScanPlan struct { func (plan *wrapFloat32ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapFloat32ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*float32Wrapper)(dst.(*float32))) +func (plan *wrapFloat32ScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*float32Wrapper)(dst.(*float32))) } type wrapFloat64ScanPlan struct { @@ -838,8 +752,8 @@ type wrapFloat64ScanPlan struct { func (plan *wrapFloat64ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapFloat64ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*float64Wrapper)(dst.(*float64))) +func (plan *wrapFloat64ScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*float64Wrapper)(dst.(*float64))) } type wrapStringScanPlan struct { @@ -848,8 +762,8 @@ type wrapStringScanPlan struct { func (plan *wrapStringScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapStringScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*stringWrapper)(dst.(*string))) +func (plan *wrapStringScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*stringWrapper)(dst.(*string))) } type wrapTimeScanPlan struct { @@ -858,8 +772,8 @@ type wrapTimeScanPlan struct { func (plan *wrapTimeScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapTimeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*timeWrapper)(dst.(*time.Time))) +func (plan *wrapTimeScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*timeWrapper)(dst.(*time.Time))) } type wrapDurationScanPlan struct { @@ -868,8 +782,8 @@ type wrapDurationScanPlan struct { func (plan *wrapDurationScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapDurationScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*durationWrapper)(dst.(*time.Duration))) +func (plan *wrapDurationScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*durationWrapper)(dst.(*time.Duration))) } type wrapNetIPNetScanPlan struct { @@ -878,8 +792,8 @@ type wrapNetIPNetScanPlan struct { func (plan *wrapNetIPNetScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapNetIPNetScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*netIPNetWrapper)(dst.(*net.IPNet))) +func (plan *wrapNetIPNetScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*netIPNetWrapper)(dst.(*net.IPNet))) } type wrapNetIPScanPlan struct { @@ -888,8 +802,8 @@ type wrapNetIPScanPlan struct { func (plan *wrapNetIPScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapNetIPScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*netIPWrapper)(dst.(*net.IP))) +func (plan *wrapNetIPScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*netIPWrapper)(dst.(*net.IP))) } type wrapMapStringToPointerStringScanPlan struct { @@ -898,8 +812,8 @@ type wrapMapStringToPointerStringScanPlan struct { func (plan *wrapMapStringToPointerStringScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapMapStringToPointerStringScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*mapStringToPointerStringWrapper)(dst.(*map[string]*string))) +func (plan *wrapMapStringToPointerStringScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*mapStringToPointerStringWrapper)(dst.(*map[string]*string))) } type wrapMapStringToStringScanPlan struct { @@ -908,8 +822,8 @@ type wrapMapStringToStringScanPlan struct { func (plan *wrapMapStringToStringScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapMapStringToStringScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*mapStringToStringWrapper)(dst.(*map[string]string))) +func (plan *wrapMapStringToStringScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*mapStringToStringWrapper)(dst.(*map[string]string))) } type wrapByte16ScanPlan struct { @@ -918,8 +832,8 @@ type wrapByte16ScanPlan struct { func (plan *wrapByte16ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapByte16ScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*byte16Wrapper)(dst.(*[16]byte))) +func (plan *wrapByte16ScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*byte16Wrapper)(dst.(*[16]byte))) } type wrapByteSliceScanPlan struct { @@ -928,16 +842,19 @@ type wrapByteSliceScanPlan struct { func (plan *wrapByteSliceScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapByteSliceScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - return plan.next.Scan(ci, oid, formatCode, src, (*byteSliceWrapper)(dst.(*[]byte))) +func (plan *wrapByteSliceScanPlan) Scan(src []byte, dst interface{}) error { + return plan.next.Scan(src, (*byteSliceWrapper)(dst.(*[]byte))) } type pointerEmptyInterfaceScanPlan struct { - codec Codec + codec Codec + ci *ConnInfo + oid uint32 + formatCode int16 } -func (plan *pointerEmptyInterfaceScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - value, err := plan.codec.DecodeValue(ci, oid, formatCode, src) +func (plan *pointerEmptyInterfaceScanPlan) Scan(src []byte, dst interface{}) error { + value, err := plan.codec.DecodeValue(plan.ci, plan.oid, plan.formatCode, src) if err != nil { return err } @@ -948,45 +865,28 @@ func (plan *pointerEmptyInterfaceScanPlan) Scan(ci *ConnInfo, oid uint32, format return nil } -// PlanScan prepares a plan to scan a value into dst. -func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan { - if _, ok := dst.(*UndecodedBytes); ok { +// PlanScan prepares a plan to scan a value into target. +func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) ScanPlan { + if _, ok := target.(*UndecodedBytes); ok { return scanPlanAnyToUndecodedBytes{} } switch formatCode { case BinaryFormatCode: - switch dst.(type) { + switch target.(type) { case *string: switch oid { case TextOID, VarcharOID: return scanPlanString{} } - case *int64: - if oid == Int8OID { - return scanPlanBinaryInt64{} - } - case *float32: - if oid == Float4OID { - return scanPlanBinaryFloat32{} - } - case *float64: - if oid == Float8OID { - return scanPlanBinaryFloat64{} - } - case *[]byte: - switch oid { - case ByteaOID, TextOID, VarcharOID, JSONOID: - return scanPlanBinaryBytes{} - } } case TextFormatCode: - switch dst.(type) { + switch target.(type) { case *string: return scanPlanString{} case *[]byte: if oid != ByteaOID { - return scanPlanBinaryBytes{} + return scanPlanAnyTextToBytes{} } case TextScanner: return scanPlanTextAnyToTextScanner{} @@ -995,47 +895,43 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan var dt *DataType - if oid == 0 { - if dataType, ok := ci.DataTypeForValue(dst); ok { - dt = dataType - oid = dt.OID // Preserve assumed OID in case we are recursively called below. - } - } else { - if dataType, ok := ci.DataTypeForOID(oid); ok { - dt = dataType - } + if dataType, ok := ci.DataTypeForOID(oid); ok { + dt = dataType + } else if dataType, ok := ci.DataTypeForValue(target); ok { + dt = dataType + oid = dt.OID // Preserve assumed OID in case we are recursively called below. } if dt != nil { - if plan := dt.Codec.PlanScan(ci, oid, formatCode, dst, false); plan != nil { + if plan := dt.Codec.PlanScan(ci, oid, formatCode, target, false); plan != nil { return plan } - for _, f := range ci.TryWrapScanPlanFuncs { - if wrapperPlan, nextDst, ok := f(dst); ok { - if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { - if _, ok := nextPlan.(scanPlanReflection); !ok { // avoid fallthrough -- this will go away when old system removed. - wrapperPlan.SetNext(nextPlan) - return wrapperPlan - } + if _, ok := target.(*interface{}); ok { + return &pointerEmptyInterfaceScanPlan{codec: dt.Codec, ci: ci, oid: oid, formatCode: formatCode} + } + + if _, ok := target.(sql.Scanner); ok { + return &scanPlanCodecSQLScanner{c: dt.Codec, ci: ci, oid: oid, formatCode: formatCode} + } + } + + for _, f := range ci.TryWrapScanPlanFuncs { + if wrapperPlan, nextDst, ok := f(target); ok { + if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { + if _, failed := nextPlan.(*scanPlanFail); !failed { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan } } } - - if _, ok := dst.(*interface{}); ok { - return &pointerEmptyInterfaceScanPlan{codec: dt.Codec} - } - - if _, ok := dst.(sql.Scanner); ok { - return &scanPlanCodecSQLScanner{c: dt.Codec} - } } - if _, ok := dst.(sql.Scanner); ok { - return scanPlanSQLScanner{} + if _, ok := target.(sql.Scanner); ok { + return &scanPlanSQLScanner{formatCode: formatCode} } - return scanPlanReflection{} + return &scanPlanFail{oid: oid, formatCode: formatCode} } func (ci *ConnInfo) Scan(oid uint32, formatCode int16, src []byte, dst interface{}) error { @@ -1044,7 +940,7 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, src []byte, dst interface } plan := ci.PlanScan(oid, formatCode, dst) - return plan.Scan(ci, oid, formatCode, src, dst) + return plan.Scan(src, dst) } func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) error { @@ -1073,7 +969,7 @@ func codecScan(codec Codec, ci *ConnInfo, oid uint32, format int16, src []byte, if scanPlan == nil { return fmt.Errorf("PlanScan did not find a plan") } - return scanPlan.Scan(ci, oid, format, src, dst) + return scanPlan.Scan(src, dst) } func codecDecodeToTextFormat(codec Codec, ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 9bd665c5..2917c31c 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -118,18 +118,10 @@ func TestConnInfoScanUnknownOIDToStringsAndBytes(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []byte("foo"), b) - err = ci.Scan(unknownOID, pgx.BinaryFormatCode, srcBuf, &b) - assert.NoError(t, err) - assert.Equal(t, []byte("foo"), b) - var rb _byteSlice err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rb) assert.NoError(t, err) assert.Equal(t, []byte("foo"), []byte(rb)) - - err = ci.Scan(unknownOID, pgx.BinaryFormatCode, srcBuf, &b) - assert.NoError(t, err) - assert.Equal(t, []byte("foo"), []byte(rb)) } type pgCustomType struct { @@ -219,21 +211,6 @@ func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) { } } -func TestScanPlanBinaryInt32ScanChangedType(t *testing.T) { - ci := pgtype.NewConnInfo() - src := []byte{0, 0, 0, 42} - var v int32 - - plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) - err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) - require.NoError(t, err) - require.EqualValues(t, 42, v) - - var d pgtype.Int4 - err = plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &d) - require.EqualError(t, err, pgtype.ErrScanTargetTypeChanged.Error()) -} - func BenchmarkConnInfoScanInt4IntoGoInt32(b *testing.B) { ci := pgtype.NewConnInfo() src := []byte{0, 0, 0, 42} @@ -260,7 +237,7 @@ func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) { for i := 0; i < b.N; i++ { v = pgtype.Int4{} - err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + err := plan.Scan(src, &v) if err != nil { b.Fatal(err) } @@ -279,7 +256,7 @@ func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { for i := 0; i < b.N; i++ { v = 0 - err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + err := plan.Scan(src, &v) if err != nil { b.Fatal(err) } diff --git a/pgtype/point.go b/pgtype/point.go index a9be4fdc..0a300fe2 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -77,7 +77,7 @@ func (dst *Point) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextAnyToPointScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) + return scanPlanTextAnyToPointScanner{}.Scan([]byte(src), dst) } return fmt.Errorf("cannot scan %T", src) @@ -214,7 +214,7 @@ func (c PointCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []by type scanPlanBinaryPointToPointScanner struct{} -func (scanPlanBinaryPointToPointScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryPointToPointScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(PointScanner) if src == nil { @@ -236,7 +236,7 @@ func (scanPlanBinaryPointToPointScanner) Scan(ci *ConnInfo, oid uint32, formatCo type scanPlanTextAnyToPointScanner struct{} -func (scanPlanTextAnyToPointScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToPointScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(PointScanner) if src == nil { diff --git a/pgtype/polygon.go b/pgtype/polygon.go index 47dbfed9..e4a1b2af 100644 --- a/pgtype/polygon.go +++ b/pgtype/polygon.go @@ -42,7 +42,7 @@ func (p *Polygon) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextAnyToPolygonScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), p) + return scanPlanTextAnyToPolygonScanner{}.Scan([]byte(src), p) } return fmt.Errorf("cannot scan %T", src) @@ -158,7 +158,7 @@ func (PolygonCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target inte type scanPlanBinaryPolygonToPolygonScanner struct{} -func (scanPlanBinaryPolygonToPolygonScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryPolygonToPolygonScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(PolygonScanner) if src == nil { @@ -193,7 +193,7 @@ func (scanPlanBinaryPolygonToPolygonScanner) Scan(ci *ConnInfo, oid uint32, form type scanPlanTextAnyToPolygonScanner struct{} -func (scanPlanTextAnyToPolygonScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToPolygonScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(PolygonScanner) if src == nil { diff --git a/pgtype/qchar.go b/pgtype/qchar.go index 28c91110..5c712369 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -72,7 +72,7 @@ func (QCharCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interf type scanPlanQcharCodecByte struct{} -func (scanPlanQcharCodecByte) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanQcharCodecByte) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -94,7 +94,7 @@ func (scanPlanQcharCodecByte) Scan(ci *ConnInfo, oid uint32, formatCode int16, s type scanPlanQcharCodecRune struct{} -func (scanPlanQcharCodecRune) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanQcharCodecRune) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } diff --git a/pgtype/text.go b/pgtype/text.go index 3c73cc15..7e4f8b99 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -189,7 +189,7 @@ func (c TextCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt type scanPlanTextAnyToString struct{} -func (scanPlanTextAnyToString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToString) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -202,7 +202,7 @@ func (scanPlanTextAnyToString) Scan(ci *ConnInfo, oid uint32, formatCode int16, type scanPlanAnyToNewByteSlice struct{} -func (scanPlanAnyToNewByteSlice) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanAnyToNewByteSlice) Scan(src []byte, dst interface{}) error { p := (dst).(*[]byte) if src == nil { *p = nil @@ -216,7 +216,7 @@ func (scanPlanAnyToNewByteSlice) Scan(ci *ConnInfo, oid uint32, formatCode int16 type scanPlanTextAnyToRune struct{} -func (scanPlanTextAnyToRune) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToRune) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -234,7 +234,7 @@ func (scanPlanTextAnyToRune) Scan(ci *ConnInfo, oid uint32, formatCode int16, sr type scanPlanTextAnyToTextScanner struct{} -func (scanPlanTextAnyToTextScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToTextScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(TextScanner) if src == nil { diff --git a/pgtype/tid.go b/pgtype/tid.go index 450cfbc9..cc38f404 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -53,7 +53,7 @@ func (dst *TID) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextAnyToTIDScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) + return scanPlanTextAnyToTIDScanner{}.Scan([]byte(src), dst) } return fmt.Errorf("cannot scan %T", src) @@ -152,7 +152,7 @@ func (TIDCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfac type scanPlanBinaryTIDToTIDScanner struct{} -func (scanPlanBinaryTIDToTIDScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryTIDToTIDScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(TIDScanner) if src == nil { @@ -172,7 +172,7 @@ func (scanPlanBinaryTIDToTIDScanner) Scan(ci *ConnInfo, oid uint32, formatCode i type scanPlanBinaryTIDToTextScanner struct{} -func (scanPlanBinaryTIDToTextScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryTIDToTextScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(TextScanner) if src == nil { @@ -194,7 +194,7 @@ func (scanPlanBinaryTIDToTextScanner) Scan(ci *ConnInfo, oid uint32, formatCode type scanPlanTextAnyToTIDScanner struct{} -func (scanPlanTextAnyToTIDScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToTIDScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(TIDScanner) if src == nil { diff --git a/pgtype/time.go b/pgtype/time.go index 47dabe99..734687cb 100644 --- a/pgtype/time.go +++ b/pgtype/time.go @@ -45,7 +45,7 @@ func (t *Time) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextAnyToTimeScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), t) + return scanPlanTextAnyToTimeScanner{}.Scan([]byte(src), t) } return fmt.Errorf("cannot scan %T", src) @@ -149,7 +149,7 @@ func (TimeCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa type scanPlanBinaryTimeToTimeScanner struct{} -func (scanPlanBinaryTimeToTimeScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryTimeToTimeScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(TimeScanner) if src == nil { @@ -167,7 +167,7 @@ func (scanPlanBinaryTimeToTimeScanner) Scan(ci *ConnInfo, oid uint32, formatCode type scanPlanTextAnyToTimeScanner struct{} -func (scanPlanTextAnyToTimeScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(TimeScanner) if src == nil { diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 374aafe4..24ee76b0 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -44,7 +44,7 @@ func (ts *Timestamp) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextTimestampToTimestampScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), ts) + return scanPlanTextTimestampToTimestampScanner{}.Scan([]byte(src), ts) case time.Time: *ts = Timestamp{Time: src, Valid: true} return nil @@ -172,7 +172,7 @@ func (TimestampCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target in type scanPlanBinaryTimestampToTimestampScanner struct{} -func (scanPlanBinaryTimestampToTimestampScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(TimestampScanner) if src == nil { @@ -204,7 +204,7 @@ func (scanPlanBinaryTimestampToTimestampScanner) Scan(ci *ConnInfo, oid uint32, type scanPlanTextTimestampToTimestampScanner struct{} -func (scanPlanTextTimestampToTimestampScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(TimestampScanner) if src == nil { diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index 1bbceaf5..562bb192 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -52,6 +52,6 @@ func TestTimestampCodecDecodeTextInvalid(t *testing.T) { c := &pgtype.TimestampCodec{} var ts pgtype.Timestamp plan := c.PlanScan(nil, pgtype.TimestampOID, pgtype.TextFormatCode, &ts, false) - err := plan.Scan(nil, pgtype.TimestampOID, pgtype.TextFormatCode, []byte(`eeeee`), &ts) + err := plan.Scan([]byte(`eeeee`), &ts) require.Error(t, err) } diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index eec1dca5..ea2ebfbe 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -53,7 +53,7 @@ func (tstz *Timestamptz) Scan(src interface{}) error { switch src := src.(type) { case string: - return scanPlanTextTimestamptzToTimestamptzScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), tstz) + return scanPlanTextTimestamptzToTimestamptzScanner{}.Scan([]byte(src), tstz) case time.Time: *tstz = Timestamptz{Time: src, Valid: true} return nil @@ -220,7 +220,7 @@ func (TimestamptzCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target type scanPlanBinaryTimestamptzToTimestamptzScanner struct{} -func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(TimestamptzScanner) if src == nil { @@ -252,7 +252,7 @@ func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(ci *ConnInfo, oid uint type scanPlanTextTimestamptzToTimestamptzScanner struct{} -func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(TimestamptzScanner) if src == nil { diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index 42439b7b..ec408f47 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -52,7 +52,7 @@ func TestTimestamptzDecodeTextInvalid(t *testing.T) { c := &pgtype.TimestamptzCodec{} var tstz pgtype.Timestamptz plan := c.PlanScan(nil, pgtype.TimestamptzOID, pgtype.TextFormatCode, &tstz, false) - err := plan.Scan(nil, pgtype.TimestamptzOID, pgtype.TextFormatCode, []byte(`eeeee`), &tstz) + err := plan.Scan([]byte(`eeeee`), &tstz) require.Error(t, err) } diff --git a/pgtype/uint32.go b/pgtype/uint32.go index ccf39471..7d481a27 100644 --- a/pgtype/uint32.go +++ b/pgtype/uint32.go @@ -246,7 +246,7 @@ func (c Uint32Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []b type scanPlanBinaryUint32ToUint32 struct{} -func (scanPlanBinaryUint32ToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryUint32ToUint32) Scan(src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -263,7 +263,7 @@ func (scanPlanBinaryUint32ToUint32) Scan(ci *ConnInfo, oid uint32, formatCode in type scanPlanBinaryUint32ToUint32Scanner struct{} -func (scanPlanBinaryUint32ToUint32Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryUint32ToUint32Scanner) Scan(src []byte, dst interface{}) error { s, ok := (dst).(Uint32Scanner) if !ok { return ErrScanTargetTypeChanged @@ -284,7 +284,7 @@ func (scanPlanBinaryUint32ToUint32Scanner) Scan(ci *ConnInfo, oid uint32, format type scanPlanTextAnyToUint32Scanner struct{} -func (scanPlanTextAnyToUint32Scanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToUint32Scanner) Scan(src []byte, dst interface{}) error { s, ok := (dst).(Uint32Scanner) if !ok { return ErrScanTargetTypeChanged diff --git a/pgtype/uuid.go b/pgtype/uuid.go index 288fc454..2655a124 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -186,7 +186,7 @@ func (UUIDCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa type scanPlanBinaryUUIDToUUIDScanner struct{} -func (scanPlanBinaryUUIDToUUIDScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryUUIDToUUIDScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(UUIDScanner) if src == nil { @@ -205,7 +205,7 @@ func (scanPlanBinaryUUIDToUUIDScanner) Scan(ci *ConnInfo, oid uint32, formatCode type scanPlanTextAnyToUUIDScanner struct{} -func (scanPlanTextAnyToUUIDScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToUUIDScanner) Scan(src []byte, dst interface{}) error { scanner := (dst).(UUIDScanner) if src == nil { diff --git a/rows.go b/rows.go index 5a4bc9a9..06e0a933 100644 --- a/rows.go +++ b/rows.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "reflect" "time" "github.com/jackc/pgproto3/v2" @@ -111,6 +112,7 @@ type connRows struct { multiResultReader *pgconn.MultiResultReader scanPlans []pgtype.ScanPlan + scanTypes []reflect.Type } func (rows *connRows) FieldDescriptions() []pgproto3.FieldDescription { @@ -208,8 +210,10 @@ func (rows *connRows) Scan(dest ...interface{}) error { if rows.scanPlans == nil { rows.scanPlans = make([]pgtype.ScanPlan, len(values)) + rows.scanTypes = make([]reflect.Type, len(values)) for i := range dest { rows.scanPlans[i] = ci.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) + rows.scanTypes[i] = reflect.TypeOf(dest[i]) } } @@ -218,7 +222,12 @@ func (rows *connRows) Scan(dest ...interface{}) error { continue } - err := rows.scanPlans[i].Scan(ci, fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], dst) + if rows.scanTypes[i] != reflect.TypeOf(dst) { + rows.scanPlans[i] = ci.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) + rows.scanTypes[i] = reflect.TypeOf(dest[i]) + } + + err := rows.scanPlans[i].Scan(values[i], dst) if err != nil { err = ScanArgError{ColumnIndex: i, Err: err} rows.fatal(err) diff --git a/stdlib/sql.go b/stdlib/sql.go index 40693ded..5e2c3a03 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -605,21 +605,21 @@ func (r *Rows) Next(dest []driver.Value) error { var d bool scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) return d, err } case pgtype.ByteaOID: var d []byte scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) return d, err } case pgtype.CIDOID, pgtype.OIDOID, pgtype.XIDOID: var d pgtype.Uint32 scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) if err != nil { return nil, err } @@ -629,7 +629,7 @@ func (r *Rows) Next(dest []driver.Value) error { var d pgtype.Date scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) if err != nil { return nil, err } @@ -639,42 +639,42 @@ func (r *Rows) Next(dest []driver.Value) error { var d float32 scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) return float64(d), err } case pgtype.Float8OID: var d float64 scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) return d, err } case pgtype.Int2OID: var d int16 scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) return int64(d), err } case pgtype.Int4OID: var d int32 scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) return int64(d), err } case pgtype.Int8OID: var d int64 scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) return d, err } case pgtype.JSONOID, pgtype.JSONBOID: var d []byte scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) if err != nil { return nil, err } @@ -684,7 +684,7 @@ func (r *Rows) Next(dest []driver.Value) error { var d pgtype.Timestamp scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) if err != nil { return nil, err } @@ -694,7 +694,7 @@ func (r *Rows) Next(dest []driver.Value) error { var d pgtype.Timestamptz scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) if err != nil { return nil, err } @@ -704,7 +704,7 @@ func (r *Rows) Next(dest []driver.Value) error { var d string scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + err := scanPlan.Scan(src, &d) return d, err } } From 0ddf9e3b4b97c566c0c6ce23a622bc49bd03b963 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jan 2022 18:40:46 -0600 Subject: [PATCH 0875/1158] Try wrapping scan target before sql.Scanner This allows wrappers to directly avoid the slow sql.Scanner interface. --- pgtype/pgtype.go | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 150f1a23..cba1bb2f 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -906,14 +906,6 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) S if plan := dt.Codec.PlanScan(ci, oid, formatCode, target, false); plan != nil { return plan } - - if _, ok := target.(*interface{}); ok { - return &pointerEmptyInterfaceScanPlan{codec: dt.Codec, ci: ci, oid: oid, formatCode: formatCode} - } - - if _, ok := target.(sql.Scanner); ok { - return &scanPlanCodecSQLScanner{c: dt.Codec, ci: ci, oid: oid, formatCode: formatCode} - } } for _, f := range ci.TryWrapScanPlanFuncs { @@ -927,6 +919,16 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) S } } + if dt != nil { + if _, ok := target.(*interface{}); ok { + return &pointerEmptyInterfaceScanPlan{codec: dt.Codec, ci: ci, oid: oid, formatCode: formatCode} + } + + if _, ok := target.(sql.Scanner); ok { + return &scanPlanCodecSQLScanner{c: dt.Codec, ci: ci, oid: oid, formatCode: formatCode} + } + } + if _, ok := target.(sql.Scanner); ok { return &scanPlanSQLScanner{formatCode: formatCode} } From f5806bc01c49ab9ed8b239419be48df54a08eaa0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 24 Jan 2022 08:10:01 -0600 Subject: [PATCH 0876/1158] Add a fuzz test Investigating https://github.com/jackc/pgx/issues/938. --- chunkreader_test.go | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/chunkreader_test.go b/chunkreader_test.go index 67a20af2..ddc2fbf6 100644 --- a/chunkreader_test.go +++ b/chunkreader_test.go @@ -2,6 +2,7 @@ package chunkreader import ( "bytes" + "math/rand" "testing" ) @@ -94,3 +95,34 @@ func TestChunkReaderDoesNotReuseBuf(t *testing.T) { t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1) } } + +type randomReader struct { + rnd *rand.Rand +} + +// Read reads a random number of random bytes. +func (r *randomReader) Read(p []byte) (n int, err error) { + n = r.rnd.Intn(len(p) + 1) + return r.rnd.Read(p[:n]) +} + +func TestChunkReaderNextFuzz(t *testing.T) { + rr := &randomReader{rnd: rand.New(rand.NewSource(1))} + r, err := NewConfig(rr, Config{MinBufLen: 8192}) + if err != nil { + t.Fatal(err) + } + + randomSizes := rand.New(rand.NewSource(0)) + + for i := 0; i < 100000; i++ { + size := randomSizes.Intn(16384) + 1 + buf, err := r.Next(size) + if err != nil { + t.Fatal(err) + } + if len(buf) != size { + t.Fatalf("Expected to get %v bytes but got %v bytes", size, len(buf)) + } + } +} From 551d26ca41b018c38b25a6f268791823b2d8da96 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 25 Jan 2022 20:19:02 -0600 Subject: [PATCH 0877/1158] Change ArrayHeader.ElementOID to uint32 --- pgtype/array.go | 6 +++--- pgtype/array_codec.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pgtype/array.go b/pgtype/array.go index 29d6f803..54e85f37 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -19,7 +19,7 @@ import ( type ArrayHeader struct { ContainsNull bool - ElementOID int32 + ElementOID uint32 Dimensions []ArrayDimension } @@ -55,7 +55,7 @@ func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1 rp += 4 - dst.ElementOID = int32(binary.BigEndian.Uint32(src[rp:])) + dst.ElementOID = binary.BigEndian.Uint32(src[rp:]) rp += 4 if numDims > 0 { @@ -84,7 +84,7 @@ func (src ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte { } buf = pgio.AppendInt32(buf, containsNull) - buf = pgio.AppendInt32(buf, src.ElementOID) + buf = pgio.AppendUint32(buf, src.ElementOID) for i := range src.Dimensions { buf = pgio.AppendInt32(buf, src.Dimensions[i].Length) diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 4cc7e84c..922f6d26 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -155,7 +155,7 @@ func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newB arrayHeader := ArrayHeader{ Dimensions: dimensions, - ElementOID: int32(p.ac.ElementOID), + ElementOID: p.ac.ElementOID, } containsNullIndex := len(buf) + 4 From 47345e0d1ef511304af30dd6beac4c673b491dae Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 25 Jan 2022 20:21:28 -0600 Subject: [PATCH 0878/1158] ArrayHeader.EncodeBinary doesn't need ci parameter --- pgtype/array.go | 2 +- pgtype/array_codec.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pgtype/array.go b/pgtype/array.go index 54e85f37..0e8e31a0 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -75,7 +75,7 @@ func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { return rp, nil } -func (src ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte { +func (src ArrayHeader) EncodeBinary(buf []byte) []byte { buf = pgio.AppendInt32(buf, int32(len(src.Dimensions))) var containsNull int32 diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 922f6d26..54e1bf90 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -160,7 +160,7 @@ func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newB containsNullIndex := len(buf) + 4 - buf = arrayHeader.EncodeBinary(p.ci, buf) + buf = arrayHeader.EncodeBinary(buf) elementCount := cardinality(dimensions) From dc77e7c2da14312ae3362862f17201fee93f81bc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Jan 2022 08:17:50 -0600 Subject: [PATCH 0879/1158] Add QueryRow warning to DriverBytes --- pgtype/bytea.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 501e0c59..58f3b348 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -16,8 +16,9 @@ type BytesValuer interface { BytesValue() ([]byte, error) } -// DriverBytes is a byte slice that holds a reference to memory owned by the driver. It is only valid until the next -// database method call. e.g. Any call to a Rows or Conn method invalidates the slice. +// DriverBytes is a byte slice that holds a reference to memory owned by the driver. It is only valid from the time it +// is scanned until Rows.Next or Rows.Close is called. It is safe to use in a function passed to QueryFunc. It is never +// safe to use DriverBytes with QueryRow as Row.Scan internally calls Rows.Close before returning. type DriverBytes []byte func (b *DriverBytes) ScanBytes(v []byte) error { From f5c3eeb813aa682f1a8ae95582ca86b876e0e563 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Jan 2022 15:43:18 -0600 Subject: [PATCH 0880/1158] Initial rebuilt composite support --- pgtype/composite.go | 551 +++++++++++++++++++++++++++++++++++++++ pgtype/composite_test.go | 76 ++++++ 2 files changed, 627 insertions(+) create mode 100644 pgtype/composite.go create mode 100644 pgtype/composite_test.go diff --git a/pgtype/composite.go b/pgtype/composite.go new file mode 100644 index 00000000..d21ab665 --- /dev/null +++ b/pgtype/composite.go @@ -0,0 +1,551 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "errors" + "fmt" + "strings" + + "github.com/jackc/pgio" +) + +// CompositeIndexGetter is a type accessed by index that can be converted into a PostgreSQL composite. +type CompositeIndexGetter interface { + // IsNull returns true if the value is SQL NULL. + IsNull() bool + + // Index returns the element at i. + Index(i int) interface{} +} + +// CompositeIndexScanner is a type accessed by index that can be scanned from a PostgreSQL composite. +type CompositeIndexScanner interface { + // ScanNull sets the value to SQL NULL. + ScanNull() error + + // ScanIndex returns a value usable as a scan target for i. + ScanIndex(i int) interface{} +} + +type CompositeCodecField struct { + Name string + DataType *DataType +} + +type CompositeCodec struct { + Fields []CompositeCodecField +} + +func (c *CompositeCodec) FormatSupported(format int16) bool { + for _, f := range c.Fields { + if !f.DataType.Codec.FormatSupported(format) { + return false + } + } + + return true +} + +func (c *CompositeCodec) PreferredFormat() int16 { + if c.FormatSupported(BinaryFormatCode) { + return BinaryFormatCode + } + return TextFormatCode +} + +func (c *CompositeCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(CompositeIndexGetter); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return &encodePlanCompositeCodecCompositeIndexGetterToBinary{cc: c, ci: ci} + case TextFormatCode: + return &encodePlanCompositeCodecCompositeIndexGetterToText{cc: c, ci: ci} + } + + return nil +} + +type encodePlanCompositeCodecCompositeIndexGetterToBinary struct { + cc *CompositeCodec + ci *ConnInfo +} + +func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + getter := value.(CompositeIndexGetter) + + if getter.IsNull() { + return nil, nil + } + + builder := NewCompositeBinaryBuilder(plan.ci, buf) + for i, field := range plan.cc.Fields { + builder.AppendValue(field.DataType.OID, getter.Index(i)) + } + + return builder.Finish() +} + +type encodePlanCompositeCodecCompositeIndexGetterToText struct { + cc *CompositeCodec + ci *ConnInfo +} + +func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + getter := value.(CompositeIndexGetter) + + if getter.IsNull() { + return nil, nil + } + + b := NewCompositeTextBuilder(plan.ci, buf) + for i, field := range plan.cc.Fields { + b.AppendValue(field.DataType.OID, getter.Index(i)) + } + + return b.Finish() +} + +func (c *CompositeCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case CompositeIndexScanner: + return &scanPlanBinaryCompositeToCompositeIndexScanner{cc: c, ci: ci} + } + case TextFormatCode: + switch target.(type) { + case CompositeIndexScanner: + return &scanPlanTextCompositeToCompositeIndexScanner{cc: c, ci: ci} + } + } + + return nil +} + +type scanPlanBinaryCompositeToCompositeIndexScanner struct { + cc *CompositeCodec + ci *ConnInfo +} + +func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, target interface{}) error { + targetScanner := (target).(CompositeIndexScanner) + + if src == nil { + return targetScanner.ScanNull() + } + + scanner := NewCompositeBinaryScanner(plan.ci, src) + for i, field := range plan.cc.Fields { + if scanner.Next() { + fieldTarget := targetScanner.ScanIndex(i) + if fieldTarget != nil { + fieldPlan := plan.ci.PlanScan(field.DataType.OID, BinaryFormatCode, fieldTarget) + if fieldPlan == nil { + return fmt.Errorf("unable to encode %v into OID %d in binary format", field, field.DataType.OID) + } + + err := fieldPlan.Scan(scanner.Bytes(), fieldTarget) + if err != nil { + return err + } + } + } else { + return errors.New("read past end of composite") + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil +} + +type scanPlanTextCompositeToCompositeIndexScanner struct { + cc *CompositeCodec + ci *ConnInfo +} + +func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, target interface{}) error { + targetScanner := (target).(CompositeIndexScanner) + + if src == nil { + return targetScanner.ScanNull() + } + + scanner := NewCompositeTextScanner(plan.ci, src) + for i, field := range plan.cc.Fields { + if scanner.Next() { + fieldTarget := targetScanner.ScanIndex(i) + if fieldTarget != nil { + fieldPlan := plan.ci.PlanScan(field.DataType.OID, TextFormatCode, fieldTarget) + if fieldPlan == nil { + return fmt.Errorf("unable to encode %v into OID %d in text format", field, field.DataType.OID) + } + + err := fieldPlan.Scan(scanner.Bytes(), fieldTarget) + if err != nil { + return err + } + } + } else { + return errors.New("read past end of composite") + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil +} + +func (c *CompositeCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + // var n int64 + // err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) + // return n, err + + return nil, fmt.Errorf("not implemented") +} + +func (c *CompositeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + // var n int16 + // err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) + // return n, err + + return nil, fmt.Errorf("not implemented") +} + +type CompositeBinaryScanner struct { + ci *ConnInfo + rp int + src []byte + + fieldCount int32 + fieldBytes []byte + fieldOID uint32 + err error +} + +// NewCompositeBinaryScanner a scanner over a binary encoded composite balue. +func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner { + rp := 0 + if len(src[rp:]) < 4 { + return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)} + } + + fieldCount := int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + return &CompositeBinaryScanner{ + ci: ci, + rp: rp, + src: src, + fieldCount: fieldCount, + } +} + +// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Next returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeBinaryScanner) Next() bool { + if cfs.err != nil { + return false + } + + if cfs.rp == len(cfs.src) { + return false + } + + if len(cfs.src[cfs.rp:]) < 8 { + cfs.err = fmt.Errorf("Record incomplete %v", cfs.src) + return false + } + cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:]) + cfs.rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:]))) + cfs.rp += 4 + + if fieldLen >= 0 { + if len(cfs.src[cfs.rp:]) < fieldLen { + cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src) + return false + } + cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen] + cfs.rp += fieldLen + } else { + cfs.fieldBytes = nil + } + + return true +} + +func (cfs *CompositeBinaryScanner) FieldCount() int { + return int(cfs.fieldCount) +} + +// Bytes returns the bytes of the field most recently read by Scan(). +func (cfs *CompositeBinaryScanner) Bytes() []byte { + return cfs.fieldBytes +} + +// OID returns the OID of the field most recently read by Scan(). +func (cfs *CompositeBinaryScanner) OID() uint32 { + return cfs.fieldOID +} + +// Err returns any error encountered by the scanner. +func (cfs *CompositeBinaryScanner) Err() error { + return cfs.err +} + +type CompositeTextScanner struct { + ci *ConnInfo + rp int + src []byte + + fieldBytes []byte + err error +} + +// NewCompositeTextScanner a scanner over a text encoded composite value. +func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner { + if len(src) < 2 { + return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)} + } + + if src[0] != '(' { + return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")} + } + + if src[len(src)-1] != ')' { + return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")} + } + + return &CompositeTextScanner{ + ci: ci, + rp: 1, + src: src, + } +} + +// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Next returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeTextScanner) Next() bool { + if cfs.err != nil { + return false + } + + if cfs.rp == len(cfs.src) { + return false + } + + switch cfs.src[cfs.rp] { + case ',', ')': // null + cfs.rp++ + cfs.fieldBytes = nil + return true + case '"': // quoted value + cfs.rp++ + cfs.fieldBytes = make([]byte, 0, 16) + for { + ch := cfs.src[cfs.rp] + + if ch == '"' { + cfs.rp++ + if cfs.src[cfs.rp] == '"' { + cfs.fieldBytes = append(cfs.fieldBytes, '"') + cfs.rp++ + } else { + break + } + } else if ch == '\\' { + cfs.rp++ + cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp]) + cfs.rp++ + } else { + cfs.fieldBytes = append(cfs.fieldBytes, ch) + cfs.rp++ + } + } + cfs.rp++ + return true + default: // unquoted value + start := cfs.rp + for { + ch := cfs.src[cfs.rp] + if ch == ',' || ch == ')' { + break + } + cfs.rp++ + } + cfs.fieldBytes = cfs.src[start:cfs.rp] + cfs.rp++ + return true + } +} + +// Bytes returns the bytes of the field most recently read by Scan(). +func (cfs *CompositeTextScanner) Bytes() []byte { + return cfs.fieldBytes +} + +// Err returns any error encountered by the scanner. +func (cfs *CompositeTextScanner) Err() error { + return cfs.err +} + +type CompositeBinaryBuilder struct { + ci *ConnInfo + buf []byte + startIdx int + fieldCount uint32 + err error +} + +func NewCompositeBinaryBuilder(ci *ConnInfo, buf []byte) *CompositeBinaryBuilder { + startIdx := len(buf) + buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields + return &CompositeBinaryBuilder{ci: ci, buf: buf, startIdx: startIdx} +} + +func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) { + if b.err != nil { + return + } + + if field == nil { + b.buf = pgio.AppendUint32(b.buf, oid) + b.buf = pgio.AppendInt32(b.buf, -1) + b.fieldCount++ + return + } + + plan := b.ci.PlanEncode(oid, BinaryFormatCode, field) + if plan == nil { + b.err = fmt.Errorf("unable to encode %v into OID %d in binary format", field, oid) + return + } + + b.buf = pgio.AppendUint32(b.buf, oid) + lengthPos := len(b.buf) + b.buf = pgio.AppendInt32(b.buf, -1) + fieldBuf, err := plan.Encode(field, b.buf) + if err != nil { + b.err = err + return + } + if fieldBuf != nil { + binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf))) + b.buf = fieldBuf + } + + b.fieldCount++ +} + +func (b *CompositeBinaryBuilder) Finish() ([]byte, error) { + if b.err != nil { + return nil, b.err + } + + binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount) + return b.buf, nil +} + +type CompositeTextBuilder struct { + ci *ConnInfo + buf []byte + startIdx int + fieldCount uint32 + err error + fieldBuf [32]byte +} + +func NewCompositeTextBuilder(ci *ConnInfo, buf []byte) *CompositeTextBuilder { + buf = append(buf, '(') // allocate room for number of fields + return &CompositeTextBuilder{ci: ci, buf: buf} +} + +func (b *CompositeTextBuilder) AppendValue(oid uint32, field interface{}) { + if b.err != nil { + return + } + + if field == nil { + b.buf = append(b.buf, ',') + return + } + + plan := b.ci.PlanEncode(oid, TextFormatCode, field) + if plan == nil { + b.err = fmt.Errorf("unable to encode %v into OID %d in text format", field, oid) + return + } + + fieldBuf, err := plan.Encode(field, b.fieldBuf[0:0]) + if err != nil { + b.err = err + return + } + if fieldBuf != nil { + b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...) + } + + b.buf = append(b.buf, ',') +} + +func (b *CompositeTextBuilder) Finish() ([]byte, error) { + if b.err != nil { + return nil, b.err + } + + b.buf[len(b.buf)-1] = ')' + return b.buf, nil +} + +var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteCompositeField(src string) string { + return `"` + quoteCompositeReplacer.Replace(src) + `"` +} + +func quoteCompositeFieldIfNeeded(src string) string { + if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) { + return quoteCompositeField(src) + } + return src +} + +// CompositeFields represents the values of a composite value. It can be used as an encoding source or as a scan target. +// It cannot scan a NULL, but the composite fields can be NULL. +type CompositeFields []interface{} + +func (cf CompositeFields) SkipUnderlyingTypePlan() {} + +func (cf CompositeFields) IsNull() bool { + return cf == nil +} + +func (cf CompositeFields) Index(i int) interface{} { + return cf[i] +} + +func (cf CompositeFields) ScanNull() error { + return fmt.Errorf("cannot scan NULL into CompositeFields") +} + +func (cf CompositeFields) ScanIndex(i int) interface{} { + return cf[i] +} diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go new file mode 100644 index 00000000..ba91de80 --- /dev/null +++ b/pgtype/composite_test.go @@ -0,0 +1,76 @@ +package pgtype_test + +import ( + "context" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestCompositeCodecTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists ct_test; + +create type ct_test as ( + a text, + b int4 +);`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type ct_test") + + var oid uint32 + err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid) + require.NoError(t, err) + + defer conn.Exec(context.Background(), "drop type ct_test") + + textDataType, ok := conn.ConnInfo().DataTypeForOID(pgtype.TextOID) + require.True(t, ok) + + int4DataType, ok := conn.ConnInfo().DataTypeForOID(pgtype.Int4OID) + require.True(t, ok) + + conn.ConnInfo().RegisterDataType(pgtype.DataType{ + Name: "ct_test", + OID: oid, + Codec: &pgtype.CompositeCodec{ + Fields: []pgtype.CompositeCodecField{ + { + Name: "a", + DataType: textDataType, + }, + { + Name: "b", + DataType: int4DataType, + }, + }, + }, + }) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + var a string + var b int32 + + err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QueryResultFormats{format.code}, + pgtype.CompositeFields{"hi", int32(42)}, + ).Scan( + pgtype.CompositeFields{&a, &b}, + ) + require.NoErrorf(t, err, "%v", format.name) + require.EqualValuesf(t, "hi", a, "%v", format.name) + require.EqualValuesf(t, 42, b, "%v", format.name) + } +} From b5bf9d7bb9fc3d84e14e91778e30009fce72809f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Jan 2022 16:32:05 -0600 Subject: [PATCH 0881/1158] Move LoadDataType to pgx.Conn --- conn.go | 90 ++++++++++++++++++++++++++++++ pgtype/composite_test.go | 103 +++++++++++++++++++++++++--------- pgtype/pgxtype/README.md | 3 - pgtype/pgxtype/pgxtype.go | 114 -------------------------------------- 4 files changed, 166 insertions(+), 144 deletions(-) delete mode 100644 pgtype/pgxtype/README.md delete mode 100644 pgtype/pgxtype/pgxtype.go diff --git a/conn.go b/conn.go index 4412e174..11f275a6 100644 --- a/conn.go +++ b/conn.go @@ -862,3 +862,93 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, return sanitize.SanitizeSQL(sql, valueArgs...) } + +// LoadDataType inspects the database for typeName and produces a pgtype.DataType suitable for +// registration. +func (c *Conn) LoadDataType(ctx context.Context, typeName string) (*pgtype.DataType, error) { + var oid uint32 + + err := c.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid) + if err != nil { + return nil, err + } + + var typtype string + + err = c.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype) + if err != nil { + return nil, err + } + + switch typtype { + case "b": // array + elementOID, err := c.getArrayElementOID(ctx, oid) + if err != nil { + return nil, err + } + + var elementCodec pgtype.Codec + if dt, ok := c.ConnInfo().DataTypeForOID(elementOID); ok { + if dt.Codec == nil { + return nil, errors.New("array element OID not registered with Codec") + } + elementCodec = dt.Codec + } + + return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementOID: elementOID, ElementCodec: elementCodec}}, nil + case "c": // composite + fields, err := c.getCompositeFields(ctx, oid) + if err != nil { + return nil, err + } + + return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil + case "e": // enum + return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil + default: + return &pgtype.DataType{}, errors.New("unknown typtype") + } +} + +func (c *Conn) getArrayElementOID(ctx context.Context, oid uint32) (uint32, error) { + var typelem uint32 + + err := c.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem) + if err != nil { + return 0, err + } + + return typelem, nil +} + +func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.CompositeCodecField, error) { + var typrelid uint32 + + err := c.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid) + if err != nil { + return nil, err + } + + var fields []pgtype.CompositeCodecField + var fieldName string + var fieldOID uint32 + _, err = c.QueryFunc(ctx, `select attname, atttypid +from pg_attribute +where attrelid=$1 +order by attnum`, + []interface{}{typrelid}, + []interface{}{&fieldName, &fieldOID}, + func(qfr QueryFuncRow) error { + dt, ok := c.ConnInfo().DataTypeForOID(fieldOID) + if !ok { + return fmt.Errorf("unknown composite type field OID: %v", fieldOID) + } + fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, DataType: dt}) + return nil + }) + if err != nil { + return nil, err + } + + return fields, nil +} diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go index ba91de80..c9319c2d 100644 --- a/pgtype/composite_test.go +++ b/pgtype/composite_test.go @@ -2,6 +2,7 @@ package pgtype_test import ( "context" + "fmt" "testing" pgx "github.com/jackc/pgx/v5" @@ -23,34 +24,9 @@ create type ct_test as ( require.NoError(t, err) defer conn.Exec(context.Background(), "drop type ct_test") - var oid uint32 - err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid) + dt, err := conn.LoadDataType(context.Background(), "ct_test") require.NoError(t, err) - - defer conn.Exec(context.Background(), "drop type ct_test") - - textDataType, ok := conn.ConnInfo().DataTypeForOID(pgtype.TextOID) - require.True(t, ok) - - int4DataType, ok := conn.ConnInfo().DataTypeForOID(pgtype.Int4OID) - require.True(t, ok) - - conn.ConnInfo().RegisterDataType(pgtype.DataType{ - Name: "ct_test", - OID: oid, - Codec: &pgtype.CompositeCodec{ - Fields: []pgtype.CompositeCodecField{ - { - Name: "a", - DataType: textDataType, - }, - { - Name: "b", - DataType: int4DataType, - }, - }, - }, - }) + conn.ConnInfo().RegisterDataType(*dt) formats := []struct { name string @@ -74,3 +50,76 @@ create type ct_test as ( require.EqualValuesf(t, 42, b, "%v", format.name) } } + +type point3d struct { + X, Y, Z float64 +} + +func (p point3d) IsNull() bool { + return false +} + +func (p point3d) Index(i int) interface{} { + switch i { + case 0: + return p.X + case 1: + return p.Y + case 2: + return p.Z + default: + panic("invalid index") + } +} + +func (p *point3d) ScanNull() error { + return fmt.Errorf("cannot scan NULL into point3d") +} + +func (p *point3d) ScanIndex(i int) interface{} { + switch i { + case 0: + return &p.X + case 1: + return &p.Y + case 2: + return &p.Z + default: + panic("invalid index") + } +} + +func TestCompositeCodecTranscodeStruct(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists point3d; + +create type point3d as ( + x float8, + y float8, + z float8 +);`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type point3d") + + dt, err := conn.LoadDataType(context.Background(), "point3d") + require.NoError(t, err) + conn.ConnInfo().RegisterDataType(*dt) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + input := point3d{X: 1, Y: 2, Z: 3} + var output point3d + err := conn.QueryRow(context.Background(), "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output) + require.NoErrorf(t, err, "%v", format.name) + require.Equalf(t, input, output, "%v", format.name) + } +} diff --git a/pgtype/pgxtype/README.md b/pgtype/pgxtype/README.md deleted file mode 100644 index a070111f..00000000 --- a/pgtype/pgxtype/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# pgxtype - -pgxtype is a helper module that connects pgx and pgtype. This package is not currently covered by semantic version guarantees. i.e. The interfaces may change without a major version release of pgtype. diff --git a/pgtype/pgxtype/pgxtype.go b/pgtype/pgxtype/pgxtype.go deleted file mode 100644 index 6436f01b..00000000 --- a/pgtype/pgxtype/pgxtype.go +++ /dev/null @@ -1,114 +0,0 @@ -package pgxtype - -import ( - "context" - "errors" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgtype" -) - -type Querier interface { - Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) - Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) - QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row -} - -// LoadDataType uses conn to inspect the database for typeName and produces a pgtype.DataType suitable for -// registration on ci. -func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeName string) (pgtype.DataType, error) { - var oid uint32 - - err := conn.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid) - if err != nil { - return pgtype.DataType{}, err - } - - var typtype string - - err = conn.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype) - if err != nil { - return pgtype.DataType{}, err - } - - switch typtype { - case "b": // array - elementOID, err := GetArrayElementOID(ctx, conn, oid) - if err != nil { - return pgtype.DataType{}, err - } - - var elementCodec pgtype.Codec - if dt, ok := ci.DataTypeForOID(elementOID); ok { - if dt.Codec == nil { - return pgtype.DataType{}, errors.New("array element OID not registered with Codec") - } - elementCodec = dt.Codec - } - - return pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementOID: elementOID, ElementCodec: elementCodec}}, nil - case "c": // composite - panic("TODO - restore composite support") - // fields, err := GetCompositeFields(ctx, conn, oid) - // if err != nil { - // return pgtype.DataType{}, err - // } - // ct, err := pgtype.NewCompositeType(typeName, fields, ci) - // if err != nil { - // return pgtype.DataType{}, err - // } - // return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil - case "e": // enum - return pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil - default: - return pgtype.DataType{}, errors.New("unknown typtype") - } -} - -func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32, error) { - var typelem uint32 - - err := conn.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem) - if err != nil { - return 0, err - } - - return typelem, nil -} - -// TODO - restore composite support -// GetCompositeFields gets the fields of a composite type. -// func GetCompositeFields(ctx context.Context, conn Querier, oid uint32) ([]pgtype.CompositeTypeField, error) { -// var typrelid uint32 - -// err := conn.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid) -// if err != nil { -// return nil, err -// } - -// var fields []pgtype.CompositeTypeField - -// rows, err := conn.Query(ctx, `select attname, atttypid -// from pg_attribute -// where attrelid=$1 -// order by attnum`, typrelid) -// if err != nil { -// return nil, err -// } - -// for rows.Next() { -// var f pgtype.CompositeTypeField -// err := rows.Scan(&f.Name, &f.OID) -// if err != nil { -// return nil, err -// } -// fields = append(fields, f) -// } - -// if rows.Err() != nil { -// return nil, rows.Err() -// } - -// return fields, nil -// } From 558748ef9c2a5411fa051919e40a0ede5b8f627c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Jan 2022 16:41:07 -0600 Subject: [PATCH 0882/1158] ArrayCodec contains element DataType --- conn.go | 11 ++--- pgtype/array_codec.go | 21 +++++----- pgtype/pgtype.go | 95 ++++++++++++++++++++++--------------------- 3 files changed, 62 insertions(+), 65 deletions(-) diff --git a/conn.go b/conn.go index 11f275a6..877589a7 100644 --- a/conn.go +++ b/conn.go @@ -887,15 +887,12 @@ func (c *Conn) LoadDataType(ctx context.Context, typeName string) (*pgtype.DataT return nil, err } - var elementCodec pgtype.Codec - if dt, ok := c.ConnInfo().DataTypeForOID(elementOID); ok { - if dt.Codec == nil { - return nil, errors.New("array element OID not registered with Codec") - } - elementCodec = dt.Codec + dt, ok := c.ConnInfo().DataTypeForOID(elementOID) + if !ok { + return nil, errors.New("array element OID not registered") } - return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementOID: elementOID, ElementCodec: elementCodec}}, nil + return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementDataType: dt}}, nil case "c": // composite fields, err := c.getCompositeFields(ctx, oid) if err != nil { diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 54e1bf90..5d847fae 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -31,16 +31,15 @@ type ArraySetter interface { // ArrayCodec is a codec for any array type. type ArrayCodec struct { - ElementCodec Codec - ElementOID uint32 + ElementDataType *DataType } func (c *ArrayCodec) FormatSupported(format int16) bool { - return c.ElementCodec.FormatSupported(format) + return c.ElementDataType.Codec.FormatSupported(format) } func (c *ArrayCodec) PreferredFormat() int16 { - return c.ElementCodec.PreferredFormat() + return c.ElementDataType.Codec.PreferredFormat() } func (c *ArrayCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { @@ -109,7 +108,7 @@ func (p *encodePlanArrayCodecText) Encode(value interface{}, buf []byte) (newBuf elemType := reflect.TypeOf(elem) if lastElemType != elemType { lastElemType = elemType - encodePlan = p.ci.PlanEncode(p.ac.ElementOID, TextFormatCode, elem) + encodePlan = p.ci.PlanEncode(p.ac.ElementDataType.OID, TextFormatCode, elem) if encodePlan == nil { return nil, fmt.Errorf("unable to encode %v", array.Index(i)) } @@ -155,7 +154,7 @@ func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newB arrayHeader := ArrayHeader{ Dimensions: dimensions, - ElementOID: p.ac.ElementOID, + ElementOID: p.ac.ElementDataType.OID, } containsNullIndex := len(buf) + 4 @@ -176,7 +175,7 @@ func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newB elemType := reflect.TypeOf(elem) if lastElemType != elemType { lastElemType = elemType - encodePlan = p.ci.PlanEncode(p.ac.ElementOID, BinaryFormatCode, elem) + encodePlan = p.ci.PlanEncode(p.ac.ElementDataType.OID, BinaryFormatCode, elem) if encodePlan == nil { return nil, fmt.Errorf("unable to encode %v", array.Index(i)) } @@ -235,9 +234,9 @@ func (c *ArrayCodec) decodeBinary(ci *ConnInfo, arrayOID uint32, src []byte, arr return nil } - elementScanPlan := c.ElementCodec.PlanScan(ci, c.ElementOID, BinaryFormatCode, array.ScanIndex(0), false) + elementScanPlan := c.ElementDataType.Codec.PlanScan(ci, c.ElementDataType.OID, BinaryFormatCode, array.ScanIndex(0), false) if elementScanPlan == nil { - elementScanPlan = ci.PlanScan(c.ElementOID, BinaryFormatCode, array.ScanIndex(0)) + elementScanPlan = ci.PlanScan(c.ElementDataType.OID, BinaryFormatCode, array.ScanIndex(0)) } for i := 0; i < elementCount; i++ { @@ -279,9 +278,9 @@ func (c *ArrayCodec) decodeText(ci *ConnInfo, arrayOID uint32, src []byte, array return nil } - elementScanPlan := c.ElementCodec.PlanScan(ci, c.ElementOID, TextFormatCode, array.ScanIndex(0), false) + elementScanPlan := c.ElementDataType.Codec.PlanScan(ci, c.ElementDataType.OID, TextFormatCode, array.ScanIndex(0), false) if elementScanPlan == nil { - elementScanPlan = ci.PlanScan(c.ElementOID, TextFormatCode, array.ScanIndex(0)) + elementScanPlan = ci.PlanScan(c.ElementDataType.OID, TextFormatCode, array.ScanIndex(0)) } for i, s := range uta.Elements { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index cba1bb2f..0c8f4763 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -207,44 +207,15 @@ func NewConnInfo() *ConnInfo { }, } - ci.RegisterDataType(DataType{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementCodec: &TextFormatOnlyCodec{TextCodec{}}, ElementOID: ACLItemOID}}) - ci.RegisterDataType(DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementCodec: BoolCodec{}, ElementOID: BoolOID}}) - ci.RegisterDataType(DataType{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: BPCharOID}}) - ci.RegisterDataType(DataType{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementCodec: ByteaCodec{}, ElementOID: ByteaOID}}) - ci.RegisterDataType(DataType{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementCodec: InetCodec{}, ElementOID: CIDROID}}) - ci.RegisterDataType(DataType{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementCodec: DateCodec{}, ElementOID: DateOID}}) - ci.RegisterDataType(DataType{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementCodec: Float4Codec{}, ElementOID: Float4OID}}) - ci.RegisterDataType(DataType{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementCodec: Float8Codec{}, ElementOID: Float8OID}}) - ci.RegisterDataType(DataType{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementCodec: InetCodec{}, ElementOID: InetOID}}) - ci.RegisterDataType(DataType{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementCodec: Int2Codec{}, ElementOID: Int2OID}}) - ci.RegisterDataType(DataType{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementCodec: Int4Codec{}, ElementOID: Int4OID}}) - ci.RegisterDataType(DataType{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementCodec: Int8Codec{}, ElementOID: Int8OID}}) - ci.RegisterDataType(DataType{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementCodec: IntervalCodec{}, ElementOID: IntervalOID}}) - ci.RegisterDataType(DataType{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementCodec: BoxCodec{}, ElementOID: BoxOID}}) - ci.RegisterDataType(DataType{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementCodec: LineCodec{}, ElementOID: LineOID}}) - ci.RegisterDataType(DataType{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementCodec: LsegCodec{}, ElementOID: LsegOID}}) - ci.RegisterDataType(DataType{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementCodec: PathCodec{}, ElementOID: PathOID}}) - ci.RegisterDataType(DataType{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementCodec: CircleCodec{}, ElementOID: CircleOID}}) - ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementCodec: PointCodec{}, ElementOID: PointOID}}) - ci.RegisterDataType(DataType{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementCodec: PolygonCodec{}, ElementOID: PolygonOID}}) - ci.RegisterDataType(DataType{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: NameOID}}) - ci.RegisterDataType(DataType{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementCodec: QCharCodec{}, ElementOID: QCharOID}}) - ci.RegisterDataType(DataType{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementCodec: NumericCodec{}, ElementOID: NumericOID}}) - ci.RegisterDataType(DataType{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: TextOID}}) - ci.RegisterDataType(DataType{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementCodec: TimestampCodec{}, ElementOID: TimestampOID}}) - ci.RegisterDataType(DataType{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementCodec: TimestamptzCodec{}, ElementOID: TimestamptzOID}}) - ci.RegisterDataType(DataType{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementCodec: MacaddrCodec{}, ElementOID: MacaddrOID}}) - ci.RegisterDataType(DataType{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementCodec: TIDCodec{}, ElementOID: TIDOID}}) - ci.RegisterDataType(DataType{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementCodec: UUIDCodec{}, ElementOID: UUIDOID}}) - ci.RegisterDataType(DataType{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementCodec: JSONBCodec{}, ElementOID: JSONBOID}}) - ci.RegisterDataType(DataType{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementCodec: JSONCodec{}, ElementOID: JSONOID}}) - ci.RegisterDataType(DataType{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: VarcharOID}}) - ci.RegisterDataType(DataType{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementCodec: BitsCodec{}, ElementOID: BitOID}}) - ci.RegisterDataType(DataType{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementCodec: BitsCodec{}, ElementOID: VarbitOID}}) - ci.RegisterDataType(DataType{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementCodec: Uint32Codec{}, ElementOID: CIDOID}}) - ci.RegisterDataType(DataType{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementCodec: Uint32Codec{}, ElementOID: OIDOID}}) - ci.RegisterDataType(DataType{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementCodec: Uint32Codec{}, ElementOID: XIDOID}}) - ci.RegisterDataType(DataType{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementCodec: TimeCodec{}, ElementOID: TimeOID}}) + // ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) + // ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) + // ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) + // ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) + // ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) + // ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) + // ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID}) + // ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) + // ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID}) ci.RegisterDataType(DataType{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) ci.RegisterDataType(DataType{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) @@ -256,15 +227,12 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) ci.RegisterDataType(DataType{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) ci.RegisterDataType(DataType{Name: "date", OID: DateOID, Codec: DateCodec{}}) - // ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) ci.RegisterDataType(DataType{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) ci.RegisterDataType(DataType{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) ci.RegisterDataType(DataType{Name: "inet", OID: InetOID, Codec: InetCodec{}}) ci.RegisterDataType(DataType{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) ci.RegisterDataType(DataType{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) - // ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) ci.RegisterDataType(DataType{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) - // ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) ci.RegisterDataType(DataType{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) ci.RegisterDataType(DataType{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) ci.RegisterDataType(DataType{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) @@ -273,27 +241,60 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) ci.RegisterDataType(DataType{Name: "name", OID: NameOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) - // ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) ci.RegisterDataType(DataType{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) ci.RegisterDataType(DataType{Name: "path", OID: PathOID, Codec: PathCodec{}}) ci.RegisterDataType(DataType{Name: "point", OID: PointOID, Codec: PointCodec{}}) ci.RegisterDataType(DataType{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) - // ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) ci.RegisterDataType(DataType{Name: "text", OID: TextOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) ci.RegisterDataType(DataType{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) ci.RegisterDataType(DataType{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) ci.RegisterDataType(DataType{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) - // ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) - // ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID}) - // ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) - // ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID}) ci.RegisterDataType(DataType{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) ci.RegisterDataType(DataType{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) ci.RegisterDataType(DataType{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) + ci.RegisterDataType(DataType{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[ACLItemOID]}}) + ci.RegisterDataType(DataType{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BitOID]}}) + ci.RegisterDataType(DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BoolOID]}}) + ci.RegisterDataType(DataType{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BoxOID]}}) + ci.RegisterDataType(DataType{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BPCharOID]}}) + ci.RegisterDataType(DataType{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[ByteaOID]}}) + ci.RegisterDataType(DataType{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[QCharOID]}}) + ci.RegisterDataType(DataType{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[CIDOID]}}) + ci.RegisterDataType(DataType{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[CIDROID]}}) + ci.RegisterDataType(DataType{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[CircleOID]}}) + ci.RegisterDataType(DataType{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[DateOID]}}) + ci.RegisterDataType(DataType{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Float4OID]}}) + ci.RegisterDataType(DataType{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Float8OID]}}) + ci.RegisterDataType(DataType{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[InetOID]}}) + ci.RegisterDataType(DataType{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int2OID]}}) + ci.RegisterDataType(DataType{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int4OID]}}) + ci.RegisterDataType(DataType{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int8OID]}}) + ci.RegisterDataType(DataType{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[IntervalOID]}}) + ci.RegisterDataType(DataType{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[JSONOID]}}) + ci.RegisterDataType(DataType{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[JSONBOID]}}) + ci.RegisterDataType(DataType{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[LineOID]}}) + ci.RegisterDataType(DataType{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[LsegOID]}}) + ci.RegisterDataType(DataType{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[MacaddrOID]}}) + ci.RegisterDataType(DataType{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[NameOID]}}) + ci.RegisterDataType(DataType{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[NumericOID]}}) + ci.RegisterDataType(DataType{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[OIDOID]}}) + ci.RegisterDataType(DataType{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PathOID]}}) + ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PointOID]}}) + ci.RegisterDataType(DataType{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PolygonOID]}}) + ci.RegisterDataType(DataType{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TextOID]}}) + ci.RegisterDataType(DataType{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TIDOID]}}) + ci.RegisterDataType(DataType{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimeOID]}}) + ci.RegisterDataType(DataType{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimestampOID]}}) + ci.RegisterDataType(DataType{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimestamptzOID]}}) + ci.RegisterDataType(DataType{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[UUIDOID]}}) + ci.RegisterDataType(DataType{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[VarbitOID]}}) + ci.RegisterDataType(DataType{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[VarcharOID]}}) + ci.RegisterDataType(DataType{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[XIDOID]}}) + registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) { // T ci.RegisterDefaultPgType(value, name) From ef7114a8ceec41043c2006dc7f4032962269b770 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Jan 2022 20:39:50 -0600 Subject: [PATCH 0883/1158] Add DecodeValue and DecodeDatabaseSQLValue for ArrayCodec --- pgtype/array_codec.go | 27 +++++++++++++++------------ pgtype/array_codec_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 12 deletions(-) diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 5d847fae..901ea5f7 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -332,26 +332,29 @@ func (spac *scanPlanArrayCodec) Scan(src []byte, dst interface{}) error { } } -func (c ArrayCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *ArrayCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } - // var n int64 - // err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) - // return n, err - - return nil, fmt.Errorf("not implemented") + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } } -func (c ArrayCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c *ArrayCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } - // var n int16 - // err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) - // return n, err - - return nil, fmt.Errorf("not implemented") + var slice []interface{} + err := ci.PlanScan(oid, format, &slice).Scan(src, &slice) + return slice, err } diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go index c358586e..0c31dcee 100644 --- a/pgtype/array_codec_test.go +++ b/pgtype/array_codec_test.go @@ -6,6 +6,7 @@ import ( "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestArrayCodec(t *testing.T) { @@ -70,3 +71,40 @@ func TestArrayCodecAnySlice(t *testing.T) { assert.Equalf(t, tt.expected, actual, "%d", i) } } + +func TestArrayCodecDecodeValue(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + for _, tt := range []struct { + sql string + expected interface{} + }{ + { + sql: `select '{}'::int4[]`, + expected: []interface{}{}, + }, + { + sql: `select '{1,2}'::int8[]`, + expected: []interface{}{int64(1), int64(2)}, + }, + { + sql: `select '{foo,bar}'::text[]`, + expected: []interface{}{"foo", "bar"}, + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(context.Background(), tt.sql) + require.NoError(t, err) + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } +} From 11223497b3e7b4531bcd7cb827ad71f36ed4efcb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Jan 2022 20:42:12 -0600 Subject: [PATCH 0884/1158] Restore record support --- pgtype/pgtype.go | 4 +- pgtype/record_codec.go | 116 ++++++++++++++++++++++++++++++++++++ pgtype/record_codec_test.go | 72 ++++++++++++++++++++++ 3 files changed, 191 insertions(+), 1 deletion(-) create mode 100644 pgtype/record_codec.go create mode 100644 pgtype/record_codec_test.go diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 0c8f4763..ab317f6e 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -86,6 +86,7 @@ const ( VarbitArrayOID = 1563 NumericOID = 1700 RecordOID = 2249 + RecordArrayOID = 2287 UUIDOID = 2950 UUIDArrayOID = 2951 JSONBOID = 3802 @@ -211,7 +212,6 @@ func NewConnInfo() *ConnInfo { // ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) // ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) // ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) - // ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) // ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) // ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID}) // ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) @@ -245,6 +245,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "path", OID: PathOID, Codec: PathCodec{}}) ci.RegisterDataType(DataType{Name: "point", OID: PointOID, Codec: PointCodec{}}) ci.RegisterDataType(DataType{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) + ci.RegisterDataType(DataType{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) ci.RegisterDataType(DataType{Name: "text", OID: TextOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) ci.RegisterDataType(DataType{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) @@ -285,6 +286,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PathOID]}}) ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PointOID]}}) ci.RegisterDataType(DataType{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PolygonOID]}}) + ci.RegisterDataType(DataType{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[RecordOID]}}) ci.RegisterDataType(DataType{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TextOID]}}) ci.RegisterDataType(DataType{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TIDOID]}}) ci.RegisterDataType(DataType{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimeOID]}}) diff --git a/pgtype/record_codec.go b/pgtype/record_codec.go new file mode 100644 index 00000000..31001b1f --- /dev/null +++ b/pgtype/record_codec.go @@ -0,0 +1,116 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" +) + +// ArrayGetter is a type that can be converted into a PostgreSQL array. + +// RecordCodec is a codec for the generic PostgreSQL record type such as is created with the "row" function. Record can +// only decode the binary format. The text format output format from PostgreSQL does not include type information and +// is therefore impossible to decode. Encoding is impossible because PostgreSQL does not support input of generic +// records. +type RecordCodec struct{} + +func (RecordCodec) FormatSupported(format int16) bool { + return format == BinaryFormatCode +} + +func (RecordCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (RecordCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + return nil +} + +func (RecordCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + if format == BinaryFormatCode { + switch target.(type) { + case CompositeIndexScanner: + return &scanPlanBinaryRecordToCompositeIndexScanner{ci: ci} + } + } + + return nil +} + +type scanPlanBinaryRecordToCompositeIndexScanner struct { + ci *ConnInfo +} + +func (plan *scanPlanBinaryRecordToCompositeIndexScanner) Scan(src []byte, target interface{}) error { + targetScanner := (target).(CompositeIndexScanner) + + if src == nil { + return targetScanner.ScanNull() + } + + scanner := NewCompositeBinaryScanner(plan.ci, src) + for i := 0; scanner.Next(); i++ { + fieldTarget := targetScanner.ScanIndex(i) + if fieldTarget != nil { + fieldPlan := plan.ci.PlanScan(scanner.OID(), BinaryFormatCode, fieldTarget) + if fieldPlan == nil { + return fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), fieldTarget) + } + + err := fieldPlan.Scan(scanner.Bytes(), fieldTarget) + if err != nil { + return err + } + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil +} + +func (RecordCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + return nil, fmt.Errorf("not implemented") +} + +func (RecordCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + scanner := NewCompositeBinaryScanner(ci, src) + values := make([]interface{}, scanner.FieldCount()) + for i := 0; scanner.Next(); i++ { + var v interface{} + fieldPlan := ci.PlanScan(scanner.OID(), BinaryFormatCode, &v) + if fieldPlan == nil { + return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v) + } + + err := fieldPlan.Scan(scanner.Bytes(), &v) + if err != nil { + return nil, err + } + + values[i] = v + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return values, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } + +} diff --git a/pgtype/record_codec_test.go b/pgtype/record_codec_test.go new file mode 100644 index 00000000..14018e9e --- /dev/null +++ b/pgtype/record_codec_test.go @@ -0,0 +1,72 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestRecordCodec(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + var a string + var b int32 + err := conn.QueryRow(context.Background(), `select row('foo'::text, 42::int4)`).Scan(pgtype.CompositeFields{&a, &b}) + require.NoError(t, err) + + require.Equal(t, "foo", a) + require.Equal(t, int32(42), b) +} + +func TestRecordCodecDecodeValue(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + for _, tt := range []struct { + sql string + expected interface{} + }{ + { + sql: `select row()`, + expected: []interface{}{}, + }, + { + sql: `select row('foo'::text, 42::int4)`, + expected: []interface{}{"foo", int32(42)}, + }, + { + sql: `select row(100.0::float4, 1.09::float4)`, + expected: []interface{}{float32(100), float32(1.09)}, + }, + { + sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, + expected: []interface{}{"foo", []interface{}{int32(1), int32(2), nil, int32(4)}, int32(42)}, + }, + { + sql: `select row(null)`, + expected: []interface{}{nil}, + }, + { + sql: `select null::record`, + expected: nil, + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(context.Background(), tt.sql) + require.NoError(t, err) + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } +} From 94e10b98b1558e160816e82090e68dd9a7e8b66c Mon Sep 17 00:00:00 2001 From: Pinank Solanki Date: Wed, 2 Feb 2022 03:23:29 +0530 Subject: [PATCH 0885/1158] Fix typo in float8 --- float8.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/float8.go b/float8.go index 4d9e7116..6297ab5e 100644 --- a/float8.go +++ b/float8.go @@ -204,7 +204,7 @@ func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 { - return fmt.Errorf("invalid length for float4: %v", len(src)) + return fmt.Errorf("invalid length for float8: %v", len(src)) } n := int64(binary.BigEndian.Uint64(src)) From cebe44ee85040d17ad915f265809d6cab6ae9564 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Feb 2022 08:40:42 -0600 Subject: [PATCH 0886/1158] Restore range support --- Rakefile | 11 +- pgtype/int_test.go.erb | 2 +- pgtype/pgtype.go | 18 +- pgtype/range.go | 277 +++++++++++++++++++++++++ pgtype/range_codec.go | 414 +++++++++++++++++++++++++++++++++++++ pgtype/range_codec_test.go | 72 +++++++ pgtype/range_test.go | 177 ++++++++++++++++ pgtype/range_types.go | 218 +++++++++++++++++++ pgtype/range_types.go.erb | 49 +++++ 9 files changed, 1228 insertions(+), 10 deletions(-) create mode 100644 pgtype/range.go create mode 100644 pgtype/range_codec.go create mode 100644 pgtype/range_codec_test.go create mode 100644 pgtype/range_test.go create mode 100644 pgtype/range_types.go create mode 100644 pgtype/range_types.go.erb diff --git a/Rakefile b/Rakefile index 4579034d..f3a61a09 100644 --- a/Rakefile +++ b/Rakefile @@ -6,5 +6,14 @@ rule '.go' => '.go.erb' do |task| sh "goimports", "-w", task.name end +generated_code_files = [ + "pgtype/int.go", + "pgtype/int_test.go", + "pgtype/integration_benchmark_test.go", + "pgtype/range_types.go", + "pgtype/zeronull/int.go", + "pgtype/zeronull/int_test.go" +] + desc "Generate code" -task generate: ["pgtype/int.go", "pgtype/int_test.go", "pgtype/integration_benchmark_test.go", "pgtype/zeronull/int.go", "pgtype/zeronull/int_test.go"] +task generate: generated_code_files diff --git a/pgtype/int_test.go.erb b/pgtype/int_test.go.erb index c98f6488..8858ce90 100644 --- a/pgtype/int_test.go.erb +++ b/pgtype/int_test.go.erb @@ -10,7 +10,7 @@ import ( <% [2, 4, 8].each do |pg_byte_size| %> <% pg_bit_size = pg_byte_size * 8 %> func TestInt<%= pg_byte_size %>Codec(t *testing.T) { - testPgxCodec(t, "int<%= pg_byte_size %>", []testutil.TranscodeTestCase{ + testutil.RunTranscodeTests(t, "int<%= pg_byte_size %>", []testutil.TranscodeTestCase{ {int8(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, {int16(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, {int32(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index ab317f6e..ec6d3ec9 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -208,14 +208,6 @@ func NewConnInfo() *ConnInfo { }, } - // ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) - // ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) - // ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) - // ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) - // ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) - // ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID}) - // ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) - // ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID}) ci.RegisterDataType(DataType{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) ci.RegisterDataType(DataType{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) @@ -257,6 +249,16 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) + ci.RegisterDataType(DataType{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[DateOID]}}) + ci.RegisterDataType(DataType{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[Int4OID]}}) + ci.RegisterDataType(DataType{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[Int8OID]}}) + ci.RegisterDataType(DataType{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[NumericOID]}}) + ci.RegisterDataType(DataType{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[TimestampOID]}}) + ci.RegisterDataType(DataType{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[TimestamptzOID]}}) + + // ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID}) + // ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID}) + ci.RegisterDataType(DataType{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[ACLItemOID]}}) ci.RegisterDataType(DataType{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BitOID]}}) ci.RegisterDataType(DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BoolOID]}}) diff --git a/pgtype/range.go b/pgtype/range.go new file mode 100644 index 00000000..e999f6a9 --- /dev/null +++ b/pgtype/range.go @@ -0,0 +1,277 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +type BoundType byte + +const ( + Inclusive = BoundType('i') + Exclusive = BoundType('e') + Unbounded = BoundType('U') + Empty = BoundType('E') +) + +func (bt BoundType) String() string { + return string(bt) +} + +type UntypedTextRange struct { + Lower string + Upper string + LowerType BoundType + UpperType BoundType +} + +func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { + utr := &UntypedTextRange{} + if src == "empty" { + utr.LowerType = Empty + utr.UpperType = Empty + return utr, nil + } + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid lower bound: %v", err) + } + switch r { + case '(': + utr.LowerType = Exclusive + case '[': + utr.LowerType = Inclusive + default: + return nil, fmt.Errorf("missing lower bound, instead got: %v", string(r)) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid lower value: %v", err) + } + buf.UnreadRune() + + if r == ',' { + utr.LowerType = Unbounded + } else { + utr.Lower, err = rangeParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid lower value: %v", err) + } + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("missing range separator: %v", err) + } + if r != ',' { + return nil, fmt.Errorf("missing range separator: %v", r) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid upper value: %v", err) + } + + if r == ')' || r == ']' { + utr.UpperType = Unbounded + } else { + buf.UnreadRune() + utr.Upper, err = rangeParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid upper value: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("missing upper bound: %v", err) + } + switch r { + case ')': + utr.UpperType = Exclusive + case ']': + utr.UpperType = Inclusive + default: + return nil, fmt.Errorf("missing upper bound, instead got: %v", string(r)) + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + return utr, nil +} + +func rangeParseValue(buf *bytes.Buffer) (string, error) { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + if r == '"' { + return rangeParseQuotedValue(buf) + } + buf.UnreadRune() + + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case ',', '[', ']', '(', ')': + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} + +func rangeParseQuotedValue(buf *bytes.Buffer) (string, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case '"': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + if r != '"' { + buf.UnreadRune() + return s.String(), nil + } + } + s.WriteRune(r) + } +} + +type UntypedBinaryRange struct { + Lower []byte + Upper []byte + LowerType BoundType + UpperType BoundType +} + +// 0 = () = 00000 +// 1 = empty = 00001 +// 2 = [) = 00010 +// 4 = (] = 00100 +// 6 = [] = 00110 +// 8 = ) = 01000 +// 12 = ] = 01100 +// 16 = ( = 10000 +// 18 = [ = 10010 +// 24 = = 11000 + +const emptyMask = 1 +const lowerInclusiveMask = 2 +const upperInclusiveMask = 4 +const lowerUnboundedMask = 8 +const upperUnboundedMask = 16 + +func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { + ubr := &UntypedBinaryRange{} + + if len(src) == 0 { + return nil, fmt.Errorf("range too short: %v", len(src)) + } + + rangeType := src[0] + rp := 1 + + if rangeType&emptyMask > 0 { + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) + } + ubr.LowerType = Empty + ubr.UpperType = Empty + return ubr, nil + } + + if rangeType&lowerInclusiveMask > 0 { + ubr.LowerType = Inclusive + } else if rangeType&lowerUnboundedMask > 0 { + ubr.LowerType = Unbounded + } else { + ubr.LowerType = Exclusive + } + + if rangeType&upperInclusiveMask > 0 { + ubr.UpperType = Inclusive + } else if rangeType&upperUnboundedMask > 0 { + ubr.UpperType = Unbounded + } else { + ubr.UpperType = Exclusive + } + + if ubr.LowerType == Unbounded && ubr.UpperType == Unbounded { + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) + } + return ubr, nil + } + + if len(src[rp:]) < 4 { + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + } + valueLen := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + val := src[rp : rp+valueLen] + rp += valueLen + + if ubr.LowerType != Unbounded { + ubr.Lower = val + } else { + ubr.Upper = val + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + } + return ubr, nil + } + + if ubr.UpperType != Unbounded { + if len(src[rp:]) < 4 { + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + } + valueLen := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + ubr.Upper = src[rp : rp+valueLen] + rp += valueLen + } + + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + } + + return ubr, nil + +} diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go new file mode 100644 index 00000000..0dc63e6c --- /dev/null +++ b/pgtype/range_codec.go @@ -0,0 +1,414 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgio" +) + +// RangeValuer is a type that can be converted into a PostgreSQL range. +type RangeValuer interface { + // IsNull returns true if the value is SQL NULL. + IsNull() bool + + // BoundTypes returns the lower and upper bound types. + BoundTypes() (lower, upper BoundType) + + // Bounds returns the lower and upper range values. + Bounds() (lower, upper interface{}) +} + +// RangeScanner is a type can be scanned from a PostgreSQL range. +type RangeScanner interface { + // ScanNull sets the value to SQL NULL. + ScanNull() error + + // ScanBounds returns values usable as a scan target. The returned values may not be scanned if the range is empty or + // the bound type is unbounded. + ScanBounds() (lowerTarget, upperTarget interface{}) + + // SetBoundTypes sets the lower and upper bound types. ScanBounds will be called and the returned values scanned + // (if appropriate) before SetBoundTypes is called. + SetBoundTypes(lower, upper BoundType) error +} + +type GenericRange struct { + Lower interface{} + Upper interface{} + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r GenericRange) IsNull() bool { + return !r.Valid +} + +func (r GenericRange) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r GenericRange) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *GenericRange) ScanNull() error { + *r = GenericRange{} + return nil +} + +func (r *GenericRange) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *GenericRange) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} + +// RangeCodec is a codec for any range type. +type RangeCodec struct { + ElementDataType *DataType +} + +func (c *RangeCodec) FormatSupported(format int16) bool { + return c.ElementDataType.Codec.FormatSupported(format) +} + +func (c *RangeCodec) PreferredFormat() int16 { + if c.FormatSupported(BinaryFormatCode) { + return BinaryFormatCode + } + return TextFormatCode +} + +func (c *RangeCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(RangeValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return &encodePlanRangeCodecRangeValuerToBinary{rc: c, ci: ci} + case TextFormatCode: + return &encodePlanRangeCodecRangeValuerToText{rc: c, ci: ci} + } + + return nil +} + +type encodePlanRangeCodecRangeValuerToBinary struct { + rc *RangeCodec + ci *ConnInfo +} + +func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + getter := value.(RangeValuer) + + if getter.IsNull() { + return nil, nil + } + + lowerType, upperType := getter.BoundTypes() + lower, upper := getter.Bounds() + + var rangeType byte + switch lowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", lowerType) + } + + switch upperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", upperType) + } + + buf = append(buf, rangeType) + + if lowerType != Unbounded { + if lower == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + lowerPlan := plan.ci.PlanEncode(plan.rc.ElementDataType.OID, BinaryFormatCode, lower) + if lowerPlan == nil { + return nil, fmt.Errorf("cannot encode %v as element of range", lower) + } + + buf, err = lowerPlan.Encode(lower, buf) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %v", lower, err) + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if upperType != Unbounded { + if upper == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + upperPlan := plan.ci.PlanEncode(plan.rc.ElementDataType.OID, BinaryFormatCode, upper) + if upperPlan == nil { + return nil, fmt.Errorf("cannot encode %v as element of range", upper) + } + + buf, err = upperPlan.Encode(upper, buf) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %v", upper, err) + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +type encodePlanRangeCodecRangeValuerToText struct { + rc *RangeCodec + ci *ConnInfo +} + +func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + getter := value.(RangeValuer) + + if getter.IsNull() { + return nil, nil + } + + lowerType, upperType := getter.BoundTypes() + lower, upper := getter.Bounds() + + switch lowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", lowerType) + } + + if lowerType != Unbounded { + if lower == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + lowerPlan := plan.ci.PlanEncode(plan.rc.ElementDataType.OID, TextFormatCode, lower) + if lowerPlan == nil { + return nil, fmt.Errorf("cannot encode %v as element of range", lower) + } + + buf, err = lowerPlan.Encode(lower, buf) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %v", lower, err) + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if upperType != Unbounded { + if upper == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + upperPlan := plan.ci.PlanEncode(plan.rc.ElementDataType.OID, TextFormatCode, upper) + if upperPlan == nil { + return nil, fmt.Errorf("cannot encode %v as element of range", upper) + } + + buf, err = upperPlan.Encode(upper, buf) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %v", upper, err) + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch upperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", upperType) + } + + return buf, nil +} + +func (c *RangeCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case RangeScanner: + return &scanPlanBinaryRangeToRangeScanner{rc: c, ci: ci} + } + case TextFormatCode: + switch target.(type) { + case RangeScanner: + return &scanPlanTextRangeToRangeScanner{rc: c, ci: ci} + } + } + + return nil +} + +type scanPlanBinaryRangeToRangeScanner struct { + rc *RangeCodec + ci *ConnInfo +} + +func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target interface{}) error { + rangeScanner := (target).(RangeScanner) + + if src == nil { + return rangeScanner.ScanNull() + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + if ubr.LowerType == Empty { + return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType) + } + + lowerTarget, upperTarget := rangeScanner.ScanBounds() + + if ubr.LowerType == Inclusive || ubr.LowerType == Exclusive { + lowerPlan := plan.ci.PlanScan(plan.rc.ElementDataType.OID, BinaryFormatCode, lowerTarget) + if lowerPlan == nil { + return fmt.Errorf("cannot scan into %v from range element", lowerTarget) + } + + err = lowerPlan.Scan(ubr.Lower, lowerTarget) + if err != nil { + return fmt.Errorf("cannot scan into %v from range element: %v", lowerTarget, err) + } + } + + if ubr.UpperType == Inclusive || ubr.UpperType == Exclusive { + upperPlan := plan.ci.PlanScan(plan.rc.ElementDataType.OID, BinaryFormatCode, upperTarget) + if upperPlan == nil { + return fmt.Errorf("cannot scan into %v from range element", upperTarget) + } + + err = upperPlan.Scan(ubr.Upper, upperTarget) + if err != nil { + return fmt.Errorf("cannot scan into %v from range element: %v", upperTarget, err) + } + } + + return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType) +} + +type scanPlanTextRangeToRangeScanner struct { + rc *RangeCodec + ci *ConnInfo +} + +func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target interface{}) error { + rangeScanner := (target).(RangeScanner) + + if src == nil { + return rangeScanner.ScanNull() + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + if utr.LowerType == Empty { + return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType) + } + + lowerTarget, upperTarget := rangeScanner.ScanBounds() + + if utr.LowerType == Inclusive || utr.LowerType == Exclusive { + lowerPlan := plan.ci.PlanScan(plan.rc.ElementDataType.OID, TextFormatCode, lowerTarget) + if lowerPlan == nil { + return fmt.Errorf("cannot scan into %v from range element", lowerTarget) + } + + err = lowerPlan.Scan([]byte(utr.Lower), lowerTarget) + if err != nil { + return fmt.Errorf("cannot scan into %v from range element: %v", lowerTarget, err) + } + } + + if utr.UpperType == Inclusive || utr.UpperType == Exclusive { + upperPlan := plan.ci.PlanScan(plan.rc.ElementDataType.OID, TextFormatCode, upperTarget) + if upperPlan == nil { + return fmt.Errorf("cannot scan into %v from range element", upperTarget) + } + + err = upperPlan.Scan([]byte(utr.Upper), upperTarget) + if err != nil { + return fmt.Errorf("cannot scan into %v from range element: %v", upperTarget, err) + } + } + + return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType) +} + +func (c *RangeCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } +} + +func (c *RangeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var r GenericRange + err := c.PlanScan(ci, oid, format, &r, true).Scan(src, &r) + return r, err +} diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go new file mode 100644 index 00000000..b4cc9e8e --- /dev/null +++ b/pgtype/range_codec_test.go @@ -0,0 +1,72 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestRangeCodecTranscode(t *testing.T) { + testutil.RunTranscodeTests(t, "int4range", []testutil.TranscodeTestCase{ + { + pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + new(pgtype.Int4range), + isExpectedEq(pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}), + }, + { + pgtype.Int4range{ + LowerType: pgtype.Inclusive, + Lower: pgtype.Int4{Int: 1, Valid: true}, + Upper: pgtype.Int4{Int: 5, Valid: true}, + UpperType: pgtype.Exclusive, Valid: true, + }, + new(pgtype.Int4range), + isExpectedEq(pgtype.Int4range{ + LowerType: pgtype.Inclusive, + Lower: pgtype.Int4{Int: 1, Valid: true}, + Upper: pgtype.Int4{Int: 5, Valid: true}, + UpperType: pgtype.Exclusive, Valid: true, + }), + }, + {pgtype.Int4range{}, new(pgtype.Int4range), isExpectedEq(pgtype.Int4range{})}, + {nil, new(pgtype.Int4range), isExpectedEq(pgtype.Int4range{})}, + }) +} + +func TestRangeCodecDecodeValue(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + for _, tt := range []struct { + sql string + expected interface{} + }{ + { + sql: `select '[1,5)'::int4range`, + expected: pgtype.GenericRange{ + Lower: int32(1), + Upper: int32(5), + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(context.Background(), tt.sql) + require.NoError(t, err) + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } +} diff --git a/pgtype/range_test.go b/pgtype/range_test.go new file mode 100644 index 00000000..9e16df59 --- /dev/null +++ b/pgtype/range_test.go @@ -0,0 +1,177 @@ +package pgtype + +import ( + "bytes" + "testing" +) + +func TestParseUntypedTextRange(t *testing.T) { + tests := []struct { + src string + result UntypedTextRange + err error + }{ + { + src: `[1,2)`, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[1,2]`, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: `(1,3)`, + result: UntypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: ` [1,2) `, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[ foo , bar )`, + result: UntypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["foo","bar")`, + result: UntypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["f""oo","b""ar")`, + result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["f""oo","b""ar")`, + result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["","bar")`, + result: UntypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[f\"oo\,,b\\ar\))`, + result: UntypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `empty`, + result: UntypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty}, + err: nil, + }, + } + + for i, tt := range tests { + r, err := ParseUntypedTextRange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if r.LowerType != tt.result.LowerType { + t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) + } + + if r.UpperType != tt.result.UpperType { + t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) + } + + if r.Lower != tt.result.Lower { + t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) + } + + if r.Upper != tt.result.Upper { + t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) + } + } +} + +func TestParseUntypedBinaryRange(t *testing.T) { + tests := []struct { + src []byte + result UntypedBinaryRange + err error + }{ + { + src: []byte{0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{1}, + result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty}, + err: nil, + }, + { + src: []byte{2, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{4, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{6, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{8, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{12, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{16, 0, 0, 0, 2, 0, 4}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded}, + err: nil, + }, + { + src: []byte{18, 0, 0, 0, 2, 0, 4}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded}, + err: nil, + }, + { + src: []byte{24}, + result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded}, + err: nil, + }, + } + + for i, tt := range tests { + r, err := ParseUntypedBinaryRange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if r.LowerType != tt.result.LowerType { + t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) + } + + if r.UpperType != tt.result.UpperType { + t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) + } + + if bytes.Compare(r.Lower, tt.result.Lower) != 0 { + t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) + } + + if bytes.Compare(r.Upper, tt.result.Upper) != 0 { + t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) + } + } +} diff --git a/pgtype/range_types.go b/pgtype/range_types.go new file mode 100644 index 00000000..3f1e7d8a --- /dev/null +++ b/pgtype/range_types.go @@ -0,0 +1,218 @@ +// Do not edit. Generated from pgtype/range_types.go.erb +package pgtype + +type Int4range struct { + Lower Int4 + Upper Int4 + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Int4range) IsNull() bool { + return !r.Valid +} + +func (r Int4range) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Int4range) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Int4range) ScanNull() error { + *r = Int4range{} + return nil +} + +func (r *Int4range) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Int4range) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} + +type Int8range struct { + Lower Int8 + Upper Int8 + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Int8range) IsNull() bool { + return !r.Valid +} + +func (r Int8range) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Int8range) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Int8range) ScanNull() error { + *r = Int8range{} + return nil +} + +func (r *Int8range) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Int8range) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} + +type Numrange struct { + Lower Numeric + Upper Numeric + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Numrange) IsNull() bool { + return !r.Valid +} + +func (r Numrange) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Numrange) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Numrange) ScanNull() error { + *r = Numrange{} + return nil +} + +func (r *Numrange) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Numrange) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} + +type Tsrange struct { + Lower Timestamp + Upper Timestamp + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Tsrange) IsNull() bool { + return !r.Valid +} + +func (r Tsrange) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Tsrange) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Tsrange) ScanNull() error { + *r = Tsrange{} + return nil +} + +func (r *Tsrange) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Tsrange) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} + +type Tstzrange struct { + Lower Timestamptz + Upper Timestamptz + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Tstzrange) IsNull() bool { + return !r.Valid +} + +func (r Tstzrange) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Tstzrange) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Tstzrange) ScanNull() error { + *r = Tstzrange{} + return nil +} + +func (r *Tstzrange) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Tstzrange) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} + +type Daterange struct { + Lower Date + Upper Date + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Daterange) IsNull() bool { + return !r.Valid +} + +func (r Daterange) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Daterange) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Daterange) ScanNull() error { + *r = Daterange{} + return nil +} + +func (r *Daterange) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Daterange) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} diff --git a/pgtype/range_types.go.erb b/pgtype/range_types.go.erb new file mode 100644 index 00000000..11b12822 --- /dev/null +++ b/pgtype/range_types.go.erb @@ -0,0 +1,49 @@ +package pgtype + +<% + [ + ["Int4range", "Int4"], + ["Int8range", "Int8"], + ["Numrange", "Numeric"], + ["Tsrange", "Timestamp"], + ["Tstzrange", "Timestamptz"], + ["Daterange", "Date"] + ].each do |range_type, element_type| +%> +type <%= range_type %> struct { + Lower <%= element_type %> + Upper <%= element_type %> + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r <%= range_type %>) IsNull() bool { + return !r.Valid +} + +func (r <%= range_type %>) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r <%= range_type %>) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *<%= range_type %>) ScanNull() error { + *r = <%= range_type %>{} + return nil +} + +func (r *<%= range_type %>) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *<%= range_type %>) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} + +<% end %> From ba4583cf4c64dd593d21fbe3d0e9ce9bc277f2a9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Feb 2022 08:47:56 -0600 Subject: [PATCH 0887/1158] Add range array types --- pgtype/pgtype.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index ec6d3ec9..571b5bfc 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -92,13 +92,17 @@ const ( JSONBOID = 3802 JSONBArrayOID = 3807 DaterangeOID = 3912 + DaterangeArrayOID = 3913 Int4rangeOID = 3904 + Int4rangeArrayOID = 3905 NumrangeOID = 3906 + NumrangeArrayOID = 3907 TsrangeOID = 3908 TsrangeArrayOID = 3909 TstzrangeOID = 3910 TstzrangeArrayOID = 3911 Int8rangeOID = 3926 + Int8rangeArrayOID = 3927 ) type InfinityModifier int8 @@ -256,9 +260,6 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[TimestampOID]}}) ci.RegisterDataType(DataType{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[TimestamptzOID]}}) - // ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID}) - // ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID}) - ci.RegisterDataType(DataType{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[ACLItemOID]}}) ci.RegisterDataType(DataType{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BitOID]}}) ci.RegisterDataType(DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BoolOID]}}) @@ -270,12 +271,15 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[CIDROID]}}) ci.RegisterDataType(DataType{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[CircleOID]}}) ci.RegisterDataType(DataType{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[DateOID]}}) + ci.RegisterDataType(DataType{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[DaterangeOID]}}) ci.RegisterDataType(DataType{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Float4OID]}}) ci.RegisterDataType(DataType{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Float8OID]}}) ci.RegisterDataType(DataType{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[InetOID]}}) ci.RegisterDataType(DataType{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int2OID]}}) ci.RegisterDataType(DataType{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int4OID]}}) + ci.RegisterDataType(DataType{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int4rangeOID]}}) ci.RegisterDataType(DataType{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int8OID]}}) + ci.RegisterDataType(DataType{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int8rangeOID]}}) ci.RegisterDataType(DataType{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[IntervalOID]}}) ci.RegisterDataType(DataType{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[JSONOID]}}) ci.RegisterDataType(DataType{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[JSONBOID]}}) @@ -284,6 +288,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[MacaddrOID]}}) ci.RegisterDataType(DataType{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[NameOID]}}) ci.RegisterDataType(DataType{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[NumericOID]}}) + ci.RegisterDataType(DataType{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[NumrangeOID]}}) ci.RegisterDataType(DataType{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[OIDOID]}}) ci.RegisterDataType(DataType{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PathOID]}}) ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PointOID]}}) @@ -294,6 +299,8 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimeOID]}}) ci.RegisterDataType(DataType{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimestampOID]}}) ci.RegisterDataType(DataType{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimestamptzOID]}}) + ci.RegisterDataType(DataType{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TsrangeOID]}}) + ci.RegisterDataType(DataType{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TstzrangeOID]}}) ci.RegisterDataType(DataType{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[UUIDOID]}}) ci.RegisterDataType(DataType{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[VarbitOID]}}) ci.RegisterDataType(DataType{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[VarcharOID]}}) From a280f4db8a2646c33870897a93c5d2d2f7e66a42 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 3 Feb 2022 20:19:52 -0600 Subject: [PATCH 0888/1158] Float4 and Float8 implement Int64 Scanner and Valuer --- pgtype/float4.go | 9 +++++++++ pgtype/float8.go | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/pgtype/float4.go b/pgtype/float4.go index db9b2215..fd5f4523 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -25,6 +25,15 @@ func (f Float4) Float64Value() (Float8, error) { return Float8{Float: float64(f.Float), Valid: f.Valid}, nil } +func (f *Float4) ScanInt64(n Int8) error { + *f = Float4{Float: float32(n.Int), Valid: n.Valid} + return nil +} + +func (f Float4) Int64Value() (Int8, error) { + return Int8{Int: int64(f.Float), Valid: f.Valid}, nil +} + // Scan implements the database/sql Scanner interface. func (f *Float4) Scan(src interface{}) error { if src == nil { diff --git a/pgtype/float8.go b/pgtype/float8.go index 96dcb0f3..54b1796c 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -33,6 +33,15 @@ func (f Float8) Float64Value() (Float8, error) { return f, nil } +func (f *Float8) ScanInt64(n Int8) error { + *f = Float8{Float: float64(n.Int), Valid: n.Valid} + return nil +} + +func (f Float8) Int64Value() (Int8, error) { + return Int8{Int: int64(f.Float), Valid: f.Valid}, nil +} + // Scan implements the database/sql Scanner interface. func (f *Float8) Scan(src interface{}) error { if src == nil { From a74ebc9e51fe210504c63a13220390bbc8b1cef6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Feb 2022 08:39:53 -0600 Subject: [PATCH 0889/1158] pgtype.Numeric implements Float64Valuer --- pgtype/numeric.go | 55 ++++++++++++++++++++++++------------------ pgtype/numeric_test.go | 25 +++++++++++++++++++ 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/pgtype/numeric.go b/pgtype/numeric.go index 5bdbd4d5..d2311f3a 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -80,6 +80,35 @@ func (n Numeric) NumericValue() (Numeric, error) { return n, nil } +func (n Numeric) Float64Value() (Float8, error) { + if !n.Valid { + return Float8{}, nil + } else if n.NaN { + return Float8{Float: math.NaN(), Valid: true}, nil + } else if n.InfinityModifier == Infinity { + return Float8{Float: math.Inf(1), Valid: true}, nil + } else if n.InfinityModifier == NegativeInfinity { + return Float8{Float: math.Inf(-1), Valid: true}, nil + } + + buf := make([]byte, 0, 32) + + if n.Int == nil { + buf = append(buf, '0') + } else { + buf = append(buf, n.Int.String()...) + } + buf = append(buf, 'e') + buf = append(buf, strconv.FormatInt(int64(n.Exp), 10)...) + + f, err := strconv.ParseFloat(string(buf), 64) + if err != nil { + return Float8{}, err + } + + return Float8{Float: f, Valid: true}, nil +} + func (n *Numeric) toBigInt() (*big.Int, error) { if n.Exp == 0 { return n.Int, nil @@ -104,28 +133,6 @@ func (n *Numeric) toBigInt() (*big.Int, error) { return num, nil } -func (n *Numeric) toFloat64() (float64, error) { - if n.NaN { - return math.NaN(), nil - } else if n.InfinityModifier == Infinity { - return math.Inf(1), nil - } else if n.InfinityModifier == NegativeInfinity { - return math.Inf(-1), nil - } - - buf := make([]byte, 0, 32) - - buf = append(buf, n.Int.String()...) - buf = append(buf, 'e') - buf = append(buf, strconv.FormatInt(int64(n.Exp), 10)...) - - f, err := strconv.ParseFloat(string(buf), 64) - if err != nil { - return 0, err - } - return f, nil -} - func parseNumericString(str string) (n *big.Int, exp int32, err error) { parts := strings.SplitN(str, ".", 2) digits := strings.Join(parts, "") @@ -642,12 +649,12 @@ func (scanPlanBinaryNumericToFloat64Scanner) Scan(src []byte, dst interface{}) e return err } - f64, err := n.toFloat64() + f8, err := n.Float64Value() if err != nil { return err } - return scanner.ScanFloat64(Float8{Float: f64, Valid: true}) + return scanner.ScanFloat64(f8) } type scanPlanBinaryNumericToInt64Scanner struct{} diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index c74fb9a3..0449059e 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -11,6 +11,7 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -117,6 +118,30 @@ func TestNumericCodec(t *testing.T) { }) } +func TestNumericFloat64Valuer(t *testing.T) { + for i, tt := range []struct { + n pgtype.Numeric + f pgtype.Float8 + }{ + {mustParseNumeric(t, "1"), pgtype.Float8{Float: 1, Valid: true}}, + {mustParseNumeric(t, "0.0000000000000000001"), pgtype.Float8{Float: 0.0000000000000000001, Valid: true}}, + {mustParseNumeric(t, "-99999999999"), pgtype.Float8{Float: -99999999999, Valid: true}}, + {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, pgtype.Float8{Float: math.Inf(1), Valid: true}}, + {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, pgtype.Float8{Float: math.Inf(-1), Valid: true}}, + {pgtype.Numeric{Valid: true}, pgtype.Float8{Valid: true}}, + {pgtype.Numeric{}, pgtype.Float8{}}, + } { + f, err := tt.n.Float64Value() + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.f, f, "%d", i) + } + + f, err := pgtype.Numeric{NaN: true, Valid: true}.Float64Value() + assert.NoError(t, err) + assert.True(t, math.IsNaN(f.Float)) + assert.True(t, f.Valid) +} + func TestNumericCodecFuzz(t *testing.T) { r := rand.New(rand.NewSource(0)) max := &big.Int{} From 0355d2ffeae3bc0d5cb14ae63fbd407b790a9358 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Feb 2022 08:54:38 -0600 Subject: [PATCH 0890/1158] Add Float8range PostgreSQL doesn't define float8range out of the box though it can easily be created by the user. However, it is still convenient to treat a numrange as a float8range. --- pgtype/range_codec_test.go | 27 +++++++++++++++++++++++++++ pgtype/range_types.go | 36 ++++++++++++++++++++++++++++++++++++ pgtype/range_types.go.erb | 3 ++- 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go index b4cc9e8e..30095065 100644 --- a/pgtype/range_codec_test.go +++ b/pgtype/range_codec_test.go @@ -36,6 +36,33 @@ func TestRangeCodecTranscode(t *testing.T) { }) } +func TestRangeCodecTranscodeCompatibleRangeElementTypes(t *testing.T) { + testutil.RunTranscodeTests(t, "numrange", []testutil.TranscodeTestCase{ + { + pgtype.Float8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + new(pgtype.Float8range), + isExpectedEq(pgtype.Float8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}), + }, + { + pgtype.Float8range{ + LowerType: pgtype.Inclusive, + Lower: pgtype.Float8{Float: 1, Valid: true}, + Upper: pgtype.Float8{Float: 5, Valid: true}, + UpperType: pgtype.Exclusive, Valid: true, + }, + new(pgtype.Float8range), + isExpectedEq(pgtype.Float8range{ + LowerType: pgtype.Inclusive, + Lower: pgtype.Float8{Float: 1, Valid: true}, + Upper: pgtype.Float8{Float: 5, Valid: true}, + UpperType: pgtype.Exclusive, Valid: true, + }), + }, + {pgtype.Float8range{}, new(pgtype.Float8range), isExpectedEq(pgtype.Float8range{})}, + {nil, new(pgtype.Float8range), isExpectedEq(pgtype.Float8range{})}, + }) +} + func TestRangeCodecDecodeValue(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) diff --git a/pgtype/range_types.go b/pgtype/range_types.go index 3f1e7d8a..aa979d56 100644 --- a/pgtype/range_types.go +++ b/pgtype/range_types.go @@ -216,3 +216,39 @@ func (r *Daterange) SetBoundTypes(lower, upper BoundType) error { r.Valid = true return nil } + +type Float8range struct { + Lower Float8 + Upper Float8 + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Float8range) IsNull() bool { + return !r.Valid +} + +func (r Float8range) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Float8range) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Float8range) ScanNull() error { + *r = Float8range{} + return nil +} + +func (r *Float8range) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Float8range) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} diff --git a/pgtype/range_types.go.erb b/pgtype/range_types.go.erb index 11b12822..dc796a1d 100644 --- a/pgtype/range_types.go.erb +++ b/pgtype/range_types.go.erb @@ -7,7 +7,8 @@ package pgtype ["Numrange", "Numeric"], ["Tsrange", "Timestamp"], ["Tstzrange", "Timestamptz"], - ["Daterange", "Date"] + ["Daterange", "Date"], + ["Float8range", "Float8"] ].each do |range_type, element_type| %> type <%= range_type %> struct { From 288080c58c5aa31ad978e8a5840a25c28c02526a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Feb 2022 09:34:39 -0600 Subject: [PATCH 0891/1158] Add test documenting typed nil json encoding Encoded into json null not SQL NULL. --- pgtype/json_test.go | 2 ++ pgtype/jsonb_test.go | 2 ++ 2 files changed, 4 insertions(+) diff --git a/pgtype/json_test.go b/pgtype/json_test.go index a255c45a..a1dd63fb 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -48,6 +48,8 @@ func TestJSONCodec(t *testing.T) { {map[string]interface{}{"foo": "bar"}, new(map[string]interface{}), isExpectedEqMap(map[string]interface{}{"foo": "bar"})}, {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, + {map[string]interface{}(nil), new(string), isExpectedEq(`null`)}, + {map[string]interface{}(nil), new([]byte), isExpectedEqBytes([]byte("null"))}, {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, }) diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index 981ec28e..fa5ea20e 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -21,6 +21,8 @@ func TestJSONBTranscode(t *testing.T) { {map[string]interface{}{"foo": "bar"}, new(map[string]interface{}), isExpectedEqMap(map[string]interface{}{"foo": "bar"})}, {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, + {map[string]interface{}(nil), new(string), isExpectedEq(`null`)}, + {map[string]interface{}(nil), new([]byte), isExpectedEqBytes([]byte("null"))}, {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, }) From 6ebf54b62be12e5139f7995c8163cfdcb38fbcd6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Feb 2022 09:57:40 -0600 Subject: [PATCH 0892/1158] Fix EnumCodec caching and add tests --- pgtype/enum_codec.go | 7 ++-- pgtype/enum_codec_test.go | 69 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 3 deletions(-) create mode 100644 pgtype/enum_codec_test.go diff --git a/pgtype/enum_codec.go b/pgtype/enum_codec.go index 3bf29f4a..970895d8 100644 --- a/pgtype/enum_codec.go +++ b/pgtype/enum_codec.go @@ -76,10 +76,11 @@ func (c *EnumCodec) lookupAndCacheString(src []byte) string { if s, found := c.membersMap[string(src)]; found { return s - } else { - c.membersMap[s] = s - return s } + + s := string(src) + c.membersMap[s] = s + return s } type scanPlanTextAnyToEnumString struct { diff --git a/pgtype/enum_codec_test.go b/pgtype/enum_codec_test.go new file mode 100644 index 00000000..139bfc34 --- /dev/null +++ b/pgtype/enum_codec_test.go @@ -0,0 +1,69 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestEnumCodec(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists enum_test; + +create type enum_test as enum ('foo', 'bar', 'baz');`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type enum_test") + + dt, err := conn.LoadDataType(context.Background(), "enum_test") + require.NoError(t, err) + + conn.ConnInfo().RegisterDataType(*dt) + + var s string + err = conn.QueryRow(context.Background(), `select 'foo'::enum_test`).Scan(&s) + require.NoError(t, err) + require.Equal(t, "foo", s) + + err = conn.QueryRow(context.Background(), `select $1::enum_test`, "bar").Scan(&s) + require.NoError(t, err) + require.Equal(t, "bar", s) + + err = conn.QueryRow(context.Background(), `select 'foo'::enum_test`).Scan(&s) + require.NoError(t, err) + require.Equal(t, "foo", s) + + err = conn.QueryRow(context.Background(), `select $1::enum_test`, "bar").Scan(&s) + require.NoError(t, err) + require.Equal(t, "bar", s) + + err = conn.QueryRow(context.Background(), `select 'baz'::enum_test`).Scan(&s) + require.NoError(t, err) + require.Equal(t, "baz", s) +} + +func TestEnumCodecValues(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists enum_test; + +create type enum_test as enum ('foo', 'bar', 'baz');`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type enum_test") + + dt, err := conn.LoadDataType(context.Background(), "enum_test") + require.NoError(t, err) + + conn.ConnInfo().RegisterDataType(*dt) + + rows, err := conn.Query(context.Background(), `select 'foo'::enum_test`) + require.NoError(t, err) + require.True(t, rows.Next()) + values, err := rows.Values() + require.NoError(t, err) + require.Equal(t, values, []interface{}{"foo"}) +} From 28ea2cd1905808c0d7d9db0bc573ee4861ef91c6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Feb 2022 13:05:23 -0600 Subject: [PATCH 0893/1158] Better error messages --- pgtype/array_codec.go | 2 +- pgtype/pgtype.go | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 901ea5f7..5a02c435 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -250,7 +250,7 @@ func (c *ArrayCodec) decodeBinary(ci *ConnInfo, arrayOID uint32, src []byte, arr } err = elementScanPlan.Scan(elemSrc, elem) if err != nil { - return err + return fmt.Errorf("failed to scan array element %d: %w", i, err) } } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 571b5bfc..45b9a092 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -489,7 +489,17 @@ type scanPlanFail struct { } func (plan *scanPlanFail) Scan(src []byte, dst interface{}) error { - return fmt.Errorf("cannot scan OID %v in format %v into %T", plan.oid, plan.formatCode, dst) + var format string + switch plan.formatCode { + case TextFormatCode: + format = "text" + case BinaryFormatCode: + format = "binary" + default: + format = fmt.Sprintf("unknown %d", plan.formatCode) + } + + return fmt.Errorf("cannot scan OID %v in %v format into %T", plan.oid, format, dst) } // TryWrapScanPlanFunc is a function that tries to create a wrapper plan for target. If successful it returns a plan From 727fc19cb7b51af1cef6eca65d4599f34c1c6a5f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Feb 2022 13:10:58 -0600 Subject: [PATCH 0894/1158] Another error message improvement --- pgtype/pgtype.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 45b9a092..8c4a8c49 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1416,7 +1416,7 @@ func (ci *ConnInfo) Encode(oid uint32, formatCode int16, value interface{}, buf plan := ci.PlanEncode(oid, formatCode, value) if plan == nil { - return nil, fmt.Errorf("unable to encode %v", value) + return nil, fmt.Errorf("unable to encode %#v into OID %d", value, oid) } return plan.Encode(value, buf) } From 3a94113118632989997242de144d5d970a392067 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Feb 2022 14:24:34 -0600 Subject: [PATCH 0895/1158] Add composite to arbitrary struct encoding and decoding --- pgtype/builtin_wrappers.go | 37 ++++++++++++++++ pgtype/composite_test.go | 39 +++++++++++++++++ pgtype/pgtype.go | 89 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 165 insertions(+) diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index fe58eee0..1799de55 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -4,6 +4,7 @@ import ( "fmt" "math" "net" + "reflect" "strconv" "time" ) @@ -619,3 +620,39 @@ func (w byteSliceWrapper) UUIDValue() (UUID, error) { copy(uuid.Bytes[:], w) return uuid, nil } + +// structWrapper implements CompositeIndexGetter for a struct. +type structWrapper struct { + s interface{} + exportedFields []reflect.Value +} + +func (w structWrapper) IsNull() bool { + return w.s == nil +} + +func (w structWrapper) Index(i int) interface{} { + if i >= len(w.exportedFields) { + return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i) + } + + return w.exportedFields[i].Interface() +} + +// ptrStructWrapper implements CompositeIndexScanner for a pointer to a struct. +type ptrStructWrapper struct { + s interface{} + exportedFields []reflect.Value +} + +func (w *ptrStructWrapper) ScanNull() error { + return fmt.Errorf("cannot scan NULL into %#v", w.s) +} + +func (w *ptrStructWrapper) ScanIndex(i int) interface{} { + if i >= len(w.exportedFields) { + return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i) + } + + return w.exportedFields[i].Addr().Interface() +} diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go index c9319c2d..9a0eff2a 100644 --- a/pgtype/composite_test.go +++ b/pgtype/composite_test.go @@ -123,3 +123,42 @@ create type point3d as ( require.Equalf(t, input, output, "%v", format.name) } } + +func TestCompositeCodecTranscodeStructWrapper(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists point3d; + +create type point3d as ( + x float8, + y float8, + z float8 +);`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type point3d") + + dt, err := conn.LoadDataType(context.Background(), "point3d") + require.NoError(t, err) + conn.ConnInfo().RegisterDataType(*dt) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + type anotherPoint struct { + X, Y, Z float64 + } + + for _, format := range formats { + input := anotherPoint{X: 1, Y: 2, Z: 3} + var output anotherPoint + err := conn.QueryRow(context.Background(), "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output) + require.NoErrorf(t, err, "%v", format.name) + require.Equalf(t, input, output, "%v", format.name) + } +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 8c4a8c49..8db5ae3f 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -203,12 +203,14 @@ func NewConnInfo() *ConnInfo { TryWrapDerefPointerEncodePlan, TryWrapBuiltinTypeEncodePlan, TryWrapFindUnderlyingTypeEncodePlan, + TryWrapStructEncodePlan, }, TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{ TryPointerPointerScanPlan, TryWrapBuiltinTypeScanPlan, TryFindUnderlyingTypeScanPlan, + TryWrapStructScanPlan, }, } @@ -887,6 +889,47 @@ func (plan *pointerEmptyInterfaceScanPlan) Scan(src []byte, dst interface{}) err return nil } +// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. +func TryWrapStructScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextValue interface{}, ok bool) { + targetValue := reflect.ValueOf(target) + if targetValue.Kind() != reflect.Ptr { + return nil, nil, false + } + + targetElemValue := targetValue.Elem() + targetElemType := targetElemValue.Type() + + if targetElemType.Kind() == reflect.Struct { + exportedFields := getExportedFieldValues(targetElemValue) + if len(exportedFields) == 0 { + return nil, nil, false + } + + w := ptrStructWrapper{ + s: target, + exportedFields: exportedFields, + } + return &wrapAnyPtrStructScanPlan{}, &w, true + } + + return nil, nil, false +} + +type wrapAnyPtrStructScanPlan struct { + next ScanPlan +} + +func (plan *wrapAnyPtrStructScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapAnyPtrStructScanPlan) Scan(src []byte, target interface{}) error { + w := ptrStructWrapper{ + s: target, + exportedFields: getExportedFieldValues(reflect.ValueOf(target).Elem()), + } + + return plan.next.Scan(src, &w) +} + // PlanScan prepares a plan to scan a value into target. func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) ScanPlan { if _, ok := target.(*UndecodedBytes); ok { @@ -1406,6 +1449,52 @@ func (plan *wrapFmtStringerEncodePlan) Encode(value interface{}, buf []byte) (ne return plan.next.Encode(fmtStringerWrapper{value.(fmt.Stringer)}, buf) } +// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. +func TryWrapStructEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { + if reflect.TypeOf(value).Kind() == reflect.Struct { + exportedFields := getExportedFieldValues(reflect.ValueOf(value)) + if len(exportedFields) == 0 { + return nil, nil, false + } + + w := structWrapper{ + s: value, + exportedFields: exportedFields, + } + return &wrapAnyStructEncodePlan{}, w, true + } + + return nil, nil, false +} + +type wrapAnyStructEncodePlan struct { + next EncodePlan +} + +func (plan *wrapAnyStructEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapAnyStructEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + w := structWrapper{ + s: value, + exportedFields: getExportedFieldValues(reflect.ValueOf(value)), + } + + return plan.next.Encode(w, buf) +} + +func getExportedFieldValues(structValue reflect.Value) []reflect.Value { + structType := structValue.Type() + exportedFields := make([]reflect.Value, 0, structValue.NumField()) + for i := 0; i < structType.NumField(); i++ { + sf := structType.Field(i) + if sf.IsExported() { + exportedFields = append(exportedFields, structValue.Field(i)) + } + } + + return exportedFields +} + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written. From 02372f1c3c30d54d271ac0e32d16a866f22ac620 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Feb 2022 15:12:09 -0600 Subject: [PATCH 0896/1158] Add DecodeValue to composites --- pgtype/composite.go | 69 ++++++++++++++++++++++++++++++++++------ pgtype/composite_test.go | 39 +++++++++++++++++++++++ pgtype/record_codec.go | 11 ++++++- 3 files changed, 109 insertions(+), 10 deletions(-) diff --git a/pgtype/composite.go b/pgtype/composite.go index d21ab665..2ccc7b1d 100644 --- a/pgtype/composite.go +++ b/pgtype/composite.go @@ -209,11 +209,16 @@ func (c *CompositeCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format return nil, nil } - // var n int64 - // err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) - // return n, err - - return nil, fmt.Errorf("not implemented") + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } } func (c *CompositeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { @@ -221,11 +226,57 @@ func (c *CompositeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src return nil, nil } - // var n int16 - // err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) - // return n, err + switch format { + case TextFormatCode: + scanner := NewCompositeTextScanner(ci, src) + values := make(map[string]interface{}, len(c.Fields)) + for i := 0; scanner.Next() && i < len(c.Fields); i++ { + var v interface{} + fieldPlan := ci.PlanScan(c.Fields[i].DataType.OID, TextFormatCode, &v) + if fieldPlan == nil { + return nil, fmt.Errorf("unable to scan OID %d in text format into %v", c.Fields[i].DataType.OID, v) + } + + err := fieldPlan.Scan(scanner.Bytes(), &v) + if err != nil { + return nil, err + } + + values[c.Fields[i].Name] = v + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return values, nil + case BinaryFormatCode: + scanner := NewCompositeBinaryScanner(ci, src) + values := make(map[string]interface{}, len(c.Fields)) + for i := 0; scanner.Next() && i < len(c.Fields); i++ { + var v interface{} + fieldPlan := ci.PlanScan(scanner.OID(), BinaryFormatCode, &v) + if fieldPlan == nil { + return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v) + } + + err := fieldPlan.Scan(scanner.Bytes(), &v) + if err != nil { + return nil, err + } + + values[c.Fields[i].Name] = v + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return values, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } - return nil, fmt.Errorf("not implemented") } type CompositeBinaryScanner struct { diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go index 9a0eff2a..66db4281 100644 --- a/pgtype/composite_test.go +++ b/pgtype/composite_test.go @@ -162,3 +162,42 @@ create type point3d as ( require.Equalf(t, input, output, "%v", format.name) } } + +func TestCompositeCodecDecodeValue(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists point3d; + +create type point3d as ( + x float8, + y float8, + z float8 +);`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type point3d") + + dt, err := conn.LoadDataType(context.Background(), "point3d") + require.NoError(t, err) + conn.ConnInfo().RegisterDataType(*dt) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + rows, err := conn.Query(context.Background(), "select '(1,2,3)'::point3d", pgx.QueryResultFormats{format.code}) + require.NoErrorf(t, err, "%v", format.name) + require.True(t, rows.Next()) + values, err := rows.Values() + require.NoErrorf(t, err, "%v", format.name) + require.Lenf(t, values, 1, "%v", format.name) + require.Equalf(t, map[string]interface{}{"x": 1.0, "y": 2.0, "z": 3.0}, values[0], "%v", format.name) + require.False(t, rows.Next()) + require.NoErrorf(t, rows.Err(), "%v", format.name) + } +} diff --git a/pgtype/record_codec.go b/pgtype/record_codec.go index 31001b1f..92c197b2 100644 --- a/pgtype/record_codec.go +++ b/pgtype/record_codec.go @@ -75,7 +75,16 @@ func (RecordCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16 return nil, nil } - return nil, fmt.Errorf("not implemented") + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } } func (RecordCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { From f4252a58be6acfa4acdb5c5aa18d6e40d77b5b76 Mon Sep 17 00:00:00 2001 From: Collin Forsyth Date: Wed, 2 Feb 2022 23:37:56 -0500 Subject: [PATCH 0897/1158] correctly Scan type aliases for floating point types --- convert.go | 4 ++++ float4_test.go | 7 +++++++ float8_test.go | 8 ++++++++ 3 files changed, 19 insertions(+) diff --git a/convert.go b/convert.go index de9ba9ba..f7219bd4 100644 --- a/convert.go +++ b/convert.go @@ -337,6 +337,10 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { el := v.Elem() switch el.Kind() { + // if dst is a type alias of a float32 or 64, set dst val + case reflect.Float32, reflect.Float64: + el.SetFloat(srcVal) + return nil // if dst is a pointer to pointer, strip the pointer and try again case reflect.Ptr: if el.IsNil() { diff --git a/float4_test.go b/float4_test.go index d2524cda..1977f194 100644 --- a/float4_test.go +++ b/float4_test.go @@ -56,6 +56,9 @@ func TestFloat4Set(t *testing.T) { } func TestFloat4AssignTo(t *testing.T) { + type aliasf32 float32 + type aliasf64 float64 + var i8 int8 var i16 int16 var i32 int32 @@ -73,6 +76,8 @@ func TestFloat4AssignTo(t *testing.T) { var f64 float64 var pf32 *float32 var pf64 *float64 + var a32 aliasf32 + var a64 aliasf64 simpleTests := []struct { src pgtype.Float4 @@ -91,6 +96,8 @@ func TestFloat4AssignTo(t *testing.T) { {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &a32, expected: aliasf32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &a64, expected: aliasf64(42)}, {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, } diff --git a/float8_test.go b/float8_test.go index 6bc7c652..c21f00d0 100644 --- a/float8_test.go +++ b/float8_test.go @@ -56,6 +56,9 @@ func TestFloat8Set(t *testing.T) { } func TestFloat8AssignTo(t *testing.T) { + type aliasf32 float32 + type aliasf64 float64 + var i8 int8 var i16 int16 var i32 int32 @@ -73,6 +76,8 @@ func TestFloat8AssignTo(t *testing.T) { var f64 float64 var pf32 *float32 var pf64 *float64 + var a32 aliasf32 + var a64 aliasf64 simpleTests := []struct { src pgtype.Float8 @@ -91,6 +96,9 @@ func TestFloat8AssignTo(t *testing.T) { {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &a32, expected: aliasf32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &a64, expected: aliasf64(42)}, + {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, } From 202542ead5c88f0d66fddc0159b70b4c6c948170 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 7 Feb 2022 10:51:03 -0600 Subject: [PATCH 0898/1158] Release v1.10.0 --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e34c7979..73126cf3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +# 1.10.0 (February 7, 2022) + +* Normalize UTC timestamps to comply with stdlib (Torkel Rogstad) +* Assign Numeric to *big.Rat (Oleg Lomaka) +* Fix typo in float8 error message (Pinank Solanki) +* Scan type aliases for floating point types (Collin Forsyth) + # 1.9.1 (November 28, 2021) * Fix: binary timestamp is assumed to be in UTC (restored behavior changed in v1.9.0) From 3e5de443149f6d3dbbe1bf7387203c413d0efab4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 7 Feb 2022 10:54:39 -0600 Subject: [PATCH 0899/1158] Release v1.11.0 --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 63933a3a..a37eecfe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +# 1.11.0 (February 7, 2022) + +* Support port in ip from LookupFunc to override config (James Hartig) +* Fix TLS connection timeout (Blake Embrey) +* Add support for read-only, primary, standby, prefer-standby target_session_attributes (Oscar) +* Fix connect when receiving NoticeResponse + # 1.10.1 (November 20, 2021) * Close without waiting for response (Kei Kamikawa) From 7193e4892302469eeb6ffe7410672037b2404697 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 8 Feb 2022 10:07:40 -0600 Subject: [PATCH 0900/1158] Restore multi-dimensional slices Move ArrayCode to use pgtype wrapper pattern as well --- Rakefile | 1 + pgtype/array_codec.go | 78 ++++++++++---- pgtype/array_codec_test.go | 57 +++++++++++ pgtype/array_getter_setter.go | 99 +++--------------- pgtype/array_getter_setter.go.erb | 80 ++------------- pgtype/builtin_wrappers.go | 165 ++++++++++++++++++++++++++++++ pgtype/pgtype.go | 117 +++++++++++++++++++++ 7 files changed, 424 insertions(+), 173 deletions(-) diff --git a/Rakefile b/Rakefile index f3a61a09..3fe26cb5 100644 --- a/Rakefile +++ b/Rakefile @@ -7,6 +7,7 @@ rule '.go' => '.go.erb' do |task| end generated_code_files = [ + "pgtype/array_getter_setter.go", "pgtype/int.go", "pgtype/int_test.go", "pgtype/integration_benchmark_test.go", diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 5a02c435..94d24fc9 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -16,6 +16,9 @@ type ArrayGetter interface { // Index returns the element at i. Index(i int) interface{} + + // IndexType returns a non-nil scan target of the type Index will return. This is used by ArrayCodec.PlanEncode. + IndexType() interface{} } // ArraySetter is a type can be set from a PostgreSQL array. @@ -27,6 +30,10 @@ type ArraySetter interface { // ScanIndex returns a value usable as a scan target for i. SetDimensions must be called before ScanIndex. ScanIndex(i int) interface{} + + // ScanIndexType returns a non-nil scan target of the type ScanIndex will return. This is used by + // ArrayCodec.PlanScan. + ScanIndexType() interface{} } // ArrayCodec is a codec for any array type. @@ -43,6 +50,18 @@ func (c *ArrayCodec) PreferredFormat() int16 { } func (c *ArrayCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + arrayValuer, ok := value.(ArrayGetter) + if !ok { + return nil + } + + elementType := arrayValuer.IndexType() + + elementEncodePlan := ci.PlanEncode(c.ElementDataType.OID, format, elementType) + if elementEncodePlan == nil { + return nil + } + switch format { case BinaryFormatCode: return &encodePlanArrayCodecBinary{ac: c, ci: ci, oid: oid} @@ -60,10 +79,7 @@ type encodePlanArrayCodecText struct { } func (p *encodePlanArrayCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - array, err := makeArrayGetter(value) - if err != nil { - return nil, err - } + array := value.(ArrayGetter) dimensions := array.Dimensions() if dimensions == nil { @@ -142,10 +158,7 @@ type encodePlanArrayCodecBinary struct { } func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - array, err := makeArrayGetter(value) - if err != nil { - return nil, err - } + array := value.(ArrayGetter) dimensions := array.Dimensions() if dimensions == nil { @@ -198,8 +211,15 @@ func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newB } func (c *ArrayCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { - _, err := makeArraySetter(target) - if err != nil { + arrayScanner, ok := target.(ArraySetter) + if !ok { + return nil + } + + elementType := arrayScanner.ScanIndexType() + + elementScanPlan := ci.PlanScan(c.ElementDataType.OID, format, elementType) + if _, ok := elementScanPlan.(*scanPlanFail); ok { return nil } @@ -300,10 +320,11 @@ func (c *ArrayCodec) decodeText(ci *ConnInfo, arrayOID uint32, src []byte, array } type scanPlanArrayCodec struct { - arrayCodec *ArrayCodec - ci *ConnInfo - oid uint32 - formatCode int16 + arrayCodec *ArrayCodec + ci *ConnInfo + oid uint32 + formatCode int16 + elementScanPlan ScanPlan } func (spac *scanPlanArrayCodec) Scan(src []byte, dst interface{}) error { @@ -312,11 +333,7 @@ func (spac *scanPlanArrayCodec) Scan(src []byte, dst interface{}) error { oid := spac.oid formatCode := spac.formatCode - array, err := makeArraySetter(dst) - if err != nil { - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(src, dst) - } + array := dst.(ArraySetter) if src == nil { return array.SetDimensions(nil) @@ -358,3 +375,26 @@ func (c *ArrayCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []b err := ci.PlanScan(oid, format, &slice).Scan(src, &slice) return slice, err } + +func isRagged(slice reflect.Value) bool { + if slice.Type().Elem().Kind() != reflect.Slice { + return false + } + + sliceLen := slice.Len() + innerLen := 0 + for i := 0; i < sliceLen; i++ { + if i == 0 { + innerLen = slice.Index(i).Len() + } else { + if slice.Index(i).Len() != innerLen { + return true + } + } + if isRagged(slice.Index(i)) { + return true + } + } + + return false +} diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go index 0c31dcee..b4b9b6a7 100644 --- a/pgtype/array_codec_test.go +++ b/pgtype/array_codec_test.go @@ -108,3 +108,60 @@ func TestArrayCodecDecodeValue(t *testing.T) { }) } } + +func TestArrayCodecScanMultipleDimensions(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + rows, err := conn.Query(context.Background(), `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`) + require.NoError(t, err) + + for rows.Next() { + var ss [][]int32 + err := rows.Scan(&ss) + require.NoError(t, err) + require.Equal(t, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, ss) + } + + require.NoError(t, rows.Err()) +} + +func TestArrayCodecScanWrongMultipleDimensions(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + rows, err := conn.Query(context.Background(), `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`) + require.NoError(t, err) + + for rows.Next() { + var ss [][][]int32 + err := rows.Scan(&ss) + require.Error(t, err, "can't scan into dest[0]: PostgreSQL array has 2 dimensions but slice has 3 dimensions") + } +} + +func TestArrayCodecEncodeMultipleDimensions(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + rows, err := conn.Query(context.Background(), `select $1::int4[]`, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}) + require.NoError(t, err) + + for rows.Next() { + var ss [][]int32 + err := rows.Scan(&ss) + require.NoError(t, err) + require.Equal(t, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, ss) + } + + require.NoError(t, rows.Err()) +} + +func TestArrayCodecEncodeMultipleDimensionsRagged(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + rows, err := conn.Query(context.Background(), `select $1::int4[]`, [][]int32{{1, 2, 3, 4}, {5}, {9, 10, 11, 12}}) + require.Error(t, err, "cannot convert [][]int32 to ArrayGetter because it is a ragged multi-dimensional") + defer rows.Close() +} diff --git a/pgtype/array_getter_setter.go b/pgtype/array_getter_setter.go index 72a6f0e7..2e20f9ec 100644 --- a/pgtype/array_getter_setter.go +++ b/pgtype/array_getter_setter.go @@ -1,10 +1,6 @@ +// Do not edit. Generated from pgtype/array_getter_setter.go.erb package pgtype -import ( - "fmt" - "reflect" -) - type int16Array []int16 func (a int16Array) Dimensions() []ArrayDimension { @@ -19,6 +15,11 @@ func (a int16Array) Index(i int) interface{} { return a[i] } +func (a int16Array) IndexType() interface{} { + var el int16 + return el +} + func (a *int16Array) SetDimensions(dimensions []ArrayDimension) error { if dimensions == nil { a = nil @@ -34,6 +35,10 @@ func (a int16Array) ScanIndex(i int) interface{} { return &a[i] } +func (a int16Array) ScanIndexType() interface{} { + return new(int16) +} + type uint16Array []uint16 func (a uint16Array) Dimensions() []ArrayDimension { @@ -48,6 +53,11 @@ func (a uint16Array) Index(i int) interface{} { return a[i] } +func (a uint16Array) IndexType() interface{} { + var el uint16 + return el +} + func (a *uint16Array) SetDimensions(dimensions []ArrayDimension) error { if dimensions == nil { a = nil @@ -63,81 +73,6 @@ func (a uint16Array) ScanIndex(i int) interface{} { return &a[i] } -type anySliceArray struct { - slice reflect.Value -} - -func (a anySliceArray) Dimensions() []ArrayDimension { - if a.slice.IsNil() { - return nil - } - - return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}} -} - -func (a anySliceArray) Index(i int) interface{} { - return a.slice.Index(i).Interface() -} - -func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error { - sliceType := a.slice.Type() - - if dimensions == nil { - a.slice.Set(reflect.Zero(sliceType)) - return nil - } - - elementCount := cardinality(dimensions) - slice := reflect.MakeSlice(sliceType, elementCount, elementCount) - a.slice.Set(slice) - return nil -} - -func (a anySliceArray) ScanIndex(i int) interface{} { - return a.slice.Index(i).Addr().Interface() -} - -func makeArrayGetter(a interface{}) (ArrayGetter, error) { - switch a := a.(type) { - case ArrayGetter: - return a, nil - - case []int16: - return (*int16Array)(&a), nil - - case []uint16: - return (*uint16Array)(&a), nil - - } - - reflectValue := reflect.ValueOf(a) - if reflectValue.Kind() == reflect.Slice { - return &anySliceArray{slice: reflectValue}, nil - } - - return nil, fmt.Errorf("cannot convert %T to ArrayGetter", a) -} - -func makeArraySetter(a interface{}) (ArraySetter, error) { - switch a := a.(type) { - case ArraySetter: - return a, nil - - case *[]int16: - return (*int16Array)(a), nil - - case *[]uint16: - return (*uint16Array)(a), nil - - } - - value := reflect.ValueOf(a) - if value.Kind() == reflect.Ptr { - elemValue := value.Elem() - if elemValue.Kind() == reflect.Slice { - return &anySliceArray{slice: elemValue}, nil - } - } - - return nil, fmt.Errorf("cannot convert %T to ArraySetter", a) +func (a uint16Array) ScanIndexType() interface{} { + return new(uint16) } diff --git a/pgtype/array_getter_setter.go.erb b/pgtype/array_getter_setter.go.erb index 01b7d4fa..a9d60d35 100644 --- a/pgtype/array_getter_setter.go.erb +++ b/pgtype/array_getter_setter.go.erb @@ -27,6 +27,11 @@ import ( return a[i] } + func (a <%= array_type %>) IndexType() interface{} { + var el <%= element_type %> + return el + } + func (a *<%= array_type %>) SetDimensions(dimensions []ArrayDimension) error { if dimensions == nil { a = nil @@ -41,77 +46,8 @@ import ( func (a <%= array_type %>) ScanIndex(i int) interface{} { return &a[i] } -<% end %> -type anySliceArray struct { - slice reflect.Value -} - -func (a anySliceArray) Dimensions() []ArrayDimension { - if a.slice.IsNil() { - return nil - } - - return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}} -} - -func (a anySliceArray) Index(i int) interface{} { - return a.slice.Index(i).Interface() -} - -func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error { - sliceType := a.slice.Type() - - if dimensions == nil { - a.slice.Set(reflect.Zero(sliceType)) - return nil - } - - elementCount := cardinality(dimensions) - slice := reflect.MakeSlice(sliceType, elementCount, elementCount) - a.slice.Set(slice) - return nil -} - -func (a anySliceArray) ScanIndex(i int) interface{} { - return a.slice.Index(i).Addr().Interface() -} - -func makeArrayGetter(a interface{}) (ArrayGetter, error) { - switch a := a.(type) { - case ArrayGetter: - return a, nil - <% types.each do |array_type, element_type| %> - case []<%= element_type %>: - return (*<%= array_type %>)(&a), nil - <% end %> - } - - reflectValue := reflect.ValueOf(a) - if reflectValue.Kind() == reflect.Slice { - return &anySliceArray{slice: reflectValue}, nil + func (a <%= array_type %>) ScanIndexType() interface{} { + return new(<%= element_type %>) } - - return nil, fmt.Errorf("cannot convert %T to ArrayGetter", a) -} - -func makeArraySetter(a interface{}) (ArraySetter, error) { - switch a := a.(type) { - case ArraySetter: - return a, nil - <% types.each do |array_type, element_type| %> - case *[]<%= element_type %>: - return (*<%= array_type %>)(a), nil - <% end %> - } - - value := reflect.ValueOf(a) - if value.Kind() == reflect.Ptr { - elemValue := value.Elem() - if elemValue.Kind() == reflect.Slice { - return &anySliceArray{slice: elemValue}, nil - } - } - - return nil, fmt.Errorf("cannot convert %T to ArraySetter", a) -} +<% end %> diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 1799de55..466ef45a 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -656,3 +656,168 @@ func (w *ptrStructWrapper) ScanIndex(i int) interface{} { return w.exportedFields[i].Addr().Interface() } + +type anySliceArray struct { + slice reflect.Value +} + +func (a anySliceArray) Dimensions() []ArrayDimension { + if a.slice.IsNil() { + return nil + } + + return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}} +} + +func (a anySliceArray) Index(i int) interface{} { + return a.slice.Index(i).Interface() +} + +func (a anySliceArray) IndexType() interface{} { + return reflect.New(a.slice.Type().Elem()).Elem().Interface() +} + +func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error { + sliceType := a.slice.Type() + + if dimensions == nil { + a.slice.Set(reflect.Zero(sliceType)) + return nil + } + + elementCount := cardinality(dimensions) + slice := reflect.MakeSlice(sliceType, elementCount, elementCount) + a.slice.Set(slice) + return nil +} + +func (a *anySliceArray) ScanIndex(i int) interface{} { + return a.slice.Index(i).Addr().Interface() +} + +func (a *anySliceArray) ScanIndexType() interface{} { + return reflect.New(a.slice.Type().Elem()).Interface() +} + +type anyMultiDimSliceArray struct { + slice reflect.Value + dims []ArrayDimension +} + +func (a *anyMultiDimSliceArray) Dimensions() []ArrayDimension { + if a.slice.IsNil() { + return nil + } + + s := a.slice + for { + a.dims = append(a.dims, ArrayDimension{Length: int32(s.Len()), LowerBound: 1}) + if s.Len() > 0 { + s = s.Index(0) + } else { + break + } + if s.Type().Kind() == reflect.Slice { + } else { + break + } + } + + return a.dims +} + +func (a *anyMultiDimSliceArray) Index(i int) interface{} { + if len(a.dims) == 1 { + return a.slice.Index(i).Interface() + } + + indexes := make([]int, len(a.dims)) + for j := len(a.dims) - 1; j >= 0; j-- { + dimLen := int(a.dims[j].Length) + indexes[j] = i % dimLen + i = i / dimLen + } + + v := a.slice + for _, si := range indexes { + v = v.Index(si) + } + + return v.Interface() +} + +func (a *anyMultiDimSliceArray) IndexType() interface{} { + lowestSliceType := a.slice.Type() + for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() { + } + return reflect.New(lowestSliceType.Elem()).Elem().Interface() +} + +func (a *anyMultiDimSliceArray) SetDimensions(dimensions []ArrayDimension) error { + sliceType := a.slice.Type() + + if dimensions == nil { + a.slice.Set(reflect.Zero(sliceType)) + return nil + } + + switch len(dimensions) { + case 0: + return fmt.Errorf("impossible: non-nil dimensions but zero elements") + case 1: + elementCount := cardinality(dimensions) + slice := reflect.MakeSlice(sliceType, elementCount, elementCount) + a.slice.Set(slice) + return nil + default: + sliceDimensionCount := 1 + lowestSliceType := sliceType + for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() { + sliceDimensionCount++ + } + + if sliceDimensionCount != len(dimensions) { + return fmt.Errorf("PostgreSQL array has %d dimensions but slice has %d dimensions", len(dimensions), sliceDimensionCount) + } + + elementCount := cardinality(dimensions) + flatSlice := reflect.MakeSlice(lowestSliceType, elementCount, elementCount) + + multiDimSlice := a.makeMultidimensionalSlice(sliceType, dimensions, flatSlice, 0) + a.slice.Set(multiDimSlice) + + // Now that a.slice is a multi-dimensional slice with the underlying data pointed at flatSlice change a.slice to + // flatSlice so ScanIndex only has to handle simple one dimensional slices. + a.slice = flatSlice + + return nil + } + +} + +func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type, dimensions []ArrayDimension, flatSlice reflect.Value, flatSliceIdx int) reflect.Value { + if len(dimensions) == 1 { + endIdx := flatSliceIdx + int(dimensions[0].Length) + return flatSlice.Slice3(flatSliceIdx, endIdx, endIdx) + } + + sliceLen := int(dimensions[0].Length) + slice := reflect.MakeSlice(sliceType, sliceLen, sliceLen) + for i := 0; i < sliceLen; i++ { + subSlice := a.makeMultidimensionalSlice(sliceType.Elem(), dimensions[1:], flatSlice, flatSliceIdx+(i*int(dimensions[1].Length))) + slice.Index(i).Set(subSlice) + } + + return slice +} + +func (a *anyMultiDimSliceArray) ScanIndex(i int) interface{} { + return a.slice.Index(i).Addr().Interface() +} + +func (a *anyMultiDimSliceArray) ScanIndexType() interface{} { + lowestSliceType := a.slice.Type() + for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() { + } + return reflect.New(lowestSliceType.Elem()).Interface() +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 8db5ae3f..54792963 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -204,6 +204,8 @@ func NewConnInfo() *ConnInfo { TryWrapBuiltinTypeEncodePlan, TryWrapFindUnderlyingTypeEncodePlan, TryWrapStructEncodePlan, + TryWrapSliceEncodePlan, + TryWrapMultiDimSliceEncodePlan, }, TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{ @@ -211,6 +213,8 @@ func NewConnInfo() *ConnInfo { TryWrapBuiltinTypeScanPlan, TryFindUnderlyingTypeScanPlan, TryWrapStructScanPlan, + TryWrapPtrSliceScanPlan, + TryWrapPtrMultiDimSliceScanPlan, }, } @@ -930,6 +934,62 @@ func (plan *wrapAnyPtrStructScanPlan) Scan(src []byte, target interface{}) error return plan.next.Scan(src, &w) } +// TryWrapPtrSliceScanPlan tries to wrap a pointer to a single dimension slice. +func TryWrapPtrSliceScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextValue interface{}, ok bool) { + targetValue := reflect.ValueOf(target) + if targetValue.Kind() != reflect.Ptr { + return nil, nil, false + } + + targetElemValue := targetValue.Elem() + + if targetElemValue.Kind() == reflect.Slice { + return &wrapPtrSliceScanPlan{}, &anySliceArray{slice: targetElemValue}, true + } + return nil, nil, false +} + +type wrapPtrSliceScanPlan struct { + next ScanPlan +} + +func (plan *wrapPtrSliceScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapPtrSliceScanPlan) Scan(src []byte, target interface{}) error { + return plan.next.Scan(src, &anySliceArray{slice: reflect.ValueOf(target).Elem()}) +} + +// TryWrapPtrMultiDimSliceScanPlan tries to wrap a pointer to a multi-dimension slice. +func TryWrapPtrMultiDimSliceScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextValue interface{}, ok bool) { + targetValue := reflect.ValueOf(target) + if targetValue.Kind() != reflect.Ptr { + return nil, nil, false + } + + targetElemValue := targetValue.Elem() + + if targetElemValue.Kind() == reflect.Slice { + elemElemKind := targetElemValue.Type().Elem().Kind() + if elemElemKind == reflect.Slice { + if !isRagged(targetElemValue) { + return &wrapPtrMultiDimSliceScanPlan{}, &anyMultiDimSliceArray{slice: targetValue.Elem()}, true + } + } + } + + return nil, nil, false +} + +type wrapPtrMultiDimSliceScanPlan struct { + next ScanPlan +} + +func (plan *wrapPtrMultiDimSliceScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapPtrMultiDimSliceScanPlan) Scan(src []byte, target interface{}) error { + return plan.next.Scan(src, &anyMultiDimSliceArray{slice: reflect.ValueOf(target).Elem()}) +} + // PlanScan prepares a plan to scan a value into target. func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) ScanPlan { if _, ok := target.(*UndecodedBytes); ok { @@ -1495,6 +1555,63 @@ func getExportedFieldValues(structValue reflect.Value) []reflect.Value { return exportedFields } +func TryWrapSliceEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { + if reflect.TypeOf(value).Kind() == reflect.Slice { + w := anySliceArray{ + slice: reflect.ValueOf(value), + } + return &wrapSliceEncodePlan{}, w, true + } + + return nil, nil, false +} + +type wrapSliceEncodePlan struct { + next EncodePlan +} + +func (plan *wrapSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapSliceEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + w := anySliceArray{ + slice: reflect.ValueOf(value), + } + + return plan.next.Encode(w, buf) +} + +func TryWrapMultiDimSliceEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { + sliceValue := reflect.ValueOf(value) + if sliceValue.Kind() == reflect.Slice { + valueElemType := sliceValue.Type().Elem() + + if valueElemType.Kind() == reflect.Slice { + if !isRagged(sliceValue) { + w := anyMultiDimSliceArray{ + slice: reflect.ValueOf(value), + } + return &wrapMultiDimSliceEncodePlan{}, &w, true + } + } + } + + return nil, nil, false +} + +type wrapMultiDimSliceEncodePlan struct { + next EncodePlan +} + +func (plan *wrapMultiDimSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapMultiDimSliceEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + w := anyMultiDimSliceArray{ + slice: reflect.ValueOf(value), + } + + return plan.next.Encode(&w, buf) +} + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written. From bcc0af3f56716ac821fef71811aaf2ad5db25542 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 8 Feb 2022 11:12:05 -0600 Subject: [PATCH 0901/1158] Fix scan empty array into multi-dimension slice --- pgtype/array_codec_test.go | 17 +++++++++++++++++ pgtype/builtin_wrappers.go | 5 ++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go index b4b9b6a7..81e564c7 100644 --- a/pgtype/array_codec_test.go +++ b/pgtype/array_codec_test.go @@ -126,6 +126,23 @@ func TestArrayCodecScanMultipleDimensions(t *testing.T) { require.NoError(t, rows.Err()) } +func TestArrayCodecScanMultipleDimensionsEmpty(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + rows, err := conn.Query(context.Background(), `select '{}'::int4[]`) + require.NoError(t, err) + + for rows.Next() { + var ss [][]int32 + err := rows.Scan(&ss) + require.NoError(t, err) + require.Equal(t, [][]int32{}, ss) + } + + require.NoError(t, rows.Err()) +} + func TestArrayCodecScanWrongMultipleDimensions(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 466ef45a..cb981906 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -763,7 +763,10 @@ func (a *anyMultiDimSliceArray) SetDimensions(dimensions []ArrayDimension) error switch len(dimensions) { case 0: - return fmt.Errorf("impossible: non-nil dimensions but zero elements") + // Empty, but non-nil array + slice := reflect.MakeSlice(sliceType, 0, 0) + a.slice.Set(slice) + return nil case 1: elementCount := cardinality(dimensions) slice := reflect.MakeSlice(sliceType, elementCount, elementCount) From 1334d45d7129a0419659bc6803e7a6316307b701 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 8 Feb 2022 11:35:40 -0600 Subject: [PATCH 0902/1158] Parse array header to empty slices instead of nils --- pgtype/array.go | 11 ++++++----- pgtype/array_codec.go | 12 ------------ pgtype/array_test.go | 6 +++--- 3 files changed, 9 insertions(+), 20 deletions(-) diff --git a/pgtype/array.go b/pgtype/array.go index 0e8e31a0..d1a78e64 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -58,9 +58,7 @@ func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { dst.ElementOID = binary.BigEndian.Uint32(src[rp:]) rp += 4 - if numDims > 0 { - dst.Dimensions = make([]ArrayDimension, numDims) - } + dst.Dimensions = make([]ArrayDimension, numDims) if len(src) < 12+numDims*8 { return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) } @@ -101,7 +99,11 @@ type UntypedTextArray struct { } func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { - dst := &UntypedTextArray{} + dst := &UntypedTextArray{ + Elements: []string{}, + Quoted: []bool{}, + Dimensions: []ArrayDimension{}, + } buf := bytes.NewBufferString(src) @@ -234,7 +236,6 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { } if len(dst.Elements) == 0 { - dst.Dimensions = nil } else if len(explicitDimensions) > 0 { dst.Dimensions = explicitDimensions } else { diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 94d24fc9..f23d8e3b 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -238,12 +238,6 @@ func (c *ArrayCodec) decodeBinary(ci *ConnInfo, arrayOID uint32, src []byte, arr return err } - // TODO - ArrayHeader.DecodeBinary should do this. But doing this there breaks old array code. Leave until old code - // can be removed. - if arrayHeader.Dimensions == nil { - arrayHeader.Dimensions = []ArrayDimension{} - } - err = array.SetDimensions(arrayHeader.Dimensions) if err != nil { return err @@ -283,12 +277,6 @@ func (c *ArrayCodec) decodeText(ci *ConnInfo, arrayOID uint32, src []byte, array return err } - // TODO - ParseUntypedTextArray should do this. But doing this there breaks old array code. Leave until old code - // can be removed. - if uta.Dimensions == nil { - uta.Dimensions = []ArrayDimension{} - } - err = array.SetDimensions(uta.Dimensions) if err != nil { return err diff --git a/pgtype/array_test.go b/pgtype/array_test.go index 82f5f229..8043e12f 100644 --- a/pgtype/array_test.go +++ b/pgtype/array_test.go @@ -15,9 +15,9 @@ func TestParseUntypedTextArray(t *testing.T) { { source: "{}", result: pgtype.UntypedTextArray{ - Elements: nil, - Quoted: nil, - Dimensions: nil, + Elements: []string{}, + Quoted: []bool{}, + Dimensions: []pgtype.ArrayDimension{}, }, }, { From 0306ce3a1944d453fc956a091735d6eddf659cef Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 8 Feb 2022 14:13:06 -0600 Subject: [PATCH 0903/1158] Fix scanning negative ints into Int64Scanner --- pgtype/int.go | 6 +++--- pgtype/int.go.erb | 2 +- pgtype/int_test.go | 3 +++ pgtype/int_test.go.erb | 1 + 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pgtype/int.go b/pgtype/int.go index a5b1c0a5..237fe7e7 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -552,7 +552,7 @@ func (scanPlanBinaryInt2ToInt64Scanner) Scan(src []byte, dst interface{}) error return fmt.Errorf("invalid length for int2: %v", len(src)) } - n := int64(binary.BigEndian.Uint16(src)) + n := int64(int16(binary.BigEndian.Uint16(src))) return s.ScanInt64(Int8{Int: n, Valid: true}) } @@ -1100,7 +1100,7 @@ func (scanPlanBinaryInt4ToInt64Scanner) Scan(src []byte, dst interface{}) error return fmt.Errorf("invalid length for int4: %v", len(src)) } - n := int64(binary.BigEndian.Uint32(src)) + n := int64(int32(binary.BigEndian.Uint32(src))) return s.ScanInt64(Int8{Int: n, Valid: true}) } @@ -1670,7 +1670,7 @@ func (scanPlanBinaryInt8ToInt64Scanner) Scan(src []byte, dst interface{}) error return fmt.Errorf("invalid length for int8: %v", len(src)) } - n := int64(binary.BigEndian.Uint64(src)) + n := int64(int64(binary.BigEndian.Uint64(src))) return s.ScanInt64(Int8{Int: n, Valid: true}) } diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 8524136f..18e708fa 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -439,7 +439,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner) Scan(src []byte, dst i } - n := int64(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + n := int64(int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src))) return s.ScanInt64(Int8{Int: n, Valid: true}) } diff --git a/pgtype/int_test.go b/pgtype/int_test.go index 1dbf32d5..a2e64f4e 100644 --- a/pgtype/int_test.go +++ b/pgtype/int_test.go @@ -22,6 +22,7 @@ func TestInt2Codec(t *testing.T) { {int(1), new(int16), isExpectedEq(int16(1))}, {uint(1), new(int16), isExpectedEq(int16(1))}, {pgtype.Int2{Int: 1, Valid: true}, new(int16), isExpectedEq(int16(1))}, + {int32(-1), new(pgtype.Int2), isExpectedEq(pgtype.Int2{Int: -1, Valid: true})}, {1, new(int8), isExpectedEq(int8(1))}, {1, new(int16), isExpectedEq(int16(1))}, {1, new(int32), isExpectedEq(int32(1))}, @@ -102,6 +103,7 @@ func TestInt4Codec(t *testing.T) { {int(1), new(int32), isExpectedEq(int32(1))}, {uint(1), new(int32), isExpectedEq(int32(1))}, {pgtype.Int4{Int: 1, Valid: true}, new(int32), isExpectedEq(int32(1))}, + {int32(-1), new(pgtype.Int4), isExpectedEq(pgtype.Int4{Int: -1, Valid: true})}, {1, new(int8), isExpectedEq(int8(1))}, {1, new(int16), isExpectedEq(int16(1))}, {1, new(int32), isExpectedEq(int32(1))}, @@ -182,6 +184,7 @@ func TestInt8Codec(t *testing.T) { {int(1), new(int64), isExpectedEq(int64(1))}, {uint(1), new(int64), isExpectedEq(int64(1))}, {pgtype.Int8{Int: 1, Valid: true}, new(int64), isExpectedEq(int64(1))}, + {int32(-1), new(pgtype.Int8), isExpectedEq(pgtype.Int8{Int: -1, Valid: true})}, {1, new(int8), isExpectedEq(int8(1))}, {1, new(int16), isExpectedEq(int16(1))}, {1, new(int32), isExpectedEq(int32(1))}, diff --git a/pgtype/int_test.go.erb b/pgtype/int_test.go.erb index 8858ce90..d55851c2 100644 --- a/pgtype/int_test.go.erb +++ b/pgtype/int_test.go.erb @@ -22,6 +22,7 @@ func TestInt<%= pg_byte_size %>Codec(t *testing.T) { {int(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, {uint(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, {pgtype.Int<%= pg_byte_size %>{Int: 1, Valid: true}, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int32(-1), new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{Int: -1, Valid: true})}, {1, new(int8), isExpectedEq(int8(1))}, {1, new(int16), isExpectedEq(int16(1))}, {1, new(int32), isExpectedEq(int32(1))}, From f861d83a17de567831a2a50ab9e4743aac335bbe Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 8 Feb 2022 16:48:17 -0600 Subject: [PATCH 0904/1158] Fix range types not clearing unbounded or empty --- pgtype/range_codec.go | 3 ++- pgtype/range_codec_test.go | 52 ++++++++++++++++++++++++++++++++++++++ pgtype/range_types.go | 42 ++++++++++++++++++++++++++++++ pgtype/range_types.go.erb | 6 +++++ 4 files changed, 102 insertions(+), 1 deletion(-) diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go index 0dc63e6c..f5091c36 100644 --- a/pgtype/range_codec.go +++ b/pgtype/range_codec.go @@ -29,7 +29,8 @@ type RangeScanner interface { ScanBounds() (lowerTarget, upperTarget interface{}) // SetBoundTypes sets the lower and upper bound types. ScanBounds will be called and the returned values scanned - // (if appropriate) before SetBoundTypes is called. + // (if appropriate) before SetBoundTypes is called. If the bound types are unbounded or empty this method must + // also set the bound values. SetBoundTypes(lower, upper BoundType) error } diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go index 30095065..6597ab98 100644 --- a/pgtype/range_codec_test.go +++ b/pgtype/range_codec_test.go @@ -63,6 +63,58 @@ func TestRangeCodecTranscodeCompatibleRangeElementTypes(t *testing.T) { }) } +func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + var r pgtype.Int4range + + err := conn.QueryRow(context.Background(), `select '[1,5)'::int4range`).Scan(&r) + require.NoError(t, err) + + require.Equal( + t, + pgtype.Int4range{ + Lower: pgtype.Int4{Int: 1, Valid: true}, + Upper: pgtype.Int4{Int: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + r, + ) + + err = conn.QueryRow(context.Background(), `select '[1,)'::int4range`).Scan(&r) + require.NoError(t, err) + + require.Equal( + t, + pgtype.Int4range{ + Lower: pgtype.Int4{Int: 1, Valid: true}, + Upper: pgtype.Int4{}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Unbounded, + Valid: true, + }, + r, + ) + + err = conn.QueryRow(context.Background(), `select 'empty'::int4range`).Scan(&r) + require.NoError(t, err) + + require.Equal( + t, + pgtype.Int4range{ + Lower: pgtype.Int4{}, + Upper: pgtype.Int4{}, + LowerType: pgtype.Empty, + UpperType: pgtype.Empty, + Valid: true, + }, + r, + ) +} + func TestRangeCodecDecodeValue(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) diff --git a/pgtype/range_types.go b/pgtype/range_types.go index aa979d56..1496ca30 100644 --- a/pgtype/range_types.go +++ b/pgtype/range_types.go @@ -31,6 +31,12 @@ func (r *Int4range) ScanBounds() (lowerTarget, upperTarget interface{}) { } func (r *Int4range) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + r.Lower = Int4{} + } + if upper == Unbounded || upper == Empty { + r.Upper = Int4{} + } r.LowerType = lower r.UpperType = upper r.Valid = true @@ -67,6 +73,12 @@ func (r *Int8range) ScanBounds() (lowerTarget, upperTarget interface{}) { } func (r *Int8range) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + r.Lower = Int8{} + } + if upper == Unbounded || upper == Empty { + r.Upper = Int8{} + } r.LowerType = lower r.UpperType = upper r.Valid = true @@ -103,6 +115,12 @@ func (r *Numrange) ScanBounds() (lowerTarget, upperTarget interface{}) { } func (r *Numrange) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + r.Lower = Numeric{} + } + if upper == Unbounded || upper == Empty { + r.Upper = Numeric{} + } r.LowerType = lower r.UpperType = upper r.Valid = true @@ -139,6 +157,12 @@ func (r *Tsrange) ScanBounds() (lowerTarget, upperTarget interface{}) { } func (r *Tsrange) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + r.Lower = Timestamp{} + } + if upper == Unbounded || upper == Empty { + r.Upper = Timestamp{} + } r.LowerType = lower r.UpperType = upper r.Valid = true @@ -175,6 +199,12 @@ func (r *Tstzrange) ScanBounds() (lowerTarget, upperTarget interface{}) { } func (r *Tstzrange) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + r.Lower = Timestamptz{} + } + if upper == Unbounded || upper == Empty { + r.Upper = Timestamptz{} + } r.LowerType = lower r.UpperType = upper r.Valid = true @@ -211,6 +241,12 @@ func (r *Daterange) ScanBounds() (lowerTarget, upperTarget interface{}) { } func (r *Daterange) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + r.Lower = Date{} + } + if upper == Unbounded || upper == Empty { + r.Upper = Date{} + } r.LowerType = lower r.UpperType = upper r.Valid = true @@ -247,6 +283,12 @@ func (r *Float8range) ScanBounds() (lowerTarget, upperTarget interface{}) { } func (r *Float8range) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + r.Lower = Float8{} + } + if upper == Unbounded || upper == Empty { + r.Upper = Float8{} + } r.LowerType = lower r.UpperType = upper r.Valid = true diff --git a/pgtype/range_types.go.erb b/pgtype/range_types.go.erb index dc796a1d..8b43f7f9 100644 --- a/pgtype/range_types.go.erb +++ b/pgtype/range_types.go.erb @@ -41,6 +41,12 @@ func (r *<%= range_type %>) ScanBounds() (lowerTarget, upperTarget interface{}) } func (r *<%= range_type %>) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + r.Lower = <%= element_type %>{} + } + if upper == Unbounded || upper == Empty { + r.Upper = <%= element_type %>{} + } r.LowerType = lower r.UpperType = upper r.Valid = true From a14f3f291fd0e65afae4b039514e9f3d70c5565e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Feb 2022 09:35:52 -0600 Subject: [PATCH 0905/1158] Re-enable domain type test --- conn_test.go | 33 +++++++++------------------------ 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/conn_test.go b/conn_test.go index cedd6a7a..e35def64 100644 --- a/conn_test.go +++ b/conn_test.go @@ -874,28 +874,18 @@ func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) { } func TestDomainType(t *testing.T) { - t.Skip("TODO - unskip later in v5") - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { skipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") - var n uint64 - // Domain type uint64 is a PostgreSQL domain of underlying type numeric. - err := conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) + // Unregistered type can be used as string. + var s string + err := conn.QueryRow(context.Background(), "select $1::uint64", "24").Scan(&s) require.NoError(t, err) + require.Equal(t, "24", s) - // A string can be used. But a string cannot be the result because the describe result from the PostgreSQL server gives - // the underlying type of numeric. - err = conn.QueryRow(context.Background(), "select $1::uint64", "42").Scan(&n) - if err != nil { - t.Fatal(err) - } - if n != 42 { - t.Fatalf("Expected n to be 42, but was %v", n) - } - + // Register type var uint64OID uint32 err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID) if err != nil { @@ -903,6 +893,10 @@ func TestDomainType(t *testing.T) { } conn.ConnInfo().RegisterDataType(pgtype.DataType{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}}) + var n uint64 + err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) + require.NoError(t, err) + // String is still an acceptable argument after registration err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n) if err != nil { @@ -911,15 +905,6 @@ func TestDomainType(t *testing.T) { if n != 7 { t.Fatalf("Expected n to be 7, but was %v", n) } - - // But a uint64 is acceptable - err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) - if err != nil { - t.Fatal(err) - } - if n != 24 { - t.Fatalf("Expected n to be 24, but was %v", n) - } }) } From 60da2914f3ae06f879c7cf0a1b792276e29fa1f6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Feb 2022 09:37:12 -0600 Subject: [PATCH 0906/1158] Re-enable test --- query_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/query_test.go b/query_test.go index c22c2795..c85802b2 100644 --- a/query_test.go +++ b/query_test.go @@ -369,7 +369,6 @@ func TestConnQueryCloseEarlyWithErrorOnWire(t *testing.T) { // Test that a connection stays valid when query results read incorrectly func TestConnQueryReadWrongTypeError(t *testing.T) { - t.Skip("TODO - unskip later in v5") t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -397,7 +396,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) { t.Fatal("Expected Rows to have an error after an improper read but it didn't") } - if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" && !strings.Contains(rows.Err().Error(), "cannot assign") { + if rows.Err().Error() != "can't scan into dest[0]: cannot scan OID 23 in binary format into *time.Time" { t.Fatalf("Expected different Rows.Err(): %v", rows.Err()) } From e6680127e3f25a91391adee30f29dc14fc2b88c6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Feb 2022 09:40:33 -0600 Subject: [PATCH 0907/1158] Reenable TestRowsScanNilThenScanValue --- values_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/values_test.go b/values_test.go index 080bb305..f9cfd8ce 100644 --- a/values_test.go +++ b/values_test.go @@ -991,7 +991,6 @@ func TestEncodeTypeRename(t *testing.T) { // https://github.com/jackc/pgx/issues/810 func TestRowsScanNilThenScanValue(t *testing.T) { - t.Skip("TODO - unskip later in v5") t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { From 9c5dfbdfb39f2c804ab719628bb460cd2b92054d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Feb 2022 10:26:26 -0600 Subject: [PATCH 0908/1158] pgconn.CommandTag is now an opaque type It now makes a copy instead of retaining driver memory. This is in preparation to reuse the driver read buffer. --- batch.go | 16 ++-- conn.go | 18 ++--- conn_test.go | 32 ++++---- pgconn/benchmark_private_test.go | 73 ++++++++++++++++++ pgconn/benchmark_test.go | 68 ----------------- pgconn/pgconn.go | 123 +++++++++++++++++-------------- pgconn/pgconn_private_test.go | 41 +++++++++++ pgconn/pgconn_test.go | 72 +++++------------- pgxpool/batch_results.go | 4 +- pgxpool/common_test.go | 2 +- pgxpool/pool.go | 4 +- pgxpool/rows.go | 2 +- query_test.go | 8 +- tx.go | 8 +- 14 files changed, 246 insertions(+), 225 deletions(-) create mode 100644 pgconn/benchmark_private_test.go create mode 100644 pgconn/pgconn_private_test.go diff --git a/batch.go b/batch.go index caa5a02f..689877a9 100644 --- a/batch.go +++ b/batch.go @@ -64,10 +64,10 @@ type batchResults struct { // Exec reads the results from the next query in the batch as if the query has been sent with Exec. func (br *batchResults) Exec() (pgconn.CommandTag, error) { if br.err != nil { - return nil, br.err + return pgconn.CommandTag{}, br.err } if br.closed { - return nil, fmt.Errorf("batch already closed") + return pgconn.CommandTag{}, fmt.Errorf("batch already closed") } query, arguments, _ := br.nextQueryAndArgs() @@ -84,7 +84,7 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { "err": err, }) } - return nil, err + return pgconn.CommandTag{}, err } commandTag, err := br.mrr.ResultReader().Close() @@ -151,29 +151,29 @@ func (br *batchResults) Query() (Rows, error) { // QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc. func (br *batchResults) QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { if br.closed { - return nil, fmt.Errorf("batch already closed") + return pgconn.CommandTag{}, fmt.Errorf("batch already closed") } rows, err := br.Query() if err != nil { - return nil, err + return pgconn.CommandTag{}, err } defer rows.Close() for rows.Next() { err = rows.Scan(scans...) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } err = f(rows) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } } if err := rows.Err(); err != nil { - return nil, err + return pgconn.CommandTag{}, err } return rows.CommandTag(), nil diff --git a/conn.go b/conn.go index 8e0707c4..a03871ad 100644 --- a/conn.go +++ b/conn.go @@ -432,7 +432,7 @@ optionLoop: if c.stmtcache != nil { sd, err := c.stmtcache.Get(ctx, sql) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } if c.stmtcache.Mode() == stmtcache.ModeDescribe { @@ -443,7 +443,7 @@ optionLoop: sd, err := c.Prepare(ctx, "", sql) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } return c.execPrepared(ctx, sd, arguments) } @@ -452,7 +452,7 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []i if len(arguments) > 0 { sql, err = c.sanitizeForSimpleQuery(sql, arguments...) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } } @@ -493,7 +493,7 @@ func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, argu func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { err := c.execParamsAndPreparedPrefix(sd, arguments) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read() @@ -504,7 +504,7 @@ func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { err := c.execParamsAndPreparedPrefix(sd, arguments) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() @@ -688,24 +688,24 @@ type QueryFuncRow interface { func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { rows, err := c.Query(ctx, sql, args...) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } defer rows.Close() for rows.Next() { err = rows.Scan(scans...) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } err = f(rows) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } } if err := rows.Err(); err != nil { - return nil, err + return pgconn.CommandTag{}, err } return rows.CommandTag(), nil diff --git a/conn_test.go b/conn_test.go index e35def64..0cbc0040 100644 --- a/conn_test.go +++ b/conn_test.go @@ -188,31 +188,31 @@ func TestExec(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { - if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); string(results) != "CREATE TABLE" { + if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results.String() != "CREATE TABLE" { t.Error("Unexpected results from Exec") } // Accept parameters - if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); string(results) != "INSERT 0 1" { + if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); results.String() != "INSERT 0 1" { t.Errorf("Unexpected results from Exec: %v", results) } - if results := mustExec(t, conn, "drop table foo;"); string(results) != "DROP TABLE" { + if results := mustExec(t, conn, "drop table foo;"); results.String() != "DROP TABLE" { t.Error("Unexpected results from Exec") } // Multiple statements can be executed -- last command tag is returned - if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); string(results) != "DROP TABLE" { + if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); results.String() != "DROP TABLE" { t.Error("Unexpected results from Exec") } // Can execute longer SQL strings than sharedBufferSize - if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); string(results) != "SELECT 1" { + if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); results.String() != "SELECT 1" { t.Errorf("Unexpected results from Exec: %v", results) } // Exec no-op which does not return a command tag - if results := mustExec(t, conn, "--;"); string(results) != "" { + if results := mustExec(t, conn, "--;"); results.String() != "" { t.Errorf("Unexpected results from Exec: %v", results) } }) @@ -260,7 +260,7 @@ func TestExecContextWithoutCancelation(t *testing.T) { if err != nil { t.Fatal(err) } - if string(commandTag) != "CREATE TABLE" { + if commandTag.String() != "CREATE TABLE" { t.Fatalf("Unexpected results from Exec: %v", commandTag) } assert.False(t, pgconn.SafeToRetry(err)) @@ -350,15 +350,15 @@ func TestExecStatementCacheModes(t *testing.T) { commandTag, err := conn.Exec(context.Background(), "select 1") assert.NoError(t, err, tt.name) - assert.Equal(t, "SELECT 1", string(commandTag), tt.name) + assert.Equal(t, "SELECT 1", commandTag.String(), tt.name) commandTag, err = conn.Exec(context.Background(), "select 1 union all select 1") assert.NoError(t, err, tt.name) - assert.Equal(t, "SELECT 2", string(commandTag), tt.name) + assert.Equal(t, "SELECT 2", commandTag.String(), tt.name) commandTag, err = conn.Exec(context.Background(), "select 1") assert.NoError(t, err, tt.name) - assert.Equal(t, "SELECT 1", string(commandTag), tt.name) + assert.Equal(t, "SELECT 1", commandTag.String(), tt.name) ensureConnValid(t, conn) }() @@ -378,7 +378,7 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) { if err != nil { t.Fatal(err) } - if string(commandTag) != "CREATE TABLE" { + if commandTag.String() != "CREATE TABLE" { t.Fatalf("Unexpected results from Exec: %v", commandTag) } @@ -390,7 +390,7 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) { if err != nil { t.Fatal(err) } - if string(commandTag) != "INSERT 0 1" { + if commandTag.String() != "INSERT 0 1" { t.Fatalf("Unexpected results from Exec: %v", commandTag) } @@ -720,12 +720,12 @@ func TestInsertBoolArray(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { - if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); string(results) != "CREATE TABLE" { + if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results.String() != "CREATE TABLE" { t.Error("Unexpected results from Exec") } // Accept parameters - if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); string(results) != "INSERT 0 1" { + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results.String() != "INSERT 0 1" { t.Errorf("Unexpected results from Exec: %v", results) } }) @@ -735,12 +735,12 @@ func TestInsertTimestampArray(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { - if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); string(results) != "CREATE TABLE" { + if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results.String() != "CREATE TABLE" { t.Error("Unexpected results from Exec") } // Accept parameters - if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); string(results) != "INSERT 0 1" { + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results.String() != "INSERT 0 1" { t.Errorf("Unexpected results from Exec: %v", results) } }) diff --git a/pgconn/benchmark_private_test.go b/pgconn/benchmark_private_test.go new file mode 100644 index 00000000..e074c75c --- /dev/null +++ b/pgconn/benchmark_private_test.go @@ -0,0 +1,73 @@ +package pgconn + +import ( + "strings" + "testing" +) + +func BenchmarkCommandTagRowsAffected(b *testing.B) { + benchmarks := []struct { + commandTag string + rowsAffected int64 + }{ + {"UPDATE 1", 1}, + {"UPDATE 123456789", 123456789}, + {"INSERT 0 1", 1}, + {"INSERT 0 123456789", 123456789}, + } + + for _, bm := range benchmarks { + ct := CommandTag{buf: []byte(bm.commandTag)} + b.Run(bm.commandTag, func(b *testing.B) { + var n int64 + for i := 0; i < b.N; i++ { + n = ct.RowsAffected() + } + if n != bm.rowsAffected { + b.Errorf("expected %d got %d", bm.rowsAffected, n) + } + }) + } +} + +func BenchmarkCommandTagTypeFromString(b *testing.B) { + ct := CommandTag{buf: []byte("UPDATE 1")} + + var update bool + for i := 0; i < b.N; i++ { + update = strings.HasPrefix(ct.String(), "UPDATE") + } + if !update { + b.Error("expected update") + } +} + +func BenchmarkCommandTagInsert(b *testing.B) { + benchmarks := []struct { + commandTag string + is bool + }{ + {"INSERT 1", true}, + {"INSERT 1234567890", true}, + {"UPDATE 1", false}, + {"UPDATE 1234567890", false}, + {"DELETE 1", false}, + {"DELETE 1234567890", false}, + {"SELECT 1", false}, + {"SELECT 1234567890", false}, + {"UNKNOWN 1234567890", false}, + } + + for _, bm := range benchmarks { + ct := CommandTag{buf: []byte(bm.commandTag)} + b.Run(bm.commandTag, func(b *testing.B) { + var is bool + for i := 0; i < b.N; i++ { + is = ct.Insert() + } + if is != bm.is { + b.Errorf("expected %v got %v", bm.is, is) + } + }) + } +} diff --git a/pgconn/benchmark_test.go b/pgconn/benchmark_test.go index 088a9bd9..ffa42243 100644 --- a/pgconn/benchmark_test.go +++ b/pgconn/benchmark_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "os" - "strings" "testing" "github.com/jackc/pgx/v5/pgconn" @@ -253,70 +252,3 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { // conn.ChanToSetDeadline().Ignore() // } // } - -func BenchmarkCommandTagRowsAffected(b *testing.B) { - benchmarks := []struct { - commandTag string - rowsAffected int64 - }{ - {"UPDATE 1", 1}, - {"UPDATE 123456789", 123456789}, - {"INSERT 0 1", 1}, - {"INSERT 0 123456789", 123456789}, - } - - for _, bm := range benchmarks { - ct := pgconn.CommandTag(bm.commandTag) - b.Run(bm.commandTag, func(b *testing.B) { - var n int64 - for i := 0; i < b.N; i++ { - n = ct.RowsAffected() - } - if n != bm.rowsAffected { - b.Errorf("expected %d got %d", bm.rowsAffected, n) - } - }) - } -} - -func BenchmarkCommandTagTypeFromString(b *testing.B) { - ct := pgconn.CommandTag("UPDATE 1") - - var update bool - for i := 0; i < b.N; i++ { - update = strings.HasPrefix(ct.String(), "UPDATE") - } - if !update { - b.Error("expected update") - } -} - -func BenchmarkCommandTagInsert(b *testing.B) { - benchmarks := []struct { - commandTag string - is bool - }{ - {"INSERT 1", true}, - {"INSERT 1234567890", true}, - {"UPDATE 1", false}, - {"UPDATE 1234567890", false}, - {"DELETE 1", false}, - {"DELETE 1234567890", false}, - {"SELECT 1", false}, - {"SELECT 1234567890", false}, - {"UNKNOWN 1234567890", false}, - } - - for _, bm := range benchmarks { - ct := pgconn.CommandTag(bm.commandTag) - b.Run(bm.commandTag, func(b *testing.B) { - var is bool - for i := 0; i < b.N; i++ { - is = ct.Insert() - } - if is != bm.is { - b.Errorf("expected %v got %v", bm.is, is) - } - }) - } -} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 843bbef4..16d54f3a 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -685,15 +685,17 @@ func (pgConn *PgConn) ParameterStatus(key string) string { } // CommandTag is the result of an Exec function -type CommandTag []byte +type CommandTag struct { + buf []byte +} // RowsAffected returns the number of rows affected. If the CommandTag was not // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. func (ct CommandTag) RowsAffected() int64 { // Find last non-digit idx := -1 - for i := len(ct) - 1; i >= 0; i-- { - if ct[i] >= '0' && ct[i] <= '9' { + for i := len(ct.buf) - 1; i >= 0; i-- { + if ct.buf[i] >= '0' && ct.buf[i] <= '9' { idx = i } else { break @@ -705,7 +707,7 @@ func (ct CommandTag) RowsAffected() int64 { } var n int64 - for _, b := range ct[idx:] { + for _, b := range ct.buf[idx:] { n = n*10 + int64(b-'0') } @@ -713,51 +715,51 @@ func (ct CommandTag) RowsAffected() int64 { } func (ct CommandTag) String() string { - return string(ct) + return string(ct.buf) } // Insert is true if the command tag starts with "INSERT". func (ct CommandTag) Insert() bool { - return len(ct) >= 6 && - ct[0] == 'I' && - ct[1] == 'N' && - ct[2] == 'S' && - ct[3] == 'E' && - ct[4] == 'R' && - ct[5] == 'T' + return len(ct.buf) >= 6 && + ct.buf[0] == 'I' && + ct.buf[1] == 'N' && + ct.buf[2] == 'S' && + ct.buf[3] == 'E' && + ct.buf[4] == 'R' && + ct.buf[5] == 'T' } // Update is true if the command tag starts with "UPDATE". func (ct CommandTag) Update() bool { - return len(ct) >= 6 && - ct[0] == 'U' && - ct[1] == 'P' && - ct[2] == 'D' && - ct[3] == 'A' && - ct[4] == 'T' && - ct[5] == 'E' + return len(ct.buf) >= 6 && + ct.buf[0] == 'U' && + ct.buf[1] == 'P' && + ct.buf[2] == 'D' && + ct.buf[3] == 'A' && + ct.buf[4] == 'T' && + ct.buf[5] == 'E' } // Delete is true if the command tag starts with "DELETE". func (ct CommandTag) Delete() bool { - return len(ct) >= 6 && - ct[0] == 'D' && - ct[1] == 'E' && - ct[2] == 'L' && - ct[3] == 'E' && - ct[4] == 'T' && - ct[5] == 'E' + return len(ct.buf) >= 6 && + ct.buf[0] == 'D' && + ct.buf[1] == 'E' && + ct.buf[2] == 'L' && + ct.buf[3] == 'E' && + ct.buf[4] == 'T' && + ct.buf[5] == 'E' } // Select is true if the command tag starts with "SELECT". func (ct CommandTag) Select() bool { - return len(ct) >= 6 && - ct[0] == 'S' && - ct[1] == 'E' && - ct[2] == 'L' && - ct[3] == 'E' && - ct[4] == 'C' && - ct[5] == 'T' + return len(ct.buf) >= 6 && + ct.buf[0] == 'S' && + ct.buf[1] == 'E' && + ct.buf[2] == 'L' && + ct.buf[3] == 'E' && + ct.buf[4] == 'C' && + ct.buf[5] == 'T' } type StatementDescription struct { @@ -1076,13 +1078,13 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by result := &pgConn.resultReader if err := pgConn.lock(); err != nil { - result.concludeCommand(nil, err) + result.concludeCommand(CommandTag{}, err) result.closed = true return result } if len(paramValues) > math.MaxUint16 { - result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.concludeCommand(CommandTag{}, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true pgConn.unlock() return result @@ -1091,7 +1093,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by if ctx != context.Background() { select { case <-ctx.Done(): - result.concludeCommand(nil, newContextAlreadyDoneError(ctx)) + result.concludeCommand(CommandTag{}, newContextAlreadyDoneError(ctx)) result.closed = true pgConn.unlock() return result @@ -1111,7 +1113,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { n, err := pgConn.conn.Write(buf) if err != nil { pgConn.asyncClose() - result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0}) + result.concludeCommand(CommandTag{}, &writeError{err: err, safeToRetry: n == 0}) pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() @@ -1124,14 +1126,14 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, err + return CommandTag{}, err } if ctx != context.Background() { select { case <-ctx.Done(): pgConn.unlock() - return nil, newContextAlreadyDoneError(ctx) + return CommandTag{}, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -1146,7 +1148,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm if err != nil { pgConn.asyncClose() pgConn.unlock() - return nil, &writeError{err: err, safeToRetry: n == 0} + return CommandTag{}, &writeError{err: err, safeToRetry: n == 0} } // Read results @@ -1156,7 +1158,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return CommandTag{}, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1165,13 +1167,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm _, err := w.Write(msg.Data) if err != nil { pgConn.asyncClose() - return nil, err + return CommandTag{}, err } case *pgproto3.ReadyForQuery: pgConn.unlock() return commandTag, pgErr case *pgproto3.CommandComplete: - commandTag = CommandTag(msg.CommandTag) + commandTag = pgConn.makeCommandTag(msg.CommandTag) case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) } @@ -1184,14 +1186,14 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, err + return CommandTag{}, err } defer pgConn.unlock() if ctx != context.Background() { select { case <-ctx.Done(): - return nil, newContextAlreadyDoneError(ctx) + return CommandTag{}, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -1205,7 +1207,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co n, err := pgConn.conn.Write(buf) if err != nil { pgConn.asyncClose() - return nil, &writeError{err: err, safeToRetry: n == 0} + return CommandTag{}, &writeError{err: err, safeToRetry: n == 0} } // Send copy data @@ -1255,7 +1257,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return CommandTag{}, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1279,7 +1281,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.asyncClose() - return nil, err + return CommandTag{}, err } // Read results @@ -1288,14 +1290,14 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return CommandTag{}, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { case *pgproto3.ReadyForQuery: return commandTag, pgErr case *pgproto3.CommandComplete: - commandTag = CommandTag(msg.CommandTag) + commandTag = pgConn.makeCommandTag(msg.CommandTag) case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) } @@ -1368,7 +1370,7 @@ func (mrr *MultiResultReader) NextResult() bool { return true case *pgproto3.CommandComplete: mrr.pgConn.resultReader = ResultReader{ - commandTag: CommandTag(msg.CommandTag), + commandTag: mrr.pgConn.makeCommandTag(msg.CommandTag), commandConcluded: true, closed: true, } @@ -1483,7 +1485,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { for !rr.commandConcluded { _, err := rr.receiveMessage() if err != nil { - return nil, rr.err + return CommandTag{}, rr.err } } @@ -1491,7 +1493,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { for { msg, err := rr.receiveMessage() if err != nil { - return nil, rr.err + return CommandTag{}, rr.err } switch msg := msg.(type) { @@ -1538,7 +1540,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error if err != nil { err = preferContextOverNetTimeoutError(rr.ctx, err) - rr.concludeCommand(nil, err) + rr.concludeCommand(CommandTag{}, err) rr.pgConn.contextWatcher.Unwatch() rr.closed = true if rr.multiResultReader == nil { @@ -1552,11 +1554,11 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.RowDescription: rr.fieldDescriptions = msg.Fields case *pgproto3.CommandComplete: - rr.concludeCommand(CommandTag(msg.CommandTag), nil) + rr.concludeCommand(rr.pgConn.makeCommandTag(msg.CommandTag), nil) case *pgproto3.EmptyQueryResponse: - rr.concludeCommand(nil, nil) + rr.concludeCommand(CommandTag{}, nil) case *pgproto3.ErrorResponse: - rr.concludeCommand(nil, ErrorResponseToPgError(msg)) + rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg)) } return msg, nil @@ -1659,6 +1661,13 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) { return strings.Replace(s, "'", "''", -1), nil } +// makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory. +func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag { + ct := make([]byte, len(buf)) + copy(ct, buf) + return CommandTag{buf: ct} +} + // HijackedConn is the result of hijacking a connection. // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning diff --git a/pgconn/pgconn_private_test.go b/pgconn/pgconn_private_test.go new file mode 100644 index 00000000..4368f717 --- /dev/null +++ b/pgconn/pgconn_private_test.go @@ -0,0 +1,41 @@ +package pgconn + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCommandTag(t *testing.T) { + t.Parallel() + + var tests = []struct { + commandTag CommandTag + rowsAffected int64 + isInsert bool + isUpdate bool + isDelete bool + isSelect bool + }{ + {commandTag: CommandTag{buf: []byte("INSERT 0 5")}, rowsAffected: 5, isInsert: true}, + {commandTag: CommandTag{buf: []byte("UPDATE 0")}, rowsAffected: 0, isUpdate: true}, + {commandTag: CommandTag{buf: []byte("UPDATE 1")}, rowsAffected: 1, isUpdate: true}, + {commandTag: CommandTag{buf: []byte("DELETE 0")}, rowsAffected: 0, isDelete: true}, + {commandTag: CommandTag{buf: []byte("DELETE 1")}, rowsAffected: 1, isDelete: true}, + {commandTag: CommandTag{buf: []byte("DELETE 1234567890")}, rowsAffected: 1234567890, isDelete: true}, + {commandTag: CommandTag{buf: []byte("SELECT 1")}, rowsAffected: 1, isSelect: true}, + {commandTag: CommandTag{buf: []byte("SELECT 99999999999")}, rowsAffected: 99999999999, isSelect: true}, + {commandTag: CommandTag{buf: []byte("CREATE TABLE")}, rowsAffected: 0}, + {commandTag: CommandTag{buf: []byte("ALTER TABLE")}, rowsAffected: 0}, + {commandTag: CommandTag{buf: []byte("DROP TABLE")}, rowsAffected: 0}, + } + + for i, tt := range tests { + ct := tt.commandTag + assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag) + } +} diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index d1ba29d2..4d975f32 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -538,7 +538,7 @@ func TestConnExec(t *testing.T) { assert.Len(t, results, 1) assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) assert.Len(t, results[0].Rows, 1) assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) @@ -579,12 +579,12 @@ func TestConnExecMultipleQueries(t *testing.T) { assert.Len(t, results, 2) assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) assert.Len(t, results[0].Rows, 1) assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) assert.Nil(t, results[1].Err) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + assert.Equal(t, "SELECT 1", results[1].CommandTag.String()) assert.Len(t, results[1].Rows, 1) assert.Equal(t, "1", string(results[1].Rows[0][0])) @@ -741,7 +741,7 @@ func TestConnExecParams(t *testing.T) { } assert.Equal(t, 1, rowCount) commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Equal(t, "SELECT 1", commandTag.String()) assert.NoError(t, err) ensureConnValid(t, pgConn) @@ -840,7 +840,7 @@ func TestConnExecParamsCanceled(t *testing.T) { } assert.Equal(t, 0, rowCount) commandTag, err := result.Close() - assert.Equal(t, pgconn.CommandTag(nil), commandTag) + assert.Equal(t, pgconn.CommandTag{}, commandTag) assert.True(t, pgconn.Timeout(err)) assert.ErrorIs(t, err, context.DeadlineExceeded) @@ -880,7 +880,7 @@ func TestConnExecParamsEmptySQL(t *testing.T) { defer closeConn(t, pgConn) result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read() - assert.Nil(t, result.CommandTag) + assert.Equal(t, pgconn.CommandTag{}, result.CommandTag) assert.Len(t, result.Rows, 0) assert.NoError(t, result.Err) @@ -907,7 +907,7 @@ func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) { } assert.Equal(t, 1, rowCount) commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Equal(t, "SELECT 1", commandTag.String()) assert.NoError(t, err) ensureConnValid(t, pgConn) @@ -937,7 +937,7 @@ func TestConnExecPrepared(t *testing.T) { } assert.Equal(t, 1, rowCount) commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Equal(t, "SELECT 1", commandTag.String()) assert.NoError(t, err) ensureConnValid(t, pgConn) @@ -1025,7 +1025,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { } assert.Equal(t, 0, rowCount) commandTag, err := result.Close() - assert.Equal(t, pgconn.CommandTag(nil), commandTag) + assert.Equal(t, pgconn.CommandTag{}, commandTag) assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) select { @@ -1069,7 +1069,7 @@ func TestConnExecPreparedEmptySQL(t *testing.T) { require.NoError(t, err) result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() - assert.Nil(t, result.CommandTag) + assert.Equal(t, pgconn.CommandTag{}, result.CommandTag) assert.Len(t, result.Rows, 0) assert.NoError(t, result.Err) @@ -1097,15 +1097,15 @@ func TestConnExecBatch(t *testing.T) { require.Len(t, results[0].Rows, 1) require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) require.Len(t, results[1].Rows, 1) require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + assert.Equal(t, "SELECT 1", results[1].CommandTag.String()) require.Len(t, results[2].Rows, 1) require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) + assert.Equal(t, "SELECT 1", results[2].CommandTag.String()) } func TestConnExecBatchDeferredError(t *testing.T) { @@ -1199,7 +1199,7 @@ func TestConnExecBatchHuge(t *testing.T) { for i := range args { require.Len(t, results[i].Rows, 1) require.Equal(t, args[i], string(results[i].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) + assert.Equal(t, "SELECT 1", results[i].CommandTag.String()) } } @@ -1247,47 +1247,13 @@ func TestConnLocking(t *testing.T) { assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) assert.Len(t, results[0].Rows, 1) assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) ensureConnValid(t, pgConn) } -func TestCommandTag(t *testing.T) { - t.Parallel() - - var tests = []struct { - commandTag pgconn.CommandTag - rowsAffected int64 - isInsert bool - isUpdate bool - isDelete bool - isSelect bool - }{ - {commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5, isInsert: true}, - {commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0, isUpdate: true}, - {commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1, isUpdate: true}, - {commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0, isDelete: true}, - {commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1, isDelete: true}, - {commandTag: pgconn.CommandTag("DELETE 1234567890"), rowsAffected: 1234567890, isDelete: true}, - {commandTag: pgconn.CommandTag("SELECT 1"), rowsAffected: 1, isSelect: true}, - {commandTag: pgconn.CommandTag("SELECT 99999999999"), rowsAffected: 99999999999, isSelect: true}, - {commandTag: pgconn.CommandTag("CREATE TABLE"), rowsAffected: 0}, - {commandTag: pgconn.CommandTag("ALTER TABLE"), rowsAffected: 0}, - {commandTag: pgconn.CommandTag("DROP TABLE"), rowsAffected: 0}, - } - - for i, tt := range tests { - ct := tt.commandTag - assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag) - assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag) - assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag) - assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag) - assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag) - } -} - func TestConnOnNotice(t *testing.T) { t.Parallel() @@ -1546,7 +1512,7 @@ func TestConnCopyToCanceled(t *testing.T) { defer cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") assert.Error(t, err) - assert.Equal(t, pgconn.CommandTag(nil), res) + assert.Equal(t, pgconn.CommandTag{}, res) assert.True(t, pgConn.IsClosed()) select { @@ -1571,7 +1537,7 @@ func TestConnCopyToPrecanceled(t *testing.T) { require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, pgconn.SafeToRetry(err)) - assert.Equal(t, pgconn.CommandTag(nil), res) + assert.Equal(t, pgconn.CommandTag{}, res) ensureConnValid(t, pgConn) } @@ -1692,7 +1658,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) { require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, pgconn.SafeToRetry(err)) - assert.Equal(t, pgconn.CommandTag(nil), ct) + assert.Equal(t, pgconn.CommandTag{}, ct) ensureConnValid(t, pgConn) } @@ -2014,7 +1980,7 @@ func TestHijackAndConstruct(t *testing.T) { assert.Len(t, results, 1) assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) assert.Len(t, results[0].Rows, 1) assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) diff --git a/pgxpool/batch_results.go b/pgxpool/batch_results.go index 8bec35cb..aa1d609d 100644 --- a/pgxpool/batch_results.go +++ b/pgxpool/batch_results.go @@ -10,7 +10,7 @@ type errBatchResults struct { } func (br errBatchResults) Exec() (pgconn.CommandTag, error) { - return nil, br.err + return pgconn.CommandTag{}, br.err } func (br errBatchResults) Query() (pgx.Rows, error) { @@ -18,7 +18,7 @@ func (br errBatchResults) Query() (pgx.Rows, error) { } func (br errBatchResults) QueryFunc(scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { - return nil, br.err + return pgconn.CommandTag{}, br.err } func (br errBatchResults) QueryRow() pgx.Row { diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go index c6f3b77b..7b9f9f29 100644 --- a/pgxpool/common_test.go +++ b/pgxpool/common_test.go @@ -27,7 +27,7 @@ type execer interface { func testExec(t *testing.T, db execer) { results, err := db.Exec(context.Background(), "set time zone 'America/Chicago'") require.NoError(t, err) - assert.EqualValues(t, "SET", results) + assert.EqualValues(t, "SET", results.String()) } type queryer interface { diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 41fb4d5b..30d02879 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -470,7 +470,7 @@ func (p *Pool) Stat() *Stat { func (p *Pool) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { c, err := p.Acquire(ctx) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } defer c.Release() @@ -527,7 +527,7 @@ func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pg func (p *Pool) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { c, err := p.Acquire(ctx) if err != nil { - return nil, err + return pgconn.CommandTag{}, err } defer c.Release() diff --git a/pgxpool/rows.go b/pgxpool/rows.go index 0c97dc91..f3f24649 100644 --- a/pgxpool/rows.go +++ b/pgxpool/rows.go @@ -12,7 +12,7 @@ type errRows struct { func (errRows) Close() {} func (e errRows) Err() error { return e.err } -func (errRows) CommandTag() pgconn.CommandTag { return nil } +func (errRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} } func (errRows) FieldDescriptions() []pgproto3.FieldDescription { return nil } func (errRows) Next() bool { return false } func (e errRows) Scan(dest ...interface{}) error { return e.err } diff --git a/query_test.go b/query_test.go index c85802b2..2f8975ac 100644 --- a/query_test.go +++ b/query_test.go @@ -45,7 +45,7 @@ func TestConnQueryScan(t *testing.T) { t.Fatalf("conn.Query failed: %v", err) } - assert.Equal(t, "SELECT 10", string(rows.CommandTag())) + assert.Equal(t, "SELECT 10", rows.CommandTag().String()) if rowCount != 10 { t.Error("Select called onDataRow wrong number of times") @@ -79,7 +79,7 @@ func TestConnQueryWithoutResultSetCommandTag(t *testing.T) { assert.NoError(t, err) rows.Close() assert.NoError(t, rows.Err()) - assert.Equal(t, "CREATE TABLE", string(rows.CommandTag())) + assert.Equal(t, "CREATE TABLE", rows.CommandTag().String()) } func TestConnQueryScanWithManyColumns(t *testing.T) { @@ -1139,7 +1139,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes if err != nil { t.Fatal(err) } - if string(commandTag) != "INSERT 0 1" { + if commandTag.String() != "INSERT 0 1" { t.Fatalf("want %s, got %s", "INSERT 0 1", commandTag) } @@ -1976,7 +1976,7 @@ func TestConnQueryFuncAbort(t *testing.T) { }, ) require.EqualError(t, err, "abort") - require.Nil(t, ct) + require.Equal(t, pgconn.CommandTag{}, ct) }) } diff --git a/tx.go b/tx.go index 3ed0ca67..6b85b303 100644 --- a/tx.go +++ b/tx.go @@ -235,7 +235,7 @@ func (tx *dbTx) Commit(ctx context.Context) error { } return err } - if string(commandTag) == "ROLLBACK" { + if commandTag.String() == "ROLLBACK" { return ErrTxCommitRollback } @@ -296,7 +296,7 @@ func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...interface{}) R // QueryFunc delegates to the underlying *Conn. func (tx *dbTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { if tx.closed { - return nil, ErrTxClosed + return pgconn.CommandTag{}, ErrTxClosed } return tx.conn.QueryFunc(ctx, sql, args, scans, f) @@ -380,7 +380,7 @@ func (sp *dbSavepoint) Rollback(ctx context.Context) error { // Exec delegates to the underlying Tx func (sp *dbSavepoint) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { if sp.closed { - return nil, ErrTxClosed + return pgconn.CommandTag{}, ErrTxClosed } return sp.tx.Exec(ctx, sql, arguments...) @@ -415,7 +415,7 @@ func (sp *dbSavepoint) QueryRow(ctx context.Context, sql string, args ...interfa // QueryFunc delegates to the underlying Tx. func (sp *dbSavepoint) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { if sp.closed { - return nil, ErrTxClosed + return pgconn.CommandTag{}, ErrTxClosed } return sp.tx.QueryFunc(ctx, sql, args, scans, f) From 34bf0a5df957597aaef6997388ce77f64b69a8d8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 19 Feb 2022 08:00:49 -0600 Subject: [PATCH 0909/1158] Upgrade golang.org/x/text to v0.3.7 https://github.com/jackc/pgconn/issues/103 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 6fdd0e97..fb3ed181 100644 --- a/go.mod +++ b/go.mod @@ -11,5 +11,5 @@ require ( github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.7.0 golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 - golang.org/x/text v0.3.6 + golang.org/x/text v0.3.7 ) diff --git a/go.sum b/go.sum index 3c77ee21..bdb5ee8c 100644 --- a/go.sum +++ b/go.sum @@ -114,6 +114,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= From ccb96b8aca08245245455a3d6c168e7768d51734 Mon Sep 17 00:00:00 2001 From: William Storey Date: Wed, 16 Feb 2022 11:34:09 -0800 Subject: [PATCH 0910/1158] Fix typos in comments --- pgconn.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgconn.go b/pgconn.go index 7bf2f20e..29889a74 100644 --- a/pgconn.go +++ b/pgconn.go @@ -99,7 +99,7 @@ type PgConn struct { } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) -// to provide configuration. See documention for ParseConfig for details. ctx can be used to cancel a connect attempt. +// to provide configuration. See documentation for ParseConfig for details. ctx can be used to cancel a connect attempt. func Connect(ctx context.Context, connString string) (*PgConn, error) { config, err := ParseConfig(connString) if err != nil { @@ -154,7 +154,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err break } else if pgerr, ok := err.(*PgError); ok { err = &connectError{config: config, msg: "server error", err: pgerr} - ERRCODE_INVALID_PASSWORD := "28P01" // worng password + ERRCODE_INVALID_PASSWORD := "28P01" // wrong password ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION := "28000" // db does not exist if pgerr.Code == ERRCODE_INVALID_PASSWORD || pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION { break From ded272b1f2c31e345b3209d7ab373822bf90b761 Mon Sep 17 00:00:00 2001 From: William Storey Date: Wed, 16 Feb 2022 11:37:26 -0800 Subject: [PATCH 0911/1158] Remove documentation line stating only one IP is used With `expandWithIPs()` (added in #14), we try all IPs. --- config.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/config.go b/config.go index 0eab23af..5cee9297 100644 --- a/config.go +++ b/config.go @@ -176,8 +176,6 @@ func NetworkAddress(host string, port uint16) (network, address string) { // // Other known differences with libpq: // -// If a host name resolves into multiple addresses, libpq will try all addresses. pgconn will only try the first. -// // When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn // does not. // From a3c351d11a055cce561bf6360117b086cd1e063c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 21 Feb 2022 08:49:04 -0600 Subject: [PATCH 0912/1158] RegisterDataType now accepts *DataType --- bench_test.go | 2 +- conn_test.go | 2 +- pgtype/composite_test.go | 8 +- pgtype/enum_codec_test.go | 4 +- pgtype/hstore_test.go | 2 +- pgtype/pgtype.go | 188 +++++++++++++++++++------------------- 6 files changed, 103 insertions(+), 103 deletions(-) diff --git a/bench_test.go b/bench_test.go index c49c87f6..f6e8a871 100644 --- a/bench_test.go +++ b/bench_test.go @@ -918,7 +918,7 @@ func BenchmarkSelectManyRegisteredEnum(b *testing.B) { err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "color").Scan(&oid) require.NoError(b, err) - conn.ConnInfo().RegisterDataType(pgtype.DataType{Name: "color", OID: oid, Codec: &pgtype.EnumCodec{}}) + conn.ConnInfo().RegisterDataType(&pgtype.DataType{Name: "color", OID: oid, Codec: &pgtype.EnumCodec{}}) b.ResetTimer() var x, y, z string diff --git a/conn_test.go b/conn_test.go index 0cbc0040..da83a7cc 100644 --- a/conn_test.go +++ b/conn_test.go @@ -891,7 +891,7 @@ func TestDomainType(t *testing.T) { if err != nil { t.Fatalf("did not find uint64 OID, %v", err) } - conn.ConnInfo().RegisterDataType(pgtype.DataType{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}}) + conn.ConnInfo().RegisterDataType(&pgtype.DataType{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}}) var n uint64 err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go index 66db4281..954a5f6a 100644 --- a/pgtype/composite_test.go +++ b/pgtype/composite_test.go @@ -26,7 +26,7 @@ create type ct_test as ( dt, err := conn.LoadDataType(context.Background(), "ct_test") require.NoError(t, err) - conn.ConnInfo().RegisterDataType(*dt) + conn.ConnInfo().RegisterDataType(dt) formats := []struct { name string @@ -105,7 +105,7 @@ create type point3d as ( dt, err := conn.LoadDataType(context.Background(), "point3d") require.NoError(t, err) - conn.ConnInfo().RegisterDataType(*dt) + conn.ConnInfo().RegisterDataType(dt) formats := []struct { name string @@ -140,7 +140,7 @@ create type point3d as ( dt, err := conn.LoadDataType(context.Background(), "point3d") require.NoError(t, err) - conn.ConnInfo().RegisterDataType(*dt) + conn.ConnInfo().RegisterDataType(dt) formats := []struct { name string @@ -179,7 +179,7 @@ create type point3d as ( dt, err := conn.LoadDataType(context.Background(), "point3d") require.NoError(t, err) - conn.ConnInfo().RegisterDataType(*dt) + conn.ConnInfo().RegisterDataType(dt) formats := []struct { name string diff --git a/pgtype/enum_codec_test.go b/pgtype/enum_codec_test.go index 139bfc34..5ced8a11 100644 --- a/pgtype/enum_codec_test.go +++ b/pgtype/enum_codec_test.go @@ -21,7 +21,7 @@ create type enum_test as enum ('foo', 'bar', 'baz');`) dt, err := conn.LoadDataType(context.Background(), "enum_test") require.NoError(t, err) - conn.ConnInfo().RegisterDataType(*dt) + conn.ConnInfo().RegisterDataType(dt) var s string err = conn.QueryRow(context.Background(), `select 'foo'::enum_test`).Scan(&s) @@ -58,7 +58,7 @@ create type enum_test as enum ('foo', 'bar', 'baz');`) dt, err := conn.LoadDataType(context.Background(), "enum_test") require.NoError(t, err) - conn.ConnInfo().RegisterDataType(*dt) + conn.ConnInfo().RegisterDataType(dt) rows, err := conn.Query(context.Background(), `select 'foo'::enum_test`) require.NoError(t, err) diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go index 8d2b6971..0967bf0c 100644 --- a/pgtype/hstore_test.go +++ b/pgtype/hstore_test.go @@ -61,7 +61,7 @@ func TestHstoreCodec(t *testing.T) { t.Skipf("Skipping: cannot find hstore OID") } - conn.ConnInfo().RegisterDataType(pgtype.DataType{Name: "hstore", OID: hstoreOID, Codec: pgtype.HstoreCodec{}}) + conn.ConnInfo().RegisterDataType(&pgtype.DataType{Name: "hstore", OID: hstoreOID, Codec: pgtype.HstoreCodec{}}) formats := []struct { name string diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 54792963..6f5c8878 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -218,99 +218,99 @@ func NewConnInfo() *ConnInfo { }, } - ci.RegisterDataType(DataType{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) - ci.RegisterDataType(DataType{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) - ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) - ci.RegisterDataType(DataType{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) - ci.RegisterDataType(DataType{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) - ci.RegisterDataType(DataType{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) - ci.RegisterDataType(DataType{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) - ci.RegisterDataType(DataType{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) - ci.RegisterDataType(DataType{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) - ci.RegisterDataType(DataType{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) - ci.RegisterDataType(DataType{Name: "date", OID: DateOID, Codec: DateCodec{}}) - ci.RegisterDataType(DataType{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) - ci.RegisterDataType(DataType{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) - ci.RegisterDataType(DataType{Name: "inet", OID: InetOID, Codec: InetCodec{}}) - ci.RegisterDataType(DataType{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) - ci.RegisterDataType(DataType{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) - ci.RegisterDataType(DataType{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) - ci.RegisterDataType(DataType{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) - ci.RegisterDataType(DataType{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) - ci.RegisterDataType(DataType{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) - ci.RegisterDataType(DataType{Name: "line", OID: LineOID, Codec: LineCodec{}}) - ci.RegisterDataType(DataType{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) - ci.RegisterDataType(DataType{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) - ci.RegisterDataType(DataType{Name: "name", OID: NameOID, Codec: TextCodec{}}) - ci.RegisterDataType(DataType{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) - ci.RegisterDataType(DataType{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) - ci.RegisterDataType(DataType{Name: "path", OID: PathOID, Codec: PathCodec{}}) - ci.RegisterDataType(DataType{Name: "point", OID: PointOID, Codec: PointCodec{}}) - ci.RegisterDataType(DataType{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) - ci.RegisterDataType(DataType{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) - ci.RegisterDataType(DataType{Name: "text", OID: TextOID, Codec: TextCodec{}}) - ci.RegisterDataType(DataType{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) - ci.RegisterDataType(DataType{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) - ci.RegisterDataType(DataType{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) - ci.RegisterDataType(DataType{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) - ci.RegisterDataType(DataType{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) - ci.RegisterDataType(DataType{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) - ci.RegisterDataType(DataType{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) - ci.RegisterDataType(DataType{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) - ci.RegisterDataType(DataType{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) + ci.RegisterDataType(&DataType{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) + ci.RegisterDataType(&DataType{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) + ci.RegisterDataType(&DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) + ci.RegisterDataType(&DataType{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) + ci.RegisterDataType(&DataType{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) + ci.RegisterDataType(&DataType{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) + ci.RegisterDataType(&DataType{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) + ci.RegisterDataType(&DataType{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) + ci.RegisterDataType(&DataType{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) + ci.RegisterDataType(&DataType{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) + ci.RegisterDataType(&DataType{Name: "date", OID: DateOID, Codec: DateCodec{}}) + ci.RegisterDataType(&DataType{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) + ci.RegisterDataType(&DataType{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) + ci.RegisterDataType(&DataType{Name: "inet", OID: InetOID, Codec: InetCodec{}}) + ci.RegisterDataType(&DataType{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) + ci.RegisterDataType(&DataType{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) + ci.RegisterDataType(&DataType{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) + ci.RegisterDataType(&DataType{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) + ci.RegisterDataType(&DataType{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) + ci.RegisterDataType(&DataType{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) + ci.RegisterDataType(&DataType{Name: "line", OID: LineOID, Codec: LineCodec{}}) + ci.RegisterDataType(&DataType{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) + ci.RegisterDataType(&DataType{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) + ci.RegisterDataType(&DataType{Name: "name", OID: NameOID, Codec: TextCodec{}}) + ci.RegisterDataType(&DataType{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) + ci.RegisterDataType(&DataType{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) + ci.RegisterDataType(&DataType{Name: "path", OID: PathOID, Codec: PathCodec{}}) + ci.RegisterDataType(&DataType{Name: "point", OID: PointOID, Codec: PointCodec{}}) + ci.RegisterDataType(&DataType{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) + ci.RegisterDataType(&DataType{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) + ci.RegisterDataType(&DataType{Name: "text", OID: TextOID, Codec: TextCodec{}}) + ci.RegisterDataType(&DataType{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) + ci.RegisterDataType(&DataType{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) + ci.RegisterDataType(&DataType{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) + ci.RegisterDataType(&DataType{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) + ci.RegisterDataType(&DataType{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) + ci.RegisterDataType(&DataType{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) + ci.RegisterDataType(&DataType{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) + ci.RegisterDataType(&DataType{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) + ci.RegisterDataType(&DataType{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) - ci.RegisterDataType(DataType{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[DateOID]}}) - ci.RegisterDataType(DataType{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[Int4OID]}}) - ci.RegisterDataType(DataType{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[Int8OID]}}) - ci.RegisterDataType(DataType{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[NumericOID]}}) - ci.RegisterDataType(DataType{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[TimestampOID]}}) - ci.RegisterDataType(DataType{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[TimestamptzOID]}}) + ci.RegisterDataType(&DataType{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[DateOID]}}) + ci.RegisterDataType(&DataType{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[Int4OID]}}) + ci.RegisterDataType(&DataType{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[Int8OID]}}) + ci.RegisterDataType(&DataType{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[NumericOID]}}) + ci.RegisterDataType(&DataType{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[TimestampOID]}}) + ci.RegisterDataType(&DataType{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[TimestamptzOID]}}) - ci.RegisterDataType(DataType{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[ACLItemOID]}}) - ci.RegisterDataType(DataType{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BitOID]}}) - ci.RegisterDataType(DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BoolOID]}}) - ci.RegisterDataType(DataType{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BoxOID]}}) - ci.RegisterDataType(DataType{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BPCharOID]}}) - ci.RegisterDataType(DataType{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[ByteaOID]}}) - ci.RegisterDataType(DataType{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[QCharOID]}}) - ci.RegisterDataType(DataType{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[CIDOID]}}) - ci.RegisterDataType(DataType{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[CIDROID]}}) - ci.RegisterDataType(DataType{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[CircleOID]}}) - ci.RegisterDataType(DataType{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[DateOID]}}) - ci.RegisterDataType(DataType{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[DaterangeOID]}}) - ci.RegisterDataType(DataType{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Float4OID]}}) - ci.RegisterDataType(DataType{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Float8OID]}}) - ci.RegisterDataType(DataType{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[InetOID]}}) - ci.RegisterDataType(DataType{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int2OID]}}) - ci.RegisterDataType(DataType{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int4OID]}}) - ci.RegisterDataType(DataType{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int4rangeOID]}}) - ci.RegisterDataType(DataType{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int8OID]}}) - ci.RegisterDataType(DataType{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int8rangeOID]}}) - ci.RegisterDataType(DataType{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[IntervalOID]}}) - ci.RegisterDataType(DataType{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[JSONOID]}}) - ci.RegisterDataType(DataType{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[JSONBOID]}}) - ci.RegisterDataType(DataType{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[LineOID]}}) - ci.RegisterDataType(DataType{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[LsegOID]}}) - ci.RegisterDataType(DataType{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[MacaddrOID]}}) - ci.RegisterDataType(DataType{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[NameOID]}}) - ci.RegisterDataType(DataType{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[NumericOID]}}) - ci.RegisterDataType(DataType{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[NumrangeOID]}}) - ci.RegisterDataType(DataType{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[OIDOID]}}) - ci.RegisterDataType(DataType{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PathOID]}}) - ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PointOID]}}) - ci.RegisterDataType(DataType{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PolygonOID]}}) - ci.RegisterDataType(DataType{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[RecordOID]}}) - ci.RegisterDataType(DataType{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TextOID]}}) - ci.RegisterDataType(DataType{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TIDOID]}}) - ci.RegisterDataType(DataType{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimeOID]}}) - ci.RegisterDataType(DataType{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimestampOID]}}) - ci.RegisterDataType(DataType{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimestamptzOID]}}) - ci.RegisterDataType(DataType{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TsrangeOID]}}) - ci.RegisterDataType(DataType{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TstzrangeOID]}}) - ci.RegisterDataType(DataType{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[UUIDOID]}}) - ci.RegisterDataType(DataType{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[VarbitOID]}}) - ci.RegisterDataType(DataType{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[VarcharOID]}}) - ci.RegisterDataType(DataType{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[XIDOID]}}) + ci.RegisterDataType(&DataType{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[ACLItemOID]}}) + ci.RegisterDataType(&DataType{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BitOID]}}) + ci.RegisterDataType(&DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BoolOID]}}) + ci.RegisterDataType(&DataType{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BoxOID]}}) + ci.RegisterDataType(&DataType{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BPCharOID]}}) + ci.RegisterDataType(&DataType{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[ByteaOID]}}) + ci.RegisterDataType(&DataType{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[QCharOID]}}) + ci.RegisterDataType(&DataType{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[CIDOID]}}) + ci.RegisterDataType(&DataType{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[CIDROID]}}) + ci.RegisterDataType(&DataType{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[CircleOID]}}) + ci.RegisterDataType(&DataType{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[DateOID]}}) + ci.RegisterDataType(&DataType{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[DaterangeOID]}}) + ci.RegisterDataType(&DataType{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Float4OID]}}) + ci.RegisterDataType(&DataType{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Float8OID]}}) + ci.RegisterDataType(&DataType{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[InetOID]}}) + ci.RegisterDataType(&DataType{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int2OID]}}) + ci.RegisterDataType(&DataType{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int4OID]}}) + ci.RegisterDataType(&DataType{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int4rangeOID]}}) + ci.RegisterDataType(&DataType{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int8OID]}}) + ci.RegisterDataType(&DataType{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int8rangeOID]}}) + ci.RegisterDataType(&DataType{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[IntervalOID]}}) + ci.RegisterDataType(&DataType{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[JSONOID]}}) + ci.RegisterDataType(&DataType{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[JSONBOID]}}) + ci.RegisterDataType(&DataType{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[LineOID]}}) + ci.RegisterDataType(&DataType{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[LsegOID]}}) + ci.RegisterDataType(&DataType{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[MacaddrOID]}}) + ci.RegisterDataType(&DataType{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[NameOID]}}) + ci.RegisterDataType(&DataType{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[NumericOID]}}) + ci.RegisterDataType(&DataType{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[NumrangeOID]}}) + ci.RegisterDataType(&DataType{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[OIDOID]}}) + ci.RegisterDataType(&DataType{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PathOID]}}) + ci.RegisterDataType(&DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PointOID]}}) + ci.RegisterDataType(&DataType{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PolygonOID]}}) + ci.RegisterDataType(&DataType{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[RecordOID]}}) + ci.RegisterDataType(&DataType{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TextOID]}}) + ci.RegisterDataType(&DataType{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TIDOID]}}) + ci.RegisterDataType(&DataType{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimeOID]}}) + ci.RegisterDataType(&DataType{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimestampOID]}}) + ci.RegisterDataType(&DataType{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimestamptzOID]}}) + ci.RegisterDataType(&DataType{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TsrangeOID]}}) + ci.RegisterDataType(&DataType{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TstzrangeOID]}}) + ci.RegisterDataType(&DataType{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[UUIDOID]}}) + ci.RegisterDataType(&DataType{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[VarbitOID]}}) + ci.RegisterDataType(&DataType{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[VarcharOID]}}) + ci.RegisterDataType(&DataType{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[XIDOID]}}) registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) { // T @@ -361,9 +361,9 @@ func NewConnInfo() *ConnInfo { return ci } -func (ci *ConnInfo) RegisterDataType(t DataType) { - ci.oidToDataType[t.OID] = &t - ci.nameToDataType[t.Name] = &t +func (ci *ConnInfo) RegisterDataType(t *DataType) { + ci.oidToDataType[t.OID] = t + ci.nameToDataType[t.Name] = t ci.oidToFormatCode[t.OID] = t.Codec.PreferredFormat() ci.reflectTypeToDataType = nil // Invalidated by type registration } From bda10b2ec97374a4d724ea492aab730faf14cb18 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 21 Feb 2022 09:01:48 -0600 Subject: [PATCH 0913/1158] Rename pgtype.DataType to pgtype.Type --- bench_test.go | 2 +- conn.go | 19 ++- conn_test.go | 6 +- copy_from_test.go | 2 +- extended_query_builder.go | 2 +- pgtype/array_codec.go | 24 ++-- pgtype/composite.go | 22 ++-- pgtype/composite_test.go | 16 +-- pgtype/enum_codec_test.go | 8 +- pgtype/hstore_test.go | 2 +- pgtype/line_test.go | 2 +- pgtype/pgtype.go | 244 +++++++++++++++++++------------------- pgtype/range_codec.go | 20 ++-- rows.go | 2 +- stdlib/sql.go | 2 +- values.go | 4 +- values_test.go | 4 +- 17 files changed, 190 insertions(+), 191 deletions(-) diff --git a/bench_test.go b/bench_test.go index f6e8a871..587ce162 100644 --- a/bench_test.go +++ b/bench_test.go @@ -918,7 +918,7 @@ func BenchmarkSelectManyRegisteredEnum(b *testing.B) { err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "color").Scan(&oid) require.NoError(b, err) - conn.ConnInfo().RegisterDataType(&pgtype.DataType{Name: "color", OID: oid, Codec: &pgtype.EnumCodec{}}) + conn.ConnInfo().RegisterType(&pgtype.Type{Name: "color", OID: oid, Codec: &pgtype.EnumCodec{}}) b.ResetTimer() var x, y, z string diff --git a/conn.go b/conn.go index 035ed209..7a66b5e2 100644 --- a/conn.go +++ b/conn.go @@ -832,9 +832,8 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, return sanitize.SanitizeSQL(sql, valueArgs...) } -// LoadDataType inspects the database for typeName and produces a pgtype.DataType suitable for -// registration. -func (c *Conn) LoadDataType(ctx context.Context, typeName string) (*pgtype.DataType, error) { +// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. +func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) { var oid uint32 err := c.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid) @@ -856,23 +855,23 @@ func (c *Conn) LoadDataType(ctx context.Context, typeName string) (*pgtype.DataT return nil, err } - dt, ok := c.ConnInfo().DataTypeForOID(elementOID) + dt, ok := c.ConnInfo().TypeForOID(elementOID) if !ok { return nil, errors.New("array element OID not registered") } - return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementDataType: dt}}, nil + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementType: dt}}, nil case "c": // composite fields, err := c.getCompositeFields(ctx, oid) if err != nil { return nil, err } - return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil case "e": // enum - return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil default: - return &pgtype.DataType{}, errors.New("unknown typtype") + return &pgtype.Type{}, errors.New("unknown typtype") } } @@ -905,11 +904,11 @@ order by attnum`, []interface{}{typrelid}, []interface{}{&fieldName, &fieldOID}, func(qfr QueryFuncRow) error { - dt, ok := c.ConnInfo().DataTypeForOID(fieldOID) + dt, ok := c.ConnInfo().TypeForOID(fieldOID) if !ok { return fmt.Errorf("unknown composite type field OID: %v", fieldOID) } - fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, DataType: dt}) + fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt}) return nil }) if err != nil { diff --git a/conn_test.go b/conn_test.go index da83a7cc..22928a01 100644 --- a/conn_test.go +++ b/conn_test.go @@ -841,11 +841,11 @@ func TestConnInitConnInfo(t *testing.T) { "text": pgtype.TextOID, } for name, oid := range nameOIDs { - dtByName, ok := conn.ConnInfo().DataTypeForName(name) + dtByName, ok := conn.ConnInfo().TypeForName(name) if !ok { t.Fatalf("Expected type named %v to be present", name) } - dtByOID, ok := conn.ConnInfo().DataTypeForOID(oid) + dtByOID, ok := conn.ConnInfo().TypeForOID(oid) if !ok { t.Fatalf("Expected type OID %v to be present", oid) } @@ -891,7 +891,7 @@ func TestDomainType(t *testing.T) { if err != nil { t.Fatalf("did not find uint64 OID, %v", err) } - conn.ConnInfo().RegisterDataType(&pgtype.DataType{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}}) + conn.ConnInfo().RegisterType(&pgtype.Type{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}}) var n uint64 err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) diff --git a/copy_from_test.go b/copy_from_test.go index 32644f38..867b53af 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -256,7 +256,7 @@ func TestConnCopyFromJSON(t *testing.T) { defer closeConn(t, conn) for _, typeName := range []string{"json", "jsonb"} { - if _, ok := conn.ConnInfo().DataTypeForName(typeName); !ok { + if _, ok := conn.ConnInfo().TypeForName(typeName); !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } } diff --git a/extended_query_builder.go b/extended_query_builder.go index 8b250685..a28900c4 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -83,7 +83,7 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg) } - if _, ok := ci.DataTypeForOID(oid); ok { + if _, ok := ci.TypeForOID(oid); ok { buf, err := ci.Encode(oid, formatCode, arg, eqb.paramValueBytes) if err != nil { return nil, err diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index f23d8e3b..0b32b37f 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -38,15 +38,15 @@ type ArraySetter interface { // ArrayCodec is a codec for any array type. type ArrayCodec struct { - ElementDataType *DataType + ElementType *Type } func (c *ArrayCodec) FormatSupported(format int16) bool { - return c.ElementDataType.Codec.FormatSupported(format) + return c.ElementType.Codec.FormatSupported(format) } func (c *ArrayCodec) PreferredFormat() int16 { - return c.ElementDataType.Codec.PreferredFormat() + return c.ElementType.Codec.PreferredFormat() } func (c *ArrayCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { @@ -57,7 +57,7 @@ func (c *ArrayCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value in elementType := arrayValuer.IndexType() - elementEncodePlan := ci.PlanEncode(c.ElementDataType.OID, format, elementType) + elementEncodePlan := ci.PlanEncode(c.ElementType.OID, format, elementType) if elementEncodePlan == nil { return nil } @@ -124,7 +124,7 @@ func (p *encodePlanArrayCodecText) Encode(value interface{}, buf []byte) (newBuf elemType := reflect.TypeOf(elem) if lastElemType != elemType { lastElemType = elemType - encodePlan = p.ci.PlanEncode(p.ac.ElementDataType.OID, TextFormatCode, elem) + encodePlan = p.ci.PlanEncode(p.ac.ElementType.OID, TextFormatCode, elem) if encodePlan == nil { return nil, fmt.Errorf("unable to encode %v", array.Index(i)) } @@ -167,7 +167,7 @@ func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newB arrayHeader := ArrayHeader{ Dimensions: dimensions, - ElementOID: p.ac.ElementDataType.OID, + ElementOID: p.ac.ElementType.OID, } containsNullIndex := len(buf) + 4 @@ -188,7 +188,7 @@ func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newB elemType := reflect.TypeOf(elem) if lastElemType != elemType { lastElemType = elemType - encodePlan = p.ci.PlanEncode(p.ac.ElementDataType.OID, BinaryFormatCode, elem) + encodePlan = p.ci.PlanEncode(p.ac.ElementType.OID, BinaryFormatCode, elem) if encodePlan == nil { return nil, fmt.Errorf("unable to encode %v", array.Index(i)) } @@ -218,7 +218,7 @@ func (c *ArrayCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target int elementType := arrayScanner.ScanIndexType() - elementScanPlan := ci.PlanScan(c.ElementDataType.OID, format, elementType) + elementScanPlan := ci.PlanScan(c.ElementType.OID, format, elementType) if _, ok := elementScanPlan.(*scanPlanFail); ok { return nil } @@ -248,9 +248,9 @@ func (c *ArrayCodec) decodeBinary(ci *ConnInfo, arrayOID uint32, src []byte, arr return nil } - elementScanPlan := c.ElementDataType.Codec.PlanScan(ci, c.ElementDataType.OID, BinaryFormatCode, array.ScanIndex(0), false) + elementScanPlan := c.ElementType.Codec.PlanScan(ci, c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0), false) if elementScanPlan == nil { - elementScanPlan = ci.PlanScan(c.ElementDataType.OID, BinaryFormatCode, array.ScanIndex(0)) + elementScanPlan = ci.PlanScan(c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0)) } for i := 0; i < elementCount; i++ { @@ -286,9 +286,9 @@ func (c *ArrayCodec) decodeText(ci *ConnInfo, arrayOID uint32, src []byte, array return nil } - elementScanPlan := c.ElementDataType.Codec.PlanScan(ci, c.ElementDataType.OID, TextFormatCode, array.ScanIndex(0), false) + elementScanPlan := c.ElementType.Codec.PlanScan(ci, c.ElementType.OID, TextFormatCode, array.ScanIndex(0), false) if elementScanPlan == nil { - elementScanPlan = ci.PlanScan(c.ElementDataType.OID, TextFormatCode, array.ScanIndex(0)) + elementScanPlan = ci.PlanScan(c.ElementType.OID, TextFormatCode, array.ScanIndex(0)) } for i, s := range uta.Elements { diff --git a/pgtype/composite.go b/pgtype/composite.go index 2ccc7b1d..1142a704 100644 --- a/pgtype/composite.go +++ b/pgtype/composite.go @@ -29,8 +29,8 @@ type CompositeIndexScanner interface { } type CompositeCodecField struct { - Name string - DataType *DataType + Name string + Type *Type } type CompositeCodec struct { @@ -39,7 +39,7 @@ type CompositeCodec struct { func (c *CompositeCodec) FormatSupported(format int16) bool { for _, f := range c.Fields { - if !f.DataType.Codec.FormatSupported(format) { + if !f.Type.Codec.FormatSupported(format) { return false } } @@ -83,7 +83,7 @@ func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value i builder := NewCompositeBinaryBuilder(plan.ci, buf) for i, field := range plan.cc.Fields { - builder.AppendValue(field.DataType.OID, getter.Index(i)) + builder.AppendValue(field.Type.OID, getter.Index(i)) } return builder.Finish() @@ -103,7 +103,7 @@ func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value int b := NewCompositeTextBuilder(plan.ci, buf) for i, field := range plan.cc.Fields { - b.AppendValue(field.DataType.OID, getter.Index(i)) + b.AppendValue(field.Type.OID, getter.Index(i)) } return b.Finish() @@ -143,9 +143,9 @@ func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, tar if scanner.Next() { fieldTarget := targetScanner.ScanIndex(i) if fieldTarget != nil { - fieldPlan := plan.ci.PlanScan(field.DataType.OID, BinaryFormatCode, fieldTarget) + fieldPlan := plan.ci.PlanScan(field.Type.OID, BinaryFormatCode, fieldTarget) if fieldPlan == nil { - return fmt.Errorf("unable to encode %v into OID %d in binary format", field, field.DataType.OID) + return fmt.Errorf("unable to encode %v into OID %d in binary format", field, field.Type.OID) } err := fieldPlan.Scan(scanner.Bytes(), fieldTarget) @@ -182,9 +182,9 @@ func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, targe if scanner.Next() { fieldTarget := targetScanner.ScanIndex(i) if fieldTarget != nil { - fieldPlan := plan.ci.PlanScan(field.DataType.OID, TextFormatCode, fieldTarget) + fieldPlan := plan.ci.PlanScan(field.Type.OID, TextFormatCode, fieldTarget) if fieldPlan == nil { - return fmt.Errorf("unable to encode %v into OID %d in text format", field, field.DataType.OID) + return fmt.Errorf("unable to encode %v into OID %d in text format", field, field.Type.OID) } err := fieldPlan.Scan(scanner.Bytes(), fieldTarget) @@ -232,9 +232,9 @@ func (c *CompositeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src values := make(map[string]interface{}, len(c.Fields)) for i := 0; scanner.Next() && i < len(c.Fields); i++ { var v interface{} - fieldPlan := ci.PlanScan(c.Fields[i].DataType.OID, TextFormatCode, &v) + fieldPlan := ci.PlanScan(c.Fields[i].Type.OID, TextFormatCode, &v) if fieldPlan == nil { - return nil, fmt.Errorf("unable to scan OID %d in text format into %v", c.Fields[i].DataType.OID, v) + return nil, fmt.Errorf("unable to scan OID %d in text format into %v", c.Fields[i].Type.OID, v) } err := fieldPlan.Scan(scanner.Bytes(), &v) diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go index 954a5f6a..f96a6470 100644 --- a/pgtype/composite_test.go +++ b/pgtype/composite_test.go @@ -24,9 +24,9 @@ create type ct_test as ( require.NoError(t, err) defer conn.Exec(context.Background(), "drop type ct_test") - dt, err := conn.LoadDataType(context.Background(), "ct_test") + dt, err := conn.LoadType(context.Background(), "ct_test") require.NoError(t, err) - conn.ConnInfo().RegisterDataType(dt) + conn.ConnInfo().RegisterType(dt) formats := []struct { name string @@ -103,9 +103,9 @@ create type point3d as ( require.NoError(t, err) defer conn.Exec(context.Background(), "drop type point3d") - dt, err := conn.LoadDataType(context.Background(), "point3d") + dt, err := conn.LoadType(context.Background(), "point3d") require.NoError(t, err) - conn.ConnInfo().RegisterDataType(dt) + conn.ConnInfo().RegisterType(dt) formats := []struct { name string @@ -138,9 +138,9 @@ create type point3d as ( require.NoError(t, err) defer conn.Exec(context.Background(), "drop type point3d") - dt, err := conn.LoadDataType(context.Background(), "point3d") + dt, err := conn.LoadType(context.Background(), "point3d") require.NoError(t, err) - conn.ConnInfo().RegisterDataType(dt) + conn.ConnInfo().RegisterType(dt) formats := []struct { name string @@ -177,9 +177,9 @@ create type point3d as ( require.NoError(t, err) defer conn.Exec(context.Background(), "drop type point3d") - dt, err := conn.LoadDataType(context.Background(), "point3d") + dt, err := conn.LoadType(context.Background(), "point3d") require.NoError(t, err) - conn.ConnInfo().RegisterDataType(dt) + conn.ConnInfo().RegisterType(dt) formats := []struct { name string diff --git a/pgtype/enum_codec_test.go b/pgtype/enum_codec_test.go index 5ced8a11..f8dba1a0 100644 --- a/pgtype/enum_codec_test.go +++ b/pgtype/enum_codec_test.go @@ -18,10 +18,10 @@ create type enum_test as enum ('foo', 'bar', 'baz');`) require.NoError(t, err) defer conn.Exec(context.Background(), "drop type enum_test") - dt, err := conn.LoadDataType(context.Background(), "enum_test") + dt, err := conn.LoadType(context.Background(), "enum_test") require.NoError(t, err) - conn.ConnInfo().RegisterDataType(dt) + conn.ConnInfo().RegisterType(dt) var s string err = conn.QueryRow(context.Background(), `select 'foo'::enum_test`).Scan(&s) @@ -55,10 +55,10 @@ create type enum_test as enum ('foo', 'bar', 'baz');`) require.NoError(t, err) defer conn.Exec(context.Background(), "drop type enum_test") - dt, err := conn.LoadDataType(context.Background(), "enum_test") + dt, err := conn.LoadType(context.Background(), "enum_test") require.NoError(t, err) - conn.ConnInfo().RegisterDataType(dt) + conn.ConnInfo().RegisterType(dt) rows, err := conn.Query(context.Background(), `select 'foo'::enum_test`) require.NoError(t, err) diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go index 0967bf0c..2437a240 100644 --- a/pgtype/hstore_test.go +++ b/pgtype/hstore_test.go @@ -61,7 +61,7 @@ func TestHstoreCodec(t *testing.T) { t.Skipf("Skipping: cannot find hstore OID") } - conn.ConnInfo().RegisterDataType(&pgtype.DataType{Name: "hstore", OID: hstoreOID, Codec: pgtype.HstoreCodec{}}) + conn.ConnInfo().RegisterType(&pgtype.Type{Name: "hstore", OID: hstoreOID, Codec: pgtype.HstoreCodec{}}) formats := []struct { name string diff --git a/pgtype/line_test.go b/pgtype/line_test.go index b7c82e35..3ed8fc4b 100644 --- a/pgtype/line_test.go +++ b/pgtype/line_test.go @@ -11,7 +11,7 @@ import ( func TestLineTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) defer conn.Close(context.Background()) - if _, ok := conn.ConnInfo().DataTypeForName("line"); !ok { + if _, ok := conn.ConnInfo().TypeForName("line"); !ok { t.Skip("Skipping due to no line type") } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 6f5c8878..81431826 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -163,20 +163,20 @@ func (e *nullAssignmentError) Error() string { return fmt.Sprintf("cannot assign NULL to %T", e.dst) } -type DataType struct { +type Type struct { Codec Codec Name string OID uint32 } type ConnInfo struct { - oidToDataType map[uint32]*DataType - nameToDataType map[string]*DataType + oidToType map[uint32]*Type + nameToType map[string]*Type reflectTypeToName map[reflect.Type]string oidToFormatCode map[uint32]int16 oidToResultFormatCode map[uint32]int16 - reflectTypeToDataType map[reflect.Type]*DataType + reflectTypeToType map[reflect.Type]*Type // TryWrapEncodePlanFuncs is a slice of functions that will wrap a value that cannot be encoded by the Codec. Every // time a wrapper is found the PlanEncode method will be recursively called with the new value. This allows several layers of wrappers @@ -193,8 +193,8 @@ type ConnInfo struct { func NewConnInfo() *ConnInfo { ci := &ConnInfo{ - oidToDataType: make(map[uint32]*DataType), - nameToDataType: make(map[string]*DataType), + oidToType: make(map[uint32]*Type), + nameToType: make(map[string]*Type), reflectTypeToName: make(map[reflect.Type]string), oidToFormatCode: make(map[uint32]int16), oidToResultFormatCode: make(map[uint32]int16), @@ -218,99 +218,99 @@ func NewConnInfo() *ConnInfo { }, } - ci.RegisterDataType(&DataType{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) - ci.RegisterDataType(&DataType{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) - ci.RegisterDataType(&DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) - ci.RegisterDataType(&DataType{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) - ci.RegisterDataType(&DataType{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) - ci.RegisterDataType(&DataType{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) - ci.RegisterDataType(&DataType{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) - ci.RegisterDataType(&DataType{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) - ci.RegisterDataType(&DataType{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) - ci.RegisterDataType(&DataType{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) - ci.RegisterDataType(&DataType{Name: "date", OID: DateOID, Codec: DateCodec{}}) - ci.RegisterDataType(&DataType{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) - ci.RegisterDataType(&DataType{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) - ci.RegisterDataType(&DataType{Name: "inet", OID: InetOID, Codec: InetCodec{}}) - ci.RegisterDataType(&DataType{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) - ci.RegisterDataType(&DataType{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) - ci.RegisterDataType(&DataType{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) - ci.RegisterDataType(&DataType{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) - ci.RegisterDataType(&DataType{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) - ci.RegisterDataType(&DataType{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) - ci.RegisterDataType(&DataType{Name: "line", OID: LineOID, Codec: LineCodec{}}) - ci.RegisterDataType(&DataType{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) - ci.RegisterDataType(&DataType{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) - ci.RegisterDataType(&DataType{Name: "name", OID: NameOID, Codec: TextCodec{}}) - ci.RegisterDataType(&DataType{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) - ci.RegisterDataType(&DataType{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) - ci.RegisterDataType(&DataType{Name: "path", OID: PathOID, Codec: PathCodec{}}) - ci.RegisterDataType(&DataType{Name: "point", OID: PointOID, Codec: PointCodec{}}) - ci.RegisterDataType(&DataType{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) - ci.RegisterDataType(&DataType{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) - ci.RegisterDataType(&DataType{Name: "text", OID: TextOID, Codec: TextCodec{}}) - ci.RegisterDataType(&DataType{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) - ci.RegisterDataType(&DataType{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) - ci.RegisterDataType(&DataType{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) - ci.RegisterDataType(&DataType{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) - ci.RegisterDataType(&DataType{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) - ci.RegisterDataType(&DataType{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) - ci.RegisterDataType(&DataType{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) - ci.RegisterDataType(&DataType{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) - ci.RegisterDataType(&DataType{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) + ci.RegisterType(&Type{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) + ci.RegisterType(&Type{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) + ci.RegisterType(&Type{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) + ci.RegisterType(&Type{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) + ci.RegisterType(&Type{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) + ci.RegisterType(&Type{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) + ci.RegisterType(&Type{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) + ci.RegisterType(&Type{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) + ci.RegisterType(&Type{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) + ci.RegisterType(&Type{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) + ci.RegisterType(&Type{Name: "date", OID: DateOID, Codec: DateCodec{}}) + ci.RegisterType(&Type{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) + ci.RegisterType(&Type{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) + ci.RegisterType(&Type{Name: "inet", OID: InetOID, Codec: InetCodec{}}) + ci.RegisterType(&Type{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) + ci.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) + ci.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) + ci.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) + ci.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) + ci.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) + ci.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) + ci.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) + ci.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) + ci.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) + ci.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) + ci.RegisterType(&Type{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) + ci.RegisterType(&Type{Name: "path", OID: PathOID, Codec: PathCodec{}}) + ci.RegisterType(&Type{Name: "point", OID: PointOID, Codec: PointCodec{}}) + ci.RegisterType(&Type{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) + ci.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) + ci.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) + ci.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) + ci.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) + ci.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) + ci.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) + ci.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) + ci.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) + ci.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) + ci.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) + ci.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) - ci.RegisterDataType(&DataType{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[DateOID]}}) - ci.RegisterDataType(&DataType{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[Int4OID]}}) - ci.RegisterDataType(&DataType{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[Int8OID]}}) - ci.RegisterDataType(&DataType{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[NumericOID]}}) - ci.RegisterDataType(&DataType{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[TimestampOID]}}) - ci.RegisterDataType(&DataType{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[TimestamptzOID]}}) + ci.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: ci.oidToType[DateOID]}}) + ci.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: ci.oidToType[Int4OID]}}) + ci.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: ci.oidToType[Int8OID]}}) + ci.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: ci.oidToType[NumericOID]}}) + ci.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: ci.oidToType[TimestampOID]}}) + ci.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: ci.oidToType[TimestamptzOID]}}) - ci.RegisterDataType(&DataType{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[ACLItemOID]}}) - ci.RegisterDataType(&DataType{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BitOID]}}) - ci.RegisterDataType(&DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BoolOID]}}) - ci.RegisterDataType(&DataType{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BoxOID]}}) - ci.RegisterDataType(&DataType{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BPCharOID]}}) - ci.RegisterDataType(&DataType{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[ByteaOID]}}) - ci.RegisterDataType(&DataType{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[QCharOID]}}) - ci.RegisterDataType(&DataType{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[CIDOID]}}) - ci.RegisterDataType(&DataType{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[CIDROID]}}) - ci.RegisterDataType(&DataType{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[CircleOID]}}) - ci.RegisterDataType(&DataType{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[DateOID]}}) - ci.RegisterDataType(&DataType{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[DaterangeOID]}}) - ci.RegisterDataType(&DataType{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Float4OID]}}) - ci.RegisterDataType(&DataType{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Float8OID]}}) - ci.RegisterDataType(&DataType{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[InetOID]}}) - ci.RegisterDataType(&DataType{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int2OID]}}) - ci.RegisterDataType(&DataType{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int4OID]}}) - ci.RegisterDataType(&DataType{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int4rangeOID]}}) - ci.RegisterDataType(&DataType{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int8OID]}}) - ci.RegisterDataType(&DataType{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[Int8rangeOID]}}) - ci.RegisterDataType(&DataType{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[IntervalOID]}}) - ci.RegisterDataType(&DataType{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[JSONOID]}}) - ci.RegisterDataType(&DataType{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[JSONBOID]}}) - ci.RegisterDataType(&DataType{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[LineOID]}}) - ci.RegisterDataType(&DataType{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[LsegOID]}}) - ci.RegisterDataType(&DataType{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[MacaddrOID]}}) - ci.RegisterDataType(&DataType{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[NameOID]}}) - ci.RegisterDataType(&DataType{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[NumericOID]}}) - ci.RegisterDataType(&DataType{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[NumrangeOID]}}) - ci.RegisterDataType(&DataType{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[OIDOID]}}) - ci.RegisterDataType(&DataType{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PathOID]}}) - ci.RegisterDataType(&DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PointOID]}}) - ci.RegisterDataType(&DataType{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PolygonOID]}}) - ci.RegisterDataType(&DataType{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[RecordOID]}}) - ci.RegisterDataType(&DataType{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TextOID]}}) - ci.RegisterDataType(&DataType{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TIDOID]}}) - ci.RegisterDataType(&DataType{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimeOID]}}) - ci.RegisterDataType(&DataType{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimestampOID]}}) - ci.RegisterDataType(&DataType{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimestamptzOID]}}) - ci.RegisterDataType(&DataType{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TsrangeOID]}}) - ci.RegisterDataType(&DataType{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TstzrangeOID]}}) - ci.RegisterDataType(&DataType{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[UUIDOID]}}) - ci.RegisterDataType(&DataType{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[VarbitOID]}}) - ci.RegisterDataType(&DataType{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[VarcharOID]}}) - ci.RegisterDataType(&DataType{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[XIDOID]}}) + ci.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[ACLItemOID]}}) + ci.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[BitOID]}}) + ci.RegisterType(&Type{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[BoolOID]}}) + ci.RegisterType(&Type{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[BoxOID]}}) + ci.RegisterType(&Type{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[BPCharOID]}}) + ci.RegisterType(&Type{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[ByteaOID]}}) + ci.RegisterType(&Type{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[QCharOID]}}) + ci.RegisterType(&Type{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[CIDOID]}}) + ci.RegisterType(&Type{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[CIDROID]}}) + ci.RegisterType(&Type{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[CircleOID]}}) + ci.RegisterType(&Type{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[DateOID]}}) + ci.RegisterType(&Type{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[DaterangeOID]}}) + ci.RegisterType(&Type{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[Float4OID]}}) + ci.RegisterType(&Type{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[Float8OID]}}) + ci.RegisterType(&Type{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[InetOID]}}) + ci.RegisterType(&Type{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[Int2OID]}}) + ci.RegisterType(&Type{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[Int4OID]}}) + ci.RegisterType(&Type{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[Int4rangeOID]}}) + ci.RegisterType(&Type{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[Int8OID]}}) + ci.RegisterType(&Type{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[Int8rangeOID]}}) + ci.RegisterType(&Type{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[IntervalOID]}}) + ci.RegisterType(&Type{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[JSONOID]}}) + ci.RegisterType(&Type{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[JSONBOID]}}) + ci.RegisterType(&Type{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[LineOID]}}) + ci.RegisterType(&Type{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[LsegOID]}}) + ci.RegisterType(&Type{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[MacaddrOID]}}) + ci.RegisterType(&Type{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[NameOID]}}) + ci.RegisterType(&Type{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[NumericOID]}}) + ci.RegisterType(&Type{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[NumrangeOID]}}) + ci.RegisterType(&Type{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[OIDOID]}}) + ci.RegisterType(&Type{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[PathOID]}}) + ci.RegisterType(&Type{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[PointOID]}}) + ci.RegisterType(&Type{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[PolygonOID]}}) + ci.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[RecordOID]}}) + ci.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[TextOID]}}) + ci.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[TIDOID]}}) + ci.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[TimeOID]}}) + ci.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[TimestampOID]}}) + ci.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[TimestamptzOID]}}) + ci.RegisterType(&Type{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[TsrangeOID]}}) + ci.RegisterType(&Type{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[TstzrangeOID]}}) + ci.RegisterType(&Type{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[UUIDOID]}}) + ci.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[VarbitOID]}}) + ci.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[VarcharOID]}}) + ci.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[XIDOID]}}) registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) { // T @@ -361,49 +361,49 @@ func NewConnInfo() *ConnInfo { return ci } -func (ci *ConnInfo) RegisterDataType(t *DataType) { - ci.oidToDataType[t.OID] = t - ci.nameToDataType[t.Name] = t +func (ci *ConnInfo) RegisterType(t *Type) { + ci.oidToType[t.OID] = t + ci.nameToType[t.Name] = t ci.oidToFormatCode[t.OID] = t.Codec.PreferredFormat() - ci.reflectTypeToDataType = nil // Invalidated by type registration + ci.reflectTypeToType = nil // Invalidated by type registration } // RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be // encoded or decoded is determined by the PostgreSQL OID. But if the OID of a value to be encoded or decoded is -// unknown, this additional mapping will be used by DataTypeForValue to determine a suitable data type. +// unknown, this additional mapping will be used by TypeForValue to determine a suitable data type. func (ci *ConnInfo) RegisterDefaultPgType(value interface{}, name string) { ci.reflectTypeToName[reflect.TypeOf(value)] = name - ci.reflectTypeToDataType = nil // Invalidated by registering a default type + ci.reflectTypeToType = nil // Invalidated by registering a default type } -func (ci *ConnInfo) DataTypeForOID(oid uint32) (*DataType, bool) { - dt, ok := ci.oidToDataType[oid] +func (ci *ConnInfo) TypeForOID(oid uint32) (*Type, bool) { + dt, ok := ci.oidToType[oid] return dt, ok } -func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) { - dt, ok := ci.nameToDataType[name] +func (ci *ConnInfo) TypeForName(name string) (*Type, bool) { + dt, ok := ci.nameToType[name] return dt, ok } -func (ci *ConnInfo) buildReflectTypeToDataType() { - ci.reflectTypeToDataType = make(map[reflect.Type]*DataType) +func (ci *ConnInfo) buildReflectTypeToType() { + ci.reflectTypeToType = make(map[reflect.Type]*Type) for reflectType, name := range ci.reflectTypeToName { - if dt, ok := ci.nameToDataType[name]; ok { - ci.reflectTypeToDataType[reflectType] = dt + if dt, ok := ci.nameToType[name]; ok { + ci.reflectTypeToType[reflectType] = dt } } } -// DataTypeForValue finds a data type suitable for v. Use RegisterDataType to register types that can encode and decode +// TypeForValue finds a data type suitable for v. Use RegisterType to register types that can encode and decode // themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. -func (ci *ConnInfo) DataTypeForValue(v interface{}) (*DataType, bool) { - if ci.reflectTypeToDataType == nil { - ci.buildReflectTypeToDataType() +func (ci *ConnInfo) TypeForValue(v interface{}) (*Type, bool) { + if ci.reflectTypeToType == nil { + ci.buildReflectTypeToType() } - dt, ok := ci.reflectTypeToDataType[reflect.TypeOf(v)] + dt, ok := ci.reflectTypeToType[reflect.TypeOf(v)] return dt, ok } @@ -1018,11 +1018,11 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) S } } - var dt *DataType + var dt *Type - if dataType, ok := ci.DataTypeForOID(oid); ok { + if dataType, ok := ci.TypeForOID(oid); ok { dt = dataType - } else if dataType, ok := ci.DataTypeForValue(target); ok { + } else if dataType, ok := ci.TypeForValue(target); ok { dt = dataType oid = dt.OID // Preserve assumed OID in case we are recursively called below. } @@ -1123,15 +1123,15 @@ func codecDecodeToTextFormat(codec Codec, ci *ConnInfo, oid uint32, format int16 // found then nil is returned. func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan { - var dt *DataType + var dt *Type if oid == 0 { - if dataType, ok := ci.DataTypeForValue(value); ok { + if dataType, ok := ci.TypeForValue(value); ok { dt = dataType oid = dt.OID // Preserve assumed OID in case we are recursively called below. } } else { - if dataType, ok := ci.DataTypeForOID(oid); ok { + if dataType, ok := ci.TypeForOID(oid); ok { dt = dataType } } diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go index f5091c36..98903d3a 100644 --- a/pgtype/range_codec.go +++ b/pgtype/range_codec.go @@ -72,11 +72,11 @@ func (r *GenericRange) SetBoundTypes(lower, upper BoundType) error { // RangeCodec is a codec for any range type. type RangeCodec struct { - ElementDataType *DataType + ElementType *Type } func (c *RangeCodec) FormatSupported(format int16) bool { - return c.ElementDataType.Codec.FormatSupported(format) + return c.ElementType.Codec.FormatSupported(format) } func (c *RangeCodec) PreferredFormat() int16 { @@ -149,7 +149,7 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value interface{}, b sp := len(buf) buf = pgio.AppendInt32(buf, -1) - lowerPlan := plan.ci.PlanEncode(plan.rc.ElementDataType.OID, BinaryFormatCode, lower) + lowerPlan := plan.ci.PlanEncode(plan.rc.ElementType.OID, BinaryFormatCode, lower) if lowerPlan == nil { return nil, fmt.Errorf("cannot encode %v as element of range", lower) } @@ -173,7 +173,7 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value interface{}, b sp := len(buf) buf = pgio.AppendInt32(buf, -1) - upperPlan := plan.ci.PlanEncode(plan.rc.ElementDataType.OID, BinaryFormatCode, upper) + upperPlan := plan.ci.PlanEncode(plan.rc.ElementType.OID, BinaryFormatCode, upper) if upperPlan == nil { return nil, fmt.Errorf("cannot encode %v as element of range", upper) } @@ -223,7 +223,7 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value interface{}, buf return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } - lowerPlan := plan.ci.PlanEncode(plan.rc.ElementDataType.OID, TextFormatCode, lower) + lowerPlan := plan.ci.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, lower) if lowerPlan == nil { return nil, fmt.Errorf("cannot encode %v as element of range", lower) } @@ -244,7 +244,7 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value interface{}, buf return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } - upperPlan := plan.ci.PlanEncode(plan.rc.ElementDataType.OID, TextFormatCode, upper) + upperPlan := plan.ci.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, upper) if upperPlan == nil { return nil, fmt.Errorf("cannot encode %v as element of range", upper) } @@ -311,7 +311,7 @@ func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target interface lowerTarget, upperTarget := rangeScanner.ScanBounds() if ubr.LowerType == Inclusive || ubr.LowerType == Exclusive { - lowerPlan := plan.ci.PlanScan(plan.rc.ElementDataType.OID, BinaryFormatCode, lowerTarget) + lowerPlan := plan.ci.PlanScan(plan.rc.ElementType.OID, BinaryFormatCode, lowerTarget) if lowerPlan == nil { return fmt.Errorf("cannot scan into %v from range element", lowerTarget) } @@ -323,7 +323,7 @@ func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target interface } if ubr.UpperType == Inclusive || ubr.UpperType == Exclusive { - upperPlan := plan.ci.PlanScan(plan.rc.ElementDataType.OID, BinaryFormatCode, upperTarget) + upperPlan := plan.ci.PlanScan(plan.rc.ElementType.OID, BinaryFormatCode, upperTarget) if upperPlan == nil { return fmt.Errorf("cannot scan into %v from range element", upperTarget) } @@ -361,7 +361,7 @@ func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target interface{} lowerTarget, upperTarget := rangeScanner.ScanBounds() if utr.LowerType == Inclusive || utr.LowerType == Exclusive { - lowerPlan := plan.ci.PlanScan(plan.rc.ElementDataType.OID, TextFormatCode, lowerTarget) + lowerPlan := plan.ci.PlanScan(plan.rc.ElementType.OID, TextFormatCode, lowerTarget) if lowerPlan == nil { return fmt.Errorf("cannot scan into %v from range element", lowerTarget) } @@ -373,7 +373,7 @@ func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target interface{} } if utr.UpperType == Inclusive || utr.UpperType == Exclusive { - upperPlan := plan.ci.PlanScan(plan.rc.ElementDataType.OID, TextFormatCode, upperTarget) + upperPlan := plan.ci.PlanScan(plan.rc.ElementType.OID, TextFormatCode, upperTarget) if upperPlan == nil { return fmt.Errorf("cannot scan into %v from range element", upperTarget) } diff --git a/rows.go b/rows.go index aa5310fe..25a42466 100644 --- a/rows.go +++ b/rows.go @@ -257,7 +257,7 @@ func (rows *connRows) Values() ([]interface{}, error) { continue } - if dt, ok := rows.connInfo.DataTypeForOID(fd.DataTypeOID); ok { + if dt, ok := rows.connInfo.TypeForOID(fd.DataTypeOID); ok { value, err := dt.Codec.DecodeValue(rows.connInfo, fd.DataTypeOID, fd.Format, buf) if err != nil { rows.fatal(err) diff --git a/stdlib/sql.go b/stdlib/sql.go index b0d92af3..b9079fc8 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -519,7 +519,7 @@ func (r *Rows) Columns() []string { // ColumnTypeDatabaseTypeName returns the database system type name. If the name is unknown the OID is returned. func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { - if dt, ok := r.conn.conn.ConnInfo().DataTypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok { + if dt, ok := r.conn.conn.ConnInfo().TypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok { return strings.ToUpper(dt.Name) } diff --git a/values.go b/values.go index a6fdcc86..7661a94e 100644 --- a/values.go +++ b/values.go @@ -79,7 +79,7 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e return int64(arg), nil } - if _, found := ci.DataTypeForValue(arg); found { + if _, found := ci.TypeForValue(arg); found { buf, err := ci.Encode(0, TextFormatCode, arg, nil) if err != nil { return nil, err @@ -123,7 +123,7 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32 return encodePreparedStatementArgument(ci, buf, oid, arg) } - if _, ok := ci.DataTypeForOID(oid); ok { + if _, ok := ci.TypeForOID(oid); ok { sp := len(buf) buf = pgio.AppendInt32(buf, -1) argBuf, err := ci.Encode(oid, BinaryFormatCode, arg, buf) diff --git a/values_test.go b/values_test.go index f9cfd8ce..05139d29 100644 --- a/values_test.go +++ b/values_test.go @@ -79,7 +79,7 @@ func TestJSONAndJSONBTranscode(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { for _, typename := range []string{"json", "jsonb"} { - if _, ok := conn.ConnInfo().DataTypeForName(typename); !ok { + if _, ok := conn.ConnInfo().TypeForName(typename); !ok { continue // No JSON/JSONB type -- must be running against old PostgreSQL } @@ -96,7 +96,7 @@ func TestJSONAndJSONBTranscodeExtendedOnly(t *testing.T) { defer closeConn(t, conn) for _, typename := range []string{"json", "jsonb"} { - if _, ok := conn.ConnInfo().DataTypeForName(typename); !ok { + if _, ok := conn.ConnInfo().TypeForName(typename); !ok { continue // No JSON/JSONB type -- must be running against old PostgreSQL } testJSONSingleLevelStringMap(t, conn, typename) From 1f2f239d097e983ce77bec7b53fd5254927bd527 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 21 Feb 2022 09:13:09 -0600 Subject: [PATCH 0914/1158] Renamed pgtype.ConnInfo to pgtype.Map --- bench_test.go | 2 +- conn.go | 28 ++-- conn_test.go | 8 +- copy_from.go | 2 +- copy_from_test.go | 2 +- extended_query_builder.go | 16 +- pgtype/array.go | 2 +- pgtype/array_codec.go | 50 +++--- pgtype/bits.go | 12 +- pgtype/bool.go | 12 +- pgtype/box.go | 12 +- pgtype/bytea.go | 12 +- pgtype/circle.go | 12 +- pgtype/composite.go | 72 ++++----- pgtype/composite_test.go | 8 +- pgtype/date.go | 12 +- pgtype/enum_codec.go | 10 +- pgtype/enum_codec_test.go | 4 +- pgtype/float4.go | 12 +- pgtype/float8.go | 12 +- pgtype/hstore.go | 12 +- pgtype/hstore_test.go | 2 +- pgtype/inet.go | 12 +- pgtype/int.go | 36 ++--- pgtype/int.go.erb | 12 +- pgtype/interval.go | 12 +- pgtype/json.go | 8 +- pgtype/jsonb.go | 16 +- pgtype/line.go | 12 +- pgtype/line_test.go | 2 +- pgtype/lseg.go | 12 +- pgtype/macaddr.go | 12 +- pgtype/numeric.go | 14 +- pgtype/path.go | 12 +- pgtype/pgtype.go | 316 +++++++++++++++++++------------------- pgtype/pgtype_test.go | 80 +++++----- pgtype/point.go | 12 +- pgtype/polygon.go | 12 +- pgtype/qchar.go | 12 +- pgtype/range_codec.go | 42 ++--- pgtype/record_codec.go | 20 +-- pgtype/text.go | 10 +- pgtype/tid.go | 12 +- pgtype/time.go | 12 +- pgtype/timestamp.go | 12 +- pgtype/timestamptz.go | 12 +- pgtype/uint32.go | 12 +- pgtype/uuid.go | 12 +- query_test.go | 2 +- rows.go | 18 +-- stdlib/sql.go | 30 ++-- values.go | 24 +-- values_test.go | 4 +- 53 files changed, 565 insertions(+), 563 deletions(-) diff --git a/bench_test.go b/bench_test.go index 587ce162..bd182ebd 100644 --- a/bench_test.go +++ b/bench_test.go @@ -918,7 +918,7 @@ func BenchmarkSelectManyRegisteredEnum(b *testing.B) { err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "color").Scan(&oid) require.NoError(b, err) - conn.ConnInfo().RegisterType(&pgtype.Type{Name: "color", OID: oid, Codec: &pgtype.EnumCodec{}}) + conn.TypeMap().RegisterType(&pgtype.Type{Name: "color", OID: oid, Codec: &pgtype.EnumCodec{}}) b.ResetTimer() var x, y, z string diff --git a/conn.go b/conn.go index 7a66b5e2..2a66a5ff 100644 --- a/conn.go +++ b/conn.go @@ -71,7 +71,7 @@ type Conn struct { doneChan chan struct{} closedChan chan error - connInfo *pgtype.ConnInfo + typeMap *pgtype.Map wbuf []byte eqb extendedQueryBuilder @@ -202,7 +202,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { c = &Conn{ config: originalConfig, - connInfo: pgtype.NewConnInfo(), + typeMap: pgtype.NewMap(), logLevel: config.LogLevel, logger: config.Logger, } @@ -375,8 +375,8 @@ func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn } // StatementCache returns the statement cache used for this connection. func (c *Conn) StatementCache() stmtcache.Cache { return c.stmtcache } -// ConnInfo returns the connection info used for this connection. -func (c *Conn) ConnInfo() *pgtype.ConnInfo { return c.connInfo } +// TypeMap returns the connection info used for this connection. +func (c *Conn) TypeMap() *pgtype.Map { return c.typeMap } // Config returns a copy of config that was used to establish this connection. func (c *Conn) Config() *ConnConfig { return c.config.Copy() } @@ -476,14 +476,14 @@ func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, argu } for i := range args { - err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) + err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) if err != nil { return err } } for i := range sd.Fields { - c.eqb.AppendResultFormat(c.ConnInfo().FormatCodeForOID(sd.Fields[i].DataTypeOID)) + c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) } return nil @@ -516,7 +516,7 @@ func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *con r.ctx = ctx r.logger = c - r.connInfo = c.connInfo + r.typeMap = c.typeMap r.startTime = time.Now() r.sql = sql r.args = args @@ -622,7 +622,7 @@ optionLoop: } for i := range args { - err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) + err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) if err != nil { rows.fatal(err) return rows, rows.err @@ -638,7 +638,7 @@ optionLoop: if resultFormats == nil { for i := range sd.Fields { - c.eqb.AppendResultFormat(c.ConnInfo().FormatCodeForOID(sd.Fields[i].DataTypeOID)) + c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) } resultFormats = c.eqb.resultFormats @@ -781,14 +781,14 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { } for i := range args { - err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) + err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } } for i := range sd.Fields { - c.eqb.AppendResultFormat(c.ConnInfo().FormatCodeForOID(sd.Fields[i].DataTypeOID)) + c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) } if sd.Name == "" { @@ -823,7 +823,7 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, var err error valueArgs := make([]interface{}, len(args)) for i, a := range args { - valueArgs[i], err = convertSimpleArgument(c.connInfo, a) + valueArgs[i], err = convertSimpleArgument(c.typeMap, a) if err != nil { return "", err } @@ -855,7 +855,7 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err return nil, err } - dt, ok := c.ConnInfo().TypeForOID(elementOID) + dt, ok := c.TypeMap().TypeForOID(elementOID) if !ok { return nil, errors.New("array element OID not registered") } @@ -904,7 +904,7 @@ order by attnum`, []interface{}{typrelid}, []interface{}{&fieldName, &fieldOID}, func(qfr QueryFuncRow) error { - dt, ok := c.ConnInfo().TypeForOID(fieldOID) + dt, ok := c.TypeMap().TypeForOID(fieldOID) if !ok { return fmt.Errorf("unknown composite type field OID: %v", fieldOID) } diff --git a/conn_test.go b/conn_test.go index 22928a01..3240c954 100644 --- a/conn_test.go +++ b/conn_test.go @@ -829,7 +829,7 @@ func TestIdentifierSanitize(t *testing.T) { } } -func TestConnInitConnInfo(t *testing.T) { +func TestConnInitTypeMap(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) @@ -841,11 +841,11 @@ func TestConnInitConnInfo(t *testing.T) { "text": pgtype.TextOID, } for name, oid := range nameOIDs { - dtByName, ok := conn.ConnInfo().TypeForName(name) + dtByName, ok := conn.TypeMap().TypeForName(name) if !ok { t.Fatalf("Expected type named %v to be present", name) } - dtByOID, ok := conn.ConnInfo().TypeForOID(oid) + dtByOID, ok := conn.TypeMap().TypeForOID(oid) if !ok { t.Fatalf("Expected type OID %v to be present", oid) } @@ -891,7 +891,7 @@ func TestDomainType(t *testing.T) { if err != nil { t.Fatalf("did not find uint64 OID, %v", err) } - conn.ConnInfo().RegisterType(&pgtype.Type{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}}) + conn.TypeMap().RegisterType(&pgtype.Type{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}}) var n uint64 err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) diff --git a/copy_from.go b/copy_from.go index 0bf3478f..8eb3c111 100644 --- a/copy_from.go +++ b/copy_from.go @@ -178,7 +178,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) for i, val := range values { - buf, err = encodePreparedStatementArgument(ct.conn.connInfo, buf, sd.Fields[i].DataTypeOID, val) + buf, err = encodePreparedStatementArgument(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val) if err != nil { return false, nil, err } diff --git a/copy_from_test.go b/copy_from_test.go index 867b53af..5c22dc35 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -256,7 +256,7 @@ func TestConnCopyFromJSON(t *testing.T) { defer closeConn(t, conn) for _, typeName := range []string{"json", "jsonb"} { - if _, ok := conn.ConnInfo().TypeForName(typeName); !ok { + if _, ok := conn.TypeMap().TypeForName(typeName); !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } } diff --git a/extended_query_builder.go b/extended_query_builder.go index a28900c4..51759362 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -14,11 +14,11 @@ type extendedQueryBuilder struct { resultFormats []int16 } -func (eqb *extendedQueryBuilder) AppendParam(ci *pgtype.ConnInfo, oid uint32, arg interface{}) error { - f := chooseParameterFormatCode(ci, oid, arg) +func (eqb *extendedQueryBuilder) AppendParam(m *pgtype.Map, oid uint32, arg interface{}) error { + f := chooseParameterFormatCode(m, oid, arg) eqb.paramFormats = append(eqb.paramFormats, f) - v, err := eqb.encodeExtendedParamValue(ci, oid, f, arg) + v, err := eqb.encodeExtendedParamValue(m, oid, f, arg) if err != nil { return err } @@ -54,7 +54,7 @@ func (eqb *extendedQueryBuilder) Reset() { } } -func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid uint32, formatCode int16, arg interface{}) ([]byte, error) { +func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg interface{}) ([]byte, error) { if arg == nil { return nil, nil } @@ -80,11 +80,11 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o // We have already checked that arg is not pointing to nil, // so it is safe to dereference here. arg = refVal.Elem().Interface() - return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg) + return eqb.encodeExtendedParamValue(m, oid, formatCode, arg) } - if _, ok := ci.TypeForOID(oid); ok { - buf, err := ci.Encode(oid, formatCode, arg, eqb.paramValueBytes) + if _, ok := m.TypeForOID(oid); ok { + buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes) if err != nil { return nil, err } @@ -96,7 +96,7 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o } if strippedArg, ok := stripNamedType(&refVal); ok { - return eqb.encodeExtendedParamValue(ci, oid, formatCode, strippedArg) + return eqb.encodeExtendedParamValue(m, oid, formatCode, strippedArg) } return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } diff --git a/pgtype/array.go b/pgtype/array.go index d1a78e64..3648f385 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -42,7 +42,7 @@ func cardinality(dimensions []ArrayDimension) int { return elementCount } -func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { +func (dst *ArrayHeader) DecodeBinary(m *Map, src []byte) (int, error) { if len(src) < 12 { return 0, fmt.Errorf("array header too short: %d", len(src)) } diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 0b32b37f..dd14e83e 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -49,7 +49,7 @@ func (c *ArrayCodec) PreferredFormat() int16 { return c.ElementType.Codec.PreferredFormat() } -func (c *ArrayCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (c *ArrayCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { arrayValuer, ok := value.(ArrayGetter) if !ok { return nil @@ -57,16 +57,16 @@ func (c *ArrayCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value in elementType := arrayValuer.IndexType() - elementEncodePlan := ci.PlanEncode(c.ElementType.OID, format, elementType) + elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType) if elementEncodePlan == nil { return nil } switch format { case BinaryFormatCode: - return &encodePlanArrayCodecBinary{ac: c, ci: ci, oid: oid} + return &encodePlanArrayCodecBinary{ac: c, m: m, oid: oid} case TextFormatCode: - return &encodePlanArrayCodecText{ac: c, ci: ci, oid: oid} + return &encodePlanArrayCodecText{ac: c, m: m, oid: oid} } return nil @@ -74,7 +74,7 @@ func (c *ArrayCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value in type encodePlanArrayCodecText struct { ac *ArrayCodec - ci *ConnInfo + m *Map oid uint32 } @@ -124,7 +124,7 @@ func (p *encodePlanArrayCodecText) Encode(value interface{}, buf []byte) (newBuf elemType := reflect.TypeOf(elem) if lastElemType != elemType { lastElemType = elemType - encodePlan = p.ci.PlanEncode(p.ac.ElementType.OID, TextFormatCode, elem) + encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, TextFormatCode, elem) if encodePlan == nil { return nil, fmt.Errorf("unable to encode %v", array.Index(i)) } @@ -153,7 +153,7 @@ func (p *encodePlanArrayCodecText) Encode(value interface{}, buf []byte) (newBuf type encodePlanArrayCodecBinary struct { ac *ArrayCodec - ci *ConnInfo + m *Map oid uint32 } @@ -188,7 +188,7 @@ func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newB elemType := reflect.TypeOf(elem) if lastElemType != elemType { lastElemType = elemType - encodePlan = p.ci.PlanEncode(p.ac.ElementType.OID, BinaryFormatCode, elem) + encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, BinaryFormatCode, elem) if encodePlan == nil { return nil, fmt.Errorf("unable to encode %v", array.Index(i)) } @@ -210,7 +210,7 @@ func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newB return buf, nil } -func (c *ArrayCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { arrayScanner, ok := target.(ArraySetter) if !ok { return nil @@ -218,22 +218,22 @@ func (c *ArrayCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target int elementType := arrayScanner.ScanIndexType() - elementScanPlan := ci.PlanScan(c.ElementType.OID, format, elementType) + elementScanPlan := m.PlanScan(c.ElementType.OID, format, elementType) if _, ok := elementScanPlan.(*scanPlanFail); ok { return nil } return &scanPlanArrayCodec{ arrayCodec: c, - ci: ci, + m: m, oid: oid, formatCode: format, } } -func (c *ArrayCodec) decodeBinary(ci *ConnInfo, arrayOID uint32, src []byte, array ArraySetter) error { +func (c *ArrayCodec) decodeBinary(m *Map, arrayOID uint32, src []byte, array ArraySetter) error { var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) + rp, err := arrayHeader.DecodeBinary(m, src) if err != nil { return err } @@ -248,9 +248,9 @@ func (c *ArrayCodec) decodeBinary(ci *ConnInfo, arrayOID uint32, src []byte, arr return nil } - elementScanPlan := c.ElementType.Codec.PlanScan(ci, c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0), false) + elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0), false) if elementScanPlan == nil { - elementScanPlan = ci.PlanScan(c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0)) + elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0)) } for i := 0; i < elementCount; i++ { @@ -271,7 +271,7 @@ func (c *ArrayCodec) decodeBinary(ci *ConnInfo, arrayOID uint32, src []byte, arr return nil } -func (c *ArrayCodec) decodeText(ci *ConnInfo, arrayOID uint32, src []byte, array ArraySetter) error { +func (c *ArrayCodec) decodeText(m *Map, arrayOID uint32, src []byte, array ArraySetter) error { uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err @@ -286,9 +286,9 @@ func (c *ArrayCodec) decodeText(ci *ConnInfo, arrayOID uint32, src []byte, array return nil } - elementScanPlan := c.ElementType.Codec.PlanScan(ci, c.ElementType.OID, TextFormatCode, array.ScanIndex(0), false) + elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, TextFormatCode, array.ScanIndex(0), false) if elementScanPlan == nil { - elementScanPlan = ci.PlanScan(c.ElementType.OID, TextFormatCode, array.ScanIndex(0)) + elementScanPlan = m.PlanScan(c.ElementType.OID, TextFormatCode, array.ScanIndex(0)) } for i, s := range uta.Elements { @@ -309,7 +309,7 @@ func (c *ArrayCodec) decodeText(ci *ConnInfo, arrayOID uint32, src []byte, array type scanPlanArrayCodec struct { arrayCodec *ArrayCodec - ci *ConnInfo + m *Map oid uint32 formatCode int16 elementScanPlan ScanPlan @@ -317,7 +317,7 @@ type scanPlanArrayCodec struct { func (spac *scanPlanArrayCodec) Scan(src []byte, dst interface{}) error { c := spac.arrayCodec - ci := spac.ci + m := spac.m oid := spac.oid formatCode := spac.formatCode @@ -329,15 +329,15 @@ func (spac *scanPlanArrayCodec) Scan(src []byte, dst interface{}) error { switch formatCode { case BinaryFormatCode: - return c.decodeBinary(ci, oid, src, array) + return c.decodeBinary(m, oid, src, array) case TextFormatCode: - return c.decodeText(ci, oid, src, array) + return c.decodeText(m, oid, src, array) default: return fmt.Errorf("unknown format code %d", formatCode) } } -func (c *ArrayCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *ArrayCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -354,13 +354,13 @@ func (c *ArrayCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int } } -func (c *ArrayCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c *ArrayCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var slice []interface{} - err := ci.PlanScan(oid, format, &slice).Scan(src, &slice) + err := m.PlanScan(oid, format, &slice).Scan(src, &slice) return slice, err } diff --git a/pgtype/bits.go b/pgtype/bits.go index 541a3a6b..8afd8d90 100644 --- a/pgtype/bits.go +++ b/pgtype/bits.go @@ -70,7 +70,7 @@ func (BitsCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (BitsCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (BitsCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(BitsValuer); !ok { return nil } @@ -126,7 +126,7 @@ func (encodePlanBitsCodecText) Encode(value interface{}, buf []byte) (newBuf []b return buf, nil } -func (BitsCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (BitsCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -144,17 +144,17 @@ func (BitsCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa return nil } -func (c BitsCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) +func (c BitsCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c BitsCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c BitsCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var box Bits - err := codecScan(c, ci, oid, format, src, &box) + err := codecScan(c, m, oid, format, src, &box) if err != nil { return nil, err } diff --git a/pgtype/bool.go b/pgtype/bool.go index 5aa06870..1158ad06 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -106,7 +106,7 @@ func (BoolCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (BoolCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (BoolCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -197,7 +197,7 @@ func (encodePlanBoolCodecTextBool) Encode(value interface{}, buf []byte) (newBuf return buf, nil } -func (BoolCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (BoolCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -219,17 +219,17 @@ func (BoolCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa return nil } -func (c BoolCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return c.DecodeValue(ci, oid, format, src) +func (c BoolCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(m, oid, format, src) } -func (c BoolCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c BoolCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var b bool - err := codecScan(c, ci, oid, format, src, &b) + err := codecScan(c, m, oid, format, src, &b) if err != nil { return nil, err } diff --git a/pgtype/box.go b/pgtype/box.go index 6c637308..6e44c436 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -71,7 +71,7 @@ func (BoxCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (BoxCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (BoxCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(BoxValuer); !ok { return nil } @@ -126,7 +126,7 @@ func (encodePlanBoxCodecText) Encode(value interface{}, buf []byte) (newBuf []by return buf, nil } -func (BoxCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (BoxCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -220,17 +220,17 @@ func (scanPlanTextAnyToBoxScanner) Scan(src []byte, dst interface{}) error { return scanner.ScanBox(Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true}) } -func (c BoxCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) +func (c BoxCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c BoxCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c BoxCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var box Box - err := codecScan(c, ci, oid, format, src, &box) + err := codecScan(c, m, oid, format, src, &box) if err != nil { return nil, err } diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 58f3b348..eb865df0 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -72,7 +72,7 @@ func (ByteaCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (ByteaCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (ByteaCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -147,7 +147,7 @@ func (encodePlanBytesCodecTextBytesValuer) Encode(value interface{}, buf []byte) return buf, nil } -func (ByteaCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (ByteaCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -237,17 +237,17 @@ func decodeHexBytea(src []byte) ([]byte, error) { return buf, nil } -func (c ByteaCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) +func (c ByteaCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c ByteaCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c ByteaCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var buf []byte - err := codecScan(c, ci, oid, format, src, &buf) + err := codecScan(c, m, oid, format, src, &buf) if err != nil { return nil, err } diff --git a/pgtype/circle.go b/pgtype/circle.go index 8e06de88..6a83b41f 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -72,7 +72,7 @@ func (CircleCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (CircleCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (CircleCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(CircleValuer); !ok { return nil } @@ -125,7 +125,7 @@ func (encodePlanCircleCodecText) Encode(value interface{}, buf []byte) (newBuf [ return buf, nil } -func (CircleCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (CircleCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { @@ -142,17 +142,17 @@ func (CircleCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target inter return nil } -func (c CircleCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) +func (c CircleCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c CircleCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c CircleCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var circle Circle - err := codecScan(c, ci, oid, format, src, &circle) + err := codecScan(c, m, oid, format, src, &circle) if err != nil { return nil, err } diff --git a/pgtype/composite.go b/pgtype/composite.go index 1142a704..c538834d 100644 --- a/pgtype/composite.go +++ b/pgtype/composite.go @@ -54,16 +54,16 @@ func (c *CompositeCodec) PreferredFormat() int16 { return TextFormatCode } -func (c *CompositeCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (c *CompositeCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(CompositeIndexGetter); !ok { return nil } switch format { case BinaryFormatCode: - return &encodePlanCompositeCodecCompositeIndexGetterToBinary{cc: c, ci: ci} + return &encodePlanCompositeCodecCompositeIndexGetterToBinary{cc: c, m: m} case TextFormatCode: - return &encodePlanCompositeCodecCompositeIndexGetterToText{cc: c, ci: ci} + return &encodePlanCompositeCodecCompositeIndexGetterToText{cc: c, m: m} } return nil @@ -71,7 +71,7 @@ func (c *CompositeCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, valu type encodePlanCompositeCodecCompositeIndexGetterToBinary struct { cc *CompositeCodec - ci *ConnInfo + m *Map } func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { @@ -81,7 +81,7 @@ func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value i return nil, nil } - builder := NewCompositeBinaryBuilder(plan.ci, buf) + builder := NewCompositeBinaryBuilder(plan.m, buf) for i, field := range plan.cc.Fields { builder.AppendValue(field.Type.OID, getter.Index(i)) } @@ -91,7 +91,7 @@ func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value i type encodePlanCompositeCodecCompositeIndexGetterToText struct { cc *CompositeCodec - ci *ConnInfo + m *Map } func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { @@ -101,7 +101,7 @@ func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value int return nil, nil } - b := NewCompositeTextBuilder(plan.ci, buf) + b := NewCompositeTextBuilder(plan.m, buf) for i, field := range plan.cc.Fields { b.AppendValue(field.Type.OID, getter.Index(i)) } @@ -109,17 +109,17 @@ func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value int return b.Finish() } -func (c *CompositeCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (c *CompositeCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { case CompositeIndexScanner: - return &scanPlanBinaryCompositeToCompositeIndexScanner{cc: c, ci: ci} + return &scanPlanBinaryCompositeToCompositeIndexScanner{cc: c, m: m} } case TextFormatCode: switch target.(type) { case CompositeIndexScanner: - return &scanPlanTextCompositeToCompositeIndexScanner{cc: c, ci: ci} + return &scanPlanTextCompositeToCompositeIndexScanner{cc: c, m: m} } } @@ -128,7 +128,7 @@ func (c *CompositeCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target type scanPlanBinaryCompositeToCompositeIndexScanner struct { cc *CompositeCodec - ci *ConnInfo + m *Map } func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, target interface{}) error { @@ -138,12 +138,12 @@ func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, tar return targetScanner.ScanNull() } - scanner := NewCompositeBinaryScanner(plan.ci, src) + scanner := NewCompositeBinaryScanner(plan.m, src) for i, field := range plan.cc.Fields { if scanner.Next() { fieldTarget := targetScanner.ScanIndex(i) if fieldTarget != nil { - fieldPlan := plan.ci.PlanScan(field.Type.OID, BinaryFormatCode, fieldTarget) + fieldPlan := plan.m.PlanScan(field.Type.OID, BinaryFormatCode, fieldTarget) if fieldPlan == nil { return fmt.Errorf("unable to encode %v into OID %d in binary format", field, field.Type.OID) } @@ -167,7 +167,7 @@ func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, tar type scanPlanTextCompositeToCompositeIndexScanner struct { cc *CompositeCodec - ci *ConnInfo + m *Map } func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, target interface{}) error { @@ -177,12 +177,12 @@ func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, targe return targetScanner.ScanNull() } - scanner := NewCompositeTextScanner(plan.ci, src) + scanner := NewCompositeTextScanner(plan.m, src) for i, field := range plan.cc.Fields { if scanner.Next() { fieldTarget := targetScanner.ScanIndex(i) if fieldTarget != nil { - fieldPlan := plan.ci.PlanScan(field.Type.OID, TextFormatCode, fieldTarget) + fieldPlan := plan.m.PlanScan(field.Type.OID, TextFormatCode, fieldTarget) if fieldPlan == nil { return fmt.Errorf("unable to encode %v into OID %d in text format", field, field.Type.OID) } @@ -204,7 +204,7 @@ func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, targe return nil } -func (c *CompositeCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *CompositeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -221,18 +221,18 @@ func (c *CompositeCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format } } -func (c *CompositeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } switch format { case TextFormatCode: - scanner := NewCompositeTextScanner(ci, src) + scanner := NewCompositeTextScanner(m, src) values := make(map[string]interface{}, len(c.Fields)) for i := 0; scanner.Next() && i < len(c.Fields); i++ { var v interface{} - fieldPlan := ci.PlanScan(c.Fields[i].Type.OID, TextFormatCode, &v) + fieldPlan := m.PlanScan(c.Fields[i].Type.OID, TextFormatCode, &v) if fieldPlan == nil { return nil, fmt.Errorf("unable to scan OID %d in text format into %v", c.Fields[i].Type.OID, v) } @@ -251,11 +251,11 @@ func (c *CompositeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src return values, nil case BinaryFormatCode: - scanner := NewCompositeBinaryScanner(ci, src) + scanner := NewCompositeBinaryScanner(m, src) values := make(map[string]interface{}, len(c.Fields)) for i := 0; scanner.Next() && i < len(c.Fields); i++ { var v interface{} - fieldPlan := ci.PlanScan(scanner.OID(), BinaryFormatCode, &v) + fieldPlan := m.PlanScan(scanner.OID(), BinaryFormatCode, &v) if fieldPlan == nil { return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v) } @@ -280,7 +280,7 @@ func (c *CompositeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src } type CompositeBinaryScanner struct { - ci *ConnInfo + m *Map rp int src []byte @@ -291,7 +291,7 @@ type CompositeBinaryScanner struct { } // NewCompositeBinaryScanner a scanner over a binary encoded composite balue. -func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner { +func NewCompositeBinaryScanner(m *Map, src []byte) *CompositeBinaryScanner { rp := 0 if len(src[rp:]) < 4 { return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)} @@ -301,7 +301,7 @@ func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner rp += 4 return &CompositeBinaryScanner{ - ci: ci, + m: m, rp: rp, src: src, fieldCount: fieldCount, @@ -363,7 +363,7 @@ func (cfs *CompositeBinaryScanner) Err() error { } type CompositeTextScanner struct { - ci *ConnInfo + m *Map rp int src []byte @@ -372,7 +372,7 @@ type CompositeTextScanner struct { } // NewCompositeTextScanner a scanner over a text encoded composite value. -func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner { +func NewCompositeTextScanner(m *Map, src []byte) *CompositeTextScanner { if len(src) < 2 { return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)} } @@ -386,7 +386,7 @@ func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner { } return &CompositeTextScanner{ - ci: ci, + m: m, rp: 1, src: src, } @@ -459,17 +459,17 @@ func (cfs *CompositeTextScanner) Err() error { } type CompositeBinaryBuilder struct { - ci *ConnInfo + m *Map buf []byte startIdx int fieldCount uint32 err error } -func NewCompositeBinaryBuilder(ci *ConnInfo, buf []byte) *CompositeBinaryBuilder { +func NewCompositeBinaryBuilder(m *Map, buf []byte) *CompositeBinaryBuilder { startIdx := len(buf) buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields - return &CompositeBinaryBuilder{ci: ci, buf: buf, startIdx: startIdx} + return &CompositeBinaryBuilder{m: m, buf: buf, startIdx: startIdx} } func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) { @@ -484,7 +484,7 @@ func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) { return } - plan := b.ci.PlanEncode(oid, BinaryFormatCode, field) + plan := b.m.PlanEncode(oid, BinaryFormatCode, field) if plan == nil { b.err = fmt.Errorf("unable to encode %v into OID %d in binary format", field, oid) return @@ -516,7 +516,7 @@ func (b *CompositeBinaryBuilder) Finish() ([]byte, error) { } type CompositeTextBuilder struct { - ci *ConnInfo + m *Map buf []byte startIdx int fieldCount uint32 @@ -524,9 +524,9 @@ type CompositeTextBuilder struct { fieldBuf [32]byte } -func NewCompositeTextBuilder(ci *ConnInfo, buf []byte) *CompositeTextBuilder { +func NewCompositeTextBuilder(m *Map, buf []byte) *CompositeTextBuilder { buf = append(buf, '(') // allocate room for number of fields - return &CompositeTextBuilder{ci: ci, buf: buf} + return &CompositeTextBuilder{m: m, buf: buf} } func (b *CompositeTextBuilder) AppendValue(oid uint32, field interface{}) { @@ -539,7 +539,7 @@ func (b *CompositeTextBuilder) AppendValue(oid uint32, field interface{}) { return } - plan := b.ci.PlanEncode(oid, TextFormatCode, field) + plan := b.m.PlanEncode(oid, TextFormatCode, field) if plan == nil { b.err = fmt.Errorf("unable to encode %v into OID %d in text format", field, oid) return diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go index f96a6470..d97f617b 100644 --- a/pgtype/composite_test.go +++ b/pgtype/composite_test.go @@ -26,7 +26,7 @@ create type ct_test as ( dt, err := conn.LoadType(context.Background(), "ct_test") require.NoError(t, err) - conn.ConnInfo().RegisterType(dt) + conn.TypeMap().RegisterType(dt) formats := []struct { name string @@ -105,7 +105,7 @@ create type point3d as ( dt, err := conn.LoadType(context.Background(), "point3d") require.NoError(t, err) - conn.ConnInfo().RegisterType(dt) + conn.TypeMap().RegisterType(dt) formats := []struct { name string @@ -140,7 +140,7 @@ create type point3d as ( dt, err := conn.LoadType(context.Background(), "point3d") require.NoError(t, err) - conn.ConnInfo().RegisterType(dt) + conn.TypeMap().RegisterType(dt) formats := []struct { name string @@ -179,7 +179,7 @@ create type point3d as ( dt, err := conn.LoadType(context.Background(), "point3d") require.NoError(t, err) - conn.ConnInfo().RegisterType(dt) + conn.TypeMap().RegisterType(dt) formats := []struct { name string diff --git a/pgtype/date.go b/pgtype/date.go index adfa0999..fe917a3e 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -126,7 +126,7 @@ func (DateCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (DateCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (DateCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(DateValuer); !ok { return nil } @@ -196,7 +196,7 @@ func (encodePlanDateCodecText) Encode(value interface{}, buf []byte) (newBuf []b return append(buf, s...), nil } -func (DateCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -265,17 +265,17 @@ func (scanPlanTextAnyToDateScanner) Scan(src []byte, dst interface{}) error { } } -func (c DateCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) +func (c DateCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c DateCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c DateCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var date Date - err := codecScan(c, ci, oid, format, src, &date) + err := codecScan(c, m, oid, format, src, &date) if err != nil { return nil, err } diff --git a/pgtype/enum_codec.go b/pgtype/enum_codec.go index 970895d8..3dce4449 100644 --- a/pgtype/enum_codec.go +++ b/pgtype/enum_codec.go @@ -20,7 +20,7 @@ func (EnumCodec) PreferredFormat() int16 { return TextFormatCode } -func (EnumCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (EnumCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch format { case TextFormatCode, BinaryFormatCode: switch value.(type) { @@ -38,7 +38,7 @@ func (EnumCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interf return nil } -func (c *EnumCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (c *EnumCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case TextFormatCode, BinaryFormatCode: switch target.(type) { @@ -56,11 +56,11 @@ func (c *EnumCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target inte return nil } -func (c *EnumCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return c.DecodeValue(ci, oid, format, src) +func (c *EnumCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(m, oid, format, src) } -func (c *EnumCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c *EnumCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } diff --git a/pgtype/enum_codec_test.go b/pgtype/enum_codec_test.go index f8dba1a0..afd062a2 100644 --- a/pgtype/enum_codec_test.go +++ b/pgtype/enum_codec_test.go @@ -21,7 +21,7 @@ create type enum_test as enum ('foo', 'bar', 'baz');`) dt, err := conn.LoadType(context.Background(), "enum_test") require.NoError(t, err) - conn.ConnInfo().RegisterType(dt) + conn.TypeMap().RegisterType(dt) var s string err = conn.QueryRow(context.Background(), `select 'foo'::enum_test`).Scan(&s) @@ -58,7 +58,7 @@ create type enum_test as enum ('foo', 'bar', 'baz');`) dt, err := conn.LoadType(context.Background(), "enum_test") require.NoError(t, err) - conn.ConnInfo().RegisterType(dt) + conn.TypeMap().RegisterType(dt) rows, err := conn.Query(context.Background(), `select 'foo'::enum_test`) require.NoError(t, err) diff --git a/pgtype/float4.go b/pgtype/float4.go index fd5f4523..9ca6fe6a 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -75,7 +75,7 @@ func (Float4Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Float4Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (Float4Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -145,7 +145,7 @@ func (encodePlanFloat4CodecBinaryInt64Valuer) Encode(value interface{}, buf []by return pgio.AppendUint32(buf, math.Float32bits(f)), nil } -func (Float4Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (Float4Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -247,26 +247,26 @@ func (scanPlanTextAnyToFloat32) Scan(src []byte, dst interface{}) error { return nil } -func (c Float4Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c Float4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } var n float64 - err := codecScan(c, ci, oid, format, src, &n) + err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } return n, nil } -func (c Float4Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c Float4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var n float32 - err := codecScan(c, ci, oid, format, src, &n) + err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } diff --git a/pgtype/float8.go b/pgtype/float8.go index 54b1796c..d5461ab0 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -83,7 +83,7 @@ func (Float8Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Float8Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (Float8Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -183,7 +183,7 @@ func (encodePlanTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf [ return append(buf, strconv.FormatInt(n.Int, 10)...), nil } -func (Float8Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -302,17 +302,17 @@ func (scanPlanTextAnyToFloat64Scanner) Scan(src []byte, dst interface{}) error { return s.ScanFloat64(Float8{Float: n, Valid: true}) } -func (c Float8Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return c.DecodeValue(ci, oid, format, src) +func (c Float8Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(m, oid, format, src) } -func (c Float8Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c Float8Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var n float64 - err := codecScan(c, ci, oid, format, src, &n) + err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } diff --git a/pgtype/hstore.go b/pgtype/hstore.go index dc5caa84..a27330ae 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -72,7 +72,7 @@ func (HstoreCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (HstoreCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (HstoreCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(HstoreValuer); !ok { return nil } @@ -150,7 +150,7 @@ func (encodePlanHstoreCodecText) Encode(value interface{}, buf []byte) (newBuf [ return buf, nil } -func (HstoreCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -254,17 +254,17 @@ func (scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst interface{}) error { return scanner.ScanHstore(m) } -func (c HstoreCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) +func (c HstoreCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c HstoreCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c HstoreCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var hstore Hstore - err := codecScan(c, ci, oid, format, src, &hstore) + err := codecScan(c, m, oid, format, src, &hstore) if err != nil { return nil, err } diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go index 2437a240..8141687a 100644 --- a/pgtype/hstore_test.go +++ b/pgtype/hstore_test.go @@ -61,7 +61,7 @@ func TestHstoreCodec(t *testing.T) { t.Skipf("Skipping: cannot find hstore OID") } - conn.ConnInfo().RegisterType(&pgtype.Type{Name: "hstore", OID: hstoreOID, Codec: pgtype.HstoreCodec{}}) + conn.TypeMap().RegisterType(&pgtype.Type{Name: "hstore", OID: hstoreOID, Codec: pgtype.HstoreCodec{}}) formats := []struct { name string diff --git a/pgtype/inet.go b/pgtype/inet.go index 9530d1a2..ab4cff47 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -75,7 +75,7 @@ func (InetCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (InetCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (InetCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(InetValuer); !ok { return nil } @@ -140,7 +140,7 @@ func (encodePlanInetCodecText) Encode(value interface{}, buf []byte) (newBuf []b return append(buf, inet.IPNet.String()...), nil } -func (InetCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -158,17 +158,17 @@ func (InetCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa return nil } -func (c InetCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) +func (c InetCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c InetCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c InetCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var inet Inet - err := codecScan(c, ci, oid, format, src, &inet) + err := codecScan(c, m, oid, format, src, &inet) if err != nil { return nil, err } diff --git a/pgtype/int.go b/pgtype/int.go index 237fe7e7..bfdb0184 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -127,7 +127,7 @@ func (Int2Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int2Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (Int2Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -206,7 +206,7 @@ func (encodePlanInt2CodecTextInt64Valuer) Encode(value interface{}, buf []byte) return append(buf, strconv.FormatInt(n.Int, 10)...), nil } -func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (Int2Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -264,26 +264,26 @@ func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa return nil } -func (c Int2Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c Int2Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } var n int64 - err := codecScan(c, ci, oid, format, src, &n) + err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } return n, nil } -func (c Int2Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c Int2Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var n int16 - err := codecScan(c, ci, oid, format, src, &n) + err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } @@ -664,7 +664,7 @@ func (Int4Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int4Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (Int4Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -743,7 +743,7 @@ func (encodePlanInt4CodecTextInt64Valuer) Encode(value interface{}, buf []byte) return append(buf, strconv.FormatInt(n.Int, 10)...), nil } -func (Int4Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (Int4Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -801,26 +801,26 @@ func (Int4Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa return nil } -func (c Int4Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c Int4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } var n int64 - err := codecScan(c, ci, oid, format, src, &n) + err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } return n, nil } -func (c Int4Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c Int4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var n int32 - err := codecScan(c, ci, oid, format, src, &n) + err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } @@ -1212,7 +1212,7 @@ func (Int8Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int8Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (Int8Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -1291,7 +1291,7 @@ func (encodePlanInt8CodecTextInt64Valuer) Encode(value interface{}, buf []byte) return append(buf, strconv.FormatInt(n.Int, 10)...), nil } -func (Int8Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (Int8Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -1349,26 +1349,26 @@ func (Int8Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa return nil } -func (c Int8Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c Int8Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } var n int64 - err := codecScan(c, ci, oid, format, src, &n) + err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } return n, nil } -func (c Int8Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c Int8Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var n int64 - err := codecScan(c, ci, oid, format, src, &n) + err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 18e708fa..cec88984 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -128,7 +128,7 @@ func (Int<%= pg_byte_size %>Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int<%= pg_byte_size %>Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (Int<%= pg_byte_size %>Codec) PlanEncode(m *TypeMap, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -207,7 +207,7 @@ func (encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer) Encode(value interfa return append(buf, strconv.FormatInt(n.Int, 10)...), nil } -func (Int<%= pg_byte_size %>Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (Int<%= pg_byte_size %>Codec) PlanScan(m *TypeMap, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -265,26 +265,26 @@ func (Int<%= pg_byte_size %>Codec) PlanScan(ci *ConnInfo, oid uint32, format int return nil } -func (c Int<%= pg_byte_size %>Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c Int<%= pg_byte_size %>Codec) DecodeDatabaseSQLValue(m *TypeMap, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } var n int64 - err := codecScan(c, ci, oid, format, src, &n) + err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } return n, nil } -func (c Int<%= pg_byte_size %>Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c Int<%= pg_byte_size %>Codec) DecodeValue(m *TypeMap, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var n int<%= pg_bit_size %> - err := codecScan(c, ci, oid, format, src, &n) + err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } diff --git a/pgtype/interval.go b/pgtype/interval.go index a20266eb..a13969c3 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -80,7 +80,7 @@ func (IntervalCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (IntervalCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (IntervalCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(IntervalValuer); !ok { return nil } @@ -151,7 +151,7 @@ func (encodePlanIntervalCodecText) Encode(value interface{}, buf []byte) (newBuf return buf, nil } -func (IntervalCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -274,17 +274,17 @@ func (scanPlanTextAnyToIntervalScanner) Scan(src []byte, dst interface{}) error return scanner.ScanInterval(Interval{Months: months, Days: days, Microseconds: microseconds, Valid: true}) } -func (c IntervalCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) +func (c IntervalCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c IntervalCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c IntervalCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var interval Interval - err := codecScan(c, ci, oid, format, src, &interval) + err := codecScan(c, m, oid, format, src, &interval) if err != nil { return nil, err } diff --git a/pgtype/json.go b/pgtype/json.go index cd8b8ec9..04ce6f6b 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -16,7 +16,7 @@ func (JSONCodec) PreferredFormat() int16 { return TextFormatCode } -func (JSONCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch value.(type) { case []byte: return encodePlanJSONCodecEitherFormatByteSlice{} @@ -49,7 +49,7 @@ func (encodePlanJSONCodecEitherFormatMarshal) Encode(value interface{}, buf []by return buf, nil } -func (JSONCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch target.(type) { case *string: return scanPlanAnyToString{} @@ -110,7 +110,7 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst interface{}) error { return json.Unmarshal(src, dst) } -func (c JSONCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -120,7 +120,7 @@ func (c JSONCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16 return dstBuf, nil } -func (c JSONCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index 07ea58bc..0504ee62 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -16,15 +16,15 @@ func (JSONBCodec) PreferredFormat() int16 { return TextFormatCode } -func (JSONBCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: - plan := JSONCodec{}.PlanEncode(ci, oid, TextFormatCode, value) + plan := JSONCodec{}.PlanEncode(m, oid, TextFormatCode, value) if plan != nil { return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan} } case TextFormatCode: - return JSONCodec{}.PlanEncode(ci, oid, format, value) + return JSONCodec{}.PlanEncode(m, oid, format, value) } return nil @@ -39,15 +39,15 @@ func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value interface{}, buf []b return plan.textPlan.Encode(value, buf) } -func (JSONBCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: - plan := JSONCodec{}.PlanScan(ci, oid, TextFormatCode, target, actualTarget) + plan := JSONCodec{}.PlanScan(m, oid, TextFormatCode, target, actualTarget) if plan != nil { return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan} } case TextFormatCode: - return JSONCodec{}.PlanScan(ci, oid, format, target, actualTarget) + return JSONCodec{}.PlanScan(m, oid, format, target, actualTarget) } return nil @@ -73,7 +73,7 @@ func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst interface{}) return plan.textPlan.Scan(src[1:], dst) } -func (c JSONBCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -100,7 +100,7 @@ func (c JSONBCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int1 } } -func (c JSONBCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } diff --git a/pgtype/line.go b/pgtype/line.go index acae903b..17acab81 100644 --- a/pgtype/line.go +++ b/pgtype/line.go @@ -75,7 +75,7 @@ func (LineCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (LineCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (LineCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(LineValuer); !ok { return nil } @@ -128,7 +128,7 @@ func (encodePlanLineCodecText) Encode(value interface{}, buf []byte) (newBuf []b return buf, nil } -func (LineCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (LineCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -207,17 +207,17 @@ func (scanPlanTextAnyToLineScanner) Scan(src []byte, dst interface{}) error { return scanner.ScanLine(Line{A: a, B: b, C: c, Valid: true}) } -func (c LineCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) +func (c LineCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c LineCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c LineCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var line Line - err := codecScan(c, ci, oid, format, src, &line) + err := codecScan(c, m, oid, format, src, &line) if err != nil { return nil, err } diff --git a/pgtype/line_test.go b/pgtype/line_test.go index 3ed8fc4b..6c7b734b 100644 --- a/pgtype/line_test.go +++ b/pgtype/line_test.go @@ -11,7 +11,7 @@ import ( func TestLineTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) defer conn.Close(context.Background()) - if _, ok := conn.ConnInfo().TypeForName("line"); !ok { + if _, ok := conn.TypeMap().TypeForName("line"); !ok { t.Skip("Skipping due to no line type") } diff --git a/pgtype/lseg.go b/pgtype/lseg.go index 471b36b2..8f65c7c3 100644 --- a/pgtype/lseg.go +++ b/pgtype/lseg.go @@ -71,7 +71,7 @@ func (LsegCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (LsegCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (LsegCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(LsegValuer); !ok { return nil } @@ -126,7 +126,7 @@ func (encodePlanLsegCodecText) Encode(value interface{}, buf []byte) (newBuf []b return buf, nil } -func (LsegCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (LsegCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -220,17 +220,17 @@ func (scanPlanTextAnyToLsegScanner) Scan(src []byte, dst interface{}) error { return scanner.ScanLseg(Lseg{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true}) } -func (c LsegCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) +func (c LsegCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c LsegCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c LsegCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var lseg Lseg - err := codecScan(c, ci, oid, format, src, &lseg) + err := codecScan(c, m, oid, format, src, &lseg) if err != nil { return nil, err } diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go index 5b42811a..23ca55d1 100644 --- a/pgtype/macaddr.go +++ b/pgtype/macaddr.go @@ -15,7 +15,7 @@ func (MacaddrCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (MacaddrCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (MacaddrCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -78,7 +78,7 @@ func (encodePlanMacaddrCodecTextHardwareAddr) Encode(value interface{}, buf []by return append(buf, addr.String()...), nil } -func (MacaddrCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (MacaddrCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { @@ -144,17 +144,17 @@ func (scanPlanTextMacaddrToHardwareAddr) Scan(src []byte, dst interface{}) error return nil } -func (c MacaddrCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) +func (c MacaddrCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c MacaddrCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c MacaddrCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var addr net.HardwareAddr - err := codecScan(c, ci, oid, format, src, &addr) + err := codecScan(c, m, oid, format, src, &addr) if err != nil { return nil, err } diff --git a/pgtype/numeric.go b/pgtype/numeric.go index d2311f3a..3e7f972f 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -249,7 +249,7 @@ func (NumericCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (NumericCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (NumericCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -501,7 +501,7 @@ func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { return buf, nil } -func (NumericCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -710,7 +710,7 @@ func (scanPlanTextAnyToNumericScanner) Scan(src []byte, dst interface{}) error { return scanner.ScanNumeric(Numeric{Int: num, Exp: exp, Valid: true}) } -func (c NumericCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c NumericCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -720,25 +720,25 @@ func (c NumericCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format in } var n Numeric - err := codecScan(c, ci, oid, format, src, &n) + err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } - buf, err := ci.Encode(oid, TextFormatCode, n, nil) + buf, err := m.Encode(oid, TextFormatCode, n, nil) if err != nil { return nil, err } return string(buf), nil } -func (c NumericCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c NumericCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var n Numeric - err := codecScan(c, ci, oid, format, src, &n) + err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } diff --git a/pgtype/path.go b/pgtype/path.go index 62a23219..c1355e41 100644 --- a/pgtype/path.go +++ b/pgtype/path.go @@ -73,7 +73,7 @@ func (PathCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (PathCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (PathCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(PathValuer); !ok { return nil } @@ -153,7 +153,7 @@ func (encodePlanPathCodecText) Encode(value interface{}, buf []byte) (newBuf []b return buf, nil } -func (PathCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (PathCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -254,17 +254,17 @@ func (scanPlanTextAnyToPathScanner) Scan(src []byte, dst interface{}) error { return scanner.ScanPath(Path{P: points, Closed: closed, Valid: true}) } -func (c PathCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) +func (c PathCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c PathCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c PathCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var path Path - err := codecScan(c, ci, oid, format, src, &path) + err := codecScan(c, m, oid, format, src, &path) if err != nil { return nil, err } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 81431826..90f07d51 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -141,18 +141,18 @@ type Codec interface { // PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be // found then nil is returned. - PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan + PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan // PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If // actualTarget is true then the returned ScanPlan may be optimized to directly scan into target. If no plan can be // found then nil is returned. - PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan + PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan // DecodeDatabaseSQLValue returns src decoded into a value compatible with the sql.Scanner interface. - DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) + DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) // DecodeValue returns src decoded into its default format. - DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) + DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) } type nullAssignmentError struct { @@ -169,7 +169,9 @@ type Type struct { OID uint32 } -type ConnInfo struct { +// Map is the mapping between PostgreSQL server types and Go type handling logic. It can encode values for +// transmission to a PostgreSQL server and scan received values. +type Map struct { oidToType map[uint32]*Type nameToType map[string]*Type reflectTypeToName map[reflect.Type]string @@ -180,19 +182,19 @@ type ConnInfo struct { // TryWrapEncodePlanFuncs is a slice of functions that will wrap a value that cannot be encoded by the Codec. Every // time a wrapper is found the PlanEncode method will be recursively called with the new value. This allows several layers of wrappers - // to be built up. There are default functions placed in this slice by NewConnInfo(). In most cases these functions + // to be built up. There are default functions placed in this slice by NewMap(). In most cases these functions // should run last. i.e. Additional functions should typically be prepended not appended. TryWrapEncodePlanFuncs []TryWrapEncodePlanFunc // TryWrapScanPlanFuncs is a slice of functions that will wrap a target that cannot be scanned into by the Codec. Every // time a wrapper is found the PlanScan method will be recursively called with the new target. This allows several layers of wrappers - // to be built up. There are default functions placed in this slice by NewConnInfo(). In most cases these functions + // to be built up. There are default functions placed in this slice by NewMap(). In most cases these functions // should run last. i.e. Additional functions should typically be prepended not appended. TryWrapScanPlanFuncs []TryWrapScanPlanFunc } -func NewConnInfo() *ConnInfo { - ci := &ConnInfo{ +func NewMap() *Map { + m := &Map{ oidToType: make(map[uint32]*Type), nameToType: make(map[string]*Type), reflectTypeToName: make(map[reflect.Type]string), @@ -218,121 +220,121 @@ func NewConnInfo() *ConnInfo { }, } - ci.RegisterType(&Type{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) - ci.RegisterType(&Type{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) - ci.RegisterType(&Type{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) - ci.RegisterType(&Type{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) - ci.RegisterType(&Type{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) - ci.RegisterType(&Type{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) - ci.RegisterType(&Type{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) - ci.RegisterType(&Type{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) - ci.RegisterType(&Type{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) - ci.RegisterType(&Type{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) - ci.RegisterType(&Type{Name: "date", OID: DateOID, Codec: DateCodec{}}) - ci.RegisterType(&Type{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) - ci.RegisterType(&Type{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) - ci.RegisterType(&Type{Name: "inet", OID: InetOID, Codec: InetCodec{}}) - ci.RegisterType(&Type{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) - ci.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) - ci.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) - ci.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) - ci.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) - ci.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) - ci.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) - ci.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) - ci.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) - ci.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) - ci.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) - ci.RegisterType(&Type{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) - ci.RegisterType(&Type{Name: "path", OID: PathOID, Codec: PathCodec{}}) - ci.RegisterType(&Type{Name: "point", OID: PointOID, Codec: PointCodec{}}) - ci.RegisterType(&Type{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) - ci.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) - ci.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) - ci.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) - ci.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) - ci.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) - ci.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) - ci.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) - ci.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) - ci.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) - ci.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) - ci.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) + m.RegisterType(&Type{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) + m.RegisterType(&Type{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) + m.RegisterType(&Type{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) + m.RegisterType(&Type{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) + m.RegisterType(&Type{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) + m.RegisterType(&Type{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) + m.RegisterType(&Type{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) + m.RegisterType(&Type{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) + m.RegisterType(&Type{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) + m.RegisterType(&Type{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) + m.RegisterType(&Type{Name: "date", OID: DateOID, Codec: DateCodec{}}) + m.RegisterType(&Type{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) + m.RegisterType(&Type{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) + m.RegisterType(&Type{Name: "inet", OID: InetOID, Codec: InetCodec{}}) + m.RegisterType(&Type{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) + m.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) + m.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) + m.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) + m.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) + m.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) + m.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) + m.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) + m.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) + m.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) + m.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) + m.RegisterType(&Type{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) + m.RegisterType(&Type{Name: "path", OID: PathOID, Codec: PathCodec{}}) + m.RegisterType(&Type{Name: "point", OID: PointOID, Codec: PointCodec{}}) + m.RegisterType(&Type{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) + m.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) + m.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) + m.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) + m.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) + m.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) + m.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) + m.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) + m.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) + m.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) + m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) + m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) - ci.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: ci.oidToType[DateOID]}}) - ci.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: ci.oidToType[Int4OID]}}) - ci.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: ci.oidToType[Int8OID]}}) - ci.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: ci.oidToType[NumericOID]}}) - ci.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: ci.oidToType[TimestampOID]}}) - ci.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: ci.oidToType[TimestamptzOID]}}) + m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: m.oidToType[DateOID]}}) + m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int4OID]}}) + m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int8OID]}}) + m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[NumericOID]}}) + m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestampOID]}}) + m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestamptzOID]}}) - ci.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[ACLItemOID]}}) - ci.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[BitOID]}}) - ci.RegisterType(&Type{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[BoolOID]}}) - ci.RegisterType(&Type{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[BoxOID]}}) - ci.RegisterType(&Type{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[BPCharOID]}}) - ci.RegisterType(&Type{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[ByteaOID]}}) - ci.RegisterType(&Type{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[QCharOID]}}) - ci.RegisterType(&Type{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[CIDOID]}}) - ci.RegisterType(&Type{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[CIDROID]}}) - ci.RegisterType(&Type{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[CircleOID]}}) - ci.RegisterType(&Type{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[DateOID]}}) - ci.RegisterType(&Type{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[DaterangeOID]}}) - ci.RegisterType(&Type{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[Float4OID]}}) - ci.RegisterType(&Type{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[Float8OID]}}) - ci.RegisterType(&Type{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[InetOID]}}) - ci.RegisterType(&Type{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[Int2OID]}}) - ci.RegisterType(&Type{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[Int4OID]}}) - ci.RegisterType(&Type{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[Int4rangeOID]}}) - ci.RegisterType(&Type{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[Int8OID]}}) - ci.RegisterType(&Type{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[Int8rangeOID]}}) - ci.RegisterType(&Type{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[IntervalOID]}}) - ci.RegisterType(&Type{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[JSONOID]}}) - ci.RegisterType(&Type{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[JSONBOID]}}) - ci.RegisterType(&Type{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[LineOID]}}) - ci.RegisterType(&Type{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[LsegOID]}}) - ci.RegisterType(&Type{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[MacaddrOID]}}) - ci.RegisterType(&Type{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[NameOID]}}) - ci.RegisterType(&Type{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[NumericOID]}}) - ci.RegisterType(&Type{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[NumrangeOID]}}) - ci.RegisterType(&Type{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[OIDOID]}}) - ci.RegisterType(&Type{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[PathOID]}}) - ci.RegisterType(&Type{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[PointOID]}}) - ci.RegisterType(&Type{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[PolygonOID]}}) - ci.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[RecordOID]}}) - ci.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[TextOID]}}) - ci.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[TIDOID]}}) - ci.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[TimeOID]}}) - ci.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[TimestampOID]}}) - ci.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[TimestamptzOID]}}) - ci.RegisterType(&Type{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[TsrangeOID]}}) - ci.RegisterType(&Type{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[TstzrangeOID]}}) - ci.RegisterType(&Type{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[UUIDOID]}}) - ci.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[VarbitOID]}}) - ci.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[VarcharOID]}}) - ci.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: ci.oidToType[XIDOID]}}) + m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}}) + m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}}) + m.RegisterType(&Type{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BoolOID]}}) + m.RegisterType(&Type{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BoxOID]}}) + m.RegisterType(&Type{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BPCharOID]}}) + m.RegisterType(&Type{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ByteaOID]}}) + m.RegisterType(&Type{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[QCharOID]}}) + m.RegisterType(&Type{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CIDOID]}}) + m.RegisterType(&Type{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CIDROID]}}) + m.RegisterType(&Type{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CircleOID]}}) + m.RegisterType(&Type{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[DateOID]}}) + m.RegisterType(&Type{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[DaterangeOID]}}) + m.RegisterType(&Type{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Float4OID]}}) + m.RegisterType(&Type{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Float8OID]}}) + m.RegisterType(&Type{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[InetOID]}}) + m.RegisterType(&Type{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int2OID]}}) + m.RegisterType(&Type{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int4OID]}}) + m.RegisterType(&Type{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int4rangeOID]}}) + m.RegisterType(&Type{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int8OID]}}) + m.RegisterType(&Type{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int8rangeOID]}}) + m.RegisterType(&Type{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[IntervalOID]}}) + m.RegisterType(&Type{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONOID]}}) + m.RegisterType(&Type{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONBOID]}}) + m.RegisterType(&Type{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[LineOID]}}) + m.RegisterType(&Type{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[LsegOID]}}) + m.RegisterType(&Type{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[MacaddrOID]}}) + m.RegisterType(&Type{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NameOID]}}) + m.RegisterType(&Type{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NumericOID]}}) + m.RegisterType(&Type{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NumrangeOID]}}) + m.RegisterType(&Type{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[OIDOID]}}) + m.RegisterType(&Type{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PathOID]}}) + m.RegisterType(&Type{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PointOID]}}) + m.RegisterType(&Type{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PolygonOID]}}) + m.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[RecordOID]}}) + m.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TextOID]}}) + m.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TIDOID]}}) + m.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimeOID]}}) + m.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimestampOID]}}) + m.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimestamptzOID]}}) + m.RegisterType(&Type{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TsrangeOID]}}) + m.RegisterType(&Type{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TstzrangeOID]}}) + m.RegisterType(&Type{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[UUIDOID]}}) + m.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarbitOID]}}) + m.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarcharOID]}}) + m.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[XIDOID]}}) registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) { // T - ci.RegisterDefaultPgType(value, name) + m.RegisterDefaultPgType(value, name) // *T valueType := reflect.TypeOf(value) - ci.RegisterDefaultPgType(reflect.New(valueType).Interface(), name) + m.RegisterDefaultPgType(reflect.New(valueType).Interface(), name) // []T sliceType := reflect.SliceOf(valueType) - ci.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName) + m.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName) // *[]T - ci.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName) + m.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName) // []*T sliceOfPointerType := reflect.SliceOf(reflect.TypeOf(reflect.New(valueType).Interface())) - ci.RegisterDefaultPgType(reflect.MakeSlice(sliceOfPointerType, 0, 0).Interface(), arrayName) + m.RegisterDefaultPgType(reflect.MakeSlice(sliceOfPointerType, 0, 0).Interface(), arrayName) // *[]*T - ci.RegisterDefaultPgType(reflect.New(sliceOfPointerType).Interface(), arrayName) + m.RegisterDefaultPgType(reflect.New(sliceOfPointerType).Interface(), arrayName) } // Integer types that directly map to a PostgreSQL type @@ -358,57 +360,57 @@ func NewConnInfo() *ConnInfo { registerDefaultPgTypeVariants("inet", "_inet", net.IP{}) registerDefaultPgTypeVariants("cidr", "_cidr", net.IPNet{}) - return ci + return m } -func (ci *ConnInfo) RegisterType(t *Type) { - ci.oidToType[t.OID] = t - ci.nameToType[t.Name] = t - ci.oidToFormatCode[t.OID] = t.Codec.PreferredFormat() - ci.reflectTypeToType = nil // Invalidated by type registration +func (m *Map) RegisterType(t *Type) { + m.oidToType[t.OID] = t + m.nameToType[t.Name] = t + m.oidToFormatCode[t.OID] = t.Codec.PreferredFormat() + m.reflectTypeToType = nil // Invalidated by type registration } // RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be // encoded or decoded is determined by the PostgreSQL OID. But if the OID of a value to be encoded or decoded is // unknown, this additional mapping will be used by TypeForValue to determine a suitable data type. -func (ci *ConnInfo) RegisterDefaultPgType(value interface{}, name string) { - ci.reflectTypeToName[reflect.TypeOf(value)] = name - ci.reflectTypeToType = nil // Invalidated by registering a default type +func (m *Map) RegisterDefaultPgType(value interface{}, name string) { + m.reflectTypeToName[reflect.TypeOf(value)] = name + m.reflectTypeToType = nil // Invalidated by registering a default type } -func (ci *ConnInfo) TypeForOID(oid uint32) (*Type, bool) { - dt, ok := ci.oidToType[oid] +func (m *Map) TypeForOID(oid uint32) (*Type, bool) { + dt, ok := m.oidToType[oid] return dt, ok } -func (ci *ConnInfo) TypeForName(name string) (*Type, bool) { - dt, ok := ci.nameToType[name] +func (m *Map) TypeForName(name string) (*Type, bool) { + dt, ok := m.nameToType[name] return dt, ok } -func (ci *ConnInfo) buildReflectTypeToType() { - ci.reflectTypeToType = make(map[reflect.Type]*Type) +func (m *Map) buildReflectTypeToType() { + m.reflectTypeToType = make(map[reflect.Type]*Type) - for reflectType, name := range ci.reflectTypeToName { - if dt, ok := ci.nameToType[name]; ok { - ci.reflectTypeToType[reflectType] = dt + for reflectType, name := range m.reflectTypeToName { + if dt, ok := m.nameToType[name]; ok { + m.reflectTypeToType[reflectType] = dt } } } // TypeForValue finds a data type suitable for v. Use RegisterType to register types that can encode and decode // themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. -func (ci *ConnInfo) TypeForValue(v interface{}) (*Type, bool) { - if ci.reflectTypeToType == nil { - ci.buildReflectTypeToType() +func (m *Map) TypeForValue(v interface{}) (*Type, bool) { + if m.reflectTypeToType == nil { + m.buildReflectTypeToType() } - dt, ok := ci.reflectTypeToType[reflect.TypeOf(v)] + dt, ok := m.reflectTypeToType[reflect.TypeOf(v)] return dt, ok } -func (ci *ConnInfo) FormatCodeForOID(oid uint32) int16 { - fc, ok := ci.oidToFormatCode[oid] +func (m *Map) FormatCodeForOID(oid uint32) int16 { + fc, ok := m.oidToFormatCode[oid] if ok { return fc } @@ -431,13 +433,13 @@ type ScanPlan interface { type scanPlanCodecSQLScanner struct { c Codec - ci *ConnInfo + m *Map oid uint32 formatCode int16 } func (plan *scanPlanCodecSQLScanner) Scan(src []byte, dst interface{}) error { - value, err := plan.c.DecodeDatabaseSQLValue(plan.ci, plan.oid, plan.formatCode, src) + value, err := plan.c.DecodeDatabaseSQLValue(plan.m, plan.oid, plan.formatCode, src) if err != nil { return err } @@ -876,13 +878,13 @@ func (plan *wrapByteSliceScanPlan) Scan(src []byte, dst interface{}) error { type pointerEmptyInterfaceScanPlan struct { codec Codec - ci *ConnInfo + m *Map oid uint32 formatCode int16 } func (plan *pointerEmptyInterfaceScanPlan) Scan(src []byte, dst interface{}) error { - value, err := plan.codec.DecodeValue(plan.ci, plan.oid, plan.formatCode, src) + value, err := plan.codec.DecodeValue(plan.m, plan.oid, plan.formatCode, src) if err != nil { return err } @@ -991,7 +993,7 @@ func (plan *wrapPtrMultiDimSliceScanPlan) Scan(src []byte, target interface{}) e } // PlanScan prepares a plan to scan a value into target. -func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) ScanPlan { +func (m *Map) PlanScan(oid uint32, formatCode int16, target interface{}) ScanPlan { if _, ok := target.(*UndecodedBytes); ok { return scanPlanAnyToUndecodedBytes{} } @@ -1020,22 +1022,22 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) S var dt *Type - if dataType, ok := ci.TypeForOID(oid); ok { + if dataType, ok := m.TypeForOID(oid); ok { dt = dataType - } else if dataType, ok := ci.TypeForValue(target); ok { + } else if dataType, ok := m.TypeForValue(target); ok { dt = dataType oid = dt.OID // Preserve assumed OID in case we are recursively called below. } if dt != nil { - if plan := dt.Codec.PlanScan(ci, oid, formatCode, target, false); plan != nil { + if plan := dt.Codec.PlanScan(m, oid, formatCode, target, false); plan != nil { return plan } } - for _, f := range ci.TryWrapScanPlanFuncs { + for _, f := range m.TryWrapScanPlanFuncs { if wrapperPlan, nextDst, ok := f(target); ok { - if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { + if nextPlan := m.PlanScan(oid, formatCode, nextDst); nextPlan != nil { if _, failed := nextPlan.(*scanPlanFail); !failed { wrapperPlan.SetNext(nextPlan) return wrapperPlan @@ -1046,11 +1048,11 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) S if dt != nil { if _, ok := target.(*interface{}); ok { - return &pointerEmptyInterfaceScanPlan{codec: dt.Codec, ci: ci, oid: oid, formatCode: formatCode} + return &pointerEmptyInterfaceScanPlan{codec: dt.Codec, m: m, oid: oid, formatCode: formatCode} } if _, ok := target.(sql.Scanner); ok { - return &scanPlanCodecSQLScanner{c: dt.Codec, ci: ci, oid: oid, formatCode: formatCode} + return &scanPlanCodecSQLScanner{c: dt.Codec, m: m, oid: oid, formatCode: formatCode} } } @@ -1061,12 +1063,12 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) S return &scanPlanFail{oid: oid, formatCode: formatCode} } -func (ci *ConnInfo) Scan(oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (m *Map) Scan(oid uint32, formatCode int16, src []byte, dst interface{}) error { if dst == nil { return nil } - plan := ci.PlanScan(oid, formatCode, dst) + plan := m.PlanScan(oid, formatCode, dst) return plan.Scan(src, dst) } @@ -1091,15 +1093,15 @@ func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) var ErrScanTargetTypeChanged = errors.New("scan target type changed") -func codecScan(codec Codec, ci *ConnInfo, oid uint32, format int16, src []byte, dst interface{}) error { - scanPlan := codec.PlanScan(ci, oid, format, dst, true) +func codecScan(codec Codec, m *Map, oid uint32, format int16, src []byte, dst interface{}) error { + scanPlan := codec.PlanScan(m, oid, format, dst, true) if scanPlan == nil { return fmt.Errorf("PlanScan did not find a plan") } return scanPlan.Scan(src, dst) } -func codecDecodeToTextFormat(codec Codec, ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func codecDecodeToTextFormat(codec Codec, m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -1107,11 +1109,11 @@ func codecDecodeToTextFormat(codec Codec, ci *ConnInfo, oid uint32, format int16 if format == TextFormatCode { return string(src), nil } else { - value, err := codec.DecodeValue(ci, oid, format, src) + value, err := codec.DecodeValue(m, oid, format, src) if err != nil { return nil, err } - buf, err := ci.Encode(oid, TextFormatCode, value, nil) + buf, err := m.Encode(oid, TextFormatCode, value, nil) if err != nil { return nil, err } @@ -1121,29 +1123,29 @@ func codecDecodeToTextFormat(codec Codec, ci *ConnInfo, oid uint32, format int16 // PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be // found then nil is returned. -func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan { +func (m *Map) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan { var dt *Type if oid == 0 { - if dataType, ok := ci.TypeForValue(value); ok { + if dataType, ok := m.TypeForValue(value); ok { dt = dataType oid = dt.OID // Preserve assumed OID in case we are recursively called below. } } else { - if dataType, ok := ci.TypeForOID(oid); ok { + if dataType, ok := m.TypeForOID(oid); ok { dt = dataType } } if dt != nil { - if plan := dt.Codec.PlanEncode(ci, oid, format, value); plan != nil { + if plan := dt.Codec.PlanEncode(m, oid, format, value); plan != nil { return plan } - for _, f := range ci.TryWrapEncodePlanFuncs { + for _, f := range m.TryWrapEncodePlanFuncs { if wrapperPlan, nextValue, ok := f(value); ok { - if nextPlan := ci.PlanEncode(oid, format, nextValue); nextPlan != nil { + if nextPlan := m.PlanEncode(oid, format, nextValue); nextPlan != nil { wrapperPlan.SetNext(nextPlan) return wrapperPlan } @@ -1615,12 +1617,12 @@ func (plan *wrapMultiDimSliceEncodePlan) Encode(value interface{}, buf []byte) ( // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written. -func (ci *ConnInfo) Encode(oid uint32, formatCode int16, value interface{}, buf []byte) (newBuf []byte, err error) { +func (m *Map) Encode(oid uint32, formatCode int16, value interface{}, buf []byte) (newBuf []byte, err error) { if value == nil { return nil, nil } - plan := ci.PlanEncode(oid, formatCode, value) + plan := m.PlanEncode(oid, formatCode, value) if plan == nil { return nil, fmt.Errorf("unable to encode %#v into OID %d", value, oid) } diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 2917c31c..bbec30f3 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -67,59 +67,59 @@ func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { return addr } -func TestConnInfoScanNilIsNoOp(t *testing.T) { - ci := pgtype.NewConnInfo() +func TestTypeMapScanNilIsNoOp(t *testing.T) { + m := pgtype.NewMap() - err := ci.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), nil) + err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), nil) assert.NoError(t, err) } -func TestConnInfoScanTextFormatInterfacePtr(t *testing.T) { - ci := pgtype.NewConnInfo() +func TestTypeMapScanTextFormatInterfacePtr(t *testing.T) { + m := pgtype.NewMap() var got interface{} - err := ci.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), &got) + err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), &got) require.NoError(t, err) assert.Equal(t, "foo", got) } -func TestConnInfoScanTextFormatNonByteaIntoByteSlice(t *testing.T) { - ci := pgtype.NewConnInfo() +func TestTypeMapScanTextFormatNonByteaIntoByteSlice(t *testing.T) { + m := pgtype.NewMap() var got []byte - err := ci.Scan(pgtype.JSONBOID, pgx.TextFormatCode, []byte("{}"), &got) + err := m.Scan(pgtype.JSONBOID, pgx.TextFormatCode, []byte("{}"), &got) require.NoError(t, err) assert.Equal(t, []byte("{}"), got) } -func TestConnInfoScanBinaryFormatInterfacePtr(t *testing.T) { - ci := pgtype.NewConnInfo() +func TestTypeMapScanBinaryFormatInterfacePtr(t *testing.T) { + m := pgtype.NewMap() var got interface{} - err := ci.Scan(pgtype.TextOID, pgx.BinaryFormatCode, []byte("foo"), &got) + err := m.Scan(pgtype.TextOID, pgx.BinaryFormatCode, []byte("foo"), &got) require.NoError(t, err) assert.Equal(t, "foo", got) } -func TestConnInfoScanUnknownOIDToStringsAndBytes(t *testing.T) { +func TestTypeMapScanUnknownOIDToStringsAndBytes(t *testing.T) { unknownOID := uint32(999999) srcBuf := []byte("foo") - ci := pgtype.NewConnInfo() + m := pgtype.NewMap() var s string - err := ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &s) + err := m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &s) assert.NoError(t, err) assert.Equal(t, "foo", s) var rs _string - err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rs) + err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rs) assert.NoError(t, err) assert.Equal(t, "foo", string(rs)) var b []byte - err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &b) + err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &b) assert.NoError(t, err) assert.Equal(t, []byte("foo"), b) var rb _byteSlice - err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rb) + err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rb) assert.NoError(t, err) assert.Equal(t, []byte("foo"), []byte(rb)) } @@ -129,7 +129,7 @@ type pgCustomType struct { b string } -func (ct *pgCustomType) DecodeText(ci *pgtype.ConnInfo, buf []byte) error { +func (ct *pgCustomType) DecodeText(m *pgtype.Map, buf []byte) error { // This is not a complete parser for the text format of composite types. This is just for test purposes. if buf == nil { return errors.New("cannot parse null") @@ -150,58 +150,58 @@ func (ct *pgCustomType) DecodeText(ci *pgtype.ConnInfo, buf []byte) error { return nil } -func TestConnInfoScanUnregisteredOIDToCustomType(t *testing.T) { +func TestTypeMapScanUnregisteredOIDToCustomType(t *testing.T) { t.Skip("TODO - unskip later in v5") // may no longer be relevent unregisteredOID := uint32(999999) - ci := pgtype.NewConnInfo() + m := pgtype.NewMap() var ct pgCustomType - err := ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct) + err := m.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct) assert.NoError(t, err) assert.Equal(t, "foo", ct.a) assert.Equal(t, "bar", ct.b) // Scan value into pointer to custom type var pCt *pgCustomType - err = ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt) + err = m.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt) assert.NoError(t, err) require.NotNil(t, pCt) assert.Equal(t, "foo", pCt.a) assert.Equal(t, "bar", pCt.b) // Scan null into pointer to custom type - err = ci.Scan(unregisteredOID, pgx.TextFormatCode, nil, &pCt) + err = m.Scan(unregisteredOID, pgx.TextFormatCode, nil, &pCt) assert.NoError(t, err) assert.Nil(t, pCt) } -func TestConnInfoScanUnknownOIDTextFormat(t *testing.T) { - ci := pgtype.NewConnInfo() +func TestTypeMapScanUnknownOIDTextFormat(t *testing.T) { + m := pgtype.NewMap() var n int32 - err := ci.Scan(0, pgx.TextFormatCode, []byte("123"), &n) + err := m.Scan(0, pgx.TextFormatCode, []byte("123"), &n) assert.NoError(t, err) assert.EqualValues(t, 123, n) } -func TestConnInfoScanUnknownOIDIntoSQLScanner(t *testing.T) { - ci := pgtype.NewConnInfo() +func TestTypeMapScanUnknownOIDIntoSQLScanner(t *testing.T) { + m := pgtype.NewMap() var s sql.NullString - err := ci.Scan(0, pgx.TextFormatCode, []byte(nil), &s) + err := m.Scan(0, pgx.TextFormatCode, []byte(nil), &s) assert.NoError(t, err) assert.Equal(t, "", s.String) assert.False(t, s.Valid) } -func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) { - ci := pgtype.NewConnInfo() +func BenchmarkTypeMapScanInt4IntoBinaryDecoder(b *testing.B) { + m := pgtype.NewMap() src := []byte{0, 0, 0, 42} var v pgtype.Int4 for i := 0; i < b.N; i++ { v = pgtype.Int4{} - err := ci.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + err := m.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) if err != nil { b.Fatal(err) } @@ -211,14 +211,14 @@ func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) { } } -func BenchmarkConnInfoScanInt4IntoGoInt32(b *testing.B) { - ci := pgtype.NewConnInfo() +func BenchmarkTypeMapScanInt4IntoGoInt32(b *testing.B) { + m := pgtype.NewMap() src := []byte{0, 0, 0, 42} var v int32 for i := 0; i < b.N; i++ { v = 0 - err := ci.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + err := m.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) if err != nil { b.Fatal(err) } @@ -229,11 +229,11 @@ func BenchmarkConnInfoScanInt4IntoGoInt32(b *testing.B) { } func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) { - ci := pgtype.NewConnInfo() + m := pgtype.NewMap() src := []byte{0, 0, 0, 42} var v pgtype.Int4 - plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) + plan := m.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) for i := 0; i < b.N; i++ { v = pgtype.Int4{} @@ -248,11 +248,11 @@ func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) { } func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { - ci := pgtype.NewConnInfo() + m := pgtype.NewMap() src := []byte{0, 0, 0, 42} var v int32 - plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) + plan := m.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) for i := 0; i < b.N; i++ { v = 0 diff --git a/pgtype/point.go b/pgtype/point.go index 0a300fe2..4a9637ee 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -127,7 +127,7 @@ func (PointCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (PointCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (PointCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(PointValuer); !ok { return nil } @@ -177,7 +177,7 @@ func (encodePlanPointCodecText) Encode(value interface{}, buf []byte) (newBuf [] )...), nil } -func (PointCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -195,17 +195,17 @@ func (PointCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interf return nil } -func (c PointCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) +func (c PointCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c PointCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c PointCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var point Point - err := codecScan(c, ci, oid, format, src, &point) + err := codecScan(c, m, oid, format, src, &point) if err != nil { return nil, err } diff --git a/pgtype/polygon.go b/pgtype/polygon.go index e4a1b2af..5a090ff0 100644 --- a/pgtype/polygon.go +++ b/pgtype/polygon.go @@ -72,7 +72,7 @@ func (PolygonCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (PolygonCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (PolygonCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(PolygonValuer); !ok { return nil } @@ -138,7 +138,7 @@ func (encodePlanPolygonCodecText) Encode(value interface{}, buf []byte) (newBuf return buf, nil } -func (PolygonCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (PolygonCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -235,17 +235,17 @@ func (scanPlanTextAnyToPolygonScanner) Scan(src []byte, dst interface{}) error { return scanner.ScanPolygon(Polygon{P: points, Valid: true}) } -func (c PolygonCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) +func (c PolygonCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c PolygonCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c PolygonCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var polygon Polygon - err := codecScan(c, ci, oid, format, src, &polygon) + err := codecScan(c, m, oid, format, src, &polygon) if err != nil { return nil, err } diff --git a/pgtype/qchar.go b/pgtype/qchar.go index 5c712369..ce2c3dcb 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -22,7 +22,7 @@ func (QCharCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (QCharCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (QCharCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch format { case TextFormatCode, BinaryFormatCode: switch value.(type) { @@ -56,7 +56,7 @@ func (encodePlanQcharCodecRune) Encode(value interface{}, buf []byte) (newBuf [] return buf, nil } -func (QCharCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (QCharCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case TextFormatCode, BinaryFormatCode: switch target.(type) { @@ -114,26 +114,26 @@ func (scanPlanQcharCodecRune) Scan(src []byte, dst interface{}) error { return nil } -func (c QCharCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c QCharCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } var r rune - err := codecScan(c, ci, oid, format, src, &r) + err := codecScan(c, m, oid, format, src, &r) if err != nil { return nil, err } return string(r), nil } -func (c QCharCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c QCharCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var r rune - err := codecScan(c, ci, oid, format, src, &r) + err := codecScan(c, m, oid, format, src, &r) if err != nil { return nil, err } diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go index 98903d3a..9628ce4b 100644 --- a/pgtype/range_codec.go +++ b/pgtype/range_codec.go @@ -86,16 +86,16 @@ func (c *RangeCodec) PreferredFormat() int16 { return TextFormatCode } -func (c *RangeCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (c *RangeCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(RangeValuer); !ok { return nil } switch format { case BinaryFormatCode: - return &encodePlanRangeCodecRangeValuerToBinary{rc: c, ci: ci} + return &encodePlanRangeCodecRangeValuerToBinary{rc: c, m: m} case TextFormatCode: - return &encodePlanRangeCodecRangeValuerToText{rc: c, ci: ci} + return &encodePlanRangeCodecRangeValuerToText{rc: c, m: m} } return nil @@ -103,7 +103,7 @@ func (c *RangeCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value in type encodePlanRangeCodecRangeValuerToBinary struct { rc *RangeCodec - ci *ConnInfo + m *Map } func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { @@ -149,7 +149,7 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value interface{}, b sp := len(buf) buf = pgio.AppendInt32(buf, -1) - lowerPlan := plan.ci.PlanEncode(plan.rc.ElementType.OID, BinaryFormatCode, lower) + lowerPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, BinaryFormatCode, lower) if lowerPlan == nil { return nil, fmt.Errorf("cannot encode %v as element of range", lower) } @@ -173,7 +173,7 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value interface{}, b sp := len(buf) buf = pgio.AppendInt32(buf, -1) - upperPlan := plan.ci.PlanEncode(plan.rc.ElementType.OID, BinaryFormatCode, upper) + upperPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, BinaryFormatCode, upper) if upperPlan == nil { return nil, fmt.Errorf("cannot encode %v as element of range", upper) } @@ -194,7 +194,7 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value interface{}, b type encodePlanRangeCodecRangeValuerToText struct { rc *RangeCodec - ci *ConnInfo + m *Map } func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { @@ -223,7 +223,7 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value interface{}, buf return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } - lowerPlan := plan.ci.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, lower) + lowerPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, lower) if lowerPlan == nil { return nil, fmt.Errorf("cannot encode %v as element of range", lower) } @@ -244,7 +244,7 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value interface{}, buf return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } - upperPlan := plan.ci.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, upper) + upperPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, upper) if upperPlan == nil { return nil, fmt.Errorf("cannot encode %v as element of range", upper) } @@ -270,17 +270,17 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value interface{}, buf return buf, nil } -func (c *RangeCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { case RangeScanner: - return &scanPlanBinaryRangeToRangeScanner{rc: c, ci: ci} + return &scanPlanBinaryRangeToRangeScanner{rc: c, m: m} } case TextFormatCode: switch target.(type) { case RangeScanner: - return &scanPlanTextRangeToRangeScanner{rc: c, ci: ci} + return &scanPlanTextRangeToRangeScanner{rc: c, m: m} } } @@ -289,7 +289,7 @@ func (c *RangeCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target int type scanPlanBinaryRangeToRangeScanner struct { rc *RangeCodec - ci *ConnInfo + m *Map } func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target interface{}) error { @@ -311,7 +311,7 @@ func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target interface lowerTarget, upperTarget := rangeScanner.ScanBounds() if ubr.LowerType == Inclusive || ubr.LowerType == Exclusive { - lowerPlan := plan.ci.PlanScan(plan.rc.ElementType.OID, BinaryFormatCode, lowerTarget) + lowerPlan := plan.m.PlanScan(plan.rc.ElementType.OID, BinaryFormatCode, lowerTarget) if lowerPlan == nil { return fmt.Errorf("cannot scan into %v from range element", lowerTarget) } @@ -323,7 +323,7 @@ func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target interface } if ubr.UpperType == Inclusive || ubr.UpperType == Exclusive { - upperPlan := plan.ci.PlanScan(plan.rc.ElementType.OID, BinaryFormatCode, upperTarget) + upperPlan := plan.m.PlanScan(plan.rc.ElementType.OID, BinaryFormatCode, upperTarget) if upperPlan == nil { return fmt.Errorf("cannot scan into %v from range element", upperTarget) } @@ -339,7 +339,7 @@ func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target interface type scanPlanTextRangeToRangeScanner struct { rc *RangeCodec - ci *ConnInfo + m *Map } func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target interface{}) error { @@ -361,7 +361,7 @@ func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target interface{} lowerTarget, upperTarget := rangeScanner.ScanBounds() if utr.LowerType == Inclusive || utr.LowerType == Exclusive { - lowerPlan := plan.ci.PlanScan(plan.rc.ElementType.OID, TextFormatCode, lowerTarget) + lowerPlan := plan.m.PlanScan(plan.rc.ElementType.OID, TextFormatCode, lowerTarget) if lowerPlan == nil { return fmt.Errorf("cannot scan into %v from range element", lowerTarget) } @@ -373,7 +373,7 @@ func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target interface{} } if utr.UpperType == Inclusive || utr.UpperType == Exclusive { - upperPlan := plan.ci.PlanScan(plan.rc.ElementType.OID, TextFormatCode, upperTarget) + upperPlan := plan.m.PlanScan(plan.rc.ElementType.OID, TextFormatCode, upperTarget) if upperPlan == nil { return fmt.Errorf("cannot scan into %v from range element", upperTarget) } @@ -387,7 +387,7 @@ func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target interface{} return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType) } -func (c *RangeCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -404,12 +404,12 @@ func (c *RangeCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int } } -func (c *RangeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var r GenericRange - err := c.PlanScan(ci, oid, format, &r, true).Scan(src, &r) + err := c.PlanScan(m, oid, format, &r, true).Scan(src, &r) return r, err } diff --git a/pgtype/record_codec.go b/pgtype/record_codec.go index 92c197b2..c31aa63c 100644 --- a/pgtype/record_codec.go +++ b/pgtype/record_codec.go @@ -21,15 +21,15 @@ func (RecordCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (RecordCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (RecordCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { return nil } -func (RecordCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (RecordCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { if format == BinaryFormatCode { switch target.(type) { case CompositeIndexScanner: - return &scanPlanBinaryRecordToCompositeIndexScanner{ci: ci} + return &scanPlanBinaryRecordToCompositeIndexScanner{m: m} } } @@ -37,7 +37,7 @@ func (RecordCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target inter } type scanPlanBinaryRecordToCompositeIndexScanner struct { - ci *ConnInfo + m *Map } func (plan *scanPlanBinaryRecordToCompositeIndexScanner) Scan(src []byte, target interface{}) error { @@ -47,11 +47,11 @@ func (plan *scanPlanBinaryRecordToCompositeIndexScanner) Scan(src []byte, target return targetScanner.ScanNull() } - scanner := NewCompositeBinaryScanner(plan.ci, src) + scanner := NewCompositeBinaryScanner(plan.m, src) for i := 0; scanner.Next(); i++ { fieldTarget := targetScanner.ScanIndex(i) if fieldTarget != nil { - fieldPlan := plan.ci.PlanScan(scanner.OID(), BinaryFormatCode, fieldTarget) + fieldPlan := plan.m.PlanScan(scanner.OID(), BinaryFormatCode, fieldTarget) if fieldPlan == nil { return fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), fieldTarget) } @@ -70,7 +70,7 @@ func (plan *scanPlanBinaryRecordToCompositeIndexScanner) Scan(src []byte, target return nil } -func (RecordCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (RecordCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -87,7 +87,7 @@ func (RecordCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16 } } -func (RecordCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (RecordCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } @@ -96,11 +96,11 @@ func (RecordCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt case TextFormatCode: return string(src), nil case BinaryFormatCode: - scanner := NewCompositeBinaryScanner(ci, src) + scanner := NewCompositeBinaryScanner(m, src) values := make([]interface{}, scanner.FieldCount()) for i := 0; scanner.Next(); i++ { var v interface{} - fieldPlan := ci.PlanScan(scanner.OID(), BinaryFormatCode, &v) + fieldPlan := m.PlanScan(scanner.OID(), BinaryFormatCode, &v) if fieldPlan == nil { return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v) } diff --git a/pgtype/text.go b/pgtype/text.go index 7e4f8b99..2c551958 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -90,7 +90,7 @@ func (TextCodec) PreferredFormat() int16 { return TextFormatCode } -func (TextCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (TextCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch format { case TextFormatCode, BinaryFormatCode: switch value.(type) { @@ -156,7 +156,7 @@ func (encodePlanTextCodecTextValuer) Encode(value interface{}, buf []byte) (newB return buf, nil } -func (TextCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (TextCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case TextFormatCode, BinaryFormatCode: @@ -175,11 +175,11 @@ func (TextCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interfa return nil } -func (c TextCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return c.DecodeValue(ci, oid, format, src) +func (c TextCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(m, oid, format, src) } -func (c TextCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c TextCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } diff --git a/pgtype/tid.go b/pgtype/tid.go index cc38f404..2744cb15 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -82,7 +82,7 @@ func (TIDCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (TIDCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (TIDCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(TIDValuer); !ok { return nil } @@ -130,7 +130,7 @@ func (encodePlanTIDCodecText) Encode(value interface{}, buf []byte) (newBuf []by return buf, nil } -func (TIDCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (TIDCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -223,17 +223,17 @@ func (scanPlanTextAnyToTIDScanner) Scan(src []byte, dst interface{}) error { return scanner.ScanTID(TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Valid: true}) } -func (c TIDCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) +func (c TIDCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c TIDCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c TIDCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var tid TID - err := codecScan(c, ci, oid, format, src, &tid) + err := codecScan(c, m, oid, format, src, &tid) if err != nil { return nil, err } diff --git a/pgtype/time.go b/pgtype/time.go index 734687cb..71e7f597 100644 --- a/pgtype/time.go +++ b/pgtype/time.go @@ -74,7 +74,7 @@ func (TimeCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (TimeCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (TimeCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(TimeValuer); !ok { return nil } @@ -129,7 +129,7 @@ func (encodePlanTimeCodecText) Encode(value interface{}, buf []byte) (newBuf []b return append(buf, s...), nil } -func (TimeCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -215,17 +215,17 @@ func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst interface{}) error { return scanner.ScanTime(Time{Microseconds: usec, Valid: true}) } -func (c TimeCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, ci, oid, format, src) +func (c TimeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c TimeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c TimeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var t Time - err := codecScan(c, ci, oid, format, src, &t) + err := codecScan(c, m, oid, format, src, &t) if err != nil { return nil, err } diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 24ee76b0..9cfa7702 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -75,7 +75,7 @@ func (TimestampCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (TimestampCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(TimestampValuer); !ok { return nil } @@ -152,7 +152,7 @@ func discardTimeZone(t time.Time) time.Time { return t } -func (TimestampCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -230,13 +230,13 @@ func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst interface{}) return scanner.ScanTimestamp(ts) } -func (c TimestampCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } var ts Timestamp - err := codecScan(c, ci, oid, format, src, &ts) + err := codecScan(c, m, oid, format, src, &ts) if err != nil { return nil, err } @@ -248,13 +248,13 @@ func (c TimestampCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format return ts.Time, nil } -func (c TimestampCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var ts Timestamp - err := codecScan(c, ci, oid, format, src, &ts) + err := codecScan(c, m, oid, format, src, &ts) if err != nil { return nil, err } diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index ea2ebfbe..0bdd0fd6 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -133,7 +133,7 @@ func (TimestamptzCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (TimestamptzCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(TimestamptzValuer); !ok { return nil } @@ -200,7 +200,7 @@ func (encodePlanTimestamptzCodecText) Encode(value interface{}, buf []byte) (new return buf, nil } -func (TimestamptzCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -287,13 +287,13 @@ func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst interfac return scanner.ScanTimestamptz(tstz) } -func (c TimestamptzCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } var tstz Timestamptz - err := codecScan(c, ci, oid, format, src, &tstz) + err := codecScan(c, m, oid, format, src, &tstz) if err != nil { return nil, err } @@ -305,13 +305,13 @@ func (c TimestamptzCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, forma return tstz.Time, nil } -func (c TimestamptzCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var tstz Timestamptz - err := codecScan(c, ci, oid, format, src, &tstz) + err := codecScan(c, m, oid, format, src, &tstz) if err != nil { return nil, err } diff --git a/pgtype/uint32.go b/pgtype/uint32.go index 7d481a27..44238fa0 100644 --- a/pgtype/uint32.go +++ b/pgtype/uint32.go @@ -85,7 +85,7 @@ func (Uint32Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Uint32Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (Uint32Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -196,7 +196,7 @@ func (encodePlanUint32CodecTextInt64Valuer) Encode(value interface{}, buf []byte return append(buf, strconv.FormatInt(v.Int, 10)...), nil } -func (Uint32Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: @@ -218,26 +218,26 @@ func (Uint32Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target inter return nil } -func (c Uint32Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c Uint32Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } var n uint32 - err := codecScan(c, ci, oid, format, src, &n) + err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } return int64(n), nil } -func (c Uint32Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c Uint32Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var n uint32 - err := codecScan(c, ci, oid, format, src, &n) + err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } diff --git a/pgtype/uuid.go b/pgtype/uuid.go index 2655a124..39df8537 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -122,7 +122,7 @@ func (UUIDCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (UUIDCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { +func (UUIDCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { if _, ok := value.(UUIDValuer); !ok { return nil } @@ -167,7 +167,7 @@ func (encodePlanUUIDCodecTextUUIDValuer) Encode(value interface{}, buf []byte) ( return append(buf, encodeUUID(uuid.Bytes)...), nil } -func (UUIDCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (UUIDCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { @@ -220,13 +220,13 @@ func (scanPlanTextAnyToUUIDScanner) Scan(src []byte, dst interface{}) error { return scanner.ScanUUID(UUID{Bytes: buf, Valid: true}) } -func (c UUIDCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c UUIDCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } var uuid UUID - err := codecScan(c, ci, oid, format, src, &uuid) + err := codecScan(c, m, oid, format, src, &uuid) if err != nil { return nil, err } @@ -234,13 +234,13 @@ func (c UUIDCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16 return encodeUUID(uuid.Bytes), nil } -func (c UUIDCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { +func (c UUIDCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } var uuid UUID - err := codecScan(c, ci, oid, format, src, &uuid) + err := codecScan(c, m, oid, format, src, &uuid) if err != nil { return nil, err } diff --git a/query_test.go b/query_test.go index 2f8975ac..3728f8a3 100644 --- a/query_test.go +++ b/query_test.go @@ -1358,7 +1358,7 @@ func TestScanRow(t *testing.T) { for resultReader.NextRow() { var n int32 - err := pgx.ScanRow(conn.ConnInfo(), resultReader.FieldDescriptions(), resultReader.Values(), &n) + err := pgx.ScanRow(conn.TypeMap(), resultReader.FieldDescriptions(), resultReader.Values(), &n) assert.NoError(t, err) sum += n rowCount++ diff --git a/rows.go b/rows.go index 25a42466..805be100 100644 --- a/rows.go +++ b/rows.go @@ -100,7 +100,7 @@ type rowLog interface { type connRows struct { ctx context.Context logger rowLog - connInfo *pgtype.ConnInfo + typeMap *pgtype.Map values [][]byte rowCount int err error @@ -196,7 +196,7 @@ func (rows *connRows) Next() bool { } func (rows *connRows) Scan(dest ...interface{}) error { - ci := rows.connInfo + m := rows.typeMap fieldDescriptions := rows.FieldDescriptions() values := rows.values @@ -215,7 +215,7 @@ func (rows *connRows) Scan(dest ...interface{}) error { rows.scanPlans = make([]pgtype.ScanPlan, len(values)) rows.scanTypes = make([]reflect.Type, len(values)) for i := range dest { - rows.scanPlans[i] = ci.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) + rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) rows.scanTypes[i] = reflect.TypeOf(dest[i]) } } @@ -226,7 +226,7 @@ func (rows *connRows) Scan(dest ...interface{}) error { } if rows.scanTypes[i] != reflect.TypeOf(dst) { - rows.scanPlans[i] = ci.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) + rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) rows.scanTypes[i] = reflect.TypeOf(dest[i]) } @@ -257,8 +257,8 @@ func (rows *connRows) Values() ([]interface{}, error) { continue } - if dt, ok := rows.connInfo.TypeForOID(fd.DataTypeOID); ok { - value, err := dt.Codec.DecodeValue(rows.connInfo, fd.DataTypeOID, fd.Format, buf) + if dt, ok := rows.typeMap.TypeForOID(fd.DataTypeOID); ok { + value, err := dt.Codec.DecodeValue(rows.typeMap, fd.DataTypeOID, fd.Format, buf) if err != nil { rows.fatal(err) } @@ -303,11 +303,11 @@ func (e ScanArgError) Unwrap() error { // ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level pgconn interface. // -// connInfo - OID to Go type mapping. +// typeMap - OID to Go type mapping. // fieldDescriptions - OID and format of values // values - the raw data as returned from the PostgreSQL server // dest - the destination that values will be decoded into -func ScanRow(connInfo *pgtype.ConnInfo, fieldDescriptions []pgproto3.FieldDescription, values [][]byte, dest ...interface{}) error { +func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgproto3.FieldDescription, values [][]byte, dest ...interface{}) error { if len(fieldDescriptions) != len(values) { return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) } @@ -320,7 +320,7 @@ func ScanRow(connInfo *pgtype.ConnInfo, fieldDescriptions []pgproto3.FieldDescri continue } - err := connInfo.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d) + err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d) if err != nil { return ScanArgError{ColumnIndex: i, Err: err} } diff --git a/stdlib/sql.go b/stdlib/sql.go index b9079fc8..7624605c 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -519,7 +519,7 @@ func (r *Rows) Columns() []string { // ColumnTypeDatabaseTypeName returns the database system type name. If the name is unknown the OID is returned. func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { - if dt, ok := r.conn.conn.ConnInfo().TypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok { + if dt, ok := r.conn.conn.TypeMap().TypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok { return strings.ToUpper(dt.Name) } @@ -594,7 +594,7 @@ func (r *Rows) Close() error { } func (r *Rows) Next(dest []driver.Value) error { - ci := r.conn.conn.ConnInfo() + m := r.conn.conn.TypeMap() fieldDescriptions := r.rows.FieldDescriptions() if r.valueFuncs == nil { @@ -607,21 +607,21 @@ func (r *Rows) Next(dest []driver.Value) error { switch fd.DataTypeOID { case pgtype.BoolOID: var d bool - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(src, &d) return d, err } case pgtype.ByteaOID: var d []byte - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(src, &d) return d, err } case pgtype.CIDOID, pgtype.OIDOID, pgtype.XIDOID: var d pgtype.Uint32 - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(src, &d) if err != nil { @@ -631,7 +631,7 @@ func (r *Rows) Next(dest []driver.Value) error { } case pgtype.DateOID: var d pgtype.Date - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(src, &d) if err != nil { @@ -641,42 +641,42 @@ func (r *Rows) Next(dest []driver.Value) error { } case pgtype.Float4OID: var d float32 - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(src, &d) return float64(d), err } case pgtype.Float8OID: var d float64 - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(src, &d) return d, err } case pgtype.Int2OID: var d int16 - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(src, &d) return int64(d), err } case pgtype.Int4OID: var d int32 - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(src, &d) return int64(d), err } case pgtype.Int8OID: var d int64 - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(src, &d) return d, err } case pgtype.JSONOID, pgtype.JSONBOID: var d []byte - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(src, &d) if err != nil { @@ -686,7 +686,7 @@ func (r *Rows) Next(dest []driver.Value) error { } case pgtype.TimestampOID: var d pgtype.Timestamp - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(src, &d) if err != nil { @@ -696,7 +696,7 @@ func (r *Rows) Next(dest []driver.Value) error { } case pgtype.TimestamptzOID: var d pgtype.Timestamptz - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(src, &d) if err != nil { @@ -706,7 +706,7 @@ func (r *Rows) Next(dest []driver.Value) error { } default: var d string - scanPlan := ci.PlanScan(dataTypeOID, format, &d) + scanPlan := m.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(src, &d) return d, err diff --git a/values.go b/values.go index 7661a94e..075ac2ff 100644 --- a/values.go +++ b/values.go @@ -24,7 +24,7 @@ func (e SerializationError) Error() string { return string(e) } -func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, error) { +func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error) { if arg == nil { return nil, nil } @@ -79,8 +79,8 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e return int64(arg), nil } - if _, found := ci.TypeForValue(arg); found { - buf, err := ci.Encode(0, TextFormatCode, arg, nil) + if _, found := m.TypeForValue(arg); found { + buf, err := m.Encode(0, TextFormatCode, arg, nil) if err != nil { return nil, err } @@ -92,16 +92,16 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e if refVal.Kind() == reflect.Ptr { arg = refVal.Elem().Interface() - return convertSimpleArgument(ci, arg) + return convertSimpleArgument(m, arg) } if strippedArg, ok := stripNamedType(&refVal); ok { - return convertSimpleArgument(ci, strippedArg) + return convertSimpleArgument(m, strippedArg) } return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) } -func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32, arg interface{}) ([]byte, error) { +func encodePreparedStatementArgument(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([]byte, error) { if arg == nil { return pgio.AppendInt32(buf, -1), nil } @@ -120,13 +120,13 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32 return pgio.AppendInt32(buf, -1), nil } arg = refVal.Elem().Interface() - return encodePreparedStatementArgument(ci, buf, oid, arg) + return encodePreparedStatementArgument(m, buf, oid, arg) } - if _, ok := ci.TypeForOID(oid); ok { + if _, ok := m.TypeForOID(oid); ok { sp := len(buf) buf = pgio.AppendInt32(buf, -1) - argBuf, err := ci.Encode(oid, BinaryFormatCode, arg, buf) + argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf) if err != nil { return nil, err } @@ -138,7 +138,7 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32 } if strippedArg, ok := stripNamedType(&refVal); ok { - return encodePreparedStatementArgument(ci, buf, oid, strippedArg) + return encodePreparedStatementArgument(m, buf, oid, strippedArg) } return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } @@ -146,13 +146,13 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32 // chooseParameterFormatCode determines the correct format code for an // argument to a prepared statement. It defaults to TextFormatCode if no // determination can be made. -func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid uint32, arg interface{}) int16 { +func chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg interface{}) int16 { switch arg.(type) { case string, *string: return TextFormatCode } - return ci.FormatCodeForOID(oid) + return m.FormatCodeForOID(oid) } func stripNamedType(val *reflect.Value) (interface{}, bool) { diff --git a/values_test.go b/values_test.go index 05139d29..81138bfa 100644 --- a/values_test.go +++ b/values_test.go @@ -79,7 +79,7 @@ func TestJSONAndJSONBTranscode(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { for _, typename := range []string{"json", "jsonb"} { - if _, ok := conn.ConnInfo().TypeForName(typename); !ok { + if _, ok := conn.TypeMap().TypeForName(typename); !ok { continue // No JSON/JSONB type -- must be running against old PostgreSQL } @@ -96,7 +96,7 @@ func TestJSONAndJSONBTranscodeExtendedOnly(t *testing.T) { defer closeConn(t, conn) for _, typename := range []string{"json", "jsonb"} { - if _, ok := conn.ConnInfo().TypeForName(typename); !ok { + if _, ok := conn.TypeMap().TypeForName(typename); !ok { continue // No JSON/JSONB type -- must be running against old PostgreSQL } testJSONSingleLevelStringMap(t, conn, typename) From f3defbc150fadbba3aa35e66bf7f2a492f59b556 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 21 Feb 2022 09:25:30 -0600 Subject: [PATCH 0915/1158] Rename pgtype.None to pgtype.Finite --- pgtype/builtin_wrappers.go | 6 +++--- pgtype/date.go | 10 +++++----- pgtype/numeric_test.go | 2 +- pgtype/pgtype.go | 6 +++--- pgtype/timestamp.go | 10 +++++----- pgtype/timestamptz.go | 12 ++++++------ pgtype/zeronull/timestamp.go | 2 +- pgtype/zeronull/timestamptz.go | 2 +- 8 files changed, 25 insertions(+), 25 deletions(-) diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index cb981906..e43244cf 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -368,7 +368,7 @@ func (w *timeWrapper) ScanDate(v Date) error { } switch v.InfinityModifier { - case None: + case Finite: *w = timeWrapper(v.Time) return nil case Infinity: @@ -390,7 +390,7 @@ func (w *timeWrapper) ScanTimestamp(v Timestamp) error { } switch v.InfinityModifier { - case None: + case Finite: *w = timeWrapper(v.Time) return nil case Infinity: @@ -412,7 +412,7 @@ func (w *timeWrapper) ScanTimestamptz(v Timestamptz) error { } switch v.InfinityModifier { - case None: + case Finite: *w = timeWrapper(v.Time) return nil case Infinity: diff --git a/pgtype/date.go b/pgtype/date.go index fe917a3e..bb3975a7 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -62,7 +62,7 @@ func (src Date) Value() (driver.Value, error) { return nil, nil } - if src.InfinityModifier != None { + if src.InfinityModifier != Finite { return src.InfinityModifier.String(), nil } return src.Time, nil @@ -76,7 +76,7 @@ func (src Date) MarshalJSON() ([]byte, error) { var s string switch src.InfinityModifier { - case None: + case Finite: s = src.Time.Format("2006-01-02") case Infinity: s = "infinity" @@ -155,7 +155,7 @@ func (encodePlanDateCodecBinary) Encode(value interface{}, buf []byte) (newBuf [ var daysSinceDateEpoch int32 switch date.InfinityModifier { - case None: + case Finite: tUnix := time.Date(date.Time.Year(), date.Time.Month(), date.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() @@ -185,7 +185,7 @@ func (encodePlanDateCodecText) Encode(value interface{}, buf []byte) (newBuf []b var s string switch date.InfinityModifier { - case None: + case Finite: s = date.Time.Format("2006-01-02") case Infinity: s = "infinity" @@ -282,7 +282,7 @@ func (c DateCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (in if date.Valid { switch date.InfinityModifier { - case None: + case Finite: return date.Time, nil case Infinity: return "infinity", nil diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 0449059e..950e5b11 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -42,7 +42,7 @@ func isExpectedEqNumeric(a interface{}) func(interface{}) bool { } // If NaN or InfinityModifier are set then Int and Exp don't matter. - if aa.NaN || aa.InfinityModifier != pgtype.None { + if aa.NaN || aa.InfinityModifier != pgtype.Finite { return true } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 90f07d51..b42773bc 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -109,14 +109,14 @@ type InfinityModifier int8 const ( Infinity InfinityModifier = 1 - None InfinityModifier = 0 + Finite InfinityModifier = 0 NegativeInfinity InfinityModifier = -Infinity ) func (im InfinityModifier) String() string { switch im { - case None: - return "none" + case Finite: + return "finite" case Infinity: return "infinity" case NegativeInfinity: diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 9cfa7702..314d7371 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -59,7 +59,7 @@ func (ts Timestamp) Value() (driver.Value, error) { return nil, nil } - if ts.InfinityModifier != None { + if ts.InfinityModifier != Finite { return ts.InfinityModifier.String(), nil } return ts.Time, nil @@ -104,7 +104,7 @@ func (encodePlanTimestampCodecBinary) Encode(value interface{}, buf []byte) (new var microsecSinceY2K int64 switch ts.InfinityModifier { - case None: + case Finite: t := discardTimeZone(ts.Time) microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K @@ -130,7 +130,7 @@ func (encodePlanTimestampCodecText) Encode(value interface{}, buf []byte) (newBu var s string switch ts.InfinityModifier { - case None: + case Finite: t := discardTimeZone(ts.Time) s = t.Truncate(time.Microsecond).Format(pgTimestampFormat) case Infinity: @@ -241,7 +241,7 @@ func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, return nil, err } - if ts.InfinityModifier != None { + if ts.InfinityModifier != Finite { return ts.InfinityModifier.String(), nil } @@ -259,7 +259,7 @@ func (c TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte return nil, err } - if ts.InfinityModifier != None { + if ts.InfinityModifier != Finite { return ts.InfinityModifier, nil } diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 0bdd0fd6..46554f54 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -68,7 +68,7 @@ func (tstz Timestamptz) Value() (driver.Value, error) { return nil, nil } - if tstz.InfinityModifier != None { + if tstz.InfinityModifier != Finite { return tstz.InfinityModifier.String(), nil } return tstz.Time, nil @@ -82,7 +82,7 @@ func (tstz Timestamptz) MarshalJSON() ([]byte, error) { var s string switch tstz.InfinityModifier { - case None: + case Finite: s = tstz.Time.Format(time.RFC3339Nano) case Infinity: s = "infinity" @@ -162,7 +162,7 @@ func (encodePlanTimestamptzCodecBinary) Encode(value interface{}, buf []byte) (n var microsecSinceY2K int64 switch ts.InfinityModifier { - case None: + case Finite: microsecSinceUnixEpoch := ts.Time.Unix()*1000000 + int64(ts.Time.Nanosecond())/1000 microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K case Infinity: @@ -187,7 +187,7 @@ func (encodePlanTimestamptzCodecText) Encode(value interface{}, buf []byte) (new var s string switch ts.InfinityModifier { - case None: + case Finite: s = ts.Time.UTC().Truncate(time.Microsecond).Format(pgTimestamptzSecondFormat) case Infinity: s = "infinity" @@ -298,7 +298,7 @@ func (c TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int1 return nil, err } - if tstz.InfinityModifier != None { + if tstz.InfinityModifier != Finite { return tstz.InfinityModifier.String(), nil } @@ -316,7 +316,7 @@ func (c TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []by return nil, err } - if tstz.InfinityModifier != None { + if tstz.InfinityModifier != Finite { return tstz.InfinityModifier, nil } diff --git a/pgtype/zeronull/timestamp.go b/pgtype/zeronull/timestamp.go index 6e2c3d1e..163af041 100644 --- a/pgtype/zeronull/timestamp.go +++ b/pgtype/zeronull/timestamp.go @@ -19,7 +19,7 @@ func (ts *Timestamp) ScanTimestamp(v pgtype.Timestamp) error { } switch v.InfinityModifier { - case pgtype.None: + case pgtype.Finite: *ts = Timestamp(v.Time) return nil case pgtype.Infinity: diff --git a/pgtype/zeronull/timestamptz.go b/pgtype/zeronull/timestamptz.go index 79fcb563..6cd60c37 100644 --- a/pgtype/zeronull/timestamptz.go +++ b/pgtype/zeronull/timestamptz.go @@ -19,7 +19,7 @@ func (ts *Timestamptz) ScanTimestamptz(v pgtype.Timestamptz) error { } switch v.InfinityModifier { - case pgtype.None: + case pgtype.Finite: *ts = Timestamptz(v.Time) return nil case pgtype.Infinity: From 9c538cd4a9cfe3483a382559b5e8f852943c8698 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 21 Feb 2022 09:30:01 -0600 Subject: [PATCH 0916/1158] Remove actualTarget argument --- pgtype/array_codec.go | 6 +++--- pgtype/bits.go | 2 +- pgtype/bool.go | 2 +- pgtype/box.go | 2 +- pgtype/bytea.go | 2 +- pgtype/circle.go | 2 +- pgtype/composite.go | 2 +- pgtype/date.go | 2 +- pgtype/enum_codec.go | 2 +- pgtype/float4.go | 2 +- pgtype/float8.go | 2 +- pgtype/hstore.go | 2 +- pgtype/inet.go | 2 +- pgtype/int.go | 6 +++--- pgtype/int.go.erb | 2 +- pgtype/interval.go | 2 +- pgtype/json.go | 2 +- pgtype/jsonb.go | 6 +++--- pgtype/line.go | 2 +- pgtype/lseg.go | 2 +- pgtype/macaddr.go | 2 +- pgtype/numeric.go | 2 +- pgtype/numeric_test.go | 2 +- pgtype/path.go | 2 +- pgtype/pgtype.go | 9 ++++----- pgtype/point.go | 2 +- pgtype/polygon.go | 2 +- pgtype/qchar.go | 2 +- pgtype/range_codec.go | 4 ++-- pgtype/record_codec.go | 2 +- pgtype/text.go | 2 +- pgtype/tid.go | 2 +- pgtype/time.go | 2 +- pgtype/timestamp.go | 2 +- pgtype/timestamp_test.go | 2 +- pgtype/timestamptz.go | 2 +- pgtype/timestamptz_test.go | 2 +- pgtype/uint32.go | 2 +- pgtype/uuid.go | 2 +- 39 files changed, 49 insertions(+), 50 deletions(-) diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index dd14e83e..74bcc5d3 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -210,7 +210,7 @@ func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newB return buf, nil } -func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { arrayScanner, ok := target.(ArraySetter) if !ok { return nil @@ -248,7 +248,7 @@ func (c *ArrayCodec) decodeBinary(m *Map, arrayOID uint32, src []byte, array Arr return nil } - elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0), false) + elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0)) if elementScanPlan == nil { elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0)) } @@ -286,7 +286,7 @@ func (c *ArrayCodec) decodeText(m *Map, arrayOID uint32, src []byte, array Array return nil } - elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, TextFormatCode, array.ScanIndex(0), false) + elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, TextFormatCode, array.ScanIndex(0)) if elementScanPlan == nil { elementScanPlan = m.PlanScan(c.ElementType.OID, TextFormatCode, array.ScanIndex(0)) } diff --git a/pgtype/bits.go b/pgtype/bits.go index 8afd8d90..12df03d5 100644 --- a/pgtype/bits.go +++ b/pgtype/bits.go @@ -126,7 +126,7 @@ func (encodePlanBitsCodecText) Encode(value interface{}, buf []byte) (newBuf []b return buf, nil } -func (BitsCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (BitsCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/bool.go b/pgtype/bool.go index 1158ad06..24ae5c5f 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -197,7 +197,7 @@ func (encodePlanBoolCodecTextBool) Encode(value interface{}, buf []byte) (newBuf return buf, nil } -func (BoolCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (BoolCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/box.go b/pgtype/box.go index 6e44c436..25d4f153 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -126,7 +126,7 @@ func (encodePlanBoxCodecText) Encode(value interface{}, buf []byte) (newBuf []by return buf, nil } -func (BoxCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (BoxCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/bytea.go b/pgtype/bytea.go index eb865df0..f3c33cb9 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -147,7 +147,7 @@ func (encodePlanBytesCodecTextBytesValuer) Encode(value interface{}, buf []byte) return buf, nil } -func (ByteaCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (ByteaCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/circle.go b/pgtype/circle.go index 6a83b41f..6dfb4fae 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -125,7 +125,7 @@ func (encodePlanCircleCodecText) Encode(value interface{}, buf []byte) (newBuf [ return buf, nil } -func (CircleCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (CircleCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/composite.go b/pgtype/composite.go index c538834d..5a67c3df 100644 --- a/pgtype/composite.go +++ b/pgtype/composite.go @@ -109,7 +109,7 @@ func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value int return b.Finish() } -func (c *CompositeCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (c *CompositeCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/date.go b/pgtype/date.go index bb3975a7..f59508d3 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -196,7 +196,7 @@ func (encodePlanDateCodecText) Encode(value interface{}, buf []byte) (newBuf []b return append(buf, s...), nil } -func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/enum_codec.go b/pgtype/enum_codec.go index 3dce4449..ecad0d9f 100644 --- a/pgtype/enum_codec.go +++ b/pgtype/enum_codec.go @@ -38,7 +38,7 @@ func (EnumCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) return nil } -func (c *EnumCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (c *EnumCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case TextFormatCode, BinaryFormatCode: switch target.(type) { diff --git a/pgtype/float4.go b/pgtype/float4.go index 9ca6fe6a..9b31579f 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -145,7 +145,7 @@ func (encodePlanFloat4CodecBinaryInt64Valuer) Encode(value interface{}, buf []by return pgio.AppendUint32(buf, math.Float32bits(f)), nil } -func (Float4Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (Float4Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/float8.go b/pgtype/float8.go index d5461ab0..30548b88 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -183,7 +183,7 @@ func (encodePlanTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf [ return append(buf, strconv.FormatInt(n.Int, 10)...), nil } -func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/hstore.go b/pgtype/hstore.go index a27330ae..46b3d236 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -150,7 +150,7 @@ func (encodePlanHstoreCodecText) Encode(value interface{}, buf []byte) (newBuf [ return buf, nil } -func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/inet.go b/pgtype/inet.go index ab4cff47..a272e00b 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -140,7 +140,7 @@ func (encodePlanInetCodecText) Encode(value interface{}, buf []byte) (newBuf []b return append(buf, inet.IPNet.String()...), nil } -func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/int.go b/pgtype/int.go index bfdb0184..a799f2bf 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -206,7 +206,7 @@ func (encodePlanInt2CodecTextInt64Valuer) Encode(value interface{}, buf []byte) return append(buf, strconv.FormatInt(n.Int, 10)...), nil } -func (Int2Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (Int2Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: @@ -743,7 +743,7 @@ func (encodePlanInt4CodecTextInt64Valuer) Encode(value interface{}, buf []byte) return append(buf, strconv.FormatInt(n.Int, 10)...), nil } -func (Int4Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (Int4Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: @@ -1291,7 +1291,7 @@ func (encodePlanInt8CodecTextInt64Valuer) Encode(value interface{}, buf []byte) return append(buf, strconv.FormatInt(n.Int, 10)...), nil } -func (Int8Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (Int8Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index cec88984..d3c519a7 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -207,7 +207,7 @@ func (encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer) Encode(value interfa return append(buf, strconv.FormatInt(n.Int, 10)...), nil } -func (Int<%= pg_byte_size %>Codec) PlanScan(m *TypeMap, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (Int<%= pg_byte_size %>Codec) PlanScan(m *TypeMap, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/interval.go b/pgtype/interval.go index a13969c3..b4dcf0a6 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -151,7 +151,7 @@ func (encodePlanIntervalCodecText) Encode(value interface{}, buf []byte) (newBuf return buf, nil } -func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/json.go b/pgtype/json.go index 04ce6f6b..e8882d3a 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -49,7 +49,7 @@ func (encodePlanJSONCodecEitherFormatMarshal) Encode(value interface{}, buf []by return buf, nil } -func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch target.(type) { case *string: return scanPlanAnyToString{} diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index 0504ee62..7e3d3f8d 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -39,15 +39,15 @@ func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value interface{}, buf []b return plan.textPlan.Encode(value, buf) } -func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: - plan := JSONCodec{}.PlanScan(m, oid, TextFormatCode, target, actualTarget) + plan := JSONCodec{}.PlanScan(m, oid, TextFormatCode, target) if plan != nil { return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan} } case TextFormatCode: - return JSONCodec{}.PlanScan(m, oid, format, target, actualTarget) + return JSONCodec{}.PlanScan(m, oid, format, target) } return nil diff --git a/pgtype/line.go b/pgtype/line.go index 17acab81..c9cac4a7 100644 --- a/pgtype/line.go +++ b/pgtype/line.go @@ -128,7 +128,7 @@ func (encodePlanLineCodecText) Encode(value interface{}, buf []byte) (newBuf []b return buf, nil } -func (LineCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (LineCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/lseg.go b/pgtype/lseg.go index 8f65c7c3..4243f6e0 100644 --- a/pgtype/lseg.go +++ b/pgtype/lseg.go @@ -126,7 +126,7 @@ func (encodePlanLsegCodecText) Encode(value interface{}, buf []byte) (newBuf []b return buf, nil } -func (LsegCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (LsegCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go index 23ca55d1..686e759a 100644 --- a/pgtype/macaddr.go +++ b/pgtype/macaddr.go @@ -78,7 +78,7 @@ func (encodePlanMacaddrCodecTextHardwareAddr) Encode(value interface{}, buf []by return append(buf, addr.String()...), nil } -func (MacaddrCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (MacaddrCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/numeric.go b/pgtype/numeric.go index 3e7f972f..41bf1432 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -501,7 +501,7 @@ func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { return buf, nil } -func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 950e5b11..91b77881 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -63,7 +63,7 @@ func isExpectedEqNumeric(a interface{}) func(interface{}) bool { func mustParseNumeric(t *testing.T, src string) pgtype.Numeric { var n pgtype.Numeric - plan := pgtype.NumericCodec{}.PlanScan(nil, pgtype.NumericOID, pgtype.TextFormatCode, &n, false) + plan := pgtype.NumericCodec{}.PlanScan(nil, pgtype.NumericOID, pgtype.TextFormatCode, &n) require.NotNil(t, plan) err := plan.Scan([]byte(src), &n) require.NoError(t, err) diff --git a/pgtype/path.go b/pgtype/path.go index c1355e41..3b8e598e 100644 --- a/pgtype/path.go +++ b/pgtype/path.go @@ -153,7 +153,7 @@ func (encodePlanPathCodecText) Encode(value interface{}, buf []byte) (newBuf []b return buf, nil } -func (PathCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (PathCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index b42773bc..70c1059c 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -144,9 +144,8 @@ type Codec interface { PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan // PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If - // actualTarget is true then the returned ScanPlan may be optimized to directly scan into target. If no plan can be - // found then nil is returned. - PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan + // no plan can be found then nil is returned. + PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan // DecodeDatabaseSQLValue returns src decoded into a value compatible with the sql.Scanner interface. DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) @@ -1030,7 +1029,7 @@ func (m *Map) PlanScan(oid uint32, formatCode int16, target interface{}) ScanPla } if dt != nil { - if plan := dt.Codec.PlanScan(m, oid, formatCode, target, false); plan != nil { + if plan := dt.Codec.PlanScan(m, oid, formatCode, target); plan != nil { return plan } } @@ -1094,7 +1093,7 @@ func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) var ErrScanTargetTypeChanged = errors.New("scan target type changed") func codecScan(codec Codec, m *Map, oid uint32, format int16, src []byte, dst interface{}) error { - scanPlan := codec.PlanScan(m, oid, format, dst, true) + scanPlan := codec.PlanScan(m, oid, format, dst) if scanPlan == nil { return fmt.Errorf("PlanScan did not find a plan") } diff --git a/pgtype/point.go b/pgtype/point.go index 4a9637ee..8df57703 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -177,7 +177,7 @@ func (encodePlanPointCodecText) Encode(value interface{}, buf []byte) (newBuf [] )...), nil } -func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/polygon.go b/pgtype/polygon.go index 5a090ff0..ca479965 100644 --- a/pgtype/polygon.go +++ b/pgtype/polygon.go @@ -138,7 +138,7 @@ func (encodePlanPolygonCodecText) Encode(value interface{}, buf []byte) (newBuf return buf, nil } -func (PolygonCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (PolygonCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/qchar.go b/pgtype/qchar.go index ce2c3dcb..677b9003 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -56,7 +56,7 @@ func (encodePlanQcharCodecRune) Encode(value interface{}, buf []byte) (newBuf [] return buf, nil } -func (QCharCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (QCharCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case TextFormatCode, BinaryFormatCode: switch target.(type) { diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go index 9628ce4b..0fa43a68 100644 --- a/pgtype/range_codec.go +++ b/pgtype/range_codec.go @@ -270,7 +270,7 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value interface{}, buf return buf, nil } -func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { @@ -410,6 +410,6 @@ func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) ( } var r GenericRange - err := c.PlanScan(m, oid, format, &r, true).Scan(src, &r) + err := c.PlanScan(m, oid, format, &r).Scan(src, &r) return r, err } diff --git a/pgtype/record_codec.go b/pgtype/record_codec.go index c31aa63c..a5c72aac 100644 --- a/pgtype/record_codec.go +++ b/pgtype/record_codec.go @@ -25,7 +25,7 @@ func (RecordCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{ return nil } -func (RecordCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (RecordCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { if format == BinaryFormatCode { switch target.(type) { case CompositeIndexScanner: diff --git a/pgtype/text.go b/pgtype/text.go index 2c551958..82e7753c 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -156,7 +156,7 @@ func (encodePlanTextCodecTextValuer) Encode(value interface{}, buf []byte) (newB return buf, nil } -func (TextCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (TextCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case TextFormatCode, BinaryFormatCode: diff --git a/pgtype/tid.go b/pgtype/tid.go index 2744cb15..5faa7502 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -130,7 +130,7 @@ func (encodePlanTIDCodecText) Encode(value interface{}, buf []byte) (newBuf []by return buf, nil } -func (TIDCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (TIDCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/time.go b/pgtype/time.go index 71e7f597..dc40f1fc 100644 --- a/pgtype/time.go +++ b/pgtype/time.go @@ -129,7 +129,7 @@ func (encodePlanTimeCodecText) Encode(value interface{}, buf []byte) (newBuf []b return append(buf, s...), nil } -func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 314d7371..03fb2b28 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -152,7 +152,7 @@ func discardTimeZone(t time.Time) time.Time { return t } -func (TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index 562bb192..a33ce78f 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -51,7 +51,7 @@ func TestTimestampTranscodeBigTimeBinary(t *testing.T) { func TestTimestampCodecDecodeTextInvalid(t *testing.T) { c := &pgtype.TimestampCodec{} var ts pgtype.Timestamp - plan := c.PlanScan(nil, pgtype.TimestampOID, pgtype.TextFormatCode, &ts, false) + plan := c.PlanScan(nil, pgtype.TimestampOID, pgtype.TextFormatCode, &ts) err := plan.Scan([]byte(`eeeee`), &ts) require.Error(t, err) } diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 46554f54..8be5970c 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -200,7 +200,7 @@ func (encodePlanTimestamptzCodecText) Encode(value interface{}, buf []byte) (new return buf, nil } -func (TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index ec408f47..ec198fa1 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -51,7 +51,7 @@ func TestTimestamptzTranscodeBigTimeBinary(t *testing.T) { func TestTimestamptzDecodeTextInvalid(t *testing.T) { c := &pgtype.TimestamptzCodec{} var tstz pgtype.Timestamptz - plan := c.PlanScan(nil, pgtype.TimestamptzOID, pgtype.TextFormatCode, &tstz, false) + plan := c.PlanScan(nil, pgtype.TimestamptzOID, pgtype.TextFormatCode, &tstz) err := plan.Scan([]byte(`eeeee`), &tstz) require.Error(t, err) } diff --git a/pgtype/uint32.go b/pgtype/uint32.go index 44238fa0..d406f79b 100644 --- a/pgtype/uint32.go +++ b/pgtype/uint32.go @@ -196,7 +196,7 @@ func (encodePlanUint32CodecTextInt64Valuer) Encode(value interface{}, buf []byte return append(buf, strconv.FormatInt(v.Int, 10)...), nil } -func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: diff --git a/pgtype/uuid.go b/pgtype/uuid.go index 39df8537..a561bed9 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -167,7 +167,7 @@ func (encodePlanUUIDCodecTextUUIDValuer) Encode(value interface{}, buf []byte) ( return append(buf, encodeUUID(uuid.Bytes)...), nil } -func (UUIDCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { +func (UUIDCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { From 43083cb0e3c4432bd21a478e71e46219b9087108 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 21 Feb 2022 10:10:16 -0600 Subject: [PATCH 0917/1158] Memoize pgtype.Map.PlanScan --- pgtype/pgtype.go | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 70c1059c..cb24f295 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -179,6 +179,8 @@ type Map struct { reflectTypeToType map[reflect.Type]*Type + memoizedScanPlans map[uint32]map[reflect.Type][2]ScanPlan + // TryWrapEncodePlanFuncs is a slice of functions that will wrap a value that cannot be encoded by the Codec. Every // time a wrapper is found the PlanEncode method will be recursively called with the new value. This allows several layers of wrappers // to be built up. There are default functions placed in this slice by NewMap(). In most cases these functions @@ -200,6 +202,8 @@ func NewMap() *Map { oidToFormatCode: make(map[uint32]int16), oidToResultFormatCode: make(map[uint32]int16), + memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan), + TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ TryWrapDerefPointerEncodePlan, TryWrapBuiltinTypeEncodePlan, @@ -366,7 +370,12 @@ func (m *Map) RegisterType(t *Type) { m.oidToType[t.OID] = t m.nameToType[t.Name] = t m.oidToFormatCode[t.OID] = t.Codec.PreferredFormat() - m.reflectTypeToType = nil // Invalidated by type registration + + // Invalidated by type registration + m.reflectTypeToType = nil + for k := range m.memoizedScanPlans { + delete(m.memoizedScanPlans, k) + } } // RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be @@ -374,7 +383,12 @@ func (m *Map) RegisterType(t *Type) { // unknown, this additional mapping will be used by TypeForValue to determine a suitable data type. func (m *Map) RegisterDefaultPgType(value interface{}, name string) { m.reflectTypeToName[reflect.TypeOf(value)] = name - m.reflectTypeToType = nil // Invalidated by registering a default type + + // Invalidated by type registration + m.reflectTypeToType = nil + for k := range m.memoizedScanPlans { + delete(m.memoizedScanPlans, k) + } } func (m *Map) TypeForOID(oid uint32) (*Type, bool) { @@ -993,6 +1007,24 @@ func (plan *wrapPtrMultiDimSliceScanPlan) Scan(src []byte, target interface{}) e // PlanScan prepares a plan to scan a value into target. func (m *Map) PlanScan(oid uint32, formatCode int16, target interface{}) ScanPlan { + oidMemo := m.memoizedScanPlans[oid] + if oidMemo == nil { + oidMemo = make(map[reflect.Type][2]ScanPlan) + m.memoizedScanPlans[oid] = oidMemo + } + targetReflectType := reflect.TypeOf(target) + typeMemo := oidMemo[targetReflectType] + plan := typeMemo[formatCode] + if plan == nil { + plan = m.planScan(oid, formatCode, target) + typeMemo[formatCode] = plan + oidMemo[targetReflectType] = typeMemo + } + + return plan +} + +func (m *Map) planScan(oid uint32, formatCode int16, target interface{}) ScanPlan { if _, ok := target.(*UndecodedBytes); ok { return scanPlanAnyToUndecodedBytes{} } From 04476c4a131370216c1bdd64c9fdb686f068a03d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 21 Feb 2022 11:57:34 -0600 Subject: [PATCH 0918/1158] Move pgproto3 to subdirectory --- LICENSE => pgproto3/LICENSE | 0 README.md => pgproto3/README.md | 0 .../authentication_cleartext_password.go | 0 .../authentication_md5_password.go | 0 authentication_ok.go => pgproto3/authentication_ok.go | 0 authentication_sasl.go => pgproto3/authentication_sasl.go | 0 .../authentication_sasl_continue.go | 0 .../authentication_sasl_final.go | 0 backend.go => pgproto3/backend.go | 0 backend_key_data.go => pgproto3/backend_key_data.go | 0 backend_test.go => pgproto3/backend_test.go | 0 big_endian.go => pgproto3/big_endian.go | 0 bind.go => pgproto3/bind.go | 0 bind_complete.go => pgproto3/bind_complete.go | 0 cancel_request.go => pgproto3/cancel_request.go | 0 chunkreader.go => pgproto3/chunkreader.go | 0 close.go => pgproto3/close.go | 0 close_complete.go => pgproto3/close_complete.go | 0 command_complete.go => pgproto3/command_complete.go | 0 copy_both_response.go => pgproto3/copy_both_response.go | 0 copy_data.go => pgproto3/copy_data.go | 0 copy_done.go => pgproto3/copy_done.go | 0 copy_fail.go => pgproto3/copy_fail.go | 0 copy_in_response.go => pgproto3/copy_in_response.go | 0 copy_out_response.go => pgproto3/copy_out_response.go | 0 data_row.go => pgproto3/data_row.go | 0 describe.go => pgproto3/describe.go | 0 doc.go => pgproto3/doc.go | 0 empty_query_response.go => pgproto3/empty_query_response.go | 0 error_response.go => pgproto3/error_response.go | 0 {example => pgproto3/example}/pgfortune/README.md | 0 {example => pgproto3/example}/pgfortune/main.go | 0 {example => pgproto3/example}/pgfortune/server.go | 0 execute.go => pgproto3/execute.go | 0 flush.go => pgproto3/flush.go | 0 frontend.go => pgproto3/frontend.go | 0 frontend_test.go => pgproto3/frontend_test.go | 0 function_call.go => pgproto3/function_call.go | 0 function_call_response.go => pgproto3/function_call_response.go | 0 function_call_test.go => pgproto3/function_call_test.go | 0 go.mod => pgproto3/go.mod | 0 go.sum => pgproto3/go.sum | 0 gss_enc_request.go => pgproto3/gss_enc_request.go | 0 json_test.go => pgproto3/json_test.go | 0 no_data.go => pgproto3/no_data.go | 0 notice_response.go => pgproto3/notice_response.go | 0 notification_response.go => pgproto3/notification_response.go | 0 parameter_description.go => pgproto3/parameter_description.go | 0 parameter_status.go => pgproto3/parameter_status.go | 0 parse.go => pgproto3/parse.go | 0 parse_complete.go => pgproto3/parse_complete.go | 0 password_message.go => pgproto3/password_message.go | 0 pgproto3.go => pgproto3/pgproto3.go | 0 portal_suspended.go => pgproto3/portal_suspended.go | 0 query.go => pgproto3/query.go | 0 ready_for_query.go => pgproto3/ready_for_query.go | 0 row_description.go => pgproto3/row_description.go | 0 sasl_initial_response.go => pgproto3/sasl_initial_response.go | 0 sasl_response.go => pgproto3/sasl_response.go | 0 ssl_request.go => pgproto3/ssl_request.go | 0 startup_message.go => pgproto3/startup_message.go | 0 sync.go => pgproto3/sync.go | 0 terminate.go => pgproto3/terminate.go | 0 63 files changed, 0 insertions(+), 0 deletions(-) rename LICENSE => pgproto3/LICENSE (100%) rename README.md => pgproto3/README.md (100%) rename authentication_cleartext_password.go => pgproto3/authentication_cleartext_password.go (100%) rename authentication_md5_password.go => pgproto3/authentication_md5_password.go (100%) rename authentication_ok.go => pgproto3/authentication_ok.go (100%) rename authentication_sasl.go => pgproto3/authentication_sasl.go (100%) rename authentication_sasl_continue.go => pgproto3/authentication_sasl_continue.go (100%) rename authentication_sasl_final.go => pgproto3/authentication_sasl_final.go (100%) rename backend.go => pgproto3/backend.go (100%) rename backend_key_data.go => pgproto3/backend_key_data.go (100%) rename backend_test.go => pgproto3/backend_test.go (100%) rename big_endian.go => pgproto3/big_endian.go (100%) rename bind.go => pgproto3/bind.go (100%) rename bind_complete.go => pgproto3/bind_complete.go (100%) rename cancel_request.go => pgproto3/cancel_request.go (100%) rename chunkreader.go => pgproto3/chunkreader.go (100%) rename close.go => pgproto3/close.go (100%) rename close_complete.go => pgproto3/close_complete.go (100%) rename command_complete.go => pgproto3/command_complete.go (100%) rename copy_both_response.go => pgproto3/copy_both_response.go (100%) rename copy_data.go => pgproto3/copy_data.go (100%) rename copy_done.go => pgproto3/copy_done.go (100%) rename copy_fail.go => pgproto3/copy_fail.go (100%) rename copy_in_response.go => pgproto3/copy_in_response.go (100%) rename copy_out_response.go => pgproto3/copy_out_response.go (100%) rename data_row.go => pgproto3/data_row.go (100%) rename describe.go => pgproto3/describe.go (100%) rename doc.go => pgproto3/doc.go (100%) rename empty_query_response.go => pgproto3/empty_query_response.go (100%) rename error_response.go => pgproto3/error_response.go (100%) rename {example => pgproto3/example}/pgfortune/README.md (100%) rename {example => pgproto3/example}/pgfortune/main.go (100%) rename {example => pgproto3/example}/pgfortune/server.go (100%) rename execute.go => pgproto3/execute.go (100%) rename flush.go => pgproto3/flush.go (100%) rename frontend.go => pgproto3/frontend.go (100%) rename frontend_test.go => pgproto3/frontend_test.go (100%) rename function_call.go => pgproto3/function_call.go (100%) rename function_call_response.go => pgproto3/function_call_response.go (100%) rename function_call_test.go => pgproto3/function_call_test.go (100%) rename go.mod => pgproto3/go.mod (100%) rename go.sum => pgproto3/go.sum (100%) rename gss_enc_request.go => pgproto3/gss_enc_request.go (100%) rename json_test.go => pgproto3/json_test.go (100%) rename no_data.go => pgproto3/no_data.go (100%) rename notice_response.go => pgproto3/notice_response.go (100%) rename notification_response.go => pgproto3/notification_response.go (100%) rename parameter_description.go => pgproto3/parameter_description.go (100%) rename parameter_status.go => pgproto3/parameter_status.go (100%) rename parse.go => pgproto3/parse.go (100%) rename parse_complete.go => pgproto3/parse_complete.go (100%) rename password_message.go => pgproto3/password_message.go (100%) rename pgproto3.go => pgproto3/pgproto3.go (100%) rename portal_suspended.go => pgproto3/portal_suspended.go (100%) rename query.go => pgproto3/query.go (100%) rename ready_for_query.go => pgproto3/ready_for_query.go (100%) rename row_description.go => pgproto3/row_description.go (100%) rename sasl_initial_response.go => pgproto3/sasl_initial_response.go (100%) rename sasl_response.go => pgproto3/sasl_response.go (100%) rename ssl_request.go => pgproto3/ssl_request.go (100%) rename startup_message.go => pgproto3/startup_message.go (100%) rename sync.go => pgproto3/sync.go (100%) rename terminate.go => pgproto3/terminate.go (100%) diff --git a/LICENSE b/pgproto3/LICENSE similarity index 100% rename from LICENSE rename to pgproto3/LICENSE diff --git a/README.md b/pgproto3/README.md similarity index 100% rename from README.md rename to pgproto3/README.md diff --git a/authentication_cleartext_password.go b/pgproto3/authentication_cleartext_password.go similarity index 100% rename from authentication_cleartext_password.go rename to pgproto3/authentication_cleartext_password.go diff --git a/authentication_md5_password.go b/pgproto3/authentication_md5_password.go similarity index 100% rename from authentication_md5_password.go rename to pgproto3/authentication_md5_password.go diff --git a/authentication_ok.go b/pgproto3/authentication_ok.go similarity index 100% rename from authentication_ok.go rename to pgproto3/authentication_ok.go diff --git a/authentication_sasl.go b/pgproto3/authentication_sasl.go similarity index 100% rename from authentication_sasl.go rename to pgproto3/authentication_sasl.go diff --git a/authentication_sasl_continue.go b/pgproto3/authentication_sasl_continue.go similarity index 100% rename from authentication_sasl_continue.go rename to pgproto3/authentication_sasl_continue.go diff --git a/authentication_sasl_final.go b/pgproto3/authentication_sasl_final.go similarity index 100% rename from authentication_sasl_final.go rename to pgproto3/authentication_sasl_final.go diff --git a/backend.go b/pgproto3/backend.go similarity index 100% rename from backend.go rename to pgproto3/backend.go diff --git a/backend_key_data.go b/pgproto3/backend_key_data.go similarity index 100% rename from backend_key_data.go rename to pgproto3/backend_key_data.go diff --git a/backend_test.go b/pgproto3/backend_test.go similarity index 100% rename from backend_test.go rename to pgproto3/backend_test.go diff --git a/big_endian.go b/pgproto3/big_endian.go similarity index 100% rename from big_endian.go rename to pgproto3/big_endian.go diff --git a/bind.go b/pgproto3/bind.go similarity index 100% rename from bind.go rename to pgproto3/bind.go diff --git a/bind_complete.go b/pgproto3/bind_complete.go similarity index 100% rename from bind_complete.go rename to pgproto3/bind_complete.go diff --git a/cancel_request.go b/pgproto3/cancel_request.go similarity index 100% rename from cancel_request.go rename to pgproto3/cancel_request.go diff --git a/chunkreader.go b/pgproto3/chunkreader.go similarity index 100% rename from chunkreader.go rename to pgproto3/chunkreader.go diff --git a/close.go b/pgproto3/close.go similarity index 100% rename from close.go rename to pgproto3/close.go diff --git a/close_complete.go b/pgproto3/close_complete.go similarity index 100% rename from close_complete.go rename to pgproto3/close_complete.go diff --git a/command_complete.go b/pgproto3/command_complete.go similarity index 100% rename from command_complete.go rename to pgproto3/command_complete.go diff --git a/copy_both_response.go b/pgproto3/copy_both_response.go similarity index 100% rename from copy_both_response.go rename to pgproto3/copy_both_response.go diff --git a/copy_data.go b/pgproto3/copy_data.go similarity index 100% rename from copy_data.go rename to pgproto3/copy_data.go diff --git a/copy_done.go b/pgproto3/copy_done.go similarity index 100% rename from copy_done.go rename to pgproto3/copy_done.go diff --git a/copy_fail.go b/pgproto3/copy_fail.go similarity index 100% rename from copy_fail.go rename to pgproto3/copy_fail.go diff --git a/copy_in_response.go b/pgproto3/copy_in_response.go similarity index 100% rename from copy_in_response.go rename to pgproto3/copy_in_response.go diff --git a/copy_out_response.go b/pgproto3/copy_out_response.go similarity index 100% rename from copy_out_response.go rename to pgproto3/copy_out_response.go diff --git a/data_row.go b/pgproto3/data_row.go similarity index 100% rename from data_row.go rename to pgproto3/data_row.go diff --git a/describe.go b/pgproto3/describe.go similarity index 100% rename from describe.go rename to pgproto3/describe.go diff --git a/doc.go b/pgproto3/doc.go similarity index 100% rename from doc.go rename to pgproto3/doc.go diff --git a/empty_query_response.go b/pgproto3/empty_query_response.go similarity index 100% rename from empty_query_response.go rename to pgproto3/empty_query_response.go diff --git a/error_response.go b/pgproto3/error_response.go similarity index 100% rename from error_response.go rename to pgproto3/error_response.go diff --git a/example/pgfortune/README.md b/pgproto3/example/pgfortune/README.md similarity index 100% rename from example/pgfortune/README.md rename to pgproto3/example/pgfortune/README.md diff --git a/example/pgfortune/main.go b/pgproto3/example/pgfortune/main.go similarity index 100% rename from example/pgfortune/main.go rename to pgproto3/example/pgfortune/main.go diff --git a/example/pgfortune/server.go b/pgproto3/example/pgfortune/server.go similarity index 100% rename from example/pgfortune/server.go rename to pgproto3/example/pgfortune/server.go diff --git a/execute.go b/pgproto3/execute.go similarity index 100% rename from execute.go rename to pgproto3/execute.go diff --git a/flush.go b/pgproto3/flush.go similarity index 100% rename from flush.go rename to pgproto3/flush.go diff --git a/frontend.go b/pgproto3/frontend.go similarity index 100% rename from frontend.go rename to pgproto3/frontend.go diff --git a/frontend_test.go b/pgproto3/frontend_test.go similarity index 100% rename from frontend_test.go rename to pgproto3/frontend_test.go diff --git a/function_call.go b/pgproto3/function_call.go similarity index 100% rename from function_call.go rename to pgproto3/function_call.go diff --git a/function_call_response.go b/pgproto3/function_call_response.go similarity index 100% rename from function_call_response.go rename to pgproto3/function_call_response.go diff --git a/function_call_test.go b/pgproto3/function_call_test.go similarity index 100% rename from function_call_test.go rename to pgproto3/function_call_test.go diff --git a/go.mod b/pgproto3/go.mod similarity index 100% rename from go.mod rename to pgproto3/go.mod diff --git a/go.sum b/pgproto3/go.sum similarity index 100% rename from go.sum rename to pgproto3/go.sum diff --git a/gss_enc_request.go b/pgproto3/gss_enc_request.go similarity index 100% rename from gss_enc_request.go rename to pgproto3/gss_enc_request.go diff --git a/json_test.go b/pgproto3/json_test.go similarity index 100% rename from json_test.go rename to pgproto3/json_test.go diff --git a/no_data.go b/pgproto3/no_data.go similarity index 100% rename from no_data.go rename to pgproto3/no_data.go diff --git a/notice_response.go b/pgproto3/notice_response.go similarity index 100% rename from notice_response.go rename to pgproto3/notice_response.go diff --git a/notification_response.go b/pgproto3/notification_response.go similarity index 100% rename from notification_response.go rename to pgproto3/notification_response.go diff --git a/parameter_description.go b/pgproto3/parameter_description.go similarity index 100% rename from parameter_description.go rename to pgproto3/parameter_description.go diff --git a/parameter_status.go b/pgproto3/parameter_status.go similarity index 100% rename from parameter_status.go rename to pgproto3/parameter_status.go diff --git a/parse.go b/pgproto3/parse.go similarity index 100% rename from parse.go rename to pgproto3/parse.go diff --git a/parse_complete.go b/pgproto3/parse_complete.go similarity index 100% rename from parse_complete.go rename to pgproto3/parse_complete.go diff --git a/password_message.go b/pgproto3/password_message.go similarity index 100% rename from password_message.go rename to pgproto3/password_message.go diff --git a/pgproto3.go b/pgproto3/pgproto3.go similarity index 100% rename from pgproto3.go rename to pgproto3/pgproto3.go diff --git a/portal_suspended.go b/pgproto3/portal_suspended.go similarity index 100% rename from portal_suspended.go rename to pgproto3/portal_suspended.go diff --git a/query.go b/pgproto3/query.go similarity index 100% rename from query.go rename to pgproto3/query.go diff --git a/ready_for_query.go b/pgproto3/ready_for_query.go similarity index 100% rename from ready_for_query.go rename to pgproto3/ready_for_query.go diff --git a/row_description.go b/pgproto3/row_description.go similarity index 100% rename from row_description.go rename to pgproto3/row_description.go diff --git a/sasl_initial_response.go b/pgproto3/sasl_initial_response.go similarity index 100% rename from sasl_initial_response.go rename to pgproto3/sasl_initial_response.go diff --git a/sasl_response.go b/pgproto3/sasl_response.go similarity index 100% rename from sasl_response.go rename to pgproto3/sasl_response.go diff --git a/ssl_request.go b/pgproto3/ssl_request.go similarity index 100% rename from ssl_request.go rename to pgproto3/ssl_request.go diff --git a/startup_message.go b/pgproto3/startup_message.go similarity index 100% rename from startup_message.go rename to pgproto3/startup_message.go diff --git a/sync.go b/pgproto3/sync.go similarity index 100% rename from sync.go rename to pgproto3/sync.go diff --git a/terminate.go b/pgproto3/terminate.go similarity index 100% rename from terminate.go rename to pgproto3/terminate.go From 95cbbfe441740661a430c0f493ef60dae06b7483 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 21 Feb 2022 13:22:42 -0600 Subject: [PATCH 0919/1158] Import pgproto3 Also copy in pgmock as an internal package. --- conn.go | 2 +- go.mod | 4 +- go.sum | 96 ------------------- internal/pgmock/pgmock.go | 135 +++++++++++++++++++++++++++ internal/pgmock/pgmock_test.go | 93 ++++++++++++++++++ pgconn/auth_scram.go | 2 +- pgconn/config.go | 2 +- pgconn/frontend_test.go | 2 +- pgconn/pgconn.go | 2 +- pgconn/pgconn_test.go | 4 +- pgproto3/LICENSE | 22 ----- pgproto3/README.md | 5 - pgproto3/backend_test.go | 2 +- pgproto3/example/pgfortune/server.go | 2 +- pgproto3/frontend_test.go | 2 +- pgproto3/go.mod | 9 -- pgproto3/go.sum | 14 --- pgxpool/rows.go | 2 +- rows.go | 2 +- 19 files changed, 242 insertions(+), 160 deletions(-) create mode 100644 internal/pgmock/pgmock.go create mode 100644 internal/pgmock/pgmock_test.go delete mode 100644 pgproto3/LICENSE delete mode 100644 pgproto3/go.mod delete mode 100644 pgproto3/go.sum diff --git a/conn.go b/conn.go index 2a66a5ff..ba0d9d00 100644 --- a/conn.go +++ b/conn.go @@ -8,10 +8,10 @@ import ( "strings" "time" - "github.com/jackc/pgproto3/v2" "github.com/jackc/pgx/v5/internal/sanitize" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn/stmtcache" + "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" ) diff --git a/go.mod b/go.mod index a4a9b8b3..efcfd9af 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,7 @@ go 1.17 require ( github.com/jackc/chunkreader/v2 v2.0.1 github.com/jackc/pgio v1.0.0 - github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.2.0 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/jackc/puddle v1.2.1 github.com/stretchr/testify v1.7.0 @@ -17,6 +15,8 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/pretty v0.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index b055f738..0e049ef2 100644 --- a/go.sum +++ b/go.sum @@ -1,138 +1,42 @@ -github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= -github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= -github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= -github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= -github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= -github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= -github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= -github.com/jackc/pgconn v1.9.0 h1:gqibKSTJup/ahCsNKyMZAniPuZEfIqfXFc8FOWVYR+Q= -github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= -github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= -github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= -github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= -github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= -github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= -github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.2.0 h1:r7JypeP2D3onoQTCxWdTpCtJ4D+qpKr0TxvoyMhZ5ns= -github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= -github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= -github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= -github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= -github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= -github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= -github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= -github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.2.1 h1:gI8os0wpRXFd4FiAY2dWiqRK037tjj3t7rKFeO4X5iw= github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= -github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= -github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= -github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= -github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= -github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= -github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= -github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= -go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= -golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b h1:QAqMVf3pSa6eeTsuklijukjXBlj7Es2QQplab+/RbQ4= golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= diff --git a/internal/pgmock/pgmock.go b/internal/pgmock/pgmock.go new file mode 100644 index 00000000..97dd024d --- /dev/null +++ b/internal/pgmock/pgmock.go @@ -0,0 +1,135 @@ +// Package pgmock provides the ability to mock a PostgreSQL server. +package pgmock + +import ( + "fmt" + "io" + "reflect" + + "github.com/jackc/pgx/v5/pgproto3" +) + +type Step interface { + Step(*pgproto3.Backend) error +} + +type Script struct { + Steps []Step +} + +func (s *Script) Run(backend *pgproto3.Backend) error { + for _, step := range s.Steps { + err := step.Step(backend) + if err != nil { + return err + } + } + + return nil +} + +func (s *Script) Step(backend *pgproto3.Backend) error { + return s.Run(backend) +} + +type expectMessageStep struct { + want pgproto3.FrontendMessage + any bool +} + +func (e *expectMessageStep) Step(backend *pgproto3.Backend) error { + msg, err := backend.Receive() + if err != nil { + return err + } + + if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) { + return nil + } + + if !reflect.DeepEqual(msg, e.want) { + return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want) + } + + return nil +} + +type expectStartupMessageStep struct { + want *pgproto3.StartupMessage + any bool +} + +func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error { + msg, err := backend.ReceiveStartupMessage() + if err != nil { + return err + } + + if e.any { + return nil + } + + if !reflect.DeepEqual(msg, e.want) { + return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want) + } + + return nil +} + +func ExpectMessage(want pgproto3.FrontendMessage) Step { + return expectMessage(want, false) +} + +func ExpectAnyMessage(want pgproto3.FrontendMessage) Step { + return expectMessage(want, true) +} + +func expectMessage(want pgproto3.FrontendMessage, any bool) Step { + if want, ok := want.(*pgproto3.StartupMessage); ok { + return &expectStartupMessageStep{want: want, any: any} + } + + return &expectMessageStep{want: want, any: any} +} + +type sendMessageStep struct { + msg pgproto3.BackendMessage +} + +func (e *sendMessageStep) Step(backend *pgproto3.Backend) error { + return backend.Send(e.msg) +} + +func SendMessage(msg pgproto3.BackendMessage) Step { + return &sendMessageStep{msg: msg} +} + +type waitForCloseMessageStep struct{} + +func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error { + for { + msg, err := backend.Receive() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if _, ok := msg.(*pgproto3.Terminate); ok { + return nil + } + } +} + +func WaitForClose() Step { + return &waitForCloseMessageStep{} +} + +func AcceptUnauthenticatedConnRequestSteps() []Step { + return []Step{ + ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + SendMessage(&pgproto3.AuthenticationOk{}), + SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + } +} diff --git a/internal/pgmock/pgmock_test.go b/internal/pgmock/pgmock_test.go new file mode 100644 index 00000000..1e22cbcb --- /dev/null +++ b/internal/pgmock/pgmock_test.go @@ -0,0 +1,93 @@ +package pgmock_test + +import ( + "context" + "fmt" + "net" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5/internal/pgmock" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestScript(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "select 42"})) + script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + pgproto3.FieldDescription{ + Name: []byte("?column?"), + TableOID: 0, + TableAttributeNumber: 0, + DataTypeOID: 23, + DataTypeSize: 4, + TypeModifier: -1, + Format: 0, + }, + }, + })) + script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.DataRow{ + Values: [][]byte{[]byte("42")}, + })) + script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")})) + script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})) + script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Terminate{})) + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(time.Second)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + pgConn, err := pgconn.Connect(ctx, connStr) + require.NoError(t, err) + results, err := pgConn.Exec(ctx, "select 42").ReadAll() + assert.NoError(t, err) + + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "42", string(results[0].Rows[0][0])) + + pgConn.Close(ctx) + + assert.NoError(t, <-serverErrChan) +} diff --git a/pgconn/auth_scram.go b/pgconn/auth_scram.go index 6a143fcd..de13c687 100644 --- a/pgconn/auth_scram.go +++ b/pgconn/auth_scram.go @@ -22,7 +22,7 @@ import ( "fmt" "strconv" - "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgx/v5/pgproto3" "golang.org/x/crypto/pbkdf2" "golang.org/x/text/secure/precis" ) diff --git a/pgconn/config.go b/pgconn/config.go index 0eab23af..6c166834 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -19,8 +19,8 @@ import ( "github.com/jackc/chunkreader/v2" "github.com/jackc/pgpassfile" - "github.com/jackc/pgproto3/v2" "github.com/jackc/pgservicefile" + "github.com/jackc/pgx/v5/pgproto3" ) type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error diff --git a/pgconn/frontend_test.go b/pgconn/frontend_test.go index 9ea53b10..439d3251 100644 --- a/pgconn/frontend_test.go +++ b/pgconn/frontend_test.go @@ -6,8 +6,8 @@ import ( "os" "testing" - "github.com/jackc/pgproto3/v2" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 16d54f3a..20dd6858 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -17,8 +17,8 @@ import ( "time" "github.com/jackc/pgio" - "github.com/jackc/pgproto3/v2" "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" + "github.com/jackc/pgx/v5/pgproto3" ) const ( diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 4d975f32..42214d2c 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -18,9 +18,9 @@ import ( "testing" "time" - "github.com/jackc/pgmock" - "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgx/v5/internal/pgmock" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgproto3/LICENSE b/pgproto3/LICENSE deleted file mode 100644 index c1c4f50f..00000000 --- a/pgproto3/LICENSE +++ /dev/null @@ -1,22 +0,0 @@ -Copyright (c) 2019 Jack Christensen - -MIT License - -Permission is hereby granted, free of charge, to any person obtaining -a copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, -distribute, sublicense, and/or sell copies of the Software, and to -permit persons to whom the Software is furnished to do so, subject to -the following conditions: - -The above copyright notice and this permission notice shall be -included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/pgproto3/README.md b/pgproto3/README.md index 565b3efd..79d3a68b 100644 --- a/pgproto3/README.md +++ b/pgproto3/README.md @@ -1,6 +1,3 @@ -[![](https://godoc.org/github.com/jackc/pgproto3?status.svg)](https://godoc.org/github.com/jackc/pgproto3) -[![Build Status](https://travis-ci.org/jackc/pgproto3.svg)](https://travis-ci.org/jackc/pgproto3) - # pgproto3 Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. @@ -8,5 +5,3 @@ Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol versio pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more. See example/pgfortune for a playful example of a fake PostgreSQL server. - -Extracted from original implementation in https://github.com/jackc/pgx. diff --git a/pgproto3/backend_test.go b/pgproto3/backend_test.go index 708f1280..c3be614a 100644 --- a/pgproto3/backend_test.go +++ b/pgproto3/backend_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/jackc/pgio" - "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgx/v5/pgproto3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgproto3/example/pgfortune/server.go b/pgproto3/example/pgfortune/server.go index 777192a6..fe406452 100644 --- a/pgproto3/example/pgfortune/server.go +++ b/pgproto3/example/pgfortune/server.go @@ -4,7 +4,7 @@ import ( "fmt" "net" - "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgx/v5/pgproto3" ) type PgFortuneBackend struct { diff --git a/pgproto3/frontend_test.go b/pgproto3/frontend_test.go index d202451f..595877bd 100644 --- a/pgproto3/frontend_test.go +++ b/pgproto3/frontend_test.go @@ -4,7 +4,7 @@ import ( "io" "testing" - "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgx/v5/pgproto3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pgproto3/go.mod b/pgproto3/go.mod deleted file mode 100644 index 36041a94..00000000 --- a/pgproto3/go.mod +++ /dev/null @@ -1,9 +0,0 @@ -module github.com/jackc/pgproto3/v2 - -go 1.12 - -require ( - github.com/jackc/chunkreader/v2 v2.0.0 - github.com/jackc/pgio v1.0.0 - github.com/stretchr/testify v1.4.0 -) diff --git a/pgproto3/go.sum b/pgproto3/go.sum deleted file mode 100644 index dd9cd044..00000000 --- a/pgproto3/go.sum +++ /dev/null @@ -1,14 +0,0 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= -github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= -github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= -github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/pgxpool/rows.go b/pgxpool/rows.go index f3f24649..ff7ad80b 100644 --- a/pgxpool/rows.go +++ b/pgxpool/rows.go @@ -1,9 +1,9 @@ package pgxpool import ( - "github.com/jackc/pgproto3/v2" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" ) type errRows struct { diff --git a/rows.go b/rows.go index 805be100..d9b155e6 100644 --- a/rows.go +++ b/rows.go @@ -7,8 +7,8 @@ import ( "reflect" "time" - "github.com/jackc/pgproto3/v2" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" ) From fd1a98f85875ad45f74a7dd0a3c646bec0323437 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 21 Feb 2022 14:27:05 -0600 Subject: [PATCH 0920/1158] Move and clean for import --- .travis.yml | 9 -------- LICENSE | 22 ------------------- README.md | 8 ------- chunkreader.go => chunkreader/chunkreader.go | 0 .../chunkreader_test.go | 0 go.mod | 3 --- 6 files changed, 42 deletions(-) delete mode 100644 .travis.yml delete mode 100644 LICENSE delete mode 100644 README.md rename chunkreader.go => chunkreader/chunkreader.go (100%) rename chunkreader_test.go => chunkreader/chunkreader_test.go (100%) delete mode 100644 go.mod diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index e176228e..00000000 --- a/.travis.yml +++ /dev/null @@ -1,9 +0,0 @@ -language: go - -go: - - 1.x - - tip - -matrix: - allow_failures: - - go: tip diff --git a/LICENSE b/LICENSE deleted file mode 100644 index c1c4f50f..00000000 --- a/LICENSE +++ /dev/null @@ -1,22 +0,0 @@ -Copyright (c) 2019 Jack Christensen - -MIT License - -Permission is hereby granted, free of charge, to any person obtaining -a copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, -distribute, sublicense, and/or sell copies of the Software, and to -permit persons to whom the Software is furnished to do so, subject to -the following conditions: - -The above copyright notice and this permission notice shall be -included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md deleted file mode 100644 index 01209bfa..00000000 --- a/README.md +++ /dev/null @@ -1,8 +0,0 @@ -[![](https://godoc.org/github.com/jackc/chunkreader?status.svg)](https://godoc.org/github.com/jackc/chunkreader) -[![Build Status](https://travis-ci.org/jackc/chunkreader.svg)](https://travis-ci.org/jackc/chunkreader) - -# chunkreader - -Package chunkreader provides an io.Reader wrapper that minimizes IO reads and memory allocations. - -Extracted from original implementation in https://github.com/jackc/pgx. diff --git a/chunkreader.go b/chunkreader/chunkreader.go similarity index 100% rename from chunkreader.go rename to chunkreader/chunkreader.go diff --git a/chunkreader_test.go b/chunkreader/chunkreader_test.go similarity index 100% rename from chunkreader_test.go rename to chunkreader/chunkreader_test.go diff --git a/go.mod b/go.mod deleted file mode 100644 index a1384b40..00000000 --- a/go.mod +++ /dev/null @@ -1,3 +0,0 @@ -module github.com/jackc/chunkreader/v2 - -go 1.12 From 032ea5f5c0ffd93c006ca6156aa1cf8253297c4e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 21 Feb 2022 14:29:39 -0600 Subject: [PATCH 0921/1158] Finish import of chunkreader --- go.mod | 1 - go.sum | 2 -- pgconn/config.go | 2 +- pgproto3/chunkreader.go | 2 +- 4 files changed, 2 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index efcfd9af..b467e7f4 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/jackc/pgx/v5 go 1.17 require ( - github.com/jackc/chunkreader/v2 v2.0.1 github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b diff --git a/go.sum b/go.sum index 0e049ef2..0931401b 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= -github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= diff --git a/pgconn/config.go b/pgconn/config.go index 6c166834..6f7f3ca5 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -17,9 +17,9 @@ import ( "strings" "time" - "github.com/jackc/chunkreader/v2" "github.com/jackc/pgpassfile" "github.com/jackc/pgservicefile" + "github.com/jackc/pgx/v5/chunkreader" "github.com/jackc/pgx/v5/pgproto3" ) diff --git a/pgproto3/chunkreader.go b/pgproto3/chunkreader.go index 92206f35..3f878183 100644 --- a/pgproto3/chunkreader.go +++ b/pgproto3/chunkreader.go @@ -3,7 +3,7 @@ package pgproto3 import ( "io" - "github.com/jackc/chunkreader/v2" + "github.com/jackc/pgx/v5/chunkreader" ) // ChunkReader is an interface to decouple github.com/jackc/chunkreader from this package. From d35500e3979fc75bfe76d8e727223ea039edef76 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 21 Feb 2022 14:32:55 -0600 Subject: [PATCH 0922/1158] Move pgio --- .travis.yml | 9 --------- LICENSE | 22 ---------------------- README.md | 11 ----------- go.mod | 3 --- pgio/README.md | 6 ++++++ doc.go => pgio/doc.go | 0 write.go => pgio/write.go | 0 write_test.go => pgio/write_test.go | 0 8 files changed, 6 insertions(+), 45 deletions(-) delete mode 100644 .travis.yml delete mode 100644 LICENSE delete mode 100644 README.md delete mode 100644 go.mod create mode 100644 pgio/README.md rename doc.go => pgio/doc.go (100%) rename write.go => pgio/write.go (100%) rename write_test.go => pgio/write_test.go (100%) diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index e176228e..00000000 --- a/.travis.yml +++ /dev/null @@ -1,9 +0,0 @@ -language: go - -go: - - 1.x - - tip - -matrix: - allow_failures: - - go: tip diff --git a/LICENSE b/LICENSE deleted file mode 100644 index c1c4f50f..00000000 --- a/LICENSE +++ /dev/null @@ -1,22 +0,0 @@ -Copyright (c) 2019 Jack Christensen - -MIT License - -Permission is hereby granted, free of charge, to any person obtaining -a copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, -distribute, sublicense, and/or sell copies of the Software, and to -permit persons to whom the Software is furnished to do so, subject to -the following conditions: - -The above copyright notice and this permission notice shall be -included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md deleted file mode 100644 index 1952ed86..00000000 --- a/README.md +++ /dev/null @@ -1,11 +0,0 @@ -[![](https://godoc.org/github.com/jackc/pgio?status.svg)](https://godoc.org/github.com/jackc/pgio) -[![Build Status](https://travis-ci.org/jackc/pgio.svg)](https://travis-ci.org/jackc/pgio) - -# pgio - -Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol. - -pgio provides functions for appending integers to a []byte while doing byte -order conversion. - -Extracted from original implementation in https://github.com/jackc/pgx. diff --git a/go.mod b/go.mod deleted file mode 100644 index c1efdddb..00000000 --- a/go.mod +++ /dev/null @@ -1,3 +0,0 @@ -module github.com/jackc/pgio - -go 1.12 diff --git a/pgio/README.md b/pgio/README.md new file mode 100644 index 00000000..b2fc5801 --- /dev/null +++ b/pgio/README.md @@ -0,0 +1,6 @@ +# pgio + +Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol. + +pgio provides functions for appending integers to a []byte while doing byte +order conversion. diff --git a/doc.go b/pgio/doc.go similarity index 100% rename from doc.go rename to pgio/doc.go diff --git a/write.go b/pgio/write.go similarity index 100% rename from write.go rename to pgio/write.go diff --git a/write_test.go b/pgio/write_test.go similarity index 100% rename from write_test.go rename to pgio/write_test.go From d13f651810bfcd7534170b0a0b498dccfd9fd7cc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 21 Feb 2022 14:35:20 -0600 Subject: [PATCH 0923/1158] Finish importing pgio as internal package --- copy_from.go | 2 +- go.mod | 1 - go.sum | 2 -- {pgio => internal/pgio}/README.md | 0 {pgio => internal/pgio}/doc.go | 0 {pgio => internal/pgio}/write.go | 0 {pgio => internal/pgio}/write_test.go | 0 pgconn/pgconn.go | 2 +- pgproto3/authentication_cleartext_password.go | 2 +- pgproto3/authentication_md5_password.go | 2 +- pgproto3/authentication_ok.go | 2 +- pgproto3/authentication_sasl.go | 2 +- pgproto3/authentication_sasl_continue.go | 2 +- pgproto3/authentication_sasl_final.go | 2 +- pgproto3/backend_key_data.go | 2 +- pgproto3/backend_test.go | 2 +- pgproto3/bind.go | 2 +- pgproto3/cancel_request.go | 2 +- pgproto3/close.go | 2 +- pgproto3/command_complete.go | 2 +- pgproto3/copy_both_response.go | 2 +- pgproto3/copy_data.go | 2 +- pgproto3/copy_fail.go | 2 +- pgproto3/copy_in_response.go | 2 +- pgproto3/copy_out_response.go | 2 +- pgproto3/data_row.go | 2 +- pgproto3/describe.go | 2 +- pgproto3/execute.go | 2 +- pgproto3/function_call.go | 3 ++- pgproto3/function_call_response.go | 2 +- pgproto3/gss_enc_request.go | 2 +- pgproto3/notification_response.go | 2 +- pgproto3/parameter_description.go | 2 +- pgproto3/parameter_status.go | 2 +- pgproto3/parse.go | 2 +- pgproto3/password_message.go | 2 +- pgproto3/query.go | 2 +- pgproto3/row_description.go | 2 +- pgproto3/sasl_initial_response.go | 2 +- pgproto3/sasl_response.go | 2 +- pgproto3/ssl_request.go | 2 +- pgproto3/startup_message.go | 2 +- pgtype/array.go | 2 +- pgtype/array_codec.go | 2 +- pgtype/bits.go | 2 +- pgtype/box.go | 2 +- pgtype/circle.go | 2 +- pgtype/composite.go | 2 +- pgtype/date.go | 2 +- pgtype/float4.go | 2 +- pgtype/float8.go | 2 +- pgtype/hstore.go | 2 +- pgtype/int.go | 2 +- pgtype/int.go.erb | 2 +- pgtype/interval.go | 2 +- pgtype/line.go | 2 +- pgtype/lseg.go | 2 +- pgtype/numeric.go | 2 +- pgtype/path.go | 2 +- pgtype/point.go | 2 +- pgtype/polygon.go | 2 +- pgtype/range_codec.go | 2 +- pgtype/tid.go | 2 +- pgtype/time.go | 2 +- pgtype/timestamp.go | 2 +- pgtype/timestamptz.go | 2 +- pgtype/uint32.go | 2 +- values.go | 2 +- 68 files changed, 63 insertions(+), 65 deletions(-) rename {pgio => internal/pgio}/README.md (100%) rename {pgio => internal/pgio}/doc.go (100%) rename {pgio => internal/pgio}/write.go (100%) rename {pgio => internal/pgio}/write_test.go (100%) diff --git a/copy_from.go b/copy_from.go index 8eb3c111..7d6a8813 100644 --- a/copy_from.go +++ b/copy_from.go @@ -7,7 +7,7 @@ import ( "io" "time" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgconn" ) diff --git a/go.mod b/go.mod index b467e7f4..79cbd50d 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/jackc/pgx/v5 go 1.17 require ( - github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/jackc/puddle v1.2.1 diff --git a/go.sum b/go.sum index 0931401b..a9851d43 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= -github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= diff --git a/pgio/README.md b/internal/pgio/README.md similarity index 100% rename from pgio/README.md rename to internal/pgio/README.md diff --git a/pgio/doc.go b/internal/pgio/doc.go similarity index 100% rename from pgio/doc.go rename to internal/pgio/doc.go diff --git a/pgio/write.go b/internal/pgio/write.go similarity index 100% rename from pgio/write.go rename to internal/pgio/write.go diff --git a/pgio/write_test.go b/internal/pgio/write_test.go similarity index 100% rename from pgio/write_test.go rename to internal/pgio/write_test.go diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 20dd6858..a9a6de8c 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -16,7 +16,7 @@ import ( "sync" "time" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" "github.com/jackc/pgx/v5/pgproto3" ) diff --git a/pgproto3/authentication_cleartext_password.go b/pgproto3/authentication_cleartext_password.go index 241fa600..d8f98b9a 100644 --- a/pgproto3/authentication_cleartext_password.go +++ b/pgproto3/authentication_cleartext_password.go @@ -5,7 +5,7 @@ import ( "encoding/json" "errors" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) // AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required. diff --git a/pgproto3/authentication_md5_password.go b/pgproto3/authentication_md5_password.go index 32ec0390..5671c84c 100644 --- a/pgproto3/authentication_md5_password.go +++ b/pgproto3/authentication_md5_password.go @@ -5,7 +5,7 @@ import ( "encoding/json" "errors" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) // AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required. diff --git a/pgproto3/authentication_ok.go b/pgproto3/authentication_ok.go index 2b476fe5..88d648ae 100644 --- a/pgproto3/authentication_ok.go +++ b/pgproto3/authentication_ok.go @@ -5,7 +5,7 @@ import ( "encoding/json" "errors" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) // AuthenticationOk is a message sent from the backend indicating that authentication was successful. diff --git a/pgproto3/authentication_sasl.go b/pgproto3/authentication_sasl.go index bdcb2c36..996b97d3 100644 --- a/pgproto3/authentication_sasl.go +++ b/pgproto3/authentication_sasl.go @@ -6,7 +6,7 @@ import ( "encoding/json" "errors" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) // AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required. diff --git a/pgproto3/authentication_sasl_continue.go b/pgproto3/authentication_sasl_continue.go index 7f4a9c23..2ce70a47 100644 --- a/pgproto3/authentication_sasl_continue.go +++ b/pgproto3/authentication_sasl_continue.go @@ -5,7 +5,7 @@ import ( "encoding/json" "errors" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) // AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge. diff --git a/pgproto3/authentication_sasl_final.go b/pgproto3/authentication_sasl_final.go index d82b9ee4..a38a8b91 100644 --- a/pgproto3/authentication_sasl_final.go +++ b/pgproto3/authentication_sasl_final.go @@ -5,7 +5,7 @@ import ( "encoding/json" "errors" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) // AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed. diff --git a/pgproto3/backend_key_data.go b/pgproto3/backend_key_data.go index ca20dd25..12c60817 100644 --- a/pgproto3/backend_key_data.go +++ b/pgproto3/backend_key_data.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type BackendKeyData struct { diff --git a/pgproto3/backend_test.go b/pgproto3/backend_test.go index c3be614a..75755f22 100644 --- a/pgproto3/backend_test.go +++ b/pgproto3/backend_test.go @@ -4,7 +4,7 @@ import ( "io" "testing" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgproto3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/pgproto3/bind.go b/pgproto3/bind.go index e9664f59..fdd2d3b8 100644 --- a/pgproto3/bind.go +++ b/pgproto3/bind.go @@ -7,7 +7,7 @@ import ( "encoding/json" "fmt" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type Bind struct { diff --git a/pgproto3/cancel_request.go b/pgproto3/cancel_request.go index 942e404b..8fcf8217 100644 --- a/pgproto3/cancel_request.go +++ b/pgproto3/cancel_request.go @@ -5,7 +5,7 @@ import ( "encoding/json" "errors" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) const cancelRequestCode = 80877102 diff --git a/pgproto3/close.go b/pgproto3/close.go index a45f2b93..f99b5943 100644 --- a/pgproto3/close.go +++ b/pgproto3/close.go @@ -5,7 +5,7 @@ import ( "encoding/json" "errors" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type Close struct { diff --git a/pgproto3/command_complete.go b/pgproto3/command_complete.go index cdc49f39..a19b906c 100644 --- a/pgproto3/command_complete.go +++ b/pgproto3/command_complete.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type CommandComplete struct { diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go index fbd985d8..3243dbc1 100644 --- a/pgproto3/copy_both_response.go +++ b/pgproto3/copy_both_response.go @@ -6,7 +6,7 @@ import ( "encoding/json" "errors" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type CopyBothResponse struct { diff --git a/pgproto3/copy_data.go b/pgproto3/copy_data.go index 128aa198..59e3dd94 100644 --- a/pgproto3/copy_data.go +++ b/pgproto3/copy_data.go @@ -4,7 +4,7 @@ import ( "encoding/hex" "encoding/json" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type CopyData struct { diff --git a/pgproto3/copy_fail.go b/pgproto3/copy_fail.go index 78ff0b30..0041bbb1 100644 --- a/pgproto3/copy_fail.go +++ b/pgproto3/copy_fail.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type CopyFail struct { diff --git a/pgproto3/copy_in_response.go b/pgproto3/copy_in_response.go index 80733adc..4584f7df 100644 --- a/pgproto3/copy_in_response.go +++ b/pgproto3/copy_in_response.go @@ -6,7 +6,7 @@ import ( "encoding/json" "errors" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type CopyInResponse struct { diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go index 5e607e3a..3175c6a4 100644 --- a/pgproto3/copy_out_response.go +++ b/pgproto3/copy_out_response.go @@ -6,7 +6,7 @@ import ( "encoding/json" "errors" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type CopyOutResponse struct { diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go index 63768761..0bfe9a0d 100644 --- a/pgproto3/data_row.go +++ b/pgproto3/data_row.go @@ -5,7 +5,7 @@ import ( "encoding/hex" "encoding/json" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type DataRow struct { diff --git a/pgproto3/describe.go b/pgproto3/describe.go index 0d825db1..f131d1f4 100644 --- a/pgproto3/describe.go +++ b/pgproto3/describe.go @@ -5,7 +5,7 @@ import ( "encoding/json" "errors" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type Describe struct { diff --git a/pgproto3/execute.go b/pgproto3/execute.go index 8bae6133..a5fee7cb 100644 --- a/pgproto3/execute.go +++ b/pgproto3/execute.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type Execute struct { diff --git a/pgproto3/function_call.go b/pgproto3/function_call.go index b3a22c4f..2c4f38df 100644 --- a/pgproto3/function_call.go +++ b/pgproto3/function_call.go @@ -2,7 +2,8 @@ package pgproto3 import ( "encoding/binary" - "github.com/jackc/pgio" + + "github.com/jackc/pgx/v5/internal/pgio" ) type FunctionCall struct { diff --git a/pgproto3/function_call_response.go b/pgproto3/function_call_response.go index 53d64222..3d3606dd 100644 --- a/pgproto3/function_call_response.go +++ b/pgproto3/function_call_response.go @@ -5,7 +5,7 @@ import ( "encoding/hex" "encoding/json" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type FunctionCallResponse struct { diff --git a/pgproto3/gss_enc_request.go b/pgproto3/gss_enc_request.go index cf405a3e..30ffc08d 100644 --- a/pgproto3/gss_enc_request.go +++ b/pgproto3/gss_enc_request.go @@ -5,7 +5,7 @@ import ( "encoding/json" "errors" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) const gssEncReqNumber = 80877104 diff --git a/pgproto3/notification_response.go b/pgproto3/notification_response.go index e762eb96..03ce51e5 100644 --- a/pgproto3/notification_response.go +++ b/pgproto3/notification_response.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type NotificationResponse struct { diff --git a/pgproto3/parameter_description.go b/pgproto3/parameter_description.go index e28965c8..374d38a3 100644 --- a/pgproto3/parameter_description.go +++ b/pgproto3/parameter_description.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type ParameterDescription struct { diff --git a/pgproto3/parameter_status.go b/pgproto3/parameter_status.go index c4021d92..a303e453 100644 --- a/pgproto3/parameter_status.go +++ b/pgproto3/parameter_status.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type ParameterStatus struct { diff --git a/pgproto3/parse.go b/pgproto3/parse.go index 723885d4..b53200dc 100644 --- a/pgproto3/parse.go +++ b/pgproto3/parse.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type Parse struct { diff --git a/pgproto3/password_message.go b/pgproto3/password_message.go index cae76c50..41f98692 100644 --- a/pgproto3/password_message.go +++ b/pgproto3/password_message.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type PasswordMessage struct { diff --git a/pgproto3/query.go b/pgproto3/query.go index 41c93b4a..e963a0ec 100644 --- a/pgproto3/query.go +++ b/pgproto3/query.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type Query struct { diff --git a/pgproto3/row_description.go b/pgproto3/row_description.go index a2e0d28e..6f6f0681 100644 --- a/pgproto3/row_description.go +++ b/pgproto3/row_description.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) const ( diff --git a/pgproto3/sasl_initial_response.go b/pgproto3/sasl_initial_response.go index f862f2a8..eeda4691 100644 --- a/pgproto3/sasl_initial_response.go +++ b/pgproto3/sasl_initial_response.go @@ -6,7 +6,7 @@ import ( "encoding/json" "errors" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type SASLInitialResponse struct { diff --git a/pgproto3/sasl_response.go b/pgproto3/sasl_response.go index d402759a..54c3d96f 100644 --- a/pgproto3/sasl_response.go +++ b/pgproto3/sasl_response.go @@ -4,7 +4,7 @@ import ( "encoding/hex" "encoding/json" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type SASLResponse struct { diff --git a/pgproto3/ssl_request.go b/pgproto3/ssl_request.go index 96ce489e..1b00c16b 100644 --- a/pgproto3/ssl_request.go +++ b/pgproto3/ssl_request.go @@ -5,7 +5,7 @@ import ( "encoding/json" "errors" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) const sslRequestNumber = 80877103 diff --git a/pgproto3/startup_message.go b/pgproto3/startup_message.go index 5f1cd24f..5c974f02 100644 --- a/pgproto3/startup_message.go +++ b/pgproto3/startup_message.go @@ -7,7 +7,7 @@ import ( "errors" "fmt" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) const ProtocolVersionNumber = 196608 // 3.0 diff --git a/pgtype/array.go b/pgtype/array.go index 3648f385..8de2b4dd 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -10,7 +10,7 @@ import ( "strings" "unicode" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) // Information on the internals of PostgreSQL arrays can be found in diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 74bcc5d3..84012083 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -6,7 +6,7 @@ import ( "fmt" "reflect" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) // ArrayGetter is a type that can be converted into a PostgreSQL array. diff --git a/pgtype/bits.go b/pgtype/bits.go index 12df03d5..5b0671ca 100644 --- a/pgtype/bits.go +++ b/pgtype/bits.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "fmt" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type BitsScanner interface { diff --git a/pgtype/box.go b/pgtype/box.go index 25d4f153..d6087eab 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type BoxScanner interface { diff --git a/pgtype/circle.go b/pgtype/circle.go index 6dfb4fae..4b499a12 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type CircleScanner interface { diff --git a/pgtype/composite.go b/pgtype/composite.go index 5a67c3df..af6ed28b 100644 --- a/pgtype/composite.go +++ b/pgtype/composite.go @@ -7,7 +7,7 @@ import ( "fmt" "strings" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) // CompositeIndexGetter is a type accessed by index that can be converted into a PostgreSQL composite. diff --git a/pgtype/date.go b/pgtype/date.go index f59508d3..1d27fc78 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -7,7 +7,7 @@ import ( "fmt" "time" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type DateScanner interface { diff --git a/pgtype/float4.go b/pgtype/float4.go index 9b31579f..127eb56a 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -7,7 +7,7 @@ import ( "math" "strconv" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type Float4 struct { diff --git a/pgtype/float8.go b/pgtype/float8.go index 30548b88..b8b962b2 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -7,7 +7,7 @@ import ( "math" "strconv" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type Float64Scanner interface { diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 46b3d236..7f0fa8c2 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -10,7 +10,7 @@ import ( "unicode" "unicode/utf8" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type HstoreScanner interface { diff --git a/pgtype/int.go b/pgtype/int.go index a799f2bf..ebac1403 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -9,7 +9,7 @@ import ( "math" "strconv" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type Int64Scanner interface { diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index d3c519a7..3b5b14a9 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -7,7 +7,7 @@ import ( "math" "strconv" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type Int64Scanner interface { diff --git a/pgtype/interval.go b/pgtype/interval.go index b4dcf0a6..882fd6d6 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -7,7 +7,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) const ( diff --git a/pgtype/line.go b/pgtype/line.go index c9cac4a7..087c7688 100644 --- a/pgtype/line.go +++ b/pgtype/line.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type LineScanner interface { diff --git a/pgtype/lseg.go b/pgtype/lseg.go index 4243f6e0..f5cf888e 100644 --- a/pgtype/lseg.go +++ b/pgtype/lseg.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type LsegScanner interface { diff --git a/pgtype/numeric.go b/pgtype/numeric.go index 41bf1432..58707a02 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -10,7 +10,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) // PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 diff --git a/pgtype/path.go b/pgtype/path.go index 3b8e598e..10767404 100644 --- a/pgtype/path.go +++ b/pgtype/path.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type PathScanner interface { diff --git a/pgtype/point.go b/pgtype/point.go index 8df57703..d2ddaf2f 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -9,7 +9,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type Vec2 struct { diff --git a/pgtype/polygon.go b/pgtype/polygon.go index ca479965..a7a6d606 100644 --- a/pgtype/polygon.go +++ b/pgtype/polygon.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type PolygonScanner interface { diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go index 0fa43a68..207e3b39 100644 --- a/pgtype/range_codec.go +++ b/pgtype/range_codec.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "fmt" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) // RangeValuer is a type that can be converted into a PostgreSQL range. diff --git a/pgtype/tid.go b/pgtype/tid.go index 5faa7502..6eefd34e 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -7,7 +7,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type TIDScanner interface { diff --git a/pgtype/time.go b/pgtype/time.go index dc40f1fc..9005b848 100644 --- a/pgtype/time.go +++ b/pgtype/time.go @@ -6,7 +6,7 @@ import ( "fmt" "strconv" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type TimeScanner interface { diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 03fb2b28..3a0bd275 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) const pgTimestampFormat = "2006-01-02 15:04:05.999999999" diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 8be5970c..5069af02 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -7,7 +7,7 @@ import ( "fmt" "time" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" diff --git a/pgtype/uint32.go b/pgtype/uint32.go index d406f79b..297ca5c2 100644 --- a/pgtype/uint32.go +++ b/pgtype/uint32.go @@ -7,7 +7,7 @@ import ( "math" "strconv" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type Uint32Scanner interface { diff --git a/values.go b/values.go index 075ac2ff..67363986 100644 --- a/values.go +++ b/values.go @@ -7,7 +7,7 @@ import ( "reflect" "time" - "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgtype" ) From 2e0ec225def2ba825e185e8813fe8d9ade63cfde Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Feb 2022 08:50:46 -0600 Subject: [PATCH 0924/1158] Make Chunkreader an internal implementation detail --- chunkreader/chunkreader.go | 104 ------------------ internal/pgmock/pgmock_test.go | 2 +- pgconn/config.go | 25 +---- pgconn/config_test.go | 40 +++---- pgconn/defaults.go | 3 +- pgconn/defaults_windows.go | 2 - pgconn/pgconn_test.go | 4 +- pgproto3/backend.go | 14 +-- pgproto3/backend_test.go | 8 +- pgproto3/chunkreader.go | 90 +++++++++++++-- {chunkreader => pgproto3}/chunkreader_test.go | 22 +--- pgproto3/example/pgfortune/server.go | 2 +- pgproto3/frontend.go | 5 +- pgproto3/frontend_test.go | 6 +- 14 files changed, 124 insertions(+), 203 deletions(-) delete mode 100644 chunkreader/chunkreader.go rename {chunkreader => pgproto3}/chunkreader_test.go (87%) diff --git a/chunkreader/chunkreader.go b/chunkreader/chunkreader.go deleted file mode 100644 index afea1c52..00000000 --- a/chunkreader/chunkreader.go +++ /dev/null @@ -1,104 +0,0 @@ -// Package chunkreader provides an io.Reader wrapper that minimizes IO reads and memory allocations. -package chunkreader - -import ( - "io" -) - -// ChunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and -// will read as much as will fit in the current buffer in a single call regardless of how large a read is actually -// requested. The memory returned via Next is owned by the caller. This avoids the need for an additional copy. -// -// The downside of this approach is that a large buffer can be pinned in memory even if only a small slice is -// referenced. For example, an entire 4096 byte block could be pinned in memory by even a 1 byte slice. In these rare -// cases it would be advantageous to copy the bytes to another slice. -type ChunkReader struct { - r io.Reader - - buf []byte - rp, wp int // buf read position and write position - - config Config -} - -// Config contains configuration parameters for ChunkReader. -type Config struct { - MinBufLen int // Minimum buffer length -} - -// New creates and returns a new ChunkReader for r with default configuration. -func New(r io.Reader) *ChunkReader { - cr, err := NewConfig(r, Config{}) - if err != nil { - panic("default config can't be bad") - } - - return cr -} - -// NewConfig creates and a new ChunkReader for r configured by config. -func NewConfig(r io.Reader, config Config) (*ChunkReader, error) { - if config.MinBufLen == 0 { - // By historical reasons Postgres currently has 8KB send buffer inside, - // so here we want to have at least the same size buffer. - // @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134 - // @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru - config.MinBufLen = 8192 - } - - return &ChunkReader{ - r: r, - buf: make([]byte, config.MinBufLen), - config: config, - }, nil -} - -// Next returns buf filled with the next n bytes. The caller gains ownership of buf. It is not necessary to make a copy -// of buf. If an error occurs, buf will be nil. -func (r *ChunkReader) Next(n int) (buf []byte, err error) { - // n bytes already in buf - if (r.wp - r.rp) >= n { - buf = r.buf[r.rp : r.rp+n] - r.rp += n - return buf, err - } - - // available space in buf is less than n - if len(r.buf) < n { - r.copyBufContents(r.newBuf(n)) - } - - // buf is large enough, but need to shift filled area to start to make enough contiguous space - minReadCount := n - (r.wp - r.rp) - if (len(r.buf) - r.wp) < minReadCount { - newBuf := r.newBuf(n) - r.copyBufContents(newBuf) - } - - if err := r.appendAtLeast(minReadCount); err != nil { - return nil, err - } - - buf = r.buf[r.rp : r.rp+n] - r.rp += n - return buf, nil -} - -func (r *ChunkReader) appendAtLeast(fillLen int) error { - n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen) - r.wp += n - return err -} - -func (r *ChunkReader) newBuf(size int) []byte { - if size < r.config.MinBufLen { - size = r.config.MinBufLen - } - return make([]byte, size) -} - -func (r *ChunkReader) copyBufContents(dest []byte) { - r.wp = copy(dest, r.buf[r.rp:r.wp]) - r.rp = 0 - r.buf = dest -} diff --git a/internal/pgmock/pgmock_test.go b/internal/pgmock/pgmock_test.go index 1e22cbcb..bc787398 100644 --- a/internal/pgmock/pgmock_test.go +++ b/internal/pgmock/pgmock_test.go @@ -62,7 +62,7 @@ func TestScript(t *testing.T) { return } - err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) + err = script.Run(pgproto3.NewBackend(conn, conn)) if err != nil { serverErrChan <- err return diff --git a/pgconn/config.go b/pgconn/config.go index 6f7f3ca5..e99dabfb 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -19,7 +19,6 @@ import ( "github.com/jackc/pgpassfile" "github.com/jackc/pgservicefile" - "github.com/jackc/pgx/v5/chunkreader" "github.com/jackc/pgx/v5/pgproto3" ) @@ -183,8 +182,6 @@ func NetworkAddress(host string, port uint16) (network, address string) { // // In addition, ParseConfig accepts the following options: // -// min_read_buffer_size -// The minimum size of the internal read buffer. Default 8192. // servicefile // libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a // part of the connection string. @@ -219,18 +216,15 @@ func ParseConfig(connString string) (*Config, error) { settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) } - minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) - if err != nil { - return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err} - } - config := &Config{ createdByParseConfig: true, Database: settings["database"], User: settings["user"], Password: settings["password"], RuntimeParams: make(map[string]string), - BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)), + BuildFrontend: func(r io.Reader, w io.Writer) Frontend { + return pgproto3.NewFrontend(r, w) + }, } if connectTimeoutSetting, present := settings["connect_timeout"]; present { @@ -260,7 +254,6 @@ func ParseConfig(connString string) (*Config, error) { "sslcert": {}, "sslrootcert": {}, "target_session_attrs": {}, - "min_read_buffer_size": {}, "service": {}, "servicefile": {}, } @@ -693,18 +686,6 @@ func makeDefaultResolver() *net.Resolver { return net.DefaultResolver } -func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { - return func(r io.Reader, w io.Writer) Frontend { - cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen}) - if err != nil { - panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err)) - } - frontend := pgproto3.NewFrontend(cr, w) - - return frontend - } -} - func parseConnectTimeoutSetting(s string) (time.Duration, error) { timeout, err := strconv.ParseInt(s, 10, 64) if err != nil { diff --git a/pgconn/config_test.go b/pgconn/config_test.go index 335a25ca..40db7bc2 100644 --- a/pgconn/config_test.go +++ b/pgconn/config_test.go @@ -572,13 +572,13 @@ func TestParseConfig(t *testing.T) { name: "target_session_attrs primary", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=primary", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPrimary, }, }, @@ -586,13 +586,13 @@ func TestParseConfig(t *testing.T) { name: "target_session_attrs standby", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=standby", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsStandby, }, }, @@ -967,15 +967,3 @@ application_name = spaced string assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) } } - -func TestParseConfigExtractsMinReadBufferSize(t *testing.T) { - t.Parallel() - - config, err := pgconn.ParseConfig("min_read_buffer_size=0") - require.NoError(t, err) - _, present := config.RuntimeParams["min_read_buffer_size"] - require.False(t, present) - - // The buffer size is internal so there isn't much that can be done to test it other than see that the runtime param - // was removed. -} diff --git a/pgconn/defaults.go b/pgconn/defaults.go index f69cad31..1dd514ff 100644 --- a/pgconn/defaults.go +++ b/pgconn/defaults.go @@ -1,3 +1,4 @@ +//go:build !windows // +build !windows package pgconn @@ -39,8 +40,6 @@ func defaultSettings() map[string]string { settings["target_session_attrs"] = "any" - settings["min_read_buffer_size"] = "8192" - return settings } diff --git a/pgconn/defaults_windows.go b/pgconn/defaults_windows.go index 71eb77db..33b4a1ff 100644 --- a/pgconn/defaults_windows.go +++ b/pgconn/defaults_windows.go @@ -46,8 +46,6 @@ func defaultSettings() map[string]string { settings["target_session_attrs"] = "any" - settings["min_read_buffer_size"] = "8192" - return settings } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 42214d2c..3ae0d1d4 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -141,7 +141,7 @@ func TestConnectTimeout(t *testing.T) { return } - err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) + err = script.Run(pgproto3.NewBackend(conn, conn)) if err != nil { serverErrChan <- err return @@ -2044,7 +2044,7 @@ func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) { return } - err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) + err = script.Run(pgproto3.NewBackend(conn, conn)) if err != nil { serverErrChan <- err return diff --git a/pgproto3/backend.go b/pgproto3/backend.go index 9c42ad02..c8d2f331 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -8,7 +8,7 @@ import ( // Backend acts as a server for the PostgreSQL wire protocol version 3. type Backend struct { - cr ChunkReader + cr *chunkReader w io.Writer // Frontend message flyweights @@ -30,11 +30,10 @@ type Backend struct { sync Sync terminate Terminate - bodyLen int - msgType byte - partialMsg bool - authType uint32 - + bodyLen int + msgType byte + partialMsg bool + authType uint32 } const ( @@ -43,7 +42,8 @@ const ( ) // NewBackend creates a new Backend. -func NewBackend(cr ChunkReader, w io.Writer) *Backend { +func NewBackend(r io.Reader, w io.Writer) *Backend { + cr := newChunkReader(r, 0) return &Backend{cr: cr, w: w} } diff --git a/pgproto3/backend_test.go b/pgproto3/backend_test.go index 75755f22..596245dd 100644 --- a/pgproto3/backend_test.go +++ b/pgproto3/backend_test.go @@ -16,7 +16,7 @@ func TestBackendReceiveInterrupted(t *testing.T) { server := &interruptReader{} server.push([]byte{'Q', 0, 0, 0, 6}) - backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) + backend := pgproto3.NewBackend(server, nil) msg, err := backend.Receive() if err == nil { @@ -43,7 +43,7 @@ func TestBackendReceiveUnexpectedEOF(t *testing.T) { server := &interruptReader{} server.push([]byte{'Q', 0, 0, 0, 6}) - backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) + backend := pgproto3.NewBackend(server, nil) // Receive regular msg msg, err := backend.Receive() @@ -77,7 +77,7 @@ func TestStartupMessage(t *testing.T) { server := &interruptReader{} server.push(dst) - backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) + backend := pgproto3.NewBackend(server, nil) msg, err := backend.ReceiveStartupMessage() require.NoError(t, err) @@ -110,7 +110,7 @@ func TestStartupMessage(t *testing.T) { dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber) server.push(dst) - backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) + backend := pgproto3.NewBackend(server, nil) msg, err := backend.ReceiveStartupMessage() require.Error(t, err) diff --git a/pgproto3/chunkreader.go b/pgproto3/chunkreader.go index 3f878183..1781d6cd 100644 --- a/pgproto3/chunkreader.go +++ b/pgproto3/chunkreader.go @@ -2,18 +2,88 @@ package pgproto3 import ( "io" - - "github.com/jackc/pgx/v5/chunkreader" ) -// ChunkReader is an interface to decouple github.com/jackc/chunkreader from this package. -type ChunkReader interface { - // Next returns buf filled with the next n bytes. If an error (including a partial read) occurs, - // buf must be nil. Next must preserve any partially read data. Next must not reuse buf. - Next(n int) (buf []byte, err error) +// chunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and +// will read as much as will fit in the current buffer in a single call regardless of how large a read is actually +// requested. The memory returned via Next is owned by the caller. This avoids the need for an additional copy. +// +// The downside of this approach is that a large buffer can be pinned in memory even if only a small slice is +// referenced. For example, an entire 4096 byte block could be pinned in memory by even a 1 byte slice. In these rare +// cases it would be advantageous to copy the bytes to another slice. +type chunkReader struct { + r io.Reader + + buf []byte + rp, wp int // buf read position and write position + + minBufLen int } -// NewChunkReader creates and returns a new default ChunkReader. -func NewChunkReader(r io.Reader) ChunkReader { - return chunkreader.New(r) +// newChunkReader creates and returns a new chunkReader for r with default configuration with minBufSize internal buffer. +// If bufSize is <= 0 it uses a default value. +func newChunkReader(r io.Reader, minBufSize int) *chunkReader { + if minBufSize <= 0 { + // By historical reasons Postgres currently has 8KB send buffer inside, + // so here we want to have at least the same size buffer. + // @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134 + // @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru + minBufSize = 8192 + } + + return &chunkReader{ + r: r, + buf: make([]byte, minBufSize), + minBufLen: minBufSize, + } +} + +// Next returns buf filled with the next n bytes. The caller gains ownership of buf. It is not necessary to make a copy +// of buf. If an error occurs, buf will be nil. +func (r *chunkReader) Next(n int) (buf []byte, err error) { + // n bytes already in buf + if (r.wp - r.rp) >= n { + buf = r.buf[r.rp : r.rp+n] + r.rp += n + return buf, err + } + + // available space in buf is less than n + if len(r.buf) < n { + r.copyBufContents(r.newBuf(n)) + } + + // buf is large enough, but need to shift filled area to start to make enough contiguous space + minReadCount := n - (r.wp - r.rp) + if (len(r.buf) - r.wp) < minReadCount { + newBuf := r.newBuf(n) + r.copyBufContents(newBuf) + } + + if err := r.appendAtLeast(minReadCount); err != nil { + return nil, err + } + + buf = r.buf[r.rp : r.rp+n] + r.rp += n + return buf, nil +} + +func (r *chunkReader) appendAtLeast(fillLen int) error { + n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen) + r.wp += n + return err +} + +func (r *chunkReader) newBuf(size int) []byte { + if size < r.minBufLen { + size = r.minBufLen + } + return make([]byte, size) +} + +func (r *chunkReader) copyBufContents(dest []byte) { + r.wp = copy(dest, r.buf[r.rp:r.wp]) + r.rp = 0 + r.buf = dest } diff --git a/chunkreader/chunkreader_test.go b/pgproto3/chunkreader_test.go similarity index 87% rename from chunkreader/chunkreader_test.go rename to pgproto3/chunkreader_test.go index ddc2fbf6..86fbd8b2 100644 --- a/chunkreader/chunkreader_test.go +++ b/pgproto3/chunkreader_test.go @@ -1,4 +1,4 @@ -package chunkreader +package pgproto3 import ( "bytes" @@ -8,10 +8,7 @@ import ( func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { server := &bytes.Buffer{} - r, err := NewConfig(server, Config{MinBufLen: 4}) - if err != nil { - t.Fatal(err) - } + r := newChunkReader(server, 4) src := []byte{1, 2, 3, 4} server.Write(src) @@ -45,10 +42,7 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { server := &bytes.Buffer{} - r, err := NewConfig(server, Config{MinBufLen: 4}) - if err != nil { - t.Fatal(err) - } + r := newChunkReader(server, 4) src := []byte{1, 2, 3, 4, 5, 6, 7, 8} server.Write(src) @@ -67,10 +61,7 @@ func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { func TestChunkReaderDoesNotReuseBuf(t *testing.T) { server := &bytes.Buffer{} - r, err := NewConfig(server, Config{MinBufLen: 4}) - if err != nil { - t.Fatal(err) - } + r := newChunkReader(server, 4) src := []byte{1, 2, 3, 4, 5, 6, 7, 8} server.Write(src) @@ -108,10 +99,7 @@ func (r *randomReader) Read(p []byte) (n int, err error) { func TestChunkReaderNextFuzz(t *testing.T) { rr := &randomReader{rnd: rand.New(rand.NewSource(1))} - r, err := NewConfig(rr, Config{MinBufLen: 8192}) - if err != nil { - t.Fatal(err) - } + r := newChunkReader(rr, 8192) randomSizes := rand.New(rand.NewSource(0)) diff --git a/pgproto3/example/pgfortune/server.go b/pgproto3/example/pgfortune/server.go index fe406452..14ae71f8 100644 --- a/pgproto3/example/pgfortune/server.go +++ b/pgproto3/example/pgfortune/server.go @@ -14,7 +14,7 @@ type PgFortuneBackend struct { } func NewPgFortuneBackend(conn net.Conn, responder func() ([]byte, error)) *PgFortuneBackend { - backend := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn) + backend := pgproto3.NewBackend(conn, conn) connHandler := &PgFortuneBackend{ backend: backend, diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index c33dfb08..ea6757ad 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -9,7 +9,7 @@ import ( // Frontend acts as a client for the PostgreSQL wire protocol version 3. type Frontend struct { - cr ChunkReader + cr *chunkReader w io.Writer // Backend message flyweights @@ -49,7 +49,8 @@ type Frontend struct { } // NewFrontend creates a new Frontend. -func NewFrontend(cr ChunkReader, w io.Writer) *Frontend { +func NewFrontend(r io.Reader, w io.Writer) *Frontend { + cr := newChunkReader(r, 0) return &Frontend{cr: cr, w: w} } diff --git a/pgproto3/frontend_test.go b/pgproto3/frontend_test.go index 595877bd..e02457d6 100644 --- a/pgproto3/frontend_test.go +++ b/pgproto3/frontend_test.go @@ -38,7 +38,7 @@ func TestFrontendReceiveInterrupted(t *testing.T) { server := &interruptReader{} server.push([]byte{'Z', 0, 0, 0, 5}) - frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil) + frontend := pgproto3.NewFrontend(server, nil) msg, err := frontend.Receive() if err == nil { @@ -65,7 +65,7 @@ func TestFrontendReceiveUnexpectedEOF(t *testing.T) { server := &interruptReader{} server.push([]byte{'Z', 0, 0, 0, 5}) - frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil) + frontend := pgproto3.NewFrontend(server, nil) msg, err := frontend.Receive() if err == nil { @@ -109,7 +109,7 @@ func TestErrorResponse(t *testing.T) { server := &interruptReader{} server.push(raw) - frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil) + frontend := pgproto3.NewFrontend(server, nil) got, err := frontend.Receive() require.NoError(t, err) From e641d0a5add7bb612bdff3234f5f0432c3e41194 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Feb 2022 09:31:45 -0600 Subject: [PATCH 0925/1158] Reuse connection read buffer To avoid extra copies and small allocations previously large read buffers were allocated and never reused. However, the down side of this was greater total memory allocation and the possibility that a reference to a single byte could pin an entire buffer. Now the buffer is reused. --- pgconn/pgconn.go | 8 ++++++-- pgproto3/chunkreader.go | 22 ++++++++++++++++------ pgproto3/chunkreader_test.go | 18 +++++++++--------- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index a9a6de8c..9c5e8e79 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1433,8 +1433,12 @@ func (rr *ResultReader) Read() *Result { copy(br.FieldDescriptions, rr.FieldDescriptions()) } - row := make([][]byte, len(rr.Values())) - copy(row, rr.Values()) + values := rr.Values() + row := make([][]byte, len(values)) + for i := range row { + row[i] = make([]byte, len(values[i])) + copy(row[i], values[i]) + } br.Rows = append(br.Rows, row) } diff --git a/pgproto3/chunkreader.go b/pgproto3/chunkreader.go index 1781d6cd..598aaa85 100644 --- a/pgproto3/chunkreader.go +++ b/pgproto3/chunkreader.go @@ -6,11 +6,9 @@ import ( // chunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and // will read as much as will fit in the current buffer in a single call regardless of how large a read is actually -// requested. The memory returned via Next is owned by the caller. This avoids the need for an additional copy. +// requested. The memory returned via Next is only valid until the next call to Next. // -// The downside of this approach is that a large buffer can be pinned in memory even if only a small slice is -// referenced. For example, an entire 4096 byte block could be pinned in memory by even a 1 byte slice. In these rare -// cases it would be advantageous to copy the bytes to another slice. +// This is roughly equivalent to a bufio.Reader that only uses Peek and Discard to never copy bytes. type chunkReader struct { r io.Reader @@ -38,13 +36,14 @@ func newChunkReader(r io.Reader, minBufSize int) *chunkReader { } } -// Next returns buf filled with the next n bytes. The caller gains ownership of buf. It is not necessary to make a copy -// of buf. If an error occurs, buf will be nil. +// Next returns buf filled with the next n bytes. buf is only valid until next call of Next. If an error occurs, buf +// will be nil. func (r *chunkReader) Next(n int) (buf []byte, err error) { // n bytes already in buf if (r.wp - r.rp) >= n { buf = r.buf[r.rp : r.rp+n] r.rp += n + r.resetBufIfEmpty() return buf, err } @@ -66,6 +65,7 @@ func (r *chunkReader) Next(n int) (buf []byte, err error) { buf = r.buf[r.rp : r.rp+n] r.rp += n + r.resetBufIfEmpty() return buf, nil } @@ -87,3 +87,13 @@ func (r *chunkReader) copyBufContents(dest []byte) { r.rp = 0 r.buf = dest } + +func (r *chunkReader) resetBufIfEmpty() { + if r.rp == r.wp { + if len(r.buf) > r.minBufLen { + r.buf = make([]byte, r.minBufLen) + } + r.rp = 0 + r.wp = 0 + } +} diff --git a/pgproto3/chunkreader_test.go b/pgproto3/chunkreader_test.go index 86fbd8b2..1c0c63d8 100644 --- a/pgproto3/chunkreader_test.go +++ b/pgproto3/chunkreader_test.go @@ -32,11 +32,11 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { if bytes.Compare(r.buf, src) != 0 { t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf) } - if r.rp != 4 { - t.Fatalf("Expected r.rp to be %v, but it was %v", 4, r.rp) + if r.rp != 0 { + t.Fatalf("Expected r.rp to be %v, but it was %v", 0, r.rp) } - if r.wp != 4 { - t.Fatalf("Expected r.wp to be %v, but it was %v", 4, r.wp) + if r.wp != 0 { + t.Fatalf("Expected r.wp to be %v, but it was %v", 0, r.wp) } } @@ -54,12 +54,12 @@ func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { if bytes.Compare(n1, src[0:5]) != 0 { t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1) } - if len(r.buf) != 5 { - t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 5, len(r.buf)) + if len(r.buf) != 4 { + t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 4, len(r.buf)) } } -func TestChunkReaderDoesNotReuseBuf(t *testing.T) { +func TestChunkReaderReusesBuf(t *testing.T) { server := &bytes.Buffer{} r := newChunkReader(server, 4) @@ -82,8 +82,8 @@ func TestChunkReaderDoesNotReuseBuf(t *testing.T) { t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2) } - if bytes.Compare(n1, src[0:4]) != 0 { - t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1) + if bytes.Compare(n1, src[4:8]) != 0 { + t.Fatalf("Expected slice to be reused, expected %v but it was %v", src[4:8], n1) } } From 2fad63c189e58311318d6c19c56d8342c95d065f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Feb 2022 09:37:14 -0600 Subject: [PATCH 0926/1158] Set cap when returning slice from chunkReader --- pgproto3/chunkreader.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgproto3/chunkreader.go b/pgproto3/chunkreader.go index 598aaa85..d17b93ea 100644 --- a/pgproto3/chunkreader.go +++ b/pgproto3/chunkreader.go @@ -41,7 +41,7 @@ func newChunkReader(r io.Reader, minBufSize int) *chunkReader { func (r *chunkReader) Next(n int) (buf []byte, err error) { // n bytes already in buf if (r.wp - r.rp) >= n { - buf = r.buf[r.rp : r.rp+n] + buf = r.buf[r.rp : r.rp+n : r.rp+n] r.rp += n r.resetBufIfEmpty() return buf, err @@ -63,7 +63,7 @@ func (r *chunkReader) Next(n int) (buf []byte, err error) { return nil, err } - buf = r.buf[r.rp : r.rp+n] + buf = r.buf[r.rp : r.rp+n : r.rp+n] r.rp += n r.resetBufIfEmpty() return buf, nil From b1e4b96e6c81af5cfba2da84a9e9fc37a99698ec Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Feb 2022 19:57:41 -0600 Subject: [PATCH 0927/1158] Reduce big read buffer allocations with sync.Pool --- pgproto3/chunkreader.go | 118 ++++++++++++++++++++++------------- pgproto3/chunkreader_test.go | 12 +++- 2 files changed, 83 insertions(+), 47 deletions(-) diff --git a/pgproto3/chunkreader.go b/pgproto3/chunkreader.go index d17b93ea..2d116c91 100644 --- a/pgproto3/chunkreader.go +++ b/pgproto3/chunkreader.go @@ -2,8 +2,48 @@ package pgproto3 import ( "io" + "sync" ) +type bigBufPool struct { + pool sync.Pool + byteSize int +} + +var bigBufPools []*bigBufPool + +func init() { + KiB := 1024 + bigBufSizes := []int{64 * KiB, 256 * KiB, 1024 * KiB, 4096 * KiB} + bigBufPools = make([]*bigBufPool, len(bigBufSizes)) + + for i := range bigBufPools { + byteSize := bigBufSizes[i] + bigBufPools[i] = &bigBufPool{ + pool: sync.Pool{New: func() interface{} { return make([]byte, byteSize) }}, + byteSize: byteSize, + } + } +} + +func getBigBuf(size int) []byte { + for _, bigBufPool := range bigBufPools { + if size < bigBufPool.byteSize { + return bigBufPool.pool.Get().([]byte) + } + } + return make([]byte, size) +} + +func releaseBigBuf(buf []byte) { + for _, bigBufPool := range bigBufPools { + if len(buf) == bigBufPool.byteSize { + bigBufPool.pool.Put(buf) + return + } + } +} + // chunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and // will read as much as will fit in the current buffer in a single call regardless of how large a read is actually // requested. The memory returned via Next is only valid until the next call to Next. @@ -15,85 +55,75 @@ type chunkReader struct { buf []byte rp, wp int // buf read position and write position - minBufLen int + ownBuf []byte // buf owned by chunkReader } -// newChunkReader creates and returns a new chunkReader for r with default configuration with minBufSize internal buffer. +// newChunkReader creates and returns a new chunkReader for r with default configuration with bufSize internal buffer. // If bufSize is <= 0 it uses a default value. -func newChunkReader(r io.Reader, minBufSize int) *chunkReader { - if minBufSize <= 0 { +func newChunkReader(r io.Reader, bufSize int) *chunkReader { + if bufSize <= 0 { // By historical reasons Postgres currently has 8KB send buffer inside, // so here we want to have at least the same size buffer. // @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134 // @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru - minBufSize = 8192 + // + // In addition, testing has found no benefit of any larger buffer. + bufSize = 8192 } + buf := make([]byte, bufSize) + return &chunkReader{ - r: r, - buf: make([]byte, minBufSize), - minBufLen: minBufSize, + r: r, + buf: buf, + ownBuf: buf, } } // Next returns buf filled with the next n bytes. buf is only valid until next call of Next. If an error occurs, buf // will be nil. func (r *chunkReader) Next(n int) (buf []byte, err error) { + // Reset the buffer if it is empty + if r.rp == r.wp { + if len(r.buf) != len(r.ownBuf) { + releaseBigBuf(r.buf) + r.buf = r.ownBuf + } + r.rp = 0 + r.wp = 0 + } + // n bytes already in buf if (r.wp - r.rp) >= n { buf = r.buf[r.rp : r.rp+n : r.rp+n] r.rp += n - r.resetBufIfEmpty() return buf, err } - // available space in buf is less than n + // buf is smaller than requested number of bytes if len(r.buf) < n { - r.copyBufContents(r.newBuf(n)) + bigBuf := getBigBuf(n) + r.wp = copy(bigBuf, r.buf[r.rp:r.wp]) + r.rp = 0 + r.buf = bigBuf } // buf is large enough, but need to shift filled area to start to make enough contiguous space minReadCount := n - (r.wp - r.rp) if (len(r.buf) - r.wp) < minReadCount { - newBuf := r.newBuf(n) - r.copyBufContents(newBuf) + r.wp = copy(r.buf, r.buf[r.rp:r.wp]) + r.rp = 0 } - if err := r.appendAtLeast(minReadCount); err != nil { + // Read at least the required number of bytes from the underlying io.Reader + readBytesCount, err := io.ReadAtLeast(r.r, r.buf[r.wp:], minReadCount) + r.wp += readBytesCount + // fmt.Println("read", n) + if err != nil { return nil, err } buf = r.buf[r.rp : r.rp+n : r.rp+n] r.rp += n - r.resetBufIfEmpty() return buf, nil } - -func (r *chunkReader) appendAtLeast(fillLen int) error { - n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen) - r.wp += n - return err -} - -func (r *chunkReader) newBuf(size int) []byte { - if size < r.minBufLen { - size = r.minBufLen - } - return make([]byte, size) -} - -func (r *chunkReader) copyBufContents(dest []byte) { - r.wp = copy(dest, r.buf[r.rp:r.wp]) - r.rp = 0 - r.buf = dest -} - -func (r *chunkReader) resetBufIfEmpty() { - if r.rp == r.wp { - if len(r.buf) > r.minBufLen { - r.buf = make([]byte, r.minBufLen) - } - r.rp = 0 - r.wp = 0 - } -} diff --git a/pgproto3/chunkreader_test.go b/pgproto3/chunkreader_test.go index 1c0c63d8..7d7bac7f 100644 --- a/pgproto3/chunkreader_test.go +++ b/pgproto3/chunkreader_test.go @@ -32,6 +32,12 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { if bytes.Compare(r.buf, src) != 0 { t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf) } + + _, err = r.Next(0) // Trigger the buffer reset. + if err != nil { + t.Fatal(err) + } + if r.rp != 0 { t.Fatalf("Expected r.rp to be %v, but it was %v", 0, r.rp) } @@ -40,7 +46,7 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { } } -func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { +func TestChunkReaderNextGetsBiggerBufAsNeededFromBigBufPools(t *testing.T) { server := &bytes.Buffer{} r := newChunkReader(server, 4) @@ -54,8 +60,8 @@ func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { if bytes.Compare(n1, src[0:5]) != 0 { t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1) } - if len(r.buf) != 4 { - t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 4, len(r.buf)) + if len(r.buf) != bigBufPools[0].byteSize { + t.Fatalf("Expected len(r.buf) to be %v, but it was %v", bigBufPools[0].byteSize, len(r.buf)) } } From ffc5a692cb0fb4b98769103fdedf6b36d1c43aa2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Feb 2022 20:23:35 -0600 Subject: [PATCH 0928/1158] Detect unsafe pgtype.DriverBytes usage Add test for unsafe usage and test for correct usage that ensures driver memory is actually used. --- pgtype/bytea_test.go | 45 ++++++++++++++++++++++++++++++++++++++------ rows.go | 7 +++++++ 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go index 41c3482e..ae4a8760 100644 --- a/pgtype/bytea_test.go +++ b/pgtype/bytea_test.go @@ -36,7 +36,7 @@ func TestByteaCodec(t *testing.T) { }) } -func TestDriverBytes(t *testing.T) { +func TestDriverBytesQueryRow(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) @@ -44,14 +44,47 @@ func TestDriverBytes(t *testing.T) { var buf []byte err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.DriverBytes)(&buf)) + require.EqualError(t, err, "cannot scan into *pgtype.DriverBytes from QueryRow") +} + +func TestDriverBytes(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + ctx := context.Background() + + argBuf := make([]byte, 128) + for i := range argBuf { + argBuf[i] = byte(i) + } + + rows, err := conn.Query(ctx, `select $1::bytea from generate_series(1, 1000)`, argBuf) require.NoError(t, err) + defer rows.Close() - require.Len(t, buf, 2) - require.Equal(t, buf, []byte{1, 2}) - require.Equalf(t, cap(buf), len(buf), "cap(buf) is larger than len(buf)") + rowCount := 0 + resultBuf := argBuf + detectedResultMutation := false + for rows.Next() { + rowCount++ - // Don't actually have any way to be sure that the bytes are from the driver at the moment as underlying driver - // doesn't reuse buffers at the present. + // At some point the buffer should be reused and change. + if bytes.Compare(argBuf, resultBuf) != 0 { + detectedResultMutation = true + } + + err = rows.Scan((*pgtype.DriverBytes)(&resultBuf)) + require.NoError(t, err) + + require.Len(t, resultBuf, len(argBuf)) + require.Equal(t, resultBuf, argBuf) + require.Equalf(t, cap(resultBuf), len(resultBuf), "cap(resultBuf) is larger than len(resultBuf)") + } + + require.True(t, detectedResultMutation) + + err = rows.Err() + require.NoError(t, err) } func TestPreallocBytes(t *testing.T) { diff --git a/rows.go b/rows.go index d9b155e6..3ff8c93e 100644 --- a/rows.go +++ b/rows.go @@ -79,6 +79,13 @@ func (r *connRow) Scan(dest ...interface{}) (err error) { return rows.Err() } + for _, d := range dest { + if _, ok := d.(*pgtype.DriverBytes); ok { + rows.Close() + return fmt.Errorf("cannot scan into *pgtype.DriverBytes from QueryRow") + } + } + if !rows.Next() { if rows.Err() == nil { return ErrNoRows From a8f6674a07b2c582e35ccdfc96c6b2dec5fb0669 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Feb 2022 20:28:15 -0600 Subject: [PATCH 0929/1158] TextCodec specifically supports scanning to BytesScanner This lets it support DriverBytes and PreallocatedBytes. --- pgtype/text.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pgtype/text.go b/pgtype/text.go index 82e7753c..53fcb368 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -165,6 +165,8 @@ func (TextCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) return scanPlanTextAnyToString{} case *[]byte: return scanPlanAnyToNewByteSlice{} + case BytesScanner: + return scanPlanAnyToByteScanner{} case TextScanner: return scanPlanTextAnyToTextScanner{} case *rune: @@ -214,6 +216,13 @@ func (scanPlanAnyToNewByteSlice) Scan(src []byte, dst interface{}) error { return nil } +type scanPlanAnyToByteScanner struct{} + +func (scanPlanAnyToByteScanner) Scan(src []byte, dst interface{}) error { + p := (dst).(BytesScanner) + return p.ScanBytes(src) +} + type scanPlanTextAnyToRune struct{} func (scanPlanTextAnyToRune) Scan(src []byte, dst interface{}) error { From 45a8b00271664a5d2718c3fd7e69fda9ca429193 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 4 Mar 2022 11:04:46 -0600 Subject: [PATCH 0930/1158] Do not recursively call public PlanScan that caches Otherwise, wrapper types get cached. Wrapper types are expected to fail most of the time. These failures should not be cached. In addition, wrappers wrap multiple different types so it doesn't make sense to cache results of a wrapper. --- pgtype/pgtype.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index cb24f295..8c743163 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1068,7 +1068,7 @@ func (m *Map) planScan(oid uint32, formatCode int16, target interface{}) ScanPla for _, f := range m.TryWrapScanPlanFuncs { if wrapperPlan, nextDst, ok := f(target); ok { - if nextPlan := m.PlanScan(oid, formatCode, nextDst); nextPlan != nil { + if nextPlan := m.planScan(oid, formatCode, nextDst); nextPlan != nil { if _, failed := nextPlan.(*scanPlanFail); !failed { wrapperPlan.SetNext(nextPlan) return wrapperPlan From a365c9a3c2c71facc0874f0a422a9b91703695db Mon Sep 17 00:00:00 2001 From: Vu Date: Tue, 1 Mar 2022 23:18:55 +0800 Subject: [PATCH 0931/1158] Add multirange support for num, int4 and int8 type --- int4_multirange.go | 239 ++++++++++++++++++++++++++++++++++++++++ int4_multirange_test.go | 81 ++++++++++++++ int8_multirange.go | 239 ++++++++++++++++++++++++++++++++++++++++ int8_multirange_test.go | 81 ++++++++++++++ multirange.go | 83 ++++++++++++++ multirange_test.go | 51 +++++++++ num_multirange.go | 239 ++++++++++++++++++++++++++++++++++++++++ num_multirange_test.go | 55 +++++++++ pgtype.go | 143 +++++++++++++----------- typed_multirange.go.erb | 239 ++++++++++++++++++++++++++++++++++++++++ typed_multirange_gen.sh | 8 ++ 11 files changed, 1391 insertions(+), 67 deletions(-) create mode 100644 int4_multirange.go create mode 100644 int4_multirange_test.go create mode 100644 int8_multirange.go create mode 100644 int8_multirange_test.go create mode 100644 multirange.go create mode 100644 multirange_test.go create mode 100644 num_multirange.go create mode 100644 num_multirange_test.go create mode 100644 typed_multirange.go.erb create mode 100755 typed_multirange_gen.sh diff --git a/int4_multirange.go b/int4_multirange.go new file mode 100644 index 00000000..c3432ce6 --- /dev/null +++ b/int4_multirange.go @@ -0,0 +1,239 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + + "github.com/jackc/pgio" +) + +type Int4multirange struct { + Ranges []Int4range + Status Status +} + +func (dst *Int4multirange) Set(src interface{}) error { + //untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int4multirange{Status: Null} + return nil + } + + switch value := src.(type) { + case Int4multirange: + *dst = value + case *Int4multirange: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + case []Int4range: + if value == nil { + *dst = Int4multirange{Status: Null} + } else if len(value) == 0 { + *dst = Int4multirange{Status: Present} + } else { + elements := make([]Int4range, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4multirange{ + Ranges: elements, + Status: Present, + } + } + case []*Int4range: + if value == nil { + *dst = Int4multirange{Status: Null} + } else if len(value) == 0 { + *dst = Int4multirange{Status: Present} + } else { + elements := make([]Int4range, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4multirange{ + Ranges: elements, + Status: Present, + } + } + default: + return fmt.Errorf("cannot convert %v to Int4multirange", src) + } + + return nil + +} + +func (dst Int4multirange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int4multirange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Int4multirange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4multirange{Status: Null} + return nil + } + + utmr, err := ParseUntypedTextMultirange(string(src)) + if err != nil { + return err + } + + var elements []Int4range + + if len(utmr.Elements) > 0 { + elements = make([]Int4range, len(utmr.Elements)) + + for i, s := range utmr.Elements { + var elem Int4range + + elemSrc := []byte(s) + + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Int4multirange{Ranges: elements, Status: Present} + + return nil +} + +func (dst *Int4multirange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4multirange{Status: Null} + return nil + } + + rp := 0 + + numElems := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + if numElems == 0 { + *dst = Int4multirange{Status: Present} + return nil + } + + elements := make([]Int4range, numElems) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err := elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Int4multirange{Ranges: elements, Status: Present} + return nil +} + +func (src Int4multirange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, '{') + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Ranges { + if i > 0 { + buf = append(buf, ',') + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + return nil, fmt.Errorf("multi-range does not allow null range") + } else { + buf = append(buf, string(elemBuf)...) + } + + } + + buf = append(buf, '}') + + return buf, nil +} + +func (src Int4multirange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendInt32(buf, int32(len(src.Ranges))) + + for i := range src.Ranges { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Ranges[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4multirange) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4multirange) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/int4_multirange_test.go b/int4_multirange_test.go new file mode 100644 index 00000000..e123c402 --- /dev/null +++ b/int4_multirange_test.go @@ -0,0 +1,81 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt4multirangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int4multirange", []interface{}{ + &pgtype.Int4multirange{ + Ranges: nil, + Status: pgtype.Present, + }, + &pgtype.Int4multirange{ + Ranges: []pgtype.Int4range{ + { + Lower: pgtype.Int4{Int: -543, Status: pgtype.Present}, + Upper: pgtype.Int4{Int: 342, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + }, + Status: pgtype.Present, + }, + &pgtype.Int4multirange{ + Ranges: []pgtype.Int4range{ + { + Lower: pgtype.Int4{Int: -42, Status: pgtype.Present}, + Upper: pgtype.Int4{Int: -5, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + { + Lower: pgtype.Int4{Int: 5, Status: pgtype.Present}, + Upper: pgtype.Int4{Int: 42, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + { + Lower: pgtype.Int4{Int: 52, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Unbounded, + Status: pgtype.Present, + }, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInt4multirangeNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select int4multirange(int4range(1, 14, '(]'), int4range(20, 25, '()'))", + Value: pgtype.Int4multirange{ + Ranges: []pgtype.Int4range{ + { + Lower: pgtype.Int4{Int: 2, Status: pgtype.Present}, + Upper: pgtype.Int4{Int: 15, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + { + Lower: pgtype.Int4{Int: 21, Status: pgtype.Present}, + Upper: pgtype.Int4{Int: 25, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + }, + Status: pgtype.Present, + }, + }, + }) +} diff --git a/int8_multirange.go b/int8_multirange.go new file mode 100644 index 00000000..e0976427 --- /dev/null +++ b/int8_multirange.go @@ -0,0 +1,239 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + + "github.com/jackc/pgio" +) + +type Int8multirange struct { + Ranges []Int8range + Status Status +} + +func (dst *Int8multirange) Set(src interface{}) error { + //untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int8multirange{Status: Null} + return nil + } + + switch value := src.(type) { + case Int8multirange: + *dst = value + case *Int8multirange: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + case []Int8range: + if value == nil { + *dst = Int8multirange{Status: Null} + } else if len(value) == 0 { + *dst = Int8multirange{Status: Present} + } else { + elements := make([]Int8range, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8multirange{ + Ranges: elements, + Status: Present, + } + } + case []*Int8range: + if value == nil { + *dst = Int8multirange{Status: Null} + } else if len(value) == 0 { + *dst = Int8multirange{Status: Present} + } else { + elements := make([]Int8range, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8multirange{ + Ranges: elements, + Status: Present, + } + } + default: + return fmt.Errorf("cannot convert %v to Int8multirange", src) + } + + return nil + +} + +func (dst Int8multirange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int8multirange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Int8multirange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8multirange{Status: Null} + return nil + } + + utmr, err := ParseUntypedTextMultirange(string(src)) + if err != nil { + return err + } + + var elements []Int8range + + if len(utmr.Elements) > 0 { + elements = make([]Int8range, len(utmr.Elements)) + + for i, s := range utmr.Elements { + var elem Int8range + + elemSrc := []byte(s) + + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Int8multirange{Ranges: elements, Status: Present} + + return nil +} + +func (dst *Int8multirange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8multirange{Status: Null} + return nil + } + + rp := 0 + + numElems := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + if numElems == 0 { + *dst = Int8multirange{Status: Present} + return nil + } + + elements := make([]Int8range, numElems) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err := elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Int8multirange{Ranges: elements, Status: Present} + return nil +} + +func (src Int8multirange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, '{') + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Ranges { + if i > 0 { + buf = append(buf, ',') + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + return nil, fmt.Errorf("multi-range does not allow null range") + } else { + buf = append(buf, string(elemBuf)...) + } + + } + + buf = append(buf, '}') + + return buf, nil +} + +func (src Int8multirange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendInt32(buf, int32(len(src.Ranges))) + + for i := range src.Ranges { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Ranges[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8multirange) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8multirange) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/int8_multirange_test.go b/int8_multirange_test.go new file mode 100644 index 00000000..a4233384 --- /dev/null +++ b/int8_multirange_test.go @@ -0,0 +1,81 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt8multirangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int8multirange", []interface{}{ + &pgtype.Int8multirange{ + Ranges: nil, + Status: pgtype.Present, + }, + &pgtype.Int8multirange{ + Ranges: []pgtype.Int8range{ + { + Lower: pgtype.Int8{Int: -543, Status: pgtype.Present}, + Upper: pgtype.Int8{Int: 342, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + }, + Status: pgtype.Present, + }, + &pgtype.Int8multirange{ + Ranges: []pgtype.Int8range{ + { + Lower: pgtype.Int8{Int: -42, Status: pgtype.Present}, + Upper: pgtype.Int8{Int: -5, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + { + Lower: pgtype.Int8{Int: 5, Status: pgtype.Present}, + Upper: pgtype.Int8{Int: 42, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + { + Lower: pgtype.Int8{Int: 52, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Unbounded, + Status: pgtype.Present, + }, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInt8multirangeNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select int8multirange(int8range(1, 14, '(]'), int8range(20, 25, '()'))", + Value: pgtype.Int8multirange{ + Ranges: []pgtype.Int8range{ + { + Lower: pgtype.Int8{Int: 2, Status: pgtype.Present}, + Upper: pgtype.Int8{Int: 15, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + { + Lower: pgtype.Int8{Int: 21, Status: pgtype.Present}, + Upper: pgtype.Int8{Int: 25, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + }, + Status: pgtype.Present, + }, + }, + }) +} diff --git a/multirange.go b/multirange.go new file mode 100644 index 00000000..beb11f70 --- /dev/null +++ b/multirange.go @@ -0,0 +1,83 @@ +package pgtype + +import ( + "bytes" + "fmt" +) + +type UntypedTextMultirange struct { + Elements []string +} + +func ParseUntypedTextMultirange(src string) (*UntypedTextMultirange, error) { + utmr := &UntypedTextMultirange{} + utmr.Elements = make([]string, 0) + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != '{' { + return nil, fmt.Errorf("invalid multirange, expected '{': %v", err) + } + +parseValueLoop: + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid multirange: %v", err) + } + + switch r { + case ',': // skip range separator + case '}': + break parseValueLoop + default: + buf.UnreadRune() + value, err := parseRange(buf) + if err != nil { + return nil, fmt.Errorf("invalid multirange value: %v", err) + } + utmr.Elements = append(utmr.Elements, value) + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + return utmr, nil + +} + +func parseRange(buf *bytes.Buffer) (string, error) { + + s := &bytes.Buffer{} + + boundSepRead := false + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case ',', '}': + if r == ',' && !boundSepRead { + boundSepRead = true + break + } + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} diff --git a/multirange_test.go b/multirange_test.go new file mode 100644 index 00000000..4991aecf --- /dev/null +++ b/multirange_test.go @@ -0,0 +1,51 @@ +package pgtype + +import ( + "reflect" + "testing" +) + +func TestParseUntypedTextMultirange(t *testing.T) { + tests := []struct { + src string + result UntypedTextMultirange + err error + }{ + { + src: `{[1,2)}`, + result: UntypedTextMultirange{Elements: []string{`[1,2)`}}, + err: nil, + }, + { + src: `{[,),["foo", "bar"]}`, + result: UntypedTextMultirange{Elements: []string{`[,)`, `["foo", "bar"]`}}, + err: nil, + }, + { + src: `{}`, + result: UntypedTextMultirange{Elements: []string{}}, + err: nil, + }, + { + src: ` { (,) , [1,2] } `, + result: UntypedTextMultirange{Elements: []string{` (,) `, ` [1,2] `}}, + err: nil, + }, + { + src: `{["f""oo","b""ar")}`, + result: UntypedTextMultirange{Elements: []string{`["f""oo","b""ar")`}}, + err: nil, + }, + } + for i, tt := range tests { + r, err := ParseUntypedTextMultirange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if !reflect.DeepEqual(*r, tt.result) { + t.Errorf("%d: expected %+v to be parsed to %+v, but it was %+v", i, tt.src, tt.result, *r) + } + } +} diff --git a/num_multirange.go b/num_multirange.go new file mode 100644 index 00000000..cbabc8ac --- /dev/null +++ b/num_multirange.go @@ -0,0 +1,239 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + + "github.com/jackc/pgio" +) + +type Nummultirange struct { + Ranges []Numrange + Status Status +} + +func (dst *Nummultirange) Set(src interface{}) error { + //untyped nil and typed nil interfaces are different + if src == nil { + *dst = Nummultirange{Status: Null} + return nil + } + + switch value := src.(type) { + case Nummultirange: + *dst = value + case *Nummultirange: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + case []Numrange: + if value == nil { + *dst = Nummultirange{Status: Null} + } else if len(value) == 0 { + *dst = Nummultirange{Status: Present} + } else { + elements := make([]Numrange, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Nummultirange{ + Ranges: elements, + Status: Present, + } + } + case []*Numrange: + if value == nil { + *dst = Nummultirange{Status: Null} + } else if len(value) == 0 { + *dst = Nummultirange{Status: Present} + } else { + elements := make([]Numrange, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Nummultirange{ + Ranges: elements, + Status: Present, + } + } + default: + return fmt.Errorf("cannot convert %v to Nummultirange", src) + } + + return nil + +} + +func (dst Nummultirange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Nummultirange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Nummultirange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Nummultirange{Status: Null} + return nil + } + + utmr, err := ParseUntypedTextMultirange(string(src)) + if err != nil { + return err + } + + var elements []Numrange + + if len(utmr.Elements) > 0 { + elements = make([]Numrange, len(utmr.Elements)) + + for i, s := range utmr.Elements { + var elem Numrange + + elemSrc := []byte(s) + + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Nummultirange{Ranges: elements, Status: Present} + + return nil +} + +func (dst *Nummultirange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Nummultirange{Status: Null} + return nil + } + + rp := 0 + + numElems := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + if numElems == 0 { + *dst = Nummultirange{Status: Present} + return nil + } + + elements := make([]Numrange, numElems) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err := elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Nummultirange{Ranges: elements, Status: Present} + return nil +} + +func (src Nummultirange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, '{') + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Ranges { + if i > 0 { + buf = append(buf, ',') + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + return nil, fmt.Errorf("multi-range does not allow null range") + } else { + buf = append(buf, string(elemBuf)...) + } + + } + + buf = append(buf, '}') + + return buf, nil +} + +func (src Nummultirange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendInt32(buf, int32(len(src.Ranges))) + + for i := range src.Ranges { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Ranges[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Nummultirange) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Nummultirange) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/num_multirange_test.go b/num_multirange_test.go new file mode 100644 index 00000000..f4289794 --- /dev/null +++ b/num_multirange_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "math/big" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestNumericMultirangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "nummultirange", []interface{}{ + &pgtype.Nummultirange{ + Ranges: nil, + Status: pgtype.Present, + }, + &pgtype.Nummultirange{ + Ranges: []pgtype.Numrange{ + { + Lower: pgtype.Numeric{Int: big.NewInt(-543), Exp: 3, Status: pgtype.Present}, + Upper: pgtype.Numeric{Int: big.NewInt(342), Exp: 1, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + }, + Status: pgtype.Present, + }, + &pgtype.Nummultirange{ + Ranges: []pgtype.Numrange{ + { + Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, + Upper: pgtype.Numeric{Int: big.NewInt(-5), Exp: 0, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + { + Lower: pgtype.Numeric{Int: big.NewInt(5), Exp: 1, Status: pgtype.Present}, + Upper: pgtype.Numeric{Int: big.NewInt(42), Exp: 1, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Inclusive, + Status: pgtype.Present, + }, + { + Lower: pgtype.Numeric{Int: big.NewInt(42), Exp: 2, Status: pgtype.Present}, + LowerType: pgtype.Exclusive, + UpperType: pgtype.Unbounded, + Status: pgtype.Present, + }, + }, + Status: pgtype.Present, + }, + }) +} diff --git a/pgtype.go b/pgtype.go index 200fb562..eba09fa5 100644 --- a/pgtype.go +++ b/pgtype.go @@ -74,12 +74,15 @@ const ( JSONBArrayOID = 3807 DaterangeOID = 3912 Int4rangeOID = 3904 + Int4multirangeOID = 4451 NumrangeOID = 3906 + NummultirangeOID = 4532 TsrangeOID = 3908 TsrangeArrayOID = 3909 TstzrangeOID = 3910 TstzrangeArrayOID = 3911 Int8rangeOID = 3926 + Int8multirangeOID = 4536 ) type Status byte @@ -288,8 +291,10 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID}) ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID}) ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) + ci.RegisterDataType(DataType{Value: &Int4multirange{}, Name: "int4multirange", OID: Int4multirangeOID}) ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID}) ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) + ci.RegisterDataType(DataType{Value: &Int8multirange{}, Name: "int8multirange", OID: Int8multirangeOID}) ci.RegisterDataType(DataType{Value: &Interval{}, Name: "interval", OID: IntervalOID}) ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID}) ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID}) @@ -300,6 +305,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Name{}, Name: "name", OID: NameOID}) ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID}) ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) + ci.RegisterDataType(DataType{Value: &Nummultirange{}, Name: "nummultirange", OID: NummultirangeOID}) ci.RegisterDataType(DataType{Value: &OIDValue{}, Name: "oid", OID: OIDOID}) ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID}) ci.RegisterDataType(DataType{Value: &Point{}, Name: "point", OID: PointOID}) @@ -873,72 +879,75 @@ var nameValues map[string]Value func init() { nameValues = map[string]Value{ - "_aclitem": &ACLItemArray{}, - "_bool": &BoolArray{}, - "_bpchar": &BPCharArray{}, - "_bytea": &ByteaArray{}, - "_cidr": &CIDRArray{}, - "_date": &DateArray{}, - "_float4": &Float4Array{}, - "_float8": &Float8Array{}, - "_inet": &InetArray{}, - "_int2": &Int2Array{}, - "_int4": &Int4Array{}, - "_int8": &Int8Array{}, - "_numeric": &NumericArray{}, - "_text": &TextArray{}, - "_timestamp": &TimestampArray{}, - "_timestamptz": &TimestamptzArray{}, - "_uuid": &UUIDArray{}, - "_varchar": &VarcharArray{}, - "_jsonb": &JSONBArray{}, - "aclitem": &ACLItem{}, - "bit": &Bit{}, - "bool": &Bool{}, - "box": &Box{}, - "bpchar": &BPChar{}, - "bytea": &Bytea{}, - "char": &QChar{}, - "cid": &CID{}, - "cidr": &CIDR{}, - "circle": &Circle{}, - "date": &Date{}, - "daterange": &Daterange{}, - "float4": &Float4{}, - "float8": &Float8{}, - "hstore": &Hstore{}, - "inet": &Inet{}, - "int2": &Int2{}, - "int4": &Int4{}, - "int4range": &Int4range{}, - "int8": &Int8{}, - "int8range": &Int8range{}, - "interval": &Interval{}, - "json": &JSON{}, - "jsonb": &JSONB{}, - "line": &Line{}, - "lseg": &Lseg{}, - "macaddr": &Macaddr{}, - "name": &Name{}, - "numeric": &Numeric{}, - "numrange": &Numrange{}, - "oid": &OIDValue{}, - "path": &Path{}, - "point": &Point{}, - "polygon": &Polygon{}, - "record": &Record{}, - "text": &Text{}, - "tid": &TID{}, - "timestamp": &Timestamp{}, - "timestamptz": &Timestamptz{}, - "tsrange": &Tsrange{}, - "_tsrange": &TsrangeArray{}, - "tstzrange": &Tstzrange{}, - "_tstzrange": &TstzrangeArray{}, - "unknown": &Unknown{}, - "uuid": &UUID{}, - "varbit": &Varbit{}, - "varchar": &Varchar{}, - "xid": &XID{}, + "_aclitem": &ACLItemArray{}, + "_bool": &BoolArray{}, + "_bpchar": &BPCharArray{}, + "_bytea": &ByteaArray{}, + "_cidr": &CIDRArray{}, + "_date": &DateArray{}, + "_float4": &Float4Array{}, + "_float8": &Float8Array{}, + "_inet": &InetArray{}, + "_int2": &Int2Array{}, + "_int4": &Int4Array{}, + "_int8": &Int8Array{}, + "_numeric": &NumericArray{}, + "_text": &TextArray{}, + "_timestamp": &TimestampArray{}, + "_timestamptz": &TimestamptzArray{}, + "_uuid": &UUIDArray{}, + "_varchar": &VarcharArray{}, + "_jsonb": &JSONBArray{}, + "aclitem": &ACLItem{}, + "bit": &Bit{}, + "bool": &Bool{}, + "box": &Box{}, + "bpchar": &BPChar{}, + "bytea": &Bytea{}, + "char": &QChar{}, + "cid": &CID{}, + "cidr": &CIDR{}, + "circle": &Circle{}, + "date": &Date{}, + "daterange": &Daterange{}, + "float4": &Float4{}, + "float8": &Float8{}, + "hstore": &Hstore{}, + "inet": &Inet{}, + "int2": &Int2{}, + "int4": &Int4{}, + "int4range": &Int4range{}, + "int4multirange": &Int4multirange{}, + "int8": &Int8{}, + "int8range": &Int8range{}, + "int8multirange": &Int8multirange{}, + "interval": &Interval{}, + "json": &JSON{}, + "jsonb": &JSONB{}, + "line": &Line{}, + "lseg": &Lseg{}, + "macaddr": &Macaddr{}, + "name": &Name{}, + "numeric": &Numeric{}, + "numrange": &Numrange{}, + "nummultirange": &Nummultirange{}, + "oid": &OIDValue{}, + "path": &Path{}, + "point": &Point{}, + "polygon": &Polygon{}, + "record": &Record{}, + "text": &Text{}, + "tid": &TID{}, + "timestamp": &Timestamp{}, + "timestamptz": &Timestamptz{}, + "tsrange": &Tsrange{}, + "_tsrange": &TsrangeArray{}, + "tstzrange": &Tstzrange{}, + "_tstzrange": &TstzrangeArray{}, + "unknown": &Unknown{}, + "uuid": &UUID{}, + "varbit": &Varbit{}, + "varchar": &Varchar{}, + "xid": &XID{}, } } diff --git a/typed_multirange.go.erb b/typed_multirange.go.erb new file mode 100644 index 00000000..84c8299f --- /dev/null +++ b/typed_multirange.go.erb @@ -0,0 +1,239 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + + "github.com/jackc/pgio" +) + +type <%= multirange_type %> struct { + Ranges []<%= range_type %> + Status Status +} + +func (dst *<%= multirange_type %>) Set(src interface{}) error { + //untyped nil and typed nil interfaces are different + if src == nil { + *dst = <%= multirange_type %>{Status: Null} + return nil + } + + switch value := src.(type) { + case <%= multirange_type %>: + *dst = value + case *<%= multirange_type %>: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + case []<%= range_type %>: + if value == nil { + *dst = <%= multirange_type %>{Status: Null} + } else if len(value) == 0 { + *dst = <%= multirange_type %>{Status: Present} + } else { + elements := make([]<%= range_type %>, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = <%= multirange_type %>{ + Ranges: elements, + Status: Present, + } + } + case []*<%= range_type %>: + if value == nil { + *dst = <%= multirange_type %>{Status: Null} + } else if len(value) == 0 { + *dst = <%= multirange_type %>{Status: Present} + } else { + elements := make([]<%= range_type %>, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = <%= multirange_type %>{ + Ranges: elements, + Status: Present, + } + } + default: + return fmt.Errorf("cannot convert %v to <%= multirange_type %>", src) + } + + return nil + +} + +func (dst <%= multirange_type %>) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *<%= multirange_type %>) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *<%= multirange_type %>) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= multirange_type %>{Status: Null} + return nil + } + + utmr, err := ParseUntypedTextMultirange(string(src)) + if err != nil { + return err + } + + var elements []<%= range_type %> + + if len(utmr.Elements) > 0 { + elements = make([]<%= range_type %>, len(utmr.Elements)) + + for i, s := range utmr.Elements { + var elem <%= range_type %> + + elemSrc := []byte(s) + + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = <%= multirange_type %>{Ranges: elements, Status: Present} + + return nil +} + +func (dst *<%= multirange_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= multirange_type %>{Status: Null} + return nil + } + + rp := 0 + + numElems := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + if numElems == 0 { + *dst = <%= multirange_type %>{Status: Present} + return nil + } + + elements := make([]<%= range_type %>, numElems) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err := elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = <%= multirange_type %>{Ranges: elements, Status: Present} + return nil +} + +func (src <%= multirange_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, '{') + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Ranges { + if i > 0 { + buf = append(buf, ',') + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + return nil, fmt.Errorf("multi-range does not allow null range") + } else { + buf = append(buf, string(elemBuf)...) + } + + } + + buf = append(buf, '}') + + return buf, nil +} + +func (src <%= multirange_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendInt32(buf, int32(len(src.Ranges))) + + for i := range src.Ranges { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Ranges[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *<%= multirange_type %>) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src <%= multirange_type %>) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/typed_multirange_gen.sh b/typed_multirange_gen.sh new file mode 100755 index 00000000..610f40a1 --- /dev/null +++ b/typed_multirange_gen.sh @@ -0,0 +1,8 @@ +erb range_type=Numrange multirange_type=Nummultirange typed_multirange.go.erb > num_multirange.go +erb range_type=Int4range multirange_type=Int4multirange typed_multirange.go.erb > int4_multirange.go +erb range_type=Int8range multirange_type=Int8multirange typed_multirange.go.erb > int8_multirange.go +# TODO +# erb range_type=Tsrange multirange_type=Tsmultirange typed_multirange.go.erb > ts_multirange.go +# erb range_type=Tstzrange multirange_type=Tstzmultirange typed_multirange.go.erb > tstz_multirange.go +# erb range_type=Daterange multirange_type=Datemultirange typed_multirange.go.erb > date_multirange.go +goimports -w *multirange.go \ No newline at end of file From b7a85d1a6fc58df695e8cf0571ebf4e7dab921d5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 08:23:58 -0600 Subject: [PATCH 0932/1158] Consider any "0A000" error a possible cached plan changed error https://github.com/jackc/pgx/issues/1162 --- stmtcache/lru.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/stmtcache/lru.go b/stmtcache/lru.go index 90fb76c2..f0fb53b9 100644 --- a/stmtcache/lru.go +++ b/stmtcache/lru.go @@ -102,10 +102,14 @@ func (c *LRU) StatementErrored(sql string, err error) { return } - isInvalidCachedPlanError := pgErr.Severity == "ERROR" && - pgErr.Code == "0A000" && - pgErr.Message == "cached plan must not change result type" - if isInvalidCachedPlanError { + // https://github.com/jackc/pgx/issues/1162 + // + // We used to look for the message "cached plan must not change result type". However, that message can be localized. + // Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to + // tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't + // have so it should be safe. + possibleInvalidCachedPlanError := pgErr.Code == "0A000" + if possibleInvalidCachedPlanError { c.stmtsToClear = append(c.stmtsToClear, sql) } } From ec8f7c4204372bda4329d4e6d8952c913863e656 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 08:56:41 -0600 Subject: [PATCH 0933/1158] Add comment for FormatCodeForOID --- pgtype/pgtype.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 8c743163..5c9e5796 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -422,6 +422,8 @@ func (m *Map) TypeForValue(v interface{}) (*Type, bool) { return dt, ok } +// FormatCodeForOID returns the preferred format code for type oid. If the type is not registered it returns the text +// format code. func (m *Map) FormatCodeForOID(oid uint32) int16 { fc, ok := m.oidToFormatCode[oid] if ok { From e7f90ba6e4f98ed686d1a6b0f3ab214a3062615d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 09:00:49 -0600 Subject: [PATCH 0934/1158] Remove unused pgtype.Map field --- pgtype/pgtype.go | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5c9e5796..d1a92089 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -171,11 +171,10 @@ type Type struct { // Map is the mapping between PostgreSQL server types and Go type handling logic. It can encode values for // transmission to a PostgreSQL server and scan received values. type Map struct { - oidToType map[uint32]*Type - nameToType map[string]*Type - reflectTypeToName map[reflect.Type]string - oidToFormatCode map[uint32]int16 - oidToResultFormatCode map[uint32]int16 + oidToType map[uint32]*Type + nameToType map[string]*Type + reflectTypeToName map[reflect.Type]string + oidToFormatCode map[uint32]int16 reflectTypeToType map[reflect.Type]*Type @@ -196,11 +195,10 @@ type Map struct { func NewMap() *Map { m := &Map{ - oidToType: make(map[uint32]*Type), - nameToType: make(map[string]*Type), - reflectTypeToName: make(map[reflect.Type]string), - oidToFormatCode: make(map[uint32]int16), - oidToResultFormatCode: make(map[uint32]int16), + oidToType: make(map[uint32]*Type), + nameToType: make(map[string]*Type), + reflectTypeToName: make(map[reflect.Type]string), + oidToFormatCode: make(map[uint32]int16), memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan), From 872a7a9315037cb0038bf2515e801dd840c939c9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 09:08:14 -0600 Subject: [PATCH 0935/1158] Fix pgtype/int.go.erb --- pgtype/int.go.erb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 3b5b14a9..29846744 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -128,7 +128,7 @@ func (Int<%= pg_byte_size %>Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int<%= pg_byte_size %>Codec) PlanEncode(m *TypeMap, oid uint32, format int16, value interface{}) EncodePlan { +func (Int<%= pg_byte_size %>Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -207,7 +207,7 @@ func (encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer) Encode(value interfa return append(buf, strconv.FormatInt(n.Int, 10)...), nil } -func (Int<%= pg_byte_size %>Codec) PlanScan(m *TypeMap, oid uint32, format int16, target interface{}) ScanPlan { +func (Int<%= pg_byte_size %>Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { switch format { case BinaryFormatCode: @@ -265,7 +265,7 @@ func (Int<%= pg_byte_size %>Codec) PlanScan(m *TypeMap, oid uint32, format int16 return nil } -func (c Int<%= pg_byte_size %>Codec) DecodeDatabaseSQLValue(m *TypeMap, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c Int<%= pg_byte_size %>Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -278,7 +278,7 @@ func (c Int<%= pg_byte_size %>Codec) DecodeDatabaseSQLValue(m *TypeMap, oid uint return n, nil } -func (c Int<%= pg_byte_size %>Codec) DecodeValue(m *TypeMap, oid uint32, format int16, src []byte) (interface{}, error) { +func (c Int<%= pg_byte_size %>Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { return nil, nil } From d723a4ab6fdc23c8d5dde95902d52da14851d074 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 09:10:58 -0600 Subject: [PATCH 0936/1158] pgtype Int2, Int4, and Int8 fields include bit size e.g. Instead of Int it is Int64. This matches the pattern set by the database/sql types. --- pgtype/builtin_wrappers.go | 120 ++++++++++++++++----------------- pgtype/float4.go | 8 +-- pgtype/float8.go | 10 +-- pgtype/int.go | 134 ++++++++++++++++++------------------- pgtype/int.go.erb | 46 ++++++------- pgtype/int_test.go | 42 ++++++------ pgtype/numeric.go | 6 +- pgtype/pgtype_test.go | 4 +- pgtype/range_codec_test.go | 14 ++-- pgtype/uint32.go | 20 +++--- pgtype/zeronull/int.go | 6 +- pgtype/zeronull/int.go.erb | 2 +- 12 files changed, 206 insertions(+), 206 deletions(-) diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index e43244cf..3706c994 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -18,19 +18,19 @@ func (w *int8Wrapper) ScanInt64(v Int8) error { return fmt.Errorf("cannot scan NULL into *int8") } - if v.Int < math.MinInt8 { - return fmt.Errorf("%d is less than minimum value for int8", v.Int) + if v.Int64 < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", v.Int64) } - if v.Int > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for int8", v.Int) + if v.Int64 > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", v.Int64) } - *w = int8Wrapper(v.Int) + *w = int8Wrapper(v.Int64) return nil } func (w int8Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(w), Valid: true}, nil + return Int8{Int64: int64(w), Valid: true}, nil } type int16Wrapper int16 @@ -42,19 +42,19 @@ func (w *int16Wrapper) ScanInt64(v Int8) error { return fmt.Errorf("cannot scan NULL into *int16") } - if v.Int < math.MinInt16 { - return fmt.Errorf("%d is less than minimum value for int16", v.Int) + if v.Int64 < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", v.Int64) } - if v.Int > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for int16", v.Int) + if v.Int64 > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", v.Int64) } - *w = int16Wrapper(v.Int) + *w = int16Wrapper(v.Int64) return nil } func (w int16Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(w), Valid: true}, nil + return Int8{Int64: int64(w), Valid: true}, nil } type int32Wrapper int32 @@ -66,19 +66,19 @@ func (w *int32Wrapper) ScanInt64(v Int8) error { return fmt.Errorf("cannot scan NULL into *int32") } - if v.Int < math.MinInt32 { - return fmt.Errorf("%d is less than minimum value for int32", v.Int) + if v.Int64 < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", v.Int64) } - if v.Int > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for int32", v.Int) + if v.Int64 > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", v.Int64) } - *w = int32Wrapper(v.Int) + *w = int32Wrapper(v.Int64) return nil } func (w int32Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(w), Valid: true}, nil + return Int8{Int64: int64(w), Valid: true}, nil } type int64Wrapper int64 @@ -90,13 +90,13 @@ func (w *int64Wrapper) ScanInt64(v Int8) error { return fmt.Errorf("cannot scan NULL into *int64") } - *w = int64Wrapper(v.Int) + *w = int64Wrapper(v.Int64) return nil } func (w int64Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(w), Valid: true}, nil + return Int8{Int64: int64(w), Valid: true}, nil } type intWrapper int @@ -108,20 +108,20 @@ func (w *intWrapper) ScanInt64(v Int8) error { return fmt.Errorf("cannot scan NULL into *int") } - if v.Int < math.MinInt { - return fmt.Errorf("%d is less than minimum value for int", v.Int) + if v.Int64 < math.MinInt { + return fmt.Errorf("%d is less than minimum value for int", v.Int64) } - if v.Int > math.MaxInt { - return fmt.Errorf("%d is greater than maximum value for int", v.Int) + if v.Int64 > math.MaxInt { + return fmt.Errorf("%d is greater than maximum value for int", v.Int64) } - *w = intWrapper(v.Int) + *w = intWrapper(v.Int64) return nil } func (w intWrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(w), Valid: true}, nil + return Int8{Int64: int64(w), Valid: true}, nil } type uint8Wrapper uint8 @@ -133,19 +133,19 @@ func (w *uint8Wrapper) ScanInt64(v Int8) error { return fmt.Errorf("cannot scan NULL into *uint8") } - if v.Int < 0 { - return fmt.Errorf("%d is less than minimum value for uint8", v.Int) + if v.Int64 < 0 { + return fmt.Errorf("%d is less than minimum value for uint8", v.Int64) } - if v.Int > math.MaxUint8 { - return fmt.Errorf("%d is greater than maximum value for uint8", v.Int) + if v.Int64 > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", v.Int64) } - *w = uint8Wrapper(v.Int) + *w = uint8Wrapper(v.Int64) return nil } func (w uint8Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(w), Valid: true}, nil + return Int8{Int64: int64(w), Valid: true}, nil } type uint16Wrapper uint16 @@ -157,19 +157,19 @@ func (w *uint16Wrapper) ScanInt64(v Int8) error { return fmt.Errorf("cannot scan NULL into *uint16") } - if v.Int < 0 { - return fmt.Errorf("%d is less than minimum value for uint16", v.Int) + if v.Int64 < 0 { + return fmt.Errorf("%d is less than minimum value for uint16", v.Int64) } - if v.Int > math.MaxUint16 { - return fmt.Errorf("%d is greater than maximum value for uint16", v.Int) + if v.Int64 > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", v.Int64) } - *w = uint16Wrapper(v.Int) + *w = uint16Wrapper(v.Int64) return nil } func (w uint16Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(w), Valid: true}, nil + return Int8{Int64: int64(w), Valid: true}, nil } type uint32Wrapper uint32 @@ -181,19 +181,19 @@ func (w *uint32Wrapper) ScanInt64(v Int8) error { return fmt.Errorf("cannot scan NULL into *uint32") } - if v.Int < 0 { - return fmt.Errorf("%d is less than minimum value for uint32", v.Int) + if v.Int64 < 0 { + return fmt.Errorf("%d is less than minimum value for uint32", v.Int64) } - if v.Int > math.MaxUint32 { - return fmt.Errorf("%d is greater than maximum value for uint32", v.Int) + if v.Int64 > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", v.Int64) } - *w = uint32Wrapper(v.Int) + *w = uint32Wrapper(v.Int64) return nil } func (w uint32Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(w), Valid: true}, nil + return Int8{Int64: int64(w), Valid: true}, nil } type uint64Wrapper uint64 @@ -205,11 +205,11 @@ func (w *uint64Wrapper) ScanInt64(v Int8) error { return fmt.Errorf("cannot scan NULL into *uint64") } - if v.Int < 0 { - return fmt.Errorf("%d is less than minimum value for uint64", v.Int) + if v.Int64 < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", v.Int64) } - *w = uint64Wrapper(v.Int) + *w = uint64Wrapper(v.Int64) return nil } @@ -219,7 +219,7 @@ func (w uint64Wrapper) Int64Value() (Int8, error) { return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", w) } - return Int8{Int: int64(w), Valid: true}, nil + return Int8{Int64: int64(w), Valid: true}, nil } type uintWrapper uint @@ -231,15 +231,15 @@ func (w *uintWrapper) ScanInt64(v Int8) error { return fmt.Errorf("cannot scan NULL into *uint64") } - if v.Int < 0 { - return fmt.Errorf("%d is less than minimum value for uint64", v.Int) + if v.Int64 < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", v.Int64) } - if uint64(v.Int) > math.MaxUint { - return fmt.Errorf("%d is greater than maximum value for uint", v.Int) + if uint64(v.Int64) > math.MaxUint { + return fmt.Errorf("%d is greater than maximum value for uint", v.Int64) } - *w = uintWrapper(v.Int) + *w = uintWrapper(v.Int64) return nil } @@ -249,7 +249,7 @@ func (w uintWrapper) Int64Value() (Int8, error) { return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", w) } - return Int8{Int: int64(w), Valid: true}, nil + return Int8{Int64: int64(w), Valid: true}, nil } type float32Wrapper float32 @@ -261,7 +261,7 @@ func (w *float32Wrapper) ScanInt64(v Int8) error { return fmt.Errorf("cannot scan NULL into *float32") } - *w = float32Wrapper(v.Int) + *w = float32Wrapper(v.Int64) return nil } @@ -271,7 +271,7 @@ func (w float32Wrapper) Int64Value() (Int8, error) { return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", w) } - return Int8{Int: int64(w), Valid: true}, nil + return Int8{Int64: int64(w), Valid: true}, nil } func (w *float32Wrapper) ScanFloat64(v Float8) error { @@ -297,7 +297,7 @@ func (w *float64Wrapper) ScanInt64(v Int8) error { return fmt.Errorf("cannot scan NULL into *float64") } - *w = float64Wrapper(v.Int) + *w = float64Wrapper(v.Int64) return nil } @@ -307,7 +307,7 @@ func (w float64Wrapper) Int64Value() (Int8, error) { return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", w) } - return Int8{Int: int64(w), Valid: true}, nil + return Int8{Int64: int64(w), Valid: true}, nil } func (w *float64Wrapper) ScanFloat64(v Float8) error { @@ -346,7 +346,7 @@ func (w *stringWrapper) ScanInt64(v Int8) error { return fmt.Errorf("cannot scan NULL into *string") } - *w = stringWrapper(strconv.FormatInt(v.Int, 10)) + *w = stringWrapper(strconv.FormatInt(v.Int64, 10)) return nil } @@ -357,7 +357,7 @@ func (w stringWrapper) Int64Value() (Int8, error) { return Int8{}, err } - return Int8{Int: int64(num), Valid: true}, nil + return Int8{Int64: int64(num), Valid: true}, nil } type timeWrapper time.Time diff --git a/pgtype/float4.go b/pgtype/float4.go index 127eb56a..e1f1fbf1 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -26,12 +26,12 @@ func (f Float4) Float64Value() (Float8, error) { } func (f *Float4) ScanInt64(n Int8) error { - *f = Float4{Float: float32(n.Int), Valid: n.Valid} + *f = Float4{Float: float32(n.Int64), Valid: n.Valid} return nil } func (f Float4) Int64Value() (Int8, error) { - return Int8{Int: int64(f.Float), Valid: f.Valid}, nil + return Int8{Int64: int64(f.Float), Valid: f.Valid}, nil } // Scan implements the database/sql Scanner interface. @@ -141,7 +141,7 @@ func (encodePlanFloat4CodecBinaryInt64Valuer) Encode(value interface{}, buf []by return nil, nil } - f := float32(n.Int) + f := float32(n.Int64) return pgio.AppendUint32(buf, math.Float32bits(f)), nil } @@ -226,7 +226,7 @@ func (scanPlanBinaryFloat4ToInt64Scanner) Scan(src []byte, dst interface{}) erro return fmt.Errorf("cannot losslessly convert %v to int64", f32) } - return s.ScanInt64(Int8{Int: i64, Valid: true}) + return s.ScanInt64(Int8{Int64: i64, Valid: true}) } type scanPlanTextAnyToFloat32 struct{} diff --git a/pgtype/float8.go b/pgtype/float8.go index b8b962b2..1de86a1d 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -34,12 +34,12 @@ func (f Float8) Float64Value() (Float8, error) { } func (f *Float8) ScanInt64(n Int8) error { - *f = Float8{Float: float64(n.Int), Valid: n.Valid} + *f = Float8{Float: float64(n.Int64), Valid: n.Valid} return nil } func (f Float8) Int64Value() (Int8, error) { - return Int8{Int: int64(f.Float), Valid: f.Valid}, nil + return Int8{Int64: int64(f.Float), Valid: f.Valid}, nil } // Scan implements the database/sql Scanner interface. @@ -164,7 +164,7 @@ func (encodePlanFloat8CodecBinaryInt64Valuer) Encode(value interface{}, buf []by return nil, nil } - f := float64(n.Int) + f := float64(n.Int64) return pgio.AppendUint64(buf, math.Float64bits(f)), nil } @@ -180,7 +180,7 @@ func (encodePlanTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf [ return nil, nil } - return append(buf, strconv.FormatInt(n.Int, 10)...), nil + return append(buf, strconv.FormatInt(n.Int64, 10)...), nil } func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { @@ -264,7 +264,7 @@ func (scanPlanBinaryFloat8ToInt64Scanner) Scan(src []byte, dst interface{}) erro return fmt.Errorf("cannot losslessly convert %v to int64", f64) } - return s.ScanInt64(Int8{Int: i64, Valid: true}) + return s.ScanInt64(Int8{Int64: i64, Valid: true}) } type scanPlanTextAnyToFloat64 struct{} diff --git a/pgtype/int.go b/pgtype/int.go index ebac1403..ee4ab932 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -21,7 +21,7 @@ type Int64Valuer interface { } type Int2 struct { - Int int16 + Int16 int16 Valid bool } @@ -32,19 +32,19 @@ func (dst *Int2) ScanInt64(n Int8) error { return nil } - if n.Int < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n.Int) + if n.Int64 < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n.Int64) } - if n.Int > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n.Int) + if n.Int64 > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n.Int64) } - *dst = Int2{Int: int16(n.Int), Valid: true} + *dst = Int2{Int16: int16(n.Int64), Valid: true} return nil } func (n Int2) Int64Value() (Int8, error) { - return Int8{Int: int64(n.Int), Valid: n.Valid}, nil + return Int8{Int64: int64(n.Int16), Valid: n.Valid}, nil } // Scan implements the database/sql Scanner interface. @@ -81,7 +81,7 @@ func (dst *Int2) Scan(src interface{}) error { if n > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", n) } - *dst = Int2{Int: int16(n), Valid: true} + *dst = Int2{Int16: int16(n), Valid: true} return nil } @@ -91,14 +91,14 @@ func (src Int2) Value() (driver.Value, error) { if !src.Valid { return nil, nil } - return int64(src.Int), nil + return int64(src.Int16), nil } func (src Int2) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil } - return []byte(strconv.FormatInt(int64(src.Int), 10)), nil + return []byte(strconv.FormatInt(int64(src.Int16), 10)), nil } func (dst *Int2) UnmarshalJSON(b []byte) error { @@ -111,7 +111,7 @@ func (dst *Int2) UnmarshalJSON(b []byte) error { if n == nil { *dst = Int2{} } else { - *dst = Int2{Int: *n, Valid: true} + *dst = Int2{Int16: *n, Valid: true} } return nil @@ -174,14 +174,14 @@ func (encodePlanInt2CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte return nil, nil } - if n.Int > math.MaxInt16 { - return nil, fmt.Errorf("%d is greater than maximum value for int2", n.Int) + if n.Int64 > math.MaxInt16 { + return nil, fmt.Errorf("%d is greater than maximum value for int2", n.Int64) } - if n.Int < math.MinInt16 { - return nil, fmt.Errorf("%d is less than minimum value for int2", n.Int) + if n.Int64 < math.MinInt16 { + return nil, fmt.Errorf("%d is less than minimum value for int2", n.Int64) } - return pgio.AppendInt16(buf, int16(n.Int)), nil + return pgio.AppendInt16(buf, int16(n.Int64)), nil } type encodePlanInt2CodecTextInt64Valuer struct{} @@ -196,14 +196,14 @@ func (encodePlanInt2CodecTextInt64Valuer) Encode(value interface{}, buf []byte) return nil, nil } - if n.Int > math.MaxInt16 { - return nil, fmt.Errorf("%d is greater than maximum value for int2", n.Int) + if n.Int64 > math.MaxInt16 { + return nil, fmt.Errorf("%d is greater than maximum value for int2", n.Int64) } - if n.Int < math.MinInt16 { - return nil, fmt.Errorf("%d is less than minimum value for int2", n.Int) + if n.Int64 < math.MinInt16 { + return nil, fmt.Errorf("%d is less than minimum value for int2", n.Int64) } - return append(buf, strconv.FormatInt(n.Int, 10)...), nil + return append(buf, strconv.FormatInt(n.Int64, 10)...), nil } func (Int2Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { @@ -554,11 +554,11 @@ func (scanPlanBinaryInt2ToInt64Scanner) Scan(src []byte, dst interface{}) error n := int64(int16(binary.BigEndian.Uint16(src))) - return s.ScanInt64(Int8{Int: n, Valid: true}) + return s.ScanInt64(Int8{Int64: n, Valid: true}) } type Int4 struct { - Int int32 + Int32 int32 Valid bool } @@ -569,19 +569,19 @@ func (dst *Int4) ScanInt64(n Int8) error { return nil } - if n.Int < math.MinInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", n.Int) + if n.Int64 < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n.Int64) } - if n.Int > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", n.Int) + if n.Int64 > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n.Int64) } - *dst = Int4{Int: int32(n.Int), Valid: true} + *dst = Int4{Int32: int32(n.Int64), Valid: true} return nil } func (n Int4) Int64Value() (Int8, error) { - return Int8{Int: int64(n.Int), Valid: n.Valid}, nil + return Int8{Int64: int64(n.Int32), Valid: n.Valid}, nil } // Scan implements the database/sql Scanner interface. @@ -618,7 +618,7 @@ func (dst *Int4) Scan(src interface{}) error { if n > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", n) } - *dst = Int4{Int: int32(n), Valid: true} + *dst = Int4{Int32: int32(n), Valid: true} return nil } @@ -628,14 +628,14 @@ func (src Int4) Value() (driver.Value, error) { if !src.Valid { return nil, nil } - return int64(src.Int), nil + return int64(src.Int32), nil } func (src Int4) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil } - return []byte(strconv.FormatInt(int64(src.Int), 10)), nil + return []byte(strconv.FormatInt(int64(src.Int32), 10)), nil } func (dst *Int4) UnmarshalJSON(b []byte) error { @@ -648,7 +648,7 @@ func (dst *Int4) UnmarshalJSON(b []byte) error { if n == nil { *dst = Int4{} } else { - *dst = Int4{Int: *n, Valid: true} + *dst = Int4{Int32: *n, Valid: true} } return nil @@ -711,14 +711,14 @@ func (encodePlanInt4CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte return nil, nil } - if n.Int > math.MaxInt32 { - return nil, fmt.Errorf("%d is greater than maximum value for int4", n.Int) + if n.Int64 > math.MaxInt32 { + return nil, fmt.Errorf("%d is greater than maximum value for int4", n.Int64) } - if n.Int < math.MinInt32 { - return nil, fmt.Errorf("%d is less than minimum value for int4", n.Int) + if n.Int64 < math.MinInt32 { + return nil, fmt.Errorf("%d is less than minimum value for int4", n.Int64) } - return pgio.AppendInt32(buf, int32(n.Int)), nil + return pgio.AppendInt32(buf, int32(n.Int64)), nil } type encodePlanInt4CodecTextInt64Valuer struct{} @@ -733,14 +733,14 @@ func (encodePlanInt4CodecTextInt64Valuer) Encode(value interface{}, buf []byte) return nil, nil } - if n.Int > math.MaxInt32 { - return nil, fmt.Errorf("%d is greater than maximum value for int4", n.Int) + if n.Int64 > math.MaxInt32 { + return nil, fmt.Errorf("%d is greater than maximum value for int4", n.Int64) } - if n.Int < math.MinInt32 { - return nil, fmt.Errorf("%d is less than minimum value for int4", n.Int) + if n.Int64 < math.MinInt32 { + return nil, fmt.Errorf("%d is less than minimum value for int4", n.Int64) } - return append(buf, strconv.FormatInt(n.Int, 10)...), nil + return append(buf, strconv.FormatInt(n.Int64, 10)...), nil } func (Int4Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { @@ -1102,11 +1102,11 @@ func (scanPlanBinaryInt4ToInt64Scanner) Scan(src []byte, dst interface{}) error n := int64(int32(binary.BigEndian.Uint32(src))) - return s.ScanInt64(Int8{Int: n, Valid: true}) + return s.ScanInt64(Int8{Int64: n, Valid: true}) } type Int8 struct { - Int int64 + Int64 int64 Valid bool } @@ -1117,19 +1117,19 @@ func (dst *Int8) ScanInt64(n Int8) error { return nil } - if n.Int < math.MinInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", n.Int) + if n.Int64 < math.MinInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n.Int64) } - if n.Int > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", n.Int) + if n.Int64 > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n.Int64) } - *dst = Int8{Int: int64(n.Int), Valid: true} + *dst = Int8{Int64: int64(n.Int64), Valid: true} return nil } func (n Int8) Int64Value() (Int8, error) { - return Int8{Int: int64(n.Int), Valid: n.Valid}, nil + return Int8{Int64: int64(n.Int64), Valid: n.Valid}, nil } // Scan implements the database/sql Scanner interface. @@ -1166,7 +1166,7 @@ func (dst *Int8) Scan(src interface{}) error { if n > math.MaxInt64 { return fmt.Errorf("%d is greater than maximum value for Int8", n) } - *dst = Int8{Int: int64(n), Valid: true} + *dst = Int8{Int64: int64(n), Valid: true} return nil } @@ -1176,14 +1176,14 @@ func (src Int8) Value() (driver.Value, error) { if !src.Valid { return nil, nil } - return int64(src.Int), nil + return int64(src.Int64), nil } func (src Int8) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil } - return []byte(strconv.FormatInt(int64(src.Int), 10)), nil + return []byte(strconv.FormatInt(int64(src.Int64), 10)), nil } func (dst *Int8) UnmarshalJSON(b []byte) error { @@ -1196,7 +1196,7 @@ func (dst *Int8) UnmarshalJSON(b []byte) error { if n == nil { *dst = Int8{} } else { - *dst = Int8{Int: *n, Valid: true} + *dst = Int8{Int64: *n, Valid: true} } return nil @@ -1259,14 +1259,14 @@ func (encodePlanInt8CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte return nil, nil } - if n.Int > math.MaxInt64 { - return nil, fmt.Errorf("%d is greater than maximum value for int8", n.Int) + if n.Int64 > math.MaxInt64 { + return nil, fmt.Errorf("%d is greater than maximum value for int8", n.Int64) } - if n.Int < math.MinInt64 { - return nil, fmt.Errorf("%d is less than minimum value for int8", n.Int) + if n.Int64 < math.MinInt64 { + return nil, fmt.Errorf("%d is less than minimum value for int8", n.Int64) } - return pgio.AppendInt64(buf, int64(n.Int)), nil + return pgio.AppendInt64(buf, int64(n.Int64)), nil } type encodePlanInt8CodecTextInt64Valuer struct{} @@ -1281,14 +1281,14 @@ func (encodePlanInt8CodecTextInt64Valuer) Encode(value interface{}, buf []byte) return nil, nil } - if n.Int > math.MaxInt64 { - return nil, fmt.Errorf("%d is greater than maximum value for int8", n.Int) + if n.Int64 > math.MaxInt64 { + return nil, fmt.Errorf("%d is greater than maximum value for int8", n.Int64) } - if n.Int < math.MinInt64 { - return nil, fmt.Errorf("%d is less than minimum value for int8", n.Int) + if n.Int64 < math.MinInt64 { + return nil, fmt.Errorf("%d is less than minimum value for int8", n.Int64) } - return append(buf, strconv.FormatInt(n.Int, 10)...), nil + return append(buf, strconv.FormatInt(n.Int64, 10)...), nil } func (Int8Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { @@ -1672,7 +1672,7 @@ func (scanPlanBinaryInt8ToInt64Scanner) Scan(src []byte, dst interface{}) error n := int64(int64(binary.BigEndian.Uint64(src))) - return s.ScanInt64(Int8{Int: n, Valid: true}) + return s.ScanInt64(Int8{Int64: n, Valid: true}) } type scanPlanTextAnyToInt8 struct{} @@ -1902,7 +1902,7 @@ func (scanPlanTextAnyToInt64Scanner) Scan(src []byte, dst interface{}) error { return err } - err = s.ScanInt64(Int8{Int: n, Valid: true}) + err = s.ScanInt64(Int8{Int64: n, Valid: true}) if err != nil { return err } diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 29846744..81f28bba 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -22,7 +22,7 @@ type Int64Valuer interface { <% [2, 4, 8].each do |pg_byte_size| %> <% pg_bit_size = pg_byte_size * 8 %> type Int<%= pg_byte_size %> struct { - Int int<%= pg_bit_size %> + Int<%= pg_bit_size %> int<%= pg_bit_size %> Valid bool } @@ -33,19 +33,19 @@ func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error { return nil } - if n.Int < math.MinInt<%= pg_bit_size %> { - return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int) + if n.Int64 < math.MinInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int64) } - if n.Int > math.MaxInt<%= pg_bit_size %> { - return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int) + if n.Int64 > math.MaxInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int64) } - *dst = Int<%= pg_byte_size %>{Int: int<%= pg_bit_size %>(n.Int), Valid: true} + *dst = Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: int<%= pg_bit_size %>(n.Int64), Valid: true} return nil } func (n Int<%= pg_byte_size %>) Int64Value() (Int8, error) { - return Int8{Int: int64(n.Int), Valid: n.Valid}, nil + return Int8{Int64: int64(n.Int<%= pg_bit_size %>), Valid: n.Valid}, nil } // Scan implements the database/sql Scanner interface. @@ -82,7 +82,7 @@ func (dst *Int<%= pg_byte_size %>) Scan(src interface{}) error { if n > math.MaxInt<%= pg_bit_size %> { return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) } - *dst = Int<%= pg_byte_size %>{Int: int<%= pg_bit_size %>(n), Valid: true} + *dst = Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: int<%= pg_bit_size %>(n), Valid: true} return nil } @@ -92,14 +92,14 @@ func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { if !src.Valid { return nil, nil } - return int64(src.Int), nil + return int64(src.Int<%= pg_bit_size %>), nil } func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil } - return []byte(strconv.FormatInt(int64(src.Int), 10)), nil + return []byte(strconv.FormatInt(int64(src.Int<%= pg_bit_size %>), 10)), nil } func (dst *Int<%= pg_byte_size %>) UnmarshalJSON(b []byte) error { @@ -112,7 +112,7 @@ func (dst *Int<%= pg_byte_size %>) UnmarshalJSON(b []byte) error { if n == nil { *dst = Int<%= pg_byte_size %>{} } else { - *dst = Int<%= pg_byte_size %>{Int: *n, Valid: true} + *dst = Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: *n, Valid: true} } return nil @@ -175,14 +175,14 @@ func (encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer) Encode(value inter return nil, nil } - if n.Int > math.MaxInt<%= pg_bit_size %> { - return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n.Int) + if n.Int64 > math.MaxInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n.Int64) } - if n.Int < math.MinInt<%= pg_bit_size %> { - return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n.Int) + if n.Int64 < math.MinInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n.Int64) } - return pgio.AppendInt<%= pg_bit_size %>(buf, int<%= pg_bit_size %>(n.Int)), nil + return pgio.AppendInt<%= pg_bit_size %>(buf, int<%= pg_bit_size %>(n.Int64)), nil } type encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer struct{} @@ -197,14 +197,14 @@ func (encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer) Encode(value interfa return nil, nil } - if n.Int > math.MaxInt<%= pg_bit_size %> { - return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n.Int) + if n.Int64 > math.MaxInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n.Int64) } - if n.Int < math.MinInt<%= pg_bit_size %> { - return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n.Int) + if n.Int64 < math.MinInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n.Int64) } - return append(buf, strconv.FormatInt(n.Int, 10)...), nil + return append(buf, strconv.FormatInt(n.Int64, 10)...), nil } func (Int<%= pg_byte_size %>Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { @@ -441,7 +441,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner) Scan(src []byte, dst i n := int64(int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src))) - return s.ScanInt64(Int8{Int: n, Valid: true}) + return s.ScanInt64(Int8{Int64: n, Valid: true}) } <% end %> @@ -513,7 +513,7 @@ func (scanPlanTextAnyToInt64Scanner) Scan(src []byte, dst interface{}) error { return err } - err = s.ScanInt64(Int8{Int: n, Valid: true}) + err = s.ScanInt64(Int8{Int64: n, Valid: true}) if err != nil { return err } diff --git a/pgtype/int_test.go b/pgtype/int_test.go index a2e64f4e..6dc65259 100644 --- a/pgtype/int_test.go +++ b/pgtype/int_test.go @@ -21,8 +21,8 @@ func TestInt2Codec(t *testing.T) { {uint64(1), new(int16), isExpectedEq(int16(1))}, {int(1), new(int16), isExpectedEq(int16(1))}, {uint(1), new(int16), isExpectedEq(int16(1))}, - {pgtype.Int2{Int: 1, Valid: true}, new(int16), isExpectedEq(int16(1))}, - {int32(-1), new(pgtype.Int2), isExpectedEq(pgtype.Int2{Int: -1, Valid: true})}, + {pgtype.Int2{Int16: 1, Valid: true}, new(int16), isExpectedEq(int16(1))}, + {int32(-1), new(pgtype.Int2), isExpectedEq(pgtype.Int2{Int16: -1, Valid: true})}, {1, new(int8), isExpectedEq(int8(1))}, {1, new(int16), isExpectedEq(int16(1))}, {1, new(int32), isExpectedEq(int32(1))}, @@ -43,7 +43,7 @@ func TestInt2Codec(t *testing.T) { {0, new(int16), isExpectedEq(int16(0))}, {1, new(int16), isExpectedEq(int16(1))}, {math.MaxInt16, new(int16), isExpectedEq(int16(math.MaxInt16))}, - {1, new(pgtype.Int2), isExpectedEq(pgtype.Int2{Int: 1, Valid: true})}, + {1, new(pgtype.Int2), isExpectedEq(pgtype.Int2{Int16: 1, Valid: true})}, {pgtype.Int2{}, new(pgtype.Int2), isExpectedEq(pgtype.Int2{})}, {nil, new(*int16), isExpectedEq((*int16)(nil))}, }) @@ -54,8 +54,8 @@ func TestInt2MarshalJSON(t *testing.T) { source pgtype.Int2 result string }{ - {source: pgtype.Int2{Int: 0}, result: "null"}, - {source: pgtype.Int2{Int: 1, Valid: true}, result: "1"}, + {source: pgtype.Int2{Int16: 0}, result: "null"}, + {source: pgtype.Int2{Int16: 1, Valid: true}, result: "1"}, } for i, tt := range successfulTests { r, err := tt.source.MarshalJSON() @@ -74,8 +74,8 @@ func TestInt2UnmarshalJSON(t *testing.T) { source string result pgtype.Int2 }{ - {source: "null", result: pgtype.Int2{Int: 0}}, - {source: "1", result: pgtype.Int2{Int: 1, Valid: true}}, + {source: "null", result: pgtype.Int2{Int16: 0}}, + {source: "1", result: pgtype.Int2{Int16: 1, Valid: true}}, } for i, tt := range successfulTests { var r pgtype.Int2 @@ -102,8 +102,8 @@ func TestInt4Codec(t *testing.T) { {uint64(1), new(int32), isExpectedEq(int32(1))}, {int(1), new(int32), isExpectedEq(int32(1))}, {uint(1), new(int32), isExpectedEq(int32(1))}, - {pgtype.Int4{Int: 1, Valid: true}, new(int32), isExpectedEq(int32(1))}, - {int32(-1), new(pgtype.Int4), isExpectedEq(pgtype.Int4{Int: -1, Valid: true})}, + {pgtype.Int4{Int32: 1, Valid: true}, new(int32), isExpectedEq(int32(1))}, + {int32(-1), new(pgtype.Int4), isExpectedEq(pgtype.Int4{Int32: -1, Valid: true})}, {1, new(int8), isExpectedEq(int8(1))}, {1, new(int16), isExpectedEq(int16(1))}, {1, new(int32), isExpectedEq(int32(1))}, @@ -124,7 +124,7 @@ func TestInt4Codec(t *testing.T) { {0, new(int32), isExpectedEq(int32(0))}, {1, new(int32), isExpectedEq(int32(1))}, {math.MaxInt32, new(int32), isExpectedEq(int32(math.MaxInt32))}, - {1, new(pgtype.Int4), isExpectedEq(pgtype.Int4{Int: 1, Valid: true})}, + {1, new(pgtype.Int4), isExpectedEq(pgtype.Int4{Int32: 1, Valid: true})}, {pgtype.Int4{}, new(pgtype.Int4), isExpectedEq(pgtype.Int4{})}, {nil, new(*int32), isExpectedEq((*int32)(nil))}, }) @@ -135,8 +135,8 @@ func TestInt4MarshalJSON(t *testing.T) { source pgtype.Int4 result string }{ - {source: pgtype.Int4{Int: 0}, result: "null"}, - {source: pgtype.Int4{Int: 1, Valid: true}, result: "1"}, + {source: pgtype.Int4{Int32: 0}, result: "null"}, + {source: pgtype.Int4{Int32: 1, Valid: true}, result: "1"}, } for i, tt := range successfulTests { r, err := tt.source.MarshalJSON() @@ -155,8 +155,8 @@ func TestInt4UnmarshalJSON(t *testing.T) { source string result pgtype.Int4 }{ - {source: "null", result: pgtype.Int4{Int: 0}}, - {source: "1", result: pgtype.Int4{Int: 1, Valid: true}}, + {source: "null", result: pgtype.Int4{Int32: 0}}, + {source: "1", result: pgtype.Int4{Int32: 1, Valid: true}}, } for i, tt := range successfulTests { var r pgtype.Int4 @@ -183,8 +183,8 @@ func TestInt8Codec(t *testing.T) { {uint64(1), new(int64), isExpectedEq(int64(1))}, {int(1), new(int64), isExpectedEq(int64(1))}, {uint(1), new(int64), isExpectedEq(int64(1))}, - {pgtype.Int8{Int: 1, Valid: true}, new(int64), isExpectedEq(int64(1))}, - {int32(-1), new(pgtype.Int8), isExpectedEq(pgtype.Int8{Int: -1, Valid: true})}, + {pgtype.Int8{Int64: 1, Valid: true}, new(int64), isExpectedEq(int64(1))}, + {int32(-1), new(pgtype.Int8), isExpectedEq(pgtype.Int8{Int64: -1, Valid: true})}, {1, new(int8), isExpectedEq(int8(1))}, {1, new(int16), isExpectedEq(int16(1))}, {1, new(int32), isExpectedEq(int32(1))}, @@ -205,7 +205,7 @@ func TestInt8Codec(t *testing.T) { {0, new(int64), isExpectedEq(int64(0))}, {1, new(int64), isExpectedEq(int64(1))}, {math.MaxInt64, new(int64), isExpectedEq(int64(math.MaxInt64))}, - {1, new(pgtype.Int8), isExpectedEq(pgtype.Int8{Int: 1, Valid: true})}, + {1, new(pgtype.Int8), isExpectedEq(pgtype.Int8{Int64: 1, Valid: true})}, {pgtype.Int8{}, new(pgtype.Int8), isExpectedEq(pgtype.Int8{})}, {nil, new(*int64), isExpectedEq((*int64)(nil))}, }) @@ -216,8 +216,8 @@ func TestInt8MarshalJSON(t *testing.T) { source pgtype.Int8 result string }{ - {source: pgtype.Int8{Int: 0}, result: "null"}, - {source: pgtype.Int8{Int: 1, Valid: true}, result: "1"}, + {source: pgtype.Int8{Int64: 0}, result: "null"}, + {source: pgtype.Int8{Int64: 1, Valid: true}, result: "1"}, } for i, tt := range successfulTests { r, err := tt.source.MarshalJSON() @@ -236,8 +236,8 @@ func TestInt8UnmarshalJSON(t *testing.T) { source string result pgtype.Int8 }{ - {source: "null", result: pgtype.Int8{Int: 0}}, - {source: "1", result: pgtype.Int8{Int: 1, Valid: true}}, + {source: "null", result: pgtype.Int8{Int64: 0}}, + {source: "1", result: pgtype.Int8{Int64: 1, Valid: true}}, } for i, tt := range successfulTests { var r pgtype.Int8 diff --git a/pgtype/numeric.go b/pgtype/numeric.go index 58707a02..da805ad8 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -324,7 +324,7 @@ func (encodePlanNumericCodecBinaryInt64Valuer) Encode(value interface{}, buf []b return nil, nil } - return encodeNumericBinary(Numeric{Int: big.NewInt(n.Int), Valid: true}, buf) + return encodeNumericBinary(Numeric{Int: big.NewInt(n.Int64), Valid: true}, buf) } func encodeNumericBinary(n Numeric, buf []byte) (newBuf []byte, err error) { @@ -476,7 +476,7 @@ func (encodePlanNumericCodecTextInt64Valuer) Encode(value interface{}, buf []byt return nil, nil } - return encodeNumericText(Numeric{Int: big.NewInt(n.Int), Valid: true}, buf) + return encodeNumericText(Numeric{Int: big.NewInt(n.Int64), Valid: true}, buf) } func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { @@ -682,7 +682,7 @@ func (scanPlanBinaryNumericToInt64Scanner) Scan(src []byte, dst interface{}) err return fmt.Errorf("%v is out of range for int64", bigInt) } - return scanner.ScanInt64(Int8{Int: bigInt.Int64(), Valid: true}) + return scanner.ScanInt64(Int8{Int64: bigInt.Int64(), Valid: true}) } type scanPlanTextAnyToNumericScanner struct{} diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index bbec30f3..ff19790a 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -205,7 +205,7 @@ func BenchmarkTypeMapScanInt4IntoBinaryDecoder(b *testing.B) { if err != nil { b.Fatal(err) } - if v != (pgtype.Int4{Int: 42, Valid: true}) { + if v != (pgtype.Int4{Int32: 42, Valid: true}) { b.Fatal("scan failed due to bad value") } } @@ -241,7 +241,7 @@ func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) { if err != nil { b.Fatal(err) } - if v != (pgtype.Int4{Int: 42, Valid: true}) { + if v != (pgtype.Int4{Int32: 42, Valid: true}) { b.Fatal("scan failed due to bad value") } } diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go index 6597ab98..1ed3d552 100644 --- a/pgtype/range_codec_test.go +++ b/pgtype/range_codec_test.go @@ -19,15 +19,15 @@ func TestRangeCodecTranscode(t *testing.T) { { pgtype.Int4range{ LowerType: pgtype.Inclusive, - Lower: pgtype.Int4{Int: 1, Valid: true}, - Upper: pgtype.Int4{Int: 5, Valid: true}, + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, UpperType: pgtype.Exclusive, Valid: true, }, new(pgtype.Int4range), isExpectedEq(pgtype.Int4range{ LowerType: pgtype.Inclusive, - Lower: pgtype.Int4{Int: 1, Valid: true}, - Upper: pgtype.Int4{Int: 5, Valid: true}, + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, UpperType: pgtype.Exclusive, Valid: true, }), }, @@ -75,8 +75,8 @@ func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { require.Equal( t, pgtype.Int4range{ - Lower: pgtype.Int4{Int: 1, Valid: true}, - Upper: pgtype.Int4{Int: 5, Valid: true}, + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true, @@ -90,7 +90,7 @@ func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { require.Equal( t, pgtype.Int4range{ - Lower: pgtype.Int4{Int: 1, Valid: true}, + Lower: pgtype.Int4{Int32: 1, Valid: true}, Upper: pgtype.Int4{}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, diff --git a/pgtype/uint32.go b/pgtype/uint32.go index 297ca5c2..344549ae 100644 --- a/pgtype/uint32.go +++ b/pgtype/uint32.go @@ -142,14 +142,14 @@ func (encodePlanUint32CodecBinaryInt64Valuer) Encode(value interface{}, buf []by return nil, nil } - if v.Int < 0 { - return nil, fmt.Errorf("%d is less than minimum value for uint32", v.Int) + if v.Int64 < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint32", v.Int64) } - if v.Int > math.MaxUint32 { - return nil, fmt.Errorf("%d is greater than maximum value for uint32", v.Int) + if v.Int64 > math.MaxUint32 { + return nil, fmt.Errorf("%d is greater than maximum value for uint32", v.Int64) } - return pgio.AppendUint32(buf, uint32(v.Int)), nil + return pgio.AppendUint32(buf, uint32(v.Int64)), nil } type encodePlanUint32CodecTextUint32 struct{} @@ -186,14 +186,14 @@ func (encodePlanUint32CodecTextInt64Valuer) Encode(value interface{}, buf []byte return nil, nil } - if v.Int < 0 { - return nil, fmt.Errorf("%d is less than minimum value for uint32", v.Int) + if v.Int64 < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint32", v.Int64) } - if v.Int > math.MaxUint32 { - return nil, fmt.Errorf("%d is greater than maximum value for uint32", v.Int) + if v.Int64 > math.MaxUint32 { + return nil, fmt.Errorf("%d is greater than maximum value for uint32", v.Int64) } - return append(buf, strconv.FormatInt(v.Int, 10)...), nil + return append(buf, strconv.FormatInt(v.Int64, 10)...), nil } func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { diff --git a/pgtype/zeronull/int.go b/pgtype/zeronull/int.go index 1d479307..9a40691f 100644 --- a/pgtype/zeronull/int.go +++ b/pgtype/zeronull/int.go @@ -44,7 +44,7 @@ func (dst *Int2) Scan(src interface{}) error { return err } - *dst = Int2(nullable.Int) + *dst = Int2(nullable.Int16) return nil } @@ -92,7 +92,7 @@ func (dst *Int4) Scan(src interface{}) error { return err } - *dst = Int4(nullable.Int) + *dst = Int4(nullable.Int32) return nil } @@ -140,7 +140,7 @@ func (dst *Int8) Scan(src interface{}) error { return err } - *dst = Int8(nullable.Int) + *dst = Int8(nullable.Int64) return nil } diff --git a/pgtype/zeronull/int.go.erb b/pgtype/zeronull/int.go.erb index 9e3b5ef0..cdae7597 100644 --- a/pgtype/zeronull/int.go.erb +++ b/pgtype/zeronull/int.go.erb @@ -45,7 +45,7 @@ func (dst *Int<%= pg_byte_size %>) Scan(src interface{}) error { return err } - *dst = Int<%= pg_byte_size %>(nullable.Int) + *dst = Int<%= pg_byte_size %>(nullable.Int<%= pg_bit_size %>) return nil } From 84a3d913228c767285fde522e8ed00985629990e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 09:20:03 -0600 Subject: [PATCH 0937/1158] pgtype Float4 and Float8 fields include bit size e.g. Instead of Float it is Float64. This matches the pattern set by the database/sql types. --- pgtype/builtin_wrappers.go | 8 ++++---- pgtype/float4.go | 22 +++++++++++----------- pgtype/float4_test.go | 6 +++--- pgtype/float8.go | 22 +++++++++++----------- pgtype/float8_test.go | 6 +++--- pgtype/numeric.go | 24 ++++++++++++------------ pgtype/numeric_test.go | 12 ++++++------ pgtype/range_codec_test.go | 8 ++++---- pgtype/zeronull/float8.go | 6 +++--- 9 files changed, 57 insertions(+), 57 deletions(-) diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 3706c994..f182f570 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -279,13 +279,13 @@ func (w *float32Wrapper) ScanFloat64(v Float8) error { return fmt.Errorf("cannot scan NULL into *float32") } - *w = float32Wrapper(v.Float) + *w = float32Wrapper(v.Float64) return nil } func (w float32Wrapper) Float64Value() (Float8, error) { - return Float8{Float: float64(w), Valid: true}, nil + return Float8{Float64: float64(w), Valid: true}, nil } type float64Wrapper float64 @@ -315,13 +315,13 @@ func (w *float64Wrapper) ScanFloat64(v Float8) error { return fmt.Errorf("cannot scan NULL into *float64") } - *w = float64Wrapper(v.Float) + *w = float64Wrapper(v.Float64) return nil } func (w float64Wrapper) Float64Value() (Float8, error) { - return Float8{Float: float64(w), Valid: true}, nil + return Float8{Float64: float64(w), Valid: true}, nil } type stringWrapper string diff --git a/pgtype/float4.go b/pgtype/float4.go index e1f1fbf1..fb84c124 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -11,27 +11,27 @@ import ( ) type Float4 struct { - Float float32 - Valid bool + Float32 float32 + Valid bool } // ScanFloat64 implements the Float64Scanner interface. func (f *Float4) ScanFloat64(n Float8) error { - *f = Float4{Float: float32(n.Float), Valid: n.Valid} + *f = Float4{Float32: float32(n.Float64), Valid: n.Valid} return nil } func (f Float4) Float64Value() (Float8, error) { - return Float8{Float: float64(f.Float), Valid: f.Valid}, nil + return Float8{Float64: float64(f.Float32), Valid: f.Valid}, nil } func (f *Float4) ScanInt64(n Int8) error { - *f = Float4{Float: float32(n.Int64), Valid: n.Valid} + *f = Float4{Float32: float32(n.Int64), Valid: n.Valid} return nil } func (f Float4) Int64Value() (Int8, error) { - return Int8{Int64: int64(f.Float), Valid: f.Valid}, nil + return Int8{Int64: int64(f.Float32), Valid: f.Valid}, nil } // Scan implements the database/sql Scanner interface. @@ -43,14 +43,14 @@ func (f *Float4) Scan(src interface{}) error { switch src := src.(type) { case float64: - *f = Float4{Float: float32(src), Valid: true} + *f = Float4{Float32: float32(src), Valid: true} return nil case string: n, err := strconv.ParseFloat(string(src), 32) if err != nil { return err } - *f = Float4{Float: float32(n), Valid: true} + *f = Float4{Float32: float32(n), Valid: true} return nil } @@ -62,7 +62,7 @@ func (f Float4) Value() (driver.Value, error) { if !f.Valid { return nil, nil } - return float64(f.Float), nil + return float64(f.Float32), nil } type Float4Codec struct{} @@ -126,7 +126,7 @@ func (encodePlanFloat4CodecBinaryFloat64Valuer) Encode(value interface{}, buf [] return nil, nil } - return pgio.AppendUint32(buf, math.Float32bits(float32(n.Float))), nil + return pgio.AppendUint32(buf, math.Float32bits(float32(n.Float64))), nil } type encodePlanFloat4CodecBinaryInt64Valuer struct{} @@ -203,7 +203,7 @@ func (scanPlanBinaryFloat4ToFloat64Scanner) Scan(src []byte, dst interface{}) er } n := int32(binary.BigEndian.Uint32(src)) - return s.ScanFloat64(Float8{Float: float64(math.Float32frombits(uint32(n))), Valid: true}) + return s.ScanFloat64(Float8{Float64: float64(math.Float32frombits(uint32(n))), Valid: true}) } type scanPlanBinaryFloat4ToInt64Scanner struct{} diff --git a/pgtype/float4_test.go b/pgtype/float4_test.go index a0069836..39d7ee75 100644 --- a/pgtype/float4_test.go +++ b/pgtype/float4_test.go @@ -9,9 +9,9 @@ import ( func TestFloat4Codec(t *testing.T) { testutil.RunTranscodeTests(t, "float4", []testutil.TranscodeTestCase{ - {pgtype.Float4{Float: -1, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float: -1, Valid: true})}, - {pgtype.Float4{Float: 0, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float: 0, Valid: true})}, - {pgtype.Float4{Float: 1, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float: 1, Valid: true})}, + {pgtype.Float4{Float32: -1, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float32: -1, Valid: true})}, + {pgtype.Float4{Float32: 0, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float32: 0, Valid: true})}, + {pgtype.Float4{Float32: 1, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float32: 1, Valid: true})}, {float32(0.00001), new(float32), isExpectedEq(float32(0.00001))}, {float32(9999.99), new(float32), isExpectedEq(float32(9999.99))}, {pgtype.Float4{}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{})}, diff --git a/pgtype/float8.go b/pgtype/float8.go index 1de86a1d..664fb9f8 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -19,8 +19,8 @@ type Float64Valuer interface { } type Float8 struct { - Float float64 - Valid bool + Float64 float64 + Valid bool } // ScanFloat64 implements the Float64Scanner interface. @@ -34,12 +34,12 @@ func (f Float8) Float64Value() (Float8, error) { } func (f *Float8) ScanInt64(n Int8) error { - *f = Float8{Float: float64(n.Int64), Valid: n.Valid} + *f = Float8{Float64: float64(n.Int64), Valid: n.Valid} return nil } func (f Float8) Int64Value() (Int8, error) { - return Int8{Int64: int64(f.Float), Valid: f.Valid}, nil + return Int8{Int64: int64(f.Float64), Valid: f.Valid}, nil } // Scan implements the database/sql Scanner interface. @@ -51,14 +51,14 @@ func (f *Float8) Scan(src interface{}) error { switch src := src.(type) { case float64: - *f = Float8{Float: src, Valid: true} + *f = Float8{Float64: src, Valid: true} return nil case string: n, err := strconv.ParseFloat(string(src), 64) if err != nil { return err } - *f = Float8{Float: n, Valid: true} + *f = Float8{Float64: n, Valid: true} return nil } @@ -70,7 +70,7 @@ func (f Float8) Value() (driver.Value, error) { if !f.Valid { return nil, nil } - return f.Float, nil + return f.Float64, nil } type Float8Codec struct{} @@ -134,7 +134,7 @@ func (encodePlanFloat8CodecBinaryFloat64Valuer) Encode(value interface{}, buf [] return nil, nil } - return pgio.AppendUint64(buf, math.Float64bits(n.Float)), nil + return pgio.AppendUint64(buf, math.Float64bits(n.Float64)), nil } type encodePlanTextFloat64Valuer struct{} @@ -149,7 +149,7 @@ func (encodePlanTextFloat64Valuer) Encode(value interface{}, buf []byte) (newBuf return nil, nil } - return append(buf, strconv.FormatFloat(n.Float, 'f', -1, 64)...), nil + return append(buf, strconv.FormatFloat(n.Float64, 'f', -1, 64)...), nil } type encodePlanFloat8CodecBinaryInt64Valuer struct{} @@ -241,7 +241,7 @@ func (scanPlanBinaryFloat8ToFloat64Scanner) Scan(src []byte, dst interface{}) er } n := int64(binary.BigEndian.Uint64(src)) - return s.ScanFloat64(Float8{Float: math.Float64frombits(uint64(n)), Valid: true}) + return s.ScanFloat64(Float8{Float64: math.Float64frombits(uint64(n)), Valid: true}) } type scanPlanBinaryFloat8ToInt64Scanner struct{} @@ -299,7 +299,7 @@ func (scanPlanTextAnyToFloat64Scanner) Scan(src []byte, dst interface{}) error { return err } - return s.ScanFloat64(Float8{Float: n, Valid: true}) + return s.ScanFloat64(Float8{Float64: n, Valid: true}) } func (c Float8Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { diff --git a/pgtype/float8_test.go b/pgtype/float8_test.go index e69174eb..29bd6f31 100644 --- a/pgtype/float8_test.go +++ b/pgtype/float8_test.go @@ -9,9 +9,9 @@ import ( func TestFloat8Codec(t *testing.T) { testutil.RunTranscodeTests(t, "float8", []testutil.TranscodeTestCase{ - {pgtype.Float8{Float: -1, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float: -1, Valid: true})}, - {pgtype.Float8{Float: 0, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float: 0, Valid: true})}, - {pgtype.Float8{Float: 1, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float: 1, Valid: true})}, + {pgtype.Float8{Float64: -1, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float64: -1, Valid: true})}, + {pgtype.Float8{Float64: 0, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float64: 0, Valid: true})}, + {pgtype.Float8{Float64: 1, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float64: 1, Valid: true})}, {float64(0.00001), new(float64), isExpectedEq(float64(0.00001))}, {float64(9999.99), new(float64), isExpectedEq(float64(9999.99))}, {pgtype.Float8{}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{})}, diff --git a/pgtype/numeric.go b/pgtype/numeric.go index da805ad8..426b6782 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -84,11 +84,11 @@ func (n Numeric) Float64Value() (Float8, error) { if !n.Valid { return Float8{}, nil } else if n.NaN { - return Float8{Float: math.NaN(), Valid: true}, nil + return Float8{Float64: math.NaN(), Valid: true}, nil } else if n.InfinityModifier == Infinity { - return Float8{Float: math.Inf(1), Valid: true}, nil + return Float8{Float64: math.Inf(1), Valid: true}, nil } else if n.InfinityModifier == NegativeInfinity { - return Float8{Float: math.Inf(-1), Valid: true}, nil + return Float8{Float64: math.Inf(-1), Valid: true}, nil } buf := make([]byte, 0, 32) @@ -106,7 +106,7 @@ func (n Numeric) Float64Value() (Float8, error) { return Float8{}, err } - return Float8{Float: f, Valid: true}, nil + return Float8{Float64: f, Valid: true}, nil } func (n *Numeric) toBigInt() (*big.Int, error) { @@ -297,14 +297,14 @@ func (encodePlanNumericCodecBinaryFloat64Valuer) Encode(value interface{}, buf [ return nil, nil } - if math.IsNaN(n.Float) { + if math.IsNaN(n.Float64) { return encodeNumericBinary(Numeric{NaN: true, Valid: true}, buf) - } else if math.IsInf(n.Float, 1) { + } else if math.IsInf(n.Float64, 1) { return encodeNumericBinary(Numeric{InfinityModifier: Infinity, Valid: true}, buf) - } else if math.IsInf(n.Float, -1) { + } else if math.IsInf(n.Float64, -1) { return encodeNumericBinary(Numeric{InfinityModifier: NegativeInfinity, Valid: true}, buf) } - num, exp, err := parseNumericString(strconv.FormatFloat(n.Float, 'f', -1, 64)) + num, exp, err := parseNumericString(strconv.FormatFloat(n.Float64, 'f', -1, 64)) if err != nil { return nil, err } @@ -449,14 +449,14 @@ func (encodePlanNumericCodecTextFloat64Valuer) Encode(value interface{}, buf []b return nil, nil } - if math.IsNaN(n.Float) { + if math.IsNaN(n.Float64) { return encodeNumericBinary(Numeric{NaN: true, Valid: true}, buf) - } else if math.IsInf(n.Float, 1) { + } else if math.IsInf(n.Float64, 1) { return encodeNumericBinary(Numeric{InfinityModifier: Infinity, Valid: true}, buf) - } else if math.IsInf(n.Float, -1) { + } else if math.IsInf(n.Float64, -1) { return encodeNumericBinary(Numeric{InfinityModifier: NegativeInfinity, Valid: true}, buf) } - num, exp, err := parseNumericString(strconv.FormatFloat(n.Float, 'f', -1, 64)) + num, exp, err := parseNumericString(strconv.FormatFloat(n.Float64, 'f', -1, 64)) if err != nil { return nil, err } diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 91b77881..448cfff2 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -123,11 +123,11 @@ func TestNumericFloat64Valuer(t *testing.T) { n pgtype.Numeric f pgtype.Float8 }{ - {mustParseNumeric(t, "1"), pgtype.Float8{Float: 1, Valid: true}}, - {mustParseNumeric(t, "0.0000000000000000001"), pgtype.Float8{Float: 0.0000000000000000001, Valid: true}}, - {mustParseNumeric(t, "-99999999999"), pgtype.Float8{Float: -99999999999, Valid: true}}, - {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, pgtype.Float8{Float: math.Inf(1), Valid: true}}, - {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, pgtype.Float8{Float: math.Inf(-1), Valid: true}}, + {mustParseNumeric(t, "1"), pgtype.Float8{Float64: 1, Valid: true}}, + {mustParseNumeric(t, "0.0000000000000000001"), pgtype.Float8{Float64: 0.0000000000000000001, Valid: true}}, + {mustParseNumeric(t, "-99999999999"), pgtype.Float8{Float64: -99999999999, Valid: true}}, + {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, pgtype.Float8{Float64: math.Inf(1), Valid: true}}, + {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, pgtype.Float8{Float64: math.Inf(-1), Valid: true}}, {pgtype.Numeric{Valid: true}, pgtype.Float8{Valid: true}}, {pgtype.Numeric{}, pgtype.Float8{}}, } { @@ -138,7 +138,7 @@ func TestNumericFloat64Valuer(t *testing.T) { f, err := pgtype.Numeric{NaN: true, Valid: true}.Float64Value() assert.NoError(t, err) - assert.True(t, math.IsNaN(f.Float)) + assert.True(t, math.IsNaN(f.Float64)) assert.True(t, f.Valid) } diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go index 1ed3d552..84a55a52 100644 --- a/pgtype/range_codec_test.go +++ b/pgtype/range_codec_test.go @@ -46,15 +46,15 @@ func TestRangeCodecTranscodeCompatibleRangeElementTypes(t *testing.T) { { pgtype.Float8range{ LowerType: pgtype.Inclusive, - Lower: pgtype.Float8{Float: 1, Valid: true}, - Upper: pgtype.Float8{Float: 5, Valid: true}, + Lower: pgtype.Float8{Float64: 1, Valid: true}, + Upper: pgtype.Float8{Float64: 5, Valid: true}, UpperType: pgtype.Exclusive, Valid: true, }, new(pgtype.Float8range), isExpectedEq(pgtype.Float8range{ LowerType: pgtype.Inclusive, - Lower: pgtype.Float8{Float: 1, Valid: true}, - Upper: pgtype.Float8{Float: 5, Valid: true}, + Lower: pgtype.Float8{Float64: 1, Valid: true}, + Upper: pgtype.Float8{Float64: 5, Valid: true}, UpperType: pgtype.Exclusive, Valid: true, }), }, diff --git a/pgtype/zeronull/float8.go b/pgtype/zeronull/float8.go index d1c053c5..7f3a06b4 100644 --- a/pgtype/zeronull/float8.go +++ b/pgtype/zeronull/float8.go @@ -17,7 +17,7 @@ func (f *Float8) ScanFloat64(n pgtype.Float8) error { return nil } - *f = Float8(n.Float) + *f = Float8(n.Float64) return nil } @@ -26,7 +26,7 @@ func (f Float8) Float64Value() (pgtype.Float8, error) { if f == 0 { return pgtype.Float8{}, nil } - return pgtype.Float8{Float: float64(f), Valid: true}, nil + return pgtype.Float8{Float64: float64(f), Valid: true}, nil } // Scan implements the database/sql Scanner interface. @@ -42,7 +42,7 @@ func (f *Float8) Scan(src interface{}) error { return err } - *f = Float8(nullable.Float) + *f = Float8(nullable.Float64) return nil } From 2885b039d5951674753b82849872767768d493c5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 09:23:25 -0600 Subject: [PATCH 0938/1158] Rename Uint32 field to include bit size i.e. Uint renamed to Uint32. This matches the pattern set by the database/sql types. --- pgtype/uint32.go | 16 ++++++++-------- pgtype/uint32_test.go | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pgtype/uint32.go b/pgtype/uint32.go index 344549ae..25d2423a 100644 --- a/pgtype/uint32.go +++ b/pgtype/uint32.go @@ -20,8 +20,8 @@ type Uint32Valuer interface { // Uint32 is the core type that is used to represent PostgreSQL types such as OID, CID, and XID. type Uint32 struct { - Uint uint32 - Valid bool + Uint32 uint32 + Valid bool } func (n *Uint32) ScanUint32(v Uint32) error { @@ -62,7 +62,7 @@ func (dst *Uint32) Scan(src interface{}) error { return fmt.Errorf("%d is greater than maximum value for Uint32", n) } - *dst = Uint32{Uint: uint32(n), Valid: true} + *dst = Uint32{Uint32: uint32(n), Valid: true} return nil } @@ -72,7 +72,7 @@ func (src Uint32) Value() (driver.Value, error) { if !src.Valid { return nil, nil } - return int64(src.Uint), nil + return int64(src.Uint32), nil } type Uint32Codec struct{} @@ -127,7 +127,7 @@ func (encodePlanUint32CodecBinaryUint32Valuer) Encode(value interface{}, buf []b return nil, nil } - return pgio.AppendUint32(buf, v.Uint), nil + return pgio.AppendUint32(buf, v.Uint32), nil } type encodePlanUint32CodecBinaryInt64Valuer struct{} @@ -171,7 +171,7 @@ func (encodePlanUint32CodecTextUint32Valuer) Encode(value interface{}, buf []byt return nil, nil } - return append(buf, strconv.FormatUint(uint64(v.Uint), 10)...), nil + return append(buf, strconv.FormatUint(uint64(v.Uint32), 10)...), nil } type encodePlanUint32CodecTextInt64Valuer struct{} @@ -279,7 +279,7 @@ func (scanPlanBinaryUint32ToUint32Scanner) Scan(src []byte, dst interface{}) err n := binary.BigEndian.Uint32(src) - return s.ScanUint32(Uint32{Uint: n, Valid: true}) + return s.ScanUint32(Uint32{Uint32: n, Valid: true}) } type scanPlanTextAnyToUint32Scanner struct{} @@ -299,5 +299,5 @@ func (scanPlanTextAnyToUint32Scanner) Scan(src []byte, dst interface{}) error { return err } - return s.ScanUint32(Uint32{Uint: uint32(n), Valid: true}) + return s.ScanUint32(Uint32{Uint32: uint32(n), Valid: true}) } diff --git a/pgtype/uint32_test.go b/pgtype/uint32_test.go index 98adbee4..d6699a03 100644 --- a/pgtype/uint32_test.go +++ b/pgtype/uint32_test.go @@ -10,9 +10,9 @@ import ( func TestUint32Codec(t *testing.T) { testutil.RunTranscodeTests(t, "oid", []testutil.TranscodeTestCase{ { - pgtype.Uint32{Uint: pgtype.TextOID, Valid: true}, + pgtype.Uint32{Uint32: pgtype.TextOID, Valid: true}, new(pgtype.Uint32), - isExpectedEq(pgtype.Uint32{Uint: pgtype.TextOID, Valid: true}), + isExpectedEq(pgtype.Uint32{Uint32: pgtype.TextOID, Valid: true}), }, {pgtype.Uint32{}, new(pgtype.Uint32), isExpectedEq(pgtype.Uint32{})}, {nil, new(pgtype.Uint32), isExpectedEq(pgtype.Uint32{})}, From aad3d65e16f9993657a8037b5d64d64445bffbe1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 10:27:15 -0600 Subject: [PATCH 0939/1158] Initial restructure of simple protocol to query exec mode --- batch_test.go | 2 +- conn.go | 91 ++++++++++++++++++++++++++++++------------ conn_test.go | 23 ++++++----- helper_test.go | 4 +- large_objects_test.go | 4 +- pgbouncer_test.go | 2 +- pgxpool/common_test.go | 2 +- query_test.go | 44 ++++++++++---------- stdlib/sql_test.go | 2 +- 9 files changed, 107 insertions(+), 67 deletions(-) diff --git a/batch_test.go b/batch_test.go index 32901830..c2e944a1 100644 --- a/batch_test.go +++ b/batch_test.go @@ -803,7 +803,7 @@ func TestSendBatchSimpleProtocol(t *testing.T) { t.Parallel() config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.PreferSimpleProtocol = true + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() diff --git a/conn.go b/conn.go index ba0d9d00..177e21ff 100644 --- a/conn.go +++ b/conn.go @@ -29,13 +29,11 @@ type ConnConfig struct { // to nil to disable automatic prepared statements. BuildStatementCache BuildStatementCacheFunc - // PreferSimpleProtocol disables implicit prepared statement usage. By default pgx automatically uses the extended - // protocol. This can improve performance due to being able to use the binary format. It also does not rely on client - // side parameter sanitization. However, it does incur two round-trips per query (unless using a prepared statement) - // and may be incompatible proxies such as PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be - // used by default. The same functionality can be controlled on a per query basis by setting - // QueryExOptions.SimpleProtocol. - PreferSimpleProtocol bool + // DefaultQueryExecMode controls the default mode for executing queries. By default pgx uses the extended protocol + // and automatically prepares and caches prepared statements. However, this may be incompatible with proxies such as + // PGBouncer. In this case it may be preferrable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same + // functionality can be controlled on a per query basis by passing a QueryExecMode as the first query argument. + DefaultQueryExecMode QueryExecMode createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -125,8 +123,9 @@ func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { // server. "describe" is primarily useful when the environment does not allow prepared statements such as when // running a connection pooler like PgBouncer. Default: "prepare" // -// prefer_simple_protocol -// Possible values: "true" and "false". Use the simple protocol instead of extended protocol. Default: false +// default_query_exec_mode +// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See +// QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement". func ParseConfig(connString string) (*ConnConfig, error) { config, err := pgconn.ParseConfig(connString) if err != nil { @@ -163,13 +162,22 @@ func ParseConfig(connString string) (*ConnConfig, error) { } } - preferSimpleProtocol := false - if s, ok := config.RuntimeParams["prefer_simple_protocol"]; ok { - delete(config.RuntimeParams, "prefer_simple_protocol") - if b, err := strconv.ParseBool(s); err == nil { - preferSimpleProtocol = b - } else { - return nil, fmt.Errorf("invalid prefer_simple_protocol: %v", err) + defaultQueryExecMode := QueryExecModeCacheStatement + if s, ok := config.RuntimeParams["default_query_exec_mode"]; ok { + delete(config.RuntimeParams, "default_query_exec_mode") + switch s { + case "cache_statement": + defaultQueryExecMode = QueryExecModeCacheStatement + case "cache_describe": + defaultQueryExecMode = QueryExecModeCacheDescribe + case "describe_exec": + defaultQueryExecMode = QueryExecModeDescribeExec + case "exec": + defaultQueryExecMode = QueryExecModeExec + case "simple_protocol": + defaultQueryExecMode = QueryExecModeSimpleProtocol + default: + return nil, fmt.Errorf("invalid default_query_exec_mode: %v", err) } } @@ -178,7 +186,7 @@ func ParseConfig(connString string) (*ConnConfig, error) { createdByParseConfig: true, LogLevel: LogLevelInfo, BuildStatementCache: buildStatementCache, - PreferSimpleProtocol: preferSimpleProtocol, + DefaultQueryExecMode: defaultQueryExecMode, connString: connString, } @@ -403,13 +411,13 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) ( } func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { - simpleProtocol := c.config.PreferSimpleProtocol + simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol optionLoop: for len(arguments) > 0 { switch arg := arguments[0].(type) { - case QuerySimpleProtocol: - simpleProtocol = bool(arg) + case QueryExecMode: + simpleProtocol = arg == QueryExecModeSimpleProtocol arguments = arguments[1:] default: break optionLoop @@ -525,8 +533,39 @@ func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *con return r } -// QuerySimpleProtocol controls whether the simple or extended protocol is used to send the query. -type QuerySimpleProtocol bool +type QueryExecMode int32 + +const ( + _ QueryExecMode = iota + + // Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single + // round trip after the statement is cached. This is the default. + QueryExecModeCacheStatement + + // Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the + // extended protocol. Queries are executed in a single round trip after the description is cached. If the database + // schema is modified or the search_path is changed this may result in undetected result decoding errors. + QueryExecModeCacheDescribe + + // Get the statement description on every execution. This uses the extended protocol. Queries require two round trips + // to execute. It does not use prepared statements (allowing usage with most connection poolers) and is safe even + // when the the database schema is modified concurrently. + QueryExecModeDescribeExec + + // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended + // protocol. Queries are executed in a single round trip. Type mappings can be registered with + // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious. + // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use + // a map[string]string directly as an argument. This mode cannot. + QueryExecModeExec + + // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. + // Queries are executed in a single round trip. Type mappings can be registered with + // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious. + // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use + // a map[string]string directly as an argument. This mode cannot. + QueryExecModeSimpleProtocol +) // QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position. type QueryResultFormats []int16 @@ -547,7 +586,7 @@ type QueryResultFormatsByOID map[uint32]int16 func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { var resultFormats QueryResultFormats var resultFormatsByOID QueryResultFormatsByOID - simpleProtocol := c.config.PreferSimpleProtocol + simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol optionLoop: for len(args) > 0 { @@ -558,8 +597,8 @@ optionLoop: case QueryResultFormatsByOID: resultFormatsByOID = arg args = args[1:] - case QuerySimpleProtocol: - simpleProtocol = bool(arg) + case QueryExecMode: + simpleProtocol = arg == QueryExecModeSimpleProtocol args = args[1:] default: break optionLoop @@ -709,7 +748,7 @@ func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, sc // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { - simpleProtocol := c.config.PreferSimpleProtocol + simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol var sb strings.Builder if simpleProtocol { for i, bi := range b.items { diff --git a/conn_test.go b/conn_test.go index 3240c954..f5a4319f 100644 --- a/conn_test.go +++ b/conn_test.go @@ -81,7 +81,7 @@ func TestConnectWithPreferSimpleProtocol(t *testing.T) { t.Parallel() connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - connConfig.PreferSimpleProtocol = true + connConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol conn := mustConnect(t, connConfig) defer closeConn(t, conn) @@ -164,23 +164,24 @@ func TestParseConfigExtractsStatementCacheOptions(t *testing.T) { require.Equal(t, stmtcache.ModeDescribe, c.Mode()) } -func TestParseConfigExtractsPreferSimpleProtocol(t *testing.T) { +func TestParseConfigExtractsDefaultQueryExecMode(t *testing.T) { t.Parallel() for _, tt := range []struct { connString string - preferSimpleProtocol bool + defaultQueryExecMode pgx.QueryExecMode }{ - {"", false}, - {"prefer_simple_protocol=false", false}, - {"prefer_simple_protocol=0", false}, - {"prefer_simple_protocol=true", true}, - {"prefer_simple_protocol=1", true}, + {"", pgx.QueryExecModeCacheStatement}, + {"default_query_exec_mode=cache_statement", pgx.QueryExecModeCacheStatement}, + {"default_query_exec_mode=cache_describe", pgx.QueryExecModeCacheDescribe}, + {"default_query_exec_mode=describe_exec", pgx.QueryExecModeDescribeExec}, + {"default_query_exec_mode=exec", pgx.QueryExecModeExec}, + {"default_query_exec_mode=simple_protocol", pgx.QueryExecModeSimpleProtocol}, } { config, err := pgx.ParseConfig(tt.connString) require.NoError(t, err) - require.Equalf(t, tt.preferSimpleProtocol, config.PreferSimpleProtocol, "connString: `%s`", tt.connString) - require.Empty(t, config.RuntimeParams["prefer_simple_protocol"]) + require.Equalf(t, tt.defaultQueryExecMode, config.DefaultQueryExecMode, "connString: `%s`", tt.connString) + require.Empty(t, config.RuntimeParams["default_query_exec_mode"]) } } @@ -384,7 +385,7 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) { commandTag, err = conn.Exec(ctx, "insert into foo(name) values($1);", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, "bar'; drop table foo;--", ) if err != nil { diff --git a/helper_test.go b/helper_test.go index 74c17431..22cc8872 100644 --- a/helper_test.go +++ b/helper_test.go @@ -18,7 +18,7 @@ func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, c config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) - config.PreferSimpleProtocol = true + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol conn, err := pgx.ConnectConfig(context.Background(), config) require.NoError(t, err) defer func() { @@ -130,7 +130,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) // Can't test function equality, so just test that they are set or not. assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName) - assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName) + assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName) assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) diff --git a/large_objects_test.go b/large_objects_test.go index e42a90e7..f86f35e9 100644 --- a/large_objects_test.go +++ b/large_objects_test.go @@ -32,7 +32,7 @@ func TestLargeObjects(t *testing.T) { testLargeObjects(t, ctx, tx) } -func TestLargeObjectsPreferSimpleProtocol(t *testing.T) { +func TestLargeObjectsSimpleProtocol(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -43,7 +43,7 @@ func TestLargeObjectsPreferSimpleProtocol(t *testing.T) { t.Fatal(err) } - config.PreferSimpleProtocol = true + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol conn, err := pgx.ConnectConfig(ctx, config) if err != nil { diff --git a/pgbouncer_test.go b/pgbouncer_test.go index eeae6db4..e80861a0 100644 --- a/pgbouncer_test.go +++ b/pgbouncer_test.go @@ -34,7 +34,7 @@ func TestPgbouncerSimpleProtocol(t *testing.T) { config := mustParseConfig(t, connString) config.BuildStatementCache = nil - config.PreferSimpleProtocol = true + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol testPgbouncer(t, config, 10, 100) } diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go index 7b9f9f29..93e1940d 100644 --- a/pgxpool/common_test.go +++ b/pgxpool/common_test.go @@ -168,7 +168,7 @@ func assertConnConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, test // Can't test function equality, so just test that they are set or not. assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName) - assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName) + assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName) assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) diff --git a/query_test.go b/query_test.go index 3728f8a3..a0b75313 100644 --- a/query_test.go +++ b/query_test.go @@ -291,7 +291,7 @@ func TestConnQueryRawValues(t *testing.T) { rows, err := conn.Query( context.Background(), "select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, 10, ) require.NoError(t, err) @@ -1385,7 +1385,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::int8", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1402,7 +1402,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::float8", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1419,7 +1419,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1436,7 +1436,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bytea", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1453,7 +1453,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::text", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1478,7 +1478,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::text[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1499,7 +1499,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::smallint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1520,7 +1520,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::int[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1541,7 +1541,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bigint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1562,7 +1562,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bigint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1583,7 +1583,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::smallint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1604,7 +1604,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bigint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1625,7 +1625,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bigint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1646,7 +1646,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::bigint[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1667,7 +1667,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::float4[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1688,7 +1688,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::float8[]", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) @@ -1706,7 +1706,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::circle", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, &expected, ).Scan(&actual) if err != nil { @@ -1734,7 +1734,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1::int8, $2::float8, $3, $4::bytea, $5::text", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString, ).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString) if err != nil { @@ -1765,7 +1765,7 @@ func TestConnSimpleProtocol(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1793,7 +1793,7 @@ func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, "test", ).Scan(&expected) if err == nil { @@ -1817,7 +1817,7 @@ func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) { err := conn.QueryRow( context.Background(), "select $1", - pgx.QuerySimpleProtocol(true), + pgx.QueryExecModeSimpleProtocol, `\'; drop table users; --`, ).Scan(&expected) if err == nil { diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 07498843..8695e4ad 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -80,7 +80,7 @@ func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, d config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) - config.PreferSimpleProtocol = true + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol db := stdlib.OpenDB(*config) defer func() { err := db.Close() From 0d8e109c212b2c1e037d56779ec1e5500229a368 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 13:57:48 -0600 Subject: [PATCH 0940/1158] Test every QueryExecMode --- conn.go | 17 +++++++ conn_test.go | 22 ++++----- helper_test.go | 54 +++++++++------------- query_test.go | 6 +-- stdlib/sql_test.go | 109 +++++++++++++++++++++------------------------ values_test.go | 30 ++++++------- 6 files changed, 118 insertions(+), 120 deletions(-) diff --git a/conn.go b/conn.go index 177e21ff..36994698 100644 --- a/conn.go +++ b/conn.go @@ -567,6 +567,23 @@ const ( QueryExecModeSimpleProtocol ) +func (m QueryExecMode) String() string { + switch m { + case QueryExecModeCacheStatement: + return "cache statement" + case QueryExecModeCacheDescribe: + return "cache describe" + case QueryExecModeDescribeExec: + return "describe exec" + case QueryExecModeExec: + return "exec" + case QueryExecModeSimpleProtocol: + return "simple protocol" + default: + return "invalid" + } +} + // QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position. type QueryResultFormats []int16 diff --git a/conn_test.go b/conn_test.go index f5a4319f..f4b3dd78 100644 --- a/conn_test.go +++ b/conn_test.go @@ -188,7 +188,7 @@ func TestParseConfigExtractsDefaultQueryExecMode(t *testing.T) { func TestExec(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results.String() != "CREATE TABLE" { t.Error("Unexpected results from Exec") } @@ -222,7 +222,7 @@ func TestExec(t *testing.T) { func TestExecFailure(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { if _, err := conn.Exec(context.Background(), "selct;"); err == nil { t.Fatal("Expected SQL syntax error") } @@ -238,7 +238,7 @@ func TestExecFailure(t *testing.T) { func TestExecFailureWithArguments(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { _, err := conn.Exec(context.Background(), "selct $1;", 1) if err == nil { t.Fatal("Expected SQL syntax error") @@ -253,7 +253,7 @@ func TestExecFailureWithArguments(t *testing.T) { func TestExecContextWithoutCancelation(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() @@ -271,7 +271,7 @@ func TestExecContextWithoutCancelation(t *testing.T) { func TestExecContextFailureWithoutCancelation(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() @@ -293,7 +293,7 @@ func TestExecContextFailureWithoutCancelation(t *testing.T) { func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() @@ -720,7 +720,7 @@ func TestFatalTxError(t *testing.T) { func TestInsertBoolArray(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results.String() != "CREATE TABLE" { t.Error("Unexpected results from Exec") } @@ -735,7 +735,7 @@ func TestInsertBoolArray(t *testing.T) { func TestInsertTimestampArray(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results.String() != "CREATE TABLE" { t.Error("Unexpected results from Exec") } @@ -859,7 +859,7 @@ func TestConnInitTypeMap(t *testing.T) { } func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { skipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") var n uint64 @@ -875,7 +875,7 @@ func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) { } func TestDomainType(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { skipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") // Domain type uint64 is a PostgreSQL domain of underlying type numeric. @@ -1046,7 +1046,7 @@ func TestStmtCacheInvalidationTx(t *testing.T) { } func TestInsertDurationInterval(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { _, err := conn.Exec(context.Background(), "create temporary table t(duration INTERVAL(0) NOT NULL)") require.NoError(t, err) diff --git a/helper_test.go b/helper_test.go index 22cc8872..c24d776b 100644 --- a/helper_test.go +++ b/helper_test.go @@ -12,43 +12,33 @@ import ( "github.com/stretchr/testify/require" ) -func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) { - t.Run("SimpleProto", - func(t *testing.T) { - config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol - conn, err := pgx.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer func() { - err := conn.Close(context.Background()) +func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) { + for _, mode := range []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + pgx.QueryExecModeExec, + pgx.QueryExecModeSimpleProtocol, + } { + t.Run(mode.String(), + func(t *testing.T) { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) - }() - f(t, conn) - - ensureConnValid(t, conn) - }, - ) - - t.Run("DefaultProto", - func(t *testing.T) { - config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - conn, err := pgx.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer func() { - err := conn.Close(context.Background()) + config.DefaultQueryExecMode = mode + conn, err := pgx.ConnectConfig(context.Background(), config) require.NoError(t, err) - }() + defer func() { + err := conn.Close(context.Background()) + require.NoError(t, err) + }() - f(t, conn) + f(t, conn) - ensureConnValid(t, conn) - }, - ) + ensureConnValid(t, conn) + }, + ) + } } func mustConnectString(t testing.TB, connString string) *pgx.Conn { diff --git a/query_test.go b/query_test.go index a0b75313..b6a0d65d 100644 --- a/query_test.go +++ b/query_test.go @@ -1912,7 +1912,7 @@ func TestQueryErrorWithNilStatementCacheMode(t *testing.T) { func TestConnQueryFunc(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { var actualResults []interface{} var a, b int @@ -1942,7 +1942,7 @@ func TestConnQueryFuncScanError(t *testing.T) { t.Skip("TODO - unskip later in v5") t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { var actualResults []interface{} var a, b int @@ -1964,7 +1964,7 @@ func TestConnQueryFuncScanError(t *testing.T) { func TestConnQueryFuncAbort(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { var a, b int ct, err := conn.QueryFunc( context.Background(), diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 8695e4ad..5fe03976 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -74,41 +74,32 @@ func skipPostgreSQLVersionLessThan(t testing.TB, db *sql.DB, minVersion int64) { require.NoError(t, err) } -func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, db *sql.DB)) { - t.Run("SimpleProto", - func(t *testing.T) { - config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol - db := stdlib.OpenDB(*config) - defer func() { - err := db.Close() +func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, db *sql.DB)) { + for _, mode := range []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + pgx.QueryExecModeExec, + pgx.QueryExecModeSimpleProtocol, + } { + t.Run(mode.String(), + func(t *testing.T) { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) - }() - f(t, db) + config.DefaultQueryExecMode = mode + db := stdlib.OpenDB(*config) + defer func() { + err := db.Close() + require.NoError(t, err) + }() - ensureDBValid(t, db) - }, - ) + f(t, db) - t.Run("DefaultProto", - func(t *testing.T) { - config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - db := stdlib.OpenDB(*config) - defer func() { - err := db.Close() - require.NoError(t, err) - }() - - f(t, db) - - ensureDBValid(t, db) - }, - ) + ensureDBValid(t, db) + }, + ) + } } // Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should @@ -267,7 +258,7 @@ func TestQueryCloseRowsEarly(t *testing.T) { } func TestConnExec(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec("create temporary table t(a varchar not null)") require.NoError(t, err) @@ -281,7 +272,7 @@ func TestConnExec(t *testing.T) { } func TestConnQuery(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10)) @@ -313,7 +304,7 @@ func TestConnQuery(t *testing.T) { // https://github.com/jackc/pgx/issues/781 func TestConnQueryDifferentScanPlansIssue781(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { var s string var b bool @@ -328,7 +319,7 @@ func TestConnQueryDifferentScanPlansIssue781(t *testing.T) { } func TestConnQueryNull(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { rows, err := db.Query("select $1::int", nil) require.NoError(t, err) @@ -353,7 +344,7 @@ func TestConnQueryNull(t *testing.T) { } func TestConnQueryRowByteSlice(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { expected := []byte{222, 173, 190, 239} var actual []byte @@ -364,7 +355,7 @@ func TestConnQueryRowByteSlice(t *testing.T) { } func TestConnQueryFailure(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.Query("select 'foo") require.Error(t, err) require.IsType(t, new(pgconn.PgError), err) @@ -372,7 +363,7 @@ func TestConnQueryFailure(t *testing.T) { } func TestConnSimpleSlicePassThrough(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server does not support cardinality function") var n int64 @@ -385,7 +376,7 @@ func TestConnSimpleSlicePassThrough(t *testing.T) { // Test type that pgx would handle natively in binary, but since it is not a // database/sql native type should be passed through as a string func TestConnQueryRowPgxBinary(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { sql := "select $1::int4[]" expected := "{1,2,3}" var actual string @@ -397,7 +388,7 @@ func TestConnQueryRowPgxBinary(t *testing.T) { } func TestConnQueryRowUnknownType(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server does not support point type") sql := "select $1::point" @@ -411,7 +402,7 @@ func TestConnQueryRowUnknownType(t *testing.T) { } func TestConnQueryJSONIntoByteSlice(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec(` create temporary table docs( body json not null @@ -471,7 +462,7 @@ func TestConnExecInsertByteSliceIntoJSON(t *testing.T) { } func TestTransactionLifeCycle(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec("create temporary table t(a varchar not null)") require.NoError(t, err) @@ -505,7 +496,7 @@ func TestTransactionLifeCycle(t *testing.T) { } func TestConnBeginTxIsolation(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server always uses serializable isolation level") var defaultIsoLevel string @@ -561,7 +552,7 @@ func TestConnBeginTxIsolation(t *testing.T) { } func TestConnBeginTxReadOnly(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) require.NoError(t, err) defer tx.Rollback() @@ -579,7 +570,7 @@ func TestConnBeginTxReadOnly(t *testing.T) { } func TestBeginTxContextCancel(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec("drop table if exists t") require.NoError(t, err) @@ -607,7 +598,7 @@ func TestBeginTxContextCancel(t *testing.T) { } func TestAcquireConn(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { var conns []*pgx.Conn for i := 1; i < 6; i++ { @@ -643,7 +634,7 @@ func TestAcquireConn(t *testing.T) { } func TestConnRaw(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { conn, err := db.Conn(context.Background()) require.NoError(t, err) @@ -659,7 +650,7 @@ func TestConnRaw(t *testing.T) { // https://github.com/jackc/pgx/issues/673 func TestReleaseConnWithTxInProgress(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server does not support backend PID") c1, err := stdlib.AcquireConn(db) @@ -690,14 +681,14 @@ func TestReleaseConnWithTxInProgress(t *testing.T) { } func TestConnPingContextSuccess(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { err := db.PingContext(context.Background()) require.NoError(t, err) }) } func TestConnPrepareContextSuccess(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { stmt, err := db.PrepareContext(context.Background(), "select now()") require.NoError(t, err) err = stmt.Close() @@ -706,14 +697,14 @@ func TestConnPrepareContextSuccess(t *testing.T) { } func TestConnExecContextSuccess(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)") require.NoError(t, err) }) } func TestConnExecContextFailureRetry(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { // We get a connection, immediately close it, and then get it back; // DB.Conn along with Conn.ResetSession does the retry for us. { @@ -730,7 +721,7 @@ func TestConnExecContextFailureRetry(t *testing.T) { } func TestConnQueryContextSuccess(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n") require.NoError(t, err) @@ -744,7 +735,7 @@ func TestConnQueryContextSuccess(t *testing.T) { } func TestConnQueryContextFailureRetry(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { // We get a connection, immediately close it, and then get it back; // DB.Conn along with Conn.ResetSession does the retry for us. { @@ -762,7 +753,7 @@ func TestConnQueryContextFailureRetry(t *testing.T) { } func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { rows, err := db.Query("select 42::bigint") require.NoError(t, err) @@ -846,7 +837,7 @@ func TestStmtQueryContextSuccess(t *testing.T) { } func TestRowsColumnTypes(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { columnTypesTests := []struct { Name string TypeName string @@ -984,7 +975,7 @@ func TestRowsColumnTypes(t *testing.T) { } func TestQueryLifeCycle(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3) @@ -1033,7 +1024,7 @@ func TestQueryLifeCycle(t *testing.T) { // https://github.com/jackc/pgx/issues/409 func TestScanJSONIntoJSONRawMessage(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { var msg json.RawMessage err := db.QueryRow("select '{}'::json").Scan(&msg) @@ -1088,7 +1079,7 @@ func TestRegisterConnConfig(t *testing.T) { // https://github.com/jackc/pgx/issues/958 func TestConnQueryRowConstraintErrors(t *testing.T) { - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipPostgreSQLVersionLessThan(t, db, 11) skipCockroachDB(t, db, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") diff --git a/values_test.go b/values_test.go index 81138bfa..b7d5c572 100644 --- a/values_test.go +++ b/values_test.go @@ -18,7 +18,7 @@ import ( func TestDateTranscode(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { dates := []time.Time{ time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC), @@ -57,7 +57,7 @@ func TestDateTranscode(t *testing.T) { func TestTimestampTzTranscode(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { inputTime := time.Date(2013, 1, 2, 3, 4, 5, 6000, time.Local) var outputTime time.Time @@ -77,7 +77,7 @@ func TestTimestampTzTranscode(t *testing.T) { func TestJSONAndJSONBTranscode(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { for _, typename := range []string{"json", "jsonb"} { if _, ok := conn.TypeMap().TypeForName(typename); !ok { continue // No JSON/JSONB type -- must be running against old PostgreSQL @@ -247,7 +247,7 @@ func TestStringToNotTextTypeTranscode(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { input := "01086ee0-4963-4e35-9116-30c173a8d0bd" var output string @@ -272,7 +272,7 @@ func TestStringToNotTextTypeTranscode(t *testing.T) { func TestInetCIDRTranscodeIPNet(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { tests := []struct { sql string value *net.IPNet @@ -323,7 +323,7 @@ func TestInetCIDRTranscodeIPNet(t *testing.T) { func TestInetCIDRTranscodeIP(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { tests := []struct { sql string value net.IP @@ -387,7 +387,7 @@ func TestInetCIDRTranscodeIP(t *testing.T) { func TestInetCIDRArrayTranscodeIPNet(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { tests := []struct { sql string value []*net.IPNet @@ -450,7 +450,7 @@ func TestInetCIDRArrayTranscodeIPNet(t *testing.T) { func TestInetCIDRArrayTranscodeIP(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { tests := []struct { sql string value []net.IP @@ -536,7 +536,7 @@ func TestInetCIDRArrayTranscodeIP(t *testing.T) { func TestInetCIDRTranscodeWithJustIP(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { tests := []struct { sql string value string @@ -582,7 +582,7 @@ func TestInetCIDRTranscodeWithJustIP(t *testing.T) { func TestArrayDecoding(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { tests := []struct { sql string query interface{} @@ -698,7 +698,7 @@ func TestArrayDecoding(t *testing.T) { func TestEmptyArrayDecoding(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { var val []string err := conn.QueryRow(context.Background(), "select array[]::text[]").Scan(&val) @@ -743,7 +743,7 @@ func TestEmptyArrayDecoding(t *testing.T) { func TestPointerPointer(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { skipCockroachDB(t, conn, "Server auto converts ints to bigint and test relies on exact types") type allTypes struct { @@ -829,7 +829,7 @@ func TestPointerPointer(t *testing.T) { func TestPointerPointerNonZero(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { f := "foo" dest := &f @@ -846,7 +846,7 @@ func TestPointerPointerNonZero(t *testing.T) { func TestEncodeTypeRename(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { type _int int inInt := _int(1) var outInt _int @@ -993,7 +993,7 @@ func TestEncodeTypeRename(t *testing.T) { func TestRowsScanNilThenScanValue(t *testing.T) { t.Parallel() - testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { sql := `select null as a, null as b union select 1, 2 From 39d2e3dc3f5bd61498d0f3e6f7d69a3bc10ce64f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 15:16:12 -0600 Subject: [PATCH 0941/1158] Move chooseParameterFormatCode --- extended_query_builder.go | 14 +++++++++++++- values.go | 12 ------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/extended_query_builder.go b/extended_query_builder.go index 51759362..5d03790e 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -15,7 +15,7 @@ type extendedQueryBuilder struct { } func (eqb *extendedQueryBuilder) AppendParam(m *pgtype.Map, oid uint32, arg interface{}) error { - f := chooseParameterFormatCode(m, oid, arg) + f := eqb.chooseParameterFormatCode(m, oid, arg) eqb.paramFormats = append(eqb.paramFormats, f) v, err := eqb.encodeExtendedParamValue(m, oid, f, arg) @@ -100,3 +100,15 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uin } return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } + +// chooseParameterFormatCode determines the correct format code for an +// argument to a prepared statement. It defaults to TextFormatCode if no +// determination can be made. +func (eqb *extendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg interface{}) int16 { + switch arg.(type) { + case string, *string: + return TextFormatCode + } + + return m.FormatCodeForOID(oid) +} diff --git a/values.go b/values.go index 67363986..7d1933b1 100644 --- a/values.go +++ b/values.go @@ -143,18 +143,6 @@ func encodePreparedStatementArgument(m *pgtype.Map, buf []byte, oid uint32, arg return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } -// chooseParameterFormatCode determines the correct format code for an -// argument to a prepared statement. It defaults to TextFormatCode if no -// determination can be made. -func chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg interface{}) int16 { - switch arg.(type) { - case string, *string: - return TextFormatCode - } - - return m.FormatCodeForOID(oid) -} - func stripNamedType(val *reflect.Value) (interface{}, bool) { switch val.Kind() { case reflect.Int: From 1cef9075d94cd69fbbe8038dc8ed26ef8dfe9842 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 19:45:44 -0600 Subject: [PATCH 0942/1158] Simply typed nil and driver.Valuer handling * Convert typed nils to untyped nils at beginning of encoding process. * Restore v4 json/jsonb null behavior * Add anynil internal package --- conn.go | 11 +++++-- go_stdlib.go | 61 --------------------------------------- internal/anynil/anynil.go | 36 +++++++++++++++++++++++ messages.go | 19 ------------ pgtype/json_test.go | 4 +-- pgtype/jsonb_test.go | 4 +-- values.go | 25 +++++++++++----- 7 files changed, 66 insertions(+), 94 deletions(-) delete mode 100644 go_stdlib.go create mode 100644 internal/anynil/anynil.go delete mode 100644 messages.go diff --git a/conn.go b/conn.go index 36994698..20968892 100644 --- a/conn.go +++ b/conn.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/sanitize" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn/stmtcache" @@ -478,7 +479,9 @@ func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, argu c.eqb.Reset() - args, err := convertDriverValuers(arguments) + anynil.NormalizeSlice(arguments) + + args, err := evaluateDriverValuers(arguments) if err != nil { return err } @@ -671,7 +674,8 @@ optionLoop: rows.sql = sd.SQL - args, err = convertDriverValuers(args) + anynil.NormalizeSlice(args) + args, err = evaluateDriverValuers(args) if err != nil { rows.fatal(err) return rows, rows.err @@ -831,7 +835,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} } - args, err := convertDriverValuers(bi.arguments) + anynil.NormalizeSlice(bi.arguments) + args, err := evaluateDriverValuers(bi.arguments) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } diff --git a/go_stdlib.go b/go_stdlib.go deleted file mode 100644 index 9372f9ef..00000000 --- a/go_stdlib.go +++ /dev/null @@ -1,61 +0,0 @@ -package pgx - -import ( - "database/sql/driver" - "reflect" -) - -// This file contains code copied from the Go standard library due to the -// required function not being public. - -// Copyright (c) 2009 The Go Authors. All rights reserved. - -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: - -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above -// copyright notice, this list of conditions and the following disclaimer -// in the documentation and/or other materials provided with the -// distribution. -// * Neither the name of Google Inc. nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. - -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -// From database/sql/convert.go - -var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() - -// callValuerValue returns vr.Value(), with one exception: -// If vr.Value is an auto-generated method on a pointer type and the -// pointer is nil, it would panic at runtime in the panicwrap -// method. Treat it like nil instead. -// Issue 8415. -// -// This is so people can implement driver.Value on value types and -// still use nil pointers to those types to mean nil/NULL, just like -// string/*string. -// -// This function is mirrored in the database/sql/driver package. -func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { - if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && - rv.IsNil() && - rv.Type().Elem().Implements(valuerReflectType) { - return nil, nil - } - return vr.Value() -} diff --git a/internal/anynil/anynil.go b/internal/anynil/anynil.go new file mode 100644 index 00000000..57a45b95 --- /dev/null +++ b/internal/anynil/anynil.go @@ -0,0 +1,36 @@ +package anynil + +import "reflect" + +// Is returns true if value is any type of nil. e.g. nil or []byte(nil). +func Is(value interface{}) bool { + if value == nil { + return true + } + + refVal := reflect.ValueOf(value) + switch refVal.Kind() { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: + return refVal.IsNil() + default: + return false + } +} + +// Normalize converts typed nils (e.g. []byte(nil)) into untyped nil. Other values are returned unmodified. +func Normalize(v interface{}) interface{} { + if Is(v) { + return nil + } + return v +} + +// NormalizeSlice converts all typed nils (e.g. []byte(nil)) in s into untyped nils. Other values are unmodified. s is +// mutated in place. +func NormalizeSlice(s []interface{}) { + for i := range s { + if Is(s[i]) { + s[i] = nil + } + } +} diff --git a/messages.go b/messages.go deleted file mode 100644 index 01ece44e..00000000 --- a/messages.go +++ /dev/null @@ -1,19 +0,0 @@ -package pgx - -import ( - "database/sql/driver" -) - -func convertDriverValuers(args []interface{}) ([]interface{}, error) { - for i, arg := range args { - switch arg := arg.(type) { - case driver.Valuer: - v, err := callValuerValue(arg) - if err != nil { - return nil, err - } - args[i] = v - } - } - return args, nil -} diff --git a/pgtype/json_test.go b/pgtype/json_test.go index a1dd63fb..39658bfa 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -48,8 +48,8 @@ func TestJSONCodec(t *testing.T) { {map[string]interface{}{"foo": "bar"}, new(map[string]interface{}), isExpectedEqMap(map[string]interface{}{"foo": "bar"})}, {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, - {map[string]interface{}(nil), new(string), isExpectedEq(`null`)}, - {map[string]interface{}(nil), new([]byte), isExpectedEqBytes([]byte("null"))}, + {map[string]interface{}(nil), new(*string), isExpectedEq((*string)(nil))}, + {map[string]interface{}(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, }) diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index fa5ea20e..c26499c6 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -21,8 +21,8 @@ func TestJSONBTranscode(t *testing.T) { {map[string]interface{}{"foo": "bar"}, new(map[string]interface{}), isExpectedEqMap(map[string]interface{}{"foo": "bar"})}, {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, - {map[string]interface{}(nil), new(string), isExpectedEq(`null`)}, - {map[string]interface{}(nil), new([]byte), isExpectedEqBytes([]byte("null"))}, + {map[string]interface{}(nil), new(*string), isExpectedEq((*string)(nil))}, + {map[string]interface{}(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, }) diff --git a/values.go b/values.go index 7d1933b1..fe7f6444 100644 --- a/values.go +++ b/values.go @@ -7,6 +7,7 @@ import ( "reflect" "time" + "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgtype" ) @@ -25,18 +26,13 @@ func (e SerializationError) Error() string { } func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error) { - if arg == nil { - return nil, nil - } - - refVal := reflect.ValueOf(arg) - if refVal.Kind() == reflect.Ptr && refVal.IsNil() { + if anynil.Is(arg) { return nil, nil } switch arg := arg.(type) { case driver.Valuer: - return callValuerValue(arg) + return arg.Value() case float32: return float64(arg), nil case float64: @@ -90,6 +86,7 @@ func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error) return string(buf), nil } + refVal := reflect.ValueOf(arg) if refVal.Kind() == reflect.Ptr { arg = refVal.Elem().Interface() return convertSimpleArgument(m, arg) @@ -182,3 +179,17 @@ func stripNamedType(val *reflect.Value) (interface{}, bool) { return nil, false } + +func evaluateDriverValuers(args []interface{}) ([]interface{}, error) { + for i, arg := range args { + switch arg := arg.(type) { + case driver.Valuer: + v, err := arg.Value() + if err != nil { + return nil, err + } + args[i] = v + } + } + return args, nil +} From e5685a34fc7f120b3479c719ed96e5b2d92e9221 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 20:16:57 -0600 Subject: [PATCH 0943/1158] Simplify encoding extended query arguments --- extended_query_builder.go | 45 ++++++++------------------------------- pgtype/json.go | 30 +++++++++++++++++++++++--- pgtype/pgtype.go | 29 +++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 39 deletions(-) diff --git a/extended_query_builder.go b/extended_query_builder.go index 5d03790e..5409c0fd 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -1,9 +1,7 @@ package pgx import ( - "fmt" - "reflect" - + "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/pgtype" ) @@ -55,14 +53,7 @@ func (eqb *extendedQueryBuilder) Reset() { } func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg interface{}) ([]byte, error) { - if arg == nil { - return nil, nil - } - - refVal := reflect.ValueOf(arg) - argIsPtr := refVal.Kind() == reflect.Ptr - - if argIsPtr && refVal.IsNil() { + if anynil.Is(arg) { return nil, nil } @@ -72,33 +63,15 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uin pos := len(eqb.paramValueBytes) - if arg, ok := arg.(string); ok { - return []byte(arg), nil + buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes) + if err != nil { + return nil, err } - - if argIsPtr { - // We have already checked that arg is not pointing to nil, - // so it is safe to dereference here. - arg = refVal.Elem().Interface() - return eqb.encodeExtendedParamValue(m, oid, formatCode, arg) + if buf == nil { + return nil, nil } - - if _, ok := m.TypeForOID(oid); ok { - buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - eqb.paramValueBytes = buf - return eqb.paramValueBytes[pos:], nil - } - - if strippedArg, ok := stripNamedType(&refVal); ok { - return eqb.encodeExtendedParamValue(m, oid, formatCode, strippedArg) - } - return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + eqb.paramValueBytes = buf + return eqb.paramValueBytes[pos:], nil } // chooseParameterFormatCode determines the correct format code for an diff --git a/pgtype/json.go b/pgtype/json.go index e8882d3a..4d8cf4c4 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -16,13 +16,37 @@ func (JSONCodec) PreferredFormat() int16 { return TextFormatCode } -func (JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch value.(type) { + case string: + return encodePlanJSONCodecEitherFormatString{} case []byte: return encodePlanJSONCodecEitherFormatByteSlice{} - default: - return encodePlanJSONCodecEitherFormatMarshal{} } + + // Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the + // appropriate wrappers here. + for _, f := range []TryWrapEncodePlanFunc{ + TryWrapDerefPointerEncodePlan, + TryWrapFindUnderlyingTypeEncodePlan, + } { + if wrapperPlan, nextValue, ok := f(value); ok { + if nextPlan := c.PlanEncode(m, oid, format, nextValue); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + + return encodePlanJSONCodecEitherFormatMarshal{} +} + +type encodePlanJSONCodecEitherFormatString struct{} + +func (encodePlanJSONCodecEitherFormatString) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + jsonString := value.(string) + buf = append(buf, jsonString...) + return buf, nil } type encodePlanJSONCodecEitherFormatByteSlice struct{} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index d1a92089..75934ced 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1155,6 +1155,14 @@ func codecDecodeToTextFormat(codec Codec, m *Map, oid uint32, format int16, src // PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be // found then nil is returned. func (m *Map) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan { + if format == TextFormatCode { + switch value.(type) { + case string: + return encodePlanStringToAnyTextFormat{} + case TextValuer: + return encodePlanTextValuerToAnyTextFormat{} + } + } var dt *Type @@ -1187,6 +1195,27 @@ func (m *Map) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan return nil } +type encodePlanStringToAnyTextFormat struct{} + +func (encodePlanStringToAnyTextFormat) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + s := value.(string) + return append(buf, s...), nil +} + +type encodePlanTextValuerToAnyTextFormat struct{} + +func (encodePlanTextValuerToAnyTextFormat) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + t, err := value.(TextValuer).TextValue() + if err != nil { + return nil, err + } + if !t.Valid { + return nil, nil + } + + return append(buf, t.String...), nil +} + // TryWrapEncodePlanFunc is a function that tries to create a wrapper plan for value. If successful it returns a plan // that will convert the value passed to Encode and then call the next plan. nextValue is value as it will be converted // by plan. It must be used to find another suitable EncodePlan. When it is found SetNext must be called on plan for it From 2831eedef368d14dda5775741dc184b707cf6275 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 20:27:36 -0600 Subject: [PATCH 0944/1158] Simplify copy encoding --- copy_from.go | 2 +- copy_from_test.go | 8 ++++++++ values.go | 46 +++++++++++----------------------------------- 3 files changed, 20 insertions(+), 36 deletions(-) diff --git a/copy_from.go b/copy_from.go index 7d6a8813..ef982269 100644 --- a/copy_from.go +++ b/copy_from.go @@ -178,7 +178,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) for i, val := range values { - buf, err = encodePreparedStatementArgument(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val) + buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val) if err != nil { return false, nil, err } diff --git a/copy_from_test.go b/copy_from_test.go index 5c22dc35..6e2fe952 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -211,6 +211,14 @@ func TestConnCopyFromEnum(t *testing.T) { _, err = tx.Exec(ctx, `create type fruit as enum ('apple', 'orange', 'grape')`) require.NoError(t, err) + // Obviously using conn while a tx is in use and registering a type after the connection has been established are + // really bad practices, but for the sake of convenience we do it in the test here. + for _, name := range []string{"fruit", "color"} { + typ, err := conn.LoadType(ctx, name) + require.NoError(t, err) + conn.TypeMap().RegisterType(typ) + } + _, err = tx.Exec(ctx, `create table foo( a text, b color, diff --git a/values.go b/values.go index fe7f6444..766074bd 100644 --- a/values.go +++ b/values.go @@ -98,46 +98,22 @@ func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error) return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) } -func encodePreparedStatementArgument(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([]byte, error) { - if arg == nil { +func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([]byte, error) { + if anynil.Is(arg) { return pgio.AppendInt32(buf, -1), nil } - switch arg := arg.(type) { - case string: - buf = pgio.AppendInt32(buf, int32(len(arg))) - buf = append(buf, arg...) - return buf, nil + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf) + if err != nil { + return nil, err } - - refVal := reflect.ValueOf(arg) - - if refVal.Kind() == reflect.Ptr { - if refVal.IsNil() { - return pgio.AppendInt32(buf, -1), nil - } - arg = refVal.Elem().Interface() - return encodePreparedStatementArgument(m, buf, oid, arg) + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - - if _, ok := m.TypeForOID(oid); ok { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf) - if err != nil { - return nil, err - } - if argBuf != nil { - buf = argBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - return buf, nil - } - - if strippedArg, ok := stripNamedType(&refVal); ok { - return encodePreparedStatementArgument(m, buf, oid, strippedArg) - } - return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + return buf, nil } func stripNamedType(val *reflect.Value) (interface{}, bool) { From 0905d1f452e79d8ebbf8bcdc73d7c153eae17172 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 21:19:58 -0600 Subject: [PATCH 0945/1158] Register more default types and handle unknown types better --- pgtype/pgtype.go | 50 +++++++++++++++++++++++++++++++++++++++++------- values.go | 8 ++++++-- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 75934ced..3244b504 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -344,10 +344,12 @@ func NewMap() *Map { registerDefaultPgTypeVariants("int8", "_int8", int64(0)) // Integer types that do not have a direct match to a PostgreSQL type + registerDefaultPgTypeVariants("int8", "_int8", int8(0)) + registerDefaultPgTypeVariants("int8", "_int8", int(0)) + registerDefaultPgTypeVariants("int8", "_int8", uint8(0)) registerDefaultPgTypeVariants("int8", "_int8", uint16(0)) registerDefaultPgTypeVariants("int8", "_int8", uint32(0)) registerDefaultPgTypeVariants("int8", "_int8", uint64(0)) - registerDefaultPgTypeVariants("int8", "_int8", int(0)) registerDefaultPgTypeVariants("int8", "_int8", uint(0)) registerDefaultPgTypeVariants("float4", "_float4", float32(0)) @@ -355,12 +357,46 @@ func NewMap() *Map { registerDefaultPgTypeVariants("bool", "_bool", false) registerDefaultPgTypeVariants("timestamptz", "_timestamptz", time.Time{}) + registerDefaultPgTypeVariants("interval", "_interval", time.Duration(0)) registerDefaultPgTypeVariants("text", "_text", "") registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil)) registerDefaultPgTypeVariants("inet", "_inet", net.IP{}) registerDefaultPgTypeVariants("cidr", "_cidr", net.IPNet{}) + // pgtype provided structs + registerDefaultPgTypeVariants("varbit", "_varbit", Bits{}) + registerDefaultPgTypeVariants("bool", "_bool", Bool{}) + registerDefaultPgTypeVariants("box", "_box", Box{}) + registerDefaultPgTypeVariants("circle", "_circle", Circle{}) + registerDefaultPgTypeVariants("date", "_date", Date{}) + registerDefaultPgTypeVariants("daterange", "_daterange", Daterange{}) + registerDefaultPgTypeVariants("float4", "_float4", Float4{}) + registerDefaultPgTypeVariants("float8", "_float8", Float8{}) + registerDefaultPgTypeVariants("float8range", "_float8range", Float8range{}) + registerDefaultPgTypeVariants("inet", "_inet", Inet{}) + registerDefaultPgTypeVariants("int2", "_int2", Int2{}) + registerDefaultPgTypeVariants("int4", "_int4", Int4{}) + registerDefaultPgTypeVariants("int4range", "_int4range", Int4range{}) + registerDefaultPgTypeVariants("int8", "_int8", Int8{}) + registerDefaultPgTypeVariants("int8range", "_int8range", Int8range{}) + registerDefaultPgTypeVariants("interval", "_interval", Interval{}) + registerDefaultPgTypeVariants("line", "_line", Line{}) + registerDefaultPgTypeVariants("lseg", "_lseg", Lseg{}) + registerDefaultPgTypeVariants("numeric", "_numeric", Numeric{}) + registerDefaultPgTypeVariants("numrange", "_numrange", Numrange{}) + registerDefaultPgTypeVariants("path", "_path", Path{}) + registerDefaultPgTypeVariants("point", "_point", Point{}) + registerDefaultPgTypeVariants("polygon", "_polygon", Polygon{}) + registerDefaultPgTypeVariants("tid", "_tid", TID{}) + registerDefaultPgTypeVariants("text", "_text", Text{}) + registerDefaultPgTypeVariants("time", "_time", Time{}) + registerDefaultPgTypeVariants("timestamp", "_timestamp", Timestamp{}) + registerDefaultPgTypeVariants("timestamptz", "_timestamptz", Timestamptz{}) + registerDefaultPgTypeVariants("tsrange", "_tsrange", Tsrange{}) + registerDefaultPgTypeVariants("tstzrange", "_tstzrange", Tstzrange{}) + registerDefaultPgTypeVariants("uuid", "_uuid", UUID{}) + return m } @@ -1181,13 +1217,13 @@ func (m *Map) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan if plan := dt.Codec.PlanEncode(m, oid, format, value); plan != nil { return plan } + } - for _, f := range m.TryWrapEncodePlanFuncs { - if wrapperPlan, nextValue, ok := f(value); ok { - if nextPlan := m.PlanEncode(oid, format, nextValue); nextPlan != nil { - wrapperPlan.SetNext(nextPlan) - return wrapperPlan - } + for _, f := range m.TryWrapEncodePlanFuncs { + if wrapperPlan, nextValue, ok := f(value); ok { + if nextPlan := m.PlanEncode(oid, format, nextValue); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan } } } diff --git a/values.go b/values.go index 766074bd..0f34b6a6 100644 --- a/values.go +++ b/values.go @@ -30,9 +30,13 @@ func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error) return nil, nil } + if dv, ok := arg.(driver.Valuer); ok { + return dv.Value() + } + + // All these could be handled by m.Encode below. However, that transforms the argument to a string. That could change + // the type of the argument. e.g. '42' instead of 42. So standard types are special cased. switch arg := arg.(type) { - case driver.Valuer: - return arg.Value() case float32: return float64(arg), nil case float64: From c4b08378f235fdc9f8928f5f31316aff5602bfaf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 21:27:17 -0600 Subject: [PATCH 0946/1158] Handle driver.Valuers inside Map.Encode --- conn.go | 28 +++++++--------------------- pgtype/pgtype.go | 11 +++++++++++ values.go | 14 -------------- 3 files changed, 18 insertions(+), 35 deletions(-) diff --git a/conn.go b/conn.go index 20968892..9f2fdcf0 100644 --- a/conn.go +++ b/conn.go @@ -472,22 +472,17 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []i return commandTag, err } -func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, arguments []interface{}) error { - if len(sd.ParamOIDs) != len(arguments) { - return fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(arguments)) +func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, args []interface{}) error { + if len(sd.ParamOIDs) != len(args) { + return fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args)) } c.eqb.Reset() - anynil.NormalizeSlice(arguments) - - args, err := evaluateDriverValuers(arguments) - if err != nil { - return err - } + anynil.NormalizeSlice(args) for i := range args { - err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) + err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) if err != nil { return err } @@ -675,11 +670,6 @@ optionLoop: rows.sql = sd.SQL anynil.NormalizeSlice(args) - args, err = evaluateDriverValuers(args) - if err != nil { - rows.fatal(err) - return rows, rows.err - } for i := range args { err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) @@ -836,13 +826,9 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { } anynil.NormalizeSlice(bi.arguments) - args, err := evaluateDriverValuers(bi.arguments) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} - } - for i := range args { - err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) + for i := range bi.arguments { + err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i]) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 3244b504..1cc809b1 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1720,6 +1720,17 @@ func (m *Map) Encode(oid uint32, formatCode int16, value interface{}, buf []byte plan := m.PlanEncode(oid, formatCode, value) if plan == nil { + if dv, ok := value.(driver.Valuer); ok { + if dv == nil { + return nil, nil + } + v, err := dv.Value() + if err != nil { + return nil, err + } + return m.Encode(oid, formatCode, v, buf) + } + return nil, fmt.Errorf("unable to encode %#v into OID %d", value, oid) } return plan.Encode(value, buf) diff --git a/values.go b/values.go index 0f34b6a6..a3343d81 100644 --- a/values.go +++ b/values.go @@ -159,17 +159,3 @@ func stripNamedType(val *reflect.Value) (interface{}, bool) { return nil, false } - -func evaluateDriverValuers(args []interface{}) ([]interface{}, error) { - for i, arg := range args { - switch arg := arg.(type) { - case driver.Valuer: - v, err := arg.Value() - if err != nil { - return nil, err - } - args[i] = v - } - } - return args, nil -} From fe21cc74864af0bf2486ac6ea571dc7594338b29 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 21:40:49 -0600 Subject: [PATCH 0947/1158] Use Map.Encode path for simple protocol --- conn.go | 6 +++ query_test.go | 4 +- values.go | 127 +++----------------------------------------------- 3 files changed, 14 insertions(+), 123 deletions(-) diff --git a/conn.go b/conn.go index 9f2fdcf0..34d23198 100644 --- a/conn.go +++ b/conn.go @@ -555,6 +555,9 @@ const ( // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious. // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use // a map[string]string directly as an argument. This mode cannot. + // + // It may be necessary to specify the desired type of an argument in the SQL string when it cannot be inferred. e.g. + // "SELECT $1::boolean". QueryExecModeExec // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. @@ -562,6 +565,9 @@ const ( // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious. // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use // a map[string]string directly as an argument. This mode cannot. + // + // This mode uses client side parameter interpolation. All values are quoted and escaped. It may be necessary to + // specify the desired type of an argument in the SQL string when it cannot be inferred. e.g. "SELECT $1::boolean". QueryExecModeSimpleProtocol ) diff --git a/query_test.go b/query_test.go index b6a0d65d..8ed89007 100644 --- a/query_test.go +++ b/query_test.go @@ -1418,7 +1418,7 @@ func TestConnSimpleProtocol(t *testing.T) { var actual bool err := conn.QueryRow( context.Background(), - "select $1", + "select $1::boolean", pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) @@ -1733,7 +1733,7 @@ func TestConnSimpleProtocol(t *testing.T) { var actualString string err := conn.QueryRow( context.Background(), - "select $1::int8, $2::float8, $3, $4::bytea, $5::text", + "select $1::int8, $2::float8, $3::boolean, $4::bytea, $5::text", pgx.QueryExecModeSimpleProtocol, expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString, ).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString) diff --git a/values.go b/values.go index a3343d81..595f2b4d 100644 --- a/values.go +++ b/values.go @@ -1,12 +1,6 @@ package pgx import ( - "database/sql/driver" - "fmt" - "math" - "reflect" - "time" - "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgtype" @@ -18,88 +12,19 @@ const ( BinaryFormatCode = 1 ) -// SerializationError occurs on failure to encode or decode a value -type SerializationError string - -func (e SerializationError) Error() string { - return string(e) -} - func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error) { if anynil.Is(arg) { return nil, nil } - if dv, ok := arg.(driver.Valuer); ok { - return dv.Value() + buf, err := m.Encode(0, TextFormatCode, arg, nil) + if err != nil { + return nil, err } - - // All these could be handled by m.Encode below. However, that transforms the argument to a string. That could change - // the type of the argument. e.g. '42' instead of 42. So standard types are special cased. - switch arg := arg.(type) { - case float32: - return float64(arg), nil - case float64: - return arg, nil - case bool: - return arg, nil - case time.Duration: - return fmt.Sprintf("%d microsecond", int64(arg)/1000), nil - case time.Time: - return arg, nil - case string: - return arg, nil - case []byte: - return arg, nil - case int8: - return int64(arg), nil - case int16: - return int64(arg), nil - case int32: - return int64(arg), nil - case int64: - return arg, nil - case int: - return int64(arg), nil - case uint8: - return int64(arg), nil - case uint16: - return int64(arg), nil - case uint32: - return int64(arg), nil - case uint64: - if arg > math.MaxInt64 { - return nil, fmt.Errorf("arg too big for int64: %v", arg) - } - return int64(arg), nil - case uint: - if uint64(arg) > math.MaxInt64 { - return nil, fmt.Errorf("arg too big for int64: %v", arg) - } - return int64(arg), nil + if buf == nil { + return nil, nil } - - if _, found := m.TypeForValue(arg); found { - buf, err := m.Encode(0, TextFormatCode, arg, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil - } - - refVal := reflect.ValueOf(arg) - if refVal.Kind() == reflect.Ptr { - arg = refVal.Elem().Interface() - return convertSimpleArgument(m, arg) - } - - if strippedArg, ok := stripNamedType(&refVal); ok { - return convertSimpleArgument(m, strippedArg) - } - return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) + return string(buf), nil } func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([]byte, error) { @@ -119,43 +44,3 @@ func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([] } return buf, nil } - -func stripNamedType(val *reflect.Value) (interface{}, bool) { - switch val.Kind() { - case reflect.Int: - convVal := int(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Int8: - convVal := int8(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Int16: - convVal := int16(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Int32: - convVal := int32(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Int64: - convVal := int64(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint: - convVal := uint(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint8: - convVal := uint8(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint16: - convVal := uint16(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint32: - convVal := uint32(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint64: - convVal := uint64(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.String: - convVal := val.String() - return convVal, reflect.TypeOf(convVal) != val.Type() - } - - return nil, false -} From f27178ba85dcb484d61189c2e0c8f5ce9c2ff097 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Mar 2022 08:35:31 -0600 Subject: [PATCH 0948/1158] Initial privatization of stmtcache ConnConfig.BuildStatementCache is pending removal once connections always have separate caches for prepared and described statements. --- batch_test.go | 2 +- bench_test.go | 2 +- conn.go | 5 +---- conn_test.go | 2 +- {pgconn => internal}/stmtcache/lru.go | 0 {pgconn => internal}/stmtcache/lru_test.go | 2 +- {pgconn => internal}/stmtcache/stmtcache.go | 0 pgbouncer_test.go | 2 +- query_test.go | 2 +- 9 files changed, 7 insertions(+), 10 deletions(-) rename {pgconn => internal}/stmtcache/lru.go (100%) rename {pgconn => internal}/stmtcache/lru_test.go (99%) rename {pgconn => internal}/stmtcache/stmtcache.go (100%) diff --git a/batch_test.go b/batch_test.go index c2e944a1..24a70e39 100644 --- a/batch_test.go +++ b/batch_test.go @@ -7,8 +7,8 @@ import ( "testing" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgconn/stmtcache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/bench_test.go b/bench_test.go index bd182ebd..dfb879e5 100644 --- a/bench_test.go +++ b/bench_test.go @@ -13,8 +13,8 @@ import ( "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgconn/stmtcache" "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/require" ) diff --git a/conn.go b/conn.go index 34d23198..0085e454 100644 --- a/conn.go +++ b/conn.go @@ -10,8 +10,8 @@ import ( "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/sanitize" + "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgconn/stmtcache" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" ) @@ -381,9 +381,6 @@ func (c *Conn) Ping(ctx context.Context) error { // is used and the connection must be returned to the same state before any *pgx.Conn methods are again used. func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn } -// StatementCache returns the statement cache used for this connection. -func (c *Conn) StatementCache() stmtcache.Cache { return c.stmtcache } - // TypeMap returns the connection info used for this connection. func (c *Conn) TypeMap() *pgtype.Map { return c.typeMap } diff --git a/conn_test.go b/conn_test.go index f4b3dd78..3792ba45 100644 --- a/conn_test.go +++ b/conn_test.go @@ -9,8 +9,8 @@ import ( "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgconn/stmtcache" "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/pgconn/stmtcache/lru.go b/internal/stmtcache/lru.go similarity index 100% rename from pgconn/stmtcache/lru.go rename to internal/stmtcache/lru.go diff --git a/pgconn/stmtcache/lru_test.go b/internal/stmtcache/lru_test.go similarity index 99% rename from pgconn/stmtcache/lru_test.go rename to internal/stmtcache/lru_test.go index 549e7670..7690a2b0 100644 --- a/pgconn/stmtcache/lru_test.go +++ b/internal/stmtcache/lru_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" + "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgconn/stmtcache" "github.com/stretchr/testify/require" ) diff --git a/pgconn/stmtcache/stmtcache.go b/internal/stmtcache/stmtcache.go similarity index 100% rename from pgconn/stmtcache/stmtcache.go rename to internal/stmtcache/stmtcache.go diff --git a/pgbouncer_test.go b/pgbouncer_test.go index e80861a0..c46f0622 100644 --- a/pgbouncer_test.go +++ b/pgbouncer_test.go @@ -6,8 +6,8 @@ import ( "testing" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgconn/stmtcache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/query_test.go b/query_test.go index 8ed89007..20cc49c0 100644 --- a/query_test.go +++ b/query_test.go @@ -13,8 +13,8 @@ import ( "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgconn/stmtcache" "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" From 8e341e20f353830131c92f6f0fdda90ab95a3ad9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Mar 2022 09:23:40 -0600 Subject: [PATCH 0949/1158] Remove ConnConfig.BuildStatementCache --- batch_test.go | 19 ++++---- bench_test.go | 49 +++++++++++---------- conn.go | 94 +++++++++++++++++++--------------------- conn_test.go | 98 +++++++++++++----------------------------- helper_test.go | 5 +-- pgbouncer_test.go | 8 +--- pgxpool/common_test.go | 8 ++-- query_test.go | 58 ++----------------------- rows.go | 4 +- 9 files changed, 124 insertions(+), 219 deletions(-) diff --git a/batch_test.go b/batch_test.go index 24a70e39..da35646b 100644 --- a/batch_test.go +++ b/batch_test.go @@ -7,7 +7,6 @@ import ( "testing" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -249,7 +248,9 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing. config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) - config.BuildStatementCache = nil + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 conn := mustConnect(t, config) defer closeConn(t, conn) @@ -653,7 +654,9 @@ func TestConnBeginBatchDeferredError(t *testing.T) { func TestConnSendBatchNoStatementCache(t *testing.T) { config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = nil + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 conn := mustConnect(t, config) defer closeConn(t, conn) @@ -663,9 +666,8 @@ func TestConnSendBatchNoStatementCache(t *testing.T) { func TestConnSendBatchPrepareStatementCache(t *testing.T) { config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModePrepare, 32) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + config.StatementCacheCapacity = 32 conn := mustConnect(t, config) defer closeConn(t, conn) @@ -675,9 +677,8 @@ func TestConnSendBatchPrepareStatementCache(t *testing.T) { func TestConnSendBatchDescribeStatementCache(t *testing.T) { config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModeDescribe, 32) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe + config.DescriptionCacheCapacity = 32 conn := mustConnect(t, config) defer closeConn(t, conn) diff --git a/bench_test.go b/bench_test.go index dfb879e5..e5913995 100644 --- a/bench_test.go +++ b/bench_test.go @@ -13,7 +13,6 @@ import ( "time" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/require" @@ -21,7 +20,9 @@ import ( func BenchmarkMinimalUnpreparedSelectWithoutStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = nil + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -43,9 +44,9 @@ func BenchmarkMinimalUnpreparedSelectWithoutStatementCache(b *testing.B) { func BenchmarkMinimalUnpreparedSelectWithStatementCacheModeDescribe(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModeDescribe, 32) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 32 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -67,9 +68,9 @@ func BenchmarkMinimalUnpreparedSelectWithStatementCacheModeDescribe(b *testing.B func BenchmarkMinimalUnpreparedSelectWithStatementCacheModePrepare(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModePrepare, 32) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + config.StatementCacheCapacity = 32 + config.DescriptionCacheCapacity = 0 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -723,7 +724,9 @@ func BenchmarkWrite10000RowsViaCopy(b *testing.B) { func BenchmarkMultipleQueriesNonBatchNoStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = nil + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -733,9 +736,9 @@ func BenchmarkMultipleQueriesNonBatchNoStatementCache(b *testing.B) { func BenchmarkMultipleQueriesNonBatchPrepareStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModePrepare, 32) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + config.StatementCacheCapacity = 32 + config.DescriptionCacheCapacity = 0 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -745,9 +748,9 @@ func BenchmarkMultipleQueriesNonBatchPrepareStatementCache(b *testing.B) { func BenchmarkMultipleQueriesNonBatchDescribeStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModeDescribe, 32) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 32 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -783,7 +786,9 @@ func benchmarkMultipleQueriesNonBatch(b *testing.B, conn *pgx.Conn, queryCount i func BenchmarkMultipleQueriesBatchNoStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = nil + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -793,9 +798,9 @@ func BenchmarkMultipleQueriesBatchNoStatementCache(b *testing.B) { func BenchmarkMultipleQueriesBatchPrepareStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModePrepare, 32) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + config.StatementCacheCapacity = 32 + config.DescriptionCacheCapacity = 0 conn := mustConnect(b, config) defer closeConn(b, conn) @@ -805,9 +810,9 @@ func BenchmarkMultipleQueriesBatchPrepareStatementCache(b *testing.B) { func BenchmarkMultipleQueriesBatchDescribeStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModeDescribe, 32) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 32 conn := mustConnect(b, config) defer closeConn(b, conn) diff --git a/conn.go b/conn.go index 0085e454..c85bca88 100644 --- a/conn.go +++ b/conn.go @@ -26,9 +26,13 @@ type ConnConfig struct { // Original connection string that was parsed into config. connString string - // BuildStatementCache creates the stmtcache.Cache implementation for connections created with this config. Set - // to nil to disable automatic prepared statements. - BuildStatementCache BuildStatementCacheFunc + // StatementCacheCapacity is maximum size of the statement cache used when executing a query with "cache_statement" + // query exec mode. + StatementCacheCapacity int + + // DescriptionCacheCapacity is the maximum size of the description cache used when executing a query with + // "cache_describe" query exec mode. + DescriptionCacheCapacity int // DefaultQueryExecMode controls the default mode for executing queries. By default pgx uses the extended protocol // and automatically prepares and caches prepared statements. However, this may be incompatible with proxies such as @@ -52,16 +56,14 @@ func (cc *ConnConfig) Copy() *ConnConfig { // ConnString returns the connection string as parsed by pgx.ParseConfig into pgx.ConnConfig. func (cc *ConnConfig) ConnString() string { return cc.connString } -// BuildStatementCacheFunc is a function that can be used to create a stmtcache.Cache implementation for connection. -type BuildStatementCacheFunc func(conn *pgconn.PgConn) stmtcache.Cache - // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access // to multiple database connections from multiple goroutines. type Conn struct { pgConn *pgconn.PgConn config *ConnConfig // config used when establishing this connection preparedStatements map[string]*pgconn.StatementDescription - stmtcache stmtcache.Cache + statementCache stmtcache.Cache + descriptionCache stmtcache.Cache logger Logger logLevel LogLevel @@ -115,27 +117,24 @@ func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { // ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig // does. In addition, it accepts the following options: // -// statement_cache_capacity -// The maximum size of the automatic statement cache. Set to 0 to disable automatic statement caching. Default: 512. -// -// statement_cache_mode -// Possible values: "prepare" and "describe". "prepare" will create prepared statements on the PostgreSQL server. -// "describe" will use the anonymous prepared statement to describe a statement without creating a statement on the -// server. "describe" is primarily useful when the environment does not allow prepared statements such as when -// running a connection pooler like PgBouncer. Default: "prepare" -// // default_query_exec_mode // Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See // QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement". +// +// statement_cache_capacity +// The maximum size of the statement cache used when executing a query with "cache_statement" query exec mode. +// Default: 512. +// +// description_cache_capacity +// The maximum size of the description cache used when executing a query with "cache_describe" query exec mode. +// Default: 512. func ParseConfig(connString string) (*ConnConfig, error) { config, err := pgconn.ParseConfig(connString) if err != nil { return nil, err } - var buildStatementCache BuildStatementCacheFunc statementCacheCapacity := 512 - statementCacheMode := stmtcache.ModePrepare if s, ok := config.RuntimeParams["statement_cache_capacity"]; ok { delete(config.RuntimeParams, "statement_cache_capacity") n, err := strconv.ParseInt(s, 10, 32) @@ -145,22 +144,14 @@ func ParseConfig(connString string) (*ConnConfig, error) { statementCacheCapacity = int(n) } - if s, ok := config.RuntimeParams["statement_cache_mode"]; ok { - delete(config.RuntimeParams, "statement_cache_mode") - switch s { - case "prepare": - statementCacheMode = stmtcache.ModePrepare - case "describe": - statementCacheMode = stmtcache.ModeDescribe - default: - return nil, fmt.Errorf("invalid statement_cache_mod: %s", s) - } - } - - if statementCacheCapacity > 0 { - buildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, statementCacheMode, statementCacheCapacity) + descriptionCacheCapacity := 512 + if s, ok := config.RuntimeParams["description_cache_capacity"]; ok { + delete(config.RuntimeParams, "description_cache_capacity") + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, fmt.Errorf("cannot parse description_cache_capacity: %w", err) } + descriptionCacheCapacity = int(n) } defaultQueryExecMode := QueryExecModeCacheStatement @@ -183,12 +174,13 @@ func ParseConfig(connString string) (*ConnConfig, error) { } connConfig := &ConnConfig{ - Config: *config, - createdByParseConfig: true, - LogLevel: LogLevelInfo, - BuildStatementCache: buildStatementCache, - DefaultQueryExecMode: defaultQueryExecMode, - connString: connString, + Config: *config, + createdByParseConfig: true, + LogLevel: LogLevelInfo, + StatementCacheCapacity: statementCacheCapacity, + DescriptionCacheCapacity: descriptionCacheCapacity, + DefaultQueryExecMode: defaultQueryExecMode, + connString: connString, } return connConfig, nil @@ -241,8 +233,12 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { c.closedChan = make(chan error) c.wbuf = make([]byte, 0, 1024) - if c.config.BuildStatementCache != nil { - c.stmtcache = c.config.BuildStatementCache(c.pgConn) + if c.config.StatementCacheCapacity > 0 { + c.statementCache = stmtcache.New(c.pgConn, stmtcache.ModePrepare, c.config.StatementCacheCapacity) + } + + if c.config.DescriptionCacheCapacity > 0 { + c.descriptionCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, c.config.DescriptionCacheCapacity) } // Replication connections can't execute the queries to @@ -434,13 +430,13 @@ optionLoop: return c.execSimpleProtocol(ctx, sql, arguments) } - if c.stmtcache != nil { - sd, err := c.stmtcache.Get(ctx, sql) + if c.statementCache != nil { + sd, err := c.statementCache.Get(ctx, sql) if err != nil { return pgconn.CommandTag{}, err } - if c.stmtcache.Mode() == stmtcache.ModeDescribe { + if c.statementCache.Mode() == stmtcache.ModeDescribe { return c.execParams(ctx, sd, arguments) } return c.execPrepared(ctx, sd, arguments) @@ -651,8 +647,8 @@ optionLoop: c.eqb.Reset() if !ok { - if c.stmtcache != nil { - sd, err = c.stmtcache.Get(ctx, sql) + if c.statementCache != nil { + sd, err = c.statementCache.Get(ctx, sql) if err != nil { rows.fatal(err) return rows, rows.err @@ -697,7 +693,7 @@ optionLoop: resultFormats = c.eqb.resultFormats } - if c.stmtcache != nil && c.stmtcache.Mode() == stmtcache.ModeDescribe { + if c.statementCache != nil && c.statementCache.Mode() == stmtcache.ModeDescribe { rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats) } else { rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) @@ -796,8 +792,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { var stmtCache stmtcache.Cache if len(distinctUnpreparedQueries) > 0 { - if c.stmtcache != nil && c.stmtcache.Cap() >= len(distinctUnpreparedQueries) { - stmtCache = c.stmtcache + if c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) { + stmtCache = c.statementCache } else { stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries)) } diff --git a/conn_test.go b/conn_test.go index 3792ba45..625d9693 100644 --- a/conn_test.go +++ b/conn_test.go @@ -9,7 +9,6 @@ import ( "time" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/assert" @@ -137,31 +136,42 @@ func TestParseConfigExtractsStatementCacheOptions(t *testing.T) { config, err := pgx.ParseConfig("statement_cache_capacity=0") require.NoError(t, err) - require.Nil(t, config.BuildStatementCache) + require.EqualValues(t, 0, config.StatementCacheCapacity) config, err = pgx.ParseConfig("statement_cache_capacity=42") require.NoError(t, err) - require.NotNil(t, config.BuildStatementCache) - c := config.BuildStatementCache(nil) - require.NotNil(t, c) - require.Equal(t, 42, c.Cap()) - require.Equal(t, stmtcache.ModePrepare, c.Mode()) + require.EqualValues(t, 42, config.StatementCacheCapacity) - config, err = pgx.ParseConfig("statement_cache_capacity=42 statement_cache_mode=prepare") + config, err = pgx.ParseConfig("description_cache_capacity=0") require.NoError(t, err) - require.NotNil(t, config.BuildStatementCache) - c = config.BuildStatementCache(nil) - require.NotNil(t, c) - require.Equal(t, 42, c.Cap()) - require.Equal(t, stmtcache.ModePrepare, c.Mode()) + require.EqualValues(t, 0, config.DescriptionCacheCapacity) - config, err = pgx.ParseConfig("statement_cache_capacity=42 statement_cache_mode=describe") + config, err = pgx.ParseConfig("description_cache_capacity=42") require.NoError(t, err) - require.NotNil(t, config.BuildStatementCache) - c = config.BuildStatementCache(nil) - require.NotNil(t, c) - require.Equal(t, 42, c.Cap()) - require.Equal(t, stmtcache.ModeDescribe, c.Mode()) + require.EqualValues(t, 42, config.DescriptionCacheCapacity) + + // default_query_exec_mode + // Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See + + config, err = pgx.ParseConfig("default_query_exec_mode=cache_statement") + require.NoError(t, err) + require.Equal(t, pgx.QueryExecModeCacheStatement, config.DefaultQueryExecMode) + + config, err = pgx.ParseConfig("default_query_exec_mode=cache_describe") + require.NoError(t, err) + require.Equal(t, pgx.QueryExecModeCacheDescribe, config.DefaultQueryExecMode) + + config, err = pgx.ParseConfig("default_query_exec_mode=describe_exec") + require.NoError(t, err) + require.Equal(t, pgx.QueryExecModeDescribeExec, config.DefaultQueryExecMode) + + config, err = pgx.ParseConfig("default_query_exec_mode=exec") + require.NoError(t, err) + require.Equal(t, pgx.QueryExecModeExec, config.DefaultQueryExecMode) + + config, err = pgx.ParseConfig("default_query_exec_mode=simple_protocol") + require.NoError(t, err) + require.Equal(t, pgx.QueryExecModeSimpleProtocol, config.DefaultQueryExecMode) } func TestParseConfigExtractsDefaultQueryExecMode(t *testing.T) { @@ -316,56 +326,6 @@ func TestExecFailureCloseBefore(t *testing.T) { assert.True(t, pgconn.SafeToRetry(err)) } -func TestExecStatementCacheModes(t *testing.T) { - t.Parallel() - - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - - tests := []struct { - name string - buildStatementCache pgx.BuildStatementCacheFunc - }{ - { - name: "disabled", - buildStatementCache: nil, - }, - { - name: "prepare", - buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModePrepare, 32) - }, - }, - { - name: "describe", - buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModeDescribe, 32) - }, - }, - } - - for _, tt := range tests { - func() { - config.BuildStatementCache = tt.buildStatementCache - conn := mustConnect(t, config) - defer closeConn(t, conn) - - commandTag, err := conn.Exec(context.Background(), "select 1") - assert.NoError(t, err, tt.name) - assert.Equal(t, "SELECT 1", commandTag.String(), tt.name) - - commandTag, err = conn.Exec(context.Background(), "select 1 union all select 1") - assert.NoError(t, err, tt.name) - assert.Equal(t, "SELECT 2", commandTag.String(), tt.name) - - commandTag, err = conn.Exec(context.Background(), "select 1") - assert.NoError(t, err, tt.name) - assert.Equal(t, "SELECT 1", commandTag.String(), tt.name) - - ensureConnValid(t, conn) - }() - } -} - func TestExecPerQuerySimpleProtocol(t *testing.T) { t.Parallel() diff --git a/helper_test.go b/helper_test.go index c24d776b..0ef21c5a 100644 --- a/helper_test.go +++ b/helper_test.go @@ -118,10 +118,9 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName) assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName) assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) - // Can't test function equality, so just test that they are set or not. - assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName) + assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName) + assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName) assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName) - assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) diff --git a/pgbouncer_test.go b/pgbouncer_test.go index c46f0622..ac22b679 100644 --- a/pgbouncer_test.go +++ b/pgbouncer_test.go @@ -6,8 +6,6 @@ import ( "testing" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/internal/stmtcache" - "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -19,9 +17,8 @@ func TestPgbouncerStatementCacheDescribe(t *testing.T) { } config := mustParseConfig(t, connString) - config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModeDescribe, 1024) - } + config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe + config.DescriptionCacheCapacity = 1024 testPgbouncer(t, config, 10, 100) } @@ -33,7 +30,6 @@ func TestPgbouncerSimpleProtocol(t *testing.T) { } config := mustParseConfig(t, connString) - config.BuildStatementCache = nil config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol testPgbouncer(t, config, 10, 100) diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go index 93e1940d..c0ae07c4 100644 --- a/pgxpool/common_test.go +++ b/pgxpool/common_test.go @@ -164,12 +164,10 @@ func assertConnConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, test assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName) assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName) assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) - - // Can't test function equality, so just test that they are set or not. - assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName) - + assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName) + assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName) + assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName) assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName) - assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) diff --git a/query_test.go b/query_test.go index 20cc49c0..7529e076 100644 --- a/query_test.go +++ b/query_test.go @@ -13,7 +13,6 @@ import ( "time" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/assert" @@ -1827,63 +1826,14 @@ func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryStatementCacheModes(t *testing.T) { - t.Parallel() - - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - - tests := []struct { - name string - buildStatementCache pgx.BuildStatementCacheFunc - }{ - { - name: "disabled", - buildStatementCache: nil, - }, - { - name: "prepare", - buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModePrepare, 32) - }, - }, - { - name: "describe", - buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { - return stmtcache.New(conn, stmtcache.ModeDescribe, 32) - }, - }, - } - - for _, tt := range tests { - func() { - config.BuildStatementCache = tt.buildStatementCache - conn := mustConnect(t, config) - defer closeConn(t, conn) - - var n int - err := conn.QueryRow(context.Background(), "select 1").Scan(&n) - assert.NoError(t, err, tt.name) - assert.Equal(t, 1, n, tt.name) - - err = conn.QueryRow(context.Background(), "select 2").Scan(&n) - assert.NoError(t, err, tt.name) - assert.Equal(t, 2, n, tt.name) - - err = conn.QueryRow(context.Background(), "select 1").Scan(&n) - assert.NoError(t, err, tt.name) - assert.Equal(t, 1, n, tt.name) - - ensureConnValid(t, conn) - }() - } -} - // https://github.com/jackc/pgx/issues/895 -func TestQueryErrorWithNilStatementCacheMode(t *testing.T) { +func TestQueryErrorWithDisabledStatementCache(t *testing.T) { t.Parallel() config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.BuildStatementCache = nil + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 conn := mustConnect(t, config) defer closeConn(t, conn) diff --git a/rows.go b/rows.go index 3ff8c93e..f3e154bf 100644 --- a/rows.go +++ b/rows.go @@ -161,8 +161,8 @@ func (rows *connRows) Close() { if rows.logger.shouldLog(LogLevelError) { rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]interface{}{"err": rows.err, "sql": rows.sql, "args": logQueryArgs(rows.args)}) } - if rows.err != nil && rows.conn.stmtcache != nil { - rows.conn.stmtcache.StatementErrored(rows.sql, rows.err) + if rows.err != nil && rows.conn.statementCache != nil { + rows.conn.statementCache.StatementErrored(rows.sql, rows.err) } } } From 46966227bc5245d6ede511781e1adf71fcc6d926 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Mar 2022 10:04:02 -0600 Subject: [PATCH 0950/1158] Enable all QueryExecModes for exec path --- conn.go | 84 ++++++++++++++++++++++++++++++++++++++++------------ conn_test.go | 10 ++++++- 2 files changed, 74 insertions(+), 20 deletions(-) diff --git a/conn.go b/conn.go index c85bca88..54d128ab 100644 --- a/conn.go +++ b/conn.go @@ -405,48 +405,62 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) ( } func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { - simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol + mode := c.config.DefaultQueryExecMode optionLoop: for len(arguments) > 0 { switch arg := arguments[0].(type) { case QueryExecMode: - simpleProtocol = arg == QueryExecModeSimpleProtocol + mode = arg arguments = arguments[1:] default: break optionLoop } } + // Always use simple protocol when there are no arguments. + if len(arguments) == 0 { + mode = QueryExecModeSimpleProtocol + } + if sd, ok := c.preparedStatements[sql]; ok { return c.execPrepared(ctx, sd, arguments) } - if simpleProtocol { - return c.execSimpleProtocol(ctx, sql, arguments) - } - - if len(arguments) == 0 { - return c.execSimpleProtocol(ctx, sql, arguments) - } - - if c.statementCache != nil { + switch mode { + case QueryExecModeCacheStatement: + if c.statementCache == nil { + return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") + } sd, err := c.statementCache.Get(ctx, sql) if err != nil { return pgconn.CommandTag{}, err } - if c.statementCache.Mode() == stmtcache.ModeDescribe { - return c.execParams(ctx, sd, arguments) + return c.execPrepared(ctx, sd, arguments) + case QueryExecModeCacheDescribe: + if c.descriptionCache == nil { + return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") + } + sd, err := c.descriptionCache.Get(ctx, sql) + if err != nil { + return pgconn.CommandTag{}, err + } + + return c.execParams(ctx, sd, arguments) + case QueryExecModeDescribeExec: + sd, err := c.Prepare(ctx, "", sql) + if err != nil { + return pgconn.CommandTag{}, err } return c.execPrepared(ctx, sd, arguments) + case QueryExecModeExec: + return c.execSQLParams(ctx, sql, arguments) + case QueryExecModeSimpleProtocol: + return c.execSimpleProtocol(ctx, sql, arguments) + default: + return pgconn.CommandTag{}, fmt.Errorf("unknown QueryExecMode: %v", mode) } - - sd, err := c.Prepare(ctx, "", sql) - if err != nil { - return pgconn.CommandTag{}, err - } - return c.execPrepared(ctx, sd, arguments) } func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []interface{}) (commandTag pgconn.CommandTag, err error) { @@ -510,6 +524,38 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription return result.CommandTag, result.Err } +type unknownArgumentTypeQueryExecModeExecError struct { + arg interface{} +} + +func (e *unknownArgumentTypeQueryExecModeExecError) Error() string { + return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg) +} + +func (c *Conn) execSQLParams(ctx context.Context, sql string, args []interface{}) (pgconn.CommandTag, error) { + c.eqb.Reset() + + anynil.NormalizeSlice(args) + + paramOIDs := make([]uint32, len(args)) + + for i := range args { + dt, ok := c.TypeMap().TypeForValue(args[i]) + if !ok { + return pgconn.CommandTag{}, &unknownArgumentTypeQueryExecModeExecError{arg: args[i]} + } + err := c.eqb.AppendParam(c.typeMap, dt.OID, args[i]) + if err != nil { + return pgconn.CommandTag{}, err + } + paramOIDs[i] = dt.OID + } + + result := c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, paramOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read() + c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + return result.CommandTag, result.Err +} + func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows { r := &connRows{} diff --git a/conn_test.go b/conn_test.go index 625d9693..85b0da2b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -256,7 +256,15 @@ func TestExecFailureWithArguments(t *testing.T) { assert.False(t, pgconn.SafeToRetry(err)) _, err = conn.Exec(context.Background(), "select $1::varchar(1);", "1", "2") - require.Error(t, err) + if conn.Config().DefaultQueryExecMode == pgx.QueryExecModeExec { + // The PostgreSQL server apparently doesn't care about receiving too many arguments and the only way to detect it + // locally would be to parse the SQL. The simple protocol path has to parse the SQL so it can cheaply do a check + // for the correct number of arguments. But since exec doesn't need to it doesn't make sense to waste time parsing + // the SQL. + require.NoError(t, err) + } else { + require.Error(t, err) + } }) } From 0c166c7620b50edf7d05c38de991132c90bf6318 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Mar 2022 12:47:01 -0600 Subject: [PATCH 0951/1158] Fix BC dates in text format --- pgtype/date.go | 54 ++++++++++++++++++++++++++++++++++++--------- pgtype/date_test.go | 4 ++++ 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/pgtype/date.go b/pgtype/date.go index 1d27fc78..db331e6c 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "strconv" "time" "github.com/jackc/pgx/v5/internal/pgio" @@ -182,18 +183,32 @@ func (encodePlanDateCodecText) Encode(value interface{}, buf []byte) (newBuf []b return nil, nil } - var s string - switch date.InfinityModifier { case Finite: - s = date.Time.Format("2006-01-02") + // Year 0000 is 1 BC + bc := false + year := date.Time.Year() + if year <= 0 { + year = -year + 1 + bc = true + } + + buf = strconv.AppendInt(buf, int64(year), 10) + buf = append(buf, '-') + buf = strconv.AppendInt(buf, int64(date.Time.Month()), 10) + buf = append(buf, '-') + buf = strconv.AppendInt(buf, int64(date.Time.Day()), 10) + + if bc { + buf = append(buf, " BC"...) + } case Infinity: - s = "infinity" + buf = append(buf, "infinity"...) case NegativeInfinity: - s = "-infinity" + buf = append(buf, "-infinity"...) } - return append(buf, s...), nil + return buf, nil } func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { @@ -256,12 +271,29 @@ func (scanPlanTextAnyToDateScanner) Scan(src []byte, dst interface{}) error { case "-infinity": return scanner.ScanDate(Date{InfinityModifier: -Infinity, Valid: true}) default: - t, err := time.ParseInLocation("2006-01-02", sbuf, time.UTC) - if err != nil { - return err - } + if len(sbuf) >= 10 { + year, err := strconv.ParseInt(sbuf[0:4], 10, 32) + if err != nil { + return fmt.Errorf("cannot parse year: %v", err) + } + month, err := strconv.ParseInt(sbuf[5:7], 10, 32) + if err != nil { + return fmt.Errorf("cannot parse month: %v", err) + } + day, err := strconv.ParseInt(sbuf[8:10], 10, 32) + if err != nil { + return fmt.Errorf("cannot parse day: %v", err) + } - return scanner.ScanDate(Date{Time: t, Valid: true}) + if len(sbuf) == 13 && sbuf[11:] == "BC" { + year = -year + 1 + } + + t := time.Date(int(year), time.Month(month), int(day), 0, 0, 0, 0, time.UTC) + return scanner.ScanDate(Date{Time: t, Valid: true}) + } else { + return fmt.Errorf("date too short") + } } } diff --git a/pgtype/date_test.go b/pgtype/date_test.go index d57b9115..06539822 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -19,6 +19,10 @@ func isExpectedEqTime(a interface{}) func(interface{}) bool { func TestDateCodec(t *testing.T) { testutil.RunTranscodeTests(t, "date", []testutil.TranscodeTestCase{ + {time.Date(-100, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(-100, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(-1, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(-1, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC))}, {time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC))}, {time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC))}, {time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC))}, From 1390a11fe26e8f2d584b0e640909a6b09fd14638 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Mar 2022 14:15:39 -0600 Subject: [PATCH 0952/1158] Query supports QueryExecMode Fixed QueryExecModeExec as it must only use text format without specifying param OIDs. --- conn.go | 227 ++++++++++++++++++++++++-------------- conn_test.go | 10 +- extended_query_builder.go | 8 +- values_test.go | 13 +++ 4 files changed, 162 insertions(+), 96 deletions(-) diff --git a/conn.go b/conn.go index 54d128ab..410acad7 100644 --- a/conn.go +++ b/conn.go @@ -98,6 +98,9 @@ var ErrNoRows = errors.New("no rows in result set") // ErrInvalidLogLevel occurs on attempt to set an invalid log level. var ErrInvalidLogLevel = errors.New("invalid log level") +var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") +var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") + // Connect establishes a connection with a PostgreSQL server with a connection string. See // pgconn.Connect for details. func Connect(ctx context.Context, connString string) (*Conn, error) { @@ -430,7 +433,7 @@ optionLoop: switch mode { case QueryExecModeCacheStatement: if c.statementCache == nil { - return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") + return pgconn.CommandTag{}, errDisabledStatementCache } sd, err := c.statementCache.Get(ctx, sql) if err != nil { @@ -440,7 +443,7 @@ optionLoop: return c.execPrepared(ctx, sd, arguments) case QueryExecModeCacheDescribe: if c.descriptionCache == nil { - return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") + return pgconn.CommandTag{}, errDisabledDescriptionCache } sd, err := c.descriptionCache.Get(ctx, sql) if err != nil { @@ -536,26 +539,51 @@ func (c *Conn) execSQLParams(ctx context.Context, sql string, args []interface{} c.eqb.Reset() anynil.NormalizeSlice(args) - - paramOIDs := make([]uint32, len(args)) - - for i := range args { - dt, ok := c.TypeMap().TypeForValue(args[i]) - if !ok { - return pgconn.CommandTag{}, &unknownArgumentTypeQueryExecModeExecError{arg: args[i]} - } - err := c.eqb.AppendParam(c.typeMap, dt.OID, args[i]) - if err != nil { - return pgconn.CommandTag{}, err - } - paramOIDs[i] = dt.OID + err := c.appendParamsForQueryExecModeExec(args) + if err != nil { + return pgconn.CommandTag{}, err } - result := c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, paramOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read() + result := c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats).Read() c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return result.CommandTag, result.Err } +// appendParamsForQueryExecModeExec appends the args to c.eqb. +// +// Parameters must be encoded in the text format because of differences in type conversion between timestamps and +// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the +// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both +// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL +// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date. +// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion +// before converting it to date. This means that dates can be shifted by one day. In text format without that double +// type conversion it takes the date directly and ignores time zone (i.e. it works). +// +// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is +// no way to safely use binary or to specify the parameter OIDs. +func (c *Conn) appendParamsForQueryExecModeExec(args []interface{}) error { + for i := range args { + if args[i] == nil { + err := c.eqb.AppendParamFormat(c.typeMap, 0, TextFormatCode, args[i]) + if err != nil { + return err + } + } else { + dt, ok := c.TypeMap().TypeForValue(args[i]) + if !ok { + return &unknownArgumentTypeQueryExecModeExecError{arg: args[i]} + } + err := c.eqb.AppendParamFormat(c.typeMap, dt.OID, TextFormatCode, args[i]) + if err != nil { + return err + } + } + } + + return nil +} + func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows { r := &connRows{} @@ -589,14 +617,11 @@ const ( // when the the database schema is modified concurrently. QueryExecModeDescribeExec - // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended - // protocol. Queries are executed in a single round trip. Type mappings can be registered with - // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious. - // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use - // a map[string]string directly as an argument. This mode cannot. - // - // It may be necessary to specify the desired type of an argument in the SQL string when it cannot be inferred. e.g. - // "SELECT $1::boolean". + // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol + // with text formatted parameters and results. Queries are executed in a single round trip. Type mappings can be + // registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are + // unregistered or ambigious. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know + // the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot. QueryExecModeExec // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. @@ -605,8 +630,13 @@ const ( // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use // a map[string]string directly as an argument. This mode cannot. // - // This mode uses client side parameter interpolation. All values are quoted and escaped. It may be necessary to - // specify the desired type of an argument in the SQL string when it cannot be inferred. e.g. "SELECT $1::boolean". + // QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec with minor + // exceptions such as behavior when multiple result returning queries are erroneously sent in a single string. + // + // QueryExecModeSimpleProtocol uses client side parameter interpolation. All values are quoted and escaped. Prefer + // QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol + // should only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does + // not support the extended protocol. QueryExecModeSimpleProtocol ) @@ -640,13 +670,13 @@ type QueryResultFormatsByOID map[uint32]int16 // Err() on the returned Rows must be checked after the Rows is closed to determine if the query executed successfully // as some errors can only be detected by reading the entire response. e.g. A divide by zero error on the last row. // -// For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and +// For extra control over how the query is executed, the types QueryExecMode, QueryResultFormats, and // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { var resultFormats QueryResultFormats var resultFormatsByOID QueryResultFormatsByOID - simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol + mode := c.config.DefaultQueryExecMode optionLoop: for len(args) > 0 { @@ -658,19 +688,97 @@ optionLoop: resultFormatsByOID = arg args = args[1:] case QueryExecMode: - simpleProtocol = arg == QueryExecModeSimpleProtocol + mode = arg args = args[1:] default: break optionLoop } } + c.eqb.Reset() + anynil.NormalizeSlice(args) rows := c.getRows(ctx, sql, args) var err error - sd, ok := c.preparedStatements[sql] + sd := c.preparedStatements[sql] + if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec { + if sd == nil { + switch mode { + case QueryExecModeCacheStatement: + if c.statementCache == nil { + err = errDisabledStatementCache + rows.fatal(err) + return rows, err + } + sd, err = c.statementCache.Get(ctx, sql) + if err != nil { + rows.fatal(err) + return rows, err + } + case QueryExecModeCacheDescribe: + if c.descriptionCache == nil { + err = errDisabledDescriptionCache + rows.fatal(err) + return rows, err + } + sd, err = c.descriptionCache.Get(ctx, sql) + if err != nil { + rows.fatal(err) + return rows, err + } + case QueryExecModeDescribeExec: + sd, err = c.Prepare(ctx, "", sql) + if err != nil { + rows.fatal(err) + return rows, err + } + } + } - if simpleProtocol && !ok { + if len(sd.ParamOIDs) != len(args) { + rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args))) + return rows, rows.err + } + + rows.sql = sd.SQL + + for i := range args { + err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + } + + if resultFormatsByOID != nil { + resultFormats = make([]int16, len(sd.Fields)) + for i := range resultFormats { + resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)] + } + } + + if resultFormats == nil { + for i := range sd.Fields { + c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) + } + + resultFormats = c.eqb.resultFormats + } + + if mode == QueryExecModeCacheDescribe { + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats) + } else { + rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) + } + } else if mode == QueryExecModeExec { + err := c.appendParamsForQueryExecModeExec(args) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats) + } else if mode == QueryExecModeSimpleProtocol { sql, err = c.sanitizeForSimpleQuery(sql, args...) if err != nil { rows.fatal(err) @@ -688,61 +796,10 @@ optionLoop: } return rows, nil - } - - c.eqb.Reset() - - if !ok { - if c.statementCache != nil { - sd, err = c.statementCache.Get(ctx, sql) - if err != nil { - rows.fatal(err) - return rows, rows.err - } - } else { - sd, err = c.pgConn.Prepare(ctx, "", sql, nil) - if err != nil { - rows.fatal(err) - return rows, rows.err - } - } - } - if len(sd.ParamOIDs) != len(args) { - rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args))) - return rows, rows.err - } - - rows.sql = sd.SQL - - anynil.NormalizeSlice(args) - - for i := range args { - err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) - if err != nil { - rows.fatal(err) - return rows, rows.err - } - } - - if resultFormatsByOID != nil { - resultFormats = make([]int16, len(sd.Fields)) - for i := range resultFormats { - resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)] - } - } - - if resultFormats == nil { - for i := range sd.Fields { - c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) - } - - resultFormats = c.eqb.resultFormats - } - - if c.statementCache != nil && c.statementCache.Mode() == stmtcache.ModeDescribe { - rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats) } else { - rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) + err = fmt.Errorf("unknown QueryExecMode: %v", mode) + rows.fatal(err) + return rows, rows.err } c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. diff --git a/conn_test.go b/conn_test.go index 85b0da2b..625d9693 100644 --- a/conn_test.go +++ b/conn_test.go @@ -256,15 +256,7 @@ func TestExecFailureWithArguments(t *testing.T) { assert.False(t, pgconn.SafeToRetry(err)) _, err = conn.Exec(context.Background(), "select $1::varchar(1);", "1", "2") - if conn.Config().DefaultQueryExecMode == pgx.QueryExecModeExec { - // The PostgreSQL server apparently doesn't care about receiving too many arguments and the only way to detect it - // locally would be to parse the SQL. The simple protocol path has to parse the SQL so it can cheaply do a check - // for the correct number of arguments. But since exec doesn't need to it doesn't make sense to waste time parsing - // the SQL. - require.NoError(t, err) - } else { - require.Error(t, err) - } + require.Error(t, err) }) } diff --git a/extended_query_builder.go b/extended_query_builder.go index 5409c0fd..0b6e1962 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -14,9 +14,13 @@ type extendedQueryBuilder struct { func (eqb *extendedQueryBuilder) AppendParam(m *pgtype.Map, oid uint32, arg interface{}) error { f := eqb.chooseParameterFormatCode(m, oid, arg) - eqb.paramFormats = append(eqb.paramFormats, f) + return eqb.AppendParamFormat(m, oid, f, arg) +} - v, err := eqb.encodeExtendedParamValue(m, oid, f, arg) +func (eqb *extendedQueryBuilder) AppendParamFormat(m *pgtype.Map, oid uint32, format int16, arg interface{}) error { + eqb.paramFormats = append(eqb.paramFormats, format) + + v, err := eqb.encodeExtendedParamValue(m, oid, format, arg) if err != nil { return err } diff --git a/values_test.go b/values_test.go index b7d5c572..f036b8a6 100644 --- a/values_test.go +++ b/values_test.go @@ -891,6 +891,19 @@ func TestEncodeTypeRename(t *testing.T) { inString := _string("foo") var outString _string + // pgx.QueryExecModeExec requires all types to be registered. + conn.TypeMap().RegisterDefaultPgType(inInt, "int8") + conn.TypeMap().RegisterDefaultPgType(inInt8, "int8") + conn.TypeMap().RegisterDefaultPgType(inInt16, "int8") + conn.TypeMap().RegisterDefaultPgType(inInt32, "int8") + conn.TypeMap().RegisterDefaultPgType(inInt64, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint8, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint16, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint32, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint64, "int8") + conn.TypeMap().RegisterDefaultPgType(inString, "text") + err := conn.QueryRow(context.Background(), "select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text", inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString, ).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString) From cb721dfb5b99f4372ba06f47085fb62292c069a0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Mar 2022 15:06:13 -0600 Subject: [PATCH 0953/1158] SendBatch supports default QueryExecMode --- batch_test.go | 817 ++++++++++++++++++++++++------------------------- conn.go | 143 +++++---- helper_test.go | 9 +- 3 files changed, 500 insertions(+), 469 deletions(-) diff --git a/batch_test.go b/batch_test.go index da35646b..3e5a2d46 100644 --- a/batch_test.go +++ b/batch_test.go @@ -15,230 +15,227 @@ import ( func TestConnSendBatch(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + skipCockroachDB(t, conn, "Server serial type is incompatible with test") - skipCockroachDB(t, conn, "Server serial type is incompatible with test") - - sql := `create temporary table ledger( + sql := `create temporary table ledger( id serial primary key, description varchar not null, amount int not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - batch := &pgx.Batch{} - batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1) - batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2) - batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3) - batch.Queue("select id, description, amount from ledger order by id") - batch.Queue("select id, description, amount from ledger order by id") - batch.Queue("select * from ledger where false") - batch.Queue("select sum(amount) from ledger") - - br := conn.SendBatch(context.Background(), batch) - - ct, err := br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 1 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) - } - - ct, err = br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 1 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) - } - - ct, err = br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 1 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) - } - - selectFromLedgerExpectedRows := []struct { - id int32 - description string - amount int32 - }{ - {1, "q1", 1}, - {2, "q2", 2}, - {3, "q3", 3}, - } - - rows, err := br.Query() - if err != nil { - t.Error(err) - } - - var id int32 - var description string - var amount int32 - rowCount := 0 - - for rows.Next() { - if rowCount >= len(selectFromLedgerExpectedRows) { - t.Fatalf("got too many rows: %d", rowCount) - } - - if err := rows.Scan(&id, &description, &amount); err != nil { - t.Fatalf("row %d: %v", rowCount, err) - } - - if id != selectFromLedgerExpectedRows[rowCount].id { - t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) - } - if description != selectFromLedgerExpectedRows[rowCount].description { - t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) - } - if amount != selectFromLedgerExpectedRows[rowCount].amount { - t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) - } - - rowCount++ - } - - if rows.Err() != nil { - t.Fatal(rows.Err()) - } - - rowCount = 0 - _, err = br.QueryFunc([]interface{}{&id, &description, &amount}, func(pgx.QueryFuncRow) error { - if id != selectFromLedgerExpectedRows[rowCount].id { - t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) - } - if description != selectFromLedgerExpectedRows[rowCount].description { - t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) - } - if amount != selectFromLedgerExpectedRows[rowCount].amount { - t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) - } - - rowCount++ - - return nil - }) - if err != nil { - t.Error(err) - } - - err = br.QueryRow().Scan(&id, &description, &amount) - if !errors.Is(err, pgx.ErrNoRows) { - t.Errorf("expected pgx.ErrNoRows but got: %v", err) - } - - err = br.QueryRow().Scan(&amount) - if err != nil { - t.Error(err) - } - if amount != 6 { - t.Errorf("amount => %v, want %v", amount, 6) - } - - err = br.Close() - if err != nil { - t.Fatal(err) - } - - ensureConnValid(t, conn) -} - -func TestConnSendBatchMany(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - sql := `create temporary table ledger( - id serial primary key, - description varchar not null, - amount int not null - );` - mustExec(t, conn, sql) - - batch := &pgx.Batch{} - - numInserts := 1000 - - for i := 0; i < numInserts; i++ { + batch := &pgx.Batch{} batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1) - } - batch.Queue("select count(*) from ledger") + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2) + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3) + batch.Queue("select id, description, amount from ledger order by id") + batch.Queue("select id, description, amount from ledger order by id") + batch.Queue("select * from ledger where false") + batch.Queue("select sum(amount) from ledger") - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - for i := 0; i < numInserts; i++ { ct, err := br.Exec() - assert.NoError(t, err) - assert.EqualValues(t, 1, ct.RowsAffected()) - } + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } - var actualInserts int - err := br.QueryRow().Scan(&actualInserts) - assert.NoError(t, err) - assert.EqualValues(t, numInserts, actualInserts) + ct, err = br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } - err = br.Close() - require.NoError(t, err) + ct, err = br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } - ensureConnValid(t, conn) -} + selectFromLedgerExpectedRows := []struct { + id int32 + description string + amount int32 + }{ + {1, "q1", 1}, + {2, "q2", 2}, + {3, "q3", 3}, + } -func TestConnSendBatchWithPreparedStatement(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") - - _, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n") - if err != nil { - t.Fatal(err) - } - - batch := &pgx.Batch{} - - queryCount := 3 - for i := 0; i < queryCount; i++ { - batch.Queue("ps1", 5) - } - - br := conn.SendBatch(context.Background(), batch) - - for i := 0; i < queryCount; i++ { rows, err := br.Query() if err != nil { - t.Fatal(err) + t.Error(err) } - for k := 0; rows.Next(); k++ { - var n int - if err := rows.Scan(&n); err != nil { - t.Fatal(err) + var id int32 + var description string + var amount int32 + rowCount := 0 + + for rows.Next() { + if rowCount >= len(selectFromLedgerExpectedRows) { + t.Fatalf("got too many rows: %d", rowCount) } - if n != k { - t.Fatalf("n => %v, want %v", n, k) + + if err := rows.Scan(&id, &description, &amount); err != nil { + t.Fatalf("row %d: %v", rowCount, err) } + + if id != selectFromLedgerExpectedRows[rowCount].id { + t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) + } + if description != selectFromLedgerExpectedRows[rowCount].description { + t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) + } + if amount != selectFromLedgerExpectedRows[rowCount].amount { + t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) + } + + rowCount++ } if rows.Err() != nil { t.Fatal(rows.Err()) } - } - err = br.Close() - if err != nil { - t.Fatal(err) - } + rowCount = 0 + _, err = br.QueryFunc([]interface{}{&id, &description, &amount}, func(pgx.QueryFuncRow) error { + if id != selectFromLedgerExpectedRows[rowCount].id { + t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) + } + if description != selectFromLedgerExpectedRows[rowCount].description { + t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) + } + if amount != selectFromLedgerExpectedRows[rowCount].amount { + t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) + } - ensureConnValid(t, conn) + rowCount++ + + return nil + }) + if err != nil { + t.Error(err) + } + + err = br.QueryRow().Scan(&id, &description, &amount) + if !errors.Is(err, pgx.ErrNoRows) { + t.Errorf("expected pgx.ErrNoRows but got: %v", err) + } + + err = br.QueryRow().Scan(&amount) + if err != nil { + t.Error(err) + } + if amount != 6 { + t.Errorf("amount => %v, want %v", amount, 6) + } + + err = br.Close() + if err != nil { + t.Fatal(err) + } + }) +} + +func TestConnSendBatchMany(t *testing.T) { + t.Parallel() + + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null + );` + mustExec(t, conn, sql) + + batch := &pgx.Batch{} + + numInserts := 1000 + + for i := 0; i < numInserts; i++ { + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1) + } + batch.Queue("select count(*) from ledger") + + br := conn.SendBatch(context.Background(), batch) + + for i := 0; i < numInserts; i++ { + ct, err := br.Exec() + assert.NoError(t, err) + assert.EqualValues(t, 1, ct.RowsAffected()) + } + + var actualInserts int + err := br.QueryRow().Scan(&actualInserts) + assert.NoError(t, err) + assert.EqualValues(t, numInserts, actualInserts) + + err = br.Close() + require.NoError(t, err) + }) +} + +func TestConnSendBatchWithPreparedStatement(t *testing.T) { + t.Parallel() + + modes := []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + pgx.QueryExecModeExec, + // Don't test simple mode with prepared statements. + } + testWithQueryExecModes(t, modes, func(t *testing.T, conn *pgx.Conn) { + skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") + _, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n") + if err != nil { + t.Fatal(err) + } + + batch := &pgx.Batch{} + + queryCount := 3 + for i := 0; i < queryCount; i++ { + batch.Queue("ps1", 5) + } + + br := conn.SendBatch(context.Background(), batch) + + for i := 0; i < queryCount; i++ { + rows, err := br.Query() + if err != nil { + t.Fatal(err) + } + + for k := 0; rows.Next(); k++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Fatal(err) + } + if n != k { + t.Fatalf("n => %v, want %v", n, k) + } + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + } + + err = br.Close() + if err != nil { + t.Fatal(err) + } + }) } // https://github.com/jackc/pgx/issues/856 @@ -303,316 +300,308 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing. func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - batch := &pgx.Batch{} - batch.Queue("select n from generate_series(0,5) n") - batch.Queue("select n from generate_series(0,5) n") + batch := &pgx.Batch{} + batch.Queue("select n from generate_series(0,5) n") + batch.Queue("select n from generate_series(0,5) n") - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - rows, err := br.Query() - if err != nil { - t.Error(err) - } - - for i := 0; i < 3; i++ { - if !rows.Next() { - t.Error("expected a row to be available") - } - - var n int - if err := rows.Scan(&n); err != nil { + rows, err := br.Query() + if err != nil { t.Error(err) } - if n != i { - t.Errorf("n => %v, want %v", n, i) + + for i := 0; i < 3; i++ { + if !rows.Next() { + t.Error("expected a row to be available") + } + + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } } - } - rows.Close() + rows.Close() - rows, err = br.Query() - if err != nil { - t.Error(err) - } - - for i := 0; rows.Next(); i++ { - var n int - if err := rows.Scan(&n); err != nil { + rows, err = br.Query() + if err != nil { t.Error(err) } - if n != i { - t.Errorf("n => %v, want %v", n, i) + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } } - } - if rows.Err() != nil { - t.Error(rows.Err()) - } + if rows.Err() != nil { + t.Error(rows.Err()) + } - err = br.Close() - if err != nil { - t.Fatal(err) - } + err = br.Close() + if err != nil { + t.Fatal(err) + } - ensureConnValid(t, conn) + }) } func TestConnSendBatchQueryError(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - batch := &pgx.Batch{} - batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0") - batch.Queue("select n from generate_series(0,5) n") + batch := &pgx.Batch{} + batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0") + batch.Queue("select n from generate_series(0,5) n") - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - rows, err := br.Query() - if err != nil { - t.Error(err) - } - - for i := 0; rows.Next(); i++ { - var n int - if err := rows.Scan(&n); err != nil { + rows, err := br.Query() + if err != nil { t.Error(err) } - if n != i { - t.Errorf("n => %v, want %v", n, i) + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } } - } - if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") { - t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012) - } + if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") { + t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012) + } - err = br.Close() - if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") { - t.Errorf("rows.Err() => %v, want error code %v", err, 22012) - } + err = br.Close() + if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") { + t.Errorf("rows.Err() => %v, want error code %v", err, 22012) + } - ensureConnValid(t, conn) + }) } func TestConnSendBatchQuerySyntaxError(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - batch := &pgx.Batch{} - batch.Queue("select 1 1") + batch := &pgx.Batch{} + batch.Queue("select 1 1") - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - var n int32 - err := br.QueryRow().Scan(&n) - if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") { - t.Errorf("rows.Err() => %v, want error code %v", err, 42601) - } + var n int32 + err := br.QueryRow().Scan(&n) + if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") { + t.Errorf("rows.Err() => %v, want error code %v", err, 42601) + } - err = br.Close() - if err == nil { - t.Error("Expected error") - } + err = br.Close() + if err == nil { + t.Error("Expected error") + } - ensureConnValid(t, conn) + }) } func TestConnSendBatchQueryRowInsert(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - sql := `create temporary table ledger( + sql := `create temporary table ledger( id serial primary key, description varchar not null, amount int not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - batch := &pgx.Batch{} - batch.Queue("select 1") - batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) + batch := &pgx.Batch{} + batch.Queue("select 1") + batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - var value int - err := br.QueryRow().Scan(&value) - if err != nil { - t.Error(err) - } + var value int + err := br.QueryRow().Scan(&value) + if err != nil { + t.Error(err) + } - ct, err := br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 2 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) - } + ct, err := br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 2 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) + } - br.Close() + br.Close() - ensureConnValid(t, conn) + }) } func TestConnSendBatchQueryPartialReadInsert(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - sql := `create temporary table ledger( + sql := `create temporary table ledger( id serial primary key, description varchar not null, amount int not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - batch := &pgx.Batch{} - batch.Queue("select 1 union all select 2 union all select 3") - batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) + batch := &pgx.Batch{} + batch.Queue("select 1 union all select 2 union all select 3") + batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - rows, err := br.Query() - if err != nil { - t.Error(err) - } - rows.Close() + rows, err := br.Query() + if err != nil { + t.Error(err) + } + rows.Close() - ct, err := br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 2 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) - } + ct, err := br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 2 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) + } - br.Close() + br.Close() - ensureConnValid(t, conn) + }) } func TestTxSendBatch(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - sql := `create temporary table ledger1( + sql := `create temporary table ledger1( id serial primary key, description varchar not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - sql = `create temporary table ledger2( + sql = `create temporary table ledger2( id int primary key, amount int not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - tx, _ := conn.Begin(context.Background()) - batch := &pgx.Batch{} - batch.Queue("insert into ledger1(description) values($1) returning id", "q1") + tx, _ := conn.Begin(context.Background()) + batch := &pgx.Batch{} + batch.Queue("insert into ledger1(description) values($1) returning id", "q1") - br := tx.SendBatch(context.Background(), batch) + br := tx.SendBatch(context.Background(), batch) - var id int - err := br.QueryRow().Scan(&id) - if err != nil { - t.Error(err) - } - br.Close() + var id int + err := br.QueryRow().Scan(&id) + if err != nil { + t.Error(err) + } + br.Close() - batch = &pgx.Batch{} - batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2) - batch.Queue("select amount from ledger2 where id = $1", id) + batch = &pgx.Batch{} + batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2) + batch.Queue("select amount from ledger2 where id = $1", id) - br = tx.SendBatch(context.Background(), batch) + br = tx.SendBatch(context.Background(), batch) - ct, err := br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 1 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) - } + ct, err := br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } - var amount int - err = br.QueryRow().Scan(&amount) - if err != nil { - t.Error(err) - } + var amount int + err = br.QueryRow().Scan(&amount) + if err != nil { + t.Error(err) + } - br.Close() - tx.Commit(context.Background()) + br.Close() + tx.Commit(context.Background()) - var count int - conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count) - if count != 1 { - t.Errorf("count => %v, want %v", count, 1) - } + var count int + conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count) + if count != 1 { + t.Errorf("count => %v, want %v", count, 1) + } - err = br.Close() - if err != nil { - t.Fatal(err) - } + err = br.Close() + if err != nil { + t.Fatal(err) + } - ensureConnValid(t, conn) + }) } func TestTxSendBatchRollback(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - sql := `create temporary table ledger1( + sql := `create temporary table ledger1( id serial primary key, description varchar not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - tx, _ := conn.Begin(context.Background()) - batch := &pgx.Batch{} - batch.Queue("insert into ledger1(description) values($1) returning id", "q1") + tx, _ := conn.Begin(context.Background()) + batch := &pgx.Batch{} + batch.Queue("insert into ledger1(description) values($1) returning id", "q1") - br := tx.SendBatch(context.Background(), batch) + br := tx.SendBatch(context.Background(), batch) - var id int - err := br.QueryRow().Scan(&id) - if err != nil { - t.Error(err) - } - br.Close() - tx.Rollback(context.Background()) + var id int + err := br.QueryRow().Scan(&id) + if err != nil { + t.Error(err) + } + br.Close() + tx.Rollback(context.Background()) - row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id) - var count int - row.Scan(&count) - if count != 0 { - t.Errorf("count => %v, want %v", count, 0) - } + row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id) + var count int + row.Scan(&count) + if count != 0 { + t.Errorf("count => %v, want %v", count, 0) + } - ensureConnValid(t, conn) + }) } func TestConnBeginBatchDeferredError(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") - mustExec(t, conn, `create temporary table t ( + mustExec(t, conn, `create temporary table t ( id text primary key, n int not null, unique (n) deferrable initially deferred @@ -620,36 +609,36 @@ func TestConnBeginBatchDeferredError(t *testing.T) { insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) - batch := &pgx.Batch{} + batch := &pgx.Batch{} - batch.Queue(`update t set n=n+1 where id='b' returning *`) + batch.Queue(`update t set n=n+1 where id='b' returning *`) - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - rows, err := br.Query() - if err != nil { - t.Error(err) - } - - for rows.Next() { - var id string - var n int32 - err = rows.Scan(&id, &n) + rows, err := br.Query() if err != nil { - t.Fatal(err) + t.Error(err) } - } - err = br.Close() - if err == nil { - t.Fatal("expected error 23505 but got none") - } + for rows.Next() { + var id string + var n int32 + err = rows.Scan(&id, &n) + if err != nil { + t.Fatal(err) + } + } - if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" { - t.Fatalf("expected error 23505, got %v", err) - } + err = br.Close() + if err == nil { + t.Fatal("expected error 23505 but got none") + } - ensureConnValid(t, conn) + if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + }) } func TestConnSendBatchNoStatementCache(t *testing.T) { diff --git a/conn.go b/conn.go index 410acad7..025d8022 100644 --- a/conn.go +++ b/conn.go @@ -861,9 +861,10 @@ func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, sc // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { - simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol - var sb strings.Builder - if simpleProtocol { + mode := c.config.DefaultQueryExecMode + + if mode == QueryExecModeSimpleProtocol { + var sb strings.Builder for i, bi := range b.items { if i > 0 { sb.WriteByte(';') @@ -884,66 +885,102 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { } } - distinctUnpreparedQueries := map[string]struct{}{} - - for _, bi := range b.items { - if _, ok := c.preparedStatements[bi.query]; ok { - continue - } - distinctUnpreparedQueries[bi.query] = struct{}{} - } - - var stmtCache stmtcache.Cache - if len(distinctUnpreparedQueries) > 0 { - if c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) { - stmtCache = c.statementCache - } else { - stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries)) - } - - for sql, _ := range distinctUnpreparedQueries { - _, err := stmtCache.Get(ctx, sql) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} - } - } - } - batch := &pgconn.Batch{} - for _, bi := range b.items { - c.eqb.Reset() + if mode == QueryExecModeExec { + for _, bi := range b.items { + c.eqb.Reset() + anynil.NormalizeSlice(bi.arguments) - sd := c.preparedStatements[bi.query] - if sd == nil { - var err error - sd, err = stmtCache.Get(ctx, bi.query) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} + sd := c.preparedStatements[bi.query] + if sd != nil { + if len(sd.ParamOIDs) != len(bi.arguments) { + return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} + } + + for i := range bi.arguments { + err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i]) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + } + + for i := range sd.Fields { + c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) + } + + batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats) + } else { + err := c.appendParamsForQueryExecModeExec(bi.arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + batch.ExecParams(bi.query, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats) + } + } + } else { + + distinctUnpreparedQueries := map[string]struct{}{} + + for _, bi := range b.items { + if _, ok := c.preparedStatements[bi.query]; ok { + continue + } + distinctUnpreparedQueries[bi.query] = struct{}{} + } + + var stmtCache stmtcache.Cache + if len(distinctUnpreparedQueries) > 0 { + if mode == QueryExecModeCacheStatement && c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) { + stmtCache = c.statementCache + } else if mode == QueryExecModeCacheStatement && c.descriptionCache != nil && c.descriptionCache.Cap() >= len(distinctUnpreparedQueries) { + stmtCache = c.descriptionCache + } else { + stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries)) + } + + for sql, _ := range distinctUnpreparedQueries { + _, err := stmtCache.Get(ctx, sql) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } } } - if len(sd.ParamOIDs) != len(bi.arguments) { - return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} - } + for _, bi := range b.items { + c.eqb.Reset() - anynil.NormalizeSlice(bi.arguments) - - for i := range bi.arguments { - err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i]) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} + sd := c.preparedStatements[bi.query] + if sd == nil { + var err error + sd, err = stmtCache.Get(ctx, bi.query) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } } - } - for i := range sd.Fields { - c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) - } + if len(sd.ParamOIDs) != len(bi.arguments) { + return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} + } - if sd.Name == "" { - batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats) - } else { - batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats) + anynil.NormalizeSlice(bi.arguments) + + for i := range bi.arguments { + err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i]) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + } + + for i := range sd.Fields { + c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) + } + + if sd.Name == "" { + batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats) + } else { + batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats) + } } } diff --git a/helper_test.go b/helper_test.go index 0ef21c5a..26509946 100644 --- a/helper_test.go +++ b/helper_test.go @@ -13,13 +13,18 @@ import ( ) func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) { - for _, mode := range []pgx.QueryExecMode{ + modes := []pgx.QueryExecMode{ pgx.QueryExecModeCacheStatement, pgx.QueryExecModeCacheDescribe, pgx.QueryExecModeDescribeExec, pgx.QueryExecModeExec, pgx.QueryExecModeSimpleProtocol, - } { + } + testWithQueryExecModes(t, modes, f) +} + +func testWithQueryExecModes(t *testing.T, modes []pgx.QueryExecMode, f func(t *testing.T, conn *pgx.Conn)) { + for _, mode := range modes { t.Run(mode.String(), func(t *testing.T) { config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) From 72b72b9ae9bb799cc5da1b75b60137e5c0aa7f06 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Mar 2022 15:07:32 -0600 Subject: [PATCH 0954/1158] Remove dead code --- conn.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/conn.go b/conn.go index 025d8022..eb685645 100644 --- a/conn.go +++ b/conn.go @@ -244,12 +244,6 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { c.descriptionCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, c.config.DescriptionCacheCapacity) } - // Replication connections can't execute the queries to - // populate the c.PgTypes and c.pgsqlAfInet - if _, ok := config.Config.RuntimeParams["replication"]; ok { - return c, nil - } - return c, nil } From 8c18d7808bdf0ccad2887914c78f7b71d0e66e5e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 19 Mar 2022 17:01:12 -0500 Subject: [PATCH 0955/1158] Add documentation --- CHANGELOG.md | 254 ++++-------------- README.md | 31 +-- pgtype/doc.go | 77 ++++++ .../example_custom_type_test.go | 5 +- .../example_json_test.go | 2 +- pgtype/pgtype.go | 3 +- 6 files changed, 132 insertions(+), 240 deletions(-) create mode 100644 pgtype/doc.go rename example_custom_type_test.go => pgtype/example_custom_type_test.go (93%) rename example_json_test.go => pgtype/example_json_test.go (96%) diff --git a/CHANGELOG.md b/CHANGELOG.md index b627ddda..fefb7c4d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,243 +1,83 @@ # Unreleased v5 -* Import github.com/jackc/pgtype repository -* Import github.com/jackc/pgconn repository +## Merged Packages -## pgtype Changes +`github.com/jackc/pgtype`, `github.com/jackc/pgconn`, and `github.com/jackc/pgproto3` are now included in the main `github.com/jackc/pgx` repository. Previously there was confusion as to where issues should be reported, additional release work due to releasing multiple packages, and less clear changelogs. -* Types now have Valid boolean field instead of Status byte. This matches database/sql pattern. -* Extracted integrations with github.com/shopspring/decimal and github.com/gofrs/uuid to https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. +## pgconn -# 4.15.0 (February 7, 2022) +`CommandTag` is now an opaque type instead of directly exposing an underlying `[]byte`. -* Upgrade to pgconn v1.11.0 -* Upgrade to pgtype v1.10.0 -* Upgrade puddle to v1.2.1 -* Make BatchResults.Close safe to be called multiple times +## pgtype -# 4.14.1 (November 28, 2021) +The `pgtype` package has been significantly changed. -* Upgrade pgtype to v1.9.1 (fixes unintentional change to timestamp binary decoding) -* Start pgxpool background health check after initial connections +### NULL Representation -# 4.14.0 (November 20, 2021) +Previously, types had a `Status` field that could be `Undefined`, `Null`, or `Present`. This has been changed to a `Valid` `bool` field to harmonize with how `database/sql` represents NULL and to make the zero value useable. -* Upgrade pgconn to v1.10.1 -* Upgrade pgproto3 to v2.2.0 -* Upgrade pgtype to v1.9.0 -* Upgrade puddle to v1.2.0 -* Add QueryFunc to BatchResults -* Add context options to zerologadapter (Thomas Frössman) -* Add zerologadapter.NewContextLogger (urso) -* Eager initialize minpoolsize on connect (Daniel) -* Unpin memory used by large queries immediately after use +### Codec and Value Split -# 4.13.0 (July 24, 2021) +Previously, the type system combined decoding and encoding values with the value types. e.g. Type `Int8` both handled encoding and decoding the PostgreSQL representation and acted as a value object. This caused some difficulties when there was not an exact 1 to 1 relationship between the Go types and the PostgreSQL types For example, scanning a PostgreSQL binary `numeric` into a Go `float64` was awkward (see https://github.com/jackc/pgtype/issues/147). This concepts have been separated. A `Codec` only has responsibility for encoding and decoding values. Value types are generally defined by implementing an interface that a particular `Codec` understands (e.g. `PointScanner` and `PointValuer` for the PostgreSQL `point` type). -* Trimmed pseudo-dependencies in Go modules from other packages tests -* Upgrade pgconn -- context cancellation no longer will return a net.Error -* Support time durations for simple protocol (Michael Darr) +### Array Types -# 4.12.0 (July 10, 2021) +All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This significantly reduced the amount of code and the compiled binary size. This also means that less common array types such as `point[]` are now supported. -* ResetSession hook is called before a connection is reused from pool for another query (Dmytro Haranzha) -* stdlib: Add RandomizeHostOrderFunc (dkinder) -* stdlib: add OptionBeforeConnect (dkinder) -* stdlib: Do not reuse ConnConfig strings (Andrew Kimball) -* stdlib: implement Conn.ResetSession (Jonathan Amsterdam) -* Upgrade pgconn to v1.9.0 -* Upgrade pgtype to v1.8.0 +### Composite Types -# 4.11.0 (March 25, 2021) +Composite types must be registered before use. `CompositeFields` may still be used to construct and destruct composite values, but any type may now implement `CompositeIndexGetter` and `CompositeIndexScanner` to be used as a composite. -* Add BeforeConnect callback to pgxpool.Config (Robert Froehlich) -* Add Ping method to pgxpool.Conn (davidsbond) -* Added a kitlog level log adapter (Fabrice Aneche) -* Make ScanArgError public to allow identification of offending column (Pau Sanchez) -* Add *pgxpool.AcquireFunc -* Add BeginFunc and BeginTxFunc -* Add prefer_simple_protocol to connection string -* Add logging on CopyFrom (Patrick Hemmer) -* Add comment support when sanitizing SQL queries (Rusakow Andrew) -* Do not panic on double close of pgxpool.Pool (Matt Schultz) -* Avoid panic on SendBatch on closed Tx (Matt Schultz) -* Update pgconn to v1.8.1 -* Update pgtype to v1.7.0 +### pgxtype -# 4.10.1 (December 19, 2020) +load data type moved to conn -* Fix panic on Query error with nil stmtcache. +### Bytea -# 4.10.0 (December 3, 2020) +The `Bytea` and `GenericBinary` types have been replaced. Use the following instead: -* Add CopyFromSlice to simplify CopyFrom usage (Egon Elbre) -* Remove broken prepared statements from stmtcache (Ethan Pailes) -* stdlib: consider any Ping error as fatal -* Update puddle to v1.1.3 - this fixes an issue where concurrent Acquires can hang when a connection cannot be established -* Update pgtype to v1.6.2 +* `[]byte` - For normal usage directly use `[]byte`. +* `DriverBytes` - Uses driver memory only available until next database method call. Avoids a copy and an allocation. +* `PreallocBytes` - Uses preallocated byte slice to avoid an allocation. +* `UndecodedBytes` - Avoids any decoding. Allows working with raw bytes. -# 4.9.2 (November 3, 2020) +### Dropped lib/pq Support -The underlying library updates fix an issue where appending to a scanned slice could corrupt other data. +`pgtype` previously supported and was tested against [lib/pq](https://github.com/lib/pq). While it will continue to work in most cases this is no longer supported. -* Update pgconn to v1.7.2 -* Update pgproto3 to v2.0.6 +### database/sql Scan -# 4.9.1 (October 31, 2020) +Previously, most `Scan` implementations would convert `[]byte` to `string` automatically to decode a text value. Now only `string` is handled. This is to allow the possibility of future binary support in `database/sql` mode by considering `[]byte` to be binary format and `string` text format. This change should have no effect for any use with `pgx`. The previous behavior was only necessary for `lib/pq` compatibility. -* Update pgconn to v1.7.1 -* Update pgtype to v1.6.1 -* Fix SendBatch of all prepared statements with statement cache disabled +### Number Type Fields Include Bit size -# 4.9.0 (September 26, 2020) +`Int2`, `Int4`, `Int8`, `Float4`, `Float8`, and `Uint32` fields now include bit size. e.g. `Int` is renamed to `Int64`. This matches the convention set by `database/sql`. In addition, for comparable types like `pgtype.Int8` and `sql.NullInt64` the structures are identical. This means they can be directly converted one to another. -* pgxpool now waits for connection cleanup to finish before making room in pool for another connection. This prevents temporarily exceeding max pool size. -* Fix when scanning a column to nil to skip it on the first row but scanning it to a real value on a subsequent row. -* Fix prefer simple protocol with prepared statements. (Jinzhu) -* Fix FieldDescriptions not being available on Rows before calling Next the first time. -* Various minor fixes in updated versions of pgconn, pgtype, and puddle. +### 3rd Party Type Integrations -# 4.8.1 (July 29, 2020) +* Extracted integrations with github.com/shopspring/decimal and github.com/gofrs/uuid to https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. This trims the pgx dependency tree. -* Update pgconn to v1.6.4 - * Fix deadlock on error after CommandComplete but before ReadyForQuery - * Fix panic on parsing DSN with trailing '=' +### Other Changes -# 4.8.0 (July 22, 2020) +* `Bit` and `Varbit` are both replaced by the `Bits` type. +* `CID`, `OID`, `OIDValue`, and `XID` are replaced by the `Uint32` type. +* `Hstore` is now defined as `map[string]*string`. +* `JSON` and `JSONB` types removed. Use `[]byte` or `string` directly. +* `QChar` type removed. Use `rune` or `byte` directly. +* `Macaddr` type removed. Use `net.HardwareAddr` directly. +* Renamed `pgtype.ConnInfo` to `pgtype.Map`. +* Renamed `pgtype.DataType` to `pgtype.Type`. +* Renamed `pgtype.None` to `pgtype.Finite`. +* `RegisterType` now accepts a `*Type` instead of `Type`. -* All argument types supported by native pgx should now also work through database/sql -* Update pgconn to v1.6.3 -* Update pgtype to v1.4.2 +## Reduced Memory Usage by Reusing Read Buffers -# 4.7.2 (July 14, 2020) +Previously, the connection read buffer would allocate large chunks of memory and never reuse them. This allowed transferring ownership to anything such as scanned values without incurring an additional allocation and memory copy. However, this came at the cost of overall increased memory allocation size. But worse it was also possible to pin large chunks of memory be retaining a reference to a small value that originally came directly from the read buffer. Now ownership remains with the read buffer and anything needing to retain a value must make a copy. -* Improve performance of Columns() (zikaeroh) -* Fix fatal Commit() failure not being considered fatal -* Update pgconn to v1.6.2 -* Update pgtype to v1.4.1 +## Query Execution Modes -# 4.7.1 (June 29, 2020) +Control over automatic prepared statement caching and simple protocol use are now combined into query execution mode. See documentation for `QueryExecMode`. -* Fix stdlib decoding error with certain order and combination of fields +## 3rd Party Logger Integration -# 4.7.0 (June 27, 2020) - -* Update pgtype to v1.4.0 -* Update pgconn to v1.6.1 -* Update puddle to v1.1.1 -* Fix context propagation with Tx commit and Rollback (georgysavva) -* Add lazy connect option to pgxpool (georgysavva) -* Fix connection leak if pgxpool.BeginTx() fail (Jean-Baptiste Bronisz) -* Add native Go slice support for strings and numbers to simple protocol -* stdlib add default timeouts for Conn.Close() and Stmt.Close() (georgysavva) -* Assorted performance improvements especially with large result sets -* Fix close pool on not lazy connect failure (Yegor Myskin) -* Add Config copy (georgysavva) -* Support SendBatch with Simple Protocol (Jordan Lewis) -* Better error logging on rows close (Igor V. Kozinov) -* Expose stdlib.Conn.Conn() to enable database/sql.Conn.Raw() -* Improve unknown type support for database/sql -* Fix transaction commit failure closing connection - -# 4.6.0 (March 30, 2020) - -* stdlib: Bail early if preloading rows.Next() results in rows.Err() (Bas van Beek) -* Sanitize time to microsecond accuracy (Andrew Nicoll) -* Update pgtype to v1.3.0 -* Update pgconn to v1.5.0 - * Update golang.org/x/crypto for security fix - * Implement "verify-ca" SSL mode - -# 4.5.0 (March 7, 2020) - -* Update to pgconn v1.4.0 - * Fixes QueryRow with empty SQL - * Adds PostgreSQL service file support -* Add Len() to *pgx.Batch (WGH) -* Better logging for individual batch items (Ben Bader) - -# 4.4.1 (February 14, 2020) - -* Update pgconn to v1.3.2 - better default read buffer size -* Fix race in CopyFrom - -# 4.4.0 (February 5, 2020) - -* Update puddle to v1.1.0 - fixes possible deadlock when acquire is cancelled -* Update pgconn to v1.3.1 - fixes CopyFrom deadlock when multiple NoticeResponse received during copy -* Update pgtype to v1.2.0 -* Add MaxConnIdleTime to pgxpool (Patrick Ellul) -* Add MinConns to pgxpool (Patrick Ellul) -* Fix: stdlib.ReleaseConn closes connections left in invalid state - -# 4.3.0 (January 23, 2020) - -* Fix Rows.Values panic when unable to decode -* Add Rows.Values support for unknown types -* Add DriverContext support for stdlib (Alex Gaynor) -* Update pgproto3 to v2.0.1 to never return an io.EOF as it would be misinterpreted by database/sql. Instead return io.UnexpectedEOF. - -# 4.2.1 (January 13, 2020) - -* Update pgconn to v1.2.1 (fixes context cancellation data race introduced in v1.2.0)) - -# 4.2.0 (January 11, 2020) - -* Update pgconn to v1.2.0. -* Update pgtype to v1.1.0. -* Return error instead of panic when wrong number of arguments passed to Exec. (malstoun) -* Fix large objects functionality when PreferSimpleProtocol = true. -* Restore GetDefaultDriver which existed in v3. (Johan Brandhorst) -* Add RegisterConnConfig to stdlib which replaces the removed RegisterDriverConfig from v3. - -# 4.1.2 (October 22, 2019) - -* Fix dbSavepoint.Begin recursive self call -* Upgrade pgtype to v1.0.2 - fix scan pointer to pointer - -# 4.1.1 (October 21, 2019) - -* Fix pgxpool Rows.CommandTag() infinite loop / typo - -# 4.1.0 (October 12, 2019) - -## Potentially Breaking Changes - -Technically, two changes are breaking changes, but in practice these are extremely unlikely to break existing code. - -* Conn.Begin and Conn.BeginTx return a Tx interface instead of the internal dbTx struct. This is necessary for the Conn.Begin method to signature as other methods that begin a transaction. -* Add Conn() to Tx interface. This is necessary to allow code using a Tx to access the *Conn (and pgconn.PgConn) on which the Tx is executing. - -## Fixes - -* Releasing a busy connection closes the connection instead of returning an unusable connection to the pool -* Do not mutate config.Config.OnNotification in connect - -# 4.0.1 (September 19, 2019) - -* Fix statement cache cleanup. -* Corrected daterange OID. -* Fix Tx when committing or rolling back multiple times in certain cases. -* Improve documentation. - -# 4.0.0 (September 14, 2019) - -v4 is a major release with many significant changes some of which are breaking changes. The most significant are -included below. - -* Simplified establishing a connection with a connection string. -* All potentially blocking operations now require a context.Context. The non-context aware functions have been removed. -* OIDs are hard-coded for known types. This saves the query on connection. -* Context cancellations while network activity is in progress is now always fatal. Previously, it was sometimes recoverable. This led to increased complexity in pgx itself and in application code. -* Go modules are required. -* Errors are now implemented in the Go 1.13 style. -* `Rows` and `Tx` are now interfaces. -* The connection pool as been decoupled from pgx and is now a separate, included package (github.com/jackc/pgx/v4/pgxpool). -* pgtype has been spun off to a separate package (github.com/jackc/pgtype). -* pgproto3 has been spun off to a separate package (github.com/jackc/pgproto3/v2). -* Logical replication support has been spun off to a separate package (github.com/jackc/pglogrepl). -* Lower level PostgreSQL functionality is now implemented in a separate package (github.com/jackc/pgconn). -* Tests are now configured with environment variables. -* Conn has an automatic statement cache by default. -* Batch interface has been simplified. -* QueryArgs has been removed. +All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency tree. diff --git a/README.md b/README.md index 3a3f1f8b..159712dd 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,8 @@ # pgx - PostgreSQL Driver and Toolkit +*This is the v5 development branch. It is still in active development and testing.* + pgx is a pure Go driver and toolkit for PostgreSQL. pgx aims to be low-level, fast, and performant, while also enabling PostgreSQL-specific features that the standard `database/sql` package does not allow for. @@ -13,8 +15,6 @@ The toolkit component is a related set of packages that implement PostgreSQL fun and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers, proxies, load balancers, logical replication clients, etc. -The current release of `pgx v4` requires Go modules. To use the previous version, checkout and vendor the `v3` branch. - ## Example Usage ```go @@ -74,7 +74,7 @@ pgx supports many features beyond what is available through `database/sql`: * Full TLS connection control * Binary format support for custom types (allows for much quicker encoding/decoding) * COPY protocol support for faster bulk data loads -* Extendable logging support including built-in support for `log15adapter`, [`logrus`](https://github.com/sirupsen/logrus), [`zap`](https://github.com/uber-go/zap), and [`zerolog`](https://github.com/rs/zerolog) +* Extendable logging support * Connection pool with after-connect hook for arbitrary connection setup * Listen / notify * Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings @@ -129,7 +129,7 @@ In addition, there are tests specific for PgBouncer that will be executed if `PG ## Supported Go and PostgreSQL Versions -pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.16 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). +pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.17 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). ## Version Policy @@ -137,29 +137,6 @@ pgx follows semantic versioning for the documented public API on stable releases ## PGX Family Libraries -pgx is the head of a family of PostgreSQL libraries. Many of these can be used independently. Many can also be accessed -from pgx for lower-level control. - -### [github.com/jackc/v4/pgconn](https://github.com/jackc/pgx/tree/master/pgconn) - -`pgconn` is a lower-level PostgreSQL database driver that operates at nearly the same level as the C library `libpq`. - -### [github.com/jackc/pgx/v5/pgxpool](https://github.com/jackc/pgx/tree/master/pgxpool) - -`pgxpool` is a connection pool for pgx. pgx is entirely decoupled from its default pool implementation. This means that pgx can be used with a different pool or without any pool at all. - -### [github.com/jackc/pgx/v5/stdlib](https://github.com/jackc/pgx/tree/master/stdlib) - -This is a `database/sql` compatibility layer for pgx. pgx can be used as a normal `database/sql` driver, but at any time, the native interface can be acquired for more performance or PostgreSQL specific functionality. - -### [github.com/jackc/pgx/v5/pgtype](https://github.com/jackc/pgx/tree/master/pgtype) - -Over 70 PostgreSQL types are supported including `uuid`, `hstore`, `json`, `bytea`, `numeric`, `interval`, `inet`, and arrays. - -### [github.com/jackc/pgproto3](https://github.com/jackc/pgproto3) - -pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling. - ### [github.com/jackc/pglogrepl](https://github.com/jackc/pglogrepl) pglogrepl provides functionality to act as a client for PostgreSQL logical replication. diff --git a/pgtype/doc.go b/pgtype/doc.go new file mode 100644 index 00000000..1de29bd2 --- /dev/null +++ b/pgtype/doc.go @@ -0,0 +1,77 @@ +// Package pgtype converts between Go and PostgreSQL values. +/* +The primary type is the Map type. It is a map of PostgreSQL types identified by OID (object ID) to a Codec. A Codec is +responsible for converting between Go and PostgreSQL values. NewMap creates a Map with all supported standard PostgreSQL +types already registered. Additional types can be registered with Map.RegisterType. + +Use Map.Scan and Map.Encode to decode PostgreSQL values to Go and encode Go values to PostgreSQL respectively. + +JSON Support + +pgtype automatically marshals and unmarshals data from json and jsonb PostgreSQL types. + +Array Support + +ArrayCodec implements support for arrays. If pgtype supports type T then it can easily support []T by registering an +ArrayCodec for the appropriate PostgreSQL OID. + +Composite Support + +CompositeCodec implements support for PostgreSQL composite types. Go structs can be scanned into if the public fields of +the struct are in the exact order and type of the PostgreSQL type or by implementing CompositeIndexScanner and +CompositeIndexGetter. + +Enum Support + +PostgreSQL enums can usually be treated as text. However, EnumCodec implements support for interning strings which can reduce memory usage. + +Array, Composite, and Enum Type Registration + +Array, composite, and enum types can be easily registered from a pgx.Conn with the LoadType method. + +Extending Existing Type Support + +Generally, all Codecs will support interfaces that can be implemented to enable scanning and encoding. For example, +PointCodec can use any Go type that implements the PointScanner and PointValuer interfaces. So rather than use +pgtype.Point and application can directly use its own point type with pgtype as long as it implements those interfaces. + +Sometimes pgx supports a PostgreSQL type such as numeric but the Go type is in an external package that does not have +pgx support such as github.com/shopspring/decimal. These types can be registered with pgtype with custom conversion +logic. See https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid for a example +integrations. + +Entirely New Type Support + +If the PostgreSQL type is not already supported then an OID / Codec mapping can be registered with Map.RegisterType. +There is no difference between a Codec defined and registered by the application and a Codec built in to pgtype. See any +of the Codecs in pgtype for Codec examples and for examples of type registration. + +Encoding Unknown Types + +pgtype works best when the OID of the PostgreSQL type is known. But in some cases such as using the simple protocol the +OID is unknown. In this case Map.RegisterDefaultPgType can be used to register an assumed OID for a particular Go type. + +Overview of Scanning Implementation + +The first step is to use the OID to lookup the correct Codec. If the OID is unavailable, Map will try to find the OID +from previous calls of Map.RegisterDefaultPgType. The Map will call the Codec's PlanScan method to get a plan for +scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types are +interfaces rather than explicit types. For example, PointCodec can use any Go type that implments the PointScanner and +PointValuer interfaces. + +If a Go value is not supported directly by a Codec then Map will try wrapping it with additional logic and try again. +For example, Int8Codec does not support scanning into a renamed type (e.g. type myInt64 int64). But Map will detect that +myInt64 is a renamed type and create a plan that converts the value to the underlying int64 type and then passes that to +the Codec (see TryFindUnderlyingTypeScanPlan). + +These plan wrappers are contained in Map.TryWrapScanPlanFuncs. By default these contain shared logic to handle renamed +types, pointers to pointers, slices, composite types, etc. Additional plan wrappers can be added to seamlessly integrate +types that do not support pgx directly. For example, the before mentioned +https://github.com/jackc/pgx-shopspring-decimal package detects decimal.Decimal values, wraps them in something +implementing NumericScanner and passes that to the Codec. + +Map.Scan and Map.Encode are convenience methods that wrap Map.PlanScan and Map.PlanEncode. Determining how to scan or +encode a particular type may be a time consuming operation. Hence the planning and execution steps of a conversion are +internally separated. +*/ +package pgtype diff --git a/example_custom_type_test.go b/pgtype/example_custom_type_test.go similarity index 93% rename from example_custom_type_test.go rename to pgtype/example_custom_type_test.go index fc0d4b78..2fd63bcc 100644 --- a/example_custom_type_test.go +++ b/pgtype/example_custom_type_test.go @@ -1,17 +1,14 @@ -package pgx_test +package pgtype_test import ( "context" "fmt" "os" - "regexp" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" ) -var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) - // Point represents a point that may be null. type Point struct { X, Y float32 // Coordinates of point diff --git a/example_json_test.go b/pgtype/example_json_test.go similarity index 96% rename from example_json_test.go rename to pgtype/example_json_test.go index 017699b9..c11348b7 100644 --- a/example_json_test.go +++ b/pgtype/example_json_test.go @@ -1,4 +1,4 @@ -package pgx_test +package pgtype_test import ( "context" diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 1cc809b1..bd065e01 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -132,6 +132,7 @@ const ( BinaryFormatCode = 1 ) +// A Codec converts between Go and PostgreSQL values. type Codec interface { // FormatSupported returns true if the format is supported. FormatSupported(int16) bool @@ -139,7 +140,7 @@ type Codec interface { // PreferredFormat returns the preferred format. PreferredFormat() int16 - // PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be + // PlanEncode returns an EncodePlan for encoding value into PostgreSQL format for oid and format. If no plan can be // found then nil is returned. PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan From 9f23ed84ba342b8363b360d599196cfe9af5a41d Mon Sep 17 00:00:00 2001 From: Patrick Audley Date: Sat, 19 Mar 2022 19:49:06 -0700 Subject: [PATCH 0956/1158] Minor typo in Changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fefb7c4d..553fd915 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -72,7 +72,7 @@ Previously, most `Scan` implementations would convert `[]byte` to `string` autom ## Reduced Memory Usage by Reusing Read Buffers -Previously, the connection read buffer would allocate large chunks of memory and never reuse them. This allowed transferring ownership to anything such as scanned values without incurring an additional allocation and memory copy. However, this came at the cost of overall increased memory allocation size. But worse it was also possible to pin large chunks of memory be retaining a reference to a small value that originally came directly from the read buffer. Now ownership remains with the read buffer and anything needing to retain a value must make a copy. +Previously, the connection read buffer would allocate large chunks of memory and never reuse them. This allowed transferring ownership to anything such as scanned values without incurring an additional allocation and memory copy. However, this came at the cost of overall increased memory allocation size. But worse it was also possible to pin large chunks of memory by retaining a reference to a small value that originally came directly from the read buffer. Now ownership remains with the read buffer and anything needing to retain a value must make a copy. ## Query Execution Modes From b103a6efbda898f3e26b9562da71837b1163f498 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jens=20Emil=20Schulz=20=C3=98stergaard?= Date: Sun, 20 Mar 2022 15:05:19 +0100 Subject: [PATCH 0957/1158] test: jsonbarray set failing test cases --- jsonb_array_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/jsonb_array_test.go b/jsonb_array_test.go index 65f1777a..e63d0c00 100644 --- a/jsonb_array_test.go +++ b/jsonb_array_test.go @@ -1,6 +1,8 @@ package pgtype_test import ( + "encoding/json" + "reflect" "testing" "github.com/jackc/pgtype" @@ -34,3 +36,53 @@ func TestJSONBArrayTranscode(t *testing.T) { }, }) } + +func TestJSONBArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.JSONBArray + }{ + {source: []string{"{}"}, result: pgtype.JSONBArray{ + Elements: []pgtype.JSONB{pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{pgtype.ArrayDimension{Length: 1, LowerBound: 1}}, + Status: pgtype.Present, + }}, + {source: [][]byte{[]byte("{}")}, result: pgtype.JSONBArray{ + Elements: []pgtype.JSONB{pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{pgtype.ArrayDimension{Length: 1, LowerBound: 1}}, + Status: pgtype.Present, + }}, + {source: [][]byte{[]byte(`{"foo":1}`), []byte(`{"bar":2}`)}, result: pgtype.JSONBArray{ + Elements: []pgtype.JSONB{pgtype.JSONB{Bytes: []byte(`{"foo":1}`), Status: pgtype.Present}, pgtype.JSONB{Bytes: []byte(`{"bar":2}`), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{pgtype.ArrayDimension{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }}, + {source: []json.RawMessage{json.RawMessage(`{"foo":1}`), json.RawMessage(`{"bar":2}`)}, result: pgtype.JSONBArray{ + Elements: []pgtype.JSONB{pgtype.JSONB{Bytes: []byte(`{"foo":1}`), Status: pgtype.Present}, pgtype.JSONB{Bytes: []byte(`{"bar":2}`), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{pgtype.ArrayDimension{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }}, + {source: []json.RawMessage{json.RawMessage(`{"foo":12}`), json.RawMessage(`{"bar":2}`)}, result: pgtype.JSONBArray{ + Elements: []pgtype.JSONB{pgtype.JSONB{Bytes: []byte(`{"foo":12}`), Status: pgtype.Present}, pgtype.JSONB{Bytes: []byte(`{"bar":2}`), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{pgtype.ArrayDimension{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }}, + {source: []json.RawMessage{json.RawMessage(`{"foo":1}`), json.RawMessage(`{"bar":{"x":2}}`)}, result: pgtype.JSONBArray{ + Elements: []pgtype.JSONB{pgtype.JSONB{Bytes: []byte(`{"foo":1}`), Status: pgtype.Present}, pgtype.JSONB{Bytes: []byte(`{"bar":{"x":2}}`), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{pgtype.ArrayDimension{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }}, + } + + for i, tt := range successfulTests { + var d pgtype.JSONBArray + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(d, tt.result) { + t.Errorf("%d: expected %+v to convert to %+v, but it was %+v", i, tt.source, tt.result, d) + } + } +} From 4c6f1b1dc49de587531c756b1120fc877c739cee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jens=20Emil=20Schulz=20=C3=98stergaard?= Date: Sun, 20 Mar 2022 15:05:43 +0100 Subject: [PATCH 0958/1158] fix: add json rawmessage to typed_array_gen.sh --- jsonb_array.go | 29 +++++++++++++++++++++++++++++ typed_array_gen.sh | 2 +- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/jsonb_array.go b/jsonb_array.go index c4b7cd3d..e78ad377 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "fmt" "reflect" @@ -72,6 +73,25 @@ func (dst *JSONBArray) Set(src interface{}) error { } } + case []json.RawMessage: + if value == nil { + *dst = JSONBArray{Status: Null} + } else if len(value) == 0 { + *dst = JSONBArray{Status: Present} + } else { + elements := make([]JSONB, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = JSONBArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []JSONB: if value == nil { *dst = JSONBArray{Status: Null} @@ -214,6 +234,15 @@ func (src *JSONBArray) AssignTo(dst interface{}) error { } return nil + case *[]json.RawMessage: + *v = make([]json.RawMessage, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + } } diff --git a/typed_array_gen.sh b/typed_array_gen.sh index ea28be07..1f4098c7 100755 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -20,7 +20,7 @@ erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[] erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]*float32,[]float64,[]*float64,[]int64,[]*int64,[]uint64,[]*uint64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string,[]*string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go -erb pgtype_array_type=JSONBArray pgtype_element_type=JSONB go_array_types=[]string,[][]byte element_type_name=jsonb text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go +erb pgtype_array_type=JSONBArray pgtype_element_type=JSONB go_array_types=[]string,[][]byte,[]json.RawMessage element_type_name=jsonb text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go # While the binary format is theoretically possible it is only practical to use the text format. erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string,[]*string text_null=NULL binary_format=false typed_array.go.erb > enum_array.go From 5ca048ed2d74fc33a087250c806ac89e6be08d00 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 22 Mar 2022 19:20:22 -0500 Subject: [PATCH 0959/1158] Fix crash with pointer to nil struct --- pgtype/pgtype.go | 7 ++++++- pgtype/pgtype_test.go | 9 +++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index bd065e01..5180e27a 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -952,7 +952,12 @@ func TryWrapStructScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, return nil, nil, false } - targetElemValue := targetValue.Elem() + var targetElemValue reflect.Value + if targetValue.IsNil() { + targetElemValue = reflect.New(targetValue.Type().Elem()) + } else { + targetElemValue = targetValue.Elem() + } targetElemType := targetElemValue.Type() if targetElemType.Kind() == reflect.Struct { diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index ff19790a..776d176c 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -175,6 +175,15 @@ func TestTypeMapScanUnregisteredOIDToCustomType(t *testing.T) { assert.Nil(t, pCt) } +func TestTypeMapScanPointerToNilStructDoesNotCrash(t *testing.T) { + m := pgtype.NewMap() + + type myStruct struct{} + var p *myStruct + err := m.Scan(0, pgx.TextFormatCode, []byte("(foo,bar)"), &p) + require.NotNil(t, err) +} + func TestTypeMapScanUnknownOIDTextFormat(t *testing.T) { m := pgtype.NewMap() From be5a6cc9c0981907ad67092d9954c40922f1fbf4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 22 Mar 2022 19:20:50 -0500 Subject: [PATCH 0960/1158] Remove obsolete test --- pgtype/pgtype_test.go | 52 ------------------------------------------- 1 file changed, 52 deletions(-) diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 776d176c..afa54e7a 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -1,7 +1,6 @@ package pgtype_test import ( - "bytes" "database/sql" "errors" "net" @@ -124,57 +123,6 @@ func TestTypeMapScanUnknownOIDToStringsAndBytes(t *testing.T) { assert.Equal(t, []byte("foo"), []byte(rb)) } -type pgCustomType struct { - a string - b string -} - -func (ct *pgCustomType) DecodeText(m *pgtype.Map, buf []byte) error { - // This is not a complete parser for the text format of composite types. This is just for test purposes. - if buf == nil { - return errors.New("cannot parse null") - } - - if len(buf) < 2 { - return errors.New("invalid text format") - } - - parts := bytes.Split(buf[1:len(buf)-1], []byte(",")) - if len(parts) != 2 { - return errors.New("wrong number of parts") - } - - ct.a = string(parts[0]) - ct.b = string(parts[1]) - - return nil -} - -func TestTypeMapScanUnregisteredOIDToCustomType(t *testing.T) { - t.Skip("TODO - unskip later in v5") // may no longer be relevent - unregisteredOID := uint32(999999) - m := pgtype.NewMap() - - var ct pgCustomType - err := m.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct) - assert.NoError(t, err) - assert.Equal(t, "foo", ct.a) - assert.Equal(t, "bar", ct.b) - - // Scan value into pointer to custom type - var pCt *pgCustomType - err = m.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt) - assert.NoError(t, err) - require.NotNil(t, pCt) - assert.Equal(t, "foo", pCt.a) - assert.Equal(t, "bar", pCt.b) - - // Scan null into pointer to custom type - err = m.Scan(unregisteredOID, pgx.TextFormatCode, nil, &pCt) - assert.NoError(t, err) - assert.Nil(t, pCt) -} - func TestTypeMapScanPointerToNilStructDoesNotCrash(t *testing.T) { m := pgtype.NewMap() From 0cd7c757c392e505c47d4d56ff29524880fd9b5f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 22 Mar 2022 19:23:40 -0500 Subject: [PATCH 0961/1158] Fix skipped test --- values_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/values_test.go b/values_test.go index f036b8a6..d728edab 100644 --- a/values_test.go +++ b/values_test.go @@ -1036,7 +1036,6 @@ order by a nulls first } func TestScanIntoByteSlice(t *testing.T) { - t.Skip("TODO - unskip later in v5") t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -1070,7 +1069,7 @@ func TestScanIntoByteSlice(t *testing.T) { sql string err string }{ - {"int binary", "select 42", "can't scan into dest[0]: cannot assign 42 into *[]uint8"}, + {"int binary", "select 42", "can't scan into dest[0]: cannot scan OID 23 in binary format into *[]uint8"}, } { t.Run(tt.name, func(t *testing.T) { var buf []byte From 29bec2b97e3972cb9b2f6b26a04b348a6ce068c6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 22 Mar 2022 19:25:41 -0500 Subject: [PATCH 0962/1158] Remove skipped test for scan binary to string Receiving a binary value and encoding it back into text seems to be an anti-pattern to may. Don't want to silently enable this. May be able to reverse course later if necessary. --- values_test.go | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/values_test.go b/values_test.go index d728edab..7a6cb3a2 100644 --- a/values_test.go +++ b/values_test.go @@ -242,33 +242,6 @@ func mustParseCIDR(t *testing.T, s string) *net.IPNet { return ipnet } -func TestStringToNotTextTypeTranscode(t *testing.T) { - t.Skip("TODO - unskip later in v5") // Should this even be a thing... i.e. anything is scanable to a string to a string - - t.Parallel() - - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - input := "01086ee0-4963-4e35-9116-30c173a8d0bd" - - var output string - err := conn.QueryRow(context.Background(), "select $1::uuid", input).Scan(&output) - if err != nil { - t.Fatal(err) - } - if input != output { - t.Errorf("uuid: Did not transcode string successfully: %s is not %s", input, output) - } - - err = conn.QueryRow(context.Background(), "select $1::uuid", &input).Scan(&output) - if err != nil { - t.Fatal(err) - } - if input != output { - t.Errorf("uuid: Did not transcode pointer to string successfully: %s is not %s", input, output) - } - }) -} - func TestInetCIDRTranscodeIPNet(t *testing.T) { t.Parallel() From 793eb53017ea8bf9a5ee2557c9b24b98ae9cffb4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 22 Mar 2022 19:28:48 -0500 Subject: [PATCH 0963/1158] Enable test with updated error message --- query_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/query_test.go b/query_test.go index 7529e076..7c9077c5 100644 --- a/query_test.go +++ b/query_test.go @@ -1889,7 +1889,6 @@ func TestConnQueryFunc(t *testing.T) { } func TestConnQueryFuncScanError(t *testing.T) { - t.Skip("TODO - unskip later in v5") t.Parallel() testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { @@ -1906,7 +1905,7 @@ func TestConnQueryFuncScanError(t *testing.T) { return nil }, ) - require.EqualError(t, err, "can't scan into dest[0]: unable to assign to *int") + require.EqualError(t, err, "can't scan into dest[0]: cannot scan OID 25 in text format into *int") require.Nil(t, ct) }) } From 95c03dc9ae906a7d42d19c85b43edbd8b4422548 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 22 Mar 2022 19:34:57 -0500 Subject: [PATCH 0964/1158] Unskip and fix tests --- query_test.go | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/query_test.go b/query_test.go index 7c9077c5..3742775e 100644 --- a/query_test.go +++ b/query_test.go @@ -263,7 +263,6 @@ func TestConnQueryReadRowMultipleTimes(t *testing.T) { // https://github.com/jackc/pgx/issues/228 func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T) { - t.Skip("TODO - unskip later in v5") t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -271,8 +270,8 @@ func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T) var s string - err := conn.QueryRow(context.Background(), "select 1").Scan(&s) - if err == nil || !(strings.Contains(err.Error(), "cannot decode binary value into string") || strings.Contains(err.Error(), "cannot assign")) { + err := conn.QueryRow(context.Background(), "select point(1,2)").Scan(&s) + if err == nil || !(strings.Contains(err.Error(), "cannot scan OID 600 in binary format into *string")) { t.Fatalf("Expected Scan to fail to encode binary value into string but: %v", err) } @@ -956,7 +955,6 @@ func TestQueryRowCoreByteSlice(t *testing.T) { } func TestQueryRowErrors(t *testing.T) { - t.Skip("TODO - unskip later in v5") t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -976,10 +974,10 @@ func TestQueryRowErrors(t *testing.T) { scanArgs []interface{} err string }{ - // {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, - // {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, - {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "unable to assign"}, - // {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Point"}, + {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, + {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, + {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "cannot scan OID 25 in text format into *int16"}, + {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "unable to encode 705 into OID 600"}, } for i, tt := range tests { From 69580cd519fb102d71a9933c11dbc2475ad04e33 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 22 Mar 2022 19:43:27 -0500 Subject: [PATCH 0965/1158] Fix a test failure --- query_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/query_test.go b/query_test.go index 3742775e..dfb8d2a4 100644 --- a/query_test.go +++ b/query_test.go @@ -1904,7 +1904,7 @@ func TestConnQueryFuncScanError(t *testing.T) { }, ) require.EqualError(t, err, "can't scan into dest[0]: cannot scan OID 25 in text format into *int") - require.Nil(t, ct) + require.Equal(t, pgconn.CommandTag{}, ct) }) } From 0fd0688d4f70600369238b7af0668846c27a43c8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 22 Mar 2022 19:56:48 -0500 Subject: [PATCH 0966/1158] Alter some tests for CockroachDB --- query_test.go | 8 +++++++- values_test.go | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/query_test.go b/query_test.go index dfb8d2a4..1cff72bf 100644 --- a/query_test.go +++ b/query_test.go @@ -268,6 +268,8 @@ func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + skipCockroachDB(t, conn, "Server does not support point type") + var s string err := conn.QueryRow(context.Background(), "select point(1,2)").Scan(&s) @@ -373,7 +375,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) { defer closeConn(t, conn) // Read a single value incorrectly - rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) + rows, err := conn.Query(context.Background(), "select n::int4 from generate_series(1,$1) n", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } @@ -960,6 +962,10 @@ func TestQueryRowErrors(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip("Skipping due to known server missing point type") + } + type allTypes struct { i16 int16 i int diff --git a/values_test.go b/values_test.go index 7a6cb3a2..15c16746 100644 --- a/values_test.go +++ b/values_test.go @@ -1042,7 +1042,7 @@ func TestScanIntoByteSlice(t *testing.T) { sql string err string }{ - {"int binary", "select 42", "can't scan into dest[0]: cannot scan OID 23 in binary format into *[]uint8"}, + {"int binary", "select 42::int4", "can't scan into dest[0]: cannot scan OID 23 in binary format into *[]uint8"}, } { t.Run(tt.name, func(t *testing.T) { var buf []byte From 210ebb4a50047f8227d13434a0fa30c4d222218b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 22 Mar 2022 19:59:56 -0500 Subject: [PATCH 0967/1158] Disable incomptible test with CockroachDB --- copy_from_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/copy_from_test.go b/copy_from_test.go index 6e2fe952..1182cb1e 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -416,6 +416,8 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + skipCockroachDB(t, conn, "Server copy error does not fail fast") + mustExec(t, conn, `create temporary table foo( a bytea not null )`) From e04b35bfcb741cc853d739a88e6e6c2eab591544 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 22 Mar 2022 20:31:00 -0500 Subject: [PATCH 0968/1158] Make pgtype test compat with CockroachDB when possible --- pgtype/array_codec_test.go | 10 ++++++++++ pgtype/box_test.go | 2 ++ pgtype/bytea_test.go | 2 +- pgtype/circle_test.go | 2 ++ pgtype/composite_test.go | 8 ++++++++ pgtype/inet_test.go | 2 ++ pgtype/line_test.go | 2 ++ pgtype/lseg_test.go | 2 ++ pgtype/macaddr_test.go | 2 ++ pgtype/numeric_test.go | 6 ++++++ pgtype/path_test.go | 2 ++ pgtype/pgtype_test.go | 10 ++++++++++ pgtype/point_test.go | 2 ++ pgtype/polygon_test.go | 2 ++ pgtype/qchar_test.go | 2 ++ pgtype/range_codec_test.go | 8 ++++++++ pgtype/text_test.go | 4 ++++ pgtype/tid_test.go | 2 ++ pgtype/timestamp_test.go | 2 ++ pgtype/timestamptz_test.go | 2 ++ 20 files changed, 73 insertions(+), 1 deletion(-) diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go index 81e564c7..0cc6d7cb 100644 --- a/pgtype/array_codec_test.go +++ b/pgtype/array_codec_test.go @@ -110,6 +110,8 @@ func TestArrayCodecDecodeValue(t *testing.T) { } func TestArrayCodecScanMultipleDimensions(t *testing.T) { + skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") + conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) @@ -127,6 +129,8 @@ func TestArrayCodecScanMultipleDimensions(t *testing.T) { } func TestArrayCodecScanMultipleDimensionsEmpty(t *testing.T) { + skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") + conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) @@ -144,6 +148,8 @@ func TestArrayCodecScanMultipleDimensionsEmpty(t *testing.T) { } func TestArrayCodecScanWrongMultipleDimensions(t *testing.T) { + skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") + conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) @@ -158,6 +164,8 @@ func TestArrayCodecScanWrongMultipleDimensions(t *testing.T) { } func TestArrayCodecEncodeMultipleDimensions(t *testing.T) { + skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") + conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) @@ -175,6 +183,8 @@ func TestArrayCodecEncodeMultipleDimensions(t *testing.T) { } func TestArrayCodecEncodeMultipleDimensionsRagged(t *testing.T) { + skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") + conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) diff --git a/pgtype/box_test.go b/pgtype/box_test.go index 72e37b76..173fb1f5 100644 --- a/pgtype/box_test.go +++ b/pgtype/box_test.go @@ -8,6 +8,8 @@ import ( ) func TestBoxCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support box type") + testutil.RunTranscodeTests(t, "box", []testutil.TranscodeTestCase{ { pgtype.Box{ diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go index ae4a8760..443b73ce 100644 --- a/pgtype/bytea_test.go +++ b/pgtype/bytea_test.go @@ -119,7 +119,7 @@ func TestUndecodedBytes(t *testing.T) { ctx := context.Background() var buf []byte - err := conn.QueryRow(ctx, `select 1`).Scan((*pgtype.UndecodedBytes)(&buf)) + err := conn.QueryRow(ctx, `select 1::int4`).Scan((*pgtype.UndecodedBytes)(&buf)) require.NoError(t, err) require.Len(t, buf, 4) diff --git a/pgtype/circle_test.go b/pgtype/circle_test.go index f38d8194..b78d35ba 100644 --- a/pgtype/circle_test.go +++ b/pgtype/circle_test.go @@ -8,6 +8,8 @@ import ( ) func TestCircleTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support box type") + testutil.RunTranscodeTests(t, "circle", []testutil.TranscodeTestCase{ { pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go index d97f617b..0f112ebd 100644 --- a/pgtype/composite_test.go +++ b/pgtype/composite_test.go @@ -12,6 +12,8 @@ import ( ) func TestCompositeCodecTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) @@ -90,6 +92,8 @@ func (p *point3d) ScanIndex(i int) interface{} { } func TestCompositeCodecTranscodeStruct(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) @@ -125,6 +129,8 @@ create type point3d as ( } func TestCompositeCodecTranscodeStructWrapper(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) @@ -164,6 +170,8 @@ create type point3d as ( } func TestCompositeCodecDecodeValue(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go index c3f66755..249caf3f 100644 --- a/pgtype/inet_test.go +++ b/pgtype/inet_test.go @@ -35,6 +35,8 @@ func TestInetTranscode(t *testing.T) { } func TestCidrTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support cidr type (see https://github.com/cockroachdb/cockroach/issues/18846)") + testutil.RunTranscodeTests(t, "cidr", []testutil.TranscodeTestCase{ {mustParseInet(t, "0.0.0.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "0.0.0.0/32"))}, {mustParseInet(t, "127.0.0.1/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "127.0.0.1/32"))}, diff --git a/pgtype/line_test.go b/pgtype/line_test.go index 6c7b734b..8e3d782c 100644 --- a/pgtype/line_test.go +++ b/pgtype/line_test.go @@ -9,6 +9,8 @@ import ( ) func TestLineTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support type line") + conn := testutil.MustConnectPgx(t) defer conn.Close(context.Background()) if _, ok := conn.TypeMap().TypeForName("line"); !ok { diff --git a/pgtype/lseg_test.go b/pgtype/lseg_test.go index 51fe2adb..e754b3b2 100644 --- a/pgtype/lseg_test.go +++ b/pgtype/lseg_test.go @@ -8,6 +8,8 @@ import ( ) func TestLsegTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support type lseg") + testutil.RunTranscodeTests(t, "lseg", []testutil.TranscodeTestCase{ { pgtype.Lseg{ diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go index 2ce7b007..06262876 100644 --- a/pgtype/macaddr_test.go +++ b/pgtype/macaddr_test.go @@ -26,6 +26,8 @@ func isExpectedEqHardwareAddr(a interface{}) func(interface{}) bool { } func TestMacaddrCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support type macaddr") + testutil.RunTranscodeTests(t, "macaddr", []testutil.TranscodeTestCase{ { mustParseMacaddr(t, "01:23:45:67:89:ab"), diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 448cfff2..d5c60575 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -71,6 +71,8 @@ func mustParseNumeric(t *testing.T, src string) pgtype.Numeric { } func TestNumericCodec(t *testing.T) { + skipCockroachDB(t, "server formats numeric text format differently") + max := new(big.Int).Exp(big.NewInt(10), big.NewInt(147454), nil) max.Add(max, big.NewInt(1)) longestNumeric := pgtype.Numeric{Int: max, Exp: -16383, Valid: true} @@ -143,6 +145,8 @@ func TestNumericFloat64Valuer(t *testing.T) { } func TestNumericCodecFuzz(t *testing.T) { + skipCockroachDB(t, "server formats numeric text format differently") + r := rand.New(rand.NewSource(0)) max := &big.Int{} max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) @@ -166,6 +170,8 @@ func TestNumericCodecFuzz(t *testing.T) { } func TestNumericMarshalJSON(t *testing.T) { + skipCockroachDB(t, "server formats numeric text format differently") + conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) diff --git a/pgtype/path_test.go b/pgtype/path_test.go index 546f4d36..40df2bfb 100644 --- a/pgtype/path_test.go +++ b/pgtype/path_test.go @@ -27,6 +27,8 @@ func isExpectedEqPath(a interface{}) func(interface{}) bool { } func TestPathTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support type path") + testutil.RunTranscodeTests(t, "path", []testutil.TranscodeTestCase{ { pgtype.Path{ diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index afa54e7a..2ee32907 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -8,6 +8,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" _ "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -66,6 +67,15 @@ func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { return addr } +func skipCockroachDB(t testing.TB, msg string) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip(msg) + } +} + func TestTypeMapScanNilIsNoOp(t *testing.T) { m := pgtype.NewMap() diff --git a/pgtype/point_test.go b/pgtype/point_test.go index 03d948b7..62fcfc51 100644 --- a/pgtype/point_test.go +++ b/pgtype/point_test.go @@ -10,6 +10,8 @@ import ( ) func TestPointCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support type point") + testutil.RunTranscodeTests(t, "point", []testutil.TranscodeTestCase{ { pgtype.Point{P: pgtype.Vec2{1.234, 5.6789012345}, Valid: true}, diff --git a/pgtype/polygon_test.go b/pgtype/polygon_test.go index 9c7c0182..6c6fc60d 100644 --- a/pgtype/polygon_test.go +++ b/pgtype/polygon_test.go @@ -27,6 +27,8 @@ func isExpectedEqPolygon(a interface{}) func(interface{}) bool { } func TestPolygonTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support type polygon") + testutil.RunTranscodeTests(t, "polygon", []testutil.TranscodeTestCase{ { pgtype.Polygon{ diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go index 36742f75..0bf781a4 100644 --- a/pgtype/qchar_test.go +++ b/pgtype/qchar_test.go @@ -8,6 +8,8 @@ import ( ) func TestQcharTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support qchar") + var tests []testutil.TranscodeTestCase for i := 0; i <= math.MaxUint8; i++ { tests = append(tests, testutil.TranscodeTestCase{rune(i), new(rune), isExpectedEq(rune(i))}) diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go index 84a55a52..b4127769 100644 --- a/pgtype/range_codec_test.go +++ b/pgtype/range_codec_test.go @@ -10,6 +10,8 @@ import ( ) func TestRangeCodecTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + testutil.RunTranscodeTests(t, "int4range", []testutil.TranscodeTestCase{ { pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, @@ -37,6 +39,8 @@ func TestRangeCodecTranscode(t *testing.T) { } func TestRangeCodecTranscodeCompatibleRangeElementTypes(t *testing.T) { + skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + testutil.RunTranscodeTests(t, "numrange", []testutil.TranscodeTestCase{ { pgtype.Float8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, @@ -64,6 +68,8 @@ func TestRangeCodecTranscodeCompatibleRangeElementTypes(t *testing.T) { } func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { + skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) @@ -116,6 +122,8 @@ func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { } func TestRangeCodecDecodeValue(t *testing.T) { + skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) diff --git a/pgtype/text_test.go b/pgtype/text_test.go index 7a188a67..c80c404b 100644 --- a/pgtype/text_test.go +++ b/pgtype/text_test.go @@ -65,6 +65,8 @@ func TestTextCodecName(t *testing.T) { // Test fixed length char types like char(3) func TestTextCodecBPChar(t *testing.T) { + skipCockroachDB(t, "Server does not properly handle bpchar with multi-byte character") + testutil.RunTranscodeTests(t, "char(3)", []testutil.TranscodeTestCase{ { pgtype.Text{String: "a ", Valid: true}, @@ -92,6 +94,8 @@ func TestTextCodecBPChar(t *testing.T) { // // It only supports the text format. func TestTextCodecACLItem(t *testing.T) { + skipCockroachDB(t, "Server does not support type aclitem") + conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go index 4ff53151..08636aa8 100644 --- a/pgtype/tid_test.go +++ b/pgtype/tid_test.go @@ -8,6 +8,8 @@ import ( ) func TestTIDCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support type tid") + testutil.RunTranscodeTests(t, "tid", []testutil.TranscodeTestCase{ { pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index a33ce78f..764baff1 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -11,6 +11,8 @@ import ( ) func TestTimestampCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") + testutil.RunTranscodeTests(t, "timestamp", []testutil.TranscodeTestCase{ {time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC))}, {time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC))}, diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index ec198fa1..678f3013 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -11,6 +11,8 @@ import ( ) func TestTimestamptzCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") + testutil.RunTranscodeTests(t, "timestamptz", []testutil.TranscodeTestCase{ {time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local))}, {time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local))}, From 7b31b56de959f65023fe71abd4de54d6708bf02c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 22 Mar 2022 20:33:16 -0500 Subject: [PATCH 0969/1158] Reactivate CI for other DB versions --- .github/workflows/ci.yml | 73 ++++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 37 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a62d38e0..af164815 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,41 +15,40 @@ jobs: strategy: matrix: go-version: [1.17] - pg-version: [14] - # pg-version: [10, 11, 12, 13, 14, cockroachdb] + pg-version: [10, 11, 12, 13, 14, cockroachdb] include: - # - pg-version: 10 - # pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test - # pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - # pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" - # pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - # pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - # pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - # pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - # - pg-version: 11 - # pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test - # pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - # pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" - # pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - # pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - # pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - # pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - # - pg-version: 12 - # pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test - # pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - # pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" - # pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - # pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - # pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - # pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - # - pg-version: 13 - # pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test - # pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - # pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" - # pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - # pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - # pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - # pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: 10 + pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: 11 + pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: 12 + pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: 13 + pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: 14 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test @@ -58,9 +57,9 @@ jobs: pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - # - pg-version: cockroachdb - # pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" - # pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" + - pg-version: cockroachdb + pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" + pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" steps: From 103dfe145ee2c97a7179e904d672ddd535ccae75 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 22 Mar 2022 20:41:05 -0500 Subject: [PATCH 0970/1158] Test should always close rows --- pgtype/record_codec_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pgtype/record_codec_test.go b/pgtype/record_codec_test.go index 14018e9e..2bc2524e 100644 --- a/pgtype/record_codec_test.go +++ b/pgtype/record_codec_test.go @@ -58,6 +58,7 @@ func TestRecordCodecDecodeValue(t *testing.T) { t.Run(tt.sql, func(t *testing.T) { rows, err := conn.Query(context.Background(), tt.sql) require.NoError(t, err) + defer rows.Close() for rows.Next() { values, err := rows.Values() From 600c4fd93122572d09e77378c3f20c9c8b31a2a2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 22 Mar 2022 20:44:17 -0500 Subject: [PATCH 0971/1158] Skip test for Cockroach CI --- pgtype/record_codec_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/record_codec_test.go b/pgtype/record_codec_test.go index 2bc2524e..d6fe603c 100644 --- a/pgtype/record_codec_test.go +++ b/pgtype/record_codec_test.go @@ -23,6 +23,8 @@ func TestRecordCodec(t *testing.T) { } func TestRecordCodecDecodeValue(t *testing.T) { + skipCockroachDB(t, "Server converts row int4 to int8") + conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) From 3a6d9490e5150f2ebfb893de14de97f74bc3580f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Mar 2022 11:38:31 -0500 Subject: [PATCH 0972/1158] Only test numeric infinity on PG 14+ --- pgtype/numeric_test.go | 20 ++++++++++++++------ pgtype/pgtype_test.go | 21 +++++++++++++++++++++ 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index d5c60575..281efa0f 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -98,16 +98,10 @@ func TestNumericCodec(t *testing.T) { {pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -92, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -92, Valid: true})}, {pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -93, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -93, Valid: true})}, {pgtype.Numeric{NaN: true, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{NaN: true, Valid: true})}, - {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true})}, - {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true})}, {longestNumeric, new(pgtype.Numeric), isExpectedEqNumeric(longestNumeric)}, {mustParseNumeric(t, "1"), new(int64), isExpectedEq(int64(1))}, {math.NaN(), new(float64), func(a interface{}) bool { return math.IsNaN(a.(float64)) }}, {float32(math.NaN()), new(float32), func(a interface{}) bool { return math.IsNaN(float64(a.(float32))) }}, - {math.Inf(1), new(float64), isExpectedEq(math.Inf(1))}, - {float32(math.Inf(1)), new(float32), isExpectedEq(float32(math.Inf(1)))}, - {math.Inf(-1), new(float64), isExpectedEq(math.Inf(-1))}, - {float32(math.Inf(-1)), new(float32), isExpectedEq(float32(math.Inf(-1)))}, {int64(-1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "-1"))}, {int64(0), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0"))}, {int64(1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))}, @@ -120,6 +114,20 @@ func TestNumericCodec(t *testing.T) { }) } +func TestNumericCodecInfinity(t *testing.T) { + skipCockroachDB(t, "server formats numeric text format differently") + skipPostgreSQLVersionLessThan(t, 14) + + testutil.RunTranscodeTests(t, "numeric", []testutil.TranscodeTestCase{ + {math.Inf(1), new(float64), isExpectedEq(math.Inf(1))}, + {float32(math.Inf(1)), new(float32), isExpectedEq(float32(math.Inf(1)))}, + {math.Inf(-1), new(float64), isExpectedEq(math.Inf(-1))}, + {float32(math.Inf(-1)), new(float32), isExpectedEq(float32(math.Inf(-1)))}, + {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true})}, + {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true})}, + }) +} + func TestNumericFloat64Valuer(t *testing.T) { for i, tt := range []struct { n pgtype.Numeric diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 2ee32907..dd0150fb 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -4,6 +4,8 @@ import ( "database/sql" "errors" "net" + "regexp" + "strconv" "testing" "github.com/jackc/pgx/v5" @@ -76,6 +78,25 @@ func skipCockroachDB(t testing.TB, msg string) { } } +func skipPostgreSQLVersionLessThan(t testing.TB, minVersion int64) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + serverVersionStr := conn.PgConn().ParameterStatus("server_version") + serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) + // if not PostgreSQL do nothing + if serverVersionStr == "" { + return + } + + serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64) + require.NoError(t, err) + + if serverVersion < minVersion { + t.Skipf("Test requires PostgreSQL v%d+", minVersion) + } +} + func TestTypeMapScanNilIsNoOp(t *testing.T) { m := pgtype.NewMap() From 500c0721d7f2a6b6215129f02a39f100d4f99810 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 1 Apr 2022 17:54:56 -0500 Subject: [PATCH 0973/1158] Improve error messages for query argument encoding --- conn.go | 4 ++++ pgtype/pgtype.go | 11 +++++++++-- query_test.go | 2 +- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/conn.go b/conn.go index eb685645..8ccf5043 100644 --- a/conn.go +++ b/conn.go @@ -488,6 +488,7 @@ func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, args for i := range args { err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) if err != nil { + err = fmt.Errorf("failed to encode args[%d]: %v", i, err) return err } } @@ -739,6 +740,7 @@ optionLoop: for i := range args { err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) if err != nil { + err = fmt.Errorf("failed to encode args[%d]: %v", i, err) rows.fatal(err) return rows, rows.err } @@ -895,6 +897,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { for i := range bi.arguments { err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i]) if err != nil { + err = fmt.Errorf("failed to encode args[%d]: %v", i, err) return &batchResults{ctx: ctx, conn: c, err: err} } } @@ -962,6 +965,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { for i := range bi.arguments { err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i]) if err != nil { + err = fmt.Errorf("failed to encode args[%d]: %v", i, err) return &batchResults{ctx: ctx, conn: c, err: err} } } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5180e27a..ffc017f8 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1737,7 +1737,14 @@ func (m *Map) Encode(oid uint32, formatCode int16, value interface{}, buf []byte return m.Encode(oid, formatCode, v, buf) } - return nil, fmt.Errorf("unable to encode %#v into OID %d", value, oid) + return nil, fmt.Errorf("unable to encode %#v into format code %d for OID %d", value, formatCode, oid) } - return plan.Encode(value, buf) + + newBuf, err = plan.Encode(value, buf) + if err != nil { + err = fmt.Errorf("unable to encode %#v into format code %d for OID %d: %v", value, formatCode, oid, err) + return nil, err + } + + return newBuf, nil } diff --git a/query_test.go b/query_test.go index 1cff72bf..8e9d3ef9 100644 --- a/query_test.go +++ b/query_test.go @@ -983,7 +983,7 @@ func TestQueryRowErrors(t *testing.T) { {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "cannot scan OID 25 in text format into *int16"}, - {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "unable to encode 705 into OID 600"}, + {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "unable to encode 705 into format code 1 for OID 600"}, } for i, tt := range tests { From e392908c7294ecb258f60c8cc833abd57b1362d6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 Apr 2022 08:24:55 -0500 Subject: [PATCH 0974/1158] Remove Int64Valuer implementation from stringWrapper --- pgtype/builtin_wrappers.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index f182f570..30a88465 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -351,15 +351,6 @@ func (w *stringWrapper) ScanInt64(v Int8) error { return nil } -func (w stringWrapper) Int64Value() (Int8, error) { - num, err := strconv.ParseInt(string(w), 10, 64) - if err != nil { - return Int8{}, err - } - - return Int8{Int64: int64(num), Valid: true}, nil -} - type timeWrapper time.Time func (w *timeWrapper) ScanDate(v Date) error { From e18d76b7983e52bc1f5b0206d6fee7fbf7d5406f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 Apr 2022 10:26:47 -0500 Subject: [PATCH 0975/1158] Initial extraction of pgxtest - Introduce ConnTestRunner - RunWithQueryExecModes --- batch_test.go | 23 ++++++------ conn_test.go | 23 ++++++------ helper_test.go | 40 +++++---------------- pgxtest/pgxtest.go | 88 ++++++++++++++++++++++++++++++++++++++++++++++ query_test.go | 7 ++-- values_test.go | 57 +++++++++++++++--------------- 6 files changed, 154 insertions(+), 84 deletions(-) create mode 100644 pgxtest/pgxtest.go diff --git a/batch_test.go b/batch_test.go index 3e5a2d46..ffd990fc 100644 --- a/batch_test.go +++ b/batch_test.go @@ -8,6 +8,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -15,7 +16,7 @@ import ( func TestConnSendBatch(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { skipCockroachDB(t, conn, "Server serial type is incompatible with test") sql := `create temporary table ledger( @@ -149,7 +150,7 @@ func TestConnSendBatch(t *testing.T) { func TestConnSendBatchMany(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { sql := `create temporary table ledger( id serial primary key, description varchar not null, @@ -194,7 +195,7 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) { pgx.QueryExecModeExec, // Don't test simple mode with prepared statements. } - testWithQueryExecModes(t, modes, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, modes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") _, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n") if err != nil { @@ -300,7 +301,7 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing. func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { batch := &pgx.Batch{} batch.Queue("select n from generate_series(0,5) n") @@ -359,7 +360,7 @@ func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) { func TestConnSendBatchQueryError(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { batch := &pgx.Batch{} batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0") @@ -397,7 +398,7 @@ func TestConnSendBatchQueryError(t *testing.T) { func TestConnSendBatchQuerySyntaxError(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { batch := &pgx.Batch{} batch.Queue("select 1 1") @@ -421,7 +422,7 @@ func TestConnSendBatchQuerySyntaxError(t *testing.T) { func TestConnSendBatchQueryRowInsert(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { sql := `create temporary table ledger( id serial primary key, @@ -458,7 +459,7 @@ func TestConnSendBatchQueryRowInsert(t *testing.T) { func TestConnSendBatchQueryPartialReadInsert(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { sql := `create temporary table ledger( id serial primary key, @@ -495,7 +496,7 @@ func TestConnSendBatchQueryPartialReadInsert(t *testing.T) { func TestTxSendBatch(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { sql := `create temporary table ledger1( id serial primary key, @@ -562,7 +563,7 @@ func TestTxSendBatch(t *testing.T) { func TestTxSendBatchRollback(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { sql := `create temporary table ledger1( id serial primary key, @@ -597,7 +598,7 @@ func TestTxSendBatchRollback(t *testing.T) { func TestConnBeginBatchDeferredError(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") diff --git a/conn_test.go b/conn_test.go index 625d9693..31190d7c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -11,6 +11,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -198,7 +199,7 @@ func TestParseConfigExtractsDefaultQueryExecMode(t *testing.T) { func TestExec(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results.String() != "CREATE TABLE" { t.Error("Unexpected results from Exec") } @@ -232,7 +233,7 @@ func TestExec(t *testing.T) { func TestExecFailure(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { if _, err := conn.Exec(context.Background(), "selct;"); err == nil { t.Fatal("Expected SQL syntax error") } @@ -248,7 +249,7 @@ func TestExecFailure(t *testing.T) { func TestExecFailureWithArguments(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { _, err := conn.Exec(context.Background(), "selct $1;", 1) if err == nil { t.Fatal("Expected SQL syntax error") @@ -263,7 +264,7 @@ func TestExecFailureWithArguments(t *testing.T) { func TestExecContextWithoutCancelation(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() @@ -281,7 +282,7 @@ func TestExecContextWithoutCancelation(t *testing.T) { func TestExecContextFailureWithoutCancelation(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() @@ -303,7 +304,7 @@ func TestExecContextFailureWithoutCancelation(t *testing.T) { func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() @@ -680,7 +681,7 @@ func TestFatalTxError(t *testing.T) { func TestInsertBoolArray(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results.String() != "CREATE TABLE" { t.Error("Unexpected results from Exec") } @@ -695,7 +696,7 @@ func TestInsertBoolArray(t *testing.T) { func TestInsertTimestampArray(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results.String() != "CREATE TABLE" { t.Error("Unexpected results from Exec") } @@ -819,7 +820,7 @@ func TestConnInitTypeMap(t *testing.T) { } func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) { - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { skipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") var n uint64 @@ -835,7 +836,7 @@ func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) { } func TestDomainType(t *testing.T) { - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { skipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") // Domain type uint64 is a PostgreSQL domain of underlying type numeric. @@ -1006,7 +1007,7 @@ func TestStmtCacheInvalidationTx(t *testing.T) { } func TestInsertDurationInterval(t *testing.T) { - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { _, err := conn.Exec(context.Background(), "create temporary table t(duration INTERVAL(0) NOT NULL)") require.NoError(t, err) diff --git a/helper_test.go b/helper_test.go index 26509946..e0f04906 100644 --- a/helper_test.go +++ b/helper_test.go @@ -9,40 +9,18 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) -func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) { - modes := []pgx.QueryExecMode{ - pgx.QueryExecModeCacheStatement, - pgx.QueryExecModeCacheDescribe, - pgx.QueryExecModeDescribeExec, - pgx.QueryExecModeExec, - pgx.QueryExecModeSimpleProtocol, - } - testWithQueryExecModes(t, modes, f) -} +var defaultConnTestRunner pgxtest.ConnTestRunner -func testWithQueryExecModes(t *testing.T, modes []pgx.QueryExecMode, f func(t *testing.T, conn *pgx.Conn)) { - for _, mode := range modes { - t.Run(mode.String(), - func(t *testing.T) { - config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - config.DefaultQueryExecMode = mode - conn, err := pgx.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer func() { - err := conn.Close(context.Background()) - require.NoError(t, err) - }() - - f(t, conn) - - ensureConnValid(t, conn) - }, - ) +func init() { + defaultConnTestRunner = pgxtest.DefaultConnTestRunner() + defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + return config } } @@ -84,7 +62,7 @@ func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{} } // Do a simple query to ensure the connection is still usable -func ensureConnValid(t *testing.T, conn *pgx.Conn) { +func ensureConnValid(t testing.TB, conn *pgx.Conn) { var sum, rowCount int32 rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) diff --git a/pgxtest/pgxtest.go b/pgxtest/pgxtest.go new file mode 100644 index 00000000..579bcd92 --- /dev/null +++ b/pgxtest/pgxtest.go @@ -0,0 +1,88 @@ +// Package pgxtest provides utilities for testing pgx and packages that integrate with pgx. +package pgxtest + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" +) + +// ConnTestRunner controls how a *pgx.Conn is created and closed by tests. All fields are required. Use DefaultConnTestRunner to get a +// ConnTestRunner with reasonable default values. +type ConnTestRunner struct { + // CreateConfig returns a *pgx.ConnConfig suitable for use with pgx.ConnectConfig. + CreateConfig func(ctx context.Context, t testing.TB) *pgx.ConnConfig + + // AfterConnect is called after conn is established. It allows for arbitrary connection setup before a test begins. + AfterConnect func(ctx context.Context, t testing.TB, conn *pgx.Conn) + + // AfterTest is called after the test is run. It allows for validating the state of the connection before it is closed. + AfterTest func(ctx context.Context, t testing.TB, conn *pgx.Conn) + + // CloseConn closes conn. + CloseConn func(ctx context.Context, t testing.TB, conn *pgx.Conn) +} + +// DefaultConnTestRunner returns a new ConnTestRunner with all fields set to reasonable default values. +func DefaultConnTestRunner() ConnTestRunner { + return ConnTestRunner{ + CreateConfig: func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config, err := pgx.ParseConfig("") + if err != nil { + t.Fatalf("ParseConfig failed: %v", err) + } + return config + }, + AfterConnect: func(ctx context.Context, t testing.TB, conn *pgx.Conn) {}, + AfterTest: func(ctx context.Context, t testing.TB, conn *pgx.Conn) {}, + CloseConn: func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + err := conn.Close(ctx) + if err != nil { + t.Errorf("Close failed: %v", err) + } + }, + } +} + +func (ctr *ConnTestRunner) RunTest(ctx context.Context, t testing.TB, f func(ctx context.Context, t testing.TB, conn *pgx.Conn)) { + config := ctr.CreateConfig(ctx, t) + conn, err := pgx.ConnectConfig(ctx, config) + if err != nil { + t.Fatalf("ConnectConfig failed: %v", err) + } + defer ctr.CloseConn(ctx, t, conn) + + ctr.AfterConnect(ctx, t, conn) + f(ctx, t, conn) + ctr.AfterTest(ctx, t, conn) +} + +// RunWithQueryExecModes runs a f in a new test for each element of modes with a new connection created using connector. +// If modes is nil all pgx.QueryExecModes are tested. +func RunWithQueryExecModes(ctx context.Context, t *testing.T, ctr ConnTestRunner, modes []pgx.QueryExecMode, f func(ctx context.Context, t testing.TB, conn *pgx.Conn)) { + if modes == nil { + modes = []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + pgx.QueryExecModeExec, + pgx.QueryExecModeSimpleProtocol, + } + } + + for _, mode := range modes { + ctrWithMode := ctr + ctrWithMode.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := ctr.CreateConfig(ctx, t) + config.DefaultQueryExecMode = mode + return config + } + + t.Run(mode.String(), + func(t *testing.T) { + ctrWithMode.RunTest(ctx, t, f) + }, + ) + } +} diff --git a/query_test.go b/query_test.go index 8e9d3ef9..e8790bba 100644 --- a/query_test.go +++ b/query_test.go @@ -15,6 +15,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1866,7 +1867,7 @@ func TestQueryErrorWithDisabledStatementCache(t *testing.T) { func TestConnQueryFunc(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { var actualResults []interface{} var a, b int @@ -1895,7 +1896,7 @@ func TestConnQueryFunc(t *testing.T) { func TestConnQueryFuncScanError(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { var actualResults []interface{} var a, b int @@ -1917,7 +1918,7 @@ func TestConnQueryFuncScanError(t *testing.T) { func TestConnQueryFuncAbort(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { var a, b int ct, err := conn.QueryFunc( context.Background(), diff --git a/values_test.go b/values_test.go index 15c16746..55c577f5 100644 --- a/values_test.go +++ b/values_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,7 +19,7 @@ import ( func TestDateTranscode(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { dates := []time.Time{ time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC), @@ -57,7 +58,7 @@ func TestDateTranscode(t *testing.T) { func TestTimestampTzTranscode(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { inputTime := time.Date(2013, 1, 2, 3, 4, 5, 6000, time.Local) var outputTime time.Time @@ -77,7 +78,7 @@ func TestTimestampTzTranscode(t *testing.T) { func TestJSONAndJSONBTranscode(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { for _, typename := range []string{"json", "jsonb"} { if _, ok := conn.TypeMap().TypeForName(typename); !ok { continue // No JSON/JSONB type -- must be running against old PostgreSQL @@ -109,7 +110,7 @@ func TestJSONAndJSONBTranscodeExtendedOnly(t *testing.T) { } -func testJSONString(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONString(t testing.TB, conn *pgx.Conn, typename string) { input := `{"key": "value"}` expectedOutput := map[string]string{"key": "value"} var output map[string]string @@ -125,7 +126,7 @@ func testJSONString(t *testing.T, conn *pgx.Conn, typename string) { } } -func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONStringPointer(t testing.TB, conn *pgx.Conn, typename string) { input := `{"key": "value"}` expectedOutput := map[string]string{"key": "value"} var output map[string]string @@ -233,7 +234,7 @@ func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) { } } -func mustParseCIDR(t *testing.T, s string) *net.IPNet { +func mustParseCIDR(t testing.TB, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { t.Fatal(err) @@ -245,7 +246,7 @@ func mustParseCIDR(t *testing.T, s string) *net.IPNet { func TestInetCIDRTranscodeIPNet(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { tests := []struct { sql string value *net.IPNet @@ -296,7 +297,7 @@ func TestInetCIDRTranscodeIPNet(t *testing.T) { func TestInetCIDRTranscodeIP(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { tests := []struct { sql string value net.IP @@ -360,7 +361,7 @@ func TestInetCIDRTranscodeIP(t *testing.T) { func TestInetCIDRArrayTranscodeIPNet(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { tests := []struct { sql string value []*net.IPNet @@ -423,7 +424,7 @@ func TestInetCIDRArrayTranscodeIPNet(t *testing.T) { func TestInetCIDRArrayTranscodeIP(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { tests := []struct { sql string value []net.IP @@ -509,7 +510,7 @@ func TestInetCIDRArrayTranscodeIP(t *testing.T) { func TestInetCIDRTranscodeWithJustIP(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { tests := []struct { sql string value string @@ -555,16 +556,16 @@ func TestInetCIDRTranscodeWithJustIP(t *testing.T) { func TestArrayDecoding(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { tests := []struct { sql string query interface{} scan interface{} - assert func(*testing.T, interface{}, interface{}) + assert func(testing.TB, interface{}, interface{}) }{ { "select $1::bool[]", []bool{true, false, true}, &[]bool{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]bool))) { t.Errorf("failed to encode bool[]") } @@ -572,7 +573,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]int16))) { t.Errorf("failed to encode smallint[]") } @@ -580,7 +581,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]uint16))) { t.Errorf("failed to encode smallint[]") } @@ -588,7 +589,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]int32))) { t.Errorf("failed to encode int[]") } @@ -596,7 +597,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]uint32))) { t.Errorf("failed to encode int[]") } @@ -604,7 +605,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]int64))) { t.Errorf("failed to encode bigint[]") } @@ -612,7 +613,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]uint64))) { t.Errorf("failed to encode bigint[]") } @@ -620,7 +621,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]string))) { t.Errorf("failed to encode text[]") } @@ -628,7 +629,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan interface{}) { queryTimeSlice := query.([]time.Time) scanTimeSlice := *(scan.(*[]time.Time)) require.Equal(t, len(queryTimeSlice), len(scanTimeSlice)) @@ -639,7 +640,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{}, - func(t *testing.T, query, scan interface{}) { + func(t testing.TB, query, scan interface{}) { queryBytesSliceSlice := query.([][]byte) scanBytesSliceSlice := *(scan.(*[][]byte)) if len(queryBytesSliceSlice) != len(scanBytesSliceSlice) { @@ -671,7 +672,7 @@ func TestArrayDecoding(t *testing.T) { func TestEmptyArrayDecoding(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { var val []string err := conn.QueryRow(context.Background(), "select array[]::text[]").Scan(&val) @@ -716,7 +717,7 @@ func TestEmptyArrayDecoding(t *testing.T) { func TestPointerPointer(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { skipCockroachDB(t, conn, "Server auto converts ints to bigint and test relies on exact types") type allTypes struct { @@ -802,7 +803,7 @@ func TestPointerPointer(t *testing.T) { func TestPointerPointerNonZero(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { f := "foo" dest := &f @@ -819,7 +820,7 @@ func TestPointerPointerNonZero(t *testing.T) { func TestEncodeTypeRename(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { type _int int inInt := _int(1) var outInt _int @@ -979,7 +980,7 @@ func TestEncodeTypeRename(t *testing.T) { func TestRowsScanNilThenScanValue(t *testing.T) { t.Parallel() - testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { sql := `select null as a, null as b union select 1, 2 From 83e50f21e8f637e61d91598748f7a8d2d7adac33 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 Apr 2022 10:35:13 -0500 Subject: [PATCH 0976/1158] Extract SkipCockroachDB to pgxtest --- batch_test.go | 8 ++++---- conn_test.go | 12 ++++++------ copy_from_test.go | 5 +++-- helper_test.go | 6 ------ large_objects_test.go | 7 ++++--- pgxtest/pgxtest.go | 7 +++++++ query_test.go | 12 ++++++------ tx_test.go | 5 +++-- values_test.go | 2 +- 9 files changed, 34 insertions(+), 30 deletions(-) diff --git a/batch_test.go b/batch_test.go index ffd990fc..6a7abd37 100644 --- a/batch_test.go +++ b/batch_test.go @@ -17,7 +17,7 @@ func TestConnSendBatch(t *testing.T) { t.Parallel() pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - skipCockroachDB(t, conn, "Server serial type is incompatible with test") + pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test") sql := `create temporary table ledger( id serial primary key, @@ -196,7 +196,7 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) { // Don't test simple mode with prepared statements. } pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, modes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") + pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") _, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n") if err != nil { t.Fatal(err) @@ -253,7 +253,7 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing. conn := mustConnect(t, config) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") + pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") _, err = conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n") if err != nil { @@ -600,7 +600,7 @@ func TestConnBeginBatchDeferredError(t *testing.T) { pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") mustExec(t, conn, `create temporary table t ( id text primary key, diff --git a/conn_test.go b/conn_test.go index 31190d7c..a618b936 100644 --- a/conn_test.go +++ b/conn_test.go @@ -506,7 +506,7 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) { func() { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") }() listenerDone := make(chan bool) @@ -582,7 +582,7 @@ func TestListenNotifySelfNotification(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") mustExec(t, conn, "listen self") @@ -617,7 +617,7 @@ func TestFatalRxError(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") var wg sync.WaitGroup wg.Add(1) @@ -656,7 +656,7 @@ func TestFatalTxError(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") otherConn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer otherConn.Close(context.Background()) @@ -821,7 +821,7 @@ func TestConnInitTypeMap(t *testing.T) { func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) { pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - skipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") + pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") var n uint64 err := conn.QueryRow(context.Background(), "select $1::uint64", "42").Scan(&n) @@ -837,7 +837,7 @@ func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) { func TestDomainType(t *testing.T) { pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - skipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") + pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") // Domain type uint64 is a PostgreSQL domain of underlying type numeric. diff --git a/copy_from_test.go b/copy_from_test.go index 1182cb1e..54e2e52b 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -10,6 +10,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) @@ -134,7 +135,7 @@ func TestConnCopyFromLarge(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/52722)") + pgxtest.SkipCockroachDB(t, conn, "Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/52722)") mustExec(t, conn, `create temporary table foo( a int2, @@ -416,7 +417,7 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server copy error does not fail fast") + pgxtest.SkipCockroachDB(t, conn, "Server copy error does not fail fast") mustExec(t, conn, `create temporary table foo( a bytea not null diff --git a/helper_test.go b/helper_test.go index e0f04906..461dfcab 100644 --- a/helper_test.go +++ b/helper_test.go @@ -137,9 +137,3 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName } } } - -func skipCockroachDB(t testing.TB, conn *pgx.Conn, msg string) { - if conn.PgConn().ParameterStatus("crdb_version") != "" { - t.Skip(msg) - } -} diff --git a/large_objects_test.go b/large_objects_test.go index f86f35e9..626809e7 100644 --- a/large_objects_test.go +++ b/large_objects_test.go @@ -9,6 +9,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" ) func TestLargeObjects(t *testing.T) { @@ -22,7 +23,7 @@ func TestLargeObjects(t *testing.T) { t.Fatal(err) } - skipCockroachDB(t, conn, "Server does support large objects") + pgxtest.SkipCockroachDB(t, conn, "Server does support large objects") tx, err := conn.Begin(ctx) if err != nil { @@ -50,7 +51,7 @@ func TestLargeObjectsSimpleProtocol(t *testing.T) { t.Fatal(err) } - skipCockroachDB(t, conn, "Server does support large objects") + pgxtest.SkipCockroachDB(t, conn, "Server does support large objects") tx, err := conn.Begin(ctx) if err != nil { @@ -169,7 +170,7 @@ func TestLargeObjectsMultipleTransactions(t *testing.T) { t.Fatal(err) } - skipCockroachDB(t, conn, "Server does support large objects") + pgxtest.SkipCockroachDB(t, conn, "Server does support large objects") tx, err := conn.Begin(ctx) if err != nil { diff --git a/pgxtest/pgxtest.go b/pgxtest/pgxtest.go index 579bcd92..dec6a520 100644 --- a/pgxtest/pgxtest.go +++ b/pgxtest/pgxtest.go @@ -86,3 +86,10 @@ func RunWithQueryExecModes(ctx context.Context, t *testing.T, ctr ConnTestRunner ) } } + +// SkipCockroachDB calls Skip on t with msg if the connection is to a CockroachDB server. +func SkipCockroachDB(t testing.TB, conn *pgx.Conn, msg string) { + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip(msg) + } +} diff --git a/query_test.go b/query_test.go index e8790bba..007d5256 100644 --- a/query_test.go +++ b/query_test.go @@ -269,7 +269,7 @@ func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support point type") + pgxtest.SkipCockroachDB(t, conn, "Server does not support point type") var s string @@ -477,7 +477,7 @@ func TestConnQueryDeferredError(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") mustExec(t, conn, `create temporary table t ( id text primary key, @@ -519,7 +519,7 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server uses numeric instead of int") + pgxtest.SkipCockroachDB(t, conn, "Server uses numeric instead of int") for i := 0; i < 100; i++ { func() { @@ -1267,7 +1267,7 @@ func TestQueryContextErrorWhileReceivingRows(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server uses numeric instead of int") + pgxtest.SkipCockroachDB(t, conn, "Server uses numeric instead of int") ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() @@ -1789,7 +1789,7 @@ func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support changing client_encoding (https://www.cockroachlabs.com/docs/stable/set-vars.html)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support changing client_encoding (https://www.cockroachlabs.com/docs/stable/set-vars.html)") mustExec(t, conn, "set client_encoding to 'SQL_ASCII'") @@ -1813,7 +1813,7 @@ func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support standard_conforming_strings = off (https://github.com/cockroachdb/cockroach/issues/36215)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support standard_conforming_strings = off (https://github.com/cockroachdb/cockroach/issues/36215)") mustExec(t, conn, "set standard_conforming_strings to off") diff --git a/tx_test.go b/tx_test.go index 23d76663..d45553a2 100644 --- a/tx_test.go +++ b/tx_test.go @@ -9,6 +9,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) @@ -106,7 +107,7 @@ func TestTxCommitWhenDeferredConstraintFailure(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") createSql := ` create temporary table foo( @@ -273,7 +274,7 @@ func TestBeginIsoLevels(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - skipCockroachDB(t, conn, "Server always uses SERIALIZABLE isolation (https://www.cockroachlabs.com/docs/stable/demo-serializable.html)") + pgxtest.SkipCockroachDB(t, conn, "Server always uses SERIALIZABLE isolation (https://www.cockroachlabs.com/docs/stable/demo-serializable.html)") isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} for _, iso := range isoLevels { diff --git a/values_test.go b/values_test.go index 55c577f5..4282880c 100644 --- a/values_test.go +++ b/values_test.go @@ -718,7 +718,7 @@ func TestPointerPointer(t *testing.T) { t.Parallel() pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - skipCockroachDB(t, conn, "Server auto converts ints to bigint and test relies on exact types") + pgxtest.SkipCockroachDB(t, conn, "Server auto converts ints to bigint and test relies on exact types") type allTypes struct { s *string From ee93440ac15326a4dbfedd6b56de685726392de6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 Apr 2022 14:34:19 -0500 Subject: [PATCH 0977/1158] pgtype uses pgxtest Added ValueRoundTripTest to pgxtest Removed pgtype/testutil pgtype tests now run with all (applicable) query modes. This gives better coverage than before and revealed several bugs which are also fixed in this commit. --- conn.go | 35 +- pgtype/array_codec_test.go | 269 ++- pgtype/bits_test.go | 7 +- pgtype/bool_test.go | 5 +- pgtype/box_test.go | 5 +- pgtype/bytea_test.go | 127 +- pgtype/circle_test.go | 5 +- pgtype/composite_test.go | 201 +- pgtype/date_test.go | 5 +- pgtype/enum_codec_test.go | 78 +- pgtype/float4_test.go | 5 +- pgtype/float8_test.go | 5 +- pgtype/hstore_test.go | 48 +- pgtype/inet_test.go | 7 +- pgtype/int_test.go | 9 +- pgtype/int_test.go.erb | 2 +- pgtype/integration_benchmark_test.go | 2171 +++++++++++----------- pgtype/integration_benchmark_test.go.erb | 66 +- pgtype/interval_test.go | 5 +- pgtype/json_test.go | 18 +- pgtype/jsonb_test.go | 18 +- pgtype/line_test.go | 36 +- pgtype/lseg_test.go | 5 +- pgtype/macaddr_test.go | 6 +- pgtype/numeric_test.go | 67 +- pgtype/path_test.go | 5 +- pgtype/pgtype.go | 2 +- pgtype/pgtype_test.go | 29 +- pgtype/point_test.go | 5 +- pgtype/polygon_test.go | 5 +- pgtype/qchar_test.go | 16 +- pgtype/range_codec_test.go | 152 +- pgtype/record_codec_test.go | 108 +- pgtype/testutil/testutil.go | 85 - pgtype/text_test.go | 51 +- pgtype/tid_test.go | 5 +- pgtype/time_test.go | 5 +- pgtype/timestamp.go | 28 + pgtype/timestamp_test.go | 31 +- pgtype/timestamptz.go | 33 +- pgtype/timestamptz_test.go | 31 +- pgtype/uint32_test.go | 5 +- pgtype/uuid_test.go | 28 +- pgtype/zeronull/float8_test.go | 5 +- pgtype/zeronull/int_test.go | 9 +- pgtype/zeronull/int_test.go.erb | 2 +- pgtype/zeronull/text_test.go | 5 +- pgtype/zeronull/timestamp_test.go | 5 +- pgtype/zeronull/timestamptz_test.go | 5 +- pgtype/zeronull/uuid_test.go | 5 +- pgtype/zeronull/zeronull.go | 17 + pgtype/zeronull/zeronull_test.go | 26 + pgxtest/pgxtest.go | 72 +- values.go | 2 +- 54 files changed, 2028 insertions(+), 1954 deletions(-) delete mode 100644 pgtype/testutil/testutil.go create mode 100644 pgtype/zeronull/zeronull.go create mode 100644 pgtype/zeronull/zeronull_test.go diff --git a/conn.go b/conn.go index 8ccf5043..e6396cd3 100644 --- a/conn.go +++ b/conn.go @@ -558,18 +558,41 @@ func (c *Conn) execSQLParams(ctx context.Context, sql string, args []interface{} // Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is // no way to safely use binary or to specify the parameter OIDs. func (c *Conn) appendParamsForQueryExecModeExec(args []interface{}) error { - for i := range args { - if args[i] == nil { - err := c.eqb.AppendParamFormat(c.typeMap, 0, TextFormatCode, args[i]) + for _, arg := range args { + if arg == nil { + err := c.eqb.AppendParamFormat(c.typeMap, 0, TextFormatCode, arg) if err != nil { return err } } else { - dt, ok := c.TypeMap().TypeForValue(args[i]) + dt, ok := c.TypeMap().TypeForValue(arg) if !ok { - return &unknownArgumentTypeQueryExecModeExecError{arg: args[i]} + var tv pgtype.TextValuer + if tv, ok = arg.(pgtype.TextValuer); ok { + t, err := tv.TextValue() + if err != nil { + return err + } + + dt, ok = c.TypeMap().TypeForOID(pgtype.TextOID) + if ok { + arg = t + } + } } - err := c.eqb.AppendParamFormat(c.typeMap, dt.OID, TextFormatCode, args[i]) + if !ok { + var str fmt.Stringer + if str, ok = arg.(fmt.Stringer); ok { + dt, ok = c.TypeMap().TypeForOID(pgtype.TextOID) + if ok { + arg = str.String() + } + } + } + if !ok { + return &unknownArgumentTypeQueryExecModeExecError{arg: arg} + } + err := c.eqb.AppendParamFormat(c.typeMap, dt.OID, TextFormatCode, arg) if err != nil { return err } diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go index 0cc6d7cb..55ab814e 100644 --- a/pgtype/array_codec_test.go +++ b/pgtype/array_codec_test.go @@ -4,191 +4,184 @@ import ( "context" "testing" - "github.com/jackc/pgx/v5/pgtype/testutil" + pgx "github.com/jackc/pgx/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestArrayCodec(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + for i, tt := range []struct { + expected interface{} + }{ + {[]int16(nil)}, + {[]int16{}}, + {[]int16{1, 2, 3}}, + } { + var actual []int16 + err := conn.QueryRow( + ctx, + "select $1::smallint[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } - for i, tt := range []struct { - expected interface{} - }{ - {[]int16(nil)}, - {[]int16{}}, - {[]int16{1, 2, 3}}, - } { - var actual []int16 - err := conn.QueryRow( - context.Background(), - "select $1::smallint[]", - tt.expected, - ).Scan(&actual) - assert.NoErrorf(t, err, "%d", i) - assert.Equalf(t, tt.expected, actual, "%d", i) - } + newInt16 := func(n int16) *int16 { return &n } - newInt16 := func(n int16) *int16 { return &n } - - for i, tt := range []struct { - expected interface{} - }{ - {[]*int16{newInt16(1), nil, newInt16(3), nil, newInt16(5)}}, - } { - var actual []*int16 - err := conn.QueryRow( - context.Background(), - "select $1::smallint[]", - tt.expected, - ).Scan(&actual) - assert.NoErrorf(t, err, "%d", i) - assert.Equalf(t, tt.expected, actual, "%d", i) - } + for i, tt := range []struct { + expected interface{} + }{ + {[]*int16{newInt16(1), nil, newInt16(3), nil, newInt16(5)}}, + } { + var actual []*int16 + err := conn.QueryRow( + ctx, + "select $1::smallint[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + }) } func TestArrayCodecAnySlice(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + type _int16Slice []int16 - type _int16Slice []int16 - - for i, tt := range []struct { - expected interface{} - }{ - {_int16Slice(nil)}, - {_int16Slice{}}, - {_int16Slice{1, 2, 3}}, - } { - var actual _int16Slice - err := conn.QueryRow( - context.Background(), - "select $1::smallint[]", - tt.expected, - ).Scan(&actual) - assert.NoErrorf(t, err, "%d", i) - assert.Equalf(t, tt.expected, actual, "%d", i) - } + for i, tt := range []struct { + expected interface{} + }{ + {_int16Slice(nil)}, + {_int16Slice{}}, + {_int16Slice{1, 2, 3}}, + } { + var actual _int16Slice + err := conn.QueryRow( + ctx, + "select $1::smallint[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + }) } func TestArrayCodecDecodeValue(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - for _, tt := range []struct { - sql string - expected interface{} - }{ - { - sql: `select '{}'::int4[]`, - expected: []interface{}{}, - }, - { - sql: `select '{1,2}'::int8[]`, - expected: []interface{}{int64(1), int64(2)}, - }, - { - sql: `select '{foo,bar}'::text[]`, - expected: []interface{}{"foo", "bar"}, - }, - } { - t.Run(tt.sql, func(t *testing.T) { - rows, err := conn.Query(context.Background(), tt.sql) - require.NoError(t, err) - - for rows.Next() { - values, err := rows.Values() + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + for _, tt := range []struct { + sql string + expected interface{} + }{ + { + sql: `select '{}'::int4[]`, + expected: []interface{}{}, + }, + { + sql: `select '{1,2}'::int8[]`, + expected: []interface{}{int64(1), int64(2)}, + }, + { + sql: `select '{foo,bar}'::text[]`, + expected: []interface{}{"foo", "bar"}, + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(ctx, tt.sql) require.NoError(t, err) - require.Len(t, values, 1) - require.Equal(t, tt.expected, values[0]) - } - require.NoError(t, rows.Err()) - }) - } + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } + }) } func TestArrayCodecScanMultipleDimensions(t *testing.T) { skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - rows, err := conn.Query(context.Background(), `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`) - require.NoError(t, err) - - for rows.Next() { - var ss [][]int32 - err := rows.Scan(&ss) + rows, err := conn.Query(ctx, `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`) require.NoError(t, err) - require.Equal(t, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, ss) - } - require.NoError(t, rows.Err()) + for rows.Next() { + var ss [][]int32 + err := rows.Scan(&ss) + require.NoError(t, err) + require.Equal(t, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, ss) + } + + require.NoError(t, rows.Err()) + }) } func TestArrayCodecScanMultipleDimensionsEmpty(t *testing.T) { skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - rows, err := conn.Query(context.Background(), `select '{}'::int4[]`) - require.NoError(t, err) - - for rows.Next() { - var ss [][]int32 - err := rows.Scan(&ss) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, err := conn.Query(ctx, `select '{}'::int4[]`) require.NoError(t, err) - require.Equal(t, [][]int32{}, ss) - } - require.NoError(t, rows.Err()) + for rows.Next() { + var ss [][]int32 + err := rows.Scan(&ss) + require.NoError(t, err) + require.Equal(t, [][]int32{}, ss) + } + + require.NoError(t, rows.Err()) + }) } func TestArrayCodecScanWrongMultipleDimensions(t *testing.T) { skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, err := conn.Query(ctx, `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`) + require.NoError(t, err) - rows, err := conn.Query(context.Background(), `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`) - require.NoError(t, err) - - for rows.Next() { - var ss [][][]int32 - err := rows.Scan(&ss) - require.Error(t, err, "can't scan into dest[0]: PostgreSQL array has 2 dimensions but slice has 3 dimensions") - } + for rows.Next() { + var ss [][][]int32 + err := rows.Scan(&ss) + require.Error(t, err, "can't scan into dest[0]: PostgreSQL array has 2 dimensions but slice has 3 dimensions") + } + }) } func TestArrayCodecEncodeMultipleDimensions(t *testing.T) { skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - rows, err := conn.Query(context.Background(), `select $1::int4[]`, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}) - require.NoError(t, err) - - for rows.Next() { - var ss [][]int32 - err := rows.Scan(&ss) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, err := conn.Query(ctx, `select $1::int4[]`, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}) require.NoError(t, err) - require.Equal(t, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, ss) - } - require.NoError(t, rows.Err()) + for rows.Next() { + var ss [][]int32 + err := rows.Scan(&ss) + require.NoError(t, err) + require.Equal(t, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, ss) + } + + require.NoError(t, rows.Err()) + }) } func TestArrayCodecEncodeMultipleDimensionsRagged(t *testing.T) { skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - rows, err := conn.Query(context.Background(), `select $1::int4[]`, [][]int32{{1, 2, 3, 4}, {5}, {9, 10, 11, 12}}) - require.Error(t, err, "cannot convert [][]int32 to ArrayGetter because it is a ragged multi-dimensional") - defer rows.Close() + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, err := conn.Query(ctx, `select $1::int4[]`, [][]int32{{1, 2, 3, 4}, {5}, {9, 10, 11, 12}}) + require.Error(t, err, "cannot convert [][]int32 to ArrayGetter because it is a ragged multi-dimensional") + defer rows.Close() + }) } diff --git a/pgtype/bits_test.go b/pgtype/bits_test.go index 5a82743c..3ca3b0c0 100644 --- a/pgtype/bits_test.go +++ b/pgtype/bits_test.go @@ -2,10 +2,11 @@ package pgtype_test import ( "bytes" + "context" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func isExpectedEqBits(a interface{}) func(interface{}) bool { @@ -17,7 +18,7 @@ func isExpectedEqBits(a interface{}) func(interface{}) bool { } func TestBitsCodecBit(t *testing.T) { - testutil.RunTranscodeTests(t, "bit(40)", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "bit(40)", []pgxtest.ValueRoundTripTest{ { pgtype.Bits{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Valid: true}, new(pgtype.Bits), @@ -34,7 +35,7 @@ func TestBitsCodecBit(t *testing.T) { } func TestBitsCodecVarbit(t *testing.T) { - testutil.RunTranscodeTests(t, "varbit", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "varbit", []pgxtest.ValueRoundTripTest{ { pgtype.Bits{Bytes: []byte{}, Len: 0, Valid: true}, new(pgtype.Bits), diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go index 57094144..7480471b 100644 --- a/pgtype/bool_test.go +++ b/pgtype/bool_test.go @@ -1,14 +1,15 @@ package pgtype_test import ( + "context" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func TestBoolCodec(t *testing.T) { - testutil.RunTranscodeTests(t, "bool", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "bool", []pgxtest.ValueRoundTripTest{ {true, new(bool), isExpectedEq(true)}, {false, new(bool), isExpectedEq(false)}, {true, new(pgtype.Bool), isExpectedEq(pgtype.Bool{Bool: true, Valid: true})}, diff --git a/pgtype/box_test.go b/pgtype/box_test.go index 173fb1f5..3b54c1f8 100644 --- a/pgtype/box_test.go +++ b/pgtype/box_test.go @@ -1,16 +1,17 @@ package pgtype_test import ( + "context" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func TestBoxCodec(t *testing.T) { skipCockroachDB(t, "Server does not support box type") - testutil.RunTranscodeTests(t, "box", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "box", []pgxtest.ValueRoundTripTest{ { pgtype.Box{ P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go index 443b73ce..7a348ebb 100644 --- a/pgtype/bytea_test.go +++ b/pgtype/bytea_test.go @@ -5,8 +5,9 @@ import ( "context" "testing" + pgx "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) @@ -28,7 +29,7 @@ func isExpectedEqBytes(a interface{}) func(interface{}) bool { } func TestByteaCodec(t *testing.T) { - testutil.RunTranscodeTests(t, "bytea", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "bytea", []pgxtest.ValueRoundTripTest{ {[]byte{1, 2, 3}, new([]byte), isExpectedEqBytes([]byte{1, 2, 3})}, {[]byte{}, new([]byte), isExpectedEqBytes([]byte{})}, {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, @@ -37,91 +38,79 @@ func TestByteaCodec(t *testing.T) { } func TestDriverBytesQueryRow(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - ctx := context.Background() - - var buf []byte - err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.DriverBytes)(&buf)) - require.EqualError(t, err, "cannot scan into *pgtype.DriverBytes from QueryRow") + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var buf []byte + err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.DriverBytes)(&buf)) + require.EqualError(t, err, "cannot scan into *pgtype.DriverBytes from QueryRow") + }) } func TestDriverBytes(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - ctx := context.Background() - - argBuf := make([]byte, 128) - for i := range argBuf { - argBuf[i] = byte(i) - } - - rows, err := conn.Query(ctx, `select $1::bytea from generate_series(1, 1000)`, argBuf) - require.NoError(t, err) - defer rows.Close() - - rowCount := 0 - resultBuf := argBuf - detectedResultMutation := false - for rows.Next() { - rowCount++ - - // At some point the buffer should be reused and change. - if bytes.Compare(argBuf, resultBuf) != 0 { - detectedResultMutation = true + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + argBuf := make([]byte, 128) + for i := range argBuf { + argBuf[i] = byte(i) } - err = rows.Scan((*pgtype.DriverBytes)(&resultBuf)) + rows, err := conn.Query(ctx, `select $1::bytea from generate_series(1, 1000)`, argBuf) require.NoError(t, err) + defer rows.Close() - require.Len(t, resultBuf, len(argBuf)) - require.Equal(t, resultBuf, argBuf) - require.Equalf(t, cap(resultBuf), len(resultBuf), "cap(resultBuf) is larger than len(resultBuf)") - } + rowCount := 0 + resultBuf := argBuf + detectedResultMutation := false + for rows.Next() { + rowCount++ - require.True(t, detectedResultMutation) + // At some point the buffer should be reused and change. + if bytes.Compare(argBuf, resultBuf) != 0 { + detectedResultMutation = true + } - err = rows.Err() - require.NoError(t, err) + err = rows.Scan((*pgtype.DriverBytes)(&resultBuf)) + require.NoError(t, err) + + require.Len(t, resultBuf, len(argBuf)) + require.Equal(t, resultBuf, argBuf) + require.Equalf(t, cap(resultBuf), len(resultBuf), "cap(resultBuf) is larger than len(resultBuf)") + } + + require.True(t, detectedResultMutation) + + err = rows.Err() + require.NoError(t, err) + }) } func TestPreallocBytes(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + origBuf := []byte{5, 6, 7, 8} + buf := origBuf + err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.PreallocBytes)(&buf)) + require.NoError(t, err) - ctx := context.Background() + require.Len(t, buf, 2) + require.Equal(t, 4, cap(buf)) + require.Equal(t, buf, []byte{1, 2}) - origBuf := []byte{5, 6, 7, 8} - buf := origBuf - err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.PreallocBytes)(&buf)) - require.NoError(t, err) + require.Equal(t, []byte{1, 2, 7, 8}, origBuf) - require.Len(t, buf, 2) - require.Equal(t, 4, cap(buf)) - require.Equal(t, buf, []byte{1, 2}) + err = conn.QueryRow(ctx, `select $1::bytea`, []byte{3, 4, 5, 6, 7}).Scan((*pgtype.PreallocBytes)(&buf)) + require.NoError(t, err) + require.Len(t, buf, 5) + require.Equal(t, 5, cap(buf)) - require.Equal(t, []byte{1, 2, 7, 8}, origBuf) - - err = conn.QueryRow(ctx, `select $1::bytea`, []byte{3, 4, 5, 6, 7}).Scan((*pgtype.PreallocBytes)(&buf)) - require.NoError(t, err) - require.Len(t, buf, 5) - require.Equal(t, 5, cap(buf)) - - require.Equal(t, []byte{1, 2, 7, 8}, origBuf) + require.Equal(t, []byte{1, 2, 7, 8}, origBuf) + }) } func TestUndecodedBytes(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var buf []byte + err := conn.QueryRow(ctx, `select 1::int4`).Scan((*pgtype.UndecodedBytes)(&buf)) + require.NoError(t, err) - ctx := context.Background() - - var buf []byte - err := conn.QueryRow(ctx, `select 1::int4`).Scan((*pgtype.UndecodedBytes)(&buf)) - require.NoError(t, err) - - require.Len(t, buf, 4) - require.Equal(t, buf, []byte{0, 0, 0, 1}) + require.Len(t, buf, 4) + require.Equal(t, buf, []byte{0, 0, 0, 1}) + }) } diff --git a/pgtype/circle_test.go b/pgtype/circle_test.go index b78d35ba..7b6db777 100644 --- a/pgtype/circle_test.go +++ b/pgtype/circle_test.go @@ -1,16 +1,17 @@ package pgtype_test import ( + "context" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func TestCircleTranscode(t *testing.T) { skipCockroachDB(t, "Server does not support box type") - testutil.RunTranscodeTests(t, "circle", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "circle", []pgxtest.ValueRoundTripTest{ { pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, new(pgtype.Circle), diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go index 0f112ebd..559403d8 100644 --- a/pgtype/composite_test.go +++ b/pgtype/composite_test.go @@ -7,50 +7,49 @@ import ( pgx "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/require" ) func TestCompositeCodecTranscode(t *testing.T) { skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(context.Background(), `drop type if exists ct_test; + _, err := conn.Exec(ctx, `drop type if exists ct_test; create type ct_test as ( a text, b int4 );`) - require.NoError(t, err) - defer conn.Exec(context.Background(), "drop type ct_test") + require.NoError(t, err) + defer conn.Exec(ctx, "drop type ct_test") - dt, err := conn.LoadType(context.Background(), "ct_test") - require.NoError(t, err) - conn.TypeMap().RegisterType(dt) + dt, err := conn.LoadType(ctx, "ct_test") + require.NoError(t, err) + conn.TypeMap().RegisterType(dt) - formats := []struct { - name string - code int16 - }{ - {name: "TextFormat", code: pgx.TextFormatCode}, - {name: "BinaryFormat", code: pgx.BinaryFormatCode}, - } + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } - for _, format := range formats { - var a string - var b int32 + for _, format := range formats { + var a string + var b int32 - err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QueryResultFormats{format.code}, - pgtype.CompositeFields{"hi", int32(42)}, - ).Scan( - pgtype.CompositeFields{&a, &b}, - ) - require.NoErrorf(t, err, "%v", format.name) - require.EqualValuesf(t, "hi", a, "%v", format.name) - require.EqualValuesf(t, 42, b, "%v", format.name) - } + err := conn.QueryRow(ctx, "select $1::ct_test", pgx.QueryResultFormats{format.code}, + pgtype.CompositeFields{"hi", int32(42)}, + ).Scan( + pgtype.CompositeFields{&a, &b}, + ) + require.NoErrorf(t, err, "%v", format.name) + require.EqualValuesf(t, "hi", a, "%v", format.name) + require.EqualValuesf(t, 42, b, "%v", format.name) + } + }) } type point3d struct { @@ -94,118 +93,118 @@ func (p *point3d) ScanIndex(i int) interface{} { func TestCompositeCodecTranscodeStruct(t *testing.T) { skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(context.Background(), `drop type if exists point3d; + _, err := conn.Exec(ctx, `drop type if exists point3d; create type point3d as ( x float8, y float8, z float8 );`) - require.NoError(t, err) - defer conn.Exec(context.Background(), "drop type point3d") + require.NoError(t, err) + defer conn.Exec(ctx, "drop type point3d") - dt, err := conn.LoadType(context.Background(), "point3d") - require.NoError(t, err) - conn.TypeMap().RegisterType(dt) + dt, err := conn.LoadType(ctx, "point3d") + require.NoError(t, err) + conn.TypeMap().RegisterType(dt) - formats := []struct { - name string - code int16 - }{ - {name: "TextFormat", code: pgx.TextFormatCode}, - {name: "BinaryFormat", code: pgx.BinaryFormatCode}, - } + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } - for _, format := range formats { - input := point3d{X: 1, Y: 2, Z: 3} - var output point3d - err := conn.QueryRow(context.Background(), "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output) - require.NoErrorf(t, err, "%v", format.name) - require.Equalf(t, input, output, "%v", format.name) - } + for _, format := range formats { + input := point3d{X: 1, Y: 2, Z: 3} + var output point3d + err := conn.QueryRow(ctx, "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output) + require.NoErrorf(t, err, "%v", format.name) + require.Equalf(t, input, output, "%v", format.name) + } + }) } func TestCompositeCodecTranscodeStructWrapper(t *testing.T) { skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(context.Background(), `drop type if exists point3d; + _, err := conn.Exec(ctx, `drop type if exists point3d; create type point3d as ( x float8, y float8, z float8 );`) - require.NoError(t, err) - defer conn.Exec(context.Background(), "drop type point3d") + require.NoError(t, err) + defer conn.Exec(ctx, "drop type point3d") - dt, err := conn.LoadType(context.Background(), "point3d") - require.NoError(t, err) - conn.TypeMap().RegisterType(dt) + dt, err := conn.LoadType(ctx, "point3d") + require.NoError(t, err) + conn.TypeMap().RegisterType(dt) - formats := []struct { - name string - code int16 - }{ - {name: "TextFormat", code: pgx.TextFormatCode}, - {name: "BinaryFormat", code: pgx.BinaryFormatCode}, - } + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } - type anotherPoint struct { - X, Y, Z float64 - } + type anotherPoint struct { + X, Y, Z float64 + } - for _, format := range formats { - input := anotherPoint{X: 1, Y: 2, Z: 3} - var output anotherPoint - err := conn.QueryRow(context.Background(), "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output) - require.NoErrorf(t, err, "%v", format.name) - require.Equalf(t, input, output, "%v", format.name) - } + for _, format := range formats { + input := anotherPoint{X: 1, Y: 2, Z: 3} + var output anotherPoint + err := conn.QueryRow(ctx, "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output) + require.NoErrorf(t, err, "%v", format.name) + require.Equalf(t, input, output, "%v", format.name) + } + }) } func TestCompositeCodecDecodeValue(t *testing.T) { skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(context.Background(), `drop type if exists point3d; + _, err := conn.Exec(ctx, `drop type if exists point3d; create type point3d as ( x float8, y float8, z float8 );`) - require.NoError(t, err) - defer conn.Exec(context.Background(), "drop type point3d") + require.NoError(t, err) + defer conn.Exec(ctx, "drop type point3d") - dt, err := conn.LoadType(context.Background(), "point3d") - require.NoError(t, err) - conn.TypeMap().RegisterType(dt) + dt, err := conn.LoadType(ctx, "point3d") + require.NoError(t, err) + conn.TypeMap().RegisterType(dt) - formats := []struct { - name string - code int16 - }{ - {name: "TextFormat", code: pgx.TextFormatCode}, - {name: "BinaryFormat", code: pgx.BinaryFormatCode}, - } + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } - for _, format := range formats { - rows, err := conn.Query(context.Background(), "select '(1,2,3)'::point3d", pgx.QueryResultFormats{format.code}) - require.NoErrorf(t, err, "%v", format.name) - require.True(t, rows.Next()) - values, err := rows.Values() - require.NoErrorf(t, err, "%v", format.name) - require.Lenf(t, values, 1, "%v", format.name) - require.Equalf(t, map[string]interface{}{"x": 1.0, "y": 2.0, "z": 3.0}, values[0], "%v", format.name) - require.False(t, rows.Next()) - require.NoErrorf(t, rows.Err(), "%v", format.name) - } + for _, format := range formats { + rows, err := conn.Query(ctx, "select '(1,2,3)'::point3d", pgx.QueryResultFormats{format.code}) + require.NoErrorf(t, err, "%v", format.name) + require.True(t, rows.Next()) + values, err := rows.Values() + require.NoErrorf(t, err, "%v", format.name) + require.Lenf(t, values, 1, "%v", format.name) + require.Equalf(t, map[string]interface{}{"x": 1.0, "y": 2.0, "z": 3.0}, values[0], "%v", format.name) + require.False(t, rows.Next()) + require.NoErrorf(t, rows.Err(), "%v", format.name) + } + }) } diff --git a/pgtype/date_test.go b/pgtype/date_test.go index 06539822..25c6bfc2 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -1,11 +1,12 @@ package pgtype_test import ( + "context" "testing" "time" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func isExpectedEqTime(a interface{}) func(interface{}) bool { @@ -18,7 +19,7 @@ func isExpectedEqTime(a interface{}) func(interface{}) bool { } func TestDateCodec(t *testing.T) { - testutil.RunTranscodeTests(t, "date", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "date", []pgxtest.ValueRoundTripTest{ {time.Date(-100, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(-100, 1, 1, 0, 0, 0, 0, time.UTC))}, {time.Date(-1, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(-1, 1, 1, 0, 0, 0, 0, time.UTC))}, {time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC))}, diff --git a/pgtype/enum_codec_test.go b/pgtype/enum_codec_test.go index afd062a2..633b610b 100644 --- a/pgtype/enum_codec_test.go +++ b/pgtype/enum_codec_test.go @@ -4,66 +4,66 @@ import ( "context" "testing" - "github.com/jackc/pgx/v5/pgtype/testutil" + pgx "github.com/jackc/pgx/v5" "github.com/stretchr/testify/require" ) func TestEnumCodec(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(context.Background(), `drop type if exists enum_test; + _, err := conn.Exec(ctx, `drop type if exists enum_test; create type enum_test as enum ('foo', 'bar', 'baz');`) - require.NoError(t, err) - defer conn.Exec(context.Background(), "drop type enum_test") + require.NoError(t, err) + defer conn.Exec(ctx, "drop type enum_test") - dt, err := conn.LoadType(context.Background(), "enum_test") - require.NoError(t, err) + dt, err := conn.LoadType(ctx, "enum_test") + require.NoError(t, err) - conn.TypeMap().RegisterType(dt) + conn.TypeMap().RegisterType(dt) - var s string - err = conn.QueryRow(context.Background(), `select 'foo'::enum_test`).Scan(&s) - require.NoError(t, err) - require.Equal(t, "foo", s) + var s string + err = conn.QueryRow(ctx, `select 'foo'::enum_test`).Scan(&s) + require.NoError(t, err) + require.Equal(t, "foo", s) - err = conn.QueryRow(context.Background(), `select $1::enum_test`, "bar").Scan(&s) - require.NoError(t, err) - require.Equal(t, "bar", s) + err = conn.QueryRow(ctx, `select $1::enum_test`, "bar").Scan(&s) + require.NoError(t, err) + require.Equal(t, "bar", s) - err = conn.QueryRow(context.Background(), `select 'foo'::enum_test`).Scan(&s) - require.NoError(t, err) - require.Equal(t, "foo", s) + err = conn.QueryRow(ctx, `select 'foo'::enum_test`).Scan(&s) + require.NoError(t, err) + require.Equal(t, "foo", s) - err = conn.QueryRow(context.Background(), `select $1::enum_test`, "bar").Scan(&s) - require.NoError(t, err) - require.Equal(t, "bar", s) + err = conn.QueryRow(ctx, `select $1::enum_test`, "bar").Scan(&s) + require.NoError(t, err) + require.Equal(t, "bar", s) - err = conn.QueryRow(context.Background(), `select 'baz'::enum_test`).Scan(&s) - require.NoError(t, err) - require.Equal(t, "baz", s) + err = conn.QueryRow(ctx, `select 'baz'::enum_test`).Scan(&s) + require.NoError(t, err) + require.Equal(t, "baz", s) + }) } func TestEnumCodecValues(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(context.Background(), `drop type if exists enum_test; + _, err := conn.Exec(ctx, `drop type if exists enum_test; create type enum_test as enum ('foo', 'bar', 'baz');`) - require.NoError(t, err) - defer conn.Exec(context.Background(), "drop type enum_test") + require.NoError(t, err) + defer conn.Exec(ctx, "drop type enum_test") - dt, err := conn.LoadType(context.Background(), "enum_test") - require.NoError(t, err) + dt, err := conn.LoadType(ctx, "enum_test") + require.NoError(t, err) - conn.TypeMap().RegisterType(dt) + conn.TypeMap().RegisterType(dt) - rows, err := conn.Query(context.Background(), `select 'foo'::enum_test`) - require.NoError(t, err) - require.True(t, rows.Next()) - values, err := rows.Values() - require.NoError(t, err) - require.Equal(t, values, []interface{}{"foo"}) + rows, err := conn.Query(ctx, `select 'foo'::enum_test`) + require.NoError(t, err) + require.True(t, rows.Next()) + values, err := rows.Values() + require.NoError(t, err) + require.Equal(t, values, []interface{}{"foo"}) + }) } diff --git a/pgtype/float4_test.go b/pgtype/float4_test.go index 39d7ee75..00b9addf 100644 --- a/pgtype/float4_test.go +++ b/pgtype/float4_test.go @@ -1,14 +1,15 @@ package pgtype_test import ( + "context" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func TestFloat4Codec(t *testing.T) { - testutil.RunTranscodeTests(t, "float4", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "float4", []pgxtest.ValueRoundTripTest{ {pgtype.Float4{Float32: -1, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float32: -1, Valid: true})}, {pgtype.Float4{Float32: 0, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float32: 0, Valid: true})}, {pgtype.Float4{Float32: 1, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float32: 1, Valid: true})}, diff --git a/pgtype/float8_test.go b/pgtype/float8_test.go index 29bd6f31..9c269072 100644 --- a/pgtype/float8_test.go +++ b/pgtype/float8_test.go @@ -1,14 +1,15 @@ package pgtype_test import ( + "context" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func TestFloat8Codec(t *testing.T) { - testutil.RunTranscodeTests(t, "float8", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "float8", []pgxtest.ValueRoundTripTest{ {pgtype.Float8{Float64: -1, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float64: -1, Valid: true})}, {pgtype.Float8{Float64: 0, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float64: 0, Valid: true})}, {pgtype.Float8{Float64: 1, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float64: 1, Valid: true})}, diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go index 8141687a..fb0cc27b 100644 --- a/pgtype/hstore_test.go +++ b/pgtype/hstore_test.go @@ -6,7 +6,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func isExpectedEqMapStringString(a interface{}) func(interface{}) bool { @@ -52,30 +52,22 @@ func isExpectedEqMapStringPointerString(a interface{}) func(interface{}) bool { } func TestHstoreCodec(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var hstoreOID uint32 + err := conn.QueryRow(context.Background(), `select oid from pg_type where typname = 'hstore'`).Scan(&hstoreOID) + if err != nil { + t.Skipf("Skipping: cannot find hstore OID") + } - var hstoreOID uint32 - err := conn.QueryRow(context.Background(), `select oid from pg_type where typname = 'hstore'`).Scan(&hstoreOID) - if err != nil { - t.Skipf("Skipping: cannot find hstore OID") - } - - conn.TypeMap().RegisterType(&pgtype.Type{Name: "hstore", OID: hstoreOID, Codec: pgtype.HstoreCodec{}}) - - formats := []struct { - name string - code int16 - }{ - {name: "TextFormat", code: pgx.TextFormatCode}, - {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + conn.TypeMap().RegisterType(&pgtype.Type{Name: "hstore", OID: hstoreOID, Codec: pgtype.HstoreCodec{}}) } fs := func(s string) *string { return &s } - tests := []testutil.TranscodeTestCase{ + tests := []pgxtest.ValueRoundTripTest{ { map[string]string{}, new(map[string]string), @@ -134,25 +126,25 @@ func TestHstoreCodec(t *testing.T) { // Special key values // at beginning - tests = append(tests, testutil.TranscodeTestCase{ + tests = append(tests, pgxtest.ValueRoundTripTest{ map[string]string{s + "foo": "bar"}, new(map[string]string), isExpectedEqMapStringString(map[string]string{s + "foo": "bar"}), }) // in middle - tests = append(tests, testutil.TranscodeTestCase{ + tests = append(tests, pgxtest.ValueRoundTripTest{ map[string]string{"foo" + s + "bar": "bar"}, new(map[string]string), isExpectedEqMapStringString(map[string]string{"foo" + s + "bar": "bar"}), }) // at end - tests = append(tests, testutil.TranscodeTestCase{ + tests = append(tests, pgxtest.ValueRoundTripTest{ map[string]string{"foo" + s: "bar"}, new(map[string]string), isExpectedEqMapStringString(map[string]string{"foo" + s: "bar"}), }) // is key - tests = append(tests, testutil.TranscodeTestCase{ + tests = append(tests, pgxtest.ValueRoundTripTest{ map[string]string{s: "bar"}, new(map[string]string), isExpectedEqMapStringString(map[string]string{s: "bar"}), @@ -161,32 +153,30 @@ func TestHstoreCodec(t *testing.T) { // Special value values // at beginning - tests = append(tests, testutil.TranscodeTestCase{ + tests = append(tests, pgxtest.ValueRoundTripTest{ map[string]string{"foo": s + "bar"}, new(map[string]string), isExpectedEqMapStringString(map[string]string{"foo": s + "bar"}), }) // in middle - tests = append(tests, testutil.TranscodeTestCase{ + tests = append(tests, pgxtest.ValueRoundTripTest{ map[string]string{"foo": "foo" + s + "bar"}, new(map[string]string), isExpectedEqMapStringString(map[string]string{"foo": "foo" + s + "bar"}), }) // at end - tests = append(tests, testutil.TranscodeTestCase{ + tests = append(tests, pgxtest.ValueRoundTripTest{ map[string]string{"foo": "foo" + s}, new(map[string]string), isExpectedEqMapStringString(map[string]string{"foo": "foo" + s}), }) // is key - tests = append(tests, testutil.TranscodeTestCase{ + tests = append(tests, pgxtest.ValueRoundTripTest{ map[string]string{"foo": s}, new(map[string]string), isExpectedEqMapStringString(map[string]string{"foo": s}), }) } - for _, format := range formats { - testutil.RunTranscodeTestsFormat(t, "hstore", tests, conn, format.name, format.code) - } + pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, pgxtest.KnownOIDQueryExecModes, "hstore", tests) } diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go index 249caf3f..8bf11a76 100644 --- a/pgtype/inet_test.go +++ b/pgtype/inet_test.go @@ -1,11 +1,12 @@ package pgtype_test import ( + "context" "net" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func isExpectedEqIPNet(a interface{}) func(interface{}) bool { @@ -18,7 +19,7 @@ func isExpectedEqIPNet(a interface{}) func(interface{}) bool { } func TestInetTranscode(t *testing.T) { - testutil.RunTranscodeTests(t, "inet", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "inet", []pgxtest.ValueRoundTripTest{ {mustParseInet(t, "0.0.0.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "0.0.0.0/32"))}, {mustParseInet(t, "127.0.0.1/8"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "127.0.0.1/8"))}, {mustParseInet(t, "12.34.56.65/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "12.34.56.65/32"))}, @@ -37,7 +38,7 @@ func TestInetTranscode(t *testing.T) { func TestCidrTranscode(t *testing.T) { skipCockroachDB(t, "Server does not support cidr type (see https://github.com/cockroachdb/cockroach/issues/18846)") - testutil.RunTranscodeTests(t, "cidr", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "cidr", []pgxtest.ValueRoundTripTest{ {mustParseInet(t, "0.0.0.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "0.0.0.0/32"))}, {mustParseInet(t, "127.0.0.1/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "127.0.0.1/32"))}, {mustParseInet(t, "12.34.56.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "12.34.56.0/32"))}, diff --git a/pgtype/int_test.go b/pgtype/int_test.go index 6dc65259..c779bdc9 100644 --- a/pgtype/int_test.go +++ b/pgtype/int_test.go @@ -2,15 +2,16 @@ package pgtype_test import ( + "context" "math" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func TestInt2Codec(t *testing.T) { - testutil.RunTranscodeTests(t, "int2", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int2", []pgxtest.ValueRoundTripTest{ {int8(1), new(int16), isExpectedEq(int16(1))}, {int16(1), new(int16), isExpectedEq(int16(1))}, {int32(1), new(int16), isExpectedEq(int16(1))}, @@ -91,7 +92,7 @@ func TestInt2UnmarshalJSON(t *testing.T) { } func TestInt4Codec(t *testing.T) { - testutil.RunTranscodeTests(t, "int4", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4", []pgxtest.ValueRoundTripTest{ {int8(1), new(int32), isExpectedEq(int32(1))}, {int16(1), new(int32), isExpectedEq(int32(1))}, {int32(1), new(int32), isExpectedEq(int32(1))}, @@ -172,7 +173,7 @@ func TestInt4UnmarshalJSON(t *testing.T) { } func TestInt8Codec(t *testing.T) { - testutil.RunTranscodeTests(t, "int8", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int8", []pgxtest.ValueRoundTripTest{ {int8(1), new(int64), isExpectedEq(int64(1))}, {int16(1), new(int64), isExpectedEq(int64(1))}, {int32(1), new(int64), isExpectedEq(int64(1))}, diff --git a/pgtype/int_test.go.erb b/pgtype/int_test.go.erb index d55851c2..799d8c32 100644 --- a/pgtype/int_test.go.erb +++ b/pgtype/int_test.go.erb @@ -10,7 +10,7 @@ import ( <% [2, 4, 8].each do |pg_byte_size| %> <% pg_bit_size = pg_byte_size * 8 %> func TestInt<%= pg_byte_size %>Codec(t *testing.T) { - testutil.RunTranscodeTests(t, "int<%= pg_byte_size %>", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int<%= pg_byte_size %>", []pgxtest.ValueRoundTripTest{ {int8(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, {int16(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, {int32(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, diff --git a/pgtype/integration_benchmark_test.go b/pgtype/integration_benchmark_test.go index 58934ead..66758d07 100644 --- a/pgtype/integration_benchmark_test.go +++ b/pgtype/integration_benchmark_test.go @@ -7,1405 +7,1334 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]int16 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]int16 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]int16 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_1_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]int16 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_10_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]int16 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_10_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]int16 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_100_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]int16 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_100_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]int16 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_1_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_1_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_1_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_1_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_10_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_10_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_100_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_100_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_1_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]int64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_1_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]int64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_1_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]int64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_1_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]int64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_10_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]int64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_10_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]int64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_100_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]int64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_100_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]int64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_1_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]uint64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_1_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]uint64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_1_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]uint64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_1_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]uint64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_10_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]uint64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_10_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]uint64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_100_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]uint64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_100_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]uint64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]pgtype.Int4 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]pgtype.Int4 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]pgtype.Int4 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]pgtype.Int4 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_10_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]pgtype.Int4 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_10_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]pgtype.Int4 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_100_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]pgtype.Int4 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_100_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]pgtype.Int4 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_1_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]int64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_1_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]int64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_1_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]int64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_1_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]int64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_10_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]int64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_10_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]int64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_100_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]int64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_100_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]int64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_1_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]float64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_1_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]float64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_1_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]float64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_1_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]float64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_10_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]float64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_10_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]float64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_100_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]float64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_100_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]float64 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]pgtype.Numeric - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]pgtype.Numeric - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]pgtype.Numeric - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]pgtype.Numeric - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_10_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]pgtype.Numeric - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_10_rows_1_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [1]pgtype.Numeric - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]pgtype.Numeric - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_10_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [10]pgtype.Numeric - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v []int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select array_agg(n) from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select array_agg(n) from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v []int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select array_agg(n) from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select array_agg(n) from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v []int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select array_agg(n) from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select array_agg(n) from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v []int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select array_agg(n) from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select array_agg(n) from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v []int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select array_agg(n) from generate_series(1, 1000) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select array_agg(n) from generate_series(1, 1000) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v []int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select array_agg(n) from generate_series(1, 1000) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select array_agg(n) from generate_series(1, 1000) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } } - } + }) } diff --git a/pgtype/integration_benchmark_test.go.erb b/pgtype/integration_benchmark_test.go.erb index de9cabc9..b122c606 100644 --- a/pgtype/integration_benchmark_test.go.erb +++ b/pgtype/integration_benchmark_test.go.erb @@ -18,23 +18,22 @@ import ( <% rows_columns.each do |rows, columns| %> <% [["Text", "pgx.TextFormatCode"], ["Binary", "pgx.BinaryFormatCode"]].each do |format_name, format_code| %> func BenchmarkQuery<%= format_name %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go_type.gsub(/\W/, "_") %>_<%= rows %>_rows_<%= columns %>_columns(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v [<%= columns %>]<%= go_type %> - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select <% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>n::<%= pg_type %> + <%= col_idx%><% end %> from generate_series(1, <%= rows %>) n`, - []interface{}{pgx.QueryResultFormats{<%= format_code %>}}, - []interface{}{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) - } - } + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [<%= columns %>]<%= go_type %> + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select <% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>n::<%= pg_type %> + <%= col_idx%><% end %> from generate_series(1, <%= rows %>) n`, + []interface{}{pgx.QueryResultFormats{<%= format_code %>}}, + []interface{}{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } + }) } <% end %> <% end %> @@ -44,23 +43,22 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go <% [10, 100, 1000].each do |array_size| %> <% [["Text", "pgx.TextFormatCode"], ["Binary", "pgx.BinaryFormatCode"]].each do |format_name, format_code| %> func BenchmarkQuery<%= format_name %>FormatDecode_PG_Int4Array_With_Go_Int4Array_<%= array_size %>(b *testing.B) { - conn := testutil.MustConnectPgx(b) - defer testutil.MustCloseContext(b, conn) - - b.ResetTimer() - var v []int32 - for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( - context.Background(), - `select array_agg(n) from generate_series(1, <%= array_size %>) n`, - []interface{}{pgx.QueryResultFormats{<%= format_code %>}}, - []interface{}{&v}, - func(pgx.QueryFuncRow) error { return nil }, - ) - if err != nil { - b.Fatal(err) - } - } + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + ctx, + `select array_agg(n) from generate_series(1, <%= array_size %>) n`, + []interface{}{pgx.QueryResultFormats{<%= format_code %>}}, + []interface{}{&v}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } + }) } <% end %> <% end %> diff --git a/pgtype/interval_test.go b/pgtype/interval_test.go index 310ea6bc..754c44e3 100644 --- a/pgtype/interval_test.go +++ b/pgtype/interval_test.go @@ -1,15 +1,16 @@ package pgtype_test import ( + "context" "testing" "time" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func TestIntervalCodec(t *testing.T) { - testutil.RunTranscodeTests(t, "interval", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "interval", []pgxtest.ValueRoundTripTest{ { pgtype.Interval{Microseconds: 1, Valid: true}, new(pgtype.Interval), diff --git a/pgtype/json_test.go b/pgtype/json_test.go index 39658bfa..0275b1e6 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -1,9 +1,10 @@ package pgtype_test import ( + "context" "testing" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func isExpectedEqMap(a interface{}) func(interface{}) bool { @@ -39,7 +40,15 @@ func TestJSONCodec(t *testing.T) { Age int `json:"age"` } - testutil.RunTranscodeTests(t, "json", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "json", []pgxtest.ValueRoundTripTest{ + {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, + {map[string]interface{}(nil), new(*string), isExpectedEq((*string)(nil))}, + {map[string]interface{}(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, + }) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{ {[]byte("{}"), new([]byte), isExpectedEqBytes([]byte("{}"))}, {[]byte("null"), new([]byte), isExpectedEqBytes([]byte("null"))}, {[]byte("42"), new([]byte), isExpectedEqBytes([]byte("42"))}, @@ -47,10 +56,5 @@ func TestJSONCodec(t *testing.T) { {[]byte(`"hello"`), new(string), isExpectedEq(`"hello"`)}, {map[string]interface{}{"foo": "bar"}, new(map[string]interface{}), isExpectedEqMap(map[string]interface{}{"foo": "bar"})}, {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, - {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, - {map[string]interface{}(nil), new(*string), isExpectedEq((*string)(nil))}, - {map[string]interface{}(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, - {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, - {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, }) } diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index c26499c6..4a9f7a35 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -1,9 +1,10 @@ package pgtype_test import ( + "context" "testing" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func TestJSONBTranscode(t *testing.T) { @@ -12,7 +13,15 @@ func TestJSONBTranscode(t *testing.T) { Age int `json:"age"` } - testutil.RunTranscodeTests(t, "jsonb", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "jsonb", []pgxtest.ValueRoundTripTest{ + {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, + {map[string]interface{}(nil), new(*string), isExpectedEq((*string)(nil))}, + {map[string]interface{}(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, + }) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "jsonb", []pgxtest.ValueRoundTripTest{ {[]byte("{}"), new([]byte), isExpectedEqBytes([]byte("{}"))}, {[]byte("null"), new([]byte), isExpectedEqBytes([]byte("null"))}, {[]byte("42"), new([]byte), isExpectedEqBytes([]byte("42"))}, @@ -20,10 +29,5 @@ func TestJSONBTranscode(t *testing.T) { {[]byte(`"hello"`), new(string), isExpectedEq(`"hello"`)}, {map[string]interface{}{"foo": "bar"}, new(map[string]interface{}), isExpectedEqMap(map[string]interface{}{"foo": "bar"})}, {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, - {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, - {map[string]interface{}(nil), new(*string), isExpectedEq((*string)(nil))}, - {map[string]interface{}(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, - {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, - {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, }) } diff --git a/pgtype/line_test.go b/pgtype/line_test.go index 8e3d782c..dc980ce1 100644 --- a/pgtype/line_test.go +++ b/pgtype/line_test.go @@ -4,30 +4,32 @@ import ( "context" "testing" + pgx "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func TestLineTranscode(t *testing.T) { - skipCockroachDB(t, "Server does not support type line") + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does not support type line") - conn := testutil.MustConnectPgx(t) - defer conn.Close(context.Background()) - if _, ok := conn.TypeMap().TypeForName("line"); !ok { - t.Skip("Skipping due to no line type") + if _, ok := conn.TypeMap().TypeForName("line"); !ok { + t.Skip("Skipping due to no line type") + } + + // line may exist but not be usable on 9.3 :( + var isPG93 bool + err := conn.QueryRow(context.Background(), "select version() ~ '9.3'").Scan(&isPG93) + if err != nil { + t.Fatal(err) + } + if isPG93 { + t.Skip("Skipping due to unimplemented line type in PG 9.3") + } } - // line may exist but not be usable on 9.3 :( - var isPG93 bool - err := conn.QueryRow(context.Background(), "select version() ~ '9.3'").Scan(&isPG93) - if err != nil { - t.Fatal(err) - } - if isPG93 { - t.Skip("Skipping due to unimplemented line type in PG 9.3") - } - - testutil.RunTranscodeTests(t, "line", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, nil, "line", []pgxtest.ValueRoundTripTest{ { pgtype.Line{ A: 1.23, B: 4.56, C: 7.89012345, diff --git a/pgtype/lseg_test.go b/pgtype/lseg_test.go index e754b3b2..04fde0eb 100644 --- a/pgtype/lseg_test.go +++ b/pgtype/lseg_test.go @@ -1,16 +1,17 @@ package pgtype_test import ( + "context" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func TestLsegTranscode(t *testing.T) { skipCockroachDB(t, "Server does not support type lseg") - testutil.RunTranscodeTests(t, "lseg", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "lseg", []pgxtest.ValueRoundTripTest{ { pgtype.Lseg{ P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go index 06262876..e2463271 100644 --- a/pgtype/macaddr_test.go +++ b/pgtype/macaddr_test.go @@ -2,10 +2,11 @@ package pgtype_test import ( "bytes" + "context" "net" "testing" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func isExpectedEqHardwareAddr(a interface{}) func(interface{}) bool { @@ -28,7 +29,8 @@ func isExpectedEqHardwareAddr(a interface{}) func(interface{}) bool { func TestMacaddrCodec(t *testing.T) { skipCockroachDB(t, "Server does not support type macaddr") - testutil.RunTranscodeTests(t, "macaddr", []testutil.TranscodeTestCase{ + // Only testing known OID query exec modes as net.HardwareAddr could map to macaddr or macaddr8. + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "macaddr", []pgxtest.ValueRoundTripTest{ { mustParseMacaddr(t, "01:23:45:67:89:ab"), new(net.HardwareAddr), diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 281efa0f..8be8ce55 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -9,8 +9,9 @@ import ( "strconv" "testing" + pgx "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -77,7 +78,7 @@ func TestNumericCodec(t *testing.T) { max.Add(max, big.NewInt(1)) longestNumeric := pgtype.Numeric{Int: max, Exp: -16383, Valid: true} - testutil.RunTranscodeTests(t, "numeric", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "numeric", []pgxtest.ValueRoundTripTest{ {mustParseNumeric(t, "1"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))}, {mustParseNumeric(t, "3.14159"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "3.14159"))}, {mustParseNumeric(t, "100010001"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "100010001"))}, @@ -118,7 +119,7 @@ func TestNumericCodecInfinity(t *testing.T) { skipCockroachDB(t, "server formats numeric text format differently") skipPostgreSQLVersionLessThan(t, 14) - testutil.RunTranscodeTests(t, "numeric", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "numeric", []pgxtest.ValueRoundTripTest{ {math.Inf(1), new(float64), isExpectedEq(math.Inf(1))}, {float32(math.Inf(1)), new(float32), isExpectedEq(float32(math.Inf(1)))}, {math.Inf(-1), new(float64), isExpectedEq(math.Inf(-1))}, @@ -159,54 +160,54 @@ func TestNumericCodecFuzz(t *testing.T) { max := &big.Int{} max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) - tests := make([]testutil.TranscodeTestCase, 0, 2000) + tests := make([]pgxtest.ValueRoundTripTest, 0, 2000) for i := 0; i < 10; i++ { for j := -50; j < 50; j++ { num := (&big.Int{}).Rand(r, max) n := pgtype.Numeric{Int: num, Exp: int32(j), Valid: true} - tests = append(tests, testutil.TranscodeTestCase{n, new(pgtype.Numeric), isExpectedEqNumeric(n)}) + tests = append(tests, pgxtest.ValueRoundTripTest{n, new(pgtype.Numeric), isExpectedEqNumeric(n)}) negNum := &big.Int{} negNum.Neg(num) n = pgtype.Numeric{Int: negNum, Exp: int32(j), Valid: true} - tests = append(tests, testutil.TranscodeTestCase{n, new(pgtype.Numeric), isExpectedEqNumeric(n)}) + tests = append(tests, pgxtest.ValueRoundTripTest{n, new(pgtype.Numeric), isExpectedEqNumeric(n)}) } } - testutil.RunTranscodeTests(t, "numeric", tests) + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "numeric", tests) } func TestNumericMarshalJSON(t *testing.T) { skipCockroachDB(t, "server formats numeric text format differently") - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - for i, tt := range []struct { - decString string - }{ - {"NaN"}, - {"0"}, - {"1"}, - {"-1"}, - {"1000000000000000000"}, - {"1234.56789"}, - {"1.56789"}, - {"0.00000000000056789"}, - {"0.00123000"}, - {"123e-3"}, - {"243723409723490243842378942378901237502734019231380123e23790"}, - {"3409823409243892349028349023482934092340892390101e-14021"}, - } { - var num pgtype.Numeric - var pgJSON string - err := conn.QueryRow(context.Background(), `select $1::numeric, to_json($1::numeric)`, tt.decString).Scan(&num, &pgJSON) - require.NoErrorf(t, err, "%d", i) + for i, tt := range []struct { + decString string + }{ + {"NaN"}, + {"0"}, + {"1"}, + {"-1"}, + {"1000000000000000000"}, + {"1234.56789"}, + {"1.56789"}, + {"0.00000000000056789"}, + {"0.00123000"}, + {"123e-3"}, + {"243723409723490243842378942378901237502734019231380123e23790"}, + {"3409823409243892349028349023482934092340892390101e-14021"}, + } { + var num pgtype.Numeric + var pgJSON string + err := conn.QueryRow(ctx, `select $1::numeric, to_json($1::numeric)`, tt.decString).Scan(&num, &pgJSON) + require.NoErrorf(t, err, "%d", i) - goJSON, err := json.Marshal(num) - require.NoErrorf(t, err, "%d", i) + goJSON, err := json.Marshal(num) + require.NoErrorf(t, err, "%d", i) - require.Equal(t, pgJSON, string(goJSON)) - } + require.Equal(t, pgJSON, string(goJSON)) + } + }) } diff --git a/pgtype/path_test.go b/pgtype/path_test.go index 40df2bfb..f9e13294 100644 --- a/pgtype/path_test.go +++ b/pgtype/path_test.go @@ -1,10 +1,11 @@ package pgtype_test import ( + "context" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func isExpectedEqPath(a interface{}) func(interface{}) bool { @@ -29,7 +30,7 @@ func isExpectedEqPath(a interface{}) func(interface{}) bool { func TestPathTranscode(t *testing.T) { skipCockroachDB(t, "Server does not support type path") - testutil.RunTranscodeTests(t, "path", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "path", []pgxtest.ValueRoundTripTest{ { pgtype.Path{ P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}}, diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index ffc017f8..94158cb8 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -374,7 +374,7 @@ func NewMap() *Map { registerDefaultPgTypeVariants("daterange", "_daterange", Daterange{}) registerDefaultPgTypeVariants("float4", "_float4", Float4{}) registerDefaultPgTypeVariants("float8", "_float8", Float8{}) - registerDefaultPgTypeVariants("float8range", "_float8range", Float8range{}) + registerDefaultPgTypeVariants("numrange", "_numrange", Float8range{}) // There is no PostgreSQL builtin float8range so map it to numrange. registerDefaultPgTypeVariants("inet", "_inet", Inet{}) registerDefaultPgTypeVariants("int2", "_int2", Int2{}) registerDefaultPgTypeVariants("int4", "_int4", Int4{}) diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index dd0150fb..9778c335 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -1,21 +1,34 @@ package pgtype_test import ( + "context" "database/sql" "errors" "net" + "os" "regexp" "strconv" "testing" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" _ "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +var defaultConnTestRunner pgxtest.ConnTestRunner + +func init() { + defaultConnTestRunner = pgxtest.DefaultConnTestRunner() + defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + return config + } +} + // Test for renamed types type _string string type _bool bool @@ -70,8 +83,11 @@ func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { } func skipCockroachDB(t testing.TB, msg string) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + defer conn.Close(context.Background()) if conn.PgConn().ParameterStatus("crdb_version") != "" { t.Skip(msg) @@ -79,8 +95,11 @@ func skipCockroachDB(t testing.TB, msg string) { } func skipPostgreSQLVersionLessThan(t testing.TB, minVersion int64) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + defer conn.Close(context.Background()) serverVersionStr := conn.PgConn().ParameterStatus("server_version") serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) diff --git a/pgtype/point_test.go b/pgtype/point_test.go index 62fcfc51..336f1a47 100644 --- a/pgtype/point_test.go +++ b/pgtype/point_test.go @@ -1,18 +1,19 @@ package pgtype_test import ( + "context" "reflect" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) func TestPointCodec(t *testing.T) { skipCockroachDB(t, "Server does not support type point") - testutil.RunTranscodeTests(t, "point", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "point", []pgxtest.ValueRoundTripTest{ { pgtype.Point{P: pgtype.Vec2{1.234, 5.6789012345}, Valid: true}, new(pgtype.Point), diff --git a/pgtype/polygon_test.go b/pgtype/polygon_test.go index 6c6fc60d..a6a60de2 100644 --- a/pgtype/polygon_test.go +++ b/pgtype/polygon_test.go @@ -1,10 +1,11 @@ package pgtype_test import ( + "context" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func isExpectedEqPolygon(a interface{}) func(interface{}) bool { @@ -29,7 +30,7 @@ func isExpectedEqPolygon(a interface{}) func(interface{}) bool { func TestPolygonTranscode(t *testing.T) { skipCockroachDB(t, "Server does not support type polygon") - testutil.RunTranscodeTests(t, "polygon", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "polygon", []pgxtest.ValueRoundTripTest{ { pgtype.Polygon{ P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go index 0bf781a4..da00b89e 100644 --- a/pgtype/qchar_test.go +++ b/pgtype/qchar_test.go @@ -1,22 +1,24 @@ package pgtype_test import ( + "context" "math" "testing" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func TestQcharTranscode(t *testing.T) { skipCockroachDB(t, "Server does not support qchar") - var tests []testutil.TranscodeTestCase + var tests []pgxtest.ValueRoundTripTest for i := 0; i <= math.MaxUint8; i++ { - tests = append(tests, testutil.TranscodeTestCase{rune(i), new(rune), isExpectedEq(rune(i))}) - tests = append(tests, testutil.TranscodeTestCase{byte(i), new(byte), isExpectedEq(byte(i))}) + tests = append(tests, pgxtest.ValueRoundTripTest{rune(i), new(rune), isExpectedEq(rune(i))}) + tests = append(tests, pgxtest.ValueRoundTripTest{byte(i), new(byte), isExpectedEq(byte(i))}) } - tests = append(tests, testutil.TranscodeTestCase{nil, new(*rune), isExpectedEq((*rune)(nil))}) - tests = append(tests, testutil.TranscodeTestCase{nil, new(*byte), isExpectedEq((*byte)(nil))}) + tests = append(tests, pgxtest.ValueRoundTripTest{nil, new(*rune), isExpectedEq((*rune)(nil))}) + tests = append(tests, pgxtest.ValueRoundTripTest{nil, new(*byte), isExpectedEq((*byte)(nil))}) - testutil.RunTranscodeTests(t, `"char"`, tests) + // Can only test with known OIDs as rune and byte would be considered numbers. + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, `"char"`, tests) } diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go index b4127769..8c1116a0 100644 --- a/pgtype/range_codec_test.go +++ b/pgtype/range_codec_test.go @@ -4,15 +4,16 @@ import ( "context" "testing" + pgx "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) func TestRangeCodecTranscode(t *testing.T) { skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") - testutil.RunTranscodeTests(t, "int4range", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4range", []pgxtest.ValueRoundTripTest{ { pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, new(pgtype.Int4range), @@ -39,9 +40,12 @@ func TestRangeCodecTranscode(t *testing.T) { } func TestRangeCodecTranscodeCompatibleRangeElementTypes(t *testing.T) { - skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + } - testutil.RunTranscodeTests(t, "numrange", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, nil, "numrange", []pgxtest.ValueRoundTripTest{ { pgtype.Float8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, new(pgtype.Float8range), @@ -70,90 +74,90 @@ func TestRangeCodecTranscodeCompatibleRangeElementTypes(t *testing.T) { func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - var r pgtype.Int4range + var r pgtype.Int4range - err := conn.QueryRow(context.Background(), `select '[1,5)'::int4range`).Scan(&r) - require.NoError(t, err) + err := conn.QueryRow(context.Background(), `select '[1,5)'::int4range`).Scan(&r) + require.NoError(t, err) - require.Equal( - t, - pgtype.Int4range{ - Lower: pgtype.Int4{Int32: 1, Valid: true}, - Upper: pgtype.Int4{Int32: 5, Valid: true}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Valid: true, - }, - r, - ) + require.Equal( + t, + pgtype.Int4range{ + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + r, + ) - err = conn.QueryRow(context.Background(), `select '[1,)'::int4range`).Scan(&r) - require.NoError(t, err) + err = conn.QueryRow(ctx, `select '[1,)'::int4range`).Scan(&r) + require.NoError(t, err) - require.Equal( - t, - pgtype.Int4range{ - Lower: pgtype.Int4{Int32: 1, Valid: true}, - Upper: pgtype.Int4{}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Unbounded, - Valid: true, - }, - r, - ) + require.Equal( + t, + pgtype.Int4range{ + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Unbounded, + Valid: true, + }, + r, + ) - err = conn.QueryRow(context.Background(), `select 'empty'::int4range`).Scan(&r) - require.NoError(t, err) + err = conn.QueryRow(ctx, `select 'empty'::int4range`).Scan(&r) + require.NoError(t, err) - require.Equal( - t, - pgtype.Int4range{ - Lower: pgtype.Int4{}, - Upper: pgtype.Int4{}, - LowerType: pgtype.Empty, - UpperType: pgtype.Empty, - Valid: true, - }, - r, - ) + require.Equal( + t, + pgtype.Int4range{ + Lower: pgtype.Int4{}, + Upper: pgtype.Int4{}, + LowerType: pgtype.Empty, + UpperType: pgtype.Empty, + Valid: true, + }, + r, + ) + }) } func TestRangeCodecDecodeValue(t *testing.T) { skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { - for _, tt := range []struct { - sql string - expected interface{} - }{ - { - sql: `select '[1,5)'::int4range`, - expected: pgtype.GenericRange{ - Lower: int32(1), - Upper: int32(5), - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Valid: true, + for _, tt := range []struct { + sql string + expected interface{} + }{ + { + sql: `select '[1,5)'::int4range`, + expected: pgtype.GenericRange{ + Lower: int32(1), + Upper: int32(5), + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, }, - }, - } { - t.Run(tt.sql, func(t *testing.T) { - rows, err := conn.Query(context.Background(), tt.sql) - require.NoError(t, err) - - for rows.Next() { - values, err := rows.Values() + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(ctx, tt.sql) require.NoError(t, err) - require.Len(t, values, 1) - require.Equal(t, tt.expected, values[0]) - } - require.NoError(t, rows.Err()) - }) - } + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } + }) } diff --git a/pgtype/record_codec_test.go b/pgtype/record_codec_test.go index d6fe603c..57fa87ff 100644 --- a/pgtype/record_codec_test.go +++ b/pgtype/record_codec_test.go @@ -4,72 +4,70 @@ import ( "context" "testing" + pgx "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/stretchr/testify/require" ) func TestRecordCodec(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var a string + var b int32 + err := conn.QueryRow(ctx, `select row('foo'::text, 42::int4)`).Scan(pgtype.CompositeFields{&a, &b}) + require.NoError(t, err) - var a string - var b int32 - err := conn.QueryRow(context.Background(), `select row('foo'::text, 42::int4)`).Scan(pgtype.CompositeFields{&a, &b}) - require.NoError(t, err) - - require.Equal(t, "foo", a) - require.Equal(t, int32(42), b) + require.Equal(t, "foo", a) + require.Equal(t, int32(42), b) + }) } func TestRecordCodecDecodeValue(t *testing.T) { skipCockroachDB(t, "Server converts row int4 to int8") - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - for _, tt := range []struct { - sql string - expected interface{} - }{ - { - sql: `select row()`, - expected: []interface{}{}, - }, - { - sql: `select row('foo'::text, 42::int4)`, - expected: []interface{}{"foo", int32(42)}, - }, - { - sql: `select row(100.0::float4, 1.09::float4)`, - expected: []interface{}{float32(100), float32(1.09)}, - }, - { - sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, - expected: []interface{}{"foo", []interface{}{int32(1), int32(2), nil, int32(4)}, int32(42)}, - }, - { - sql: `select row(null)`, - expected: []interface{}{nil}, - }, - { - sql: `select null::record`, - expected: nil, - }, - } { - t.Run(tt.sql, func(t *testing.T) { - rows, err := conn.Query(context.Background(), tt.sql) - require.NoError(t, err) - defer rows.Close() - - for rows.Next() { - values, err := rows.Values() + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + for _, tt := range []struct { + sql string + expected interface{} + }{ + { + sql: `select row()`, + expected: []interface{}{}, + }, + { + sql: `select row('foo'::text, 42::int4)`, + expected: []interface{}{"foo", int32(42)}, + }, + { + sql: `select row(100.0::float4, 1.09::float4)`, + expected: []interface{}{float32(100), float32(1.09)}, + }, + { + sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, + expected: []interface{}{"foo", []interface{}{int32(1), int32(2), nil, int32(4)}, int32(42)}, + }, + { + sql: `select row(null)`, + expected: []interface{}{nil}, + }, + { + sql: `select null::record`, + expected: nil, + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(context.Background(), tt.sql) require.NoError(t, err) - require.Len(t, values, 1) - require.Equal(t, tt.expected, values[0]) - } + defer rows.Close() - require.NoError(t, rows.Err()) - }) - } + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } + }) } diff --git a/pgtype/testutil/testutil.go b/pgtype/testutil/testutil.go deleted file mode 100644 index 19ed4412..00000000 --- a/pgtype/testutil/testutil.go +++ /dev/null @@ -1,85 +0,0 @@ -package testutil - -import ( - "context" - "fmt" - "os" - "reflect" - "testing" - - "github.com/jackc/pgx/v5" - _ "github.com/jackc/pgx/v5/stdlib" -) - -func MustConnectPgx(t testing.TB) *pgx.Conn { - conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - t.Fatal(err) - } - - return conn -} - -func MustClose(t testing.TB, conn interface { - Close() error -}) { - err := conn.Close() - if err != nil { - t.Fatal(err) - } -} - -func MustCloseContext(t testing.TB, conn interface { - Close(context.Context) error -}) { - err := conn.Close(context.Background()) - if err != nil { - t.Fatal(err) - } -} - -type TranscodeTestCase struct { - Src interface{} - Dst interface{} - Test func(interface{}) bool -} - -func RunTranscodeTests(t testing.TB, pgTypeName string, tests []TranscodeTestCase) { - conn := MustConnectPgx(t) - defer MustCloseContext(t, conn) - - formats := []struct { - name string - code int16 - }{ - {name: "TextFormat", code: pgx.TextFormatCode}, - {name: "BinaryFormat", code: pgx.BinaryFormatCode}, - } - - for _, format := range formats { - RunTranscodeTestsFormat(t, pgTypeName, tests, conn, format.name, format.code) - } -} - -func RunTranscodeTestsFormat(t testing.TB, pgTypeName string, tests []TranscodeTestCase, conn *pgx.Conn, formatName string, formatCode int16) { - _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - for i, tt := range tests { - err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{formatCode}, tt.Src).Scan(tt.Dst) - if err != nil { - t.Errorf("%s %d: %v", formatName, i, err) - } - - dst := reflect.ValueOf(tt.Dst) - if dst.Kind() == reflect.Ptr { - dst = dst.Elem() - } - - if !tt.Test(dst.Interface()) { - t.Errorf("%s %d: unexpected result for %v: %v", formatName, i, tt.Src, dst.Interface()) - } - } -} diff --git a/pgtype/text_test.go b/pgtype/text_test.go index c80c404b..1d717f49 100644 --- a/pgtype/text_test.go +++ b/pgtype/text_test.go @@ -4,8 +4,9 @@ import ( "context" "testing" + pgx "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) @@ -17,7 +18,7 @@ func (someFmtStringer) String() string { func TestTextCodec(t *testing.T) { for _, pgTypeName := range []string{"text", "varchar"} { - testutil.RunTranscodeTests(t, pgTypeName, []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, pgTypeName, []pgxtest.ValueRoundTripTest{ { pgtype.Text{String: "", Valid: true}, new(pgtype.Text), @@ -31,6 +32,10 @@ func TestTextCodec(t *testing.T) { {nil, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, {"foo", new(string), isExpectedEq("foo")}, {someFmtStringer{}, new(string), isExpectedEq("some fmt.Stringer")}, + }) + + // rune requires known OID because otherwise it is considered an int32. + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, pgTypeName, []pgxtest.ValueRoundTripTest{ {rune('R'), new(rune), isExpectedEq(rune('R'))}, }) } @@ -47,7 +52,7 @@ func TestTextCodec(t *testing.T) { // // So this is simply a smoke test of the name type. func TestTextCodecName(t *testing.T) { - testutil.RunTranscodeTests(t, "name", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "name", []pgxtest.ValueRoundTripTest{ { pgtype.Text{String: "", Valid: true}, new(pgtype.Text), @@ -67,7 +72,7 @@ func TestTextCodecName(t *testing.T) { func TestTextCodecBPChar(t *testing.T) { skipCockroachDB(t, "Server does not properly handle bpchar with multi-byte character") - testutil.RunTranscodeTests(t, "char(3)", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "char(3)", []pgxtest.ValueRoundTripTest{ { pgtype.Text{String: "a ", Valid: true}, new(pgtype.Text), @@ -94,12 +99,12 @@ func TestTextCodecBPChar(t *testing.T) { // // It only supports the text format. func TestTextCodecACLItem(t *testing.T) { - skipCockroachDB(t, "Server does not support type aclitem") + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does not support type aclitem") + } - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - testutil.RunTranscodeTestsFormat(t, "aclitem", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, nil, "aclitem", []pgxtest.ValueRoundTripTest{ { pgtype.Text{String: "postgres=arwdDxt/postgres", Valid: true}, new(pgtype.Text), @@ -107,33 +112,33 @@ func TestTextCodecACLItem(t *testing.T) { }, {pgtype.Text{}, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, {nil, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, - }, conn, "Text", pgtype.TextFormatCode) + }) } func TestTextCodecACLItemRoleWithSpecialCharacters(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does not support type aclitem") - ctx := context.Background() + // The tricky test user, below, has to actually exist so that it can be used in a test + // of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. + roleWithSpecialCharacters := ` tricky, ' } " \ test user ` - // The tricky test user, below, has to actually exist so that it can be used in a test - // of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. - roleWithSpecialCharacters := ` tricky, ' } " \ test user ` + commandTag, err := conn.Exec(ctx, `select * from pg_roles where rolname = $1`, roleWithSpecialCharacters) + require.NoError(t, err) - commandTag, err := conn.Exec(ctx, `select * from pg_roles where rolname = $1`, roleWithSpecialCharacters) - require.NoError(t, err) - - if commandTag.RowsAffected() == 0 { - t.Skipf("Role with special characters does not exist.") + if commandTag.RowsAffected() == 0 { + t.Skipf("Role with special characters does not exist.") + } } - testutil.RunTranscodeTestsFormat(t, "aclitem", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, nil, "aclitem", []pgxtest.ValueRoundTripTest{ { pgtype.Text{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}, new(pgtype.Text), isExpectedEq(pgtype.Text{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}), }, - }, conn, "Text", pgtype.TextFormatCode) + }) } func TestTextMarshalJSON(t *testing.T) { diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go index 08636aa8..3e7a1a50 100644 --- a/pgtype/tid_test.go +++ b/pgtype/tid_test.go @@ -1,16 +1,17 @@ package pgtype_test import ( + "context" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func TestTIDCodec(t *testing.T) { skipCockroachDB(t, "Server does not support type tid") - testutil.RunTranscodeTests(t, "tid", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "tid", []pgxtest.ValueRoundTripTest{ { pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, new(pgtype.TID), diff --git a/pgtype/time_test.go b/pgtype/time_test.go index 61f9ef0e..01bcee0f 100644 --- a/pgtype/time_test.go +++ b/pgtype/time_test.go @@ -1,15 +1,16 @@ package pgtype_test import ( + "context" "testing" "time" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func TestTimeCodec(t *testing.T) { - testutil.RunTranscodeTests(t, "time", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "time", []pgxtest.ValueRoundTripTest{ { pgtype.Time{Microseconds: 0, Valid: true}, new(pgtype.Time), diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 3a0bd275..10525229 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "encoding/binary" "fmt" + "strings" "time" "github.com/jackc/pgx/v5/internal/pgio" @@ -127,12 +128,29 @@ func (encodePlanTimestampCodecText) Encode(value interface{}, buf []byte) (newBu return nil, err } + if !ts.Valid { + return nil, nil + } + var s string switch ts.InfinityModifier { case Finite: t := discardTimeZone(ts.Time) + + // Year 0000 is 1 BC + bc := false + if year := t.Year(); year <= 0 { + year = -year + 1 + t = time.Date(year, t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC) + bc = true + } + s = t.Truncate(time.Microsecond).Format(pgTimestampFormat) + + if bc { + s = s + " BC" + } case Infinity: s = "infinity" case NegativeInfinity: @@ -219,11 +237,21 @@ func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst interface{}) case "-infinity": ts = Timestamp{Valid: true, InfinityModifier: -Infinity} default: + bc := false + if strings.HasSuffix(sbuf, " BC") { + sbuf = sbuf[:len(sbuf)-3] + bc = true + } tim, err := time.Parse(pgTimestampFormat, sbuf) if err != nil { return err } + if bc { + year := -tim.Year() + 1 + tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location()) + } + ts = Timestamp{Time: tim, Valid: true} } diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index 764baff1..849f55f6 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -5,15 +5,21 @@ import ( "testing" "time" + pgx "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) func TestTimestampCodec(t *testing.T) { skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") - testutil.RunTranscodeTests(t, "timestamp", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "timestamp", []pgxtest.ValueRoundTripTest{ + {time.Date(-100, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(-100, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(-1, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(-1, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC))}, {time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC))}, {time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC))}, @@ -34,19 +40,18 @@ func TestTimestampCodec(t *testing.T) { // https://github.com/jackc/pgx/v4/pgtype/pull/128 func TestTimestampTranscodeBigTimeBinary(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + in := &pgtype.Timestamp{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Valid: true} + var out pgtype.Timestamp - in := &pgtype.Timestamp{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Valid: true} - var out pgtype.Timestamp + err := conn.QueryRow(ctx, "select $1::timestamp", in).Scan(&out) + if err != nil { + t.Fatal(err) + } - err := conn.QueryRow(context.Background(), "select $1::timestamp", in).Scan(&out) - if err != nil { - t.Fatal(err) - } - - require.Equal(t, in.Valid, out.Valid) - require.Truef(t, in.Time.Equal(out.Time), "expected %v got %v", in.Time, out.Time) + require.Equal(t, in.Valid, out.Valid) + require.Truef(t, in.Time.Equal(out.Time), "expected %v got %v", in.Time, out.Time) + }) } // https://github.com/jackc/pgtype/issues/74 diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 5069af02..7709e0aa 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "strings" "time" "github.com/jackc/pgx/v5/internal/pgio" @@ -184,11 +185,30 @@ func (encodePlanTimestamptzCodecText) Encode(value interface{}, buf []byte) (new return nil, err } + if !ts.Valid { + return nil, nil + } + var s string switch ts.InfinityModifier { case Finite: - s = ts.Time.UTC().Truncate(time.Microsecond).Format(pgTimestamptzSecondFormat) + + t := ts.Time.UTC().Truncate(time.Microsecond) + + // Year 0000 is 1 BC + bc := false + if year := t.Year(); year <= 0 { + year = -year + 1 + t = time.Date(year, t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC) + bc = true + } + + s = t.Format(pgTimestamptzSecondFormat) + + if bc { + s = s + " BC" + } case Infinity: s = "infinity" case NegativeInfinity: @@ -267,6 +287,12 @@ func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst interfac case "-infinity": tstz = Timestamptz{Valid: true, InfinityModifier: -Infinity} default: + bc := false + if strings.HasSuffix(sbuf, " BC") { + sbuf = sbuf[:len(sbuf)-3] + bc = true + } + var format string if len(sbuf) >= 9 && (sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+') { format = pgTimestamptzSecondFormat @@ -281,6 +307,11 @@ func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst interfac return err } + if bc { + year := -tim.Year() + 1 + tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location()) + } + tstz = Timestamptz{Time: tim, Valid: true} } diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index 678f3013..0486ecdb 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -5,15 +5,21 @@ import ( "testing" "time" + pgx "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) func TestTimestamptzCodec(t *testing.T) { skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") - testutil.RunTranscodeTests(t, "timestamptz", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "timestamptz", []pgxtest.ValueRoundTripTest{ + {time.Date(-100, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(-100, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(-1, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(-1, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(0, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(0, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(1, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local))}, {time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local))}, {time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local))}, @@ -34,19 +40,18 @@ func TestTimestamptzCodec(t *testing.T) { // https://github.com/jackc/pgx/v4/pgtype/pull/128 func TestTimestamptzTranscodeBigTimeBinary(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + in := &pgtype.Timestamptz{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Valid: true} + var out pgtype.Timestamptz - in := &pgtype.Timestamptz{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Valid: true} - var out pgtype.Timestamptz + err := conn.QueryRow(ctx, "select $1::timestamptz", in).Scan(&out) + if err != nil { + t.Fatal(err) + } - err := conn.QueryRow(context.Background(), "select $1::timestamptz", in).Scan(&out) - if err != nil { - t.Fatal(err) - } - - require.Equal(t, in.Valid, out.Valid) - require.Truef(t, in.Time.Equal(out.Time), "expected %v got %v", in.Time, out.Time) + require.Equal(t, in.Valid, out.Valid) + require.Truef(t, in.Time.Equal(out.Time), "expected %v got %v", in.Time, out.Time) + }) } // https://github.com/jackc/pgtype/issues/74 diff --git a/pgtype/uint32_test.go b/pgtype/uint32_test.go index d6699a03..842de643 100644 --- a/pgtype/uint32_test.go +++ b/pgtype/uint32_test.go @@ -1,14 +1,15 @@ package pgtype_test import ( + "context" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) func TestUint32Codec(t *testing.T) { - testutil.RunTranscodeTests(t, "oid", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "oid", []pgxtest.ValueRoundTripTest{ { pgtype.Uint32{Uint32: pgtype.TextOID, Valid: true}, new(pgtype.Uint32), diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index 870e7ae1..06ff38c2 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -1,31 +1,22 @@ package pgtype_test import ( + "context" "reflect" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) func TestUUIDCodec(t *testing.T) { - testutil.RunTranscodeTests(t, "uuid", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "uuid", []pgxtest.ValueRoundTripTest{ { pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, new(pgtype.UUID), isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), }, - { - [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - new(pgtype.UUID), - isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), - }, - { - []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - new(pgtype.UUID), - isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), - }, { "00010203-0405-0607-0809-0a0b0c0d0e0f", new(pgtype.UUID), @@ -40,6 +31,19 @@ func TestUUIDCodec(t *testing.T) { {pgtype.UUID{}, new(pgtype.UUID), isExpectedEq(pgtype.UUID{})}, {nil, new(pgtype.UUID), isExpectedEq(pgtype.UUID{})}, }) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "uuid", []pgxtest.ValueRoundTripTest{ + { + [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + { + []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + }) } func TestUUID_MarshalJSON(t *testing.T) { diff --git a/pgtype/zeronull/float8_test.go b/pgtype/zeronull/float8_test.go index fd683657..a8c1b6a1 100644 --- a/pgtype/zeronull/float8_test.go +++ b/pgtype/zeronull/float8_test.go @@ -1,10 +1,11 @@ package zeronull_test import ( + "context" "testing" - "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" ) func isExpectedEq(a interface{}) func(interface{}) bool { @@ -14,7 +15,7 @@ func isExpectedEq(a interface{}) func(interface{}) bool { } func TestFloat8Transcode(t *testing.T) { - testutil.RunTranscodeTests(t, "float8", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "float8", []pgxtest.ValueRoundTripTest{ { (zeronull.Float8)(1), new(zeronull.Float8), diff --git a/pgtype/zeronull/int_test.go b/pgtype/zeronull/int_test.go index d687a733..30e4808f 100644 --- a/pgtype/zeronull/int_test.go +++ b/pgtype/zeronull/int_test.go @@ -2,14 +2,15 @@ package zeronull_test import ( + "context" "testing" - "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" ) func TestInt2Transcode(t *testing.T) { - testutil.RunTranscodeTests(t, "int2", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int2", []pgxtest.ValueRoundTripTest{ { (zeronull.Int2)(1), new(zeronull.Int2), @@ -29,7 +30,7 @@ func TestInt2Transcode(t *testing.T) { } func TestInt4Transcode(t *testing.T) { - testutil.RunTranscodeTests(t, "int4", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4", []pgxtest.ValueRoundTripTest{ { (zeronull.Int4)(1), new(zeronull.Int4), @@ -49,7 +50,7 @@ func TestInt4Transcode(t *testing.T) { } func TestInt8Transcode(t *testing.T) { - testutil.RunTranscodeTests(t, "int8", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int8", []pgxtest.ValueRoundTripTest{ { (zeronull.Int8)(1), new(zeronull.Int8), diff --git a/pgtype/zeronull/int_test.go.erb b/pgtype/zeronull/int_test.go.erb index b33cfa4a..2c7ddc46 100644 --- a/pgtype/zeronull/int_test.go.erb +++ b/pgtype/zeronull/int_test.go.erb @@ -10,7 +10,7 @@ import ( <% [2, 4, 8].each do |pg_byte_size| %> <% pg_bit_size = pg_byte_size * 8 %> func TestInt<%= pg_byte_size %>Transcode(t *testing.T) { - testutil.RunTranscodeTests(t, "int<%= pg_byte_size %>", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int<%= pg_byte_size %>", []pgxtest.ValueRoundTripTest{ { (zeronull.Int<%= pg_byte_size %>)(1), new(zeronull.Int<%= pg_byte_size %>), diff --git a/pgtype/zeronull/text_test.go b/pgtype/zeronull/text_test.go index e20ab868..e0d6ec43 100644 --- a/pgtype/zeronull/text_test.go +++ b/pgtype/zeronull/text_test.go @@ -1,14 +1,15 @@ package zeronull_test import ( + "context" "testing" - "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" ) func TestTextTranscode(t *testing.T) { - testutil.RunTranscodeTests(t, "text", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "text", []pgxtest.ValueRoundTripTest{ { (zeronull.Text)("foo"), new(zeronull.Text), diff --git a/pgtype/zeronull/timestamp_test.go b/pgtype/zeronull/timestamp_test.go index 9d8ee7ae..78393e9b 100644 --- a/pgtype/zeronull/timestamp_test.go +++ b/pgtype/zeronull/timestamp_test.go @@ -1,11 +1,12 @@ package zeronull_test import ( + "context" "testing" "time" - "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" ) func isExpectedEqTimestamp(a interface{}) func(interface{}) bool { @@ -18,7 +19,7 @@ func isExpectedEqTimestamp(a interface{}) func(interface{}) bool { } func TestTimestampTranscode(t *testing.T) { - testutil.RunTranscodeTests(t, "timestamp", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "timestamp", []pgxtest.ValueRoundTripTest{ { (zeronull.Timestamp)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), new(zeronull.Timestamp), diff --git a/pgtype/zeronull/timestamptz_test.go b/pgtype/zeronull/timestamptz_test.go index 15ac66da..d8273258 100644 --- a/pgtype/zeronull/timestamptz_test.go +++ b/pgtype/zeronull/timestamptz_test.go @@ -1,11 +1,12 @@ package zeronull_test import ( + "context" "testing" "time" - "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" ) func isExpectedEqTimestamptz(a interface{}) func(interface{}) bool { @@ -18,7 +19,7 @@ func isExpectedEqTimestamptz(a interface{}) func(interface{}) bool { } func TestTimestamptzTranscode(t *testing.T) { - testutil.RunTranscodeTests(t, "timestamptz", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "timestamptz", []pgxtest.ValueRoundTripTest{ { (zeronull.Timestamptz)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), new(zeronull.Timestamptz), diff --git a/pgtype/zeronull/uuid_test.go b/pgtype/zeronull/uuid_test.go index 5be1d22e..0cb169fa 100644 --- a/pgtype/zeronull/uuid_test.go +++ b/pgtype/zeronull/uuid_test.go @@ -1,14 +1,15 @@ package zeronull_test import ( + "context" "testing" - "github.com/jackc/pgx/v5/pgtype/testutil" "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" ) func TestUUIDTranscode(t *testing.T) { - testutil.RunTranscodeTests(t, "uuid", []testutil.TranscodeTestCase{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "uuid", []pgxtest.ValueRoundTripTest{ { (zeronull.UUID)([16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), new(zeronull.UUID), diff --git a/pgtype/zeronull/zeronull.go b/pgtype/zeronull/zeronull.go new file mode 100644 index 00000000..bba7b423 --- /dev/null +++ b/pgtype/zeronull/zeronull.go @@ -0,0 +1,17 @@ +package zeronull + +import ( + "github.com/jackc/pgx/v5/pgtype" +) + +// Register registers the zeronull types so they can be used in query exec modes that do not know the server OIDs. +func Register(m *pgtype.Map) { + m.RegisterDefaultPgType(Float8(0), "float8") + m.RegisterDefaultPgType(Int2(0), "int2") + m.RegisterDefaultPgType(Int4(0), "int4") + m.RegisterDefaultPgType(Int8(0), "int8") + m.RegisterDefaultPgType(Text(""), "text") + m.RegisterDefaultPgType(Timestamp{}, "timestamp") + m.RegisterDefaultPgType(Timestamptz{}, "timestamptz") + m.RegisterDefaultPgType(UUID{}, "uuid") +} diff --git a/pgtype/zeronull/zeronull_test.go b/pgtype/zeronull/zeronull_test.go new file mode 100644 index 00000000..9ee45cb7 --- /dev/null +++ b/pgtype/zeronull/zeronull_test.go @@ -0,0 +1,26 @@ +package zeronull_test + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +var defaultConnTestRunner pgxtest.ConnTestRunner + +func init() { + defaultConnTestRunner = pgxtest.DefaultConnTestRunner() + defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + return config + } + defaultConnTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + zeronull.Register(conn.TypeMap()) + } +} diff --git a/pgxtest/pgxtest.go b/pgxtest/pgxtest.go index dec6a520..3a416dc2 100644 --- a/pgxtest/pgxtest.go +++ b/pgxtest/pgxtest.go @@ -3,11 +3,28 @@ package pgxtest import ( "context" + "fmt" + "reflect" "testing" "github.com/jackc/pgx/v5" ) +var AllQueryExecModes = []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + pgx.QueryExecModeExec, + pgx.QueryExecModeSimpleProtocol, +} + +// KnownOIDQueryExecModes is a slice of all query exec modes where the param and result OIDs are known before sending the query. +var KnownOIDQueryExecModes = []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, +} + // ConnTestRunner controls how a *pgx.Conn is created and closed by tests. All fields are required. Use DefaultConnTestRunner to get a // ConnTestRunner with reasonable default values. type ConnTestRunner struct { @@ -46,6 +63,8 @@ func DefaultConnTestRunner() ConnTestRunner { } func (ctr *ConnTestRunner) RunTest(ctx context.Context, t testing.TB, f func(ctx context.Context, t testing.TB, conn *pgx.Conn)) { + t.Helper() + config := ctr.CreateConfig(ctx, t) conn, err := pgx.ConnectConfig(ctx, config) if err != nil { @@ -62,13 +81,7 @@ func (ctr *ConnTestRunner) RunTest(ctx context.Context, t testing.TB, f func(ctx // If modes is nil all pgx.QueryExecModes are tested. func RunWithQueryExecModes(ctx context.Context, t *testing.T, ctr ConnTestRunner, modes []pgx.QueryExecMode, f func(ctx context.Context, t testing.TB, conn *pgx.Conn)) { if modes == nil { - modes = []pgx.QueryExecMode{ - pgx.QueryExecModeCacheStatement, - pgx.QueryExecModeCacheDescribe, - pgx.QueryExecModeDescribeExec, - pgx.QueryExecModeExec, - pgx.QueryExecModeSimpleProtocol, - } + modes = AllQueryExecModes } for _, mode := range modes { @@ -87,6 +100,51 @@ func RunWithQueryExecModes(ctx context.Context, t *testing.T, ctr ConnTestRunner } } +type ValueRoundTripTest struct { + Param interface{} + Result interface{} + Test func(interface{}) bool +} + +func RunValueRoundTripTests( + ctx context.Context, + t testing.TB, + ctr ConnTestRunner, + modes []pgx.QueryExecMode, + pgTypeName string, + tests []ValueRoundTripTest, +) { + t.Helper() + + if modes == nil { + modes = AllQueryExecModes + } + + ctr.RunTest(ctx, t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + t.Helper() + + sql := fmt.Sprintf("select $1::%s", pgTypeName) + + for i, tt := range tests { + for _, mode := range modes { + err := conn.QueryRow(ctx, sql, mode, tt.Param).Scan(tt.Result) + if err != nil { + t.Errorf("%d. %v: %v", i, mode, err) + } + + result := reflect.ValueOf(tt.Result) + if result.Kind() == reflect.Ptr { + result = result.Elem() + } + + if !tt.Test(result.Interface()) { + t.Errorf("%d. %v: unexpected result for %v: %v", i, mode, tt.Param, result.Interface()) + } + } + } + }) +} + // SkipCockroachDB calls Skip on t with msg if the connection is to a CockroachDB server. func SkipCockroachDB(t testing.TB, conn *pgx.Conn, msg string) { if conn.PgConn().ParameterStatus("crdb_version") != "" { diff --git a/values.go b/values.go index 595f2b4d..12e5db47 100644 --- a/values.go +++ b/values.go @@ -17,7 +17,7 @@ func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error) return nil, nil } - buf, err := m.Encode(0, TextFormatCode, arg, nil) + buf, err := m.Encode(0, TextFormatCode, arg, []byte{}) if err != nil { return nil, err } From 53ec52aa174c7fff458b7ec72dc1ffc8ffcf181d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 Apr 2022 14:41:33 -0500 Subject: [PATCH 0978/1158] Fix out of date pgtype/int_test.go.erb --- pgtype/int_test.go.erb | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pgtype/int_test.go.erb b/pgtype/int_test.go.erb index 799d8c32..d72d6bbd 100644 --- a/pgtype/int_test.go.erb +++ b/pgtype/int_test.go.erb @@ -21,8 +21,8 @@ func TestInt<%= pg_byte_size %>Codec(t *testing.T) { {uint64(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, {int(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, {uint(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, - {pgtype.Int<%= pg_byte_size %>{Int: 1, Valid: true}, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, - {int32(-1), new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{Int: -1, Valid: true})}, + {pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 1, Valid: true}, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int32(-1), new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: -1, Valid: true})}, {1, new(int8), isExpectedEq(int8(1))}, {1, new(int16), isExpectedEq(int16(1))}, {1, new(int32), isExpectedEq(int32(1))}, @@ -43,7 +43,7 @@ func TestInt<%= pg_byte_size %>Codec(t *testing.T) { {0, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(0))}, {1, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, {math.MaxInt<%= pg_bit_size %>, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(math.MaxInt<%= pg_bit_size %>))}, - {1, new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{Int: 1, Valid: true})}, + {1, new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 1, Valid: true})}, {pgtype.Int<%= pg_byte_size %>{}, new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{})}, {nil, new(*int<%= pg_bit_size %>), isExpectedEq((*int<%= pg_bit_size %>)(nil))}, }) @@ -54,8 +54,8 @@ func TestInt<%= pg_byte_size %>MarshalJSON(t *testing.T) { source pgtype.Int<%= pg_byte_size %> result string }{ - {source: pgtype.Int<%= pg_byte_size %>{Int: 0}, result: "null"}, - {source: pgtype.Int<%= pg_byte_size %>{Int: 1, Valid: true}, result: "1"}, + {source: pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 0}, result: "null"}, + {source: pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 1, Valid: true}, result: "1"}, } for i, tt := range successfulTests { r, err := tt.source.MarshalJSON() @@ -74,8 +74,8 @@ func TestInt<%= pg_byte_size %>UnmarshalJSON(t *testing.T) { source string result pgtype.Int<%= pg_byte_size %> }{ - {source: "null", result: pgtype.Int<%= pg_byte_size %>{Int: 0}}, - {source: "1", result: pgtype.Int<%= pg_byte_size %>{Int: 1, Valid: true}}, + {source: "null", result: pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 0}}, + {source: "1", result: pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 1, Valid: true}}, } for i, tt := range successfulTests { var r pgtype.Int<%= pg_byte_size %> From 8cf6721d66722439e94f939a4dee58895a071f40 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 Apr 2022 16:55:05 -0500 Subject: [PATCH 0979/1158] Better int64 / numeric compat and text fixes --- pgtype/numeric.go | 62 +++++++++++++++++++++++++++++++++--------- pgtype/numeric_test.go | 6 ++++ 2 files changed, 55 insertions(+), 13 deletions(-) diff --git a/pgtype/numeric.go b/pgtype/numeric.go index 426b6782..b9827b63 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -109,6 +109,33 @@ func (n Numeric) Float64Value() (Float8, error) { return Float8{Float64: f, Valid: true}, nil } +func (n *Numeric) ScanInt64(v Int8) error { + if !v.Valid { + *n = Numeric{} + return nil + } + + *n = Numeric{Int: big.NewInt(v.Int64), Valid: true} + return nil +} + +func (n Numeric) Int64Value() (Int8, error) { + if !n.Valid { + return Int8{}, nil + } + + bi, err := n.toBigInt() + if err != nil { + return Int8{}, err + } + + if !bi.IsInt64() { + return Int8{}, fmt.Errorf("cannot convert %v to int64", n) + } + + return Int8{Int64: bi.Int64(), Valid: true}, nil +} + func (n *Numeric) toBigInt() (*big.Int, error) { if n.Exp == 0 { return n.Int, nil @@ -450,18 +477,15 @@ func (encodePlanNumericCodecTextFloat64Valuer) Encode(value interface{}, buf []b } if math.IsNaN(n.Float64) { - return encodeNumericBinary(Numeric{NaN: true, Valid: true}, buf) + buf = append(buf, "NaN"...) } else if math.IsInf(n.Float64, 1) { - return encodeNumericBinary(Numeric{InfinityModifier: Infinity, Valid: true}, buf) + buf = append(buf, "Infinity"...) } else if math.IsInf(n.Float64, -1) { - return encodeNumericBinary(Numeric{InfinityModifier: NegativeInfinity, Valid: true}, buf) + buf = append(buf, "-Infinity"...) + } else { + buf = append(buf, strconv.FormatFloat(n.Float64, 'f', -1, 64)...) } - num, exp, err := parseNumericString(strconv.FormatFloat(n.Float64, 'f', -1, 64)) - if err != nil { - return nil, err - } - - return encodeNumericText(Numeric{Int: num, Exp: exp, Valid: true}, buf) + return buf, nil } type encodePlanNumericCodecTextInt64Valuer struct{} @@ -476,7 +500,8 @@ func (encodePlanNumericCodecTextInt64Valuer) Encode(value interface{}, buf []byt return nil, nil } - return encodeNumericText(Numeric{Int: big.NewInt(n.Int64), Valid: true}, buf) + buf = append(buf, strconv.FormatInt(n.Int64, 10)...) + return buf, nil } func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { @@ -495,9 +520,20 @@ func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { return buf, nil } - buf = append(buf, n.Int.String()...) - buf = append(buf, 'e') - buf = append(buf, strconv.FormatInt(int64(n.Exp), 10)...) + digits := n.Int.String() + if n.Exp >= 0 { + buf = append(buf, digits...) + if n.Exp > 0 { + for i := int32(0); i < n.Exp; i++ { + buf = append(buf, '0') + } + } + } else { + buf = append(buf, digits...) + buf = append(buf, 'e') + buf = append(buf, strconv.FormatInt(int64(n.Exp), 10)...) + } + return buf, nil } diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 8be8ce55..3c37ae18 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -113,6 +113,12 @@ func TestNumericCodec(t *testing.T) { {pgtype.Numeric{}, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, {nil, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, }) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int8", []pgxtest.ValueRoundTripTest{ + {mustParseNumeric(t, "-1"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "-1"))}, + {mustParseNumeric(t, "0"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0"))}, + {mustParseNumeric(t, "1"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))}, + }) } func TestNumericCodecInfinity(t *testing.T) { From 5ece2efd4c610fe2c0078f49695d91025ae758fa Mon Sep 17 00:00:00 2001 From: WGH Date: Mon, 28 Mar 2022 05:08:30 +0300 Subject: [PATCH 0980/1158] Fix typo in Record type documentation --- record.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/record.go b/record.go index 718c3570..5cf2c93a 100644 --- a/record.go +++ b/record.go @@ -6,7 +6,7 @@ import ( ) // Record is the generic PostgreSQL record type such as is created with the -// "row" function. Record only implements BinaryEncoder and Value. The text +// "row" function. Record only implements BinaryDecoder and Value. The text // format output format from PostgreSQL does not include type information and is // therefore impossible to decode. No encoders are implemented because // PostgreSQL does not support input of generic records. From 71648e3d78faa726cd7c2bd8f18239da276d3f84 Mon Sep 17 00:00:00 2001 From: WGH Date: Mon, 28 Mar 2022 04:57:54 +0300 Subject: [PATCH 0981/1158] Add defaults for typed_array.go.erb template parameters Most of the time binary_format is "true", and text_null is "NULL", so it makes sense to not repeat that. --- typed_array.go.erb | 7 +++++++ typed_array_gen.sh | 48 +++++++++++++++++++++++----------------------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/typed_array.go.erb b/typed_array.go.erb index 5788626b..b89ae164 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -1,5 +1,12 @@ // Code generated by erb. DO NOT EDIT. +<% + # defaults when not explicitly set on command line + + binary_format ||= "true" + text_null ||= "NULL" +%> + package pgtype import ( diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 1f4098c7..a9090cd9 100755 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -1,28 +1,28 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[]*bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time,[]*time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time,[]*time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go -erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange go_array_types=[]Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstzrange_array.go -erb pgtype_array_type=TsrangeArray pgtype_element_type=Tsrange go_array_types=[]Tsrange element_type_name=tsrange text_null=NULL binary_format=true typed_array.go.erb > tsrange_array.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time,[]*time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32,[]*float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64,[]*float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go -erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr go_array_types=[]net.HardwareAddr,[]*net.HardwareAddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go -erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string,[]*string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > text_array.go -erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string,[]*string element_type_name=varchar text_null=NULL binary_format=true typed_array.go.erb > varchar_array.go -erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string,[]*string element_type_name=bpchar text_null=NULL binary_format=true typed_array.go.erb > bpchar_array.go -erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go -erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string,[]*string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go -erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go -erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]*float32,[]float64,[]*float64,[]int64,[]*int64,[]uint64,[]*uint64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go -erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string,[]*string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go -erb pgtype_array_type=JSONBArray pgtype_element_type=JSONB go_array_types=[]string,[][]byte,[]json.RawMessage element_type_name=jsonb text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int2 typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int4 typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int8 typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[]*bool element_type_name=bool typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time,[]*time.Time element_type_name=date typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time,[]*time.Time element_type_name=timestamptz typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange go_array_types=[]Tstzrange element_type_name=tstzrange typed_array.go.erb > tstzrange_array.go +erb pgtype_array_type=TsrangeArray pgtype_element_type=Tsrange go_array_types=[]Tsrange element_type_name=tsrange typed_array.go.erb > tsrange_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time,[]*time.Time element_type_name=timestamp typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32,[]*float32 element_type_name=float4 typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64,[]*float64 element_type_name=float8 typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=inet typed_array.go.erb > inet_array.go +erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr go_array_types=[]net.HardwareAddr,[]*net.HardwareAddr element_type_name=macaddr typed_array.go.erb > macaddr_array.go +erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=cidr typed_array.go.erb > cidr_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string,[]*string element_type_name=text typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string,[]*string element_type_name=varchar typed_array.go.erb > varchar_array.go +erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string,[]*string element_type_name=bpchar typed_array.go.erb > bpchar_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea typed_array.go.erb > bytea_array.go +erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string,[]*string element_type_name=aclitem binary_format=false typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore typed_array.go.erb > hstore_array.go +erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]*float32,[]float64,[]*float64,[]int64,[]*int64,[]uint64,[]*uint64 element_type_name=numeric typed_array.go.erb > numeric_array.go +erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string,[]*string element_type_name=uuid typed_array.go.erb > uuid_array.go +erb pgtype_array_type=JSONBArray pgtype_element_type=JSONB go_array_types=[]string,[][]byte,[]json.RawMessage element_type_name=jsonb typed_array.go.erb > jsonb_array.go # While the binary format is theoretically possible it is only practical to use the text format. -erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string,[]*string text_null=NULL binary_format=false typed_array.go.erb > enum_array.go +erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string,[]*string binary_format=false typed_array.go.erb > enum_array.go goimports -w *_array.go From 5db1de5fc18703ddf645024dff9ff73c98bc9107 Mon Sep 17 00:00:00 2001 From: WGH Date: Mon, 28 Mar 2022 05:00:22 +0300 Subject: [PATCH 0982/1158] Make text format for type_array.go.erb opt-out Some types, like RECORD, don't have sane text format. If we want to have arrays of such types, we don't want to generate text format for such arrays either. --- typed_array.go.erb | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/typed_array.go.erb b/typed_array.go.erb index b89ae164..31debcd8 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -4,6 +4,8 @@ # defaults when not explicitly set on command line binary_format ||= "true" + text_format ||= "true" + text_null ||= "NULL" %> @@ -286,6 +288,7 @@ func (src *<%= pgtype_array_type %>) assignToRecursive(value reflect.Value, inde return index, nil } +<% if text_format == "true" %> func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} @@ -321,6 +324,7 @@ func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error return nil } +<% end %> <% if binary_format == "true" %> func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { @@ -366,6 +370,7 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) erro } <% end %> +<% if text_format == "true" %> func (src <%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: @@ -422,6 +427,7 @@ func (src <%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte return buf, nil } +<% end %> <% if binary_format == "true" %> func (src <%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { @@ -469,6 +475,7 @@ func (src <%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte } <% end %> +<% if text_format == "true" %> // Scan implements the database/sql Scanner interface. func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { if src == nil { @@ -499,3 +506,4 @@ func (src <%= pgtype_array_type %>) Value() (driver.Value, error) { return string(buf), nil } +<% end %> From 3e230ba7313cffe1aa245f913c30c0105509f822 Mon Sep 17 00:00:00 2001 From: WGH Date: Mon, 28 Mar 2022 05:03:18 +0300 Subject: [PATCH 0983/1158] Split encode_binary and decode_binary in typed_array.go.erb Again, RECORD, for example, has binary decoding, but no binary encoding. --- typed_array.go.erb | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/typed_array.go.erb b/typed_array.go.erb index 31debcd8..e8433c04 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -7,6 +7,9 @@ text_format ||= "true" text_null ||= "NULL" + + encode_binary ||= binary_format + decode_binary ||= binary_format %> package pgtype @@ -326,7 +329,7 @@ func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error } <% end %> -<% if binary_format == "true" %> +<% if decode_binary == "true" %> func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} @@ -429,7 +432,7 @@ func (src <%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte } <% end %> -<% if binary_format == "true" %> +<% if encode_binary == "true" %> func (src <%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: From ccb207cba5b4dd520b9cc40994eeb161ce140737 Mon Sep 17 00:00:00 2001 From: WGH Date: Mon, 28 Mar 2022 05:06:16 +0300 Subject: [PATCH 0984/1158] Add support for record array Like Record itself, it only implements BinaryDecoder, doesn't implement BinaryEncoder, and has no support for the text protocol. --- record_array.go | 318 +++++++++++++++++++++++++++++++++++++++++++ record_array_test.go | 104 ++++++++++++++ typed_array_gen.sh | 2 + 3 files changed, 424 insertions(+) create mode 100644 record_array.go create mode 100644 record_array_test.go diff --git a/record_array.go b/record_array.go new file mode 100644 index 00000000..2271717a --- /dev/null +++ b/record_array.go @@ -0,0 +1,318 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "encoding/binary" + "fmt" + "reflect" +) + +type RecordArray struct { + Elements []Record + Dimensions []ArrayDimension + Status Status +} + +func (dst *RecordArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = RecordArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case [][]Value: + if value == nil { + *dst = RecordArray{Status: Null} + } else if len(value) == 0 { + *dst = RecordArray{Status: Present} + } else { + elements := make([]Record, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = RecordArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Record: + if value == nil { + *dst = RecordArray{Status: Null} + } else if len(value) == 0 { + *dst = RecordArray{Status: Present} + } else { + *dst = RecordArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = RecordArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for RecordArray", src) + } + if elementsLength == 0 { + *dst = RecordArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to RecordArray", src) + } + + *dst = RecordArray{ + Elements: make([]Record, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Record, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to RecordArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *RecordArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to RecordArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in RecordArray", err) + } + index++ + + return index, nil +} + +func (dst RecordArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *RecordArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[][]Value: + *v = make([][]Value, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *RecordArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from RecordArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from RecordArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *RecordArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = RecordArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = RecordArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Record, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = RecordArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} diff --git a/record_array_test.go b/record_array_test.go new file mode 100644 index 00000000..9c92e333 --- /dev/null +++ b/record_array_test.go @@ -0,0 +1,104 @@ +package pgtype_test + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" +) + +var recordArrayTests = []struct { + sql string + expected pgtype.RecordArray +}{ + { + sql: `select array_agg((x::int4, x+100::int8)) from generate_series(0, 1) x;`, + expected: pgtype.RecordArray{ + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + }, + Elements: []pgtype.Record{ + { + Fields: []pgtype.Value{ + &pgtype.Int4{Int: 0, Status: pgtype.Present}, + &pgtype.Int8{Int: 100, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + { + Fields: []pgtype.Value{ + &pgtype.Int4{Int: 1, Status: pgtype.Present}, + &pgtype.Int8{Int: 101, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + Status: pgtype.Present, + }, + }, +} + +func TestRecordArrayTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + for i, tt := range recordArrayTests { + psName := fmt.Sprintf("test%d", i) + _, err := conn.Prepare(context.Background(), psName, tt.sql) + require.NoError(t, err) + + t.Run(tt.sql, func(t *testing.T) { + var result pgtype.RecordArray + err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result) + require.NoError(t, err) + + require.Equal(t, tt.expected, result) + }) + + } +} + +func TestRecordArrayAssignTo(t *testing.T) { + src := pgtype.RecordArray{ + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + }, + Elements: []pgtype.Record{ + { + Fields: []pgtype.Value{ + &pgtype.Int4{Int: 0, Status: pgtype.Present}, + &pgtype.Int8{Int: 100, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + { + Fields: []pgtype.Value{ + &pgtype.Int4{Int: 1, Status: pgtype.Present}, + &pgtype.Int8{Int: 101, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + Status: pgtype.Present, + } + dst := [][]pgtype.Value{} + err := src.AssignTo(&dst) + require.NoError(t, err) + + expected := [][]pgtype.Value{ + { + &pgtype.Int4{Int: 0, Status: pgtype.Present}, + &pgtype.Int8{Int: 100, Status: pgtype.Present}, + }, + { + &pgtype.Int4{Int: 1, Status: pgtype.Present}, + &pgtype.Int8{Int: 101, Status: pgtype.Present}, + }, + } + require.Equal(t, expected, dst) +} diff --git a/typed_array_gen.sh b/typed_array_gen.sh index a9090cd9..d922f1cb 100755 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -25,4 +25,6 @@ erb pgtype_array_type=JSONBArray pgtype_element_type=JSONB go_array_types=[]stri # While the binary format is theoretically possible it is only practical to use the text format. erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string,[]*string binary_format=false typed_array.go.erb > enum_array.go +erb pgtype_array_type=RecordArray pgtype_element_type=Record go_array_types=[][]Value element_type_name=record text_null=NULL encode_binary=false text_format=false typed_array.go.erb > record_array.go + goimports -w *_array.go From fa2b09640075ecdcb0b6b7ccef5ddcf26456b6b0 Mon Sep 17 00:00:00 2001 From: Mukundan Kavanur Kidambi Date: Tue, 29 Mar 2022 14:05:14 -0700 Subject: [PATCH 0985/1158] fix: Adding overall format before appending ColumnFormatCodes --- copy_both_response.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/copy_both_response.go b/copy_both_response.go index fbd985d8..4a1c3a07 100644 --- a/copy_both_response.go +++ b/copy_both_response.go @@ -48,7 +48,7 @@ func (src *CopyBothResponse) Encode(dst []byte) []byte { dst = append(dst, 'W') sp := len(dst) dst = pgio.AppendInt32(dst, -1) - + dst = append(dst, src.OverallFormat) dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) From 1d7886b01260a464e06d756375aeb70603210e4e Mon Sep 17 00:00:00 2001 From: Mukundan Kavanur Kidambi Date: Wed, 30 Mar 2022 15:45:42 -0700 Subject: [PATCH 0986/1158] Adding UTs --- copy_both_response_test.go | 22 ++++++++++++++++++++++ go.mod | 1 + go.sum | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+) create mode 100644 copy_both_response_test.go diff --git a/copy_both_response_test.go b/copy_both_response_test.go new file mode 100644 index 00000000..fb2c00d0 --- /dev/null +++ b/copy_both_response_test.go @@ -0,0 +1,22 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/jackc/pgproto3/v2" + "gotest.tools/v3/assert" +) + +func TestEncodeDecode(t *testing.T) { + src := pgproto3.CopyBothResponse{ + OverallFormat: byte(1), // Just to differ from defaults + ColumnFormatCodes: []uint16{0, 1}, + } + dstBytes := []byte{} + dstBytes = src.Encode(dstBytes) + dst := pgproto3.CopyBothResponse{} + err := dst.Decode(dstBytes[5:]) + assert.NilError(t, err, "No errors on decode") + assert.Equal(t, dst.OverallFormat, src.OverallFormat, "OverallFormat is decoded successfully") + assert.DeepEqual(t, dst.ColumnFormatCodes, src.ColumnFormatCodes) +} diff --git a/go.mod b/go.mod index 36041a94..ed5b4d5d 100644 --- a/go.mod +++ b/go.mod @@ -6,4 +6,5 @@ require ( github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/stretchr/testify v1.4.0 + gotest.tools/v3 v3.1.0 ) diff --git a/go.sum b/go.sum index dd9cd044..faffff19 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,46 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gotest.tools/v3 v3.1.0 h1:rVV8Tcg/8jHUkPUorwjaMTtemIMVXfIPKiOqnhEhakk= +gotest.tools/v3 v3.1.0/go.mod h1:fHy7eyTmJFO5bQbUsEGQ1v4m2J3Jz9eWL54TP2/ZuYQ= From e145003288619f513827479651048c3ad408bf13 Mon Sep 17 00:00:00 2001 From: Mukundan Kavanur Kidambi Date: Wed, 30 Mar 2022 16:16:03 -0700 Subject: [PATCH 0987/1158] Addressing feedback --- copy_both_response_test.go | 6 +++--- go.mod | 1 - go.sum | 31 ------------------------------- 3 files changed, 3 insertions(+), 35 deletions(-) diff --git a/copy_both_response_test.go b/copy_both_response_test.go index fb2c00d0..7aa7da22 100644 --- a/copy_both_response_test.go +++ b/copy_both_response_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/jackc/pgproto3/v2" - "gotest.tools/v3/assert" + "github.com/stretchr/testify/assert" ) func TestEncodeDecode(t *testing.T) { @@ -16,7 +16,7 @@ func TestEncodeDecode(t *testing.T) { dstBytes = src.Encode(dstBytes) dst := pgproto3.CopyBothResponse{} err := dst.Decode(dstBytes[5:]) - assert.NilError(t, err, "No errors on decode") + assert.NoError(t, err, "No errors on decode") assert.Equal(t, dst.OverallFormat, src.OverallFormat, "OverallFormat is decoded successfully") - assert.DeepEqual(t, dst.ColumnFormatCodes, src.ColumnFormatCodes) + assert.EqualValues(t, dst.ColumnFormatCodes, src.ColumnFormatCodes) } diff --git a/go.mod b/go.mod index ed5b4d5d..36041a94 100644 --- a/go.mod +++ b/go.mod @@ -6,5 +6,4 @@ require ( github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/stretchr/testify v1.4.0 - gotest.tools/v3 v3.1.0 ) diff --git a/go.sum b/go.sum index faffff19..af835086 100644 --- a/go.sum +++ b/go.sum @@ -1,46 +1,15 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gotest.tools/v3 v3.1.0 h1:rVV8Tcg/8jHUkPUorwjaMTtemIMVXfIPKiOqnhEhakk= -gotest.tools/v3 v3.1.0/go.mod h1:fHy7eyTmJFO5bQbUsEGQ1v4m2J3Jz9eWL54TP2/ZuYQ= From c6ccb4b9a3dd75e628009594cd6adf8b47172d1a Mon Sep 17 00:00:00 2001 From: Mukundan Kavanur Kidambi Date: Wed, 30 Mar 2022 16:22:28 -0700 Subject: [PATCH 0988/1158] Addressing feedback --- copy_both_response_test.go | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/copy_both_response_test.go b/copy_both_response_test.go index 7aa7da22..d9816fc8 100644 --- a/copy_both_response_test.go +++ b/copy_both_response_test.go @@ -8,15 +8,11 @@ import ( ) func TestEncodeDecode(t *testing.T) { - src := pgproto3.CopyBothResponse{ - OverallFormat: byte(1), // Just to differ from defaults - ColumnFormatCodes: []uint16{0, 1}, - } - dstBytes := []byte{} - dstBytes = src.Encode(dstBytes) - dst := pgproto3.CopyBothResponse{} - err := dst.Decode(dstBytes[5:]) + srcBytes := []byte{'W', 0x00, 0x00, 0x00, 0x0b, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01} + dstResp := pgproto3.CopyBothResponse{} + err := dstResp.Decode(srcBytes[5:]) assert.NoError(t, err, "No errors on decode") - assert.Equal(t, dst.OverallFormat, src.OverallFormat, "OverallFormat is decoded successfully") - assert.EqualValues(t, dst.ColumnFormatCodes, src.ColumnFormatCodes) + dstBytes := []byte{} + dstBytes = dstResp.Encode(dstBytes) + assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match") } From 5982e4b4f881f4a71c4cacb8f2addf3ac5386921 Mon Sep 17 00:00:00 2001 From: Matthew Gabeler-Lee Date: Mon, 4 Apr 2022 15:27:52 -0400 Subject: [PATCH 0989/1158] fix detection of database does not exist error during connect --- pgconn.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pgconn.go b/pgconn.go index 29889a74..9a496ed0 100644 --- a/pgconn.go +++ b/pgconn.go @@ -154,9 +154,12 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err break } else if pgerr, ok := err.(*PgError); ok { err = &connectError{config: config, msg: "server error", err: pgerr} - ERRCODE_INVALID_PASSWORD := "28P01" // wrong password - ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION := "28000" // db does not exist - if pgerr.Code == ERRCODE_INVALID_PASSWORD || pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION { + const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password + const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings + const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist + if pgerr.Code == ERRCODE_INVALID_PASSWORD || + pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION || + pgerr.Code == ERRCODE_INVALID_CATALOG_NAME { break } } From 829babcea9257c47b9d94b25bf8a907237476617 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Apr 2022 09:09:46 -0500 Subject: [PATCH 0990/1158] Better number to string handling Avoid ambiguity of stringWrapper implementing Int64Scanner and Float64Scanner. --- conn_test.go | 4 ++- pgtype/builtin_wrappers.go | 11 ------ pgtype/float4.go | 21 ++++++++++++ pgtype/float4_test.go | 1 + pgtype/float8.go | 21 ++++++++++++ pgtype/float8_test.go | 1 + pgtype/int.go | 69 ++++++++++++++++++++++++++++++++++++++ pgtype/int.go.erb | 25 ++++++++++++++ pgtype/int_test.go | 3 ++ pgtype/int_test.go.erb | 1 + pgtype/numeric.go | 47 ++++++++++++++++++-------- pgtype/numeric_test.go | 1 + values_test.go | 16 +-------- 13 files changed, 180 insertions(+), 41 deletions(-) diff --git a/conn_test.go b/conn_test.go index a618b936..41d04b55 100644 --- a/conn_test.go +++ b/conn_test.go @@ -841,7 +841,9 @@ func TestDomainType(t *testing.T) { // Domain type uint64 is a PostgreSQL domain of underlying type numeric. - // Unregistered type can be used as string. + // In the extended protocol preparing "select $1::uint64" appears to create a statement that expects a param OID of + // uint64 but a result OID of the underlying numeric. + var s string err := conn.QueryRow(context.Background(), "select $1::uint64", "24").Scan(&s) require.NoError(t, err) diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 30a88465..0f12ada3 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -5,7 +5,6 @@ import ( "math" "net" "reflect" - "strconv" "time" ) @@ -341,16 +340,6 @@ func (w stringWrapper) TextValue() (Text, error) { return Text{String: string(w), Valid: true}, nil } -func (w *stringWrapper) ScanInt64(v Int8) error { - if !v.Valid { - return fmt.Errorf("cannot scan NULL into *string") - } - - *w = stringWrapper(strconv.FormatInt(v.Int64, 10)) - - return nil -} - type timeWrapper time.Time func (w *timeWrapper) ScanDate(v Date) error { diff --git a/pgtype/float4.go b/pgtype/float4.go index fb84c124..2c628011 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -156,6 +156,8 @@ func (Float4Codec) PlanScan(m *Map, oid uint32, format int16, target interface{} return scanPlanBinaryFloat4ToFloat64Scanner{} case Int64Scanner: return scanPlanBinaryFloat4ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryFloat4ToTextScanner{} } case TextFormatCode: switch target.(type) { @@ -229,6 +231,25 @@ func (scanPlanBinaryFloat4ToInt64Scanner) Scan(src []byte, dst interface{}) erro return s.ScanInt64(Int8{Int64: i64, Valid: true}) } +type scanPlanBinaryFloat4ToTextScanner struct{} + +func (scanPlanBinaryFloat4ToTextScanner) Scan(src []byte, dst interface{}) error { + s := (dst).(TextScanner) + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for float4: %v", len(src)) + } + + ui32 := int32(binary.BigEndian.Uint32(src)) + f32 := math.Float32frombits(uint32(ui32)) + + return s.ScanText(Text{String: strconv.FormatFloat(float64(f32), 'f', -1, 32), Valid: true}) +} + type scanPlanTextAnyToFloat32 struct{} func (scanPlanTextAnyToFloat32) Scan(src []byte, dst interface{}) error { diff --git a/pgtype/float4_test.go b/pgtype/float4_test.go index 00b9addf..f155ed97 100644 --- a/pgtype/float4_test.go +++ b/pgtype/float4_test.go @@ -17,6 +17,7 @@ func TestFloat4Codec(t *testing.T) { {float32(9999.99), new(float32), isExpectedEq(float32(9999.99))}, {pgtype.Float4{}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{})}, {int64(1), new(int64), isExpectedEq(int64(1))}, + {"1.23", new(string), isExpectedEq("1.23")}, {nil, new(*float32), isExpectedEq((*float32)(nil))}, }) } diff --git a/pgtype/float8.go b/pgtype/float8.go index 664fb9f8..b7c6177e 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -194,6 +194,8 @@ func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target interface{} return scanPlanBinaryFloat8ToFloat64Scanner{} case Int64Scanner: return scanPlanBinaryFloat8ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryFloat8ToTextScanner{} } case TextFormatCode: switch target.(type) { @@ -267,6 +269,25 @@ func (scanPlanBinaryFloat8ToInt64Scanner) Scan(src []byte, dst interface{}) erro return s.ScanInt64(Int8{Int64: i64, Valid: true}) } +type scanPlanBinaryFloat8ToTextScanner struct{} + +func (scanPlanBinaryFloat8ToTextScanner) Scan(src []byte, dst interface{}) error { + s := (dst).(TextScanner) + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for float8: %v", len(src)) + } + + ui64 := int64(binary.BigEndian.Uint64(src)) + f64 := math.Float64frombits(uint64(ui64)) + + return s.ScanText(Text{String: strconv.FormatFloat(f64, 'f', -1, 64), Valid: true}) +} + type scanPlanTextAnyToFloat64 struct{} func (scanPlanTextAnyToFloat64) Scan(src []byte, dst interface{}) error { diff --git a/pgtype/float8_test.go b/pgtype/float8_test.go index 9c269072..496b718b 100644 --- a/pgtype/float8_test.go +++ b/pgtype/float8_test.go @@ -17,6 +17,7 @@ func TestFloat8Codec(t *testing.T) { {float64(9999.99), new(float64), isExpectedEq(float64(9999.99))}, {pgtype.Float8{}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{})}, {int64(1), new(int64), isExpectedEq(int64(1))}, + {"1.23", new(string), isExpectedEq("1.23")}, {nil, new(*float64), isExpectedEq((*float64)(nil))}, }) } diff --git a/pgtype/int.go b/pgtype/int.go index ee4ab932..b3eabceb 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -233,6 +233,8 @@ func (Int2Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) return scanPlanBinaryInt2ToUint{} case Int64Scanner: return scanPlanBinaryInt2ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryInt2ToTextScanner{} } case TextFormatCode: switch target.(type) { @@ -557,6 +559,27 @@ func (scanPlanBinaryInt2ToInt64Scanner) Scan(src []byte, dst interface{}) error return s.ScanInt64(Int8{Int64: n, Valid: true}) } +type scanPlanBinaryInt2ToTextScanner struct{} + +func (scanPlanBinaryInt2ToTextScanner) Scan(src []byte, dst interface{}) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + n := int64(int16(binary.BigEndian.Uint16(src))) + + return s.ScanText(Text{String: strconv.FormatInt(n, 10), Valid: true}) +} + type Int4 struct { Int32 int32 Valid bool @@ -770,6 +793,8 @@ func (Int4Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) return scanPlanBinaryInt4ToUint{} case Int64Scanner: return scanPlanBinaryInt4ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryInt4ToTextScanner{} } case TextFormatCode: switch target.(type) { @@ -1105,6 +1130,27 @@ func (scanPlanBinaryInt4ToInt64Scanner) Scan(src []byte, dst interface{}) error return s.ScanInt64(Int8{Int64: n, Valid: true}) } +type scanPlanBinaryInt4ToTextScanner struct{} + +func (scanPlanBinaryInt4ToTextScanner) Scan(src []byte, dst interface{}) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + n := int64(int32(binary.BigEndian.Uint32(src))) + + return s.ScanText(Text{String: strconv.FormatInt(n, 10), Valid: true}) +} + type Int8 struct { Int64 int64 Valid bool @@ -1318,6 +1364,8 @@ func (Int8Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) return scanPlanBinaryInt8ToUint{} case Int64Scanner: return scanPlanBinaryInt8ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryInt8ToTextScanner{} } case TextFormatCode: switch target.(type) { @@ -1675,6 +1723,27 @@ func (scanPlanBinaryInt8ToInt64Scanner) Scan(src []byte, dst interface{}) error return s.ScanInt64(Int8{Int64: n, Valid: true}) } +type scanPlanBinaryInt8ToTextScanner struct{} + +func (scanPlanBinaryInt8ToTextScanner) Scan(src []byte, dst interface{}) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + n := int64(int64(binary.BigEndian.Uint64(src))) + + return s.ScanText(Text{String: strconv.FormatInt(n, 10), Valid: true}) +} + type scanPlanTextAnyToInt8 struct{} func (scanPlanTextAnyToInt8) Scan(src []byte, dst interface{}) error { diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 81f28bba..aa1db7fc 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -234,6 +234,8 @@ func (Int<%= pg_byte_size %>Codec) PlanScan(m *Map, oid uint32, format int16, ta return scanPlanBinaryInt<%= pg_byte_size %>ToUint{} case Int64Scanner: return scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryInt<%= pg_byte_size %>ToTextScanner{} } case TextFormatCode: switch target.(type) { @@ -443,6 +445,29 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner) Scan(src []byte, dst i return s.ScanInt64(Int8{Int64: n, Valid: true}) } + +<%# PostgreSQL binary format integer to Go TextScanner %> +type scanPlanBinaryInt<%= pg_byte_size %>ToTextScanner struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToTextScanner) Scan(src []byte, dst interface{}) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for int<%= pg_byte_size %>: %v", len(src)) + } + + + n := int64(int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src))) + + return s.ScanText(Text{String: strconv.FormatInt(n, 10), Valid: true}) +} <% end %> <%# Any text to all integer types %> diff --git a/pgtype/int_test.go b/pgtype/int_test.go index c779bdc9..73294b3c 100644 --- a/pgtype/int_test.go +++ b/pgtype/int_test.go @@ -45,6 +45,7 @@ func TestInt2Codec(t *testing.T) { {1, new(int16), isExpectedEq(int16(1))}, {math.MaxInt16, new(int16), isExpectedEq(int16(math.MaxInt16))}, {1, new(pgtype.Int2), isExpectedEq(pgtype.Int2{Int16: 1, Valid: true})}, + {"1", new(string), isExpectedEq("1")}, {pgtype.Int2{}, new(pgtype.Int2), isExpectedEq(pgtype.Int2{})}, {nil, new(*int16), isExpectedEq((*int16)(nil))}, }) @@ -126,6 +127,7 @@ func TestInt4Codec(t *testing.T) { {1, new(int32), isExpectedEq(int32(1))}, {math.MaxInt32, new(int32), isExpectedEq(int32(math.MaxInt32))}, {1, new(pgtype.Int4), isExpectedEq(pgtype.Int4{Int32: 1, Valid: true})}, + {"1", new(string), isExpectedEq("1")}, {pgtype.Int4{}, new(pgtype.Int4), isExpectedEq(pgtype.Int4{})}, {nil, new(*int32), isExpectedEq((*int32)(nil))}, }) @@ -207,6 +209,7 @@ func TestInt8Codec(t *testing.T) { {1, new(int64), isExpectedEq(int64(1))}, {math.MaxInt64, new(int64), isExpectedEq(int64(math.MaxInt64))}, {1, new(pgtype.Int8), isExpectedEq(pgtype.Int8{Int64: 1, Valid: true})}, + {"1", new(string), isExpectedEq("1")}, {pgtype.Int8{}, new(pgtype.Int8), isExpectedEq(pgtype.Int8{})}, {nil, new(*int64), isExpectedEq((*int64)(nil))}, }) diff --git a/pgtype/int_test.go.erb b/pgtype/int_test.go.erb index d72d6bbd..ac9a3f14 100644 --- a/pgtype/int_test.go.erb +++ b/pgtype/int_test.go.erb @@ -44,6 +44,7 @@ func TestInt<%= pg_byte_size %>Codec(t *testing.T) { {1, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, {math.MaxInt<%= pg_bit_size %>, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(math.MaxInt<%= pg_bit_size %>))}, {1, new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 1, Valid: true})}, + {"1", new(string), isExpectedEq("1")}, {pgtype.Int<%= pg_byte_size %>{}, new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{})}, {nil, new(*int<%= pg_bit_size %>), isExpectedEq((*int<%= pg_bit_size %>)(nil))}, }) diff --git a/pgtype/numeric.go b/pgtype/numeric.go index b9827b63..5ca7d077 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -237,6 +237,11 @@ func (n Numeric) MarshalJSON() ([]byte, error) { return []byte(`"NaN"`), nil } + return n.numberTextBytes(), nil +} + +// numberString returns a string of the number. undefined if NaN, infinite, or NULL +func (n Numeric) numberTextBytes() []byte { intStr := n.Int.String() buf := &bytes.Buffer{} exp := int(n.Exp) @@ -263,7 +268,7 @@ func (n Numeric) MarshalJSON() ([]byte, error) { buf.WriteString(intStr) } - return buf.Bytes(), nil + return buf.Bytes() } type NumericCodec struct{} @@ -520,19 +525,7 @@ func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { return buf, nil } - digits := n.Int.String() - if n.Exp >= 0 { - buf = append(buf, digits...) - if n.Exp > 0 { - for i := int32(0); i < n.Exp; i++ { - buf = append(buf, '0') - } - } - } else { - buf = append(buf, digits...) - buf = append(buf, 'e') - buf = append(buf, strconv.FormatInt(int64(n.Exp), 10)...) - } + buf = append(buf, n.numberTextBytes()...) return buf, nil } @@ -548,6 +541,8 @@ func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target interface{ return scanPlanBinaryNumericToFloat64Scanner{} case Int64Scanner: return scanPlanBinaryNumericToInt64Scanner{} + case TextScanner: + return scanPlanBinaryNumericToTextScanner{} } case TextFormatCode: switch target.(type) { @@ -721,6 +716,30 @@ func (scanPlanBinaryNumericToInt64Scanner) Scan(src []byte, dst interface{}) err return scanner.ScanInt64(Int8{Int64: bigInt.Int64(), Valid: true}) } +type scanPlanBinaryNumericToTextScanner struct{} + +func (scanPlanBinaryNumericToTextScanner) Scan(src []byte, dst interface{}) error { + scanner := (dst).(TextScanner) + + if src == nil { + return scanner.ScanText(Text{}) + } + + var n Numeric + + err := scanPlanBinaryNumericToNumericScanner{}.Scan(src, &n) + if err != nil { + return err + } + + sbuf, err := encodeNumericText(n, nil) + if err != nil { + return err + } + + return scanner.ScanText(Text{String: string(sbuf), Valid: true}) +} + type scanPlanTextAnyToNumericScanner struct{} func (scanPlanTextAnyToNumericScanner) Scan(src []byte, dst interface{}) error { diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 3c37ae18..d95deaa5 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -110,6 +110,7 @@ func TestNumericCodec(t *testing.T) { {int64(math.MinInt64 + 1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MinInt64+1, 10)))}, {int64(math.MaxInt64), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MaxInt64, 10)))}, {int64(math.MaxInt64 - 1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MaxInt64-1, 10)))}, + {"1.23", new(string), isExpectedEq("1.23")}, {pgtype.Numeric{}, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, {nil, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, }) diff --git a/values_test.go b/values_test.go index 4282880c..04441b72 100644 --- a/values_test.go +++ b/values_test.go @@ -1022,6 +1022,7 @@ func TestScanIntoByteSlice(t *testing.T) { output []byte }{ {"int - text", "select 42", pgx.TextFormatCode, []byte("42")}, + {"int - binary", "select 42", pgx.BinaryFormatCode, []byte("42")}, {"text - text", "select 'hi'", pgx.TextFormatCode, []byte("hi")}, {"text - binary", "select 'hi'", pgx.BinaryFormatCode, []byte("hi")}, {"json - text", "select '{}'::json", pgx.TextFormatCode, []byte("{}")}, @@ -1036,19 +1037,4 @@ func TestScanIntoByteSlice(t *testing.T) { require.Equal(t, tt.output, buf) }) } - - // Failure cases - for _, tt := range []struct { - name string - sql string - err string - }{ - {"int binary", "select 42::int4", "can't scan into dest[0]: cannot scan OID 23 in binary format into *[]uint8"}, - } { - t.Run(tt.name, func(t *testing.T) { - var buf []byte - err := conn.QueryRow(context.Background(), tt.sql, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&buf) - require.EqualError(t, err, tt.err) - }) - } } From 95265a74214decf31f149ce9b137ca540b5be65a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Apr 2022 09:11:19 -0500 Subject: [PATCH 0991/1158] Use Go 1.18 --- .github/workflows/ci.yml | 2 +- go.mod | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index af164815..a905ad3e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: - go-version: [1.17] + go-version: [1.18] pg-version: [10, 11, 12, 13, 14, cockroachdb] include: - pg-version: 10 diff --git a/go.mod b/go.mod index 79cbd50d..f527c7fb 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/jackc/pgx/v5 -go 1.17 +go 1.18 require ( github.com/jackc/pgpassfile v1.0.0 From f14fb3d692783501b3e2bbd7f0130c5a954d7d6b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Apr 2022 09:12:55 -0500 Subject: [PATCH 0992/1158] Replace interface{} with any --- batch.go | 20 +- batch_test.go | 2 +- bench_test.go | 12 +- conn.go | 50 ++-- conn_test.go | 4 +- copy_from.go | 18 +- copy_from_test.go | 56 ++--- doc.go | 12 +- extended_query_builder.go | 8 +- helper_test.go | 2 +- internal/anynil/anynil.go | 6 +- internal/sanitize/sanitize.go | 6 +- internal/sanitize/sanitize_test.go | 30 +-- log/testingadapter/adapter.go | 6 +- logger.go | 6 +- pgproto3/chunkreader.go | 2 +- pgtype/array_codec.go | 22 +- pgtype/array_codec_test.go | 14 +- pgtype/array_getter_setter.go | 16 +- pgtype/array_getter_setter.go.erb | 8 +- pgtype/bits.go | 16 +- pgtype/bits_test.go | 4 +- pgtype/bool.go | 24 +- pgtype/box.go | 16 +- pgtype/builtin_wrappers.go | 24 +- pgtype/bytea.go | 24 +- pgtype/bytea_test.go | 4 +- pgtype/circle.go | 16 +- pgtype/composite.go | 36 +-- pgtype/composite_test.go | 6 +- pgtype/convert.go | 26 +-- pgtype/date.go | 16 +- pgtype/date_test.go | 4 +- pgtype/enum_codec.go | 10 +- pgtype/enum_codec_test.go | 2 +- pgtype/float4.go | 26 +-- pgtype/float8.go | 32 +-- pgtype/hstore.go | 16 +- pgtype/hstore_test.go | 8 +- pgtype/inet.go | 16 +- pgtype/inet_test.go | 4 +- pgtype/int.go | 142 ++++++------ pgtype/int.go.erb | 34 +-- pgtype/integration_benchmark_test.go | 280 +++++++++++------------ pgtype/integration_benchmark_test.go.erb | 8 +- pgtype/interval.go | 16 +- pgtype/json.go | 22 +- pgtype/json_test.go | 14 +- pgtype/jsonb.go | 12 +- pgtype/jsonb_test.go | 6 +- pgtype/line.go | 18 +- pgtype/lseg.go | 16 +- pgtype/macaddr.go | 18 +- pgtype/macaddr_test.go | 4 +- pgtype/numeric.go | 30 +-- pgtype/numeric_test.go | 8 +- pgtype/path.go | 16 +- pgtype/path_test.go | 4 +- pgtype/pgtype.go | 186 +++++++-------- pgtype/pgtype_test.go | 8 +- pgtype/point.go | 16 +- pgtype/polygon.go | 16 +- pgtype/polygon_test.go | 4 +- pgtype/qchar.go | 14 +- pgtype/range_codec.go | 26 +-- pgtype/range_codec_test.go | 2 +- pgtype/range_types.go | 28 +-- pgtype/range_types.go.erb | 4 +- pgtype/record_codec.go | 12 +- pgtype/record_codec_test.go | 12 +- pgtype/text.go | 28 +-- pgtype/tid.go | 18 +- pgtype/time.go | 16 +- pgtype/timestamp.go | 16 +- pgtype/timestamptz.go | 16 +- pgtype/uint32.go | 26 +-- pgtype/uuid.go | 16 +- pgtype/zeronull/float8.go | 2 +- pgtype/zeronull/float8_test.go | 6 +- pgtype/zeronull/int.go | 6 +- pgtype/zeronull/int.go.erb | 2 +- pgtype/zeronull/int_test.go | 6 +- pgtype/zeronull/int_test.go.erb | 2 +- pgtype/zeronull/text.go | 2 +- pgtype/zeronull/text_test.go | 2 +- pgtype/zeronull/timestamp.go | 2 +- pgtype/zeronull/timestamp_test.go | 6 +- pgtype/zeronull/timestamptz.go | 2 +- pgtype/zeronull/timestamptz_test.go | 6 +- pgtype/zeronull/uuid.go | 2 +- pgtype/zeronull/uuid_test.go | 2 +- pgxpool/batch_results.go | 4 +- pgxpool/common_test.go | 10 +- pgxpool/conn.go | 8 +- pgxpool/pool.go | 12 +- pgxpool/pool_test.go | 4 +- pgxpool/rows.go | 12 +- pgxpool/tx.go | 8 +- pgxtest/pgxtest.go | 6 +- query_test.go | 78 +++---- rows.go | 24 +- stdlib/sql.go | 16 +- stdlib/sql_test.go | 10 +- tx.go | 24 +- values.go | 4 +- values_test.go | 80 +++---- 106 files changed, 1045 insertions(+), 1045 deletions(-) diff --git a/batch.go b/batch.go index 689877a9..103d9aed 100644 --- a/batch.go +++ b/batch.go @@ -10,7 +10,7 @@ import ( type batchItem struct { query string - arguments []interface{} + arguments []any } // Batch queries are a way of bundling multiple queries together to avoid @@ -20,7 +20,7 @@ type Batch struct { } // Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. -func (b *Batch) Queue(query string, arguments ...interface{}) { +func (b *Batch) Queue(query string, arguments ...any) { b.items = append(b.items, &batchItem{ query: query, arguments: arguments, @@ -43,7 +43,7 @@ type BatchResults interface { QueryRow() Row // QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc. - QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) + QueryFunc(scans []any, f func(QueryFuncRow) error) (pgconn.CommandTag, error) // Close closes the batch operation. This must be called before the underlying connection can be used again. Any error // that occurred during a batch operation may have made it impossible to resyncronize the connection with the server. @@ -78,7 +78,7 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { err = errors.New("no result") } if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{ + br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{ "sql": query, "args": logQueryArgs(arguments), "err": err, @@ -91,14 +91,14 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { if err != nil { if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{ + br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{ "sql": query, "args": logQueryArgs(arguments), "err": err, }) } } else if br.conn.shouldLog(LogLevelInfo) { - br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]interface{}{ + br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]any{ "sql": query, "args": logQueryArgs(arguments), "commandTag": commandTag, @@ -134,7 +134,7 @@ func (br *batchResults) Query() (Rows, error) { rows.closed = true if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]interface{}{ + br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]any{ "sql": query, "args": logQueryArgs(arguments), "err": rows.err, @@ -149,7 +149,7 @@ func (br *batchResults) Query() (Rows, error) { } // QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc. -func (br *batchResults) QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { +func (br *batchResults) QueryFunc(scans []any, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { if br.closed { return pgconn.CommandTag{}, fmt.Errorf("batch already closed") } @@ -206,7 +206,7 @@ func (br *batchResults) Close() error { } if br.conn.shouldLog(LogLevelInfo) { - br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]interface{}{ + br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]any{ "sql": query, "args": logQueryArgs(args), }) @@ -216,7 +216,7 @@ func (br *batchResults) Close() error { return br.mrr.Close() } -func (br *batchResults) nextQueryAndArgs() (query string, args []interface{}, ok bool) { +func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { if br.b != nil && br.ix < len(br.b.items) { bi := br.b.items[br.ix] query = bi.query diff --git a/batch_test.go b/batch_test.go index 6a7abd37..5558b823 100644 --- a/batch_test.go +++ b/batch_test.go @@ -108,7 +108,7 @@ func TestConnSendBatch(t *testing.T) { } rowCount = 0 - _, err = br.QueryFunc([]interface{}{&id, &description, &amount}, func(pgx.QueryFuncRow) error { + _, err = br.QueryFunc([]any{&id, &description, &amount}, func(pgx.QueryFuncRow) error { if id != selectFromLedgerExpectedRows[rowCount].id { t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) } diff --git a/bench_test.go b/bench_test.go index e5913995..db27491f 100644 --- a/bench_test.go +++ b/bench_test.go @@ -278,7 +278,7 @@ func BenchmarkSelectWithoutLogging(b *testing.B) { type discardLogger struct{} -func (dl discardLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { +func (dl discardLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]any) { } func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) { @@ -438,7 +438,7 @@ const benchmarkWriteTableInsertSQL = `insert into t( type benchmarkWriteTableCopyFromSrc struct { count int idx int - row []interface{} + row []any } func (s *benchmarkWriteTableCopyFromSrc) Next() bool { @@ -446,7 +446,7 @@ func (s *benchmarkWriteTableCopyFromSrc) Next() bool { return s.idx < s.count } -func (s *benchmarkWriteTableCopyFromSrc) Values() ([]interface{}, error) { +func (s *benchmarkWriteTableCopyFromSrc) Values() ([]any, error) { return s.row, nil } @@ -457,7 +457,7 @@ func (s *benchmarkWriteTableCopyFromSrc) Err() error { func newBenchmarkWriteTableCopyFromSrc(count int) pgx.CopyFromSource { return &benchmarkWriteTableCopyFromSrc{ count: count, - row: []interface{}{ + row: []any{ "varchar_1", "varchar_2", &pgtype.Text{}, @@ -509,9 +509,9 @@ func benchmarkWriteNRowsViaInsert(b *testing.B, n int) { } } -type queryArgs []interface{} +type queryArgs []any -func (qa *queryArgs) Append(v interface{}) string { +func (qa *queryArgs) Append(v any) string { *qa = append(*qa, v) return "$" + strconv.Itoa(len(*qa)) } diff --git a/conn.go b/conn.go index e6396cd3..ca34cdf3 100644 --- a/conn.go +++ b/conn.go @@ -216,17 +216,17 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { config.Config.OnNotification = c.bufferNotifications } else { if c.shouldLog(LogLevelDebug) { - c.log(ctx, LogLevelDebug, "pgx notification handler disabled by application supplied OnNotification", map[string]interface{}{"host": config.Config.Host}) + c.log(ctx, LogLevelDebug, "pgx notification handler disabled by application supplied OnNotification", map[string]any{"host": config.Config.Host}) } } if c.shouldLog(LogLevelInfo) { - c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host}) + c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]any{"host": config.Config.Host}) } c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config) if err != nil { if c.shouldLog(LogLevelError) { - c.log(ctx, LogLevelError, "connect failed", map[string]interface{}{"err": err}) + c.log(ctx, LogLevelError, "connect failed", map[string]any{"err": err}) } return nil, err } @@ -278,7 +278,7 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.Statem if c.shouldLog(LogLevelError) { defer func() { if err != nil { - c.log(ctx, LogLevelError, "Prepare failed", map[string]interface{}{"err": err, "name": name, "sql": sql}) + c.log(ctx, LogLevelError, "Prepare failed", map[string]any{"err": err, "name": name, "sql": sql}) } }() } @@ -345,9 +345,9 @@ func (c *Conn) shouldLog(lvl LogLevel) bool { return c.logger != nil && c.logLevel >= lvl } -func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) { +func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]any) { if data == nil { - data = map[string]interface{}{} + data = map[string]any{} } if c.pgConn != nil && c.pgConn.PID() != 0 { data["pid"] = c.pgConn.PID() @@ -382,26 +382,26 @@ func (c *Conn) Config() *ConnConfig { return c.config.Copy() } // Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced // positionally from the sql string as $1, $2, etc. -func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { +func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { startTime := time.Now() commandTag, err := c.exec(ctx, sql, arguments...) if err != nil { if c.shouldLog(LogLevelError) { - c.log(ctx, LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err}) + c.log(ctx, LogLevelError, "Exec", map[string]any{"sql": sql, "args": logQueryArgs(arguments), "err": err}) } return commandTag, err } if c.shouldLog(LogLevelInfo) { endTime := time.Now() - c.log(ctx, LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) + c.log(ctx, LogLevelInfo, "Exec", map[string]any{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) } return commandTag, err } -func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { +func (c *Conn) exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) { mode := c.config.DefaultQueryExecMode optionLoop: @@ -460,7 +460,7 @@ optionLoop: } } -func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []interface{}) (commandTag pgconn.CommandTag, err error) { +func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []any) (commandTag pgconn.CommandTag, err error) { if len(arguments) > 0 { sql, err = c.sanitizeForSimpleQuery(sql, arguments...) if err != nil { @@ -476,7 +476,7 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []i return commandTag, err } -func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, args []interface{}) error { +func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, args []any) error { if len(sd.ParamOIDs) != len(args) { return fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args)) } @@ -500,7 +500,7 @@ func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, args return nil } -func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { +func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) { err := c.execParamsAndPreparedPrefix(sd, arguments) if err != nil { return pgconn.CommandTag{}, err @@ -511,7 +511,7 @@ func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, return result.CommandTag, result.Err } -func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { +func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) { err := c.execParamsAndPreparedPrefix(sd, arguments) if err != nil { return pgconn.CommandTag{}, err @@ -523,14 +523,14 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription } type unknownArgumentTypeQueryExecModeExecError struct { - arg interface{} + arg any } func (e *unknownArgumentTypeQueryExecModeExecError) Error() string { return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg) } -func (c *Conn) execSQLParams(ctx context.Context, sql string, args []interface{}) (pgconn.CommandTag, error) { +func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) { c.eqb.Reset() anynil.NormalizeSlice(args) @@ -557,7 +557,7 @@ func (c *Conn) execSQLParams(ctx context.Context, sql string, args []interface{} // // Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is // no way to safely use binary or to specify the parameter OIDs. -func (c *Conn) appendParamsForQueryExecModeExec(args []interface{}) error { +func (c *Conn) appendParamsForQueryExecModeExec(args []any) error { for _, arg := range args { if arg == nil { err := c.eqb.AppendParamFormat(c.typeMap, 0, TextFormatCode, arg) @@ -602,7 +602,7 @@ func (c *Conn) appendParamsForQueryExecModeExec(args []interface{}) error { return nil } -func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows { +func (c *Conn) getRows(ctx context.Context, sql string, args []any) *connRows { r := &connRows{} r.ctx = ctx @@ -691,7 +691,7 @@ type QueryResultFormatsByOID map[uint32]int16 // For extra control over how the query is executed, the types QueryExecMode, QueryResultFormats, and // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. -func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { +func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) { var resultFormats QueryResultFormats var resultFormatsByOID QueryResultFormatsByOID mode := c.config.DefaultQueryExecMode @@ -829,7 +829,7 @@ optionLoop: // QueryRow is a convenience wrapper over Query. Any error that occurs while // querying is deferred until calling Scan on the returned Row. That Row will // error with ErrNoRows if no rows are returned. -func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { +func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row { rows, _ := c.Query(ctx, sql, args...) return (*connRow)(rows.(*connRows)) } @@ -850,7 +850,7 @@ type QueryFuncRow interface { // QueryFunc executes sql with args. For each row returned by the query the values will scanned into the elements of // scans and f will be called. If any row fails to scan or f returns an error the query will be aborted and the error // will be returned. -func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { +func (c *Conn) QueryFunc(ctx context.Context, sql string, args []any, scans []any, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { rows, err := c.Query(ctx, sql, args...) if err != nil { return pgconn.CommandTag{}, err @@ -1018,7 +1018,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { } } -func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) { +func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) { if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" { return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on") } @@ -1028,7 +1028,7 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, } var err error - valueArgs := make([]interface{}, len(args)) + valueArgs := make([]any, len(args)) for i, a := range args { valueArgs[i], err = convertSimpleArgument(c.typeMap, a) if err != nil { @@ -1108,8 +1108,8 @@ func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.Com from pg_attribute where attrelid=$1 order by attnum`, - []interface{}{typrelid}, - []interface{}{&fieldName, &fieldOID}, + []any{typrelid}, + []any{&fieldName, &fieldOID}, func(qfr QueryFuncRow) error { dt, ok := c.TypeMap().TypeForOID(fieldOID) if !ok { diff --git a/conn_test.go b/conn_test.go index 41d04b55..61f6c951 100644 --- a/conn_test.go +++ b/conn_test.go @@ -711,14 +711,14 @@ func TestInsertTimestampArray(t *testing.T) { type testLog struct { lvl pgx.LogLevel msg string - data map[string]interface{} + data map[string]any } type testLogger struct { logs []testLog } -func (l *testLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { +func (l *testLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]any) { data["ctxdata"] = ctx.Value("ctxdata") l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) } diff --git a/copy_from.go b/copy_from.go index ef982269..c1a66d52 100644 --- a/copy_from.go +++ b/copy_from.go @@ -13,12 +13,12 @@ import ( // CopyFromRows returns a CopyFromSource interface over the provided rows slice // making it usable by *Conn.CopyFrom. -func CopyFromRows(rows [][]interface{}) CopyFromSource { +func CopyFromRows(rows [][]any) CopyFromSource { return ©FromRows{rows: rows, idx: -1} } type copyFromRows struct { - rows [][]interface{} + rows [][]any idx int } @@ -27,7 +27,7 @@ func (ctr *copyFromRows) Next() bool { return ctr.idx < len(ctr.rows) } -func (ctr *copyFromRows) Values() ([]interface{}, error) { +func (ctr *copyFromRows) Values() ([]any, error) { return ctr.rows[ctr.idx], nil } @@ -37,12 +37,12 @@ func (ctr *copyFromRows) Err() error { // CopyFromSlice returns a CopyFromSource interface over a dynamic func // making it usable by *Conn.CopyFrom. -func CopyFromSlice(length int, next func(int) ([]interface{}, error)) CopyFromSource { +func CopyFromSlice(length int, next func(int) ([]any, error)) CopyFromSource { return ©FromSlice{next: next, idx: -1, len: length} } type copyFromSlice struct { - next func(int) ([]interface{}, error) + next func(int) ([]any, error) idx int len int err error @@ -53,7 +53,7 @@ func (cts *copyFromSlice) Next() bool { return cts.idx < cts.len } -func (cts *copyFromSlice) Values() ([]interface{}, error) { +func (cts *copyFromSlice) Values() ([]any, error) { values, err := cts.next(cts.idx) if err != nil { cts.err = err @@ -73,7 +73,7 @@ type CopyFromSource interface { Next() bool // Values returns the values for the current row. - Values() ([]interface{}, error) + Values() ([]any, error) // Err returns any error that has been encountered by the CopyFromSource. If // this is not nil *Conn.CopyFrom will abort the copy. @@ -156,10 +156,10 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { if err == nil { if ct.conn.shouldLog(LogLevelInfo) { endTime := time.Now() - ct.conn.log(ctx, LogLevelInfo, "CopyFrom", map[string]interface{}{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected}) + ct.conn.log(ctx, LogLevelInfo, "CopyFrom", map[string]any{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected}) } } else if ct.conn.shouldLog(LogLevelError) { - ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]interface{}{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames}) + ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]any{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames}) } return rowsAffected, err diff --git a/copy_from_test.go b/copy_from_test.go index 54e2e52b..d979d2dc 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -32,7 +32,7 @@ func TestConnCopyFromSmall(t *testing.T) { tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) - inputRows := [][]interface{}{ + inputRows := [][]any{ {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, {nil, nil, nil, nil, nil, nil, nil}, } @@ -50,7 +50,7 @@ func TestConnCopyFromSmall(t *testing.T) { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -88,13 +88,13 @@ func TestConnCopyFromSliceSmall(t *testing.T) { tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) - inputRows := [][]interface{}{ + inputRows := [][]any{ {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, {nil, nil, nil, nil, nil, nil, nil}, } copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, - pgx.CopyFromSlice(len(inputRows), func(i int) ([]interface{}, error) { + pgx.CopyFromSlice(len(inputRows), func(i int) ([]any, error) { return inputRows[i], nil })) if err != nil { @@ -109,7 +109,7 @@ func TestConnCopyFromSliceSmall(t *testing.T) { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -150,10 +150,10 @@ func TestConnCopyFromLarge(t *testing.T) { tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) - inputRows := [][]interface{}{} + inputRows := [][]any{} for i := 0; i < 10000; i++ { - inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}}) + inputRows = append(inputRows, []any{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}}) } copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows)) @@ -169,7 +169,7 @@ func TestConnCopyFromLarge(t *testing.T) { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -230,7 +230,7 @@ func TestConnCopyFromEnum(t *testing.T) { )`) require.NoError(t, err) - inputRows := [][]interface{}{ + inputRows := [][]any{ {"abc", "blue", "grape", "orange", "orange", "def"}, {nil, nil, nil, nil, nil, nil}, } @@ -242,7 +242,7 @@ func TestConnCopyFromEnum(t *testing.T) { rows, err := conn.Query(ctx, "select * from foo") require.NoError(t, err) - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() require.NoError(t, err) @@ -275,8 +275,8 @@ func TestConnCopyFromJSON(t *testing.T) { b jsonb )`) - inputRows := [][]interface{}{ - {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}}, + inputRows := [][]any{ + {map[string]any{"foo": "bar"}, map[string]any{"bar": "quz"}}, {nil, nil}, } @@ -293,7 +293,7 @@ func TestConnCopyFromJSON(t *testing.T) { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -323,12 +323,12 @@ func (cfs *clientFailSource) Next() bool { return cfs.count < 100 } -func (cfs *clientFailSource) Values() ([]interface{}, error) { +func (cfs *clientFailSource) Values() ([]any, error) { if cfs.count == 3 { cfs.err = fmt.Errorf("client error") return nil, cfs.err } - return []interface{}{make([]byte, 100000)}, nil + return []any{make([]byte, 100000)}, nil } func (cfs *clientFailSource) Err() error { @@ -346,7 +346,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) { b varchar not null )`) - inputRows := [][]interface{}{ + inputRows := [][]any{ {int32(1), "abc"}, {int32(2), nil}, // this row should trigger a failure {int32(3), "def"}, @@ -368,7 +368,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -400,11 +400,11 @@ func (fs *failSource) Next() bool { return fs.count < 100 } -func (fs *failSource) Values() ([]interface{}, error) { +func (fs *failSource) Values() ([]any, error) { if fs.count == 3 { - return []interface{}{nil}, nil + return []any{nil}, nil } - return []interface{}{make([]byte, 100000)}, nil + return []any{make([]byte, 100000)}, nil } func (fs *failSource) Err() error { @@ -447,7 +447,7 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -477,11 +477,11 @@ func (fs *slowFailRaceSource) Next() bool { return fs.count < 1000 } -func (fs *slowFailRaceSource) Values() ([]interface{}, error) { +func (fs *slowFailRaceSource) Values() ([]any, error) { if fs.count == 500 { - return []interface{}{nil, nil}, nil + return []any{nil, nil}, nil } - return []interface{}{1, make([]byte, 1000)}, nil + return []any{1, make([]byte, 1000)}, nil } func (fs *slowFailRaceSource) Err() error { @@ -536,7 +536,7 @@ func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -565,8 +565,8 @@ func (cfs *clientFinalErrSource) Next() bool { return cfs.count < 5 } -func (cfs *clientFinalErrSource) Values() ([]interface{}, error) { - return []interface{}{make([]byte, 100000)}, nil +func (cfs *clientFinalErrSource) Values() ([]any, error) { + return []any{make([]byte, 100000)}, nil } func (cfs *clientFinalErrSource) Err() error { @@ -596,7 +596,7 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { diff --git a/doc.go b/doc.go index 222f9047..660cf5a3 100644 --- a/doc.go +++ b/doc.go @@ -88,8 +88,8 @@ QueryFunc can be used to execute a callback function for every row. This is ofte _, err = conn.QueryFunc( context.Background(), "select generate_series(1,$1)", - []interface{}{10}, - []interface{}{&n}, + []any{10}, + []any{&n}, func(pgx.QueryFuncRow) error { sum += n return nil @@ -273,10 +273,10 @@ for information on how to customize or disable the statement cache. Copy Protocol Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. CopyFrom accepts a -CopyFromSource interface. If the data is already in a [][]interface{} use CopyFromRows to wrap it in a CopyFromSource +CopyFromSource interface. If the data is already in a [][]any use CopyFromRows to wrap it in a CopyFromSource interface. Or implement CopyFromSource to avoid buffering the entire data set in memory. - rows := [][]interface{}{ + rows := [][]any{ {"John", "Smith", int32(36)}, {"Jane", "Doe", int32(29)}, } @@ -299,8 +299,8 @@ When you already have a typed array using CopyFromSlice can be more convenient. context.Background(), pgx.Identifier{"people"}, []string{"first_name", "last_name", "age"}, - pgx.CopyFromSlice(len(rows), func(i int) ([]interface{}, error) { - return []interface{}{rows[i].FirstName, rows[i].LastName, rows[i].Age}, nil + pgx.CopyFromSlice(len(rows), func(i int) ([]any, error) { + return []any{rows[i].FirstName, rows[i].LastName, rows[i].Age}, nil }), ) diff --git a/extended_query_builder.go b/extended_query_builder.go index 0b6e1962..e69d0b36 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -12,12 +12,12 @@ type extendedQueryBuilder struct { resultFormats []int16 } -func (eqb *extendedQueryBuilder) AppendParam(m *pgtype.Map, oid uint32, arg interface{}) error { +func (eqb *extendedQueryBuilder) AppendParam(m *pgtype.Map, oid uint32, arg any) error { f := eqb.chooseParameterFormatCode(m, oid, arg) return eqb.AppendParamFormat(m, oid, f, arg) } -func (eqb *extendedQueryBuilder) AppendParamFormat(m *pgtype.Map, oid uint32, format int16, arg interface{}) error { +func (eqb *extendedQueryBuilder) AppendParamFormat(m *pgtype.Map, oid uint32, format int16, arg any) error { eqb.paramFormats = append(eqb.paramFormats, format) v, err := eqb.encodeExtendedParamValue(m, oid, format, arg) @@ -56,7 +56,7 @@ func (eqb *extendedQueryBuilder) Reset() { } } -func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg interface{}) ([]byte, error) { +func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) { if anynil.Is(arg) { return nil, nil } @@ -81,7 +81,7 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uin // chooseParameterFormatCode determines the correct format code for an // argument to a prepared statement. It defaults to TextFormatCode if no // determination can be made. -func (eqb *extendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg interface{}) int16 { +func (eqb *extendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg any) int16 { switch arg.(type) { case string, *string: return TextFormatCode diff --git a/helper_test.go b/helper_test.go index 461dfcab..f091d23e 100644 --- a/helper_test.go +++ b/helper_test.go @@ -53,7 +53,7 @@ func closeConn(t testing.TB, conn *pgx.Conn) { } } -func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag) { +func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...any) (commandTag pgconn.CommandTag) { var err error if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil { t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err) diff --git a/internal/anynil/anynil.go b/internal/anynil/anynil.go index 57a45b95..9a48c1a8 100644 --- a/internal/anynil/anynil.go +++ b/internal/anynil/anynil.go @@ -3,7 +3,7 @@ package anynil import "reflect" // Is returns true if value is any type of nil. e.g. nil or []byte(nil). -func Is(value interface{}) bool { +func Is(value any) bool { if value == nil { return true } @@ -18,7 +18,7 @@ func Is(value interface{}) bool { } // Normalize converts typed nils (e.g. []byte(nil)) into untyped nil. Other values are returned unmodified. -func Normalize(v interface{}) interface{} { +func Normalize(v any) any { if Is(v) { return nil } @@ -27,7 +27,7 @@ func Normalize(v interface{}) interface{} { // NormalizeSlice converts all typed nils (e.g. []byte(nil)) in s into untyped nils. Other values are unmodified. s is // mutated in place. -func NormalizeSlice(s []interface{}) { +func NormalizeSlice(s []any) { for i := range s { if Is(s[i]) { s[i] = nil diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index 2dba3b81..64e67ca6 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -12,13 +12,13 @@ import ( // Part is either a string or an int. A string is raw SQL. An int is a // argument placeholder. -type Part interface{} +type Part any type Query struct { Parts []Part } -func (q *Query) Sanitize(args ...interface{}) (string, error) { +func (q *Query) Sanitize(args ...any) (string, error) { argUse := make([]bool, len(args)) buf := &bytes.Buffer{} @@ -295,7 +295,7 @@ func multilineCommentState(l *sqlLexer) stateFn { // SanitizeSQL replaces placeholder values with args. It quotes and escapes args // as necessary. This function is only safe when standard_conforming_strings is // on. -func SanitizeSQL(sql string, args ...interface{}) (string, error) { +func SanitizeSQL(sql string, args ...any) (string, error) { query, err := NewQuery(sql) if err != nil { return "", err diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go index acfac2ec..7b4c08ef 100644 --- a/internal/sanitize/sanitize_test.go +++ b/internal/sanitize/sanitize_test.go @@ -107,57 +107,57 @@ func TestNewQuery(t *testing.T) { func TestQuerySanitize(t *testing.T) { successfulTests := []struct { query sanitize.Query - args []interface{} + args []any expected string }{ { query: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, - args: []interface{}{}, + args: []any{}, expected: `select 42`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{int64(42)}, + args: []any{int64(42)}, expected: `select 42`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{float64(1.23)}, + args: []any{float64(1.23)}, expected: `select 1.23`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{true}, + args: []any{true}, expected: `select true`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{[]byte{0, 1, 2, 3, 255}}, + args: []any{[]byte{0, 1, 2, 3, 255}}, expected: `select '\x00010203ff'`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{nil}, + args: []any{nil}, expected: `select null`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{"foobar"}, + args: []any{"foobar"}, expected: `select 'foobar'`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{"foo'bar"}, + args: []any{"foo'bar"}, expected: `select 'foo''bar'`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{`foo\'bar`}, + args: []any{`foo\'bar`}, expected: `select 'foo\''bar'`, }, { query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}}, - args: []interface{}{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)}, + args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)}, expected: `insert '2020-03-01 23:59:59.999999Z'`, }, } @@ -176,22 +176,22 @@ func TestQuerySanitize(t *testing.T) { errorTests := []struct { query sanitize.Query - args []interface{} + args []any expected string }{ { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}}, - args: []interface{}{int64(42)}, + args: []any{int64(42)}, expected: `insufficient arguments`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}}, - args: []interface{}{int64(42)}, + args: []any{int64(42)}, expected: `unused argument: 0`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{42}, + args: []any{42}, expected: `invalid arg type: int`, }, } diff --git a/log/testingadapter/adapter.go b/log/testingadapter/adapter.go index aa1b4bd6..65c14157 100644 --- a/log/testingadapter/adapter.go +++ b/log/testingadapter/adapter.go @@ -12,7 +12,7 @@ import ( // TestingLogger interface defines the subset of testing.TB methods used by this // adapter. type TestingLogger interface { - Log(args ...interface{}) + Log(args ...any) } type Logger struct { @@ -23,8 +23,8 @@ func NewLogger(l TestingLogger) *Logger { return &Logger{l: l} } -func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { - logArgs := make([]interface{}, 0, 2+len(data)) +func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]any) { + logArgs := make([]any, 0, 2+len(data)) logArgs = append(logArgs, level, msg) for k, v := range data { logArgs = append(logArgs, fmt.Sprintf("%s=%v", k, v)) diff --git a/logger.go b/logger.go index 89fd5af5..02a1e8e4 100644 --- a/logger.go +++ b/logger.go @@ -44,7 +44,7 @@ func (ll LogLevel) String() string { // Logger is the interface used to get logging from pgx internals. type Logger interface { // Log a message at the given level with data key/value pairs. data may be nil. - Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) + Log(ctx context.Context, level LogLevel, msg string, data map[string]any) } // LogLevelFromString converts log level string to constant @@ -75,8 +75,8 @@ func LogLevelFromString(s string) (LogLevel, error) { } } -func logQueryArgs(args []interface{}) []interface{} { - logArgs := make([]interface{}, 0, len(args)) +func logQueryArgs(args []any) []any { + logArgs := make([]any, 0, len(args)) for _, a := range args { switch v := a.(type) { diff --git a/pgproto3/chunkreader.go b/pgproto3/chunkreader.go index 2d116c91..8834f521 100644 --- a/pgproto3/chunkreader.go +++ b/pgproto3/chunkreader.go @@ -20,7 +20,7 @@ func init() { for i := range bigBufPools { byteSize := bigBufSizes[i] bigBufPools[i] = &bigBufPool{ - pool: sync.Pool{New: func() interface{} { return make([]byte, byteSize) }}, + pool: sync.Pool{New: func() any { return make([]byte, byteSize) }}, byteSize: byteSize, } } diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 84012083..379a9096 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -15,10 +15,10 @@ type ArrayGetter interface { Dimensions() []ArrayDimension // Index returns the element at i. - Index(i int) interface{} + Index(i int) any // IndexType returns a non-nil scan target of the type Index will return. This is used by ArrayCodec.PlanEncode. - IndexType() interface{} + IndexType() any } // ArraySetter is a type can be set from a PostgreSQL array. @@ -29,11 +29,11 @@ type ArraySetter interface { SetDimensions(dimensions []ArrayDimension) error // ScanIndex returns a value usable as a scan target for i. SetDimensions must be called before ScanIndex. - ScanIndex(i int) interface{} + ScanIndex(i int) any // ScanIndexType returns a non-nil scan target of the type ScanIndex will return. This is used by // ArrayCodec.PlanScan. - ScanIndexType() interface{} + ScanIndexType() any } // ArrayCodec is a codec for any array type. @@ -49,7 +49,7 @@ func (c *ArrayCodec) PreferredFormat() int16 { return c.ElementType.Codec.PreferredFormat() } -func (c *ArrayCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (c *ArrayCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { arrayValuer, ok := value.(ArrayGetter) if !ok { return nil @@ -78,7 +78,7 @@ type encodePlanArrayCodecText struct { oid uint32 } -func (p *encodePlanArrayCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (p *encodePlanArrayCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { array := value.(ArrayGetter) dimensions := array.Dimensions() @@ -157,7 +157,7 @@ type encodePlanArrayCodecBinary struct { oid uint32 } -func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (p *encodePlanArrayCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { array := value.(ArrayGetter) dimensions := array.Dimensions() @@ -210,7 +210,7 @@ func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newB return buf, nil } -func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { arrayScanner, ok := target.(ArraySetter) if !ok { return nil @@ -315,7 +315,7 @@ type scanPlanArrayCodec struct { elementScanPlan ScanPlan } -func (spac *scanPlanArrayCodec) Scan(src []byte, dst interface{}) error { +func (spac *scanPlanArrayCodec) Scan(src []byte, dst any) error { c := spac.arrayCodec m := spac.m oid := spac.oid @@ -354,12 +354,12 @@ func (c *ArrayCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr } } -func (c *ArrayCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c *ArrayCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } - var slice []interface{} + var slice []any err := m.PlanScan(oid, format, &slice).Scan(src, &slice) return slice, err } diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go index 55ab814e..e4c00d1e 100644 --- a/pgtype/array_codec_test.go +++ b/pgtype/array_codec_test.go @@ -12,7 +12,7 @@ import ( func TestArrayCodec(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { for i, tt := range []struct { - expected interface{} + expected any }{ {[]int16(nil)}, {[]int16{}}, @@ -31,7 +31,7 @@ func TestArrayCodec(t *testing.T) { newInt16 := func(n int16) *int16 { return &n } for i, tt := range []struct { - expected interface{} + expected any }{ {[]*int16{newInt16(1), nil, newInt16(3), nil, newInt16(5)}}, } { @@ -52,7 +52,7 @@ func TestArrayCodecAnySlice(t *testing.T) { type _int16Slice []int16 for i, tt := range []struct { - expected interface{} + expected any }{ {_int16Slice(nil)}, {_int16Slice{}}, @@ -74,19 +74,19 @@ func TestArrayCodecDecodeValue(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { for _, tt := range []struct { sql string - expected interface{} + expected any }{ { sql: `select '{}'::int4[]`, - expected: []interface{}{}, + expected: []any{}, }, { sql: `select '{1,2}'::int8[]`, - expected: []interface{}{int64(1), int64(2)}, + expected: []any{int64(1), int64(2)}, }, { sql: `select '{foo,bar}'::text[]`, - expected: []interface{}{"foo", "bar"}, + expected: []any{"foo", "bar"}, }, } { t.Run(tt.sql, func(t *testing.T) { diff --git a/pgtype/array_getter_setter.go b/pgtype/array_getter_setter.go index 2e20f9ec..b0c6b505 100644 --- a/pgtype/array_getter_setter.go +++ b/pgtype/array_getter_setter.go @@ -11,11 +11,11 @@ func (a int16Array) Dimensions() []ArrayDimension { return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} } -func (a int16Array) Index(i int) interface{} { +func (a int16Array) Index(i int) any { return a[i] } -func (a int16Array) IndexType() interface{} { +func (a int16Array) IndexType() any { var el int16 return el } @@ -31,11 +31,11 @@ func (a *int16Array) SetDimensions(dimensions []ArrayDimension) error { return nil } -func (a int16Array) ScanIndex(i int) interface{} { +func (a int16Array) ScanIndex(i int) any { return &a[i] } -func (a int16Array) ScanIndexType() interface{} { +func (a int16Array) ScanIndexType() any { return new(int16) } @@ -49,11 +49,11 @@ func (a uint16Array) Dimensions() []ArrayDimension { return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} } -func (a uint16Array) Index(i int) interface{} { +func (a uint16Array) Index(i int) any { return a[i] } -func (a uint16Array) IndexType() interface{} { +func (a uint16Array) IndexType() any { var el uint16 return el } @@ -69,10 +69,10 @@ func (a *uint16Array) SetDimensions(dimensions []ArrayDimension) error { return nil } -func (a uint16Array) ScanIndex(i int) interface{} { +func (a uint16Array) ScanIndex(i int) any { return &a[i] } -func (a uint16Array) ScanIndexType() interface{} { +func (a uint16Array) ScanIndexType() any { return new(uint16) } diff --git a/pgtype/array_getter_setter.go.erb b/pgtype/array_getter_setter.go.erb index a9d60d35..1c8cdff4 100644 --- a/pgtype/array_getter_setter.go.erb +++ b/pgtype/array_getter_setter.go.erb @@ -23,11 +23,11 @@ import ( return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} } - func (a <%= array_type %>) Index(i int) interface{} { + func (a <%= array_type %>) Index(i int) any { return a[i] } - func (a <%= array_type %>) IndexType() interface{} { + func (a <%= array_type %>) IndexType() any { var el <%= element_type %> return el } @@ -43,11 +43,11 @@ import ( return nil } - func (a <%= array_type %>) ScanIndex(i int) interface{} { + func (a <%= array_type %>) ScanIndex(i int) any { return &a[i] } - func (a <%= array_type %>) ScanIndexType() interface{} { + func (a <%= array_type %>) ScanIndexType() any { return new(<%= element_type %>) } <% end %> diff --git a/pgtype/bits.go b/pgtype/bits.go index 5b0671ca..30558118 100644 --- a/pgtype/bits.go +++ b/pgtype/bits.go @@ -33,7 +33,7 @@ func (b Bits) BitsValue() (Bits, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Bits) Scan(src interface{}) error { +func (dst *Bits) Scan(src any) error { if src == nil { *dst = Bits{} return nil @@ -70,7 +70,7 @@ func (BitsCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (BitsCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (BitsCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(BitsValuer); !ok { return nil } @@ -87,7 +87,7 @@ func (BitsCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) type encodePlanBitsCodecBinary struct{} -func (encodePlanBitsCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBitsCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { bits, err := value.(BitsValuer).BitsValue() if err != nil { return nil, err @@ -103,7 +103,7 @@ func (encodePlanBitsCodecBinary) Encode(value interface{}, buf []byte) (newBuf [ type encodePlanBitsCodecText struct{} -func (encodePlanBitsCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBitsCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { bits, err := value.(BitsValuer).BitsValue() if err != nil { return nil, err @@ -126,7 +126,7 @@ func (encodePlanBitsCodecText) Encode(value interface{}, buf []byte) (newBuf []b return buf, nil } -func (BitsCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (BitsCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -148,7 +148,7 @@ func (c BitsCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c BitsCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c BitsCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -163,7 +163,7 @@ func (c BitsCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (in type scanPlanBinaryBitsToBitsScanner struct{} -func (scanPlanBinaryBitsToBitsScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryBitsToBitsScanner) Scan(src []byte, dst any) error { scanner := (dst).(BitsScanner) if src == nil { @@ -182,7 +182,7 @@ func (scanPlanBinaryBitsToBitsScanner) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToBitsScanner struct{} -func (scanPlanTextAnyToBitsScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToBitsScanner) Scan(src []byte, dst any) error { scanner := (dst).(BitsScanner) if src == nil { diff --git a/pgtype/bits_test.go b/pgtype/bits_test.go index 3ca3b0c0..767f0d2b 100644 --- a/pgtype/bits_test.go +++ b/pgtype/bits_test.go @@ -9,8 +9,8 @@ import ( "github.com/jackc/pgx/v5/pgxtest" ) -func isExpectedEqBits(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { +func isExpectedEqBits(a any) func(any) bool { + return func(v any) bool { ab := a.(pgtype.Bits) vb := v.(pgtype.Bits) return bytes.Compare(ab.Bytes, vb.Bytes) == 0 && ab.Len == vb.Len && ab.Valid == vb.Valid diff --git a/pgtype/bool.go b/pgtype/bool.go index 24ae5c5f..6f3ef8ca 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -30,7 +30,7 @@ func (b Bool) BoolValue() (Bool, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Bool) Scan(src interface{}) error { +func (dst *Bool) Scan(src any) error { if src == nil { *dst = Bool{} return nil @@ -106,7 +106,7 @@ func (BoolCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (BoolCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (BoolCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -129,7 +129,7 @@ func (BoolCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) type encodePlanBoolCodecBinaryBool struct{} -func (encodePlanBoolCodecBinaryBool) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBoolCodecBinaryBool) Encode(value any, buf []byte) (newBuf []byte, err error) { v := value.(bool) if v { @@ -143,7 +143,7 @@ func (encodePlanBoolCodecBinaryBool) Encode(value interface{}, buf []byte) (newB type encodePlanBoolCodecTextBoolValuer struct{} -func (encodePlanBoolCodecTextBoolValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBoolCodecTextBoolValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { b, err := value.(BoolValuer).BoolValue() if err != nil { return nil, err @@ -164,7 +164,7 @@ func (encodePlanBoolCodecTextBoolValuer) Encode(value interface{}, buf []byte) ( type encodePlanBoolCodecBinaryBoolValuer struct{} -func (encodePlanBoolCodecBinaryBoolValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBoolCodecBinaryBoolValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { b, err := value.(BoolValuer).BoolValue() if err != nil { return nil, err @@ -185,7 +185,7 @@ func (encodePlanBoolCodecBinaryBoolValuer) Encode(value interface{}, buf []byte) type encodePlanBoolCodecTextBool struct{} -func (encodePlanBoolCodecTextBool) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBoolCodecTextBool) Encode(value any, buf []byte) (newBuf []byte, err error) { v := value.(bool) if v { @@ -197,7 +197,7 @@ func (encodePlanBoolCodecTextBool) Encode(value interface{}, buf []byte) (newBuf return buf, nil } -func (BoolCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (BoolCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -223,7 +223,7 @@ func (c BoolCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return c.DecodeValue(m, oid, format, src) } -func (c BoolCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c BoolCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -238,7 +238,7 @@ func (c BoolCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (in type scanPlanBinaryBoolToBool struct{} -func (scanPlanBinaryBoolToBool) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryBoolToBool) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -259,7 +259,7 @@ func (scanPlanBinaryBoolToBool) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToBool struct{} -func (scanPlanTextAnyToBool) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -280,7 +280,7 @@ func (scanPlanTextAnyToBool) Scan(src []byte, dst interface{}) error { type scanPlanBinaryBoolToBoolScanner struct{} -func (scanPlanBinaryBoolToBoolScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryBoolToBoolScanner) Scan(src []byte, dst any) error { s, ok := (dst).(BoolScanner) if !ok { return ErrScanTargetTypeChanged @@ -299,7 +299,7 @@ func (scanPlanBinaryBoolToBoolScanner) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToBoolScanner struct{} -func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst any) error { s, ok := (dst).(BoolScanner) if !ok { return ErrScanTargetTypeChanged diff --git a/pgtype/box.go b/pgtype/box.go index d6087eab..887d268b 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -34,7 +34,7 @@ func (b Box) BoxValue() (Box, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Box) Scan(src interface{}) error { +func (dst *Box) Scan(src any) error { if src == nil { *dst = Box{} return nil @@ -71,7 +71,7 @@ func (BoxCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (BoxCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (BoxCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(BoxValuer); !ok { return nil } @@ -88,7 +88,7 @@ func (BoxCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) type encodePlanBoxCodecBinary struct{} -func (encodePlanBoxCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBoxCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { box, err := value.(BoxValuer).BoxValue() if err != nil { return nil, err @@ -107,7 +107,7 @@ func (encodePlanBoxCodecBinary) Encode(value interface{}, buf []byte) (newBuf [] type encodePlanBoxCodecText struct{} -func (encodePlanBoxCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBoxCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { box, err := value.(BoxValuer).BoxValue() if err != nil { return nil, err @@ -126,7 +126,7 @@ func (encodePlanBoxCodecText) Encode(value interface{}, buf []byte) (newBuf []by return buf, nil } -func (BoxCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (BoxCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -146,7 +146,7 @@ func (BoxCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) S type scanPlanBinaryBoxToBoxScanner struct{} -func (scanPlanBinaryBoxToBoxScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryBoxToBoxScanner) Scan(src []byte, dst any) error { scanner := (dst).(BoxScanner) if src == nil { @@ -173,7 +173,7 @@ func (scanPlanBinaryBoxToBoxScanner) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToBoxScanner struct{} -func (scanPlanTextAnyToBoxScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToBoxScanner) Scan(src []byte, dst any) error { scanner := (dst).(BoxScanner) if src == nil { @@ -224,7 +224,7 @@ func (c BoxCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src [ return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c BoxCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c BoxCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 0f12ada3..b385b80a 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -603,7 +603,7 @@ func (w byteSliceWrapper) UUIDValue() (UUID, error) { // structWrapper implements CompositeIndexGetter for a struct. type structWrapper struct { - s interface{} + s any exportedFields []reflect.Value } @@ -611,7 +611,7 @@ func (w structWrapper) IsNull() bool { return w.s == nil } -func (w structWrapper) Index(i int) interface{} { +func (w structWrapper) Index(i int) any { if i >= len(w.exportedFields) { return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i) } @@ -621,7 +621,7 @@ func (w structWrapper) Index(i int) interface{} { // ptrStructWrapper implements CompositeIndexScanner for a pointer to a struct. type ptrStructWrapper struct { - s interface{} + s any exportedFields []reflect.Value } @@ -629,7 +629,7 @@ func (w *ptrStructWrapper) ScanNull() error { return fmt.Errorf("cannot scan NULL into %#v", w.s) } -func (w *ptrStructWrapper) ScanIndex(i int) interface{} { +func (w *ptrStructWrapper) ScanIndex(i int) any { if i >= len(w.exportedFields) { return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i) } @@ -649,11 +649,11 @@ func (a anySliceArray) Dimensions() []ArrayDimension { return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}} } -func (a anySliceArray) Index(i int) interface{} { +func (a anySliceArray) Index(i int) any { return a.slice.Index(i).Interface() } -func (a anySliceArray) IndexType() interface{} { +func (a anySliceArray) IndexType() any { return reflect.New(a.slice.Type().Elem()).Elem().Interface() } @@ -671,11 +671,11 @@ func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error { return nil } -func (a *anySliceArray) ScanIndex(i int) interface{} { +func (a *anySliceArray) ScanIndex(i int) any { return a.slice.Index(i).Addr().Interface() } -func (a *anySliceArray) ScanIndexType() interface{} { +func (a *anySliceArray) ScanIndexType() any { return reflect.New(a.slice.Type().Elem()).Interface() } @@ -706,7 +706,7 @@ func (a *anyMultiDimSliceArray) Dimensions() []ArrayDimension { return a.dims } -func (a *anyMultiDimSliceArray) Index(i int) interface{} { +func (a *anyMultiDimSliceArray) Index(i int) any { if len(a.dims) == 1 { return a.slice.Index(i).Interface() } @@ -726,7 +726,7 @@ func (a *anyMultiDimSliceArray) Index(i int) interface{} { return v.Interface() } -func (a *anyMultiDimSliceArray) IndexType() interface{} { +func (a *anyMultiDimSliceArray) IndexType() any { lowestSliceType := a.slice.Type() for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() { } @@ -794,11 +794,11 @@ func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type return slice } -func (a *anyMultiDimSliceArray) ScanIndex(i int) interface{} { +func (a *anyMultiDimSliceArray) ScanIndex(i int) any { return a.slice.Index(i).Addr().Interface() } -func (a *anyMultiDimSliceArray) ScanIndexType() interface{} { +func (a *anyMultiDimSliceArray) ScanIndexType() any { lowestSliceType := a.slice.Type() for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() { } diff --git a/pgtype/bytea.go b/pgtype/bytea.go index f3c33cb9..51994005 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -50,7 +50,7 @@ type UndecodedBytes []byte type scanPlanAnyToUndecodedBytes struct{} -func (scanPlanAnyToUndecodedBytes) Scan(src []byte, dst interface{}) error { +func (scanPlanAnyToUndecodedBytes) Scan(src []byte, dst any) error { dstBuf := dst.(*UndecodedBytes) if src == nil { *dstBuf = nil @@ -72,7 +72,7 @@ func (ByteaCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (ByteaCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (ByteaCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -95,7 +95,7 @@ func (ByteaCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{} type encodePlanBytesCodecBinaryBytes struct{} -func (encodePlanBytesCodecBinaryBytes) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBytesCodecBinaryBytes) Encode(value any, buf []byte) (newBuf []byte, err error) { b := value.([]byte) if b == nil { return nil, nil @@ -106,7 +106,7 @@ func (encodePlanBytesCodecBinaryBytes) Encode(value interface{}, buf []byte) (ne type encodePlanBytesCodecBinaryBytesValuer struct{} -func (encodePlanBytesCodecBinaryBytesValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBytesCodecBinaryBytesValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { b, err := value.(BytesValuer).BytesValue() if err != nil { return nil, err @@ -120,7 +120,7 @@ func (encodePlanBytesCodecBinaryBytesValuer) Encode(value interface{}, buf []byt type encodePlanBytesCodecTextBytes struct{} -func (encodePlanBytesCodecTextBytes) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBytesCodecTextBytes) Encode(value any, buf []byte) (newBuf []byte, err error) { b := value.([]byte) if b == nil { return nil, nil @@ -133,7 +133,7 @@ func (encodePlanBytesCodecTextBytes) Encode(value interface{}, buf []byte) (newB type encodePlanBytesCodecTextBytesValuer struct{} -func (encodePlanBytesCodecTextBytesValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanBytesCodecTextBytesValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { b, err := value.(BytesValuer).BytesValue() if err != nil { return nil, err @@ -147,7 +147,7 @@ func (encodePlanBytesCodecTextBytesValuer) Encode(value interface{}, buf []byte) return buf, nil } -func (ByteaCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (ByteaCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -171,7 +171,7 @@ func (ByteaCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) type scanPlanBinaryBytesToBytes struct{} -func (scanPlanBinaryBytesToBytes) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryBytesToBytes) Scan(src []byte, dst any) error { dstBuf := dst.(*[]byte) if src == nil { *dstBuf = nil @@ -185,14 +185,14 @@ func (scanPlanBinaryBytesToBytes) Scan(src []byte, dst interface{}) error { type scanPlanBinaryBytesToBytesScanner struct{} -func (scanPlanBinaryBytesToBytesScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryBytesToBytesScanner) Scan(src []byte, dst any) error { scanner := (dst).(BytesScanner) return scanner.ScanBytes(src) } type scanPlanTextByteaToBytes struct{} -func (scanPlanTextByteaToBytes) Scan(src []byte, dst interface{}) error { +func (scanPlanTextByteaToBytes) Scan(src []byte, dst any) error { dstBuf := dst.(*[]byte) if src == nil { *dstBuf = nil @@ -210,7 +210,7 @@ func (scanPlanTextByteaToBytes) Scan(src []byte, dst interface{}) error { type scanPlanTextByteaToBytesScanner struct{} -func (scanPlanTextByteaToBytesScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextByteaToBytesScanner) Scan(src []byte, dst any) error { scanner := (dst).(BytesScanner) buf, err := decodeHexBytea(src) if err != nil { @@ -241,7 +241,7 @@ func (c ByteaCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c ByteaCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c ByteaCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go index 7a348ebb..a0d27369 100644 --- a/pgtype/bytea_test.go +++ b/pgtype/bytea_test.go @@ -11,8 +11,8 @@ import ( "github.com/stretchr/testify/require" ) -func isExpectedEqBytes(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { +func isExpectedEqBytes(a any) func(any) bool { + return func(v any) bool { ab := a.([]byte) vb := v.([]byte) diff --git a/pgtype/circle.go b/pgtype/circle.go index 4b499a12..e8f118cc 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -35,7 +35,7 @@ func (c Circle) CircleValue() (Circle, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Circle) Scan(src interface{}) error { +func (dst *Circle) Scan(src any) error { if src == nil { *dst = Circle{} return nil @@ -72,7 +72,7 @@ func (CircleCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (CircleCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (CircleCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(CircleValuer); !ok { return nil } @@ -89,7 +89,7 @@ func (CircleCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{ type encodePlanCircleCodecBinary struct{} -func (encodePlanCircleCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanCircleCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { circle, err := value.(CircleValuer).CircleValue() if err != nil { return nil, err @@ -107,7 +107,7 @@ func (encodePlanCircleCodecBinary) Encode(value interface{}, buf []byte) (newBuf type encodePlanCircleCodecText struct{} -func (encodePlanCircleCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanCircleCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { circle, err := value.(CircleValuer).CircleValue() if err != nil { return nil, err @@ -125,7 +125,7 @@ func (encodePlanCircleCodecText) Encode(value interface{}, buf []byte) (newBuf [ return buf, nil } -func (CircleCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (CircleCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { @@ -146,7 +146,7 @@ func (c CircleCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c CircleCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c CircleCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -161,7 +161,7 @@ func (c CircleCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) ( type scanPlanBinaryCircleToCircleScanner struct{} -func (scanPlanBinaryCircleToCircleScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryCircleToCircleScanner) Scan(src []byte, dst any) error { scanner := (dst).(CircleScanner) if src == nil { @@ -185,7 +185,7 @@ func (scanPlanBinaryCircleToCircleScanner) Scan(src []byte, dst interface{}) err type scanPlanTextAnyToCircleScanner struct{} -func (scanPlanTextAnyToCircleScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToCircleScanner) Scan(src []byte, dst any) error { scanner := (dst).(CircleScanner) if src == nil { diff --git a/pgtype/composite.go b/pgtype/composite.go index af6ed28b..fb372325 100644 --- a/pgtype/composite.go +++ b/pgtype/composite.go @@ -16,7 +16,7 @@ type CompositeIndexGetter interface { IsNull() bool // Index returns the element at i. - Index(i int) interface{} + Index(i int) any } // CompositeIndexScanner is a type accessed by index that can be scanned from a PostgreSQL composite. @@ -25,7 +25,7 @@ type CompositeIndexScanner interface { ScanNull() error // ScanIndex returns a value usable as a scan target for i. - ScanIndex(i int) interface{} + ScanIndex(i int) any } type CompositeCodecField struct { @@ -54,7 +54,7 @@ func (c *CompositeCodec) PreferredFormat() int16 { return TextFormatCode } -func (c *CompositeCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (c *CompositeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(CompositeIndexGetter); !ok { return nil } @@ -74,7 +74,7 @@ type encodePlanCompositeCodecCompositeIndexGetterToBinary struct { m *Map } -func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { getter := value.(CompositeIndexGetter) if getter.IsNull() { @@ -94,7 +94,7 @@ type encodePlanCompositeCodecCompositeIndexGetterToText struct { m *Map } -func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value any, buf []byte) (newBuf []byte, err error) { getter := value.(CompositeIndexGetter) if getter.IsNull() { @@ -109,7 +109,7 @@ func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value int return b.Finish() } -func (c *CompositeCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (c *CompositeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { @@ -131,7 +131,7 @@ type scanPlanBinaryCompositeToCompositeIndexScanner struct { m *Map } -func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, target interface{}) error { +func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, target any) error { targetScanner := (target).(CompositeIndexScanner) if src == nil { @@ -170,7 +170,7 @@ type scanPlanTextCompositeToCompositeIndexScanner struct { m *Map } -func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, target interface{}) error { +func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, target any) error { targetScanner := (target).(CompositeIndexScanner) if src == nil { @@ -221,7 +221,7 @@ func (c *CompositeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16 } } -func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -229,9 +229,9 @@ func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byt switch format { case TextFormatCode: scanner := NewCompositeTextScanner(m, src) - values := make(map[string]interface{}, len(c.Fields)) + values := make(map[string]any, len(c.Fields)) for i := 0; scanner.Next() && i < len(c.Fields); i++ { - var v interface{} + var v any fieldPlan := m.PlanScan(c.Fields[i].Type.OID, TextFormatCode, &v) if fieldPlan == nil { return nil, fmt.Errorf("unable to scan OID %d in text format into %v", c.Fields[i].Type.OID, v) @@ -252,9 +252,9 @@ func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byt return values, nil case BinaryFormatCode: scanner := NewCompositeBinaryScanner(m, src) - values := make(map[string]interface{}, len(c.Fields)) + values := make(map[string]any, len(c.Fields)) for i := 0; scanner.Next() && i < len(c.Fields); i++ { - var v interface{} + var v any fieldPlan := m.PlanScan(scanner.OID(), BinaryFormatCode, &v) if fieldPlan == nil { return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v) @@ -472,7 +472,7 @@ func NewCompositeBinaryBuilder(m *Map, buf []byte) *CompositeBinaryBuilder { return &CompositeBinaryBuilder{m: m, buf: buf, startIdx: startIdx} } -func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) { +func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field any) { if b.err != nil { return } @@ -529,7 +529,7 @@ func NewCompositeTextBuilder(m *Map, buf []byte) *CompositeTextBuilder { return &CompositeTextBuilder{m: m, buf: buf} } -func (b *CompositeTextBuilder) AppendValue(oid uint32, field interface{}) { +func (b *CompositeTextBuilder) AppendValue(oid uint32, field any) { if b.err != nil { return } @@ -581,7 +581,7 @@ func quoteCompositeFieldIfNeeded(src string) string { // CompositeFields represents the values of a composite value. It can be used as an encoding source or as a scan target. // It cannot scan a NULL, but the composite fields can be NULL. -type CompositeFields []interface{} +type CompositeFields []any func (cf CompositeFields) SkipUnderlyingTypePlan() {} @@ -589,7 +589,7 @@ func (cf CompositeFields) IsNull() bool { return cf == nil } -func (cf CompositeFields) Index(i int) interface{} { +func (cf CompositeFields) Index(i int) any { return cf[i] } @@ -597,6 +597,6 @@ func (cf CompositeFields) ScanNull() error { return fmt.Errorf("cannot scan NULL into CompositeFields") } -func (cf CompositeFields) ScanIndex(i int) interface{} { +func (cf CompositeFields) ScanIndex(i int) any { return cf[i] } diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go index 559403d8..a6fa8315 100644 --- a/pgtype/composite_test.go +++ b/pgtype/composite_test.go @@ -60,7 +60,7 @@ func (p point3d) IsNull() bool { return false } -func (p point3d) Index(i int) interface{} { +func (p point3d) Index(i int) any { switch i { case 0: return p.X @@ -77,7 +77,7 @@ func (p *point3d) ScanNull() error { return fmt.Errorf("cannot scan NULL into point3d") } -func (p *point3d) ScanIndex(i int) interface{} { +func (p *point3d) ScanIndex(i int) any { switch i { case 0: return &p.X @@ -202,7 +202,7 @@ create type point3d as ( values, err := rows.Values() require.NoErrorf(t, err, "%v", format.name) require.Lenf(t, values, 1, "%v", format.name) - require.Equalf(t, map[string]interface{}{"x": 1.0, "y": 2.0, "z": 3.0}, values[0], "%v", format.name) + require.Equalf(t, map[string]any{"x": 1.0, "y": 2.0, "z": 3.0}, values[0], "%v", format.name) require.False(t, rows.Next()) require.NoErrorf(t, rows.Err(), "%v", format.name) } diff --git a/pgtype/convert.go b/pgtype/convert.go index 21e208f5..31ce11e5 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -15,7 +15,7 @@ const ( ) // underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8 -func underlyingNumberType(val interface{}) (interface{}, bool) { +func underlyingNumberType(val any) (any, bool) { refVal := reflect.ValueOf(val) switch refVal.Kind() { @@ -70,7 +70,7 @@ func underlyingNumberType(val interface{}) (interface{}, bool) { } // underlyingBoolType gets the underlying type that can be converted to Bool -func underlyingBoolType(val interface{}) (interface{}, bool) { +func underlyingBoolType(val any) (any, bool) { refVal := reflect.ValueOf(val) switch refVal.Kind() { @@ -89,7 +89,7 @@ func underlyingBoolType(val interface{}) (interface{}, bool) { } // underlyingBytesType gets the underlying type that can be converted to []byte -func underlyingBytesType(val interface{}) (interface{}, bool) { +func underlyingBytesType(val any) (any, bool) { refVal := reflect.ValueOf(val) switch refVal.Kind() { @@ -110,7 +110,7 @@ func underlyingBytesType(val interface{}) (interface{}, bool) { } // underlyingStringType gets the underlying type that can be converted to String -func underlyingStringType(val interface{}) (interface{}, bool) { +func underlyingStringType(val any) (any, bool) { refVal := reflect.ValueOf(val) switch refVal.Kind() { @@ -129,7 +129,7 @@ func underlyingStringType(val interface{}) (interface{}, bool) { } // underlyingPtrType dereferences a pointer -func underlyingPtrType(val interface{}) (interface{}, bool) { +func underlyingPtrType(val any) (any, bool) { refVal := reflect.ValueOf(val) switch refVal.Kind() { @@ -145,7 +145,7 @@ func underlyingPtrType(val interface{}) (interface{}, bool) { } // underlyingTimeType gets the underlying type that can be converted to time.Time -func underlyingTimeType(val interface{}) (interface{}, bool) { +func underlyingTimeType(val any) (any, bool) { refVal := reflect.ValueOf(val) switch refVal.Kind() { @@ -166,7 +166,7 @@ func underlyingTimeType(val interface{}) (interface{}, bool) { } // underlyingUUIDType gets the underlying type that can be converted to [16]byte -func underlyingUUIDType(val interface{}) (interface{}, bool) { +func underlyingUUIDType(val any) (any, bool) { refVal := reflect.ValueOf(val) switch refVal.Kind() { @@ -187,7 +187,7 @@ func underlyingUUIDType(val interface{}) (interface{}, bool) { } // underlyingSliceType gets the underlying slice type -func underlyingSliceType(val interface{}) (interface{}, bool) { +func underlyingSliceType(val any) (any, bool) { refVal := reflect.ValueOf(val) switch refVal.Kind() { @@ -208,7 +208,7 @@ func underlyingSliceType(val interface{}) (interface{}, bool) { return nil, false } -func int64AssignTo(srcVal int64, srcValid bool, dst interface{}) error { +func int64AssignTo(srcVal int64, srcValid bool, dst any) error { if srcValid { switch v := dst.(type) { case *int: @@ -326,7 +326,7 @@ func int64AssignTo(srcVal int64, srcValid bool, dst interface{}) error { return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcValid, dst) } -func float64AssignTo(srcVal float64, srcValid bool, dst interface{}) error { +func float64AssignTo(srcVal float64, srcValid bool, dst any) error { if srcValid { switch v := dst.(type) { case *float32: @@ -368,7 +368,7 @@ func float64AssignTo(srcVal float64, srcValid bool, dst interface{}) error { return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcValid, dst) } -func NullAssignTo(dst interface{}) error { +func NullAssignTo(dst any) error { dstPtr := reflect.ValueOf(dst) // AssignTo dst must always be a pointer @@ -389,7 +389,7 @@ func NullAssignTo(dst interface{}) error { var kindTypes map[reflect.Kind]reflect.Type -func toInterface(dst reflect.Value, t reflect.Type) (interface{}, bool) { +func toInterface(dst reflect.Value, t reflect.Type) (any, bool) { nextDst := dst.Convert(t) return nextDst.Interface(), dst.Type() != nextDst.Type() } @@ -401,7 +401,7 @@ func toInterface(dst reflect.Value, t reflect.Type) (interface{}, bool) { // // GetAssignToDstType returns the converted dst and a bool representing if any // change was made. -func GetAssignToDstType(dst interface{}) (interface{}, bool) { +func GetAssignToDstType(dst any) (any, bool) { dstPtr := reflect.ValueOf(dst) // AssignTo dst must always be a pointer diff --git a/pgtype/date.go b/pgtype/date.go index db331e6c..78c5db92 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -40,7 +40,7 @@ const ( ) // Scan implements the database/sql Scanner interface. -func (dst *Date) Scan(src interface{}) error { +func (dst *Date) Scan(src any) error { if src == nil { *dst = Date{} return nil @@ -127,7 +127,7 @@ func (DateCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (DateCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (DateCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(DateValuer); !ok { return nil } @@ -144,7 +144,7 @@ func (DateCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) type encodePlanDateCodecBinary struct{} -func (encodePlanDateCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanDateCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { date, err := value.(DateValuer).DateValue() if err != nil { return nil, err @@ -173,7 +173,7 @@ func (encodePlanDateCodecBinary) Encode(value interface{}, buf []byte) (newBuf [ type encodePlanDateCodecText struct{} -func (encodePlanDateCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanDateCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { date, err := value.(DateValuer).DateValue() if err != nil { return nil, err @@ -211,7 +211,7 @@ func (encodePlanDateCodecText) Encode(value interface{}, buf []byte) (newBuf []b return buf, nil } -func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -231,7 +231,7 @@ func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) type scanPlanBinaryDateToDateScanner struct{} -func (scanPlanBinaryDateToDateScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryDateToDateScanner) Scan(src []byte, dst any) error { scanner := (dst).(DateScanner) if src == nil { @@ -257,7 +257,7 @@ func (scanPlanBinaryDateToDateScanner) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToDateScanner struct{} -func (scanPlanTextAnyToDateScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToDateScanner) Scan(src []byte, dst any) error { scanner := (dst).(DateScanner) if src == nil { @@ -301,7 +301,7 @@ func (c DateCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c DateCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c DateCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/date_test.go b/pgtype/date_test.go index 25c6bfc2..de61fd72 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -9,8 +9,8 @@ import ( "github.com/jackc/pgx/v5/pgxtest" ) -func isExpectedEqTime(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { +func isExpectedEqTime(a any) func(any) bool { + return func(v any) bool { at := a.(time.Time) vt := v.(time.Time) diff --git a/pgtype/enum_codec.go b/pgtype/enum_codec.go index ecad0d9f..93513111 100644 --- a/pgtype/enum_codec.go +++ b/pgtype/enum_codec.go @@ -20,7 +20,7 @@ func (EnumCodec) PreferredFormat() int16 { return TextFormatCode } -func (EnumCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (EnumCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case TextFormatCode, BinaryFormatCode: switch value.(type) { @@ -38,7 +38,7 @@ func (EnumCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) return nil } -func (c *EnumCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (c *EnumCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case TextFormatCode, BinaryFormatCode: switch target.(type) { @@ -60,7 +60,7 @@ func (c *EnumCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return c.DecodeValue(m, oid, format, src) } -func (c *EnumCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c *EnumCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -87,7 +87,7 @@ type scanPlanTextAnyToEnumString struct { codec *EnumCodec } -func (plan *scanPlanTextAnyToEnumString) Scan(src []byte, dst interface{}) error { +func (plan *scanPlanTextAnyToEnumString) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -102,7 +102,7 @@ type scanPlanTextAnyToEnumTextScanner struct { codec *EnumCodec } -func (plan *scanPlanTextAnyToEnumTextScanner) Scan(src []byte, dst interface{}) error { +func (plan *scanPlanTextAnyToEnumTextScanner) Scan(src []byte, dst any) error { scanner := (dst).(TextScanner) if src == nil { diff --git a/pgtype/enum_codec_test.go b/pgtype/enum_codec_test.go index 633b610b..d064d49c 100644 --- a/pgtype/enum_codec_test.go +++ b/pgtype/enum_codec_test.go @@ -64,6 +64,6 @@ create type enum_test as enum ('foo', 'bar', 'baz');`) require.True(t, rows.Next()) values, err := rows.Values() require.NoError(t, err) - require.Equal(t, values, []interface{}{"foo"}) + require.Equal(t, values, []any{"foo"}) }) } diff --git a/pgtype/float4.go b/pgtype/float4.go index 2c628011..a68fa7b2 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -35,7 +35,7 @@ func (f Float4) Int64Value() (Int8, error) { } // Scan implements the database/sql Scanner interface. -func (f *Float4) Scan(src interface{}) error { +func (f *Float4) Scan(src any) error { if src == nil { *f = Float4{} return nil @@ -75,7 +75,7 @@ func (Float4Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Float4Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (Float4Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -102,21 +102,21 @@ func (Float4Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{ type encodePlanFloat4CodecBinaryFloat32 struct{} -func (encodePlanFloat4CodecBinaryFloat32) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanFloat4CodecBinaryFloat32) Encode(value any, buf []byte) (newBuf []byte, err error) { n := value.(float32) return pgio.AppendUint32(buf, math.Float32bits(n)), nil } type encodePlanTextFloat32 struct{} -func (encodePlanTextFloat32) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTextFloat32) Encode(value any, buf []byte) (newBuf []byte, err error) { n := value.(float32) return append(buf, strconv.FormatFloat(float64(n), 'f', -1, 32)...), nil } type encodePlanFloat4CodecBinaryFloat64Valuer struct{} -func (encodePlanFloat4CodecBinaryFloat64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanFloat4CodecBinaryFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Float64Valuer).Float64Value() if err != nil { return nil, err @@ -131,7 +131,7 @@ func (encodePlanFloat4CodecBinaryFloat64Valuer) Encode(value interface{}, buf [] type encodePlanFloat4CodecBinaryInt64Valuer struct{} -func (encodePlanFloat4CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanFloat4CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Int64Valuer).Int64Value() if err != nil { return nil, err @@ -145,7 +145,7 @@ func (encodePlanFloat4CodecBinaryInt64Valuer) Encode(value interface{}, buf []by return pgio.AppendUint32(buf, math.Float32bits(f)), nil } -func (Float4Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (Float4Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -175,7 +175,7 @@ func (Float4Codec) PlanScan(m *Map, oid uint32, format int16, target interface{} type scanPlanBinaryFloat4ToFloat32 struct{} -func (scanPlanBinaryFloat4ToFloat32) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryFloat4ToFloat32) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -193,7 +193,7 @@ func (scanPlanBinaryFloat4ToFloat32) Scan(src []byte, dst interface{}) error { type scanPlanBinaryFloat4ToFloat64Scanner struct{} -func (scanPlanBinaryFloat4ToFloat64Scanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryFloat4ToFloat64Scanner) Scan(src []byte, dst any) error { s := (dst).(Float64Scanner) if src == nil { @@ -210,7 +210,7 @@ func (scanPlanBinaryFloat4ToFloat64Scanner) Scan(src []byte, dst interface{}) er type scanPlanBinaryFloat4ToInt64Scanner struct{} -func (scanPlanBinaryFloat4ToInt64Scanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryFloat4ToInt64Scanner) Scan(src []byte, dst any) error { s := (dst).(Int64Scanner) if src == nil { @@ -233,7 +233,7 @@ func (scanPlanBinaryFloat4ToInt64Scanner) Scan(src []byte, dst interface{}) erro type scanPlanBinaryFloat4ToTextScanner struct{} -func (scanPlanBinaryFloat4ToTextScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryFloat4ToTextScanner) Scan(src []byte, dst any) error { s := (dst).(TextScanner) if src == nil { @@ -252,7 +252,7 @@ func (scanPlanBinaryFloat4ToTextScanner) Scan(src []byte, dst interface{}) error type scanPlanTextAnyToFloat32 struct{} -func (scanPlanTextAnyToFloat32) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToFloat32) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -281,7 +281,7 @@ func (c Float4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr return n, nil } -func (c Float4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c Float4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/float8.go b/pgtype/float8.go index b7c6177e..98334dc6 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -43,7 +43,7 @@ func (f Float8) Int64Value() (Int8, error) { } // Scan implements the database/sql Scanner interface. -func (f *Float8) Scan(src interface{}) error { +func (f *Float8) Scan(src any) error { if src == nil { *f = Float8{} return nil @@ -83,7 +83,7 @@ func (Float8Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Float8Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (Float8Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -110,21 +110,21 @@ func (Float8Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{ type encodePlanFloat8CodecBinaryFloat64 struct{} -func (encodePlanFloat8CodecBinaryFloat64) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanFloat8CodecBinaryFloat64) Encode(value any, buf []byte) (newBuf []byte, err error) { n := value.(float64) return pgio.AppendUint64(buf, math.Float64bits(n)), nil } type encodePlanTextFloat64 struct{} -func (encodePlanTextFloat64) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTextFloat64) Encode(value any, buf []byte) (newBuf []byte, err error) { n := value.(float64) return append(buf, strconv.FormatFloat(n, 'f', -1, 64)...), nil } type encodePlanFloat8CodecBinaryFloat64Valuer struct{} -func (encodePlanFloat8CodecBinaryFloat64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanFloat8CodecBinaryFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Float64Valuer).Float64Value() if err != nil { return nil, err @@ -139,7 +139,7 @@ func (encodePlanFloat8CodecBinaryFloat64Valuer) Encode(value interface{}, buf [] type encodePlanTextFloat64Valuer struct{} -func (encodePlanTextFloat64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTextFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Float64Valuer).Float64Value() if err != nil { return nil, err @@ -154,7 +154,7 @@ func (encodePlanTextFloat64Valuer) Encode(value interface{}, buf []byte) (newBuf type encodePlanFloat8CodecBinaryInt64Valuer struct{} -func (encodePlanFloat8CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanFloat8CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Int64Valuer).Int64Value() if err != nil { return nil, err @@ -170,7 +170,7 @@ func (encodePlanFloat8CodecBinaryInt64Valuer) Encode(value interface{}, buf []by type encodePlanTextInt64Valuer struct{} -func (encodePlanTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Int64Valuer).Int64Value() if err != nil { return nil, err @@ -183,7 +183,7 @@ func (encodePlanTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf [ return append(buf, strconv.FormatInt(n.Int64, 10)...), nil } -func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -213,7 +213,7 @@ func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target interface{} type scanPlanBinaryFloat8ToFloat64 struct{} -func (scanPlanBinaryFloat8ToFloat64) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryFloat8ToFloat64) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -231,7 +231,7 @@ func (scanPlanBinaryFloat8ToFloat64) Scan(src []byte, dst interface{}) error { type scanPlanBinaryFloat8ToFloat64Scanner struct{} -func (scanPlanBinaryFloat8ToFloat64Scanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryFloat8ToFloat64Scanner) Scan(src []byte, dst any) error { s := (dst).(Float64Scanner) if src == nil { @@ -248,7 +248,7 @@ func (scanPlanBinaryFloat8ToFloat64Scanner) Scan(src []byte, dst interface{}) er type scanPlanBinaryFloat8ToInt64Scanner struct{} -func (scanPlanBinaryFloat8ToInt64Scanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryFloat8ToInt64Scanner) Scan(src []byte, dst any) error { s := (dst).(Int64Scanner) if src == nil { @@ -271,7 +271,7 @@ func (scanPlanBinaryFloat8ToInt64Scanner) Scan(src []byte, dst interface{}) erro type scanPlanBinaryFloat8ToTextScanner struct{} -func (scanPlanBinaryFloat8ToTextScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryFloat8ToTextScanner) Scan(src []byte, dst any) error { s := (dst).(TextScanner) if src == nil { @@ -290,7 +290,7 @@ func (scanPlanBinaryFloat8ToTextScanner) Scan(src []byte, dst interface{}) error type scanPlanTextAnyToFloat64 struct{} -func (scanPlanTextAnyToFloat64) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToFloat64) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -308,7 +308,7 @@ func (scanPlanTextAnyToFloat64) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToFloat64Scanner struct{} -func (scanPlanTextAnyToFloat64Scanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToFloat64Scanner) Scan(src []byte, dst any) error { s := (dst).(Float64Scanner) if src == nil { @@ -327,7 +327,7 @@ func (c Float8Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr return c.DecodeValue(m, oid, format, src) } -func (c Float8Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c Float8Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 7f0fa8c2..4743643e 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -35,7 +35,7 @@ func (h Hstore) HstoreValue() (Hstore, error) { } // Scan implements the database/sql Scanner interface. -func (h *Hstore) Scan(src interface{}) error { +func (h *Hstore) Scan(src any) error { if src == nil { *h = nil return nil @@ -72,7 +72,7 @@ func (HstoreCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (HstoreCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (HstoreCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(HstoreValuer); !ok { return nil } @@ -89,7 +89,7 @@ func (HstoreCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{ type encodePlanHstoreCodecBinary struct{} -func (encodePlanHstoreCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanHstoreCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { hstore, err := value.(HstoreValuer).HstoreValue() if err != nil { return nil, err @@ -118,7 +118,7 @@ func (encodePlanHstoreCodecBinary) Encode(value interface{}, buf []byte) (newBuf type encodePlanHstoreCodecText struct{} -func (encodePlanHstoreCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { hstore, err := value.(HstoreValuer).HstoreValue() if err != nil { return nil, err @@ -150,7 +150,7 @@ func (encodePlanHstoreCodecText) Encode(value interface{}, buf []byte) (newBuf [ return buf, nil } -func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -170,7 +170,7 @@ func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target interface{} type scanPlanBinaryHstoreToHstoreScanner struct{} -func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error { scanner := (dst).(HstoreScanner) if src == nil { @@ -230,7 +230,7 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst interface{}) err type scanPlanTextAnyToHstoreScanner struct{} -func (scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error { scanner := (dst).(HstoreScanner) if src == nil { @@ -258,7 +258,7 @@ func (c HstoreCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c HstoreCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c HstoreCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go index fb0cc27b..f8684bf7 100644 --- a/pgtype/hstore_test.go +++ b/pgtype/hstore_test.go @@ -9,8 +9,8 @@ import ( "github.com/jackc/pgx/v5/pgxtest" ) -func isExpectedEqMapStringString(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { +func isExpectedEqMapStringString(a any) func(any) bool { + return func(v any) bool { am := a.(map[string]string) vm := v.(map[string]string) @@ -28,8 +28,8 @@ func isExpectedEqMapStringString(a interface{}) func(interface{}) bool { } } -func isExpectedEqMapStringPointerString(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { +func isExpectedEqMapStringPointerString(a any) func(any) bool { + return func(v any) bool { am := a.(map[string]*string) vm := v.(map[string]*string) diff --git a/pgtype/inet.go b/pgtype/inet.go index a272e00b..f8abeef8 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -38,7 +38,7 @@ func (inet Inet) InetValue() (Inet, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Inet) Scan(src interface{}) error { +func (dst *Inet) Scan(src any) error { if src == nil { *dst = Inet{} return nil @@ -75,7 +75,7 @@ func (InetCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (InetCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (InetCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(InetValuer); !ok { return nil } @@ -92,7 +92,7 @@ func (InetCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) type encodePlanInetCodecBinary struct{} -func (encodePlanInetCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInetCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { inet, err := value.(InetValuer).InetValue() if err != nil { return nil, err @@ -127,7 +127,7 @@ func (encodePlanInetCodecBinary) Encode(value interface{}, buf []byte) (newBuf [ type encodePlanInetCodecText struct{} -func (encodePlanInetCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInetCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { inet, err := value.(InetValuer).InetValue() if err != nil { return nil, err @@ -140,7 +140,7 @@ func (encodePlanInetCodecText) Encode(value interface{}, buf []byte) (newBuf []b return append(buf, inet.IPNet.String()...), nil } -func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -162,7 +162,7 @@ func (c InetCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c InetCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c InetCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -182,7 +182,7 @@ func (c InetCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (in type scanPlanBinaryInetToInetScanner struct{} -func (scanPlanBinaryInetToInetScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInetToInetScanner) Scan(src []byte, dst any) error { scanner := (dst).(InetScanner) if src == nil { @@ -211,7 +211,7 @@ func (scanPlanBinaryInetToInetScanner) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToInetScanner struct{} -func (scanPlanTextAnyToInetScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToInetScanner) Scan(src []byte, dst any) error { scanner := (dst).(InetScanner) if src == nil { diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go index 8bf11a76..0a174e1a 100644 --- a/pgtype/inet_test.go +++ b/pgtype/inet_test.go @@ -9,8 +9,8 @@ import ( "github.com/jackc/pgx/v5/pgxtest" ) -func isExpectedEqIPNet(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { +func isExpectedEqIPNet(a any) func(any) bool { + return func(v any) bool { ap := a.(*net.IPNet) vp := v.(net.IPNet) diff --git a/pgtype/int.go b/pgtype/int.go index b3eabceb..147a3656 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -48,7 +48,7 @@ func (n Int2) Int64Value() (Int8, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Int2) Scan(src interface{}) error { +func (dst *Int2) Scan(src any) error { if src == nil { *dst = Int2{} return nil @@ -127,7 +127,7 @@ func (Int2Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int2Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (Int2Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -150,21 +150,21 @@ func (Int2Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) type encodePlanInt2CodecBinaryInt16 struct{} -func (encodePlanInt2CodecBinaryInt16) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt2CodecBinaryInt16) Encode(value any, buf []byte) (newBuf []byte, err error) { n := value.(int16) return pgio.AppendInt16(buf, int16(n)), nil } type encodePlanInt2CodecTextInt16 struct{} -func (encodePlanInt2CodecTextInt16) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt2CodecTextInt16) Encode(value any, buf []byte) (newBuf []byte, err error) { n := value.(int16) return append(buf, strconv.FormatInt(int64(n), 10)...), nil } type encodePlanInt2CodecBinaryInt64Valuer struct{} -func (encodePlanInt2CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt2CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Int64Valuer).Int64Value() if err != nil { return nil, err @@ -186,7 +186,7 @@ func (encodePlanInt2CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte type encodePlanInt2CodecTextInt64Valuer struct{} -func (encodePlanInt2CodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt2CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Int64Valuer).Int64Value() if err != nil { return nil, err @@ -206,7 +206,7 @@ func (encodePlanInt2CodecTextInt64Valuer) Encode(value interface{}, buf []byte) return append(buf, strconv.FormatInt(n.Int64, 10)...), nil } -func (Int2Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (Int2Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -279,7 +279,7 @@ func (c Int2Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return n, nil } -func (c Int2Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c Int2Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -294,7 +294,7 @@ func (c Int2Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (in type scanPlanBinaryInt2ToInt8 struct{} -func (scanPlanBinaryInt2ToInt8) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToInt8) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -322,7 +322,7 @@ func (scanPlanBinaryInt2ToInt8) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt2ToUint8 struct{} -func (scanPlanBinaryInt2ToUint8) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToUint8) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -352,7 +352,7 @@ func (scanPlanBinaryInt2ToUint8) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt2ToInt16 struct{} -func (scanPlanBinaryInt2ToInt16) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToInt16) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -373,7 +373,7 @@ func (scanPlanBinaryInt2ToInt16) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt2ToUint16 struct{} -func (scanPlanBinaryInt2ToUint16) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToUint16) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -399,7 +399,7 @@ func (scanPlanBinaryInt2ToUint16) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt2ToInt32 struct{} -func (scanPlanBinaryInt2ToInt32) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToInt32) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -420,7 +420,7 @@ func (scanPlanBinaryInt2ToInt32) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt2ToUint32 struct{} -func (scanPlanBinaryInt2ToUint32) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToUint32) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -446,7 +446,7 @@ func (scanPlanBinaryInt2ToUint32) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt2ToInt64 struct{} -func (scanPlanBinaryInt2ToInt64) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToInt64) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -467,7 +467,7 @@ func (scanPlanBinaryInt2ToInt64) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt2ToUint64 struct{} -func (scanPlanBinaryInt2ToUint64) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToUint64) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -493,7 +493,7 @@ func (scanPlanBinaryInt2ToUint64) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt2ToInt struct{} -func (scanPlanBinaryInt2ToInt) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToInt) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -514,7 +514,7 @@ func (scanPlanBinaryInt2ToInt) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt2ToUint struct{} -func (scanPlanBinaryInt2ToUint) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToUint) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -540,7 +540,7 @@ func (scanPlanBinaryInt2ToUint) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt2ToInt64Scanner struct{} -func (scanPlanBinaryInt2ToInt64Scanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToInt64Scanner) Scan(src []byte, dst any) error { s, ok := (dst).(Int64Scanner) if !ok { return ErrScanTargetTypeChanged @@ -561,7 +561,7 @@ func (scanPlanBinaryInt2ToInt64Scanner) Scan(src []byte, dst interface{}) error type scanPlanBinaryInt2ToTextScanner struct{} -func (scanPlanBinaryInt2ToTextScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToTextScanner) Scan(src []byte, dst any) error { s, ok := (dst).(TextScanner) if !ok { return ErrScanTargetTypeChanged @@ -608,7 +608,7 @@ func (n Int4) Int64Value() (Int8, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Int4) Scan(src interface{}) error { +func (dst *Int4) Scan(src any) error { if src == nil { *dst = Int4{} return nil @@ -687,7 +687,7 @@ func (Int4Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int4Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (Int4Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -710,21 +710,21 @@ func (Int4Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) type encodePlanInt4CodecBinaryInt32 struct{} -func (encodePlanInt4CodecBinaryInt32) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt4CodecBinaryInt32) Encode(value any, buf []byte) (newBuf []byte, err error) { n := value.(int32) return pgio.AppendInt32(buf, int32(n)), nil } type encodePlanInt4CodecTextInt32 struct{} -func (encodePlanInt4CodecTextInt32) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt4CodecTextInt32) Encode(value any, buf []byte) (newBuf []byte, err error) { n := value.(int32) return append(buf, strconv.FormatInt(int64(n), 10)...), nil } type encodePlanInt4CodecBinaryInt64Valuer struct{} -func (encodePlanInt4CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt4CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Int64Valuer).Int64Value() if err != nil { return nil, err @@ -746,7 +746,7 @@ func (encodePlanInt4CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte type encodePlanInt4CodecTextInt64Valuer struct{} -func (encodePlanInt4CodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt4CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Int64Valuer).Int64Value() if err != nil { return nil, err @@ -766,7 +766,7 @@ func (encodePlanInt4CodecTextInt64Valuer) Encode(value interface{}, buf []byte) return append(buf, strconv.FormatInt(n.Int64, 10)...), nil } -func (Int4Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (Int4Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -839,7 +839,7 @@ func (c Int4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return n, nil } -func (c Int4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c Int4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -854,7 +854,7 @@ func (c Int4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (in type scanPlanBinaryInt4ToInt8 struct{} -func (scanPlanBinaryInt4ToInt8) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToInt8) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -882,7 +882,7 @@ func (scanPlanBinaryInt4ToInt8) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt4ToUint8 struct{} -func (scanPlanBinaryInt4ToUint8) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToUint8) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -912,7 +912,7 @@ func (scanPlanBinaryInt4ToUint8) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt4ToInt16 struct{} -func (scanPlanBinaryInt4ToInt16) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToInt16) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -940,7 +940,7 @@ func (scanPlanBinaryInt4ToInt16) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt4ToUint16 struct{} -func (scanPlanBinaryInt4ToUint16) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToUint16) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -970,7 +970,7 @@ func (scanPlanBinaryInt4ToUint16) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt4ToInt32 struct{} -func (scanPlanBinaryInt4ToInt32) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToInt32) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -991,7 +991,7 @@ func (scanPlanBinaryInt4ToInt32) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt4ToUint32 struct{} -func (scanPlanBinaryInt4ToUint32) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToUint32) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1017,7 +1017,7 @@ func (scanPlanBinaryInt4ToUint32) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt4ToInt64 struct{} -func (scanPlanBinaryInt4ToInt64) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToInt64) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1038,7 +1038,7 @@ func (scanPlanBinaryInt4ToInt64) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt4ToUint64 struct{} -func (scanPlanBinaryInt4ToUint64) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToUint64) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1064,7 +1064,7 @@ func (scanPlanBinaryInt4ToUint64) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt4ToInt struct{} -func (scanPlanBinaryInt4ToInt) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToInt) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1085,7 +1085,7 @@ func (scanPlanBinaryInt4ToInt) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt4ToUint struct{} -func (scanPlanBinaryInt4ToUint) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToUint) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1111,7 +1111,7 @@ func (scanPlanBinaryInt4ToUint) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt4ToInt64Scanner struct{} -func (scanPlanBinaryInt4ToInt64Scanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToInt64Scanner) Scan(src []byte, dst any) error { s, ok := (dst).(Int64Scanner) if !ok { return ErrScanTargetTypeChanged @@ -1132,7 +1132,7 @@ func (scanPlanBinaryInt4ToInt64Scanner) Scan(src []byte, dst interface{}) error type scanPlanBinaryInt4ToTextScanner struct{} -func (scanPlanBinaryInt4ToTextScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt4ToTextScanner) Scan(src []byte, dst any) error { s, ok := (dst).(TextScanner) if !ok { return ErrScanTargetTypeChanged @@ -1179,7 +1179,7 @@ func (n Int8) Int64Value() (Int8, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Int8) Scan(src interface{}) error { +func (dst *Int8) Scan(src any) error { if src == nil { *dst = Int8{} return nil @@ -1258,7 +1258,7 @@ func (Int8Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int8Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (Int8Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -1281,21 +1281,21 @@ func (Int8Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) type encodePlanInt8CodecBinaryInt64 struct{} -func (encodePlanInt8CodecBinaryInt64) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt8CodecBinaryInt64) Encode(value any, buf []byte) (newBuf []byte, err error) { n := value.(int64) return pgio.AppendInt64(buf, int64(n)), nil } type encodePlanInt8CodecTextInt64 struct{} -func (encodePlanInt8CodecTextInt64) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt8CodecTextInt64) Encode(value any, buf []byte) (newBuf []byte, err error) { n := value.(int64) return append(buf, strconv.FormatInt(int64(n), 10)...), nil } type encodePlanInt8CodecBinaryInt64Valuer struct{} -func (encodePlanInt8CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt8CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Int64Valuer).Int64Value() if err != nil { return nil, err @@ -1317,7 +1317,7 @@ func (encodePlanInt8CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte type encodePlanInt8CodecTextInt64Valuer struct{} -func (encodePlanInt8CodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt8CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Int64Valuer).Int64Value() if err != nil { return nil, err @@ -1337,7 +1337,7 @@ func (encodePlanInt8CodecTextInt64Valuer) Encode(value interface{}, buf []byte) return append(buf, strconv.FormatInt(n.Int64, 10)...), nil } -func (Int8Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (Int8Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -1410,7 +1410,7 @@ func (c Int8Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return n, nil } -func (c Int8Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c Int8Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -1425,7 +1425,7 @@ func (c Int8Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (in type scanPlanBinaryInt8ToInt8 struct{} -func (scanPlanBinaryInt8ToInt8) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToInt8) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1453,7 +1453,7 @@ func (scanPlanBinaryInt8ToInt8) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt8ToUint8 struct{} -func (scanPlanBinaryInt8ToUint8) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToUint8) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1483,7 +1483,7 @@ func (scanPlanBinaryInt8ToUint8) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt8ToInt16 struct{} -func (scanPlanBinaryInt8ToInt16) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToInt16) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1511,7 +1511,7 @@ func (scanPlanBinaryInt8ToInt16) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt8ToUint16 struct{} -func (scanPlanBinaryInt8ToUint16) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToUint16) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1541,7 +1541,7 @@ func (scanPlanBinaryInt8ToUint16) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt8ToInt32 struct{} -func (scanPlanBinaryInt8ToInt32) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToInt32) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1569,7 +1569,7 @@ func (scanPlanBinaryInt8ToInt32) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt8ToUint32 struct{} -func (scanPlanBinaryInt8ToUint32) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToUint32) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1599,7 +1599,7 @@ func (scanPlanBinaryInt8ToUint32) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt8ToInt64 struct{} -func (scanPlanBinaryInt8ToInt64) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToInt64) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1620,7 +1620,7 @@ func (scanPlanBinaryInt8ToInt64) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt8ToUint64 struct{} -func (scanPlanBinaryInt8ToUint64) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToUint64) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1646,7 +1646,7 @@ func (scanPlanBinaryInt8ToUint64) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt8ToInt struct{} -func (scanPlanBinaryInt8ToInt) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToInt) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1674,7 +1674,7 @@ func (scanPlanBinaryInt8ToInt) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt8ToUint struct{} -func (scanPlanBinaryInt8ToUint) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToUint) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1704,7 +1704,7 @@ func (scanPlanBinaryInt8ToUint) Scan(src []byte, dst interface{}) error { type scanPlanBinaryInt8ToInt64Scanner struct{} -func (scanPlanBinaryInt8ToInt64Scanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToInt64Scanner) Scan(src []byte, dst any) error { s, ok := (dst).(Int64Scanner) if !ok { return ErrScanTargetTypeChanged @@ -1725,7 +1725,7 @@ func (scanPlanBinaryInt8ToInt64Scanner) Scan(src []byte, dst interface{}) error type scanPlanBinaryInt8ToTextScanner struct{} -func (scanPlanBinaryInt8ToTextScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt8ToTextScanner) Scan(src []byte, dst any) error { s, ok := (dst).(TextScanner) if !ok { return ErrScanTargetTypeChanged @@ -1746,7 +1746,7 @@ func (scanPlanBinaryInt8ToTextScanner) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToInt8 struct{} -func (scanPlanTextAnyToInt8) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt8) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1767,7 +1767,7 @@ func (scanPlanTextAnyToInt8) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToUint8 struct{} -func (scanPlanTextAnyToUint8) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToUint8) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1788,7 +1788,7 @@ func (scanPlanTextAnyToUint8) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToInt16 struct{} -func (scanPlanTextAnyToInt16) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt16) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1809,7 +1809,7 @@ func (scanPlanTextAnyToInt16) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToUint16 struct{} -func (scanPlanTextAnyToUint16) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToUint16) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1830,7 +1830,7 @@ func (scanPlanTextAnyToUint16) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToInt32 struct{} -func (scanPlanTextAnyToInt32) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt32) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1851,7 +1851,7 @@ func (scanPlanTextAnyToInt32) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToUint32 struct{} -func (scanPlanTextAnyToUint32) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToUint32) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1872,7 +1872,7 @@ func (scanPlanTextAnyToUint32) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToInt64 struct{} -func (scanPlanTextAnyToInt64) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt64) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1893,7 +1893,7 @@ func (scanPlanTextAnyToInt64) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToUint64 struct{} -func (scanPlanTextAnyToUint64) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToUint64) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1914,7 +1914,7 @@ func (scanPlanTextAnyToUint64) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToInt struct{} -func (scanPlanTextAnyToInt) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1935,7 +1935,7 @@ func (scanPlanTextAnyToInt) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToUint struct{} -func (scanPlanTextAnyToUint) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToUint) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -1956,7 +1956,7 @@ func (scanPlanTextAnyToUint) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToInt64Scanner struct{} -func (scanPlanTextAnyToInt64Scanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt64Scanner) Scan(src []byte, dst any) error { s, ok := (dst).(Int64Scanner) if !ok { return ErrScanTargetTypeChanged diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index aa1db7fc..f46a1dc3 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -49,7 +49,7 @@ func (n Int<%= pg_byte_size %>) Int64Value() (Int8, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Int<%= pg_byte_size %>) Scan(src interface{}) error { +func (dst *Int<%= pg_byte_size %>) Scan(src any) error { if src == nil { *dst = Int<%= pg_byte_size %>{} return nil @@ -128,7 +128,7 @@ func (Int<%= pg_byte_size %>Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int<%= pg_byte_size %>Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (Int<%= pg_byte_size %>Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -151,21 +151,21 @@ func (Int<%= pg_byte_size %>Codec) PlanEncode(m *Map, oid uint32, format int16, type encodePlanInt<%= pg_byte_size %>CodecBinaryInt<%= pg_bit_size %> struct{} -func (encodePlanInt<%= pg_byte_size %>CodecBinaryInt<%= pg_bit_size %>) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt<%= pg_byte_size %>CodecBinaryInt<%= pg_bit_size %>) Encode(value any, buf []byte) (newBuf []byte, err error) { n := value.(int<%= pg_bit_size %>) return pgio.AppendInt<%= pg_bit_size %>(buf, int<%= pg_bit_size %>(n)), nil } type encodePlanInt<%= pg_byte_size %>CodecTextInt<%= pg_bit_size %> struct{} -func (encodePlanInt<%= pg_byte_size %>CodecTextInt<%= pg_bit_size %>) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt<%= pg_byte_size %>CodecTextInt<%= pg_bit_size %>) Encode(value any, buf []byte) (newBuf []byte, err error) { n := value.(int<%= pg_bit_size %>) return append(buf, strconv.FormatInt(int64(n), 10)...), nil } type encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer struct{} -func (encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Int64Valuer).Int64Value() if err != nil { return nil, err @@ -187,7 +187,7 @@ func (encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer) Encode(value inter type encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer struct{} -func (encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Int64Valuer).Int64Value() if err != nil { return nil, err @@ -207,7 +207,7 @@ func (encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer) Encode(value interfa return append(buf, strconv.FormatInt(n.Int64, 10)...), nil } -func (Int<%= pg_byte_size %>Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (Int<%= pg_byte_size %>Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -280,7 +280,7 @@ func (c Int<%= pg_byte_size %>Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, return n, nil } -func (c Int<%= pg_byte_size %>Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c Int<%= pg_byte_size %>Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -297,7 +297,7 @@ func (c Int<%= pg_byte_size %>Codec) DecodeValue(m *Map, oid uint32, format int1 <% [8, 16, 32, 64].each do |dst_bit_size| %> type scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %> struct{} -func (scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %>) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %>) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -331,7 +331,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %>) Scan(src []b type scanPlanBinaryInt<%= pg_byte_size %>ToUint<%= dst_bit_size %> struct{} -func (scanPlanBinaryInt<%= pg_byte_size %>ToUint<%= dst_bit_size %>) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt<%= pg_byte_size %>ToUint<%= dst_bit_size %>) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -363,7 +363,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToUint<%= dst_bit_size %>) Scan(src [] <%# PostgreSQL binary format integer to Go machine integers %> type scanPlanBinaryInt<%= pg_byte_size %>ToInt struct{} -func (scanPlanBinaryInt<%= pg_byte_size %>ToInt) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt<%= pg_byte_size %>ToInt) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -395,7 +395,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToInt) Scan(src []byte, dst interface{ type scanPlanBinaryInt<%= pg_byte_size %>ToUint struct{} -func (scanPlanBinaryInt<%= pg_byte_size %>ToUint) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt<%= pg_byte_size %>ToUint) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -426,7 +426,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToUint) Scan(src []byte, dst interface <%# PostgreSQL binary format integer to Go Int64Scanner %> type scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner struct{} -func (scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner) Scan(src []byte, dst any) error { s, ok := (dst).(Int64Scanner) if !ok { return ErrScanTargetTypeChanged @@ -449,7 +449,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner) Scan(src []byte, dst i <%# PostgreSQL binary format integer to Go TextScanner %> type scanPlanBinaryInt<%= pg_byte_size %>ToTextScanner struct{} -func (scanPlanBinaryInt<%= pg_byte_size %>ToTextScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryInt<%= pg_byte_size %>ToTextScanner) Scan(src []byte, dst any) error { s, ok := (dst).(TextScanner) if !ok { return ErrScanTargetTypeChanged @@ -480,7 +480,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToTextScanner) Scan(src []byte, dst in ].each do |type_suffix, bit_size| %> type scanPlanTextAnyToInt<%= type_suffix %> struct{} -func (scanPlanTextAnyToInt<%= type_suffix %>) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt<%= type_suffix %>) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -501,7 +501,7 @@ func (scanPlanTextAnyToInt<%= type_suffix %>) Scan(src []byte, dst interface{}) type scanPlanTextAnyToUint<%= type_suffix %> struct{} -func (scanPlanTextAnyToUint<%= type_suffix %>) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToUint<%= type_suffix %>) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -523,7 +523,7 @@ func (scanPlanTextAnyToUint<%= type_suffix %>) Scan(src []byte, dst interface{}) type scanPlanTextAnyToInt64Scanner struct{} -func (scanPlanTextAnyToInt64Scanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt64Scanner) Scan(src []byte, dst any) error { s, ok := (dst).(Int64Scanner) if !ok { return ErrScanTargetTypeChanged diff --git a/pgtype/integration_benchmark_test.go b/pgtype/integration_benchmark_test.go index 66758d07..624c29ea 100644 --- a/pgtype/integration_benchmark_test.go +++ b/pgtype/integration_benchmark_test.go @@ -17,8 +17,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *test _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -36,8 +36,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *te _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -55,8 +55,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_10_columns(b *tes _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -74,8 +74,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_1_rows_10_columns(b *t _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -93,8 +93,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_10_rows_1_columns(b *tes _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -112,8 +112,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_10_rows_1_columns(b *t _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -131,8 +131,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_100_rows_10_columns(b *t _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -150,8 +150,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_100_rows_10_columns(b _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -169,8 +169,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_1_rows_1_columns(b *test _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -188,8 +188,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_1_rows_1_columns(b *te _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -207,8 +207,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_1_rows_10_columns(b *tes _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -226,8 +226,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_1_rows_10_columns(b *t _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -245,8 +245,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_10_rows_1_columns(b *tes _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -264,8 +264,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_10_rows_1_columns(b *t _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -283,8 +283,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_100_rows_10_columns(b *t _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -302,8 +302,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_100_rows_10_columns(b _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -321,8 +321,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_1_rows_1_columns(b *test _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -340,8 +340,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_1_rows_1_columns(b *te _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -359,8 +359,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_1_rows_10_columns(b *tes _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -378,8 +378,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_1_rows_10_columns(b *t _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -397,8 +397,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_10_rows_1_columns(b *tes _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -416,8 +416,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_10_rows_1_columns(b *t _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -435,8 +435,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_100_rows_10_columns(b *t _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -454,8 +454,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_100_rows_10_columns(b _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -473,8 +473,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_1_rows_1_columns(b *tes _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -492,8 +492,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_1_rows_1_columns(b *t _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -511,8 +511,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_1_rows_10_columns(b *te _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -530,8 +530,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_1_rows_10_columns(b * _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -549,8 +549,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_10_rows_1_columns(b *te _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -568,8 +568,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_10_rows_1_columns(b * _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -587,8 +587,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_100_rows_10_columns(b * _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -606,8 +606,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_100_rows_10_columns(b _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -625,8 +625,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_1_columns(b _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -644,8 +644,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_1_columns _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -663,8 +663,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_10_columns( _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -682,8 +682,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_10_column _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -701,8 +701,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_10_rows_1_columns( _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -720,8 +720,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_10_rows_1_column _, err := conn.QueryFunc( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -739,8 +739,8 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_100_rows_10_column _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -758,8 +758,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_100_rows_10_colu _, err := conn.QueryFunc( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -777,8 +777,8 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_1_rows_1_columns(b *t _, err := conn.QueryFunc( ctx, `select n::numeric + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -796,8 +796,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_1_rows_1_columns(b _, err := conn.QueryFunc( ctx, `select n::numeric + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -815,8 +815,8 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_1_rows_10_columns(b * _, err := conn.QueryFunc( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -834,8 +834,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_1_rows_10_columns(b _, err := conn.QueryFunc( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -853,8 +853,8 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_10_rows_1_columns(b * _, err := conn.QueryFunc( ctx, `select n::numeric + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -872,8 +872,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_10_rows_1_columns(b _, err := conn.QueryFunc( ctx, `select n::numeric + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -891,8 +891,8 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_100_rows_10_columns(b _, err := conn.QueryFunc( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -910,8 +910,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_100_rows_10_columns _, err := conn.QueryFunc( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -929,8 +929,8 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_1_rows_1_columns(b _, err := conn.QueryFunc( ctx, `select n::numeric + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -948,8 +948,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_1_rows_1_columns( _, err := conn.QueryFunc( ctx, `select n::numeric + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -967,8 +967,8 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_1_rows_10_columns(b _, err := conn.QueryFunc( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -986,8 +986,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_1_rows_10_columns _, err := conn.QueryFunc( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1005,8 +1005,8 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_10_rows_1_columns(b _, err := conn.QueryFunc( ctx, `select n::numeric + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1024,8 +1024,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_10_rows_1_columns _, err := conn.QueryFunc( ctx, `select n::numeric + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1043,8 +1043,8 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_100_rows_10_columns _, err := conn.QueryFunc( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1062,8 +1062,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_100_rows_10_colum _, err := conn.QueryFunc( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1081,8 +1081,8 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_1_col _, err := conn.QueryFunc( ctx, `select n::numeric + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1100,8 +1100,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_1_c _, err := conn.QueryFunc( ctx, `select n::numeric + 0 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1119,8 +1119,8 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_10_co _, err := conn.QueryFunc( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1138,8 +1138,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_10_ _, err := conn.QueryFunc( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1157,8 +1157,8 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_10_rows_1_co _, err := conn.QueryFunc( ctx, `select n::numeric + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1176,8 +1176,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_10_rows_1_ _, err := conn.QueryFunc( ctx, `select n::numeric + 0 from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1195,8 +1195,8 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_10_ _, err := conn.QueryFunc( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1214,8 +1214,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_1 _, err := conn.QueryFunc( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1233,8 +1233,8 @@ func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testing _, err := conn.QueryFunc( ctx, `select array_agg(n) from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1252,8 +1252,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testi _, err := conn.QueryFunc( ctx, `select array_agg(n) from generate_series(1, 10) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1271,8 +1271,8 @@ func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *testin _, err := conn.QueryFunc( ctx, `select array_agg(n) from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1290,8 +1290,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *test _, err := conn.QueryFunc( ctx, `select array_agg(n) from generate_series(1, 100) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1309,8 +1309,8 @@ func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *testi _, err := conn.QueryFunc( ctx, `select array_agg(n) from generate_series(1, 1000) n`, - []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []interface{}{&v}, + []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []any{&v}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -1328,8 +1328,8 @@ func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *tes _, err := conn.QueryFunc( ctx, `select array_agg(n) from generate_series(1, 1000) n`, - []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []interface{}{&v}, + []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []any{&v}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { diff --git a/pgtype/integration_benchmark_test.go.erb b/pgtype/integration_benchmark_test.go.erb index b122c606..3459f0cb 100644 --- a/pgtype/integration_benchmark_test.go.erb +++ b/pgtype/integration_benchmark_test.go.erb @@ -25,8 +25,8 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go _, err := conn.QueryFunc( ctx, `select <% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>n::<%= pg_type %> + <%= col_idx%><% end %> from generate_series(1, <%= rows %>) n`, - []interface{}{pgx.QueryResultFormats{<%= format_code %>}}, - []interface{}{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, + []any{pgx.QueryResultFormats{<%= format_code %>}}, + []any{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { @@ -50,8 +50,8 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_Int4Array_With_Go_Int4Array _, err := conn.QueryFunc( ctx, `select array_agg(n) from generate_series(1, <%= array_size %>) n`, - []interface{}{pgx.QueryResultFormats{<%= format_code %>}}, - []interface{}{&v}, + []any{pgx.QueryResultFormats{<%= format_code %>}}, + []any{&v}, func(pgx.QueryFuncRow) error { return nil }, ) if err != nil { diff --git a/pgtype/interval.go b/pgtype/interval.go index 882fd6d6..a172ecdb 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -43,7 +43,7 @@ func (interval Interval) IntervalValue() (Interval, error) { } // Scan implements the database/sql Scanner interface. -func (interval *Interval) Scan(src interface{}) error { +func (interval *Interval) Scan(src any) error { if src == nil { *interval = Interval{} return nil @@ -80,7 +80,7 @@ func (IntervalCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (IntervalCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (IntervalCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(IntervalValuer); !ok { return nil } @@ -97,7 +97,7 @@ func (IntervalCodec) PlanEncode(m *Map, oid uint32, format int16, value interfac type encodePlanIntervalCodecBinary struct{} -func (encodePlanIntervalCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanIntervalCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { interval, err := value.(IntervalValuer).IntervalValue() if err != nil { return nil, err @@ -115,7 +115,7 @@ func (encodePlanIntervalCodecBinary) Encode(value interface{}, buf []byte) (newB type encodePlanIntervalCodecText struct{} -func (encodePlanIntervalCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanIntervalCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { interval, err := value.(IntervalValuer).IntervalValue() if err != nil { return nil, err @@ -151,7 +151,7 @@ func (encodePlanIntervalCodecText) Encode(value interface{}, buf []byte) (newBuf return buf, nil } -func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -171,7 +171,7 @@ func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target interface type scanPlanBinaryIntervalToIntervalScanner struct{} -func (scanPlanBinaryIntervalToIntervalScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryIntervalToIntervalScanner) Scan(src []byte, dst any) error { scanner := (dst).(IntervalScanner) if src == nil { @@ -191,7 +191,7 @@ func (scanPlanBinaryIntervalToIntervalScanner) Scan(src []byte, dst interface{}) type scanPlanTextAnyToIntervalScanner struct{} -func (scanPlanTextAnyToIntervalScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToIntervalScanner) Scan(src []byte, dst any) error { scanner := (dst).(IntervalScanner) if src == nil { @@ -278,7 +278,7 @@ func (c IntervalCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c IntervalCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c IntervalCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/json.go b/pgtype/json.go index 4d8cf4c4..de2f08df 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -16,7 +16,7 @@ func (JSONCodec) PreferredFormat() int16 { return TextFormatCode } -func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch value.(type) { case string: return encodePlanJSONCodecEitherFormatString{} @@ -43,7 +43,7 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{ type encodePlanJSONCodecEitherFormatString struct{} -func (encodePlanJSONCodecEitherFormatString) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanJSONCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) { jsonString := value.(string) buf = append(buf, jsonString...) return buf, nil @@ -51,7 +51,7 @@ func (encodePlanJSONCodecEitherFormatString) Encode(value interface{}, buf []byt type encodePlanJSONCodecEitherFormatByteSlice struct{} -func (encodePlanJSONCodecEitherFormatByteSlice) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanJSONCodecEitherFormatByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) { jsonBytes := value.([]byte) if jsonBytes == nil { return nil, nil @@ -63,7 +63,7 @@ func (encodePlanJSONCodecEitherFormatByteSlice) Encode(value interface{}, buf [] type encodePlanJSONCodecEitherFormatMarshal struct{} -func (encodePlanJSONCodecEitherFormatMarshal) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { jsonBytes, err := json.Marshal(value) if err != nil { return nil, err @@ -73,7 +73,7 @@ func (encodePlanJSONCodecEitherFormatMarshal) Encode(value interface{}, buf []by return buf, nil } -func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch target.(type) { case *string: return scanPlanAnyToString{} @@ -89,7 +89,7 @@ func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) type scanPlanAnyToString struct{} -func (scanPlanAnyToString) Scan(src []byte, dst interface{}) error { +func (scanPlanAnyToString) Scan(src []byte, dst any) error { p := dst.(*string) *p = string(src) return nil @@ -97,7 +97,7 @@ func (scanPlanAnyToString) Scan(src []byte, dst interface{}) error { type scanPlanJSONToByteSlice struct{} -func (scanPlanJSONToByteSlice) Scan(src []byte, dst interface{}) error { +func (scanPlanJSONToByteSlice) Scan(src []byte, dst any) error { dstBuf := dst.(*[]byte) if src == nil { *dstBuf = nil @@ -111,14 +111,14 @@ func (scanPlanJSONToByteSlice) Scan(src []byte, dst interface{}) error { type scanPlanJSONToBytesScanner struct{} -func (scanPlanJSONToBytesScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanJSONToBytesScanner) Scan(src []byte, dst any) error { scanner := (dst).(BytesScanner) return scanner.ScanBytes(src) } type scanPlanJSONToJSONUnmarshal struct{} -func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst interface{}) error { +func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { if src == nil { dstValue := reflect.ValueOf(dst) if dstValue.Kind() == reflect.Ptr { @@ -144,12 +144,12 @@ func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return dstBuf, nil } -func (c JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } - var dst interface{} + var dst any err := json.Unmarshal(src, &dst) return dst, err } diff --git a/pgtype/json_test.go b/pgtype/json_test.go index 0275b1e6..c349fa24 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -7,10 +7,10 @@ import ( "github.com/jackc/pgx/v5/pgxtest" ) -func isExpectedEqMap(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { - aa := a.(map[string]interface{}) - bb := v.(map[string]interface{}) +func isExpectedEqMap(a any) func(any) bool { + return func(v any) bool { + aa := a.(map[string]any) + bb := v.(map[string]any) if (aa == nil) != (bb == nil) { return false @@ -42,8 +42,8 @@ func TestJSONCodec(t *testing.T) { pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "json", []pgxtest.ValueRoundTripTest{ {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, - {map[string]interface{}(nil), new(*string), isExpectedEq((*string)(nil))}, - {map[string]interface{}(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {map[string]any(nil), new(*string), isExpectedEq((*string)(nil))}, + {map[string]any(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, }) @@ -54,7 +54,7 @@ func TestJSONCodec(t *testing.T) { {[]byte("42"), new([]byte), isExpectedEqBytes([]byte("42"))}, {[]byte(`"hello"`), new([]byte), isExpectedEqBytes([]byte(`"hello"`))}, {[]byte(`"hello"`), new(string), isExpectedEq(`"hello"`)}, - {map[string]interface{}{"foo": "bar"}, new(map[string]interface{}), isExpectedEqMap(map[string]interface{}{"foo": "bar"})}, + {map[string]any{"foo": "bar"}, new(map[string]any), isExpectedEqMap(map[string]any{"foo": "bar"})}, {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, }) } diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index 7e3d3f8d..25555e7f 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -16,7 +16,7 @@ func (JSONBCodec) PreferredFormat() int16 { return TextFormatCode } -func (JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case BinaryFormatCode: plan := JSONCodec{}.PlanEncode(m, oid, TextFormatCode, value) @@ -34,12 +34,12 @@ type encodePlanJSONBCodecBinaryWrapper struct { textPlan EncodePlan } -func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value any, buf []byte) (newBuf []byte, err error) { buf = append(buf, 1) return plan.textPlan.Encode(value, buf) } -func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: plan := JSONCodec{}.PlanScan(m, oid, TextFormatCode, target) @@ -57,7 +57,7 @@ type scanPlanJSONBCodecBinaryUnwrapper struct { textPlan ScanPlan } -func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst interface{}) error { +func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst any) error { if src == nil { return plan.textPlan.Scan(src, dst) } @@ -100,7 +100,7 @@ func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src } } -func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -121,7 +121,7 @@ func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (i return nil, fmt.Errorf("unknown format code: %v", format) } - var dst interface{} + var dst any err := json.Unmarshal(src, &dst) return dst, err } diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index 4a9f7a35..7dadc6c5 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -15,8 +15,8 @@ func TestJSONBTranscode(t *testing.T) { pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "jsonb", []pgxtest.ValueRoundTripTest{ {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, - {map[string]interface{}(nil), new(*string), isExpectedEq((*string)(nil))}, - {map[string]interface{}(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {map[string]any(nil), new(*string), isExpectedEq((*string)(nil))}, + {map[string]any(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, }) @@ -27,7 +27,7 @@ func TestJSONBTranscode(t *testing.T) { {[]byte("42"), new([]byte), isExpectedEqBytes([]byte("42"))}, {[]byte(`"hello"`), new([]byte), isExpectedEqBytes([]byte(`"hello"`))}, {[]byte(`"hello"`), new(string), isExpectedEq(`"hello"`)}, - {map[string]interface{}{"foo": "bar"}, new(map[string]interface{}), isExpectedEqMap(map[string]interface{}{"foo": "bar"})}, + {map[string]any{"foo": "bar"}, new(map[string]any), isExpectedEqMap(map[string]any{"foo": "bar"})}, {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, }) } diff --git a/pgtype/line.go b/pgtype/line.go index 087c7688..4ae8003e 100644 --- a/pgtype/line.go +++ b/pgtype/line.go @@ -33,12 +33,12 @@ func (line Line) LineValue() (Line, error) { return line, nil } -func (line *Line) Set(src interface{}) error { +func (line *Line) Set(src any) error { return fmt.Errorf("cannot convert %v to Line", src) } // Scan implements the database/sql Scanner interface. -func (line *Line) Scan(src interface{}) error { +func (line *Line) Scan(src any) error { if src == nil { *line = Line{} return nil @@ -75,7 +75,7 @@ func (LineCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (LineCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (LineCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(LineValuer); !ok { return nil } @@ -92,7 +92,7 @@ func (LineCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) type encodePlanLineCodecBinary struct{} -func (encodePlanLineCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanLineCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { line, err := value.(LineValuer).LineValue() if err != nil { return nil, err @@ -110,7 +110,7 @@ func (encodePlanLineCodecBinary) Encode(value interface{}, buf []byte) (newBuf [ type encodePlanLineCodecText struct{} -func (encodePlanLineCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanLineCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { line, err := value.(LineValuer).LineValue() if err != nil { return nil, err @@ -128,7 +128,7 @@ func (encodePlanLineCodecText) Encode(value interface{}, buf []byte) (newBuf []b return buf, nil } -func (LineCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (LineCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -148,7 +148,7 @@ func (LineCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) type scanPlanBinaryLineToLineScanner struct{} -func (scanPlanBinaryLineToLineScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryLineToLineScanner) Scan(src []byte, dst any) error { scanner := (dst).(LineScanner) if src == nil { @@ -173,7 +173,7 @@ func (scanPlanBinaryLineToLineScanner) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToLineScanner struct{} -func (scanPlanTextAnyToLineScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToLineScanner) Scan(src []byte, dst any) error { scanner := (dst).(LineScanner) if src == nil { @@ -211,7 +211,7 @@ func (c LineCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c LineCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c LineCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/lseg.go b/pgtype/lseg.go index f5cf888e..97f130dc 100644 --- a/pgtype/lseg.go +++ b/pgtype/lseg.go @@ -34,7 +34,7 @@ func (lseg Lseg) LsegValue() (Lseg, error) { } // Scan implements the database/sql Scanner interface. -func (lseg *Lseg) Scan(src interface{}) error { +func (lseg *Lseg) Scan(src any) error { if src == nil { *lseg = Lseg{} return nil @@ -71,7 +71,7 @@ func (LsegCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (LsegCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (LsegCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(LsegValuer); !ok { return nil } @@ -88,7 +88,7 @@ func (LsegCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) type encodePlanLsegCodecBinary struct{} -func (encodePlanLsegCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanLsegCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { lseg, err := value.(LsegValuer).LsegValue() if err != nil { return nil, err @@ -107,7 +107,7 @@ func (encodePlanLsegCodecBinary) Encode(value interface{}, buf []byte) (newBuf [ type encodePlanLsegCodecText struct{} -func (encodePlanLsegCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanLsegCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { lseg, err := value.(LsegValuer).LsegValue() if err != nil { return nil, err @@ -126,7 +126,7 @@ func (encodePlanLsegCodecText) Encode(value interface{}, buf []byte) (newBuf []b return buf, nil } -func (LsegCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (LsegCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -146,7 +146,7 @@ func (LsegCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) type scanPlanBinaryLsegToLsegScanner struct{} -func (scanPlanBinaryLsegToLsegScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryLsegToLsegScanner) Scan(src []byte, dst any) error { scanner := (dst).(LsegScanner) if src == nil { @@ -173,7 +173,7 @@ func (scanPlanBinaryLsegToLsegScanner) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToLsegScanner struct{} -func (scanPlanTextAnyToLsegScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToLsegScanner) Scan(src []byte, dst any) error { scanner := (dst).(LsegScanner) if src == nil { @@ -224,7 +224,7 @@ func (c LsegCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c LsegCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c LsegCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go index 686e759a..e913ec90 100644 --- a/pgtype/macaddr.go +++ b/pgtype/macaddr.go @@ -15,7 +15,7 @@ func (MacaddrCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (MacaddrCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (MacaddrCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -39,7 +39,7 @@ func (MacaddrCodec) PlanEncode(m *Map, oid uint32, format int16, value interface type encodePlanMacaddrCodecBinaryHardwareAddr struct{} -func (encodePlanMacaddrCodecBinaryHardwareAddr) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanMacaddrCodecBinaryHardwareAddr) Encode(value any, buf []byte) (newBuf []byte, err error) { addr := value.(net.HardwareAddr) if addr == nil { return nil, nil @@ -50,7 +50,7 @@ func (encodePlanMacaddrCodecBinaryHardwareAddr) Encode(value interface{}, buf [] type encodePlanMacAddrCodecTextValuer struct{} -func (encodePlanMacAddrCodecTextValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanMacAddrCodecTextValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { t, err := value.(TextValuer).TextValue() if err != nil { return nil, err @@ -69,7 +69,7 @@ func (encodePlanMacAddrCodecTextValuer) Encode(value interface{}, buf []byte) (n type encodePlanMacaddrCodecTextHardwareAddr struct{} -func (encodePlanMacaddrCodecTextHardwareAddr) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanMacaddrCodecTextHardwareAddr) Encode(value any, buf []byte) (newBuf []byte, err error) { addr := value.(net.HardwareAddr) if addr == nil { return nil, nil @@ -78,7 +78,7 @@ func (encodePlanMacaddrCodecTextHardwareAddr) Encode(value interface{}, buf []by return append(buf, addr.String()...), nil } -func (MacaddrCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (MacaddrCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { @@ -101,7 +101,7 @@ func (MacaddrCodec) PlanScan(m *Map, oid uint32, format int16, target interface{ type scanPlanBinaryMacaddrToHardwareAddr struct{} -func (scanPlanBinaryMacaddrToHardwareAddr) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryMacaddrToHardwareAddr) Scan(src []byte, dst any) error { dstBuf := dst.(*net.HardwareAddr) if src == nil { *dstBuf = nil @@ -115,7 +115,7 @@ func (scanPlanBinaryMacaddrToHardwareAddr) Scan(src []byte, dst interface{}) err type scanPlanBinaryMacaddrToTextScanner struct{} -func (scanPlanBinaryMacaddrToTextScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryMacaddrToTextScanner) Scan(src []byte, dst any) error { scanner := (dst).(TextScanner) if src == nil { return scanner.ScanText(Text{}) @@ -126,7 +126,7 @@ func (scanPlanBinaryMacaddrToTextScanner) Scan(src []byte, dst interface{}) erro type scanPlanTextMacaddrToHardwareAddr struct{} -func (scanPlanTextMacaddrToHardwareAddr) Scan(src []byte, dst interface{}) error { +func (scanPlanTextMacaddrToHardwareAddr) Scan(src []byte, dst any) error { p := dst.(*net.HardwareAddr) if src == nil { @@ -148,7 +148,7 @@ func (c MacaddrCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, s return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c MacaddrCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c MacaddrCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go index e2463271..ef6dae00 100644 --- a/pgtype/macaddr_test.go +++ b/pgtype/macaddr_test.go @@ -9,8 +9,8 @@ import ( "github.com/jackc/pgx/v5/pgxtest" ) -func isExpectedEqHardwareAddr(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { +func isExpectedEqHardwareAddr(a any) func(any) bool { + return func(v any) bool { aa := a.(net.HardwareAddr) vv := v.(net.HardwareAddr) diff --git a/pgtype/numeric.go b/pgtype/numeric.go index 5ca7d077..a5f4ed3a 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -201,7 +201,7 @@ func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { } // Scan implements the database/sql Scanner interface. -func (n *Numeric) Scan(src interface{}) error { +func (n *Numeric) Scan(src any) error { if src == nil { *n = Numeric{} return nil @@ -281,7 +281,7 @@ func (NumericCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (NumericCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (NumericCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -308,7 +308,7 @@ func (NumericCodec) PlanEncode(m *Map, oid uint32, format int16, value interface type encodePlanNumericCodecBinaryNumericValuer struct{} -func (encodePlanNumericCodecBinaryNumericValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanNumericCodecBinaryNumericValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(NumericValuer).NumericValue() if err != nil { return nil, err @@ -319,7 +319,7 @@ func (encodePlanNumericCodecBinaryNumericValuer) Encode(value interface{}, buf [ type encodePlanNumericCodecBinaryFloat64Valuer struct{} -func (encodePlanNumericCodecBinaryFloat64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanNumericCodecBinaryFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Float64Valuer).Float64Value() if err != nil { return nil, err @@ -346,7 +346,7 @@ func (encodePlanNumericCodecBinaryFloat64Valuer) Encode(value interface{}, buf [ type encodePlanNumericCodecBinaryInt64Valuer struct{} -func (encodePlanNumericCodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanNumericCodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Int64Valuer).Int64Value() if err != nil { return nil, err @@ -460,7 +460,7 @@ func encodeNumericBinary(n Numeric, buf []byte) (newBuf []byte, err error) { type encodePlanNumericCodecTextNumericValuer struct{} -func (encodePlanNumericCodecTextNumericValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanNumericCodecTextNumericValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(NumericValuer).NumericValue() if err != nil { return nil, err @@ -471,7 +471,7 @@ func (encodePlanNumericCodecTextNumericValuer) Encode(value interface{}, buf []b type encodePlanNumericCodecTextFloat64Valuer struct{} -func (encodePlanNumericCodecTextFloat64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanNumericCodecTextFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Float64Valuer).Float64Value() if err != nil { return nil, err @@ -495,7 +495,7 @@ func (encodePlanNumericCodecTextFloat64Valuer) Encode(value interface{}, buf []b type encodePlanNumericCodecTextInt64Valuer struct{} -func (encodePlanNumericCodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanNumericCodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { n, err := value.(Int64Valuer).Int64Value() if err != nil { return nil, err @@ -530,7 +530,7 @@ func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { return buf, nil } -func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -560,7 +560,7 @@ func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target interface{ type scanPlanBinaryNumericToNumericScanner struct{} -func (scanPlanBinaryNumericToNumericScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryNumericToNumericScanner) Scan(src []byte, dst any) error { scanner := (dst).(NumericScanner) if src == nil { @@ -666,7 +666,7 @@ func (scanPlanBinaryNumericToNumericScanner) Scan(src []byte, dst interface{}) e type scanPlanBinaryNumericToFloat64Scanner struct{} -func (scanPlanBinaryNumericToFloat64Scanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryNumericToFloat64Scanner) Scan(src []byte, dst any) error { scanner := (dst).(Float64Scanner) if src == nil { @@ -690,7 +690,7 @@ func (scanPlanBinaryNumericToFloat64Scanner) Scan(src []byte, dst interface{}) e type scanPlanBinaryNumericToInt64Scanner struct{} -func (scanPlanBinaryNumericToInt64Scanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryNumericToInt64Scanner) Scan(src []byte, dst any) error { scanner := (dst).(Int64Scanner) if src == nil { @@ -718,7 +718,7 @@ func (scanPlanBinaryNumericToInt64Scanner) Scan(src []byte, dst interface{}) err type scanPlanBinaryNumericToTextScanner struct{} -func (scanPlanBinaryNumericToTextScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryNumericToTextScanner) Scan(src []byte, dst any) error { scanner := (dst).(TextScanner) if src == nil { @@ -742,7 +742,7 @@ func (scanPlanBinaryNumericToTextScanner) Scan(src []byte, dst interface{}) erro type scanPlanTextAnyToNumericScanner struct{} -func (scanPlanTextAnyToNumericScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToNumericScanner) Scan(src []byte, dst any) error { scanner := (dst).(NumericScanner) if src == nil { @@ -787,7 +787,7 @@ func (c NumericCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, s return string(buf), nil } -func (c NumericCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c NumericCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index d95deaa5..14225e92 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -24,8 +24,8 @@ func mustParseBigInt(t *testing.T, src string) *big.Int { return i } -func isExpectedEqNumeric(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { +func isExpectedEqNumeric(a any) func(any) bool { + return func(v any) bool { aa := a.(pgtype.Numeric) vv := v.(pgtype.Numeric) @@ -101,8 +101,8 @@ func TestNumericCodec(t *testing.T) { {pgtype.Numeric{NaN: true, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{NaN: true, Valid: true})}, {longestNumeric, new(pgtype.Numeric), isExpectedEqNumeric(longestNumeric)}, {mustParseNumeric(t, "1"), new(int64), isExpectedEq(int64(1))}, - {math.NaN(), new(float64), func(a interface{}) bool { return math.IsNaN(a.(float64)) }}, - {float32(math.NaN()), new(float32), func(a interface{}) bool { return math.IsNaN(float64(a.(float32))) }}, + {math.NaN(), new(float64), func(a any) bool { return math.IsNaN(a.(float64)) }}, + {float32(math.NaN()), new(float32), func(a any) bool { return math.IsNaN(float64(a.(float32))) }}, {int64(-1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "-1"))}, {int64(0), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0"))}, {int64(1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))}, diff --git a/pgtype/path.go b/pgtype/path.go index 10767404..73e0ec52 100644 --- a/pgtype/path.go +++ b/pgtype/path.go @@ -35,7 +35,7 @@ func (path Path) PathValue() (Path, error) { } // Scan implements the database/sql Scanner interface. -func (path *Path) Scan(src interface{}) error { +func (path *Path) Scan(src any) error { if src == nil { *path = Path{} return nil @@ -73,7 +73,7 @@ func (PathCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (PathCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (PathCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(PathValuer); !ok { return nil } @@ -90,7 +90,7 @@ func (PathCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) type encodePlanPathCodecBinary struct{} -func (encodePlanPathCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanPathCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { path, err := value.(PathValuer).PathValue() if err != nil { return nil, err @@ -118,7 +118,7 @@ func (encodePlanPathCodecBinary) Encode(value interface{}, buf []byte) (newBuf [ type encodePlanPathCodecText struct{} -func (encodePlanPathCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanPathCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { path, err := value.(PathValuer).PathValue() if err != nil { return nil, err @@ -153,7 +153,7 @@ func (encodePlanPathCodecText) Encode(value interface{}, buf []byte) (newBuf []b return buf, nil } -func (PathCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (PathCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -173,7 +173,7 @@ func (PathCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) type scanPlanBinaryPathToPathScanner struct{} -func (scanPlanBinaryPathToPathScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryPathToPathScanner) Scan(src []byte, dst any) error { scanner := (dst).(PathScanner) if src == nil { @@ -211,7 +211,7 @@ func (scanPlanBinaryPathToPathScanner) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToPathScanner struct{} -func (scanPlanTextAnyToPathScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToPathScanner) Scan(src []byte, dst any) error { scanner := (dst).(PathScanner) if src == nil { @@ -258,7 +258,7 @@ func (c PathCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c PathCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c PathCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/path_test.go b/pgtype/path_test.go index f9e13294..cfffd22a 100644 --- a/pgtype/path_test.go +++ b/pgtype/path_test.go @@ -8,8 +8,8 @@ import ( "github.com/jackc/pgx/v5/pgxtest" ) -func isExpectedEqPath(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { +func isExpectedEqPath(a any) func(any) bool { + return func(v any) bool { ap := a.(pgtype.Path) vp := v.(pgtype.Path) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 94158cb8..ffdb7020 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -142,21 +142,21 @@ type Codec interface { // PlanEncode returns an EncodePlan for encoding value into PostgreSQL format for oid and format. If no plan can be // found then nil is returned. - PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan + PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan // PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If // no plan can be found then nil is returned. - PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan + PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan // DecodeDatabaseSQLValue returns src decoded into a value compatible with the sql.Scanner interface. DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) // DecodeValue returns src decoded into its default format. - DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) + DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) } type nullAssignmentError struct { - dst interface{} + dst any } func (e *nullAssignmentError) Error() string { @@ -316,7 +316,7 @@ func NewMap() *Map { m.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarcharOID]}}) m.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[XIDOID]}}) - registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) { + registerDefaultPgTypeVariants := func(name, arrayName string, value any) { // T m.RegisterDefaultPgType(value, name) @@ -416,7 +416,7 @@ func (m *Map) RegisterType(t *Type) { // RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be // encoded or decoded is determined by the PostgreSQL OID. But if the OID of a value to be encoded or decoded is // unknown, this additional mapping will be used by TypeForValue to determine a suitable data type. -func (m *Map) RegisterDefaultPgType(value interface{}, name string) { +func (m *Map) RegisterDefaultPgType(value any, name string) { m.reflectTypeToName[reflect.TypeOf(value)] = name // Invalidated by type registration @@ -448,7 +448,7 @@ func (m *Map) buildReflectTypeToType() { // TypeForValue finds a data type suitable for v. Use RegisterType to register types that can encode and decode // themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. -func (m *Map) TypeForValue(v interface{}) (*Type, bool) { +func (m *Map) TypeForValue(v any) (*Type, bool) { if m.reflectTypeToType == nil { m.buildReflectTypeToType() } @@ -472,13 +472,13 @@ type EncodePlan interface { // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written. - Encode(value interface{}, buf []byte) (newBuf []byte, err error) + Encode(value any, buf []byte) (newBuf []byte, err error) } // ScanPlan is a precompiled plan to scan into a type of destination. type ScanPlan interface { // Scan scans src into target. - Scan(src []byte, target interface{}) error + Scan(src []byte, target any) error } type scanPlanCodecSQLScanner struct { @@ -488,7 +488,7 @@ type scanPlanCodecSQLScanner struct { formatCode int16 } -func (plan *scanPlanCodecSQLScanner) Scan(src []byte, dst interface{}) error { +func (plan *scanPlanCodecSQLScanner) Scan(src []byte, dst any) error { value, err := plan.c.DecodeDatabaseSQLValue(plan.m, plan.oid, plan.formatCode, src) if err != nil { return err @@ -502,7 +502,7 @@ type scanPlanSQLScanner struct { formatCode int16 } -func (plan *scanPlanSQLScanner) Scan(src []byte, dst interface{}) error { +func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error { scanner := dst.(sql.Scanner) if src == nil { // This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the @@ -517,7 +517,7 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst interface{}) error { type scanPlanString struct{} -func (scanPlanString) Scan(src []byte, dst interface{}) error { +func (scanPlanString) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -529,7 +529,7 @@ func (scanPlanString) Scan(src []byte, dst interface{}) error { type scanPlanAnyTextToBytes struct{} -func (scanPlanAnyTextToBytes) Scan(src []byte, dst interface{}) error { +func (scanPlanAnyTextToBytes) Scan(src []byte, dst any) error { dstBuf := dst.(*[]byte) if src == nil { *dstBuf = nil @@ -546,7 +546,7 @@ type scanPlanFail struct { formatCode int16 } -func (plan *scanPlanFail) Scan(src []byte, dst interface{}) error { +func (plan *scanPlanFail) Scan(src []byte, dst any) error { var format string switch plan.formatCode { case TextFormatCode: @@ -564,7 +564,7 @@ func (plan *scanPlanFail) Scan(src []byte, dst interface{}) error { // that will convert the target passed to Scan and then call the next plan. nextTarget is target as it will be converted // by plan. It must be used to find another suitable ScanPlan. When it is found SetNext must be called on plan for it // to be usabled. ok indicates if a suitable wrapper was found. -type TryWrapScanPlanFunc func(target interface{}) (plan WrappedScanPlanNextSetter, nextTarget interface{}, ok bool) +type TryWrapScanPlanFunc func(target any) (plan WrappedScanPlanNextSetter, nextTarget any, ok bool) type pointerPointerScanPlan struct { dstType reflect.Type @@ -573,7 +573,7 @@ type pointerPointerScanPlan struct { func (plan *pointerPointerScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *pointerPointerScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *pointerPointerScanPlan) Scan(src []byte, dst any) error { el := reflect.ValueOf(dst).Elem() if src == nil { el.Set(reflect.Zero(el.Type())) @@ -586,7 +586,7 @@ func (plan *pointerPointerScanPlan) Scan(src []byte, dst interface{}) error { // TryPointerPointerScanPlan handles a pointer to a pointer by setting the target to nil for SQL NULL and allocating and // scanning for non-NULL. -func TryPointerPointerScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextTarget interface{}, ok bool) { +func TryPointerPointerScanPlan(target any) (plan WrappedScanPlanNextSetter, nextTarget any, ok bool) { if dstValue := reflect.ValueOf(target); dstValue.Kind() == reflect.Ptr { elemValue := dstValue.Elem() if elemValue.Kind() == reflect.Ptr { @@ -627,13 +627,13 @@ type underlyingTypeScanPlan struct { func (plan *underlyingTypeScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *underlyingTypeScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *underlyingTypeScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, reflect.ValueOf(dst).Convert(plan.nextDstType).Interface()) } // TryFindUnderlyingTypeScanPlan tries to convert to a Go builtin type. e.g. If value was of type MyString and // MyString was defined as a string then a wrapper plan would be returned that converts MyString to string. -func TryFindUnderlyingTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter, nextDst interface{}, ok bool) { +func TryFindUnderlyingTypeScanPlan(dst any) (plan WrappedScanPlanNextSetter, nextDst any, ok bool) { if _, ok := dst.(SkipUnderlyingTypePlanner); ok { return nil, nil, false } @@ -667,7 +667,7 @@ type WrappedScanPlanNextSetter interface { // TryWrapBuiltinTypeScanPlan tries to wrap a builtin type with a wrapper that provides additional methods. e.g. If // value was of type int32 then a wrapper plan would be returned that converts target to a value that implements // Int64Scanner. -func TryWrapBuiltinTypeScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextDst interface{}, ok bool) { +func TryWrapBuiltinTypeScanPlan(target any) (plan WrappedScanPlanNextSetter, nextDst any, ok bool) { switch target := target.(type) { case *int8: return &wrapInt8ScanPlan{}, (*int8Wrapper)(target), true @@ -722,7 +722,7 @@ type wrapInt8ScanPlan struct { func (plan *wrapInt8ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapInt8ScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapInt8ScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*int8Wrapper)(dst.(*int8))) } @@ -732,7 +732,7 @@ type wrapInt16ScanPlan struct { func (plan *wrapInt16ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapInt16ScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapInt16ScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*int16Wrapper)(dst.(*int16))) } @@ -742,7 +742,7 @@ type wrapInt32ScanPlan struct { func (plan *wrapInt32ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapInt32ScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapInt32ScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*int32Wrapper)(dst.(*int32))) } @@ -752,7 +752,7 @@ type wrapInt64ScanPlan struct { func (plan *wrapInt64ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapInt64ScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapInt64ScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*int64Wrapper)(dst.(*int64))) } @@ -762,7 +762,7 @@ type wrapIntScanPlan struct { func (plan *wrapIntScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapIntScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapIntScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*intWrapper)(dst.(*int))) } @@ -772,7 +772,7 @@ type wrapUint8ScanPlan struct { func (plan *wrapUint8ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapUint8ScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapUint8ScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*uint8Wrapper)(dst.(*uint8))) } @@ -782,7 +782,7 @@ type wrapUint16ScanPlan struct { func (plan *wrapUint16ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapUint16ScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapUint16ScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*uint16Wrapper)(dst.(*uint16))) } @@ -792,7 +792,7 @@ type wrapUint32ScanPlan struct { func (plan *wrapUint32ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapUint32ScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapUint32ScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*uint32Wrapper)(dst.(*uint32))) } @@ -802,7 +802,7 @@ type wrapUint64ScanPlan struct { func (plan *wrapUint64ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapUint64ScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapUint64ScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*uint64Wrapper)(dst.(*uint64))) } @@ -812,7 +812,7 @@ type wrapUintScanPlan struct { func (plan *wrapUintScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapUintScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapUintScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*uintWrapper)(dst.(*uint))) } @@ -822,7 +822,7 @@ type wrapFloat32ScanPlan struct { func (plan *wrapFloat32ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapFloat32ScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapFloat32ScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*float32Wrapper)(dst.(*float32))) } @@ -832,7 +832,7 @@ type wrapFloat64ScanPlan struct { func (plan *wrapFloat64ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapFloat64ScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapFloat64ScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*float64Wrapper)(dst.(*float64))) } @@ -842,7 +842,7 @@ type wrapStringScanPlan struct { func (plan *wrapStringScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapStringScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapStringScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*stringWrapper)(dst.(*string))) } @@ -852,7 +852,7 @@ type wrapTimeScanPlan struct { func (plan *wrapTimeScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapTimeScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapTimeScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*timeWrapper)(dst.(*time.Time))) } @@ -862,7 +862,7 @@ type wrapDurationScanPlan struct { func (plan *wrapDurationScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapDurationScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapDurationScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*durationWrapper)(dst.(*time.Duration))) } @@ -872,7 +872,7 @@ type wrapNetIPNetScanPlan struct { func (plan *wrapNetIPNetScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapNetIPNetScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapNetIPNetScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*netIPNetWrapper)(dst.(*net.IPNet))) } @@ -882,7 +882,7 @@ type wrapNetIPScanPlan struct { func (plan *wrapNetIPScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapNetIPScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapNetIPScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*netIPWrapper)(dst.(*net.IP))) } @@ -892,7 +892,7 @@ type wrapMapStringToPointerStringScanPlan struct { func (plan *wrapMapStringToPointerStringScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapMapStringToPointerStringScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapMapStringToPointerStringScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*mapStringToPointerStringWrapper)(dst.(*map[string]*string))) } @@ -902,7 +902,7 @@ type wrapMapStringToStringScanPlan struct { func (plan *wrapMapStringToStringScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapMapStringToStringScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapMapStringToStringScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*mapStringToStringWrapper)(dst.(*map[string]string))) } @@ -912,7 +912,7 @@ type wrapByte16ScanPlan struct { func (plan *wrapByte16ScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapByte16ScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapByte16ScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*byte16Wrapper)(dst.(*[16]byte))) } @@ -922,7 +922,7 @@ type wrapByteSliceScanPlan struct { func (plan *wrapByteSliceScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapByteSliceScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *wrapByteSliceScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*byteSliceWrapper)(dst.(*[]byte))) } @@ -933,20 +933,20 @@ type pointerEmptyInterfaceScanPlan struct { formatCode int16 } -func (plan *pointerEmptyInterfaceScanPlan) Scan(src []byte, dst interface{}) error { +func (plan *pointerEmptyInterfaceScanPlan) Scan(src []byte, dst any) error { value, err := plan.codec.DecodeValue(plan.m, plan.oid, plan.formatCode, src) if err != nil { return err } - ptrAny := dst.(*interface{}) + ptrAny := dst.(*any) *ptrAny = value return nil } // TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. -func TryWrapStructScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextValue interface{}, ok bool) { +func TryWrapStructScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) { targetValue := reflect.ValueOf(target) if targetValue.Kind() != reflect.Ptr { return nil, nil, false @@ -982,7 +982,7 @@ type wrapAnyPtrStructScanPlan struct { func (plan *wrapAnyPtrStructScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapAnyPtrStructScanPlan) Scan(src []byte, target interface{}) error { +func (plan *wrapAnyPtrStructScanPlan) Scan(src []byte, target any) error { w := ptrStructWrapper{ s: target, exportedFields: getExportedFieldValues(reflect.ValueOf(target).Elem()), @@ -992,7 +992,7 @@ func (plan *wrapAnyPtrStructScanPlan) Scan(src []byte, target interface{}) error } // TryWrapPtrSliceScanPlan tries to wrap a pointer to a single dimension slice. -func TryWrapPtrSliceScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextValue interface{}, ok bool) { +func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) { targetValue := reflect.ValueOf(target) if targetValue.Kind() != reflect.Ptr { return nil, nil, false @@ -1012,12 +1012,12 @@ type wrapPtrSliceScanPlan struct { func (plan *wrapPtrSliceScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapPtrSliceScanPlan) Scan(src []byte, target interface{}) error { +func (plan *wrapPtrSliceScanPlan) Scan(src []byte, target any) error { return plan.next.Scan(src, &anySliceArray{slice: reflect.ValueOf(target).Elem()}) } // TryWrapPtrMultiDimSliceScanPlan tries to wrap a pointer to a multi-dimension slice. -func TryWrapPtrMultiDimSliceScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextValue interface{}, ok bool) { +func TryWrapPtrMultiDimSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) { targetValue := reflect.ValueOf(target) if targetValue.Kind() != reflect.Ptr { return nil, nil, false @@ -1043,12 +1043,12 @@ type wrapPtrMultiDimSliceScanPlan struct { func (plan *wrapPtrMultiDimSliceScanPlan) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapPtrMultiDimSliceScanPlan) Scan(src []byte, target interface{}) error { +func (plan *wrapPtrMultiDimSliceScanPlan) Scan(src []byte, target any) error { return plan.next.Scan(src, &anyMultiDimSliceArray{slice: reflect.ValueOf(target).Elem()}) } // PlanScan prepares a plan to scan a value into target. -func (m *Map) PlanScan(oid uint32, formatCode int16, target interface{}) ScanPlan { +func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan { oidMemo := m.memoizedScanPlans[oid] if oidMemo == nil { oidMemo = make(map[reflect.Type][2]ScanPlan) @@ -1066,7 +1066,7 @@ func (m *Map) PlanScan(oid uint32, formatCode int16, target interface{}) ScanPla return plan } -func (m *Map) planScan(oid uint32, formatCode int16, target interface{}) ScanPlan { +func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { if _, ok := target.(*UndecodedBytes); ok { return scanPlanAnyToUndecodedBytes{} } @@ -1120,7 +1120,7 @@ func (m *Map) planScan(oid uint32, formatCode int16, target interface{}) ScanPla } if dt != nil { - if _, ok := target.(*interface{}); ok { + if _, ok := target.(*any); ok { return &pointerEmptyInterfaceScanPlan{codec: dt.Codec, m: m, oid: oid, formatCode: formatCode} } @@ -1136,7 +1136,7 @@ func (m *Map) planScan(oid uint32, formatCode int16, target interface{}) ScanPla return &scanPlanFail{oid: oid, formatCode: formatCode} } -func (m *Map) Scan(oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (m *Map) Scan(oid uint32, formatCode int16, src []byte, dst any) error { if dst == nil { return nil } @@ -1145,7 +1145,7 @@ func (m *Map) Scan(oid uint32, formatCode int16, src []byte, dst interface{}) er return plan.Scan(src, dst) } -func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) error { +func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest any) error { switch dest := dest.(type) { case *string: if formatCode == BinaryFormatCode { @@ -1166,7 +1166,7 @@ func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) var ErrScanTargetTypeChanged = errors.New("scan target type changed") -func codecScan(codec Codec, m *Map, oid uint32, format int16, src []byte, dst interface{}) error { +func codecScan(codec Codec, m *Map, oid uint32, format int16, src []byte, dst any) error { scanPlan := codec.PlanScan(m, oid, format, dst) if scanPlan == nil { return fmt.Errorf("PlanScan did not find a plan") @@ -1196,7 +1196,7 @@ func codecDecodeToTextFormat(codec Codec, m *Map, oid uint32, format int16, src // PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be // found then nil is returned. -func (m *Map) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan { +func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { if format == TextFormatCode { switch value.(type) { case string: @@ -1239,14 +1239,14 @@ func (m *Map) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan type encodePlanStringToAnyTextFormat struct{} -func (encodePlanStringToAnyTextFormat) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanStringToAnyTextFormat) Encode(value any, buf []byte) (newBuf []byte, err error) { s := value.(string) return append(buf, s...), nil } type encodePlanTextValuerToAnyTextFormat struct{} -func (encodePlanTextValuerToAnyTextFormat) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTextValuerToAnyTextFormat) Encode(value any, buf []byte) (newBuf []byte, err error) { t, err := value.(TextValuer).TextValue() if err != nil { return nil, err @@ -1262,7 +1262,7 @@ func (encodePlanTextValuerToAnyTextFormat) Encode(value interface{}, buf []byte) // that will convert the value passed to Encode and then call the next plan. nextValue is value as it will be converted // by plan. It must be used to find another suitable EncodePlan. When it is found SetNext must be called on plan for it // to be usabled. ok indicates if a suitable wrapper was found. -type TryWrapEncodePlanFunc func(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) +type TryWrapEncodePlanFunc func(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) type derefPointerEncodePlan struct { next EncodePlan @@ -1270,7 +1270,7 @@ type derefPointerEncodePlan struct { func (plan *derefPointerEncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *derefPointerEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *derefPointerEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { ptr := reflect.ValueOf(value) if ptr.IsNil() { @@ -1282,7 +1282,7 @@ func (plan *derefPointerEncodePlan) Encode(value interface{}, buf []byte) (newBu // TryWrapDerefPointerEncodePlan tries to dereference a pointer. e.g. If value was of type *string then a wrapper plan // would be returned that derefences the value. -func TryWrapDerefPointerEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { +func TryWrapDerefPointerEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { if valueType := reflect.TypeOf(value); valueType.Kind() == reflect.Ptr { return &derefPointerEncodePlan{}, reflect.New(valueType.Elem()).Elem().Interface(), true } @@ -1313,13 +1313,13 @@ type underlyingTypeEncodePlan struct { func (plan *underlyingTypeEncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *underlyingTypeEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *underlyingTypeEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(reflect.ValueOf(value).Convert(plan.nextValueType).Interface(), buf) } // TryWrapFindUnderlyingTypeEncodePlan tries to convert to a Go builtin type. e.g. If value was of type MyString and // MyString was defined as a string then a wrapper plan would be returned that converts MyString to string. -func TryWrapFindUnderlyingTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { +func TryWrapFindUnderlyingTypeEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { if _, ok := value.(SkipUnderlyingTypePlanner); ok { return nil, nil, false } @@ -1342,7 +1342,7 @@ type WrappedEncodePlanNextSetter interface { // TryWrapBuiltinTypeEncodePlan tries to wrap a builtin type with a wrapper that provides additional methods. e.g. If // value was of type int32 then a wrapper plan would be returned that converts value to a type that implements // Int64Valuer. -func TryWrapBuiltinTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { +func TryWrapBuiltinTypeEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { switch value := value.(type) { case int8: return &wrapInt8EncodePlan{}, int8Wrapper(value), true @@ -1399,7 +1399,7 @@ type wrapInt8EncodePlan struct { func (plan *wrapInt8EncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapInt8EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapInt8EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(int8Wrapper(value.(int8)), buf) } @@ -1409,7 +1409,7 @@ type wrapInt16EncodePlan struct { func (plan *wrapInt16EncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapInt16EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapInt16EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(int16Wrapper(value.(int16)), buf) } @@ -1419,7 +1419,7 @@ type wrapInt32EncodePlan struct { func (plan *wrapInt32EncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapInt32EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapInt32EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(int32Wrapper(value.(int32)), buf) } @@ -1429,7 +1429,7 @@ type wrapInt64EncodePlan struct { func (plan *wrapInt64EncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapInt64EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapInt64EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(int64Wrapper(value.(int64)), buf) } @@ -1439,7 +1439,7 @@ type wrapIntEncodePlan struct { func (plan *wrapIntEncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapIntEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapIntEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(intWrapper(value.(int)), buf) } @@ -1449,7 +1449,7 @@ type wrapUint8EncodePlan struct { func (plan *wrapUint8EncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapUint8EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapUint8EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(uint8Wrapper(value.(uint8)), buf) } @@ -1459,7 +1459,7 @@ type wrapUint16EncodePlan struct { func (plan *wrapUint16EncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapUint16EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapUint16EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(uint16Wrapper(value.(uint16)), buf) } @@ -1469,7 +1469,7 @@ type wrapUint32EncodePlan struct { func (plan *wrapUint32EncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapUint32EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapUint32EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(uint32Wrapper(value.(uint32)), buf) } @@ -1479,7 +1479,7 @@ type wrapUint64EncodePlan struct { func (plan *wrapUint64EncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapUint64EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapUint64EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(uint64Wrapper(value.(uint64)), buf) } @@ -1489,7 +1489,7 @@ type wrapUintEncodePlan struct { func (plan *wrapUintEncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapUintEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapUintEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(uintWrapper(value.(uint)), buf) } @@ -1499,7 +1499,7 @@ type wrapFloat32EncodePlan struct { func (plan *wrapFloat32EncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapFloat32EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapFloat32EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(float32Wrapper(value.(float32)), buf) } @@ -1509,7 +1509,7 @@ type wrapFloat64EncodePlan struct { func (plan *wrapFloat64EncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapFloat64EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapFloat64EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(float64Wrapper(value.(float64)), buf) } @@ -1519,7 +1519,7 @@ type wrapStringEncodePlan struct { func (plan *wrapStringEncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapStringEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapStringEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(stringWrapper(value.(string)), buf) } @@ -1529,7 +1529,7 @@ type wrapTimeEncodePlan struct { func (plan *wrapTimeEncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapTimeEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapTimeEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(timeWrapper(value.(time.Time)), buf) } @@ -1539,7 +1539,7 @@ type wrapDurationEncodePlan struct { func (plan *wrapDurationEncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapDurationEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapDurationEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(durationWrapper(value.(time.Duration)), buf) } @@ -1549,7 +1549,7 @@ type wrapNetIPNetEncodePlan struct { func (plan *wrapNetIPNetEncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapNetIPNetEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapNetIPNetEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(netIPNetWrapper(value.(net.IPNet)), buf) } @@ -1559,7 +1559,7 @@ type wrapNetIPEncodePlan struct { func (plan *wrapNetIPEncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapNetIPEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapNetIPEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(netIPWrapper(value.(net.IP)), buf) } @@ -1569,7 +1569,7 @@ type wrapMapStringToPointerStringEncodePlan struct { func (plan *wrapMapStringToPointerStringEncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapMapStringToPointerStringEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapMapStringToPointerStringEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(mapStringToPointerStringWrapper(value.(map[string]*string)), buf) } @@ -1579,7 +1579,7 @@ type wrapMapStringToStringEncodePlan struct { func (plan *wrapMapStringToStringEncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapMapStringToStringEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapMapStringToStringEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(mapStringToStringWrapper(value.(map[string]string)), buf) } @@ -1589,7 +1589,7 @@ type wrapByte16EncodePlan struct { func (plan *wrapByte16EncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapByte16EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapByte16EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(byte16Wrapper(value.([16]byte)), buf) } @@ -1599,7 +1599,7 @@ type wrapByteSliceEncodePlan struct { func (plan *wrapByteSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapByteSliceEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapByteSliceEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(byteSliceWrapper(value.([]byte)), buf) } @@ -1609,12 +1609,12 @@ type wrapFmtStringerEncodePlan struct { func (plan *wrapFmtStringerEncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapFmtStringerEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapFmtStringerEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(fmtStringerWrapper{value.(fmt.Stringer)}, buf) } // TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. -func TryWrapStructEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { +func TryWrapStructEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { if reflect.TypeOf(value).Kind() == reflect.Struct { exportedFields := getExportedFieldValues(reflect.ValueOf(value)) if len(exportedFields) == 0 { @@ -1637,7 +1637,7 @@ type wrapAnyStructEncodePlan struct { func (plan *wrapAnyStructEncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapAnyStructEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapAnyStructEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { w := structWrapper{ s: value, exportedFields: getExportedFieldValues(reflect.ValueOf(value)), @@ -1659,7 +1659,7 @@ func getExportedFieldValues(structValue reflect.Value) []reflect.Value { return exportedFields } -func TryWrapSliceEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { +func TryWrapSliceEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { if reflect.TypeOf(value).Kind() == reflect.Slice { w := anySliceArray{ slice: reflect.ValueOf(value), @@ -1676,7 +1676,7 @@ type wrapSliceEncodePlan struct { func (plan *wrapSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapSliceEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapSliceEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { w := anySliceArray{ slice: reflect.ValueOf(value), } @@ -1684,7 +1684,7 @@ func (plan *wrapSliceEncodePlan) Encode(value interface{}, buf []byte) (newBuf [ return plan.next.Encode(w, buf) } -func TryWrapMultiDimSliceEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { +func TryWrapMultiDimSliceEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { sliceValue := reflect.ValueOf(value) if sliceValue.Kind() == reflect.Slice { valueElemType := sliceValue.Type().Elem() @@ -1708,7 +1708,7 @@ type wrapMultiDimSliceEncodePlan struct { func (plan *wrapMultiDimSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapMultiDimSliceEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *wrapMultiDimSliceEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { w := anyMultiDimSliceArray{ slice: reflect.ValueOf(value), } @@ -1719,7 +1719,7 @@ func (plan *wrapMultiDimSliceEncodePlan) Encode(value interface{}, buf []byte) ( // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written. -func (m *Map) Encode(oid uint32, formatCode int16, value interface{}, buf []byte) (newBuf []byte, err error) { +func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBuf []byte, err error) { if value == nil { return nil, nil } diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 9778c335..fa4c823b 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -125,7 +125,7 @@ func TestTypeMapScanNilIsNoOp(t *testing.T) { func TestTypeMapScanTextFormatInterfacePtr(t *testing.T) { m := pgtype.NewMap() - var got interface{} + var got any err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), &got) require.NoError(t, err) assert.Equal(t, "foo", got) @@ -141,7 +141,7 @@ func TestTypeMapScanTextFormatNonByteaIntoByteSlice(t *testing.T) { func TestTypeMapScanBinaryFormatInterfacePtr(t *testing.T) { m := pgtype.NewMap() - var got interface{} + var got any err := m.Scan(pgtype.TextOID, pgx.BinaryFormatCode, []byte("foo"), &got) require.NoError(t, err) assert.Equal(t, "foo", got) @@ -273,8 +273,8 @@ func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { } } -func isExpectedEq(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { +func isExpectedEq(a any) func(any) bool { + return func(v any) bool { return a == v } } diff --git a/pgtype/point.go b/pgtype/point.go index d2ddaf2f..cfa5a9f1 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -69,7 +69,7 @@ func parsePoint(src []byte) (*Point, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Point) Scan(src interface{}) error { +func (dst *Point) Scan(src any) error { if src == nil { *dst = Point{} return nil @@ -127,7 +127,7 @@ func (PointCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (PointCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (PointCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(PointValuer); !ok { return nil } @@ -144,7 +144,7 @@ func (PointCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{} type encodePlanPointCodecBinary struct{} -func (encodePlanPointCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanPointCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { point, err := value.(PointValuer).PointValue() if err != nil { return nil, err @@ -161,7 +161,7 @@ func (encodePlanPointCodecBinary) Encode(value interface{}, buf []byte) (newBuf type encodePlanPointCodecText struct{} -func (encodePlanPointCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanPointCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { point, err := value.(PointValuer).PointValue() if err != nil { return nil, err @@ -177,7 +177,7 @@ func (encodePlanPointCodecText) Encode(value interface{}, buf []byte) (newBuf [] )...), nil } -func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -199,7 +199,7 @@ func (c PointCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c PointCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c PointCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -214,7 +214,7 @@ func (c PointCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (i type scanPlanBinaryPointToPointScanner struct{} -func (scanPlanBinaryPointToPointScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryPointToPointScanner) Scan(src []byte, dst any) error { scanner := (dst).(PointScanner) if src == nil { @@ -236,7 +236,7 @@ func (scanPlanBinaryPointToPointScanner) Scan(src []byte, dst interface{}) error type scanPlanTextAnyToPointScanner struct{} -func (scanPlanTextAnyToPointScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToPointScanner) Scan(src []byte, dst any) error { scanner := (dst).(PointScanner) if src == nil { diff --git a/pgtype/polygon.go b/pgtype/polygon.go index a7a6d606..04b0ba6b 100644 --- a/pgtype/polygon.go +++ b/pgtype/polygon.go @@ -34,7 +34,7 @@ func (p Polygon) PolygonValue() (Polygon, error) { } // Scan implements the database/sql Scanner interface. -func (p *Polygon) Scan(src interface{}) error { +func (p *Polygon) Scan(src any) error { if src == nil { *p = Polygon{} return nil @@ -72,7 +72,7 @@ func (PolygonCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (PolygonCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (PolygonCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(PolygonValuer); !ok { return nil } @@ -89,7 +89,7 @@ func (PolygonCodec) PlanEncode(m *Map, oid uint32, format int16, value interface type encodePlanPolygonCodecBinary struct{} -func (encodePlanPolygonCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanPolygonCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { polygon, err := value.(PolygonValuer).PolygonValue() if err != nil { return nil, err @@ -111,7 +111,7 @@ func (encodePlanPolygonCodecBinary) Encode(value interface{}, buf []byte) (newBu type encodePlanPolygonCodecText struct{} -func (encodePlanPolygonCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanPolygonCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { polygon, err := value.(PolygonValuer).PolygonValue() if err != nil { return nil, err @@ -138,7 +138,7 @@ func (encodePlanPolygonCodecText) Encode(value interface{}, buf []byte) (newBuf return buf, nil } -func (PolygonCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (PolygonCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -158,7 +158,7 @@ func (PolygonCodec) PlanScan(m *Map, oid uint32, format int16, target interface{ type scanPlanBinaryPolygonToPolygonScanner struct{} -func (scanPlanBinaryPolygonToPolygonScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryPolygonToPolygonScanner) Scan(src []byte, dst any) error { scanner := (dst).(PolygonScanner) if src == nil { @@ -193,7 +193,7 @@ func (scanPlanBinaryPolygonToPolygonScanner) Scan(src []byte, dst interface{}) e type scanPlanTextAnyToPolygonScanner struct{} -func (scanPlanTextAnyToPolygonScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToPolygonScanner) Scan(src []byte, dst any) error { scanner := (dst).(PolygonScanner) if src == nil { @@ -239,7 +239,7 @@ func (c PolygonCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, s return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c PolygonCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c PolygonCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/polygon_test.go b/pgtype/polygon_test.go index a6a60de2..5ddbc166 100644 --- a/pgtype/polygon_test.go +++ b/pgtype/polygon_test.go @@ -8,8 +8,8 @@ import ( "github.com/jackc/pgx/v5/pgxtest" ) -func isExpectedEqPolygon(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { +func isExpectedEqPolygon(a any) func(any) bool { + return func(v any) bool { ap := a.(pgtype.Polygon) vp := v.(pgtype.Polygon) diff --git a/pgtype/qchar.go b/pgtype/qchar.go index 677b9003..0e65041f 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -22,7 +22,7 @@ func (QCharCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (QCharCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (QCharCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case TextFormatCode, BinaryFormatCode: switch value.(type) { @@ -38,7 +38,7 @@ func (QCharCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{} type encodePlanQcharCodecByte struct{} -func (encodePlanQcharCodecByte) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanQcharCodecByte) Encode(value any, buf []byte) (newBuf []byte, err error) { b := value.(byte) buf = append(buf, b) return buf, nil @@ -46,7 +46,7 @@ func (encodePlanQcharCodecByte) Encode(value interface{}, buf []byte) (newBuf [] type encodePlanQcharCodecRune struct{} -func (encodePlanQcharCodecRune) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanQcharCodecRune) Encode(value any, buf []byte) (newBuf []byte, err error) { r := value.(rune) if r > math.MaxUint8 { return nil, fmt.Errorf(`%v cannot be encoded to "char"`, r) @@ -56,7 +56,7 @@ func (encodePlanQcharCodecRune) Encode(value interface{}, buf []byte) (newBuf [] return buf, nil } -func (QCharCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (QCharCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case TextFormatCode, BinaryFormatCode: switch target.(type) { @@ -72,7 +72,7 @@ func (QCharCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) type scanPlanQcharCodecByte struct{} -func (scanPlanQcharCodecByte) Scan(src []byte, dst interface{}) error { +func (scanPlanQcharCodecByte) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -94,7 +94,7 @@ func (scanPlanQcharCodecByte) Scan(src []byte, dst interface{}) error { type scanPlanQcharCodecRune struct{} -func (scanPlanQcharCodecRune) Scan(src []byte, dst interface{}) error { +func (scanPlanQcharCodecRune) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -127,7 +127,7 @@ func (c QCharCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return string(r), nil } -func (c QCharCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c QCharCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go index 207e3b39..6d62e7ff 100644 --- a/pgtype/range_codec.go +++ b/pgtype/range_codec.go @@ -16,7 +16,7 @@ type RangeValuer interface { BoundTypes() (lower, upper BoundType) // Bounds returns the lower and upper range values. - Bounds() (lower, upper interface{}) + Bounds() (lower, upper any) } // RangeScanner is a type can be scanned from a PostgreSQL range. @@ -26,7 +26,7 @@ type RangeScanner interface { // ScanBounds returns values usable as a scan target. The returned values may not be scanned if the range is empty or // the bound type is unbounded. - ScanBounds() (lowerTarget, upperTarget interface{}) + ScanBounds() (lowerTarget, upperTarget any) // SetBoundTypes sets the lower and upper bound types. ScanBounds will be called and the returned values scanned // (if appropriate) before SetBoundTypes is called. If the bound types are unbounded or empty this method must @@ -35,8 +35,8 @@ type RangeScanner interface { } type GenericRange struct { - Lower interface{} - Upper interface{} + Lower any + Upper any LowerType BoundType UpperType BoundType Valid bool @@ -50,7 +50,7 @@ func (r GenericRange) BoundTypes() (lower, upper BoundType) { return r.LowerType, r.UpperType } -func (r GenericRange) Bounds() (lower, upper interface{}) { +func (r GenericRange) Bounds() (lower, upper any) { return &r.Lower, &r.Upper } @@ -59,7 +59,7 @@ func (r *GenericRange) ScanNull() error { return nil } -func (r *GenericRange) ScanBounds() (lowerTarget, upperTarget interface{}) { +func (r *GenericRange) ScanBounds() (lowerTarget, upperTarget any) { return &r.Lower, &r.Upper } @@ -86,7 +86,7 @@ func (c *RangeCodec) PreferredFormat() int16 { return TextFormatCode } -func (c *RangeCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (c *RangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(RangeValuer); !ok { return nil } @@ -106,7 +106,7 @@ type encodePlanRangeCodecRangeValuerToBinary struct { m *Map } -func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { getter := value.(RangeValuer) if getter.IsNull() { @@ -197,7 +197,7 @@ type encodePlanRangeCodecRangeValuerToText struct { m *Map } -func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) (newBuf []byte, err error) { getter := value.(RangeValuer) if getter.IsNull() { @@ -270,7 +270,7 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value interface{}, buf return buf, nil } -func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { @@ -292,7 +292,7 @@ type scanPlanBinaryRangeToRangeScanner struct { m *Map } -func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target interface{}) error { +func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) error { rangeScanner := (target).(RangeScanner) if src == nil { @@ -342,7 +342,7 @@ type scanPlanTextRangeToRangeScanner struct { m *Map } -func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target interface{}) error { +func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error { rangeScanner := (target).(RangeScanner) if src == nil { @@ -404,7 +404,7 @@ func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr } } -func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go index 8c1116a0..d467b750 100644 --- a/pgtype/range_codec_test.go +++ b/pgtype/range_codec_test.go @@ -132,7 +132,7 @@ func TestRangeCodecDecodeValue(t *testing.T) { for _, tt := range []struct { sql string - expected interface{} + expected any }{ { sql: `select '[1,5)'::int4range`, diff --git a/pgtype/range_types.go b/pgtype/range_types.go index 1496ca30..c101fbdc 100644 --- a/pgtype/range_types.go +++ b/pgtype/range_types.go @@ -17,7 +17,7 @@ func (r Int4range) BoundTypes() (lower, upper BoundType) { return r.LowerType, r.UpperType } -func (r Int4range) Bounds() (lower, upper interface{}) { +func (r Int4range) Bounds() (lower, upper any) { return &r.Lower, &r.Upper } @@ -26,7 +26,7 @@ func (r *Int4range) ScanNull() error { return nil } -func (r *Int4range) ScanBounds() (lowerTarget, upperTarget interface{}) { +func (r *Int4range) ScanBounds() (lowerTarget, upperTarget any) { return &r.Lower, &r.Upper } @@ -59,7 +59,7 @@ func (r Int8range) BoundTypes() (lower, upper BoundType) { return r.LowerType, r.UpperType } -func (r Int8range) Bounds() (lower, upper interface{}) { +func (r Int8range) Bounds() (lower, upper any) { return &r.Lower, &r.Upper } @@ -68,7 +68,7 @@ func (r *Int8range) ScanNull() error { return nil } -func (r *Int8range) ScanBounds() (lowerTarget, upperTarget interface{}) { +func (r *Int8range) ScanBounds() (lowerTarget, upperTarget any) { return &r.Lower, &r.Upper } @@ -101,7 +101,7 @@ func (r Numrange) BoundTypes() (lower, upper BoundType) { return r.LowerType, r.UpperType } -func (r Numrange) Bounds() (lower, upper interface{}) { +func (r Numrange) Bounds() (lower, upper any) { return &r.Lower, &r.Upper } @@ -110,7 +110,7 @@ func (r *Numrange) ScanNull() error { return nil } -func (r *Numrange) ScanBounds() (lowerTarget, upperTarget interface{}) { +func (r *Numrange) ScanBounds() (lowerTarget, upperTarget any) { return &r.Lower, &r.Upper } @@ -143,7 +143,7 @@ func (r Tsrange) BoundTypes() (lower, upper BoundType) { return r.LowerType, r.UpperType } -func (r Tsrange) Bounds() (lower, upper interface{}) { +func (r Tsrange) Bounds() (lower, upper any) { return &r.Lower, &r.Upper } @@ -152,7 +152,7 @@ func (r *Tsrange) ScanNull() error { return nil } -func (r *Tsrange) ScanBounds() (lowerTarget, upperTarget interface{}) { +func (r *Tsrange) ScanBounds() (lowerTarget, upperTarget any) { return &r.Lower, &r.Upper } @@ -185,7 +185,7 @@ func (r Tstzrange) BoundTypes() (lower, upper BoundType) { return r.LowerType, r.UpperType } -func (r Tstzrange) Bounds() (lower, upper interface{}) { +func (r Tstzrange) Bounds() (lower, upper any) { return &r.Lower, &r.Upper } @@ -194,7 +194,7 @@ func (r *Tstzrange) ScanNull() error { return nil } -func (r *Tstzrange) ScanBounds() (lowerTarget, upperTarget interface{}) { +func (r *Tstzrange) ScanBounds() (lowerTarget, upperTarget any) { return &r.Lower, &r.Upper } @@ -227,7 +227,7 @@ func (r Daterange) BoundTypes() (lower, upper BoundType) { return r.LowerType, r.UpperType } -func (r Daterange) Bounds() (lower, upper interface{}) { +func (r Daterange) Bounds() (lower, upper any) { return &r.Lower, &r.Upper } @@ -236,7 +236,7 @@ func (r *Daterange) ScanNull() error { return nil } -func (r *Daterange) ScanBounds() (lowerTarget, upperTarget interface{}) { +func (r *Daterange) ScanBounds() (lowerTarget, upperTarget any) { return &r.Lower, &r.Upper } @@ -269,7 +269,7 @@ func (r Float8range) BoundTypes() (lower, upper BoundType) { return r.LowerType, r.UpperType } -func (r Float8range) Bounds() (lower, upper interface{}) { +func (r Float8range) Bounds() (lower, upper any) { return &r.Lower, &r.Upper } @@ -278,7 +278,7 @@ func (r *Float8range) ScanNull() error { return nil } -func (r *Float8range) ScanBounds() (lowerTarget, upperTarget interface{}) { +func (r *Float8range) ScanBounds() (lowerTarget, upperTarget any) { return &r.Lower, &r.Upper } diff --git a/pgtype/range_types.go.erb b/pgtype/range_types.go.erb index 8b43f7f9..d181548c 100644 --- a/pgtype/range_types.go.erb +++ b/pgtype/range_types.go.erb @@ -27,7 +27,7 @@ func (r <%= range_type %>) BoundTypes() (lower, upper BoundType) { return r.LowerType, r.UpperType } -func (r <%= range_type %>) Bounds() (lower, upper interface{}) { +func (r <%= range_type %>) Bounds() (lower, upper any) { return &r.Lower, &r.Upper } @@ -36,7 +36,7 @@ func (r *<%= range_type %>) ScanNull() error { return nil } -func (r *<%= range_type %>) ScanBounds() (lowerTarget, upperTarget interface{}) { +func (r *<%= range_type %>) ScanBounds() (lowerTarget, upperTarget any) { return &r.Lower, &r.Upper } diff --git a/pgtype/record_codec.go b/pgtype/record_codec.go index a5c72aac..b3b16604 100644 --- a/pgtype/record_codec.go +++ b/pgtype/record_codec.go @@ -21,11 +21,11 @@ func (RecordCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (RecordCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (RecordCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { return nil } -func (RecordCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (RecordCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { if format == BinaryFormatCode { switch target.(type) { case CompositeIndexScanner: @@ -40,7 +40,7 @@ type scanPlanBinaryRecordToCompositeIndexScanner struct { m *Map } -func (plan *scanPlanBinaryRecordToCompositeIndexScanner) Scan(src []byte, target interface{}) error { +func (plan *scanPlanBinaryRecordToCompositeIndexScanner) Scan(src []byte, target any) error { targetScanner := (target).(CompositeIndexScanner) if src == nil { @@ -87,7 +87,7 @@ func (RecordCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src } } -func (RecordCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (RecordCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -97,9 +97,9 @@ func (RecordCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (in return string(src), nil case BinaryFormatCode: scanner := NewCompositeBinaryScanner(m, src) - values := make([]interface{}, scanner.FieldCount()) + values := make([]any, scanner.FieldCount()) for i := 0; scanner.Next(); i++ { - var v interface{} + var v any fieldPlan := m.PlanScan(scanner.OID(), BinaryFormatCode, &v) if fieldPlan == nil { return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v) diff --git a/pgtype/record_codec_test.go b/pgtype/record_codec_test.go index 57fa87ff..2189f99c 100644 --- a/pgtype/record_codec_test.go +++ b/pgtype/record_codec_test.go @@ -27,27 +27,27 @@ func TestRecordCodecDecodeValue(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { for _, tt := range []struct { sql string - expected interface{} + expected any }{ { sql: `select row()`, - expected: []interface{}{}, + expected: []any{}, }, { sql: `select row('foo'::text, 42::int4)`, - expected: []interface{}{"foo", int32(42)}, + expected: []any{"foo", int32(42)}, }, { sql: `select row(100.0::float4, 1.09::float4)`, - expected: []interface{}{float32(100), float32(1.09)}, + expected: []any{float32(100), float32(1.09)}, }, { sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, - expected: []interface{}{"foo", []interface{}{int32(1), int32(2), nil, int32(4)}, int32(42)}, + expected: []any{"foo", []any{int32(1), int32(2), nil, int32(4)}, int32(42)}, }, { sql: `select row(null)`, - expected: []interface{}{nil}, + expected: []any{nil}, }, { sql: `select null::record`, diff --git a/pgtype/text.go b/pgtype/text.go index 53fcb368..0f9df7a6 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -30,7 +30,7 @@ func (t Text) TextValue() (Text, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Text) Scan(src interface{}) error { +func (dst *Text) Scan(src any) error { if src == nil { *dst = Text{} return nil @@ -90,7 +90,7 @@ func (TextCodec) PreferredFormat() int16 { return TextFormatCode } -func (TextCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (TextCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case TextFormatCode, BinaryFormatCode: switch value.(type) { @@ -110,7 +110,7 @@ func (TextCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) type encodePlanTextCodecString struct{} -func (encodePlanTextCodecString) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTextCodecString) Encode(value any, buf []byte) (newBuf []byte, err error) { s := value.(string) buf = append(buf, s...) return buf, nil @@ -118,7 +118,7 @@ func (encodePlanTextCodecString) Encode(value interface{}, buf []byte) (newBuf [ type encodePlanTextCodecByteSlice struct{} -func (encodePlanTextCodecByteSlice) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTextCodecByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) { s := value.([]byte) buf = append(buf, s...) return buf, nil @@ -126,7 +126,7 @@ func (encodePlanTextCodecByteSlice) Encode(value interface{}, buf []byte) (newBu type encodePlanTextCodecRune struct{} -func (encodePlanTextCodecRune) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTextCodecRune) Encode(value any, buf []byte) (newBuf []byte, err error) { r := value.(rune) buf = append(buf, string(r)...) return buf, nil @@ -134,7 +134,7 @@ func (encodePlanTextCodecRune) Encode(value interface{}, buf []byte) (newBuf []b type encodePlanTextCodecStringer struct{} -func (encodePlanTextCodecStringer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTextCodecStringer) Encode(value any, buf []byte) (newBuf []byte, err error) { s := value.(fmt.Stringer) buf = append(buf, s.String()...) return buf, nil @@ -142,7 +142,7 @@ func (encodePlanTextCodecStringer) Encode(value interface{}, buf []byte) (newBuf type encodePlanTextCodecTextValuer struct{} -func (encodePlanTextCodecTextValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTextCodecTextValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { text, err := value.(TextValuer).TextValue() if err != nil { return nil, err @@ -156,7 +156,7 @@ func (encodePlanTextCodecTextValuer) Encode(value interface{}, buf []byte) (newB return buf, nil } -func (TextCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (TextCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case TextFormatCode, BinaryFormatCode: @@ -181,7 +181,7 @@ func (c TextCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return c.DecodeValue(m, oid, format, src) } -func (c TextCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c TextCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -191,7 +191,7 @@ func (c TextCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (in type scanPlanTextAnyToString struct{} -func (scanPlanTextAnyToString) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToString) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -204,7 +204,7 @@ func (scanPlanTextAnyToString) Scan(src []byte, dst interface{}) error { type scanPlanAnyToNewByteSlice struct{} -func (scanPlanAnyToNewByteSlice) Scan(src []byte, dst interface{}) error { +func (scanPlanAnyToNewByteSlice) Scan(src []byte, dst any) error { p := (dst).(*[]byte) if src == nil { *p = nil @@ -218,14 +218,14 @@ func (scanPlanAnyToNewByteSlice) Scan(src []byte, dst interface{}) error { type scanPlanAnyToByteScanner struct{} -func (scanPlanAnyToByteScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanAnyToByteScanner) Scan(src []byte, dst any) error { p := (dst).(BytesScanner) return p.ScanBytes(src) } type scanPlanTextAnyToRune struct{} -func (scanPlanTextAnyToRune) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToRune) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -243,7 +243,7 @@ func (scanPlanTextAnyToRune) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToTextScanner struct{} -func (scanPlanTextAnyToTextScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToTextScanner) Scan(src []byte, dst any) error { scanner := (dst).(TextScanner) if src == nil { diff --git a/pgtype/tid.go b/pgtype/tid.go index 6eefd34e..cb4a9ec4 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -45,7 +45,7 @@ func (b TID) TIDValue() (TID, error) { } // Scan implements the database/sql Scanner interface. -func (dst *TID) Scan(src interface{}) error { +func (dst *TID) Scan(src any) error { if src == nil { *dst = TID{} return nil @@ -82,7 +82,7 @@ func (TIDCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (TIDCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (TIDCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(TIDValuer); !ok { return nil } @@ -99,7 +99,7 @@ func (TIDCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) type encodePlanTIDCodecBinary struct{} -func (encodePlanTIDCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTIDCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { tid, err := value.(TIDValuer).TIDValue() if err != nil { return nil, err @@ -116,7 +116,7 @@ func (encodePlanTIDCodecBinary) Encode(value interface{}, buf []byte) (newBuf [] type encodePlanTIDCodecText struct{} -func (encodePlanTIDCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTIDCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { tid, err := value.(TIDValuer).TIDValue() if err != nil { return nil, err @@ -130,7 +130,7 @@ func (encodePlanTIDCodecText) Encode(value interface{}, buf []byte) (newBuf []by return buf, nil } -func (TIDCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (TIDCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -152,7 +152,7 @@ func (TIDCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) S type scanPlanBinaryTIDToTIDScanner struct{} -func (scanPlanBinaryTIDToTIDScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryTIDToTIDScanner) Scan(src []byte, dst any) error { scanner := (dst).(TIDScanner) if src == nil { @@ -172,7 +172,7 @@ func (scanPlanBinaryTIDToTIDScanner) Scan(src []byte, dst interface{}) error { type scanPlanBinaryTIDToTextScanner struct{} -func (scanPlanBinaryTIDToTextScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryTIDToTextScanner) Scan(src []byte, dst any) error { scanner := (dst).(TextScanner) if src == nil { @@ -194,7 +194,7 @@ func (scanPlanBinaryTIDToTextScanner) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToTIDScanner struct{} -func (scanPlanTextAnyToTIDScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToTIDScanner) Scan(src []byte, dst any) error { scanner := (dst).(TIDScanner) if src == nil { @@ -227,7 +227,7 @@ func (c TIDCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src [ return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c TIDCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c TIDCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/time.go b/pgtype/time.go index 9005b848..2eb6ace2 100644 --- a/pgtype/time.go +++ b/pgtype/time.go @@ -37,7 +37,7 @@ func (t Time) TimeValue() (Time, error) { } // Scan implements the database/sql Scanner interface. -func (t *Time) Scan(src interface{}) error { +func (t *Time) Scan(src any) error { if src == nil { *t = Time{} return nil @@ -74,7 +74,7 @@ func (TimeCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (TimeCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (TimeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(TimeValuer); !ok { return nil } @@ -91,7 +91,7 @@ func (TimeCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) type encodePlanTimeCodecBinary struct{} -func (encodePlanTimeCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTimeCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { t, err := value.(TimeValuer).TimeValue() if err != nil { return nil, err @@ -106,7 +106,7 @@ func (encodePlanTimeCodecBinary) Encode(value interface{}, buf []byte) (newBuf [ type encodePlanTimeCodecText struct{} -func (encodePlanTimeCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTimeCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { t, err := value.(TimeValuer).TimeValue() if err != nil { return nil, err @@ -129,7 +129,7 @@ func (encodePlanTimeCodecText) Encode(value interface{}, buf []byte) (newBuf []b return append(buf, s...), nil } -func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -149,7 +149,7 @@ func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) type scanPlanBinaryTimeToTimeScanner struct{} -func (scanPlanBinaryTimeToTimeScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryTimeToTimeScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimeScanner) if src == nil { @@ -167,7 +167,7 @@ func (scanPlanBinaryTimeToTimeScanner) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToTimeScanner struct{} -func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimeScanner) if src == nil { @@ -219,7 +219,7 @@ func (c TimeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return codecDecodeToTextFormat(c, m, oid, format, src) } -func (c TimeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c TimeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 10525229..9f3de2c5 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -37,7 +37,7 @@ func (ts Timestamp) TimestampValue() (Timestamp, error) { } // Scan implements the database/sql Scanner interface. -func (ts *Timestamp) Scan(src interface{}) error { +func (ts *Timestamp) Scan(src any) error { if src == nil { *ts = Timestamp{} return nil @@ -76,7 +76,7 @@ func (TimestampCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(TimestampValuer); !ok { return nil } @@ -93,7 +93,7 @@ func (TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value interfa type encodePlanTimestampCodecBinary struct{} -func (encodePlanTimestampCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTimestampCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { ts, err := value.(TimestampValuer).TimestampValue() if err != nil { return nil, err @@ -122,7 +122,7 @@ func (encodePlanTimestampCodecBinary) Encode(value interface{}, buf []byte) (new type encodePlanTimestampCodecText struct{} -func (encodePlanTimestampCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTimestampCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { ts, err := value.(TimestampValuer).TimestampValue() if err != nil { return nil, err @@ -170,7 +170,7 @@ func discardTimeZone(t time.Time) time.Time { return t } -func (TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -190,7 +190,7 @@ func (TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target interfac type scanPlanBinaryTimestampToTimestampScanner struct{} -func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestampScanner) if src == nil { @@ -222,7 +222,7 @@ func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst interface{ type scanPlanTextTimestampToTimestampScanner struct{} -func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestampScanner) if src == nil { @@ -276,7 +276,7 @@ func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, return ts.Time, nil } -func (c TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 7709e0aa..f568fe30 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -46,7 +46,7 @@ func (tstz Timestamptz) TimestamptzValue() (Timestamptz, error) { } // Scan implements the database/sql Scanner interface. -func (tstz *Timestamptz) Scan(src interface{}) error { +func (tstz *Timestamptz) Scan(src any) error { if src == nil { *tstz = Timestamptz{} return nil @@ -134,7 +134,7 @@ func (TimestamptzCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(TimestamptzValuer); !ok { return nil } @@ -151,7 +151,7 @@ func (TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value inter type encodePlanTimestamptzCodecBinary struct{} -func (encodePlanTimestamptzCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTimestamptzCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { ts, err := value.(TimestamptzValuer).TimestamptzValue() if err != nil { return nil, err @@ -179,7 +179,7 @@ func (encodePlanTimestamptzCodecBinary) Encode(value interface{}, buf []byte) (n type encodePlanTimestamptzCodecText struct{} -func (encodePlanTimestamptzCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanTimestamptzCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { ts, err := value.(TimestamptzValuer).TimestamptzValue() if err != nil { return nil, err @@ -220,7 +220,7 @@ func (encodePlanTimestamptzCodecText) Encode(value interface{}, buf []byte) (new return buf, nil } -func (TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -240,7 +240,7 @@ func (TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target interf type scanPlanBinaryTimestamptzToTimestamptzScanner struct{} -func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestamptzScanner) if src == nil { @@ -272,7 +272,7 @@ func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst interf type scanPlanTextTimestamptzToTimestamptzScanner struct{} -func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestamptzScanner) if src == nil { @@ -336,7 +336,7 @@ func (c TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int1 return tstz.Time, nil } -func (c TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/uint32.go b/pgtype/uint32.go index 25d2423a..37e0c65f 100644 --- a/pgtype/uint32.go +++ b/pgtype/uint32.go @@ -34,7 +34,7 @@ func (n Uint32) Uint32Value() (Uint32, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Uint32) Scan(src interface{}) error { +func (dst *Uint32) Scan(src any) error { if src == nil { *dst = Uint32{} return nil @@ -85,7 +85,7 @@ func (Uint32Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Uint32Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (Uint32Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case BinaryFormatCode: switch value.(type) { @@ -110,14 +110,14 @@ func (Uint32Codec) PlanEncode(m *Map, oid uint32, format int16, value interface{ type encodePlanUint32CodecBinaryUint32 struct{} -func (encodePlanUint32CodecBinaryUint32) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanUint32CodecBinaryUint32) Encode(value any, buf []byte) (newBuf []byte, err error) { v := value.(uint32) return pgio.AppendUint32(buf, v), nil } type encodePlanUint32CodecBinaryUint32Valuer struct{} -func (encodePlanUint32CodecBinaryUint32Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanUint32CodecBinaryUint32Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { v, err := value.(Uint32Valuer).Uint32Value() if err != nil { return nil, err @@ -132,7 +132,7 @@ func (encodePlanUint32CodecBinaryUint32Valuer) Encode(value interface{}, buf []b type encodePlanUint32CodecBinaryInt64Valuer struct{} -func (encodePlanUint32CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanUint32CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { v, err := value.(Int64Valuer).Int64Value() if err != nil { return nil, err @@ -154,14 +154,14 @@ func (encodePlanUint32CodecBinaryInt64Valuer) Encode(value interface{}, buf []by type encodePlanUint32CodecTextUint32 struct{} -func (encodePlanUint32CodecTextUint32) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanUint32CodecTextUint32) Encode(value any, buf []byte) (newBuf []byte, err error) { v := value.(uint32) return append(buf, strconv.FormatUint(uint64(v), 10)...), nil } type encodePlanUint32CodecTextUint32Valuer struct{} -func (encodePlanUint32CodecTextUint32Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanUint32CodecTextUint32Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { v, err := value.(Uint32Valuer).Uint32Value() if err != nil { return nil, err @@ -176,7 +176,7 @@ func (encodePlanUint32CodecTextUint32Valuer) Encode(value interface{}, buf []byt type encodePlanUint32CodecTextInt64Valuer struct{} -func (encodePlanUint32CodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanUint32CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { v, err := value.(Int64Valuer).Int64Value() if err != nil { return nil, err @@ -196,7 +196,7 @@ func (encodePlanUint32CodecTextInt64Valuer) Encode(value interface{}, buf []byte return append(buf, strconv.FormatInt(v.Int64, 10)...), nil } -func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: @@ -231,7 +231,7 @@ func (c Uint32Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr return int64(n), nil } -func (c Uint32Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c Uint32Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -246,7 +246,7 @@ func (c Uint32Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) ( type scanPlanBinaryUint32ToUint32 struct{} -func (scanPlanBinaryUint32ToUint32) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryUint32ToUint32) Scan(src []byte, dst any) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -263,7 +263,7 @@ func (scanPlanBinaryUint32ToUint32) Scan(src []byte, dst interface{}) error { type scanPlanBinaryUint32ToUint32Scanner struct{} -func (scanPlanBinaryUint32ToUint32Scanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryUint32ToUint32Scanner) Scan(src []byte, dst any) error { s, ok := (dst).(Uint32Scanner) if !ok { return ErrScanTargetTypeChanged @@ -284,7 +284,7 @@ func (scanPlanBinaryUint32ToUint32Scanner) Scan(src []byte, dst interface{}) err type scanPlanTextAnyToUint32Scanner struct{} -func (scanPlanTextAnyToUint32Scanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToUint32Scanner) Scan(src []byte, dst any) error { s, ok := (dst).(Uint32Scanner) if !ok { return ErrScanTargetTypeChanged diff --git a/pgtype/uuid.go b/pgtype/uuid.go index a561bed9..8c3bbba5 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -56,7 +56,7 @@ func encodeUUID(src [16]byte) string { } // Scan implements the database/sql Scanner interface. -func (dst *UUID) Scan(src interface{}) error { +func (dst *UUID) Scan(src any) error { if src == nil { *dst = UUID{} return nil @@ -122,7 +122,7 @@ func (UUIDCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (UUIDCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (UUIDCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(UUIDValuer); !ok { return nil } @@ -139,7 +139,7 @@ func (UUIDCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) type encodePlanUUIDCodecBinaryUUIDValuer struct{} -func (encodePlanUUIDCodecBinaryUUIDValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanUUIDCodecBinaryUUIDValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { uuid, err := value.(UUIDValuer).UUIDValue() if err != nil { return nil, err @@ -154,7 +154,7 @@ func (encodePlanUUIDCodecBinaryUUIDValuer) Encode(value interface{}, buf []byte) type encodePlanUUIDCodecTextUUIDValuer struct{} -func (encodePlanUUIDCodecTextUUIDValuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (encodePlanUUIDCodecTextUUIDValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { uuid, err := value.(UUIDValuer).UUIDValue() if err != nil { return nil, err @@ -167,7 +167,7 @@ func (encodePlanUUIDCodecTextUUIDValuer) Encode(value interface{}, buf []byte) ( return append(buf, encodeUUID(uuid.Bytes)...), nil } -func (UUIDCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) ScanPlan { +func (UUIDCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { @@ -186,7 +186,7 @@ func (UUIDCodec) PlanScan(m *Map, oid uint32, format int16, target interface{}) type scanPlanBinaryUUIDToUUIDScanner struct{} -func (scanPlanBinaryUUIDToUUIDScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanBinaryUUIDToUUIDScanner) Scan(src []byte, dst any) error { scanner := (dst).(UUIDScanner) if src == nil { @@ -205,7 +205,7 @@ func (scanPlanBinaryUUIDToUUIDScanner) Scan(src []byte, dst interface{}) error { type scanPlanTextAnyToUUIDScanner struct{} -func (scanPlanTextAnyToUUIDScanner) Scan(src []byte, dst interface{}) error { +func (scanPlanTextAnyToUUIDScanner) Scan(src []byte, dst any) error { scanner := (dst).(UUIDScanner) if src == nil { @@ -234,7 +234,7 @@ func (c UUIDCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return encodeUUID(uuid.Bytes), nil } -func (c UUIDCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (interface{}, error) { +func (c UUIDCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/zeronull/float8.go b/pgtype/zeronull/float8.go index 7f3a06b4..08fa169e 100644 --- a/pgtype/zeronull/float8.go +++ b/pgtype/zeronull/float8.go @@ -30,7 +30,7 @@ func (f Float8) Float64Value() (pgtype.Float8, error) { } // Scan implements the database/sql Scanner interface. -func (f *Float8) Scan(src interface{}) error { +func (f *Float8) Scan(src any) error { if src == nil { *f = 0 return nil diff --git a/pgtype/zeronull/float8_test.go b/pgtype/zeronull/float8_test.go index a8c1b6a1..b3c818aa 100644 --- a/pgtype/zeronull/float8_test.go +++ b/pgtype/zeronull/float8_test.go @@ -8,8 +8,8 @@ import ( "github.com/jackc/pgx/v5/pgxtest" ) -func isExpectedEq(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { +func isExpectedEq(a any) func(any) bool { + return func(v any) bool { return a == v } } @@ -28,7 +28,7 @@ func TestFloat8Transcode(t *testing.T) { }, { (zeronull.Float8)(0), - new(interface{}), + new(any), isExpectedEq(nil), }, }) diff --git a/pgtype/zeronull/int.go b/pgtype/zeronull/int.go index 9a40691f..4fec8a1a 100644 --- a/pgtype/zeronull/int.go +++ b/pgtype/zeronull/int.go @@ -32,7 +32,7 @@ func (dst *Int2) ScanInt64(n int64, valid bool) error { } // Scan implements the database/sql Scanner interface. -func (dst *Int2) Scan(src interface{}) error { +func (dst *Int2) Scan(src any) error { if src == nil { *dst = 0 return nil @@ -80,7 +80,7 @@ func (dst *Int4) ScanInt64(n int64, valid bool) error { } // Scan implements the database/sql Scanner interface. -func (dst *Int4) Scan(src interface{}) error { +func (dst *Int4) Scan(src any) error { if src == nil { *dst = 0 return nil @@ -128,7 +128,7 @@ func (dst *Int8) ScanInt64(n int64, valid bool) error { } // Scan implements the database/sql Scanner interface. -func (dst *Int8) Scan(src interface{}) error { +func (dst *Int8) Scan(src any) error { if src == nil { *dst = 0 return nil diff --git a/pgtype/zeronull/int.go.erb b/pgtype/zeronull/int.go.erb index cdae7597..b51cba12 100644 --- a/pgtype/zeronull/int.go.erb +++ b/pgtype/zeronull/int.go.erb @@ -33,7 +33,7 @@ func (dst *Int<%= pg_byte_size %>) ScanInt64(n int64, valid bool) error { } // Scan implements the database/sql Scanner interface. -func (dst *Int<%= pg_byte_size %>) Scan(src interface{}) error { +func (dst *Int<%= pg_byte_size %>) Scan(src any) error { if src == nil { *dst = 0 return nil diff --git a/pgtype/zeronull/int_test.go b/pgtype/zeronull/int_test.go index 30e4808f..7204cc88 100644 --- a/pgtype/zeronull/int_test.go +++ b/pgtype/zeronull/int_test.go @@ -23,7 +23,7 @@ func TestInt2Transcode(t *testing.T) { }, { (zeronull.Int2)(0), - new(interface{}), + new(any), isExpectedEq(nil), }, }) @@ -43,7 +43,7 @@ func TestInt4Transcode(t *testing.T) { }, { (zeronull.Int4)(0), - new(interface{}), + new(any), isExpectedEq(nil), }, }) @@ -63,7 +63,7 @@ func TestInt8Transcode(t *testing.T) { }, { (zeronull.Int8)(0), - new(interface{}), + new(any), isExpectedEq(nil), }, }) diff --git a/pgtype/zeronull/int_test.go.erb b/pgtype/zeronull/int_test.go.erb index 2c7ddc46..c0f72ef4 100644 --- a/pgtype/zeronull/int_test.go.erb +++ b/pgtype/zeronull/int_test.go.erb @@ -23,7 +23,7 @@ func TestInt<%= pg_byte_size %>Transcode(t *testing.T) { }, { (zeronull.Int<%= pg_byte_size %>)(0), - new(interface{}), + new(any), isExpectedEq(nil), }, }) diff --git a/pgtype/zeronull/text.go b/pgtype/zeronull/text.go index b768e308..4ba51fa9 100644 --- a/pgtype/zeronull/text.go +++ b/pgtype/zeronull/text.go @@ -23,7 +23,7 @@ func (dst *Text) ScanText(v pgtype.Text) error { } // Scan implements the database/sql Scanner interface. -func (dst *Text) Scan(src interface{}) error { +func (dst *Text) Scan(src any) error { if src == nil { *dst = "" return nil diff --git a/pgtype/zeronull/text_test.go b/pgtype/zeronull/text_test.go index e0d6ec43..5a60baf1 100644 --- a/pgtype/zeronull/text_test.go +++ b/pgtype/zeronull/text_test.go @@ -22,7 +22,7 @@ func TestTextTranscode(t *testing.T) { }, { (zeronull.Text)(""), - new(interface{}), + new(any), isExpectedEq(nil), }, }) diff --git a/pgtype/zeronull/timestamp.go b/pgtype/zeronull/timestamp.go index 163af041..1697c420 100644 --- a/pgtype/zeronull/timestamp.go +++ b/pgtype/zeronull/timestamp.go @@ -40,7 +40,7 @@ func (ts Timestamp) TimestampValue() (pgtype.Timestamp, error) { } // Scan implements the database/sql Scanner interface. -func (ts *Timestamp) Scan(src interface{}) error { +func (ts *Timestamp) Scan(src any) error { if src == nil { *ts = Timestamp{} return nil diff --git a/pgtype/zeronull/timestamp_test.go b/pgtype/zeronull/timestamp_test.go index 78393e9b..8a5a5796 100644 --- a/pgtype/zeronull/timestamp_test.go +++ b/pgtype/zeronull/timestamp_test.go @@ -9,8 +9,8 @@ import ( "github.com/jackc/pgx/v5/pgxtest" ) -func isExpectedEqTimestamp(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { +func isExpectedEqTimestamp(a any) func(any) bool { + return func(v any) bool { at := time.Time(a.(zeronull.Timestamp)) vt := time.Time(v.(zeronull.Timestamp)) @@ -32,7 +32,7 @@ func TestTimestampTranscode(t *testing.T) { }, { (zeronull.Timestamp)(time.Time{}), - new(interface{}), + new(any), isExpectedEq(nil), }, }) diff --git a/pgtype/zeronull/timestamptz.go b/pgtype/zeronull/timestamptz.go index 6cd60c37..55bc0c8e 100644 --- a/pgtype/zeronull/timestamptz.go +++ b/pgtype/zeronull/timestamptz.go @@ -40,7 +40,7 @@ func (ts Timestamptz) TimestamptzValue() (pgtype.Timestamptz, error) { } // Scan implements the database/sql Scanner interface. -func (ts *Timestamptz) Scan(src interface{}) error { +func (ts *Timestamptz) Scan(src any) error { if src == nil { *ts = Timestamptz{} return nil diff --git a/pgtype/zeronull/timestamptz_test.go b/pgtype/zeronull/timestamptz_test.go index d8273258..0a6d380b 100644 --- a/pgtype/zeronull/timestamptz_test.go +++ b/pgtype/zeronull/timestamptz_test.go @@ -9,8 +9,8 @@ import ( "github.com/jackc/pgx/v5/pgxtest" ) -func isExpectedEqTimestamptz(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { +func isExpectedEqTimestamptz(a any) func(any) bool { + return func(v any) bool { at := time.Time(a.(zeronull.Timestamptz)) vt := time.Time(v.(zeronull.Timestamptz)) @@ -32,7 +32,7 @@ func TestTimestamptzTranscode(t *testing.T) { }, { (zeronull.Timestamptz)(time.Time{}), - new(interface{}), + new(any), isExpectedEq(nil), }, }) diff --git a/pgtype/zeronull/uuid.go b/pgtype/zeronull/uuid.go index abe5049e..d88be84d 100644 --- a/pgtype/zeronull/uuid.go +++ b/pgtype/zeronull/uuid.go @@ -30,7 +30,7 @@ func (u UUID) UUIDValue() (pgtype.UUID, error) { } // Scan implements the database/sql Scanner interface. -func (u *UUID) Scan(src interface{}) error { +func (u *UUID) Scan(src any) error { if src == nil { *u = UUID{} return nil diff --git a/pgtype/zeronull/uuid_test.go b/pgtype/zeronull/uuid_test.go index 0cb169fa..c50cb300 100644 --- a/pgtype/zeronull/uuid_test.go +++ b/pgtype/zeronull/uuid_test.go @@ -22,7 +22,7 @@ func TestUUIDTranscode(t *testing.T) { }, { (zeronull.UUID)([16]byte{}), - new(interface{}), + new(any), isExpectedEq(nil), }, }) diff --git a/pgxpool/batch_results.go b/pgxpool/batch_results.go index aa1d609d..fcd10b37 100644 --- a/pgxpool/batch_results.go +++ b/pgxpool/batch_results.go @@ -17,7 +17,7 @@ func (br errBatchResults) Query() (pgx.Rows, error) { return errRows{err: br.err}, br.err } -func (br errBatchResults) QueryFunc(scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { +func (br errBatchResults) QueryFunc(scans []any, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { return pgconn.CommandTag{}, br.err } @@ -42,7 +42,7 @@ func (br *poolBatchResults) Query() (pgx.Rows, error) { return br.br.Query() } -func (br *poolBatchResults) QueryFunc(scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { +func (br *poolBatchResults) QueryFunc(scans []any, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { return br.br.QueryFunc(scans, f) } diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go index c0ae07c4..c331b33b 100644 --- a/pgxpool/common_test.go +++ b/pgxpool/common_test.go @@ -21,7 +21,7 @@ func waitForReleaseToComplete() { } type execer interface { - Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) + Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) } func testExec(t *testing.T, db execer) { @@ -31,7 +31,7 @@ func testExec(t *testing.T, db execer) { } type queryer interface { - Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) + Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) } func testQuery(t *testing.T, db queryer) { @@ -53,7 +53,7 @@ func testQuery(t *testing.T, db queryer) { } type queryRower interface { - QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row } func testQueryRow(t *testing.T, db queryRower) { @@ -103,7 +103,7 @@ func testCopyFrom(t *testing.T, db interface { tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) - inputRows := [][]interface{}{ + inputRows := [][]any{ {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, {nil, nil, nil, nil, nil, nil, nil}, } @@ -115,7 +115,7 @@ func testCopyFrom(t *testing.T, db interface { rows, err := db.Query(context.Background(), "select * from foo") assert.NoError(t, err) - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { diff --git a/pgxpool/conn.go b/pgxpool/conn.go index 8798db4b..d842779e 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -46,19 +46,19 @@ func (c *Conn) Release() { }() } -func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { +func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { return c.Conn().Exec(ctx, sql, arguments...) } -func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { +func (c *Conn) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { return c.Conn().Query(ctx, sql, args...) } -func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { +func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { return c.Conn().QueryRow(ctx, sql, args...) } -func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { +func (c *Conn) QueryFunc(ctx context.Context, sql string, args []any, scans []any, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { return c.Conn().QueryFunc(ctx, sql, args, scans, f) } diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 30d02879..c3abc51a 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -177,7 +177,7 @@ func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { } p.p = puddle.NewPool( - func(ctx context.Context) (interface{}, error) { + func(ctx context.Context) (any, error) { connConfig := p.config.ConnConfig if p.beforeConnect != nil { @@ -209,7 +209,7 @@ func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { return cr, nil }, - func(value interface{}) { + func(value any) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) conn := value.(*connResource).conn conn.Close(ctx) @@ -467,7 +467,7 @@ func (p *Pool) Stat() *Stat { // SQL can be either a prepared statement name or an SQL string. // Arguments should be referenced positionally from the SQL string as $1, $2, etc. // The acquired connection is returned to the pool when the Exec function returns. -func (p *Pool) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { +func (p *Pool) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { c, err := p.Acquire(ctx) if err != nil { return pgconn.CommandTag{}, err @@ -487,7 +487,7 @@ func (p *Pool) Exec(ctx context.Context, sql string, arguments ...interface{}) ( // For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. -func (p *Pool) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { +func (p *Pool) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { c, err := p.Acquire(ctx) if err != nil { return errRows{err: err}, err @@ -514,7 +514,7 @@ func (p *Pool) Query(ctx context.Context, sql string, args ...interface{}) (pgx. // For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. -func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { +func (p *Pool) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { c, err := p.Acquire(ctx) if err != nil { return errRow{err: err} @@ -524,7 +524,7 @@ func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pg return c.getPoolRow(row) } -func (p *Pool) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { +func (p *Pool) QueryFunc(ctx context.Context, sql string, args []any, scans []any, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { c, err := p.Acquire(ctx) if err != nil { return pgconn.CommandTag{}, err diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 427e0ea9..8a898720 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -514,7 +514,7 @@ func TestPoolCopyFrom(t *testing.T) { tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) - inputRows := [][]interface{}{ + inputRows := [][]any{ {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, {nil, nil, nil, nil, nil, nil, nil}, } @@ -526,7 +526,7 @@ func TestPoolCopyFrom(t *testing.T) { rows, err := pool.Query(ctx, "select * from poolcopyfromtest") assert.NoError(t, err) - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { diff --git a/pgxpool/rows.go b/pgxpool/rows.go index ff7ad80b..aeb65179 100644 --- a/pgxpool/rows.go +++ b/pgxpool/rows.go @@ -15,15 +15,15 @@ func (e errRows) Err() error { return e.err } func (errRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} } func (errRows) FieldDescriptions() []pgproto3.FieldDescription { return nil } func (errRows) Next() bool { return false } -func (e errRows) Scan(dest ...interface{}) error { return e.err } -func (e errRows) Values() ([]interface{}, error) { return nil, e.err } +func (e errRows) Scan(dest ...any) error { return e.err } +func (e errRows) Values() ([]any, error) { return nil, e.err } func (e errRows) RawValues() [][]byte { return nil } type errRow struct { err error } -func (e errRow) Scan(dest ...interface{}) error { return e.err } +func (e errRow) Scan(dest ...any) error { return e.err } type poolRows struct { r pgx.Rows @@ -66,7 +66,7 @@ func (rows *poolRows) Next() bool { return n } -func (rows *poolRows) Scan(dest ...interface{}) error { +func (rows *poolRows) Scan(dest ...any) error { err := rows.r.Scan(dest...) if err != nil { rows.Close() @@ -74,7 +74,7 @@ func (rows *poolRows) Scan(dest ...interface{}) error { return err } -func (rows *poolRows) Values() ([]interface{}, error) { +func (rows *poolRows) Values() ([]any, error) { values, err := rows.r.Values() if err != nil { rows.Close() @@ -92,7 +92,7 @@ type poolRow struct { err error } -func (row *poolRow) Scan(dest ...interface{}) error { +func (row *poolRow) Scan(dest ...any) error { if row.err != nil { return row.err } diff --git a/pgxpool/tx.go b/pgxpool/tx.go index a82b2176..79da567c 100644 --- a/pgxpool/tx.go +++ b/pgxpool/tx.go @@ -69,19 +69,19 @@ func (tx *Tx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementD return tx.t.Prepare(ctx, name, sql) } -func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { +func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { return tx.t.Exec(ctx, sql, arguments...) } -func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { +func (tx *Tx) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { return tx.t.Query(ctx, sql, args...) } -func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { +func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { return tx.t.QueryRow(ctx, sql, args...) } -func (tx *Tx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { +func (tx *Tx) QueryFunc(ctx context.Context, sql string, args []any, scans []any, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { return tx.t.QueryFunc(ctx, sql, args, scans, f) } diff --git a/pgxtest/pgxtest.go b/pgxtest/pgxtest.go index 3a416dc2..796f850d 100644 --- a/pgxtest/pgxtest.go +++ b/pgxtest/pgxtest.go @@ -101,9 +101,9 @@ func RunWithQueryExecModes(ctx context.Context, t *testing.T, ctr ConnTestRunner } type ValueRoundTripTest struct { - Param interface{} - Result interface{} - Test func(interface{}) bool + Param any + Result any + Test func(any) bool } func RunValueRoundTripTests( diff --git a/query_test.go b/query_test.go index 007d5256..0e310eef 100644 --- a/query_test.go +++ b/query_test.go @@ -109,7 +109,7 @@ func TestConnQueryScanWithManyColumns(t *testing.T) { defer rows.Close() for rows.Next() { - destPtrs := make([]interface{}, columnCount) + destPtrs := make([]any, columnCount) for i := range destPtrs { destPtrs[i] = &dest[i] } @@ -597,18 +597,18 @@ func TestQueryRowCoreTypes(t *testing.T) { tests := []struct { sql string - queryArgs []interface{} - scanArgs []interface{} + queryArgs []any + scanArgs []any expected allTypes }{ - {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.s}, allTypes{s: "Jack"}}, - {"select $1::float4", []interface{}{float32(1.23)}, []interface{}{&actual.f32}, allTypes{f32: 1.23}}, - {"select $1::float8", []interface{}{float64(1.23)}, []interface{}{&actual.f64}, allTypes{f64: 1.23}}, - {"select $1::bool", []interface{}{true}, []interface{}{&actual.b}, allTypes{b: true}}, - {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, - {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}}, - {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}}, - {"select $1::oid", []interface{}{uint32(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, + {"select $1::text", []any{"Jack"}, []any{&actual.s}, allTypes{s: "Jack"}}, + {"select $1::float4", []any{float32(1.23)}, []any{&actual.f32}, allTypes{f32: 1.23}}, + {"select $1::float8", []any{float64(1.23)}, []any{&actual.f64}, allTypes{f64: 1.23}}, + {"select $1::bool", []any{true}, []any{&actual.b}, allTypes{b: true}}, + {"select $1::timestamptz", []any{time.Unix(123, 5000)}, []any{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, + {"select $1::timestamp", []any{time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}, []any{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}}, + {"select $1::date", []any{time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}, []any{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}}, + {"select $1::oid", []any{uint32(42)}, []any{&actual.oid}, allTypes{oid: 42}}, } for i, tt := range tests { @@ -658,8 +658,8 @@ func TestQueryRowCoreIntegerEncoding(t *testing.T) { successfulEncodeTests := []struct { sql string - queryArg interface{} - scanArg interface{} + queryArg any + scanArg any expected allTypes }{ // Check any integer type where value is within int2 range can be encoded @@ -717,7 +717,7 @@ func TestQueryRowCoreIntegerEncoding(t *testing.T) { failedEncodeTests := []struct { sql string - queryArg interface{} + queryArg any }{ // Check any integer type where value is outside pg:int2 range cannot be encoded {"select $1::int2", int(32769)}, @@ -773,7 +773,7 @@ func TestQueryRowCoreIntegerDecoding(t *testing.T) { successfulDecodeTests := []struct { sql string - scanArg interface{} + scanArg any expected allTypes }{ // Check any integer type where value is within Go:int range can be decoded @@ -860,7 +860,7 @@ func TestQueryRowCoreIntegerDecoding(t *testing.T) { failedDecodeTests := []struct { sql string - scanArg interface{} + scanArg any }{ // Check any integer type where value is outside Go:int8 range cannot be decoded {"select 128::int2", &actual.i8}, @@ -932,7 +932,7 @@ func TestQueryRowCoreByteSlice(t *testing.T) { tests := []struct { sql string - queryArg interface{} + queryArg any expected []byte }{ {"select $1::text", "Jack", []byte("Jack")}, @@ -977,14 +977,14 @@ func TestQueryRowErrors(t *testing.T) { tests := []struct { sql string - queryArgs []interface{} - scanArgs []interface{} + queryArgs []any + scanArgs []any err string }{ - {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, - {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, - {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "cannot scan OID 25 in text format into *int16"}, - {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "unable to encode 705 into format code 1 for OID 600"}, + {"select $1::badtype", []any{"Jack"}, []any{&actual.i16}, `type "badtype" does not exist`}, + {"SYNTAX ERROR", []any{}, []any{&actual.i16}, "SQLSTATE 42601"}, + {"select $1::text", []any{"Jack"}, []any{&actual.i16}, "cannot scan OID 25 in text format into *int16"}, + {"select $1::point", []any{int(705)}, []any{&actual.s}, "unable to encode 705 into format code 1 for OID 600"}, } for i, tt := range tests { @@ -1868,25 +1868,25 @@ func TestConnQueryFunc(t *testing.T) { t.Parallel() pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - var actualResults []interface{} + var actualResults []any var a, b int ct, err := conn.QueryFunc( context.Background(), "select n, n * 2 from generate_series(1, $1) n", - []interface{}{3}, - []interface{}{&a, &b}, + []any{3}, + []any{&a, &b}, func(pgx.QueryFuncRow) error { - actualResults = append(actualResults, []interface{}{a, b}) + actualResults = append(actualResults, []any{a, b}) return nil }, ) require.NoError(t, err) - expectedResults := []interface{}{ - []interface{}{1, 2}, - []interface{}{2, 4}, - []interface{}{3, 6}, + expectedResults := []any{ + []any{1, 2}, + []any{2, 4}, + []any{3, 6}, } require.Equal(t, expectedResults, actualResults) require.EqualValues(t, 3, ct.RowsAffected()) @@ -1897,16 +1897,16 @@ func TestConnQueryFuncScanError(t *testing.T) { t.Parallel() pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - var actualResults []interface{} + var actualResults []any var a, b int ct, err := conn.QueryFunc( context.Background(), "select 'foo', 'bar' from generate_series(1, $1) n", - []interface{}{3}, - []interface{}{&a, &b}, + []any{3}, + []any{&a, &b}, func(pgx.QueryFuncRow) error { - actualResults = append(actualResults, []interface{}{a, b}) + actualResults = append(actualResults, []any{a, b}) return nil }, ) @@ -1923,8 +1923,8 @@ func TestConnQueryFuncAbort(t *testing.T) { ct, err := conn.QueryFunc( context.Background(), "select n, n * 2 from generate_series(1, $1) n", - []interface{}{3}, - []interface{}{&a, &b}, + []any{3}, + []any{&a, &b}, func(pgx.QueryFuncRow) error { return errors.New("abort") }, @@ -1945,8 +1945,8 @@ func ExampleConn_QueryFunc() { _, err = conn.QueryFunc( context.Background(), "select n, n * 2 from generate_series(1, $1) n", - []interface{}{3}, - []interface{}{&a, &b}, + []any{3}, + []any{&a, &b}, func(pgx.QueryFuncRow) error { fmt.Printf("%v, %v\n", a, b) return nil diff --git a/rows.go b/rows.go index f3e154bf..da5773e0 100644 --- a/rows.go +++ b/rows.go @@ -44,12 +44,12 @@ type Rows interface { // dest can include pointers to core types, values implementing the Scanner // interface, and nil. nil will skip the value entirely. It is an error to // call Scan without first calling Next() and checking that it returned true. - Scan(dest ...interface{}) error + Scan(dest ...any) error // Values returns the decoded row values. As with Scan(), it is an error to // call Values without first calling Next() and checking that it returned // true. - Values() ([]interface{}, error) + Values() ([]any, error) // RawValues returns the unparsed bytes of the row values. The returned [][]byte is only valid until the next Next // call or the Rows is closed. However, the underlying byte data is safe to retain a reference to and mutate. @@ -66,13 +66,13 @@ type Row interface { // Scan works the same as Rows. with the following exceptions. If no // rows were found it returns ErrNoRows. If multiple rows are returned it // ignores all but the first. - Scan(dest ...interface{}) error + Scan(dest ...any) error } // connRow implements the Row interface for Conn.QueryRow. type connRow connRows -func (r *connRow) Scan(dest ...interface{}) (err error) { +func (r *connRow) Scan(dest ...any) (err error) { rows := (*connRows)(r) if rows.Err() != nil { @@ -100,7 +100,7 @@ func (r *connRow) Scan(dest ...interface{}) (err error) { type rowLog interface { shouldLog(lvl LogLevel) bool - log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) + log(ctx context.Context, lvl LogLevel, msg string, data map[string]any) } // connRows implements the Rows interface for Conn.Query. @@ -114,7 +114,7 @@ type connRows struct { commandTag pgconn.CommandTag startTime time.Time sql string - args []interface{} + args []any closed bool conn *Conn @@ -155,11 +155,11 @@ func (rows *connRows) Close() { if rows.err == nil { if rows.logger.shouldLog(LogLevelInfo) { endTime := time.Now() - rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) + rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]any{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) } } else { if rows.logger.shouldLog(LogLevelError) { - rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]interface{}{"err": rows.err, "sql": rows.sql, "args": logQueryArgs(rows.args)}) + rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]any{"err": rows.err, "sql": rows.sql, "args": logQueryArgs(rows.args)}) } if rows.err != nil && rows.conn.statementCache != nil { rows.conn.statementCache.StatementErrored(rows.sql, rows.err) @@ -202,7 +202,7 @@ func (rows *connRows) Next() bool { } } -func (rows *connRows) Scan(dest ...interface{}) error { +func (rows *connRows) Scan(dest ...any) error { m := rows.typeMap fieldDescriptions := rows.FieldDescriptions() values := rows.values @@ -248,12 +248,12 @@ func (rows *connRows) Scan(dest ...interface{}) error { return nil } -func (rows *connRows) Values() ([]interface{}, error) { +func (rows *connRows) Values() ([]any, error) { if rows.closed { return nil, errors.New("rows is closed") } - values := make([]interface{}, 0, len(rows.FieldDescriptions())) + values := make([]any, 0, len(rows.FieldDescriptions())) for i := range rows.FieldDescriptions() { buf := rows.values[i] @@ -314,7 +314,7 @@ func (e ScanArgError) Unwrap() error { // fieldDescriptions - OID and format of values // values - the raw data as returned from the PostgreSQL server // dest - the destination that values will be decoded into -func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgproto3.FieldDescription, values [][]byte, dest ...interface{}) error { +func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgproto3.FieldDescription, values [][]byte, dest ...any) error { if len(fieldDescriptions) != len(values) { return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) } diff --git a/stdlib/sql.go b/stdlib/sql.go index 7624605c..61fb77d3 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -37,7 +37,7 @@ // // handle error from acquiring connection from DB pool // } // -// err = conn.Raw(func(driverConn interface{}) error { +// err = conn.Raw(func(driverConn any) error { // conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn // // Do pgx specific stuff with conn // conn.CopyFrom(...) @@ -413,7 +413,7 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na return nil, driver.ErrBadConn } - args := []interface{}{databaseSQLResultFormats} + args := []any{databaseSQLResultFormats} args = append(args, namedValueToInterface(argsV)...) rows, err := c.conn.Query(ctx, query, args...) @@ -746,11 +746,11 @@ func (r *Rows) Next(dest []driver.Value) error { return nil } -func valueToInterface(argsV []driver.Value) []interface{} { - args := make([]interface{}, 0, len(argsV)) +func valueToInterface(argsV []driver.Value) []any { + args := make([]any, 0, len(argsV)) for _, v := range argsV { if v != nil { - args = append(args, v.(interface{})) + args = append(args, v.(any)) } else { args = append(args, nil) } @@ -758,11 +758,11 @@ func valueToInterface(argsV []driver.Value) []interface{} { return args } -func namedValueToInterface(argsV []driver.NamedValue) []interface{} { - args := make([]interface{}, 0, len(argsV)) +func namedValueToInterface(argsV []driver.NamedValue) []any { + args := make([]any, 0, len(argsV)) for _, v := range argsV { if v.Value != nil { - args = append(args, v.Value.(interface{})) + args = append(args, v.Value.(any)) } else { args = append(args, nil) } diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 5fe03976..78b2d01f 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -36,7 +36,7 @@ func skipCockroachDB(t testing.TB, db *sql.DB, msg string) { require.NoError(t, err) defer conn.Close() - err = conn.Raw(func(driverConn interface{}) error { + err = conn.Raw(func(driverConn any) error { conn := driverConn.(*stdlib.Conn).Conn() if conn.PgConn().ParameterStatus("crdb_version") != "" { t.Skip(msg) @@ -51,7 +51,7 @@ func skipPostgreSQLVersionLessThan(t testing.TB, db *sql.DB, minVersion int64) { require.NoError(t, err) defer conn.Close() - err = conn.Raw(func(driverConn interface{}) error { + err = conn.Raw(func(driverConn any) error { conn := driverConn.(*stdlib.Conn).Conn() serverVersionStr := conn.PgConn().ParameterStatus("server_version") serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) @@ -639,7 +639,7 @@ func TestConnRaw(t *testing.T) { require.NoError(t, err) var n int - err = conn.Raw(func(driverConn interface{}) error { + err = conn.Raw(func(driverConn any) error { conn := driverConn.(*stdlib.Conn).Conn() return conn.QueryRow(context.Background(), "select 42").Scan(&n) }) @@ -1036,14 +1036,14 @@ func TestScanJSONIntoJSONRawMessage(t *testing.T) { type testLog struct { lvl pgx.LogLevel msg string - data map[string]interface{} + data map[string]any } type testLogger struct { logs []testLog } -func (l *testLogger) Log(ctx context.Context, lvl pgx.LogLevel, msg string, data map[string]interface{}) { +func (l *testLogger) Log(ctx context.Context, lvl pgx.LogLevel, msg string, data map[string]any) { l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data}) } diff --git a/tx.go b/tx.go index 6b85b303..8a078d1a 100644 --- a/tx.go +++ b/tx.go @@ -160,10 +160,10 @@ type Tx interface { Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) - Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) - Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) - QueryRow(ctx context.Context, sql string, args ...interface{}) Row - QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) + Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) + Query(ctx context.Context, sql string, args ...any) (Rows, error) + QueryRow(ctx context.Context, sql string, args ...any) Row + QueryFunc(ctx context.Context, sql string, args []any, scans []any, f func(QueryFuncRow) error) (pgconn.CommandTag, error) // Conn returns the underlying *Conn that on which this transaction is executing. Conn() *Conn @@ -263,7 +263,7 @@ func (tx *dbTx) Rollback(ctx context.Context) error { } // Exec delegates to the underlying *Conn -func (tx *dbTx) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { +func (tx *dbTx) Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) { return tx.conn.Exec(ctx, sql, arguments...) } @@ -277,7 +277,7 @@ func (tx *dbTx) Prepare(ctx context.Context, name, sql string) (*pgconn.Statemen } // Query delegates to the underlying *Conn -func (tx *dbTx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { +func (tx *dbTx) Query(ctx context.Context, sql string, args ...any) (Rows, error) { if tx.closed { // Because checking for errors can be deferred to the *Rows, build one with the error err := ErrTxClosed @@ -288,13 +288,13 @@ func (tx *dbTx) Query(ctx context.Context, sql string, args ...interface{}) (Row } // QueryRow delegates to the underlying *Conn -func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { +func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...any) Row { rows, _ := tx.Query(ctx, sql, args...) return (*connRow)(rows.(*connRows)) } // QueryFunc delegates to the underlying *Conn. -func (tx *dbTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { +func (tx *dbTx) QueryFunc(ctx context.Context, sql string, args []any, scans []any, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { if tx.closed { return pgconn.CommandTag{}, ErrTxClosed } @@ -378,7 +378,7 @@ func (sp *dbSavepoint) Rollback(ctx context.Context) error { } // Exec delegates to the underlying Tx -func (sp *dbSavepoint) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { +func (sp *dbSavepoint) Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) { if sp.closed { return pgconn.CommandTag{}, ErrTxClosed } @@ -396,7 +396,7 @@ func (sp *dbSavepoint) Prepare(ctx context.Context, name, sql string) (*pgconn.S } // Query delegates to the underlying Tx -func (sp *dbSavepoint) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { +func (sp *dbSavepoint) Query(ctx context.Context, sql string, args ...any) (Rows, error) { if sp.closed { // Because checking for errors can be deferred to the *Rows, build one with the error err := ErrTxClosed @@ -407,13 +407,13 @@ func (sp *dbSavepoint) Query(ctx context.Context, sql string, args ...interface{ } // QueryRow delegates to the underlying Tx -func (sp *dbSavepoint) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { +func (sp *dbSavepoint) QueryRow(ctx context.Context, sql string, args ...any) Row { rows, _ := sp.Query(ctx, sql, args...) return (*connRow)(rows.(*connRows)) } // QueryFunc delegates to the underlying Tx. -func (sp *dbSavepoint) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { +func (sp *dbSavepoint) QueryFunc(ctx context.Context, sql string, args []any, scans []any, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { if sp.closed { return pgconn.CommandTag{}, ErrTxClosed } diff --git a/values.go b/values.go index 12e5db47..d27e071d 100644 --- a/values.go +++ b/values.go @@ -12,7 +12,7 @@ const ( BinaryFormatCode = 1 ) -func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error) { +func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) { if anynil.Is(arg) { return nil, nil } @@ -27,7 +27,7 @@ func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error) return string(buf), nil } -func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([]byte, error) { +func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { if anynil.Is(arg) { return pgio.AppendInt32(buf, -1), nil } diff --git a/values_test.go b/values_test.go index 04441b72..39bf1ead 100644 --- a/values_test.go +++ b/values_test.go @@ -158,12 +158,12 @@ func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) } func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) { - input := map[string]interface{}{ + input := map[string]any{ "name": "Uncanny", - "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, - "inventory": []interface{}{"phone", "key"}, + "stats": map[string]any{"hp": float64(107), "maxhp": float64(150)}, + "inventory": []any{"phone", "key"}, } - var output map[string]interface{} + var output map[string]any err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) if err != nil { t.Errorf("%s: QueryRow Scan failed: %v", typename, err) @@ -171,7 +171,7 @@ func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) { } if !reflect.DeepEqual(input, output) { - t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, input, output) + t.Errorf("%s: Did not transcode map[string]any successfully: %v is not %v", typename, input, output) return } } @@ -559,13 +559,13 @@ func TestArrayDecoding(t *testing.T) { pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { tests := []struct { sql string - query interface{} - scan interface{} - assert func(testing.TB, interface{}, interface{}) + query any + scan any + assert func(testing.TB, any, any) }{ { "select $1::bool[]", []bool{true, false, true}, &[]bool{}, - func(t testing.TB, query, scan interface{}) { + func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]bool))) { t.Errorf("failed to encode bool[]") } @@ -573,7 +573,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{}, - func(t testing.TB, query, scan interface{}) { + func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]int16))) { t.Errorf("failed to encode smallint[]") } @@ -581,7 +581,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{}, - func(t testing.TB, query, scan interface{}) { + func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]uint16))) { t.Errorf("failed to encode smallint[]") } @@ -589,7 +589,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, - func(t testing.TB, query, scan interface{}) { + func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]int32))) { t.Errorf("failed to encode int[]") } @@ -597,7 +597,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{}, - func(t testing.TB, query, scan interface{}) { + func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]uint32))) { t.Errorf("failed to encode int[]") } @@ -605,7 +605,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{}, - func(t testing.TB, query, scan interface{}) { + func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]int64))) { t.Errorf("failed to encode bigint[]") } @@ -613,7 +613,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{}, - func(t testing.TB, query, scan interface{}) { + func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]uint64))) { t.Errorf("failed to encode bigint[]") } @@ -621,7 +621,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, - func(t testing.TB, query, scan interface{}) { + func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]string))) { t.Errorf("failed to encode text[]") } @@ -629,7 +629,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, - func(t testing.TB, query, scan interface{}) { + func(t testing.TB, query, scan any) { queryTimeSlice := query.([]time.Time) scanTimeSlice := *(scan.(*[]time.Time)) require.Equal(t, len(queryTimeSlice), len(scanTimeSlice)) @@ -640,7 +640,7 @@ func TestArrayDecoding(t *testing.T) { }, { "select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{}, - func(t testing.TB, query, scan interface{}) { + func(t testing.TB, query, scan any) { queryBytesSliceSlice := query.([][]byte) scanBytesSliceSlice := *(scan.(*[][]byte)) if len(queryBytesSliceSlice) != len(scanBytesSliceSlice) { @@ -754,26 +754,26 @@ func TestPointerPointer(t *testing.T) { tests := []struct { sql string - queryArgs []interface{} - scanArgs []interface{} + queryArgs []any + scanArgs []any expected allTypes }{ - {"select $1::text", []interface{}{expected.s}, []interface{}{&actual.s}, allTypes{s: expected.s}}, - {"select $1::text", []interface{}{zero.s}, []interface{}{&actual.s}, allTypes{}}, - {"select $1::int2", []interface{}{expected.i16}, []interface{}{&actual.i16}, allTypes{i16: expected.i16}}, - {"select $1::int2", []interface{}{zero.i16}, []interface{}{&actual.i16}, allTypes{}}, - {"select $1::int4", []interface{}{expected.i32}, []interface{}{&actual.i32}, allTypes{i32: expected.i32}}, - {"select $1::int4", []interface{}{zero.i32}, []interface{}{&actual.i32}, allTypes{}}, - {"select $1::int8", []interface{}{expected.i64}, []interface{}{&actual.i64}, allTypes{i64: expected.i64}}, - {"select $1::int8", []interface{}{zero.i64}, []interface{}{&actual.i64}, allTypes{}}, - {"select $1::float4", []interface{}{expected.f32}, []interface{}{&actual.f32}, allTypes{f32: expected.f32}}, - {"select $1::float4", []interface{}{zero.f32}, []interface{}{&actual.f32}, allTypes{}}, - {"select $1::float8", []interface{}{expected.f64}, []interface{}{&actual.f64}, allTypes{f64: expected.f64}}, - {"select $1::float8", []interface{}{zero.f64}, []interface{}{&actual.f64}, allTypes{}}, - {"select $1::bool", []interface{}{expected.b}, []interface{}{&actual.b}, allTypes{b: expected.b}}, - {"select $1::bool", []interface{}{zero.b}, []interface{}{&actual.b}, allTypes{}}, - {"select $1::timestamptz", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}}, - {"select $1::timestamptz", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}}, + {"select $1::text", []any{expected.s}, []any{&actual.s}, allTypes{s: expected.s}}, + {"select $1::text", []any{zero.s}, []any{&actual.s}, allTypes{}}, + {"select $1::int2", []any{expected.i16}, []any{&actual.i16}, allTypes{i16: expected.i16}}, + {"select $1::int2", []any{zero.i16}, []any{&actual.i16}, allTypes{}}, + {"select $1::int4", []any{expected.i32}, []any{&actual.i32}, allTypes{i32: expected.i32}}, + {"select $1::int4", []any{zero.i32}, []any{&actual.i32}, allTypes{}}, + {"select $1::int8", []any{expected.i64}, []any{&actual.i64}, allTypes{i64: expected.i64}}, + {"select $1::int8", []any{zero.i64}, []any{&actual.i64}, allTypes{}}, + {"select $1::float4", []any{expected.f32}, []any{&actual.f32}, allTypes{f32: expected.f32}}, + {"select $1::float4", []any{zero.f32}, []any{&actual.f32}, allTypes{}}, + {"select $1::float8", []any{expected.f64}, []any{&actual.f64}, allTypes{f64: expected.f64}}, + {"select $1::float8", []any{zero.f64}, []any{&actual.f64}, allTypes{}}, + {"select $1::bool", []any{expected.b}, []any{&actual.b}, allTypes{b: expected.b}}, + {"select $1::bool", []any{zero.b}, []any{&actual.b}, allTypes{}}, + {"select $1::timestamptz", []any{expected.t}, []any{&actual.t}, allTypes{t: expected.t}}, + {"select $1::timestamptz", []any{zero.t}, []any{&actual.t}, allTypes{}}, } for i, tt := range tests { @@ -939,11 +939,11 @@ func TestEncodeTypeRename(t *testing.T) { // tests := []struct { // sql string -// expected []interface{} +// expected []any // }{ // { // "select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)", -// []interface{}{ +// []any{ // int32(1), // "cat", // time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(), @@ -951,7 +951,7 @@ func TestEncodeTypeRename(t *testing.T) { // }, // { // "select row(100.0::float, 1.09::float)", -// []interface{}{ +// []any{ // float64(100), // float64(1.09), // }, @@ -959,7 +959,7 @@ func TestEncodeTypeRename(t *testing.T) { // } // for i, tt := range tests { -// var actual []interface{} +// var actual []any // err := conn.QueryRow(context.Background(), tt.sql).Scan(&actual) // if err != nil { From c8025fd79a7531d860b72baa3a933d894f093f57 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Apr 2022 09:34:37 -0500 Subject: [PATCH 0993/1158] Use generics for Range values --- Rakefile | 1 - pgtype/pgtype.go | 14 +- pgtype/range.go | 44 ++++++ pgtype/range_codec_test.go | 40 ++--- pgtype/range_types.go | 296 ------------------------------------- pgtype/range_types.go.erb | 56 ------- 6 files changed, 71 insertions(+), 380 deletions(-) delete mode 100644 pgtype/range_types.go delete mode 100644 pgtype/range_types.go.erb diff --git a/Rakefile b/Rakefile index 3fe26cb5..de174fae 100644 --- a/Rakefile +++ b/Rakefile @@ -11,7 +11,6 @@ generated_code_files = [ "pgtype/int.go", "pgtype/int_test.go", "pgtype/integration_benchmark_test.go", - "pgtype/range_types.go", "pgtype/zeronull/int.go", "pgtype/zeronull/int_test.go" ] diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index ffdb7020..e35299e5 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -371,21 +371,21 @@ func NewMap() *Map { registerDefaultPgTypeVariants("box", "_box", Box{}) registerDefaultPgTypeVariants("circle", "_circle", Circle{}) registerDefaultPgTypeVariants("date", "_date", Date{}) - registerDefaultPgTypeVariants("daterange", "_daterange", Daterange{}) + registerDefaultPgTypeVariants("daterange", "_daterange", Range[Date]{}) registerDefaultPgTypeVariants("float4", "_float4", Float4{}) registerDefaultPgTypeVariants("float8", "_float8", Float8{}) - registerDefaultPgTypeVariants("numrange", "_numrange", Float8range{}) // There is no PostgreSQL builtin float8range so map it to numrange. + registerDefaultPgTypeVariants("numrange", "_numrange", Range[Float8]{}) // There is no PostgreSQL builtin float8range so map it to numrange. registerDefaultPgTypeVariants("inet", "_inet", Inet{}) registerDefaultPgTypeVariants("int2", "_int2", Int2{}) registerDefaultPgTypeVariants("int4", "_int4", Int4{}) - registerDefaultPgTypeVariants("int4range", "_int4range", Int4range{}) + registerDefaultPgTypeVariants("int4range", "_int4range", Range[Int4]{}) registerDefaultPgTypeVariants("int8", "_int8", Int8{}) - registerDefaultPgTypeVariants("int8range", "_int8range", Int8range{}) + registerDefaultPgTypeVariants("int8range", "_int8range", Range[Int8]{}) registerDefaultPgTypeVariants("interval", "_interval", Interval{}) registerDefaultPgTypeVariants("line", "_line", Line{}) registerDefaultPgTypeVariants("lseg", "_lseg", Lseg{}) registerDefaultPgTypeVariants("numeric", "_numeric", Numeric{}) - registerDefaultPgTypeVariants("numrange", "_numrange", Numrange{}) + registerDefaultPgTypeVariants("numrange", "_numrange", Range[Numeric]{}) registerDefaultPgTypeVariants("path", "_path", Path{}) registerDefaultPgTypeVariants("point", "_point", Point{}) registerDefaultPgTypeVariants("polygon", "_polygon", Polygon{}) @@ -394,8 +394,8 @@ func NewMap() *Map { registerDefaultPgTypeVariants("time", "_time", Time{}) registerDefaultPgTypeVariants("timestamp", "_timestamp", Timestamp{}) registerDefaultPgTypeVariants("timestamptz", "_timestamptz", Timestamptz{}) - registerDefaultPgTypeVariants("tsrange", "_tsrange", Tsrange{}) - registerDefaultPgTypeVariants("tstzrange", "_tstzrange", Tstzrange{}) + registerDefaultPgTypeVariants("tsrange", "_tsrange", Range[Timestamp]{}) + registerDefaultPgTypeVariants("tstzrange", "_tstzrange", Range[Timestamptz]{}) registerDefaultPgTypeVariants("uuid", "_uuid", UUID{}) return m diff --git a/pgtype/range.go b/pgtype/range.go index e999f6a9..776bc9eb 100644 --- a/pgtype/range.go +++ b/pgtype/range.go @@ -275,3 +275,47 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { return ubr, nil } + +type Range[T any] struct { + Lower T + Upper T + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Range[T]) IsNull() bool { + return !r.Valid +} + +func (r Range[T]) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Range[T]) Bounds() (lower, upper any) { + return &r.Lower, &r.Upper +} + +func (r *Range[T]) ScanNull() error { + *r = Range[T]{} + return nil +} + +func (r *Range[T]) ScanBounds() (lowerTarget, upperTarget any) { + return &r.Lower, &r.Upper +} + +func (r *Range[T]) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + var zero T + r.Lower = zero + } + if upper == Unbounded || upper == Empty { + var zero T + r.Upper = zero + } + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go index d467b750..ed91d3e8 100644 --- a/pgtype/range_codec_test.go +++ b/pgtype/range_codec_test.go @@ -15,27 +15,27 @@ func TestRangeCodecTranscode(t *testing.T) { pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4range", []pgxtest.ValueRoundTripTest{ { - pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, - new(pgtype.Int4range), - isExpectedEq(pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}), + pgtype.Range[pgtype.Int4]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + new(pgtype.Range[pgtype.Int4]), + isExpectedEq(pgtype.Range[pgtype.Int4]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}), }, { - pgtype.Int4range{ + pgtype.Range[pgtype.Int4]{ LowerType: pgtype.Inclusive, Lower: pgtype.Int4{Int32: 1, Valid: true}, Upper: pgtype.Int4{Int32: 5, Valid: true}, UpperType: pgtype.Exclusive, Valid: true, }, - new(pgtype.Int4range), - isExpectedEq(pgtype.Int4range{ + new(pgtype.Range[pgtype.Int4]), + isExpectedEq(pgtype.Range[pgtype.Int4]{ LowerType: pgtype.Inclusive, Lower: pgtype.Int4{Int32: 1, Valid: true}, Upper: pgtype.Int4{Int32: 5, Valid: true}, UpperType: pgtype.Exclusive, Valid: true, }), }, - {pgtype.Int4range{}, new(pgtype.Int4range), isExpectedEq(pgtype.Int4range{})}, - {nil, new(pgtype.Int4range), isExpectedEq(pgtype.Int4range{})}, + {pgtype.Range[pgtype.Int4]{}, new(pgtype.Range[pgtype.Int4]), isExpectedEq(pgtype.Range[pgtype.Int4]{})}, + {nil, new(pgtype.Range[pgtype.Int4]), isExpectedEq(pgtype.Range[pgtype.Int4]{})}, }) } @@ -47,27 +47,27 @@ func TestRangeCodecTranscodeCompatibleRangeElementTypes(t *testing.T) { pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, nil, "numrange", []pgxtest.ValueRoundTripTest{ { - pgtype.Float8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, - new(pgtype.Float8range), - isExpectedEq(pgtype.Float8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}), + pgtype.Range[pgtype.Float8]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + new(pgtype.Range[pgtype.Float8]), + isExpectedEq(pgtype.Range[pgtype.Float8]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}), }, { - pgtype.Float8range{ + pgtype.Range[pgtype.Float8]{ LowerType: pgtype.Inclusive, Lower: pgtype.Float8{Float64: 1, Valid: true}, Upper: pgtype.Float8{Float64: 5, Valid: true}, UpperType: pgtype.Exclusive, Valid: true, }, - new(pgtype.Float8range), - isExpectedEq(pgtype.Float8range{ + new(pgtype.Range[pgtype.Float8]), + isExpectedEq(pgtype.Range[pgtype.Float8]{ LowerType: pgtype.Inclusive, Lower: pgtype.Float8{Float64: 1, Valid: true}, Upper: pgtype.Float8{Float64: 5, Valid: true}, UpperType: pgtype.Exclusive, Valid: true, }), }, - {pgtype.Float8range{}, new(pgtype.Float8range), isExpectedEq(pgtype.Float8range{})}, - {nil, new(pgtype.Float8range), isExpectedEq(pgtype.Float8range{})}, + {pgtype.Range[pgtype.Float8]{}, new(pgtype.Range[pgtype.Float8]), isExpectedEq(pgtype.Range[pgtype.Float8]{})}, + {nil, new(pgtype.Range[pgtype.Float8]), isExpectedEq(pgtype.Range[pgtype.Float8]{})}, }) } @@ -76,14 +76,14 @@ func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - var r pgtype.Int4range + var r pgtype.Range[pgtype.Int4] err := conn.QueryRow(context.Background(), `select '[1,5)'::int4range`).Scan(&r) require.NoError(t, err) require.Equal( t, - pgtype.Int4range{ + pgtype.Range[pgtype.Int4]{ Lower: pgtype.Int4{Int32: 1, Valid: true}, Upper: pgtype.Int4{Int32: 5, Valid: true}, LowerType: pgtype.Inclusive, @@ -98,7 +98,7 @@ func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { require.Equal( t, - pgtype.Int4range{ + pgtype.Range[pgtype.Int4]{ Lower: pgtype.Int4{Int32: 1, Valid: true}, Upper: pgtype.Int4{}, LowerType: pgtype.Inclusive, @@ -113,7 +113,7 @@ func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { require.Equal( t, - pgtype.Int4range{ + pgtype.Range[pgtype.Int4]{ Lower: pgtype.Int4{}, Upper: pgtype.Int4{}, LowerType: pgtype.Empty, diff --git a/pgtype/range_types.go b/pgtype/range_types.go deleted file mode 100644 index c101fbdc..00000000 --- a/pgtype/range_types.go +++ /dev/null @@ -1,296 +0,0 @@ -// Do not edit. Generated from pgtype/range_types.go.erb -package pgtype - -type Int4range struct { - Lower Int4 - Upper Int4 - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r Int4range) IsNull() bool { - return !r.Valid -} - -func (r Int4range) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r Int4range) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *Int4range) ScanNull() error { - *r = Int4range{} - return nil -} - -func (r *Int4range) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *Int4range) SetBoundTypes(lower, upper BoundType) error { - if lower == Unbounded || lower == Empty { - r.Lower = Int4{} - } - if upper == Unbounded || upper == Empty { - r.Upper = Int4{} - } - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} - -type Int8range struct { - Lower Int8 - Upper Int8 - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r Int8range) IsNull() bool { - return !r.Valid -} - -func (r Int8range) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r Int8range) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *Int8range) ScanNull() error { - *r = Int8range{} - return nil -} - -func (r *Int8range) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *Int8range) SetBoundTypes(lower, upper BoundType) error { - if lower == Unbounded || lower == Empty { - r.Lower = Int8{} - } - if upper == Unbounded || upper == Empty { - r.Upper = Int8{} - } - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} - -type Numrange struct { - Lower Numeric - Upper Numeric - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r Numrange) IsNull() bool { - return !r.Valid -} - -func (r Numrange) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r Numrange) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *Numrange) ScanNull() error { - *r = Numrange{} - return nil -} - -func (r *Numrange) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *Numrange) SetBoundTypes(lower, upper BoundType) error { - if lower == Unbounded || lower == Empty { - r.Lower = Numeric{} - } - if upper == Unbounded || upper == Empty { - r.Upper = Numeric{} - } - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} - -type Tsrange struct { - Lower Timestamp - Upper Timestamp - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r Tsrange) IsNull() bool { - return !r.Valid -} - -func (r Tsrange) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r Tsrange) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *Tsrange) ScanNull() error { - *r = Tsrange{} - return nil -} - -func (r *Tsrange) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *Tsrange) SetBoundTypes(lower, upper BoundType) error { - if lower == Unbounded || lower == Empty { - r.Lower = Timestamp{} - } - if upper == Unbounded || upper == Empty { - r.Upper = Timestamp{} - } - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} - -type Tstzrange struct { - Lower Timestamptz - Upper Timestamptz - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r Tstzrange) IsNull() bool { - return !r.Valid -} - -func (r Tstzrange) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r Tstzrange) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *Tstzrange) ScanNull() error { - *r = Tstzrange{} - return nil -} - -func (r *Tstzrange) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *Tstzrange) SetBoundTypes(lower, upper BoundType) error { - if lower == Unbounded || lower == Empty { - r.Lower = Timestamptz{} - } - if upper == Unbounded || upper == Empty { - r.Upper = Timestamptz{} - } - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} - -type Daterange struct { - Lower Date - Upper Date - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r Daterange) IsNull() bool { - return !r.Valid -} - -func (r Daterange) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r Daterange) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *Daterange) ScanNull() error { - *r = Daterange{} - return nil -} - -func (r *Daterange) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *Daterange) SetBoundTypes(lower, upper BoundType) error { - if lower == Unbounded || lower == Empty { - r.Lower = Date{} - } - if upper == Unbounded || upper == Empty { - r.Upper = Date{} - } - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} - -type Float8range struct { - Lower Float8 - Upper Float8 - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r Float8range) IsNull() bool { - return !r.Valid -} - -func (r Float8range) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r Float8range) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *Float8range) ScanNull() error { - *r = Float8range{} - return nil -} - -func (r *Float8range) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *Float8range) SetBoundTypes(lower, upper BoundType) error { - if lower == Unbounded || lower == Empty { - r.Lower = Float8{} - } - if upper == Unbounded || upper == Empty { - r.Upper = Float8{} - } - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} diff --git a/pgtype/range_types.go.erb b/pgtype/range_types.go.erb deleted file mode 100644 index d181548c..00000000 --- a/pgtype/range_types.go.erb +++ /dev/null @@ -1,56 +0,0 @@ -package pgtype - -<% - [ - ["Int4range", "Int4"], - ["Int8range", "Int8"], - ["Numrange", "Numeric"], - ["Tsrange", "Timestamp"], - ["Tstzrange", "Timestamptz"], - ["Daterange", "Date"], - ["Float8range", "Float8"] - ].each do |range_type, element_type| -%> -type <%= range_type %> struct { - Lower <%= element_type %> - Upper <%= element_type %> - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r <%= range_type %>) IsNull() bool { - return !r.Valid -} - -func (r <%= range_type %>) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r <%= range_type %>) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *<%= range_type %>) ScanNull() error { - *r = <%= range_type %>{} - return nil -} - -func (r *<%= range_type %>) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *<%= range_type %>) SetBoundTypes(lower, upper BoundType) error { - if lower == Unbounded || lower == Empty { - r.Lower = <%= element_type %>{} - } - if upper == Unbounded || upper == Empty { - r.Upper = <%= element_type %>{} - } - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} - -<% end %> From 976b1e03a9cae87212f5794152abd85366d62e06 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Apr 2022 10:18:51 -0500 Subject: [PATCH 0994/1158] Use generics for RangeCodec This allows DecodeValue to return a more strongly typed value. --- pgtype/pgtype.go | 12 +++--- pgtype/range.go | 1 + pgtype/range_codec.go | 84 +++++++++++--------------------------- pgtype/range_codec_test.go | 6 +-- 4 files changed, 34 insertions(+), 69 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index e35299e5..78ed341e 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -263,12 +263,12 @@ func NewMap() *Map { m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) - m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: m.oidToType[DateOID]}}) - m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int4OID]}}) - m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int8OID]}}) - m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[NumericOID]}}) - m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestampOID]}}) - m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestamptzOID]}}) + m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec[Date]{ElementType: m.oidToType[DateOID]}}) + m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec[Int4]{ElementType: m.oidToType[Int4OID]}}) + m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec[Int8]{ElementType: m.oidToType[Int8OID]}}) + m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec[Numeric]{ElementType: m.oidToType[NumericOID]}}) + m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec[Timestamp]{ElementType: m.oidToType[TimestampOID]}}) + m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec[Timestamptz]{ElementType: m.oidToType[TimestamptzOID]}}) m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}}) m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}}) diff --git a/pgtype/range.go b/pgtype/range.go index 776bc9eb..c775239d 100644 --- a/pgtype/range.go +++ b/pgtype/range.go @@ -276,6 +276,7 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { } +// Range is a generic range type. type Range[T any] struct { Lower T Upper T diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go index 6d62e7ff..49a39a47 100644 --- a/pgtype/range_codec.go +++ b/pgtype/range_codec.go @@ -34,79 +34,43 @@ type RangeScanner interface { SetBoundTypes(lower, upper BoundType) error } -type GenericRange struct { - Lower any - Upper any - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r GenericRange) IsNull() bool { - return !r.Valid -} - -func (r GenericRange) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r GenericRange) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *GenericRange) ScanNull() error { - *r = GenericRange{} - return nil -} - -func (r *GenericRange) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *GenericRange) SetBoundTypes(lower, upper BoundType) error { - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} - // RangeCodec is a codec for any range type. -type RangeCodec struct { +type RangeCodec[T any] struct { ElementType *Type } -func (c *RangeCodec) FormatSupported(format int16) bool { +func (c *RangeCodec[T]) FormatSupported(format int16) bool { return c.ElementType.Codec.FormatSupported(format) } -func (c *RangeCodec) PreferredFormat() int16 { +func (c *RangeCodec[T]) PreferredFormat() int16 { if c.FormatSupported(BinaryFormatCode) { return BinaryFormatCode } return TextFormatCode } -func (c *RangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { +func (c *RangeCodec[T]) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(RangeValuer); !ok { return nil } switch format { case BinaryFormatCode: - return &encodePlanRangeCodecRangeValuerToBinary{rc: c, m: m} + return &encodePlanRangeCodecRangeValuerToBinary[T]{rc: c, m: m} case TextFormatCode: - return &encodePlanRangeCodecRangeValuerToText{rc: c, m: m} + return &encodePlanRangeCodecRangeValuerToText[T]{rc: c, m: m} } return nil } -type encodePlanRangeCodecRangeValuerToBinary struct { - rc *RangeCodec +type encodePlanRangeCodecRangeValuerToBinary[T any] struct { + rc *RangeCodec[T] m *Map } -func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { +func (plan *encodePlanRangeCodecRangeValuerToBinary[T]) Encode(value any, buf []byte) (newBuf []byte, err error) { getter := value.(RangeValuer) if getter.IsNull() { @@ -192,12 +156,12 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byt return buf, nil } -type encodePlanRangeCodecRangeValuerToText struct { - rc *RangeCodec +type encodePlanRangeCodecRangeValuerToText[T any] struct { + rc *RangeCodec[T] m *Map } -func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) (newBuf []byte, err error) { +func (plan *encodePlanRangeCodecRangeValuerToText[T]) Encode(value any, buf []byte) (newBuf []byte, err error) { getter := value.(RangeValuer) if getter.IsNull() { @@ -270,29 +234,29 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) return buf, nil } -func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c *RangeCodec[T]) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { case RangeScanner: - return &scanPlanBinaryRangeToRangeScanner{rc: c, m: m} + return &scanPlanBinaryRangeToRangeScanner[T]{rc: c, m: m} } case TextFormatCode: switch target.(type) { case RangeScanner: - return &scanPlanTextRangeToRangeScanner{rc: c, m: m} + return &scanPlanTextRangeToRangeScanner[T]{rc: c, m: m} } } return nil } -type scanPlanBinaryRangeToRangeScanner struct { - rc *RangeCodec +type scanPlanBinaryRangeToRangeScanner[T any] struct { + rc *RangeCodec[T] m *Map } -func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) error { +func (plan *scanPlanBinaryRangeToRangeScanner[T]) Scan(src []byte, target any) error { rangeScanner := (target).(RangeScanner) if src == nil { @@ -337,12 +301,12 @@ func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) erro return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType) } -type scanPlanTextRangeToRangeScanner struct { - rc *RangeCodec +type scanPlanTextRangeToRangeScanner[T any] struct { + rc *RangeCodec[T] m *Map } -func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error { +func (plan *scanPlanTextRangeToRangeScanner[T]) Scan(src []byte, target any) error { rangeScanner := (target).(RangeScanner) if src == nil { @@ -387,7 +351,7 @@ func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType) } -func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *RangeCodec[T]) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -404,12 +368,12 @@ func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr } } -func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { +func (c *RangeCodec[T]) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } - var r GenericRange + var r Range[T] err := c.PlanScan(m, oid, format, &r).Scan(src, &r) return r, err } diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go index ed91d3e8..23e93105 100644 --- a/pgtype/range_codec_test.go +++ b/pgtype/range_codec_test.go @@ -136,9 +136,9 @@ func TestRangeCodecDecodeValue(t *testing.T) { }{ { sql: `select '[1,5)'::int4range`, - expected: pgtype.GenericRange{ - Lower: int32(1), - Upper: int32(5), + expected: pgtype.Range[pgtype.Int4]{ + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true, From 1ef2cee36e154dc0934c2999e38a19eabcf191f3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Apr 2022 10:26:45 -0500 Subject: [PATCH 0995/1158] Update changelog --- CHANGELOG.md | 55 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 553fd915..49cc1279 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,9 @@ ## Merged Packages -`github.com/jackc/pgtype`, `github.com/jackc/pgconn`, and `github.com/jackc/pgproto3` are now included in the main `github.com/jackc/pgx` repository. Previously there was confusion as to where issues should be reported, additional release work due to releasing multiple packages, and less clear changelogs. +`github.com/jackc/pgtype`, `github.com/jackc/pgconn`, and `github.com/jackc/pgproto3` are now included in the main +`github.com/jackc/pgx` repository. Previously there was confusion as to where issues should be reported, additional +release work due to releasing multiple packages, and less clear changelogs. ## pgconn @@ -14,19 +16,34 @@ The `pgtype` package has been significantly changed. ### NULL Representation -Previously, types had a `Status` field that could be `Undefined`, `Null`, or `Present`. This has been changed to a `Valid` `bool` field to harmonize with how `database/sql` represents NULL and to make the zero value useable. +Previously, types had a `Status` field that could be `Undefined`, `Null`, or `Present`. This has been changed to a +`Valid` `bool` field to harmonize with how `database/sql` represents NULL and to make the zero value useable. ### Codec and Value Split -Previously, the type system combined decoding and encoding values with the value types. e.g. Type `Int8` both handled encoding and decoding the PostgreSQL representation and acted as a value object. This caused some difficulties when there was not an exact 1 to 1 relationship between the Go types and the PostgreSQL types For example, scanning a PostgreSQL binary `numeric` into a Go `float64` was awkward (see https://github.com/jackc/pgtype/issues/147). This concepts have been separated. A `Codec` only has responsibility for encoding and decoding values. Value types are generally defined by implementing an interface that a particular `Codec` understands (e.g. `PointScanner` and `PointValuer` for the PostgreSQL `point` type). +Previously, the type system combined decoding and encoding values with the value types. e.g. Type `Int8` both handled +encoding and decoding the PostgreSQL representation and acted as a value object. This caused some difficulties when +there was not an exact 1 to 1 relationship between the Go types and the PostgreSQL types For example, scanning a +PostgreSQL binary `numeric` into a Go `float64` was awkward (see https://github.com/jackc/pgtype/issues/147). This +concepts have been separated. A `Codec` only has responsibility for encoding and decoding values. Value types are +generally defined by implementing an interface that a particular `Codec` understands (e.g. `PointScanner` and +`PointValuer` for the PostgreSQL `point` type). ### Array Types -All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This significantly reduced the amount of code and the compiled binary size. This also means that less common array types such as `point[]` are now supported. +All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This +significantly reduced the amount of code and the compiled binary size. This also means that less common array types such +as `point[]` are now supported. ### Composite Types -Composite types must be registered before use. `CompositeFields` may still be used to construct and destruct composite values, but any type may now implement `CompositeIndexGetter` and `CompositeIndexScanner` to be used as a composite. +Composite types must be registered before use. `CompositeFields` may still be used to construct and destruct composite +values, but any type may now implement `CompositeIndexGetter` and `CompositeIndexScanner` to be used as a composite. + +### Range Types + +Range types are now handled with generic types `RangeCodec[T]` and `Range[T]`. This allows additional user defined range +types to easily be handled. ### pgxtype @@ -43,19 +60,27 @@ The `Bytea` and `GenericBinary` types have been replaced. Use the following inst ### Dropped lib/pq Support -`pgtype` previously supported and was tested against [lib/pq](https://github.com/lib/pq). While it will continue to work in most cases this is no longer supported. +`pgtype` previously supported and was tested against [lib/pq](https://github.com/lib/pq). While it will continue to work +in most cases this is no longer supported. ### database/sql Scan -Previously, most `Scan` implementations would convert `[]byte` to `string` automatically to decode a text value. Now only `string` is handled. This is to allow the possibility of future binary support in `database/sql` mode by considering `[]byte` to be binary format and `string` text format. This change should have no effect for any use with `pgx`. The previous behavior was only necessary for `lib/pq` compatibility. +Previously, most `Scan` implementations would convert `[]byte` to `string` automatically to decode a text value. Now +only `string` is handled. This is to allow the possibility of future binary support in `database/sql` mode by +considering `[]byte` to be binary format and `string` text format. This change should have no effect for any use with +`pgx`. The previous behavior was only necessary for `lib/pq` compatibility. ### Number Type Fields Include Bit size -`Int2`, `Int4`, `Int8`, `Float4`, `Float8`, and `Uint32` fields now include bit size. e.g. `Int` is renamed to `Int64`. This matches the convention set by `database/sql`. In addition, for comparable types like `pgtype.Int8` and `sql.NullInt64` the structures are identical. This means they can be directly converted one to another. +`Int2`, `Int4`, `Int8`, `Float4`, `Float8`, and `Uint32` fields now include bit size. e.g. `Int` is renamed to `Int64`. +This matches the convention set by `database/sql`. In addition, for comparable types like `pgtype.Int8` and +`sql.NullInt64` the structures are identical. This means they can be directly converted one to another. ### 3rd Party Type Integrations -* Extracted integrations with github.com/shopspring/decimal and github.com/gofrs/uuid to https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. This trims the pgx dependency tree. +* Extracted integrations with github.com/shopspring/decimal and github.com/gofrs/uuid to + https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. This trims + the pgx dependency tree. ### Other Changes @@ -72,12 +97,18 @@ Previously, most `Scan` implementations would convert `[]byte` to `string` autom ## Reduced Memory Usage by Reusing Read Buffers -Previously, the connection read buffer would allocate large chunks of memory and never reuse them. This allowed transferring ownership to anything such as scanned values without incurring an additional allocation and memory copy. However, this came at the cost of overall increased memory allocation size. But worse it was also possible to pin large chunks of memory by retaining a reference to a small value that originally came directly from the read buffer. Now ownership remains with the read buffer and anything needing to retain a value must make a copy. +Previously, the connection read buffer would allocate large chunks of memory and never reuse them. This allowed +transferring ownership to anything such as scanned values without incurring an additional allocation and memory copy. +However, this came at the cost of overall increased memory allocation size. But worse it was also possible to pin large +chunks of memory by retaining a reference to a small value that originally came directly from the read buffer. Now +ownership remains with the read buffer and anything needing to retain a value must make a copy. ## Query Execution Modes -Control over automatic prepared statement caching and simple protocol use are now combined into query execution mode. See documentation for `QueryExecMode`. +Control over automatic prepared statement caching and simple protocol use are now combined into query execution mode. +See documentation for `QueryExecMode`. ## 3rd Party Logger Integration -All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency tree. +All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency +tree. From 175856ffd3c8377db2e631b99ef7a7c996fdae77 Mon Sep 17 00:00:00 2001 From: Oliver Tan Date: Tue, 12 Apr 2022 14:26:13 +1000 Subject: [PATCH 0996/1158] add GSS authentication to pgproto3 --- authentication_gss.go | 58 +++++++++++++++++++++++++++++ authentication_gss_continue.go | 67 ++++++++++++++++++++++++++++++++++ backend.go | 11 +++--- frontend.go | 6 ++- gss_response.go | 48 ++++++++++++++++++++++++ json_test.go | 39 ++++++++++++++++++++ 6 files changed, 222 insertions(+), 7 deletions(-) create mode 100644 authentication_gss.go create mode 100644 authentication_gss_continue.go create mode 100644 gss_response.go diff --git a/authentication_gss.go b/authentication_gss.go new file mode 100644 index 00000000..5a3f3b1d --- /dev/null +++ b/authentication_gss.go @@ -0,0 +1,58 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + "github.com/jackc/pgio" +) + +type AuthenticationGSS struct{} + +func (a *AuthenticationGSS) Backend() {} + +func (a *AuthenticationGSS) AuthenticationResponse() {} + +func (a *AuthenticationGSS) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeGSS { + return errors.New("bad auth type") + } + return nil +} + +func (a *AuthenticationGSS) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, 4) + dst = pgio.AppendUint32(dst, AuthTypeGSS) + return dst +} + +func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "AuthenticationGSS", + }) +} + +func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + return nil +} diff --git a/authentication_gss_continue.go b/authentication_gss_continue.go new file mode 100644 index 00000000..cf8b1834 --- /dev/null +++ b/authentication_gss_continue.go @@ -0,0 +1,67 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + "github.com/jackc/pgio" +) + +type AuthenticationGSSContinue struct { + Data []byte +} + +func (a *AuthenticationGSSContinue) Backend() {} + +func (a *AuthenticationGSSContinue) AuthenticationResponse() {} + +func (a *AuthenticationGSSContinue) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeGSSCont { + return errors.New("bad auth type") + } + + a.Data = src[4:] + return nil +} + +func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, int32(len(a.Data))+8) + dst = pgio.AppendUint32(dst, AuthTypeGSSCont) + dst = append(dst, a.Data...) + return dst +} + +func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "AuthenticationGSSContinue", + Data: a.Data, + }) +} + +func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Data []byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + a.Data = msg.Data + return nil +} diff --git a/backend.go b/backend.go index 9c42ad02..a48b66f1 100644 --- a/backend.go +++ b/backend.go @@ -30,11 +30,10 @@ type Backend struct { sync Sync terminate Terminate - bodyLen int - msgType byte - partialMsg bool - authType uint32 - + bodyLen int + msgType byte + partialMsg bool + authType uint32 } const ( @@ -147,6 +146,8 @@ func (b *Backend) Receive() (FrontendMessage, error) { msg = &SASLResponse{} case AuthTypeSASLFinal: msg = &SASLResponse{} + case AuthTypeGSS, AuthTypeGSSCont: + msg = &GSSResponse{} case AuthTypeCleartextPassword, AuthTypeMD5Password: fallthrough default: diff --git a/frontend.go b/frontend.go index c33dfb08..f15a3e04 100644 --- a/frontend.go +++ b/frontend.go @@ -16,6 +16,8 @@ type Frontend struct { authenticationOk AuthenticationOk authenticationCleartextPassword AuthenticationCleartextPassword authenticationMD5Password AuthenticationMD5Password + authenticationGSS AuthenticationGSS + authenticationGSSContinue AuthenticationGSSContinue authenticationSASL AuthenticationSASL authenticationSASLContinue AuthenticationSASLContinue authenticationSASLFinal AuthenticationSASLFinal @@ -178,9 +180,9 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er case AuthTypeSCMCreds: return nil, errors.New("AuthTypeSCMCreds is unimplemented") case AuthTypeGSS: - return nil, errors.New("AuthTypeGSS is unimplemented") + return &f.authenticationGSS, nil case AuthTypeGSSCont: - return nil, errors.New("AuthTypeGSSCont is unimplemented") + return &f.authenticationGSSContinue, nil case AuthTypeSSPI: return nil, errors.New("AuthTypeSSPI is unimplemented") case AuthTypeSASL: diff --git a/gss_response.go b/gss_response.go new file mode 100644 index 00000000..62da99c7 --- /dev/null +++ b/gss_response.go @@ -0,0 +1,48 @@ +package pgproto3 + +import ( + "encoding/json" + "github.com/jackc/pgio" +) + +type GSSResponse struct { + Data []byte +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (g *GSSResponse) Frontend() {} + +func (g *GSSResponse) Decode(data []byte) error { + g.Data = data + return nil +} + +func (g *GSSResponse) Encode(dst []byte) []byte { + dst = append(dst, 'p') + dst = pgio.AppendInt32(dst, int32(4+len(g.Data))) + dst = append(dst, g.Data...) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (g *GSSResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "GSSResponse", + Data: g.Data, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (g *GSSResponse) UnmarshalJSON(data []byte) error { + var msg struct { + Data []byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + g.Data = msg.Data + return nil +} diff --git a/json_test.go b/json_test.go index eab26252..8fad4f88 100644 --- a/json_test.go +++ b/json_test.go @@ -37,6 +37,32 @@ func TestJSONUnmarshalAuthenticationSASL(t *testing.T) { } } +func TestJSONUnmarshalAuthenticationGSS(t *testing.T) { + data := []byte(`{"Type":"AuthenticationGSS"}`) + want := AuthenticationGSS{} + + var got AuthenticationGSS + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationGSS struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationGSSContinue(t *testing.T) { + data := []byte(`{"Type":"AuthenticationGSSContinue","Data":[1,2,3,4]}`) + want := AuthenticationGSSContinue{Data: []byte{1, 2, 3, 4}} + + var got AuthenticationGSSContinue + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationGSSContinue struct doesn't match expected value") + } +} + func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) { data := []byte(`{"Type":"AuthenticationSASLContinue", "Data":"1"}`) want := AuthenticationSASLContinue{ @@ -551,6 +577,19 @@ func TestAuthenticationMD5Password(t *testing.T) { } } +func TestJSONUnmarshalGSSResponse(t *testing.T) { + data := []byte(`{"Type":"GSSResponse","Data":[10,20,30,40]}`) + want := GSSResponse{Data: []byte{10, 20, 30, 40}} + + var got GSSResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled GSSResponse struct doesn't match expected value") + } +} + func TestErrorResponse(t *testing.T) { data := []byte(`{"Type":"ErrorResponse","UnknownFields":{"112":"foo"},"Code": "Fail","Position":1,"Message":"this is an error"}`) want := ErrorResponse{ From b03b1666a640d4a1f8076cc0e16001b043ea5df1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 14 Apr 2022 11:50:12 -0500 Subject: [PATCH 0997/1158] Add Hijack to pgxpool.Conn --- pgxpool/conn.go | 16 ++++++++++++++++ pgxpool/pool_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/pgxpool/conn.go b/pgxpool/conn.go index d842779e..edb4f257 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -46,6 +46,22 @@ func (c *Conn) Release() { }() } +// Hijack assumes ownership of the connection from the pool. Caller is responsible for closing the connection. Hijack +// will panic if called on an already released or hijacked connection. +func (c *Conn) Hijack() *pgx.Conn { + if c.res == nil { + panic("cannot hijack already released or hijacked connection") + } + + conn := c.Conn() + res := c.res + c.res = nil + + res.Hijack() + + return conn +} + func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { return c.Conn().Exec(ctx, sql, arguments...) } diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 8a898720..b58791fd 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -115,6 +115,32 @@ func TestPoolAcquireAndConnRelease(t *testing.T) { c.Release() } +func TestPoolAcquireAndConnHijack(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + + connsBeforeHijack := pool.Stat().TotalConns() + + conn := c.Hijack() + defer conn.Close(ctx) + + connsAfterHijack := pool.Stat().TotalConns() + require.Equal(t, connsBeforeHijack-1, connsAfterHijack) + + var n int32 + err = conn.QueryRow(ctx, `select 1`).Scan(&n) + require.NoError(t, err) + require.Equal(t, int32(1), n) +} + func TestPoolAcquireFunc(t *testing.T) { t.Parallel() From 90ef5bba3fff4eaa6e8d2faf20489bb907876b8b Mon Sep 17 00:00:00 2001 From: Oliver Tan Date: Wed, 13 Apr 2022 07:15:08 +1000 Subject: [PATCH 0998/1158] add GSSAPI authentication This commit adds the GSSAPI authentication to pgx. This roughly follows the lib/pq implementation: * We require registering a provider to avoid mass dependency inclusions that may not be desired (https://github.com/lib/pq/issues/971). * Requires the pgproto3 package be updated. I've included my custom fork for now. --- config.go | 2 ++ go.mod | 2 +- go.sum | 4 +-- krb5.go | 94 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ pgconn.go | 7 ++++- 5 files changed, 105 insertions(+), 4 deletions(-) create mode 100644 krb5.go diff --git a/config.go b/config.go index 5cee9297..6e6930ee 100644 --- a/config.go +++ b/config.go @@ -257,6 +257,8 @@ func ParseConfig(connString string) (*Config, error) { "sslkey": {}, "sslcert": {}, "sslrootcert": {}, + "krbspn": {}, + "krbsrvname": {}, "target_session_attrs": {}, "min_read_buffer_size": {}, "service": {}, diff --git a/go.mod b/go.mod index fb3ed181..2a2d6810 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.1.1 + github.com/jackc/pgproto3/v2 v2.2.1-0.20220412121321-175856ffd3c8 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.7.0 golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 diff --git a/go.sum b/go.sum index bdb5ee8c..c558564b 100644 --- a/go.sum +++ b/go.sum @@ -31,8 +31,9 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.2.1-0.20220412121321-175856ffd3c8 h1:KxsCQec+1iwJXtxnbbS/dY0EJ6rJEUlFsrJUnL5A2XI= +github.com/jackc/pgproto3/v2 v2.2.1-0.20220412121321-175856ffd3c8/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= @@ -112,7 +113,6 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= diff --git a/krb5.go b/krb5.go new file mode 100644 index 00000000..1f9ce97c --- /dev/null +++ b/krb5.go @@ -0,0 +1,94 @@ +package pgconn + +import ( + "errors" + "github.com/jackc/pgproto3/v2" +) + +// NewGSSFunc creates a GSS authentication provider, for use with +// RegisterGSSProvider. +type NewGSSFunc func() (GSS, error) + +var newGSS NewGSSFunc + +// RegisterGSSProvider registers a GSS authentication provider. For example, if +// you need to use Kerberos to authenticate with your server, add this to your +// main package: +// +// import "github.com/otan/gopgkrb5" +// +// func init() { +// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() }) +// } +func RegisterGSSProvider(newGSSArg NewGSSFunc) { + newGSS = newGSSArg +} + +// GSS provides GSSAPI authentication (e.g., Kerberos). +type GSS interface { + GetInitToken(host string, service string) ([]byte, error) + GetInitTokenFromSPN(spn string) ([]byte, error) + Continue(inToken []byte) (done bool, outToken []byte, err error) +} + +func (c *PgConn) gssAuth() error { + if newGSS == nil { + return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5") + } + cli, err := newGSS() + if err != nil { + return err + } + + var nextData []byte + if spn, ok := c.config.RuntimeParams["krbspn"]; ok { + // Use the supplied SPN if provided. + nextData, err = cli.GetInitTokenFromSPN(spn) + } else { + // Allow the kerberos service name to be overridden + service := "postgres" + if val, ok := c.config.RuntimeParams["krbsrvname"]; ok { + service = val + } + nextData, err = cli.GetInitToken(c.config.Host, service) + } + if err != nil { + return err + } + + for { + gssResponse := &pgproto3.GSSResponse{ + Data: nextData, + } + _, err = c.conn.Write(gssResponse.Encode(nil)) + if err != nil { + return err + } + resp, err := c.rxGSSContinue() + if err != nil { + return err + } + var done bool + done, nextData, err = cli.Continue(resp.Data) + if err != nil { + return err + } + if done { + break + } + } + return nil +} + +func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + gssContinue, ok := msg.(*pgproto3.AuthenticationGSSContinue) + if ok { + return gssContinue, nil + } + + return nil, errors.New("expected AuthenticationGSSContinue message but received unexpected message") +} diff --git a/pgconn.go b/pgconn.go index 9a496ed0..0d07ac57 100644 --- a/pgconn.go +++ b/pgconn.go @@ -320,7 +320,12 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.conn.Close() return nil, &connectError{config: config, msg: "failed SASL auth", err: err} } - + case *pgproto3.AuthenticationGSS: + err = pgConn.gssAuth() + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed GSS auth", err: err} + } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle if config.ValidateConnect != nil { From 25558de3bd1bd2441cd3442394502b567ea94fbe Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Apr 2022 07:07:31 -0500 Subject: [PATCH 0999/1158] Add UnmarshalJSON to pgtype.Int2 fixes https://github.com/jackc/pgtype/issues/153 --- int2.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/int2.go b/int2.go index 3eb5aeb5..0775882a 100644 --- a/int2.go +++ b/int2.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "fmt" "math" "strconv" @@ -302,3 +303,19 @@ func (src Int2) MarshalJSON() ([]byte, error) { return nil, errBadStatus } + +func (dst *Int2) UnmarshalJSON(b []byte) error { + var n *int16 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int2{Status: Null} + } else { + *dst = Int2{Int: *n, Status: Present} + } + + return nil +} From beb4e2cfbcd7f41e6389254496cc5feae7d99d9c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Apr 2022 07:24:25 -0500 Subject: [PATCH 1000/1158] SQLCODE 42501 is fatal connect error Don't try fallback configs. Match libpq behavior. fixes https://github.com/jackc/pgconn/issues/108 --- pgconn.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 0d07ac57..f1304d08 100644 --- a/pgconn.go +++ b/pgconn.go @@ -157,9 +157,11 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist + const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege if pgerr.Code == ERRCODE_INVALID_PASSWORD || pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION || - pgerr.Code == ERRCODE_INVALID_CATALOG_NAME { + pgerr.Code == ERRCODE_INVALID_CATALOG_NAME || + pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { break } } From 8b483e42230f2267d264b80349170106bc4dbe87 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Apr 2022 09:28:46 -0500 Subject: [PATCH 1001/1158] Use generic / type safe puddle for pgxpool --- go.mod | 2 +- go.sum | 2 ++ pgxpool/conn.go | 6 +++--- pgxpool/pool.go | 16 ++++++++-------- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index f527c7fb..6710f0e8 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.18 require ( github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b - github.com/jackc/puddle v1.2.1 + github.com/jackc/puddle v1.2.2-0.20220404125616-4e959849469a github.com/stretchr/testify v1.7.0 golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b golang.org/x/text v0.3.7 diff --git a/go.sum b/go.sum index a9851d43..9dc19a92 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHF github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/puddle v1.2.1 h1:gI8os0wpRXFd4FiAY2dWiqRK037tjj3t7rKFeO4X5iw= github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.2.2-0.20220404125616-4e959849469a h1:oH7y/b+q2BEerCnARr/HZc1NxOYbKSJor4MqQXlhh+s= +github.com/jackc/puddle v1.2.2-0.20220404125616-4e959849469a/go.mod h1:ZQuO1Un86Xpe1ShKl08ERTzYhzWq+OvrvotbpeE3XO0= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= diff --git a/pgxpool/conn.go b/pgxpool/conn.go index edb4f257..3ab8b375 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -6,12 +6,12 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/puddle" + puddle "github.com/jackc/puddle/puddleg" ) // Conn is an acquired *pgx.Conn from a Pool. type Conn struct { - res *puddle.Resource + res *puddle.Resource[*connResource] p *Pool } @@ -113,7 +113,7 @@ func (c *Conn) Conn() *pgx.Conn { } func (c *Conn) connResource() *connResource { - return c.res.Value().(*connResource) + return c.res.Value() } func (c *Conn) getPoolRow(r pgx.Row) *poolRow { diff --git a/pgxpool/pool.go b/pgxpool/pool.go index c3abc51a..6e50198d 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -10,7 +10,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/puddle" + puddle "github.com/jackc/puddle/puddleg" ) var defaultMaxConns = int32(4) @@ -26,7 +26,7 @@ type connResource struct { poolRowss []poolRows } -func (cr *connResource) getConn(p *Pool, res *puddle.Resource) *Conn { +func (cr *connResource) getConn(p *Pool, res *puddle.Resource[*connResource]) *Conn { if len(cr.conns) == 0 { cr.conns = make([]Conn, 128) } @@ -70,7 +70,7 @@ func (cr *connResource) getPoolRows(c *Conn, r pgx.Rows) *poolRows { // Pool allows for connection reuse. type Pool struct { - p *puddle.Pool + p *puddle.Pool[*connResource] config *Config beforeConnect func(context.Context, *pgx.ConnConfig) error afterConnect func(context.Context, *pgx.Conn) error @@ -177,7 +177,7 @@ func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { } p.p = puddle.NewPool( - func(ctx context.Context) (any, error) { + func(ctx context.Context) (*connResource, error) { connConfig := p.config.ConnConfig if p.beforeConnect != nil { @@ -209,9 +209,9 @@ func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { return cr, nil }, - func(value any) { + func(value *connResource) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - conn := value.(*connResource).conn + conn := value.conn conn.Close(ctx) select { case <-conn.PgConn().CleanupDone(): @@ -416,7 +416,7 @@ func (p *Pool) Acquire(ctx context.Context) (*Conn, error) { return nil, err } - cr := res.Value().(*connResource) + cr := res.Value() if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { return cr.getConn(p, res), nil } @@ -444,7 +444,7 @@ func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn { resources := p.p.AcquireAllIdle() conns := make([]*Conn, 0, len(resources)) for _, res := range resources { - cr := res.Value().(*connResource) + cr := res.Value() if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { conns = append(conns, cr.getConn(p, res)) } else { From d4abe83edb3f2793ee889afb5b4f955f8c677652 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Apr 2022 10:39:12 -0500 Subject: [PATCH 1002/1158] Revert use generics for RangeCodec Reverted almost all of 976b1e0. Still may consider a way to get DecodeValue to be strongly typed but that feature isn't worth the complications of generics. Especially in that applying this style to ArrayCodec would make Conn.LoadType impossible for arrays. --- pgtype/pgtype.go | 12 +++++----- pgtype/range_codec.go | 48 +++++++++++++++++++------------------- pgtype/range_codec_test.go | 6 ++--- 3 files changed, 33 insertions(+), 33 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 78ed341e..e35299e5 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -263,12 +263,12 @@ func NewMap() *Map { m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) - m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec[Date]{ElementType: m.oidToType[DateOID]}}) - m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec[Int4]{ElementType: m.oidToType[Int4OID]}}) - m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec[Int8]{ElementType: m.oidToType[Int8OID]}}) - m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec[Numeric]{ElementType: m.oidToType[NumericOID]}}) - m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec[Timestamp]{ElementType: m.oidToType[TimestampOID]}}) - m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec[Timestamptz]{ElementType: m.oidToType[TimestamptzOID]}}) + m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: m.oidToType[DateOID]}}) + m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int4OID]}}) + m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int8OID]}}) + m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[NumericOID]}}) + m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestampOID]}}) + m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestamptzOID]}}) m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}}) m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}}) diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go index 49a39a47..d50bd3cb 100644 --- a/pgtype/range_codec.go +++ b/pgtype/range_codec.go @@ -35,42 +35,42 @@ type RangeScanner interface { } // RangeCodec is a codec for any range type. -type RangeCodec[T any] struct { +type RangeCodec struct { ElementType *Type } -func (c *RangeCodec[T]) FormatSupported(format int16) bool { +func (c *RangeCodec) FormatSupported(format int16) bool { return c.ElementType.Codec.FormatSupported(format) } -func (c *RangeCodec[T]) PreferredFormat() int16 { +func (c *RangeCodec) PreferredFormat() int16 { if c.FormatSupported(BinaryFormatCode) { return BinaryFormatCode } return TextFormatCode } -func (c *RangeCodec[T]) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { +func (c *RangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(RangeValuer); !ok { return nil } switch format { case BinaryFormatCode: - return &encodePlanRangeCodecRangeValuerToBinary[T]{rc: c, m: m} + return &encodePlanRangeCodecRangeValuerToBinary{rc: c, m: m} case TextFormatCode: - return &encodePlanRangeCodecRangeValuerToText[T]{rc: c, m: m} + return &encodePlanRangeCodecRangeValuerToText{rc: c, m: m} } return nil } -type encodePlanRangeCodecRangeValuerToBinary[T any] struct { - rc *RangeCodec[T] +type encodePlanRangeCodecRangeValuerToBinary struct { + rc *RangeCodec m *Map } -func (plan *encodePlanRangeCodecRangeValuerToBinary[T]) Encode(value any, buf []byte) (newBuf []byte, err error) { +func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { getter := value.(RangeValuer) if getter.IsNull() { @@ -156,12 +156,12 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary[T]) Encode(value any, buf [] return buf, nil } -type encodePlanRangeCodecRangeValuerToText[T any] struct { - rc *RangeCodec[T] +type encodePlanRangeCodecRangeValuerToText struct { + rc *RangeCodec m *Map } -func (plan *encodePlanRangeCodecRangeValuerToText[T]) Encode(value any, buf []byte) (newBuf []byte, err error) { +func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) (newBuf []byte, err error) { getter := value.(RangeValuer) if getter.IsNull() { @@ -234,29 +234,29 @@ func (plan *encodePlanRangeCodecRangeValuerToText[T]) Encode(value any, buf []by return buf, nil } -func (c *RangeCodec[T]) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { case RangeScanner: - return &scanPlanBinaryRangeToRangeScanner[T]{rc: c, m: m} + return &scanPlanBinaryRangeToRangeScanner{rc: c, m: m} } case TextFormatCode: switch target.(type) { case RangeScanner: - return &scanPlanTextRangeToRangeScanner[T]{rc: c, m: m} + return &scanPlanTextRangeToRangeScanner{rc: c, m: m} } } return nil } -type scanPlanBinaryRangeToRangeScanner[T any] struct { - rc *RangeCodec[T] +type scanPlanBinaryRangeToRangeScanner struct { + rc *RangeCodec m *Map } -func (plan *scanPlanBinaryRangeToRangeScanner[T]) Scan(src []byte, target any) error { +func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) error { rangeScanner := (target).(RangeScanner) if src == nil { @@ -301,12 +301,12 @@ func (plan *scanPlanBinaryRangeToRangeScanner[T]) Scan(src []byte, target any) e return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType) } -type scanPlanTextRangeToRangeScanner[T any] struct { - rc *RangeCodec[T] +type scanPlanTextRangeToRangeScanner struct { + rc *RangeCodec m *Map } -func (plan *scanPlanTextRangeToRangeScanner[T]) Scan(src []byte, target any) error { +func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error { rangeScanner := (target).(RangeScanner) if src == nil { @@ -351,7 +351,7 @@ func (plan *scanPlanTextRangeToRangeScanner[T]) Scan(src []byte, target any) err return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType) } -func (c *RangeCodec[T]) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -368,12 +368,12 @@ func (c *RangeCodec[T]) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, } } -func (c *RangeCodec[T]) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { +func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } - var r Range[T] + var r Range[any] err := c.PlanScan(m, oid, format, &r).Scan(src, &r) return r, err } diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go index 23e93105..c0628747 100644 --- a/pgtype/range_codec_test.go +++ b/pgtype/range_codec_test.go @@ -136,9 +136,9 @@ func TestRangeCodecDecodeValue(t *testing.T) { }{ { sql: `select '[1,5)'::int4range`, - expected: pgtype.Range[pgtype.Int4]{ - Lower: pgtype.Int4{Int32: 1, Valid: true}, - Upper: pgtype.Int4{Int32: 5, Valid: true}, + expected: pgtype.Range[any]{ + Lower: int32(1), + Upper: int32(5), LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true, From f1a4ae307092a3caedb42ec417ff9be4851779ed Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Apr 2022 11:28:37 -0500 Subject: [PATCH 1003/1158] Add Array and FlatArray container types --- Rakefile | 1 - pgtype/array.go | 85 +++++++++++++++++++++++++++++++ pgtype/array_codec.go | 6 +-- pgtype/array_codec_test.go | 48 +++++++++++++++++ pgtype/array_getter_setter.go | 78 ---------------------------- pgtype/array_getter_setter.go.erb | 53 ------------------- pgtype/builtin_wrappers.go | 14 ++--- pgtype/pgtype.go | 82 +++++++++++++++++++++++++---- 8 files changed, 214 insertions(+), 153 deletions(-) delete mode 100644 pgtype/array_getter_setter.go delete mode 100644 pgtype/array_getter_setter.go.erb diff --git a/Rakefile b/Rakefile index de174fae..d957573e 100644 --- a/Rakefile +++ b/Rakefile @@ -7,7 +7,6 @@ rule '.go' => '.go.erb' do |task| end generated_code_files = [ - "pgtype/array_getter_setter.go", "pgtype/int.go", "pgtype/int_test.go", "pgtype/integration_benchmark_test.go", diff --git a/pgtype/array.go b/pgtype/array.go index 8de2b4dd..d34a94e5 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -394,3 +394,88 @@ func findDimensionsFromValue(value reflect.Value, dimensions []ArrayDimension, e } return dimensions, elementsLength, true } + +// Array represents a PostgreSQL array for T. It implements the ArrayGetter and ArraySetter interfaces. It preserves +// PostgreSQL dimensions and custom lower bounds. Use FlatArray if these are not needed. +type Array[T any] struct { + Elements []T + Dims []ArrayDimension + Valid bool +} + +func (a Array[T]) Dimensions() []ArrayDimension { + return a.Dims +} + +func (a Array[T]) Index(i int) any { + return a.Elements[i] +} + +func (a Array[T]) IndexType() any { + var el T + return el +} + +func (a *Array[T]) SetDimensions(dimensions []ArrayDimension) error { + if dimensions == nil { + *a = Array[T]{} + return nil + } + + elementCount := cardinality(dimensions) + *a = Array[T]{ + Elements: make([]T, elementCount), + Dims: dimensions, + Valid: true, + } + + return nil +} + +func (a Array[T]) ScanIndex(i int) any { + return &a.Elements[i] +} + +func (a Array[T]) ScanIndexType() any { + return new(T) +} + +// FlatArray implements the ArrayGetter and ArraySetter interfaces for any slice of T. It ignores PostgreSQL dimensions +// and custom lower bounds. Use Array to preserve these. +type FlatArray[T any] []T + +func (a FlatArray[T]) Dimensions() []ArrayDimension { + if a == nil { + return nil + } + + return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} +} + +func (a FlatArray[T]) Index(i int) any { + return a[i] +} + +func (a FlatArray[T]) IndexType() any { + var el T + return el +} + +func (a *FlatArray[T]) SetDimensions(dimensions []ArrayDimension) error { + if dimensions == nil { + a = nil + return nil + } + + elementCount := cardinality(dimensions) + *a = make(FlatArray[T], elementCount) + return nil +} + +func (a FlatArray[T]) ScanIndex(i int) any { + return &a[i] +} + +func (a FlatArray[T]) ScanIndexType() any { + return new(T) +} diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 379a9096..8aab13bb 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -23,9 +23,9 @@ type ArrayGetter interface { // ArraySetter is a type can be set from a PostgreSQL array. type ArraySetter interface { - // SetDimensions prepares the value such that ScanIndex can be called for each element. dimensions may be nil to - // indicate a NULL array. If unable to exactly preserve dimensions SetDimensions may return an error or silently - // flatten the array dimensions. + // SetDimensions prepares the value such that ScanIndex can be called for each element. This will remove any existing + // elements. dimensions may be nil to indicate a NULL array. If unable to exactly preserve dimensions SetDimensions + // may return an error or silently flatten the array dimensions. SetDimensions(dimensions []ArrayDimension) error // ScanIndex returns a value usable as a scan target for i. SetDimensions must be called before ScanIndex. diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go index e4c00d1e..65289d04 100644 --- a/pgtype/array_codec_test.go +++ b/pgtype/array_codec_test.go @@ -5,6 +5,7 @@ import ( "testing" pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -47,6 +48,53 @@ func TestArrayCodec(t *testing.T) { }) } +func TestArrayCodecFlatArray(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + for i, tt := range []struct { + expected any + }{ + {pgtype.FlatArray[int32](nil)}, + {pgtype.FlatArray[int32]{}}, + {pgtype.FlatArray[int32]{1, 2, 3}}, + } { + var actual pgtype.FlatArray[int32] + err := conn.QueryRow( + ctx, + "select $1::int[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + }) +} + +func TestArrayCodecArray(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + for i, tt := range []struct { + expected any + }{ + {pgtype.Array[int32]{ + Elements: []int32{1, 2, 3, 4}, + Dims: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 2}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }}, + } { + var actual pgtype.Array[int32] + err := conn.QueryRow( + ctx, + "select $1::int[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + }) +} + func TestArrayCodecAnySlice(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { type _int16Slice []int16 diff --git a/pgtype/array_getter_setter.go b/pgtype/array_getter_setter.go deleted file mode 100644 index b0c6b505..00000000 --- a/pgtype/array_getter_setter.go +++ /dev/null @@ -1,78 +0,0 @@ -// Do not edit. Generated from pgtype/array_getter_setter.go.erb -package pgtype - -type int16Array []int16 - -func (a int16Array) Dimensions() []ArrayDimension { - if a == nil { - return nil - } - - return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} -} - -func (a int16Array) Index(i int) any { - return a[i] -} - -func (a int16Array) IndexType() any { - var el int16 - return el -} - -func (a *int16Array) SetDimensions(dimensions []ArrayDimension) error { - if dimensions == nil { - a = nil - return nil - } - - elementCount := cardinality(dimensions) - *a = make(int16Array, elementCount) - return nil -} - -func (a int16Array) ScanIndex(i int) any { - return &a[i] -} - -func (a int16Array) ScanIndexType() any { - return new(int16) -} - -type uint16Array []uint16 - -func (a uint16Array) Dimensions() []ArrayDimension { - if a == nil { - return nil - } - - return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} -} - -func (a uint16Array) Index(i int) any { - return a[i] -} - -func (a uint16Array) IndexType() any { - var el uint16 - return el -} - -func (a *uint16Array) SetDimensions(dimensions []ArrayDimension) error { - if dimensions == nil { - a = nil - return nil - } - - elementCount := cardinality(dimensions) - *a = make(uint16Array, elementCount) - return nil -} - -func (a uint16Array) ScanIndex(i int) any { - return &a[i] -} - -func (a uint16Array) ScanIndexType() any { - return new(uint16) -} diff --git a/pgtype/array_getter_setter.go.erb b/pgtype/array_getter_setter.go.erb deleted file mode 100644 index 1c8cdff4..00000000 --- a/pgtype/array_getter_setter.go.erb +++ /dev/null @@ -1,53 +0,0 @@ -package pgtype - -import ( - "fmt" - "reflect" -) - -<% - types = [ - ["int16Array", "int16"], - ["uint16Array", "uint16"], - ] -%> - -<% types.each do |array_type, element_type| %> - type <%= array_type %> []<%= element_type %> - - func (a <%= array_type %>) Dimensions() []ArrayDimension { - if a == nil { - return nil - } - - return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} - } - - func (a <%= array_type %>) Index(i int) any { - return a[i] - } - - func (a <%= array_type %>) IndexType() any { - var el <%= element_type %> - return el - } - - func (a *<%= array_type %>) SetDimensions(dimensions []ArrayDimension) error { - if dimensions == nil { - a = nil - return nil - } - - elementCount := cardinality(dimensions) - *a = make(<%= array_type %>, elementCount) - return nil - } - - func (a <%= array_type %>) ScanIndex(i int) any { - return &a[i] - } - - func (a <%= array_type %>) ScanIndexType() any { - return new(<%= element_type %>) - } -<% end %> diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index b385b80a..da9cf0bb 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -637,11 +637,11 @@ func (w *ptrStructWrapper) ScanIndex(i int) any { return w.exportedFields[i].Addr().Interface() } -type anySliceArray struct { +type anySliceArrayReflect struct { slice reflect.Value } -func (a anySliceArray) Dimensions() []ArrayDimension { +func (a anySliceArrayReflect) Dimensions() []ArrayDimension { if a.slice.IsNil() { return nil } @@ -649,15 +649,15 @@ func (a anySliceArray) Dimensions() []ArrayDimension { return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}} } -func (a anySliceArray) Index(i int) any { +func (a anySliceArrayReflect) Index(i int) any { return a.slice.Index(i).Interface() } -func (a anySliceArray) IndexType() any { +func (a anySliceArrayReflect) IndexType() any { return reflect.New(a.slice.Type().Elem()).Elem().Interface() } -func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error { +func (a *anySliceArrayReflect) SetDimensions(dimensions []ArrayDimension) error { sliceType := a.slice.Type() if dimensions == nil { @@ -671,11 +671,11 @@ func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error { return nil } -func (a *anySliceArray) ScanIndex(i int) any { +func (a *anySliceArrayReflect) ScanIndex(i int) any { return a.slice.Index(i).Addr().Interface() } -func (a *anySliceArray) ScanIndexType() any { +func (a *anySliceArrayReflect) ScanIndexType() any { return reflect.New(a.slice.Type().Elem()).Interface() } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index e35299e5..db916220 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -993,6 +993,24 @@ func (plan *wrapAnyPtrStructScanPlan) Scan(src []byte, target any) error { // TryWrapPtrSliceScanPlan tries to wrap a pointer to a single dimension slice. func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) { + // Avoid using reflect path for common types. + switch target := target.(type) { + case *[]int16: + return &wrapPtrSliceScanPlan[int16]{}, (*FlatArray[int16])(target), true + case *[]int32: + return &wrapPtrSliceScanPlan[int32]{}, (*FlatArray[int32])(target), true + case *[]int64: + return &wrapPtrSliceScanPlan[int64]{}, (*FlatArray[int64])(target), true + case *[]float32: + return &wrapPtrSliceScanPlan[float32]{}, (*FlatArray[float32])(target), true + case *[]float64: + return &wrapPtrSliceScanPlan[float64]{}, (*FlatArray[float64])(target), true + case *[]string: + return &wrapPtrSliceScanPlan[string]{}, (*FlatArray[string])(target), true + case *[]time.Time: + return &wrapPtrSliceScanPlan[time.Time]{}, (*FlatArray[time.Time])(target), true + } + targetValue := reflect.ValueOf(target) if targetValue.Kind() != reflect.Ptr { return nil, nil, false @@ -1001,19 +1019,29 @@ func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextVa targetElemValue := targetValue.Elem() if targetElemValue.Kind() == reflect.Slice { - return &wrapPtrSliceScanPlan{}, &anySliceArray{slice: targetElemValue}, true + return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: targetElemValue}, true } return nil, nil, false } -type wrapPtrSliceScanPlan struct { +type wrapPtrSliceScanPlan[T any] struct { next ScanPlan } -func (plan *wrapPtrSliceScanPlan) SetNext(next ScanPlan) { plan.next = next } +func (plan *wrapPtrSliceScanPlan[T]) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapPtrSliceScanPlan) Scan(src []byte, target any) error { - return plan.next.Scan(src, &anySliceArray{slice: reflect.ValueOf(target).Elem()}) +func (plan *wrapPtrSliceScanPlan[T]) Scan(src []byte, target any) error { + return plan.next.Scan(src, (*FlatArray[T])(target.(*[]T))) +} + +type wrapPtrSliceReflectScanPlan struct { + next ScanPlan +} + +func (plan *wrapPtrSliceReflectScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapPtrSliceReflectScanPlan) Scan(src []byte, target any) error { + return plan.next.Scan(src, &anySliceArrayReflect{slice: reflect.ValueOf(target).Elem()}) } // TryWrapPtrMultiDimSliceScanPlan tries to wrap a pointer to a multi-dimension slice. @@ -1660,24 +1688,56 @@ func getExportedFieldValues(structValue reflect.Value) []reflect.Value { } func TryWrapSliceEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + // Avoid using reflect path for common types. + switch value := value.(type) { + case []int16: + return &wrapSliceEncodePlan[int16]{}, (FlatArray[int16])(value), true + case []int32: + return &wrapSliceEncodePlan[int32]{}, (FlatArray[int32])(value), true + case []int64: + return &wrapSliceEncodePlan[int64]{}, (FlatArray[int64])(value), true + case []float32: + return &wrapSliceEncodePlan[float32]{}, (FlatArray[float32])(value), true + case []float64: + return &wrapSliceEncodePlan[float64]{}, (FlatArray[float64])(value), true + case []string: + return &wrapSliceEncodePlan[string]{}, (FlatArray[string])(value), true + case []time.Time: + return &wrapSliceEncodePlan[time.Time]{}, (FlatArray[time.Time])(value), true + } + if reflect.TypeOf(value).Kind() == reflect.Slice { - w := anySliceArray{ + w := anySliceArrayReflect{ slice: reflect.ValueOf(value), } - return &wrapSliceEncodePlan{}, w, true + return &wrapSliceEncodeReflectPlan{}, w, true } return nil, nil, false } -type wrapSliceEncodePlan struct { +type wrapSliceEncodePlan[T any] struct { next EncodePlan } -func (plan *wrapSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next } +func (plan *wrapSliceEncodePlan[T]) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapSliceEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { - w := anySliceArray{ +func (plan *wrapSliceEncodePlan[T]) Encode(value any, buf []byte) (newBuf []byte, err error) { + w := anySliceArrayReflect{ + slice: reflect.ValueOf(value), + } + + return plan.next.Encode(w, buf) +} + +type wrapSliceEncodeReflectPlan struct { + next EncodePlan +} + +func (plan *wrapSliceEncodeReflectPlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapSliceEncodeReflectPlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + w := anySliceArrayReflect{ slice: reflect.ValueOf(value), } From fccaebc93dba7a54a636d4af400bd3676629c401 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Apr 2022 13:38:27 -0500 Subject: [PATCH 1004/1158] Add pgtype.Map.SQLScanner This enables compatibility with database/sql for types that cannot implement Scan themselves. --- pgtype/pgtype.go | 38 +++++++++++++++++++++++++ stdlib/sql.go | 70 ++++++++++++++++++++++++++-------------------- stdlib/sql_test.go | 32 +++++++++++++++++++++ 3 files changed, 109 insertions(+), 31 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index db916220..ce06e738 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1808,3 +1808,41 @@ func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBu return newBuf, nil } + +// SQLScanner returns a database/sql.Scanner for v. This is necessary for types like Array[T] and Range[T] where the +// type needs assistance from Map to implement the sql.Scanner interface. It is not necessary for types like Box that +// implement sql.Scanner directly. +// +// This uses the type of v to look up the PostgreSQL OID that v presumably came from. This means v must be registered +// with m by calling RegisterDefaultPgType. +func (m *Map) SQLScanner(v any) sql.Scanner { + if s, ok := v.(sql.Scanner); ok { + return s + } + + return &sqlScannerWrapper{m: m, v: v} +} + +type sqlScannerWrapper struct { + m *Map + v any +} + +func (w *sqlScannerWrapper) Scan(src any) error { + t, ok := w.m.TypeForValue(w.v) + if !ok { + return fmt.Errorf("cannot convert to sql.Scanner: cannot find registered type for %T", w.v) + } + + var bufSrc []byte + switch src := src.(type) { + case string: + bufSrc = []byte(src) + case []byte: + bufSrc = src + default: + bufSrc = []byte(fmt.Sprint(bufSrc)) + } + + return w.m.Scan(t.OID, TextFormatCode, bufSrc, w.v) +} diff --git a/stdlib/sql.go b/stdlib/sql.go index 61fb77d3..e4c53ea7 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -2,50 +2,58 @@ // // A database/sql connection can be established through sql.Open. // -// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") -// if err != nil { -// return err -// } +// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") +// if err != nil { +// return err +// } // // Or from a DSN string. // -// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") -// if err != nil { -// return err -// } +// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") +// if err != nil { +// return err +// } // // Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the // pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used // with sql.Open. // -// connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL")) -// connConfig.Logger = myLogger -// connStr := stdlib.RegisterConnConfig(connConfig) -// db, _ := sql.Open("pgx", connStr) +// connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL")) +// connConfig.Logger = myLogger +// connStr := stdlib.RegisterConnConfig(connConfig) +// db, _ := sql.Open("pgx", connStr) // -// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. -// It does not support named parameters. +// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. It does not support named parameters. // -// db.QueryRow("select * from users where id=$1", userID) +// db.QueryRow("select * from users where id=$1", userID) // -// In Go 1.13 and above (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard -// database/sql.DB connection pool. This allows operations that use pgx specific functionality. +// In Go 1.13 and above (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard database/sql.DB connection +// pool. This allows operations that use pgx specific functionality. // -// // Given db is a *sql.DB -// conn, err := db.Conn(context.Background()) -// if err != nil { -// // handle error from acquiring connection from DB pool -// } +// // Given db is a *sql.DB +// conn, err := db.Conn(context.Background()) +// if err != nil { +// // handle error from acquiring connection from DB pool +// } // -// err = conn.Raw(func(driverConn any) error { -// conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn -// // Do pgx specific stuff with conn -// conn.CopyFrom(...) -// return nil -// }) -// if err != nil { -// // handle error that occurred while using *pgx.Conn -// } +// err = conn.Raw(func(driverConn any) error { +// conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn +// // Do pgx specific stuff with conn +// conn.CopyFrom(...) +// return nil +// }) +// if err != nil { +// // handle error that occurred while using *pgx.Conn +// } +// +// PostgreSQL Specific Data Types +// +// The pgtype package provides support for PostgreSQL specific types. *pgtype.Map.SQLScanner is an adapter that makes +// these types usable as a sql.Scanner. +// +// m := pgtype.NewMap() +// var a []int64 +// err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) package stdlib import ( diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 78b2d01f..75f0caf4 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -15,6 +15,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -373,6 +374,37 @@ func TestConnSimpleSlicePassThrough(t *testing.T) { }) } +func TestConnQueryScanArray(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + m := pgtype.NewMap() + + var a []int64 + err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) + require.NoError(t, err) + assert.Equal(t, []int64{1, 2, 3}, a) + }) +} + +func TestConnQueryScanRange(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + m := pgtype.NewMap() + + var r pgtype.Range[pgtype.Int4] + err := db.QueryRow("select int4range(1, 5)").Scan(m.SQLScanner(&r)) + require.NoError(t, err) + assert.Equal( + t, + pgtype.Range[pgtype.Int4]{ + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + r) + }) +} + // Test type that pgx would handle natively in binary, but since it is not a // database/sql native type should be passed through as a string func TestConnQueryRowPgxBinary(t *testing.T) { From a01a9ee6dfde03a8a458ff2872dc31d56e79bd4e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Apr 2022 14:04:25 -0500 Subject: [PATCH 1005/1158] Automatically register Array and FlatArray --- pgtype/pgtype.go | 155 ++++++++++++++++++++++++--------------------- stdlib/sql_test.go | 13 +++- 2 files changed, 94 insertions(+), 74 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index ce06e738..787eaead 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -316,91 +316,100 @@ func NewMap() *Map { m.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarcharOID]}}) m.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[XIDOID]}}) - registerDefaultPgTypeVariants := func(name, arrayName string, value any) { - // T - m.RegisterDefaultPgType(value, name) - - // *T - valueType := reflect.TypeOf(value) - m.RegisterDefaultPgType(reflect.New(valueType).Interface(), name) - - // []T - sliceType := reflect.SliceOf(valueType) - m.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName) - - // *[]T - m.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName) - - // []*T - sliceOfPointerType := reflect.SliceOf(reflect.TypeOf(reflect.New(valueType).Interface())) - m.RegisterDefaultPgType(reflect.MakeSlice(sliceOfPointerType, 0, 0).Interface(), arrayName) - - // *[]*T - m.RegisterDefaultPgType(reflect.New(sliceOfPointerType).Interface(), arrayName) - } - // Integer types that directly map to a PostgreSQL type - registerDefaultPgTypeVariants("int2", "_int2", int16(0)) - registerDefaultPgTypeVariants("int4", "_int4", int32(0)) - registerDefaultPgTypeVariants("int8", "_int8", int64(0)) + registerDefaultPgTypeVariants[int16](m, "int2") + registerDefaultPgTypeVariants[int32](m, "int4") + registerDefaultPgTypeVariants[int64](m, "int8") // Integer types that do not have a direct match to a PostgreSQL type - registerDefaultPgTypeVariants("int8", "_int8", int8(0)) - registerDefaultPgTypeVariants("int8", "_int8", int(0)) - registerDefaultPgTypeVariants("int8", "_int8", uint8(0)) - registerDefaultPgTypeVariants("int8", "_int8", uint16(0)) - registerDefaultPgTypeVariants("int8", "_int8", uint32(0)) - registerDefaultPgTypeVariants("int8", "_int8", uint64(0)) - registerDefaultPgTypeVariants("int8", "_int8", uint(0)) + registerDefaultPgTypeVariants[int8](m, "int8") + registerDefaultPgTypeVariants[int](m, "int8") + registerDefaultPgTypeVariants[uint8](m, "int8") + registerDefaultPgTypeVariants[uint16](m, "int8") + registerDefaultPgTypeVariants[uint32](m, "int8") + registerDefaultPgTypeVariants[uint64](m, "int8") + registerDefaultPgTypeVariants[uint](m, "int8") - registerDefaultPgTypeVariants("float4", "_float4", float32(0)) - registerDefaultPgTypeVariants("float8", "_float8", float64(0)) + registerDefaultPgTypeVariants[float32](m, "float4") + registerDefaultPgTypeVariants[float64](m, "float8") - registerDefaultPgTypeVariants("bool", "_bool", false) - registerDefaultPgTypeVariants("timestamptz", "_timestamptz", time.Time{}) - registerDefaultPgTypeVariants("interval", "_interval", time.Duration(0)) - registerDefaultPgTypeVariants("text", "_text", "") - registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil)) + registerDefaultPgTypeVariants[bool](m, "bool") + registerDefaultPgTypeVariants[time.Time](m, "timestamptz") + registerDefaultPgTypeVariants[time.Duration](m, "interval") + registerDefaultPgTypeVariants[string](m, "text") + registerDefaultPgTypeVariants[[]byte](m, "bytea") - registerDefaultPgTypeVariants("inet", "_inet", net.IP{}) - registerDefaultPgTypeVariants("cidr", "_cidr", net.IPNet{}) + registerDefaultPgTypeVariants[net.IP](m, "inet") + registerDefaultPgTypeVariants[net.IPNet](m, "cidr") // pgtype provided structs - registerDefaultPgTypeVariants("varbit", "_varbit", Bits{}) - registerDefaultPgTypeVariants("bool", "_bool", Bool{}) - registerDefaultPgTypeVariants("box", "_box", Box{}) - registerDefaultPgTypeVariants("circle", "_circle", Circle{}) - registerDefaultPgTypeVariants("date", "_date", Date{}) - registerDefaultPgTypeVariants("daterange", "_daterange", Range[Date]{}) - registerDefaultPgTypeVariants("float4", "_float4", Float4{}) - registerDefaultPgTypeVariants("float8", "_float8", Float8{}) - registerDefaultPgTypeVariants("numrange", "_numrange", Range[Float8]{}) // There is no PostgreSQL builtin float8range so map it to numrange. - registerDefaultPgTypeVariants("inet", "_inet", Inet{}) - registerDefaultPgTypeVariants("int2", "_int2", Int2{}) - registerDefaultPgTypeVariants("int4", "_int4", Int4{}) - registerDefaultPgTypeVariants("int4range", "_int4range", Range[Int4]{}) - registerDefaultPgTypeVariants("int8", "_int8", Int8{}) - registerDefaultPgTypeVariants("int8range", "_int8range", Range[Int8]{}) - registerDefaultPgTypeVariants("interval", "_interval", Interval{}) - registerDefaultPgTypeVariants("line", "_line", Line{}) - registerDefaultPgTypeVariants("lseg", "_lseg", Lseg{}) - registerDefaultPgTypeVariants("numeric", "_numeric", Numeric{}) - registerDefaultPgTypeVariants("numrange", "_numrange", Range[Numeric]{}) - registerDefaultPgTypeVariants("path", "_path", Path{}) - registerDefaultPgTypeVariants("point", "_point", Point{}) - registerDefaultPgTypeVariants("polygon", "_polygon", Polygon{}) - registerDefaultPgTypeVariants("tid", "_tid", TID{}) - registerDefaultPgTypeVariants("text", "_text", Text{}) - registerDefaultPgTypeVariants("time", "_time", Time{}) - registerDefaultPgTypeVariants("timestamp", "_timestamp", Timestamp{}) - registerDefaultPgTypeVariants("timestamptz", "_timestamptz", Timestamptz{}) - registerDefaultPgTypeVariants("tsrange", "_tsrange", Range[Timestamp]{}) - registerDefaultPgTypeVariants("tstzrange", "_tstzrange", Range[Timestamptz]{}) - registerDefaultPgTypeVariants("uuid", "_uuid", UUID{}) + registerDefaultPgTypeVariants[Bits](m, "varbit") + registerDefaultPgTypeVariants[Bool](m, "bool") + registerDefaultPgTypeVariants[Box](m, "box") + registerDefaultPgTypeVariants[Circle](m, "circle") + registerDefaultPgTypeVariants[Date](m, "date") + registerDefaultPgTypeVariants[Range[Date]](m, "daterange") + registerDefaultPgTypeVariants[Float4](m, "float4") + registerDefaultPgTypeVariants[Float8](m, "float8") + registerDefaultPgTypeVariants[Range[Float8]](m, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange. + registerDefaultPgTypeVariants[Inet](m, "inet") + registerDefaultPgTypeVariants[Int2](m, "int2") + registerDefaultPgTypeVariants[Int4](m, "int4") + registerDefaultPgTypeVariants[Range[Int4]](m, "int4range") + registerDefaultPgTypeVariants[Int8](m, "int8") + registerDefaultPgTypeVariants[Range[Int8]](m, "int8range") + registerDefaultPgTypeVariants[Interval](m, "interval") + registerDefaultPgTypeVariants[Line](m, "line") + registerDefaultPgTypeVariants[Lseg](m, "lseg") + registerDefaultPgTypeVariants[Numeric](m, "numeric") + registerDefaultPgTypeVariants[Range[Numeric]](m, "numrange") + registerDefaultPgTypeVariants[Path](m, "path") + registerDefaultPgTypeVariants[Point](m, "point") + registerDefaultPgTypeVariants[Polygon](m, "polygon") + registerDefaultPgTypeVariants[TID](m, "tid") + registerDefaultPgTypeVariants[Text](m, "text") + registerDefaultPgTypeVariants[Time](m, "time") + registerDefaultPgTypeVariants[Timestamp](m, "timestamp") + registerDefaultPgTypeVariants[Timestamptz](m, "timestamptz") + registerDefaultPgTypeVariants[Range[Timestamp]](m, "tsrange") + registerDefaultPgTypeVariants[Range[Timestamptz]](m, "tstzrange") + registerDefaultPgTypeVariants[UUID](m, "uuid") return m } +func registerDefaultPgTypeVariants[T any](m *Map, name string) { + arrayName := "_" + name + + var value T + m.RegisterDefaultPgType(value, name) // T + m.RegisterDefaultPgType(&value, name) // *T + + var sliceT []T + m.RegisterDefaultPgType(sliceT, arrayName) // []T + m.RegisterDefaultPgType(&sliceT, arrayName) // *[]T + + var slicePtrT []*T + m.RegisterDefaultPgType(slicePtrT, arrayName) // []*T + m.RegisterDefaultPgType(&slicePtrT, arrayName) // *[]*T + + var arrayOfT Array[T] + m.RegisterDefaultPgType(arrayOfT, arrayName) // Array[T] + m.RegisterDefaultPgType(&arrayOfT, arrayName) // *Array[T] + + var arrayOfPtrT Array[*T] + m.RegisterDefaultPgType(arrayOfPtrT, arrayName) // Array[*T] + m.RegisterDefaultPgType(&arrayOfPtrT, arrayName) // *Array[*T] + + var flatArrayOfT FlatArray[T] + m.RegisterDefaultPgType(flatArrayOfT, arrayName) // FlatArray[T] + m.RegisterDefaultPgType(&flatArrayOfT, arrayName) // *FlatArray[T] + + var flatArrayOfPtrT FlatArray[*T] + m.RegisterDefaultPgType(flatArrayOfPtrT, arrayName) // FlatArray[*T] + m.RegisterDefaultPgType(&flatArrayOfPtrT, arrayName) // *FlatArray[*T] +} + func (m *Map) RegisterType(t *Type) { m.oidToType[t.OID] = t m.nameToType[t.Name] = t diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 75f0caf4..30cea7d6 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -374,7 +374,7 @@ func TestConnSimpleSlicePassThrough(t *testing.T) { }) } -func TestConnQueryScanArray(t *testing.T) { +func TestConnQueryScanGoArray(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { m := pgtype.NewMap() @@ -385,6 +385,17 @@ func TestConnQueryScanArray(t *testing.T) { }) } +func TestConnQueryScanArray(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + m := pgtype.NewMap() + + var a pgtype.Array[int64] + err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) + require.NoError(t, err) + assert.Equal(t, pgtype.Array[int64]{Elements: []int64{1, 2, 3}, Dims: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Valid: true}, a) + }) +} + func TestConnQueryScanRange(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { m := pgtype.NewMap() From e94cf1fbaad0b740226359a7f6ce1a5164277820 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Apr 2022 14:07:59 -0500 Subject: [PATCH 1006/1158] Remove AcquireConn and ReleaseConn Superseded by (*sql.Conn) Raw() --- stdlib/sql.go | 69 ------------------------------ stdlib/sql_test.go | 103 --------------------------------------------- 2 files changed, 172 deletions(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index e4c53ea7..e4565227 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -81,17 +81,10 @@ var databaseSQLResultFormats pgx.QueryResultFormatsByOID var pgxDriver *Driver -type ctxKey int - -var ctxKeyFakeTx ctxKey = 0 - -var ErrNotPgx = errors.New("not pgx *sql.DB") - func init() { pgxDriver = &Driver{ configs: make(map[string]*pgx.ConnConfig), } - fakeTxConns = make(map[*pgx.Conn]*sql.Tx) sql.Register("pgx", pgxDriver) databaseSQLResultFormats = pgx.QueryResultFormatsByOID{ @@ -111,11 +104,6 @@ func init() { } } -var ( - fakeTxMutex sync.Mutex - fakeTxConns map[*pgx.Conn]*sql.Tx -) - // OptionOpenDB options for configuring the driver when opening a new db pool. type OptionOpenDB func(*connector) @@ -367,11 +355,6 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e return nil, driver.ErrBadConn } - if pconn, ok := ctx.Value(ctxKeyFakeTx).(**pgx.Conn); ok { - *pconn = c.conn - return fakeTx{}, nil - } - var pgxOpts pgx.TxOptions switch sql.IsolationLevel(opts.Isolation) { case sql.LevelDefault: @@ -786,55 +769,3 @@ type wrapTx struct { func (wtx wrapTx) Commit() error { return wtx.tx.Commit(wtx.ctx) } func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(wtx.ctx) } - -type fakeTx struct{} - -func (fakeTx) Commit() error { return nil } - -func (fakeTx) Rollback() error { return nil } - -// AcquireConn acquires a *pgx.Conn from database/sql connection pool. It must be released with ReleaseConn. -// -// In Go 1.13 this functionality has been incorporated into the standard library in the db.Conn.Raw() method. -func AcquireConn(db *sql.DB) (*pgx.Conn, error) { - var conn *pgx.Conn - ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn) - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - if conn == nil { - tx.Rollback() - return nil, ErrNotPgx - } - - fakeTxMutex.Lock() - fakeTxConns[conn] = tx - fakeTxMutex.Unlock() - - return conn, nil -} - -// ReleaseConn releases a *pgx.Conn acquired with AcquireConn. -func ReleaseConn(db *sql.DB, conn *pgx.Conn) error { - var tx *sql.Tx - var ok bool - - if conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - conn.Close(ctx) - } - - fakeTxMutex.Lock() - tx, ok = fakeTxConns[conn] - if ok { - delete(fakeTxConns, conn) - fakeTxMutex.Unlock() - } else { - fakeTxMutex.Unlock() - return fmt.Errorf("can't release conn that is not acquired") - } - - return tx.Rollback() -} diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 30cea7d6..faa4a0cb 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -640,42 +640,6 @@ func TestBeginTxContextCancel(t *testing.T) { }) } -func TestAcquireConn(t *testing.T) { - testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { - var conns []*pgx.Conn - - for i := 1; i < 6; i++ { - conn, err := stdlib.AcquireConn(db) - if err != nil { - t.Errorf("%d. AcquireConn failed: %v", i, err) - continue - } - - var n int32 - err = conn.QueryRow(context.Background(), "select 1").Scan(&n) - if err != nil { - t.Errorf("%d. QueryRow failed: %v", i, err) - } - if n != 1 { - t.Errorf("%d. n => %d, want %d", i, n, 1) - } - - stats := db.Stats() - if stats.OpenConnections != i { - t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i) - } - - conns = append(conns, conn) - } - - for i, conn := range conns { - if err := stdlib.ReleaseConn(db, conn); err != nil { - t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err) - } - } - }) -} - func TestConnRaw(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { conn, err := db.Conn(context.Background()) @@ -691,38 +655,6 @@ func TestConnRaw(t *testing.T) { }) } -// https://github.com/jackc/pgx/issues/673 -func TestReleaseConnWithTxInProgress(t *testing.T) { - testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { - skipCockroachDB(t, db, "Server does not support backend PID") - - c1, err := stdlib.AcquireConn(db) - require.NoError(t, err) - - _, err = c1.Exec(context.Background(), "begin") - require.NoError(t, err) - - c1PID := c1.PgConn().PID() - - err = stdlib.ReleaseConn(db, c1) - require.NoError(t, err) - - c2, err := stdlib.AcquireConn(db) - require.NoError(t, err) - - c2PID := c2.PgConn().PID() - - err = stdlib.ReleaseConn(db, c2) - require.NoError(t, err) - - require.NotEqual(t, c1PID, c2PID) - - // Releasing a conn with a tx in progress should close the connection - stats := db.Stats() - require.Equal(t, 1, stats.OpenConnections) - }) -} - func TestConnPingContextSuccess(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { err := db.PingContext(context.Background()) @@ -746,23 +678,6 @@ func TestConnExecContextSuccess(t *testing.T) { }) } -func TestConnExecContextFailureRetry(t *testing.T) { - testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { - // We get a connection, immediately close it, and then get it back; - // DB.Conn along with Conn.ResetSession does the retry for us. - { - conn, err := stdlib.AcquireConn(db) - require.NoError(t, err) - conn.Close(context.Background()) - stdlib.ReleaseConn(db, conn) - } - conn, err := db.Conn(context.Background()) - require.NoError(t, err) - _, err = conn.ExecContext(context.Background(), "select 1") - require.NoError(t, err) - }) -} - func TestConnQueryContextSuccess(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n") @@ -777,24 +692,6 @@ func TestConnQueryContextSuccess(t *testing.T) { }) } -func TestConnQueryContextFailureRetry(t *testing.T) { - testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { - // We get a connection, immediately close it, and then get it back; - // DB.Conn along with Conn.ResetSession does the retry for us. - { - conn, err := stdlib.AcquireConn(db) - require.NoError(t, err) - conn.Close(context.Background()) - stdlib.ReleaseConn(db, conn) - } - conn, err := db.Conn(context.Background()) - require.NoError(t, err) - - _, err = conn.QueryContext(context.Background(), "select 1") - require.NoError(t, err) - }) -} - func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { rows, err := db.Query("select 42::bigint") From 1c90746cf5789cbc166892501e09919bc03dd300 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Apr 2022 14:14:59 -0500 Subject: [PATCH 1007/1158] Update CHANGELOG --- CHANGELOG.md | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 49cc1279..0152bed1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,7 +33,7 @@ generally defined by implementing an interface that a particular `Codec` underst All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This significantly reduced the amount of code and the compiled binary size. This also means that less common array types such -as `point[]` are now supported. +as `point[]` are now supported. `Array[T]` supports PostgreSQL multi-dimensional arrays. ### Composite Types @@ -42,8 +42,8 @@ values, but any type may now implement `CompositeIndexGetter` and `CompositeInde ### Range Types -Range types are now handled with generic types `RangeCodec[T]` and `Range[T]`. This allows additional user defined range -types to easily be handled. +Range types are now handled with types `RangeCodec` and `Range[T]`. This allows additional user defined range types to +easily be handled. ### pgxtype @@ -70,6 +70,9 @@ only `string` is handled. This is to allow the possibility of future binary supp considering `[]byte` to be binary format and `string` text format. This change should have no effect for any use with `pgx`. The previous behavior was only necessary for `lib/pq` compatibility. +Added `*Map.SQLScanner` to create a `sql.Scanner` for types such as `[]int32` and `Range[T]` that do not implement +`sql.Scanner` directly. + ### Number Type Fields Include Bit size `Int2`, `Int4`, `Int8`, `Float4`, `Float8`, and `Uint32` fields now include bit size. e.g. `Int` is renamed to `Int64`. @@ -95,6 +98,10 @@ This matches the convention set by `database/sql`. In addition, for comparable t * Renamed `pgtype.None` to `pgtype.Finite`. * `RegisterType` now accepts a `*Type` instead of `Type`. +## stdlib + +* Removed `AcquireConn` and `ReleaseConn` as that functionality has been built in since Go 1.13. + ## Reduced Memory Usage by Reusing Read Buffers Previously, the connection read buffer would allocate large chunks of memory and never reuse them. This allowed From cc7de81d3b8aede7edd3d6202007e424eaeb1651 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Apr 2022 14:21:40 -0500 Subject: [PATCH 1008/1158] Make array helpers private --- CHANGELOG.md | 1 + pgtype/array.go | 16 ++++++------- pgtype/array_codec.go | 10 ++++----- pgtype/array_test.go | 52 +++++++++++++++++++++---------------------- pgtype/doc.go | 2 +- 5 files changed, 40 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0152bed1..162a7f68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -97,6 +97,7 @@ This matches the convention set by `database/sql`. In addition, for comparable t * Renamed `pgtype.DataType` to `pgtype.Type`. * Renamed `pgtype.None` to `pgtype.Finite`. * `RegisterType` now accepts a `*Type` instead of `Type`. +* Assorted array helper methods and types made private. ## stdlib diff --git a/pgtype/array.go b/pgtype/array.go index d34a94e5..93f5aa9b 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -17,7 +17,7 @@ import ( // src/include/utils/array.h and src/backend/utils/adt/arrayfuncs.c. Of // particular interest is the array_send function. -type ArrayHeader struct { +type arrayHeader struct { ContainsNull bool ElementOID uint32 Dimensions []ArrayDimension @@ -42,7 +42,7 @@ func cardinality(dimensions []ArrayDimension) int { return elementCount } -func (dst *ArrayHeader) DecodeBinary(m *Map, src []byte) (int, error) { +func (dst *arrayHeader) DecodeBinary(m *Map, src []byte) (int, error) { if len(src) < 12 { return 0, fmt.Errorf("array header too short: %d", len(src)) } @@ -73,7 +73,7 @@ func (dst *ArrayHeader) DecodeBinary(m *Map, src []byte) (int, error) { return rp, nil } -func (src ArrayHeader) EncodeBinary(buf []byte) []byte { +func (src arrayHeader) EncodeBinary(buf []byte) []byte { buf = pgio.AppendInt32(buf, int32(len(src.Dimensions))) var containsNull int32 @@ -92,14 +92,14 @@ func (src ArrayHeader) EncodeBinary(buf []byte) []byte { return buf } -type UntypedTextArray struct { +type untypedTextArray struct { Elements []string Quoted []bool Dimensions []ArrayDimension } -func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { - dst := &UntypedTextArray{ +func parseUntypedTextArray(src string) (*untypedTextArray, error) { + dst := &untypedTextArray{ Elements: []string{}, Quoted: []bool{}, Dimensions: []ArrayDimension{}, @@ -333,7 +333,7 @@ func arrayParseInteger(buf *bytes.Buffer) (int32, error) { } } -func EncodeTextArrayDimensions(buf []byte, dimensions []ArrayDimension) []byte { +func encodeTextArrayDimensions(buf []byte, dimensions []ArrayDimension) []byte { var customDimensions bool for _, dim := range dimensions { if dim.LowerBound != 1 { @@ -367,7 +367,7 @@ func isSpace(ch byte) bool { return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\f' } -func QuoteArrayElementIfNeeded(src string) string { +func quoteArrayElementIfNeeded(src string) string { if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) { return quoteArrayElement(src) } diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 8aab13bb..8ed45da5 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -91,7 +91,7 @@ func (p *encodePlanArrayCodecText) Encode(value any, buf []byte) (newBuf []byte, return append(buf, '{', '}'), nil } - buf = EncodeTextArrayDimensions(buf, dimensions) + buf = encodeTextArrayDimensions(buf, dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -138,7 +138,7 @@ func (p *encodePlanArrayCodecText) Encode(value any, buf []byte) (newBuf []byte, if elemBuf == nil { buf = append(buf, `NULL`...) } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + buf = append(buf, quoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { @@ -165,7 +165,7 @@ func (p *encodePlanArrayCodecBinary) Encode(value any, buf []byte) (newBuf []byt return nil, nil } - arrayHeader := ArrayHeader{ + arrayHeader := arrayHeader{ Dimensions: dimensions, ElementOID: p.ac.ElementType.OID, } @@ -232,7 +232,7 @@ func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target any) Scan } func (c *ArrayCodec) decodeBinary(m *Map, arrayOID uint32, src []byte, array ArraySetter) error { - var arrayHeader ArrayHeader + var arrayHeader arrayHeader rp, err := arrayHeader.DecodeBinary(m, src) if err != nil { return err @@ -272,7 +272,7 @@ func (c *ArrayCodec) decodeBinary(m *Map, arrayOID uint32, src []byte, array Arr } func (c *ArrayCodec) decodeText(m *Map, arrayOID uint32, src []byte, array ArraySetter) error { - uta, err := ParseUntypedTextArray(string(src)) + uta, err := parseUntypedTextArray(string(src)) if err != nil { return err } diff --git a/pgtype/array_test.go b/pgtype/array_test.go index 8043e12f..f246b346 100644 --- a/pgtype/array_test.go +++ b/pgtype/array_test.go @@ -1,79 +1,77 @@ -package pgtype_test +package pgtype import ( "reflect" "testing" - - "github.com/jackc/pgx/v5/pgtype" ) func TestParseUntypedTextArray(t *testing.T) { tests := []struct { source string - result pgtype.UntypedTextArray + result untypedTextArray }{ { source: "{}", - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{}, Quoted: []bool{}, - Dimensions: []pgtype.ArrayDimension{}, + Dimensions: []ArrayDimension{}, }, }, { source: "{1}", - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{"1"}, Quoted: []bool{false}, - Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + Dimensions: []ArrayDimension{{Length: 1, LowerBound: 1}}, }, }, { source: "{a,b}", - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{"a", "b"}, Quoted: []bool{false, false}, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Dimensions: []ArrayDimension{{Length: 2, LowerBound: 1}}, }, }, { source: `{"NULL"}`, - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{"NULL"}, Quoted: []bool{true}, - Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + Dimensions: []ArrayDimension{{Length: 1, LowerBound: 1}}, }, }, { source: `{""}`, - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{""}, Quoted: []bool{true}, - Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + Dimensions: []ArrayDimension{{Length: 1, LowerBound: 1}}, }, }, { source: `{"He said, \"Hello.\""}`, - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{`He said, "Hello."`}, Quoted: []bool{true}, - Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + Dimensions: []ArrayDimension{{Length: 1, LowerBound: 1}}, }, }, { source: "{{a,b},{c,d},{e,f}}", - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{"a", "b", "c", "d", "e", "f"}, Quoted: []bool{false, false, false, false, false, false}, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Dimensions: []ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, }, }, { source: "{{{a,b},{c,d},{e,f}},{{a,b},{c,d},{e,f}}}", - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{"a", "b", "c", "d", "e", "f", "a", "b", "c", "d", "e", "f"}, Quoted: []bool{false, false, false, false, false, false, false, false, false, false, false, false}, - Dimensions: []pgtype.ArrayDimension{ + Dimensions: []ArrayDimension{ {Length: 2, LowerBound: 1}, {Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}, @@ -82,18 +80,18 @@ func TestParseUntypedTextArray(t *testing.T) { }, { source: "[4:4]={1}", - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{"1"}, Quoted: []bool{false}, - Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 4}}, + Dimensions: []ArrayDimension{{Length: 1, LowerBound: 4}}, }, }, { source: "[4:5][2:3]={{a,b},{c,d}}", - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{"a", "b", "c", "d"}, Quoted: []bool{false, false, false, false}, - Dimensions: []pgtype.ArrayDimension{ + Dimensions: []ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, @@ -101,16 +99,16 @@ func TestParseUntypedTextArray(t *testing.T) { }, { source: "[-4:-2]={1,2,3}", - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{"1", "2", "3"}, Quoted: []bool{false, false, false}, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: -4}}, + Dimensions: []ArrayDimension{{Length: 3, LowerBound: -4}}, }, }, } for i, tt := range tests { - r, err := pgtype.ParseUntypedTextArray(tt.source) + r, err := parseUntypedTextArray(tt.source) if err != nil { t.Errorf("%d: %v", i, err) continue diff --git a/pgtype/doc.go b/pgtype/doc.go index 1de29bd2..9764aabf 100644 --- a/pgtype/doc.go +++ b/pgtype/doc.go @@ -13,7 +13,7 @@ pgtype automatically marshals and unmarshals data from json and jsonb PostgreSQL Array Support ArrayCodec implements support for arrays. If pgtype supports type T then it can easily support []T by registering an -ArrayCodec for the appropriate PostgreSQL OID. +ArrayCodec for the appropriate PostgreSQL OID. In addition, Array[T] type can support multi-dimensional arrays. Composite Support From c63f912615930b25fd5c496e779d015c3bfb67fe Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 21 Apr 2022 19:19:32 -0500 Subject: [PATCH 1009/1158] Hstore.Set accepts map[string]Text --- hstore.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hstore.go b/hstore.go index f46eeaf6..706a3964 100644 --- a/hstore.go +++ b/hstore.go @@ -50,6 +50,8 @@ func (dst *Hstore) Set(src interface{}) error { } } *dst = Hstore{Map: m, Status: Present} + case map[string]Text: + *dst = Hstore{Map: value, Status: Present} default: return fmt.Errorf("cannot convert %v to Hstore", src) } From 1b244eec5da4234408cc3241e412b23be8541d4b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 21 Apr 2022 19:48:43 -0500 Subject: [PATCH 1010/1158] Upgrade to pgproto3 v2.3.0 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 2a2d6810..aaf7a486 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.2.1-0.20220412121321-175856ffd3c8 + github.com/jackc/pgproto3/v2 v2.3.0 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.7.0 golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 diff --git a/go.sum b/go.sum index c558564b..a3834fd2 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwX github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.2.1-0.20220412121321-175856ffd3c8 h1:KxsCQec+1iwJXtxnbbS/dY0EJ6rJEUlFsrJUnL5A2XI= github.com/jackc/pgproto3/v2 v2.2.1-0.20220412121321-175856ffd3c8/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.3.0 h1:brH0pCGBDkBW07HWlN/oSBXrmo3WB0UvZd1pIuDcL8Y= +github.com/jackc/pgproto3/v2 v2.3.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= From 9bb49f990f1e563b88fc0772a6a928b14ebd19f0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 21 Apr 2022 19:49:01 -0500 Subject: [PATCH 1011/1158] Release v1.12.0 --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a37eecfe..6df3ddcf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +# 1.12.0 (April 21, 2022) + +* Add pluggable GSSAPI support (Oliver Tan) +* Fix: Consider any "0A000" error a possible cached plan changed error due to locale +* Better match psql fallback behavior with multiple hosts + # 1.11.0 (February 7, 2022) * Support port in ip from LookupFunc to override config (James Hartig) From c5a0faca99221f74833133537e4b3d80f84d2cc6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 21 Apr 2022 19:58:17 -0500 Subject: [PATCH 1012/1158] Release v1.11.0 --- CHANGELOG.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 73126cf3..253f42c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +# 1.11.0 (April 21, 2022) + +* Add multirange for numeric, int4, and int8 (Vu) +* JSONBArray now supports json.RawMessage (Jens Emil Schulz Østergaard) +* Add RecordArray (WGH) +* Add UnmarshalJSON to pgtype.Int2 +* Hstore.Set accepts map[string]Text + # 1.10.0 (February 7, 2022) * Normalize UTC timestamps to comply with stdlib (Torkel Rogstad) From 791176f4fe13e0bac5b88ddb527be062a792c9f7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Apr 2022 10:26:41 -0500 Subject: [PATCH 1013/1158] Add link to github.com/vgarvardt/pgx-google-uuid --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 159712dd..1f3cd185 100644 --- a/README.md +++ b/README.md @@ -157,6 +157,7 @@ pgerrcode contains constants for the PostgreSQL error codes. * [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid) * [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal) +* [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid) ## Adapters for 3rd Party Loggers From 468b7932822617e691a0ce1d93db51dcfce509eb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Apr 2022 10:34:53 -0500 Subject: [PATCH 1014/1158] Skip tests with unsupported types on CockroachDB --- pgtype/array_codec_test.go | 8 +++++++- stdlib/sql_test.go | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go index 65289d04..9da027e8 100644 --- a/pgtype/array_codec_test.go +++ b/pgtype/array_codec_test.go @@ -6,6 +6,7 @@ import ( pgx "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -70,7 +71,12 @@ func TestArrayCodecFlatArray(t *testing.T) { } func TestArrayCodecArray(t *testing.T) { - defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does not support multi-dimensional arrays") + } + + ctr.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { for i, tt := range []struct { expected any }{ diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index faa4a0cb..9106df62 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -398,6 +398,8 @@ func TestConnQueryScanArray(t *testing.T) { func TestConnQueryScanRange(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + skipCockroachDB(t, db, "Server does not support int4range") + m := pgtype.NewMap() var r pgtype.Range[pgtype.Int4] From 126b582f19125cc7bd07c651b6feec1f9370a425 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Apr 2022 11:10:04 -0500 Subject: [PATCH 1015/1158] Make range helpers private --- pgtype/range.go | 12 +++++------ pgtype/range_codec.go | 4 ++-- pgtype/range_test.go | 50 +++++++++++++++++++++---------------------- 3 files changed, 33 insertions(+), 33 deletions(-) diff --git a/pgtype/range.go b/pgtype/range.go index c775239d..8f408f9f 100644 --- a/pgtype/range.go +++ b/pgtype/range.go @@ -19,15 +19,15 @@ func (bt BoundType) String() string { return string(bt) } -type UntypedTextRange struct { +type untypedTextRange struct { Lower string Upper string LowerType BoundType UpperType BoundType } -func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { - utr := &UntypedTextRange{} +func parseUntypedTextRange(src string) (*untypedTextRange, error) { + utr := &untypedTextRange{} if src == "empty" { utr.LowerType = Empty utr.UpperType = Empty @@ -173,7 +173,7 @@ func rangeParseQuotedValue(buf *bytes.Buffer) (string, error) { } } -type UntypedBinaryRange struct { +type untypedBinaryRange struct { Lower []byte Upper []byte LowerType BoundType @@ -197,8 +197,8 @@ const upperInclusiveMask = 4 const lowerUnboundedMask = 8 const upperUnboundedMask = 16 -func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { - ubr := &UntypedBinaryRange{} +func parseUntypedBinaryRange(src []byte) (*untypedBinaryRange, error) { + ubr := &untypedBinaryRange{} if len(src) == 0 { return nil, fmt.Errorf("range too short: %v", len(src)) diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go index d50bd3cb..f4ed41b6 100644 --- a/pgtype/range_codec.go +++ b/pgtype/range_codec.go @@ -263,7 +263,7 @@ func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) erro return rangeScanner.ScanNull() } - ubr, err := ParseUntypedBinaryRange(src) + ubr, err := parseUntypedBinaryRange(src) if err != nil { return err } @@ -313,7 +313,7 @@ func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error return rangeScanner.ScanNull() } - utr, err := ParseUntypedTextRange(string(src)) + utr, err := parseUntypedTextRange(string(src)) if err != nil { return err } diff --git a/pgtype/range_test.go b/pgtype/range_test.go index 9e16df59..1ee8d553 100644 --- a/pgtype/range_test.go +++ b/pgtype/range_test.go @@ -8,68 +8,68 @@ import ( func TestParseUntypedTextRange(t *testing.T) { tests := []struct { src string - result UntypedTextRange + result untypedTextRange err error }{ { src: `[1,2)`, - result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: `[1,2]`, - result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive}, + result: untypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive}, err: nil, }, { src: `(1,3)`, - result: UntypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive}, err: nil, }, { src: ` [1,2) `, - result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: `[ foo , bar )`, - result: UntypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: `["foo","bar")`, - result: UntypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: `["f""oo","b""ar")`, - result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: `["f""oo","b""ar")`, - result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: `["","bar")`, - result: UntypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: `[f\"oo\,,b\\ar\))`, - result: UntypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: `empty`, - result: UntypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty}, + result: untypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty}, err: nil, }, } for i, tt := range tests { - r, err := ParseUntypedTextRange(tt.src) + r, err := parseUntypedTextRange(tt.src) if err != tt.err { t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) continue @@ -96,63 +96,63 @@ func TestParseUntypedTextRange(t *testing.T) { func TestParseUntypedBinaryRange(t *testing.T) { tests := []struct { src []byte - result UntypedBinaryRange + result untypedBinaryRange err error }{ { src: []byte{0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive}, err: nil, }, { src: []byte{1}, - result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty}, + result: untypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty}, err: nil, }, { src: []byte{2, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: []byte{4, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive}, err: nil, }, { src: []byte{6, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive}, err: nil, }, { src: []byte{8, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive}, + result: untypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive}, err: nil, }, { src: []byte{12, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive}, + result: untypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive}, err: nil, }, { src: []byte{16, 0, 0, 0, 2, 0, 4}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded}, err: nil, }, { src: []byte{18, 0, 0, 0, 2, 0, 4}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded}, err: nil, }, { src: []byte{24}, - result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded}, + result: untypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded}, err: nil, }, } for i, tt := range tests { - r, err := ParseUntypedBinaryRange(tt.src) + r, err := parseUntypedBinaryRange(tt.src) if err != tt.err { t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) continue From dfb681d716f0562b8cb28320dd762d7f64e8639d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Apr 2022 12:50:18 -0500 Subject: [PATCH 1016/1158] Build / rewrite / port multirange support --- pgtype/multirange.go | 443 ++++++++++++++++++++++++++++++++++++++ pgtype/multirange_test.go | 112 ++++++++++ pgtype/pgtype.go | 214 ++++++++++-------- 3 files changed, 677 insertions(+), 92 deletions(-) create mode 100644 pgtype/multirange.go create mode 100644 pgtype/multirange_test.go diff --git a/pgtype/multirange.go b/pgtype/multirange.go new file mode 100644 index 00000000..34950b34 --- /dev/null +++ b/pgtype/multirange.go @@ -0,0 +1,443 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// MultirangeGetter is a type that can be converted into a PostgreSQL multirange. +type MultirangeGetter interface { + // IsNull returns true if the value is SQL NULL. + IsNull() bool + + // Len returns the number of elements in the multirange. + Len() int + + // Index returns the element at i. + Index(i int) any + + // IndexType returns a non-nil scan target of the type Index will return. This is used by MultirangeCodec.PlanEncode. + IndexType() any +} + +// MultirangeSetter is a type can be set from a PostgreSQL multirange. +type MultirangeSetter interface { + // ScanNull sets the value to SQL NULL. + ScanNull() error + + // SetLen prepares the value such that ScanIndex can be called for each element. This will remove any existing + // elements. + SetLen(n int) error + + // ScanIndex returns a value usable as a scan target for i. SetLen must be called before ScanIndex. + ScanIndex(i int) any + + // ScanIndexType returns a non-nil scan target of the type ScanIndex will return. This is used by + // MultirangeCodec.PlanScan. + ScanIndexType() any +} + +// MultirangeCodec is a codec for any multirange type. +type MultirangeCodec struct { + ElementType *Type +} + +func (c *MultirangeCodec) FormatSupported(format int16) bool { + return c.ElementType.Codec.FormatSupported(format) +} + +func (c *MultirangeCodec) PreferredFormat() int16 { + return c.ElementType.Codec.PreferredFormat() +} + +func (c *MultirangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + multirangeValuer, ok := value.(MultirangeGetter) + if !ok { + return nil + } + + elementType := multirangeValuer.IndexType() + + elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType) + if elementEncodePlan == nil { + return nil + } + + switch format { + case BinaryFormatCode: + return &encodePlanMultirangeCodecBinary{ac: c, m: m, oid: oid} + case TextFormatCode: + return &encodePlanMultirangeCodecText{ac: c, m: m, oid: oid} + } + + return nil +} + +type encodePlanMultirangeCodecText struct { + ac *MultirangeCodec + m *Map + oid uint32 +} + +func (p *encodePlanMultirangeCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + multirange := value.(MultirangeGetter) + + if multirange.IsNull() { + return nil, nil + } + + elementCount := multirange.Len() + + buf = append(buf, '{') + + var encodePlan EncodePlan + var lastElemType reflect.Type + inElemBuf := make([]byte, 0, 32) + for i := 0; i < elementCount; i++ { + if i > 0 { + buf = append(buf, ',') + } + + elem := multirange.Index(i) + var elemBuf []byte + if elem != nil { + elemType := reflect.TypeOf(elem) + if lastElemType != elemType { + lastElemType = elemType + encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, TextFormatCode, elem) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", multirange.Index(i)) + } + } + elemBuf, err = encodePlan.Encode(elem, inElemBuf) + if err != nil { + return nil, err + } + } + + if elemBuf == nil { + return nil, fmt.Errorf("multirange cannot contain NULL element") + } else { + buf = append(buf, elemBuf...) + } + } + + buf = append(buf, '}') + + return buf, nil +} + +type encodePlanMultirangeCodecBinary struct { + ac *MultirangeCodec + m *Map + oid uint32 +} + +func (p *encodePlanMultirangeCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + multirange := value.(MultirangeGetter) + + if multirange.IsNull() { + return nil, nil + } + + elementCount := multirange.Len() + + buf = pgio.AppendInt32(buf, int32(elementCount)) + + var encodePlan EncodePlan + var lastElemType reflect.Type + for i := 0; i < elementCount; i++ { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elem := multirange.Index(i) + var elemBuf []byte + if elem != nil { + elemType := reflect.TypeOf(elem) + if lastElemType != elemType { + lastElemType = elemType + encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, BinaryFormatCode, elem) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", multirange.Index(i)) + } + } + elemBuf, err = encodePlan.Encode(elem, buf) + if err != nil { + return nil, err + } + } + + if elemBuf == nil { + return nil, fmt.Errorf("multirange cannot contain NULL element") + } else { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +func (c *MultirangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + multirangeScanner, ok := target.(MultirangeSetter) + if !ok { + return nil + } + + elementType := multirangeScanner.ScanIndexType() + + elementScanPlan := m.PlanScan(c.ElementType.OID, format, elementType) + if _, ok := elementScanPlan.(*scanPlanFail); ok { + return nil + } + + return &scanPlanMultirangeCodec{ + multirangeCodec: c, + m: m, + oid: oid, + formatCode: format, + } +} + +func (c *MultirangeCodec) decodeBinary(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error { + rp := 0 + + elementCount := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + err := multirange.SetLen(elementCount) + if err != nil { + return err + } + + if elementCount == 0 { + return nil + } + + elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0)) + if elementScanPlan == nil { + elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0)) + } + + for i := 0; i < elementCount; i++ { + elem := multirange.ScanIndex(i) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elementScanPlan.Scan(elemSrc, elem) + if err != nil { + return fmt.Errorf("failed to scan multirange element %d: %w", i, err) + } + } + + return nil +} + +func (c *MultirangeCodec) decodeText(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error { + elements, err := parseUntypedTextMultirange(src) + if err != nil { + return err + } + + err = multirange.SetLen(len(elements)) + if err != nil { + return err + } + + if len(elements) == 0 { + return nil + } + + elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0)) + if elementScanPlan == nil { + elementScanPlan = m.PlanScan(c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0)) + } + + for i, s := range elements { + elem := multirange.ScanIndex(i) + err = elementScanPlan.Scan([]byte(s), elem) + if err != nil { + return err + } + } + + return nil +} + +type scanPlanMultirangeCodec struct { + multirangeCodec *MultirangeCodec + m *Map + oid uint32 + formatCode int16 + elementScanPlan ScanPlan +} + +func (spac *scanPlanMultirangeCodec) Scan(src []byte, dst any) error { + c := spac.multirangeCodec + m := spac.m + oid := spac.oid + formatCode := spac.formatCode + + multirange := dst.(MultirangeSetter) + + if src == nil { + return multirange.ScanNull() + } + + switch formatCode { + case BinaryFormatCode: + return c.decodeBinary(m, oid, src, multirange) + case TextFormatCode: + return c.decodeText(m, oid, src, multirange) + default: + return fmt.Errorf("unknown format code %d", formatCode) + } +} + +func (c *MultirangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } +} + +func (c *MultirangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var multirange Multirange[Range[any]] + err := m.PlanScan(oid, format, &multirange).Scan(src, &multirange) + return multirange, err +} + +func parseUntypedTextMultirange(src []byte) ([]string, error) { + elements := make([]string, 0) + + buf := bytes.NewBuffer(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != '{' { + return nil, fmt.Errorf("invalid multirange, expected '{': %v", err) + } + +parseValueLoop: + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid multirange: %v", err) + } + + switch r { + case ',': // skip range separator + case '}': + break parseValueLoop + default: + buf.UnreadRune() + value, err := parseRange(buf) + if err != nil { + return nil, fmt.Errorf("invalid multirange value: %v", err) + } + elements = append(elements, value) + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + return elements, nil + +} + +func parseRange(buf *bytes.Buffer) (string, error) { + s := &bytes.Buffer{} + + boundSepRead := false + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case ',', '}': + if r == ',' && !boundSepRead { + boundSepRead = true + break + } + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} + +// Multirange is a generic multirange type. +// +// T should implement RangeValuer and *T should implement RangeScanner. However, there does not appear to be a way to +// enforce the RangeScanner constraint. +type Multirange[T RangeValuer] []T + +func (r Multirange[T]) IsNull() bool { + return r == nil +} + +func (r Multirange[T]) Len() int { + return len(r) +} + +func (r Multirange[T]) Index(i int) any { + return r[i] +} + +func (r Multirange[T]) IndexType() any { + var zero T + return zero +} + +func (r *Multirange[T]) ScanNull() error { + *r = nil + return nil +} + +func (r *Multirange[T]) SetLen(n int) error { + *r = make([]T, n) + return nil +} + +func (r Multirange[T]) ScanIndex(i int) any { + return &r[i] +} + +func (r Multirange[T]) ScanIndexType() any { + return new(T) +} diff --git a/pgtype/multirange_test.go b/pgtype/multirange_test.go new file mode 100644 index 00000000..6d669b2b --- /dev/null +++ b/pgtype/multirange_test.go @@ -0,0 +1,112 @@ +package pgtype_test + +import ( + "context" + "reflect" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +func TestMultirangeCodecTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4multirange", []pgxtest.ValueRoundTripTest{ + { + pgtype.Multirange[pgtype.Range[pgtype.Int4]](nil), + new(pgtype.Multirange[pgtype.Range[pgtype.Int4]]), + func(a any) bool { return reflect.DeepEqual(pgtype.Multirange[pgtype.Range[pgtype.Int4]](nil), a) }, + }, + { + pgtype.Multirange[pgtype.Range[pgtype.Int4]]{}, + new(pgtype.Multirange[pgtype.Range[pgtype.Int4]]), + func(a any) bool { return reflect.DeepEqual(pgtype.Multirange[pgtype.Range[pgtype.Int4]]{}, a) }, + }, + { + pgtype.Multirange[pgtype.Range[pgtype.Int4]]{ + { + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + { + Lower: pgtype.Int4{Int32: 7, Valid: true}, + Upper: pgtype.Int4{Int32: 9, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, + new(pgtype.Multirange[pgtype.Range[pgtype.Int4]]), + func(a any) bool { + return reflect.DeepEqual(pgtype.Multirange[pgtype.Range[pgtype.Int4]]{ + { + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + { + Lower: pgtype.Int4{Int32: 7, Valid: true}, + Upper: pgtype.Int4{Int32: 9, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, a) + }, + }, + }) +} + +func TestMultirangeCodecDecodeValue(t *testing.T) { + skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + + for _, tt := range []struct { + sql string + expected any + }{ + { + sql: `select int4multirange(int4range(1, 5), int4range(7,9))`, + expected: pgtype.Multirange[pgtype.Range[any]]{ + { + Lower: int32(1), + Upper: int32(5), + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + { + Lower: int32(7), + Upper: int32(9), + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(ctx, tt.sql) + require.NoError(t, err) + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 787eaead..bf58bb23 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -12,97 +12,109 @@ import ( // PostgreSQL oids for common types const ( - BoolOID = 16 - ByteaOID = 17 - QCharOID = 18 - NameOID = 19 - Int8OID = 20 - Int2OID = 21 - Int4OID = 23 - TextOID = 25 - OIDOID = 26 - TIDOID = 27 - XIDOID = 28 - CIDOID = 29 - JSONOID = 114 - JSONArrayOID = 199 - PointOID = 600 - LsegOID = 601 - PathOID = 602 - BoxOID = 603 - PolygonOID = 604 - LineOID = 628 - LineArrayOID = 629 - CIDROID = 650 - CIDRArrayOID = 651 - Float4OID = 700 - Float8OID = 701 - CircleOID = 718 - CircleArrayOID = 719 - UnknownOID = 705 - MacaddrOID = 829 - InetOID = 869 - BoolArrayOID = 1000 - QCharArrayOID = 1003 - NameArrayOID = 1003 - Int2ArrayOID = 1005 - Int4ArrayOID = 1007 - TextArrayOID = 1009 - TIDArrayOID = 1010 - ByteaArrayOID = 1001 - XIDArrayOID = 1011 - CIDArrayOID = 1012 - BPCharArrayOID = 1014 - VarcharArrayOID = 1015 - Int8ArrayOID = 1016 - PointArrayOID = 1017 - LsegArrayOID = 1018 - PathArrayOID = 1019 - BoxArrayOID = 1020 - Float4ArrayOID = 1021 - Float8ArrayOID = 1022 - PolygonArrayOID = 1027 - OIDArrayOID = 1028 - ACLItemOID = 1033 - ACLItemArrayOID = 1034 - MacaddrArrayOID = 1040 - InetArrayOID = 1041 - BPCharOID = 1042 - VarcharOID = 1043 - DateOID = 1082 - TimeOID = 1083 - TimestampOID = 1114 - TimestampArrayOID = 1115 - DateArrayOID = 1182 - TimeArrayOID = 1183 - TimestamptzOID = 1184 - TimestamptzArrayOID = 1185 - IntervalOID = 1186 - IntervalArrayOID = 1187 - NumericArrayOID = 1231 - BitOID = 1560 - BitArrayOID = 1561 - VarbitOID = 1562 - VarbitArrayOID = 1563 - NumericOID = 1700 - RecordOID = 2249 - RecordArrayOID = 2287 - UUIDOID = 2950 - UUIDArrayOID = 2951 - JSONBOID = 3802 - JSONBArrayOID = 3807 - DaterangeOID = 3912 - DaterangeArrayOID = 3913 - Int4rangeOID = 3904 - Int4rangeArrayOID = 3905 - NumrangeOID = 3906 - NumrangeArrayOID = 3907 - TsrangeOID = 3908 - TsrangeArrayOID = 3909 - TstzrangeOID = 3910 - TstzrangeArrayOID = 3911 - Int8rangeOID = 3926 - Int8rangeArrayOID = 3927 + BoolOID = 16 + ByteaOID = 17 + QCharOID = 18 + NameOID = 19 + Int8OID = 20 + Int2OID = 21 + Int4OID = 23 + TextOID = 25 + OIDOID = 26 + TIDOID = 27 + XIDOID = 28 + CIDOID = 29 + JSONOID = 114 + JSONArrayOID = 199 + PointOID = 600 + LsegOID = 601 + PathOID = 602 + BoxOID = 603 + PolygonOID = 604 + LineOID = 628 + LineArrayOID = 629 + CIDROID = 650 + CIDRArrayOID = 651 + Float4OID = 700 + Float8OID = 701 + CircleOID = 718 + CircleArrayOID = 719 + UnknownOID = 705 + MacaddrOID = 829 + InetOID = 869 + BoolArrayOID = 1000 + QCharArrayOID = 1003 + NameArrayOID = 1003 + Int2ArrayOID = 1005 + Int4ArrayOID = 1007 + TextArrayOID = 1009 + TIDArrayOID = 1010 + ByteaArrayOID = 1001 + XIDArrayOID = 1011 + CIDArrayOID = 1012 + BPCharArrayOID = 1014 + VarcharArrayOID = 1015 + Int8ArrayOID = 1016 + PointArrayOID = 1017 + LsegArrayOID = 1018 + PathArrayOID = 1019 + BoxArrayOID = 1020 + Float4ArrayOID = 1021 + Float8ArrayOID = 1022 + PolygonArrayOID = 1027 + OIDArrayOID = 1028 + ACLItemOID = 1033 + ACLItemArrayOID = 1034 + MacaddrArrayOID = 1040 + InetArrayOID = 1041 + BPCharOID = 1042 + VarcharOID = 1043 + DateOID = 1082 + TimeOID = 1083 + TimestampOID = 1114 + TimestampArrayOID = 1115 + DateArrayOID = 1182 + TimeArrayOID = 1183 + TimestamptzOID = 1184 + TimestamptzArrayOID = 1185 + IntervalOID = 1186 + IntervalArrayOID = 1187 + NumericArrayOID = 1231 + BitOID = 1560 + BitArrayOID = 1561 + VarbitOID = 1562 + VarbitArrayOID = 1563 + NumericOID = 1700 + RecordOID = 2249 + RecordArrayOID = 2287 + UUIDOID = 2950 + UUIDArrayOID = 2951 + JSONBOID = 3802 + JSONBArrayOID = 3807 + DaterangeOID = 3912 + DaterangeArrayOID = 3913 + Int4rangeOID = 3904 + Int4rangeArrayOID = 3905 + NumrangeOID = 3906 + NumrangeArrayOID = 3907 + TsrangeOID = 3908 + TsrangeArrayOID = 3909 + TstzrangeOID = 3910 + TstzrangeArrayOID = 3911 + Int8rangeOID = 3926 + Int8rangeArrayOID = 3927 + Int4multirangeOID = 4451 + NummultirangeOID = 4532 + TsmultirangeOID = 4533 + TstzmultirangeOID = 4534 + DatemultirangeOID = 4535 + Int8multirangeOID = 4536 + Int4multirangeArrayOID = 6150 + NummultirangeArrayOID = 6151 + TsmultirangeArrayOID = 6152 + TstzmultirangeArrayOID = 6153 + DatemultirangeArrayOID = 6155 + Int8multirangeArrayOID = 6157 ) type InfinityModifier int8 @@ -222,6 +234,7 @@ func NewMap() *Map { }, } + // Base types m.RegisterType(&Type{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) m.RegisterType(&Type{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) m.RegisterType(&Type{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) @@ -263,6 +276,7 @@ func NewMap() *Map { m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) + // Range types m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: m.oidToType[DateOID]}}) m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int4OID]}}) m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int8OID]}}) @@ -270,6 +284,15 @@ func NewMap() *Map { m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestampOID]}}) m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestamptzOID]}}) + // Multirange types + m.RegisterType(&Type{Name: "datemultirange", OID: DatemultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[DaterangeOID]}}) + m.RegisterType(&Type{Name: "int4multirange", OID: Int4multirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[Int4rangeOID]}}) + m.RegisterType(&Type{Name: "int8multirange", OID: Int8multirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[Int8rangeOID]}}) + m.RegisterType(&Type{Name: "nummultirange", OID: NummultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[NumrangeOID]}}) + m.RegisterType(&Type{Name: "tsmultirange", OID: TsmultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[TsrangeOID]}}) + m.RegisterType(&Type{Name: "tstzmultirange", OID: TstzmultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[TstzrangeOID]}}) + + // Array types m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}}) m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}}) m.RegisterType(&Type{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BoolOID]}}) @@ -349,20 +372,25 @@ func NewMap() *Map { registerDefaultPgTypeVariants[Circle](m, "circle") registerDefaultPgTypeVariants[Date](m, "date") registerDefaultPgTypeVariants[Range[Date]](m, "daterange") + registerDefaultPgTypeVariants[Multirange[Range[Date]]](m, "datemultirange") registerDefaultPgTypeVariants[Float4](m, "float4") registerDefaultPgTypeVariants[Float8](m, "float8") - registerDefaultPgTypeVariants[Range[Float8]](m, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange. + registerDefaultPgTypeVariants[Range[Float8]](m, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange. + registerDefaultPgTypeVariants[Multirange[Range[Float8]]](m, "nummultirange") // There is no PostgreSQL builtin float8multirange so map it to nummultirange. registerDefaultPgTypeVariants[Inet](m, "inet") registerDefaultPgTypeVariants[Int2](m, "int2") registerDefaultPgTypeVariants[Int4](m, "int4") registerDefaultPgTypeVariants[Range[Int4]](m, "int4range") + registerDefaultPgTypeVariants[Multirange[Range[Int4]]](m, "int4multirange") registerDefaultPgTypeVariants[Int8](m, "int8") registerDefaultPgTypeVariants[Range[Int8]](m, "int8range") + registerDefaultPgTypeVariants[Multirange[Range[Int8]]](m, "int8multirange") registerDefaultPgTypeVariants[Interval](m, "interval") registerDefaultPgTypeVariants[Line](m, "line") registerDefaultPgTypeVariants[Lseg](m, "lseg") registerDefaultPgTypeVariants[Numeric](m, "numeric") registerDefaultPgTypeVariants[Range[Numeric]](m, "numrange") + registerDefaultPgTypeVariants[Multirange[Range[Numeric]]](m, "nummultirange") registerDefaultPgTypeVariants[Path](m, "path") registerDefaultPgTypeVariants[Point](m, "point") registerDefaultPgTypeVariants[Polygon](m, "polygon") @@ -372,7 +400,9 @@ func NewMap() *Map { registerDefaultPgTypeVariants[Timestamp](m, "timestamp") registerDefaultPgTypeVariants[Timestamptz](m, "timestamptz") registerDefaultPgTypeVariants[Range[Timestamp]](m, "tsrange") + registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](m, "tsmultirange") registerDefaultPgTypeVariants[Range[Timestamptz]](m, "tstzrange") + registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](m, "tstzmultirange") registerDefaultPgTypeVariants[UUID](m, "uuid") return m From f9857b73d9334cd74af0982cd45ab0e7282b18a1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Apr 2022 16:55:24 -0500 Subject: [PATCH 1017/1158] Skip multirange tests on PG < 14 --- pgtype/multirange_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/multirange_test.go b/pgtype/multirange_test.go index 6d669b2b..77273e59 100644 --- a/pgtype/multirange_test.go +++ b/pgtype/multirange_test.go @@ -12,6 +12,7 @@ import ( ) func TestMultirangeCodecTranscode(t *testing.T) { + skipPostgreSQLVersionLessThan(t, 14) skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4multirange", []pgxtest.ValueRoundTripTest{ @@ -66,6 +67,7 @@ func TestMultirangeCodecTranscode(t *testing.T) { } func TestMultirangeCodecDecodeValue(t *testing.T) { + skipPostgreSQLVersionLessThan(t, 14) skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { From b72b0daa5a8888f69562bab6674bd7f8ac4c99c4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Apr 2022 17:26:42 -0500 Subject: [PATCH 1018/1158] Add QueryRewriter interface --- batch.go | 2 +- batch_test.go | 30 ++++++++++++++++++++++++++++++ conn.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ conn_test.go | 19 +++++++++++++++++++ query_test.go | 20 ++++++++++++++++++++ 5 files changed, 115 insertions(+), 1 deletion(-) diff --git a/batch.go b/batch.go index 103d9aed..98f216dd 100644 --- a/batch.go +++ b/batch.go @@ -14,7 +14,7 @@ type batchItem struct { } // Batch queries are a way of bundling multiple queries together to avoid -// unnecessary network round trips. +// unnecessary network round trips. A Batch must only be sent once. type Batch struct { items []*batchItem } diff --git a/batch_test.go b/batch_test.go index 5558b823..96cf61c2 100644 --- a/batch_test.go +++ b/batch_test.go @@ -239,6 +239,36 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) { }) } +func TestConnSendBatchWithQueryRewriter(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + batch := &pgx.Batch{} + batch.Queue("something to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{1}}) + batch.Queue("something else to be replaced", &testQueryRewriter{sql: "select $1::text", args: []any{"hello"}}) + batch.Queue("more to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{3}}) + + br := conn.SendBatch(context.Background(), batch) + + var n int32 + err := br.QueryRow().Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + + var s string + err = br.QueryRow().Scan(&s) + require.NoError(t, err) + require.Equal(t, "hello", s) + + err = br.QueryRow().Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 3, n) + + err = br.Close() + require.NoError(t, err) + }) +} + // https://github.com/jackc/pgx/issues/856 func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) { t.Parallel() diff --git a/conn.go b/conn.go index d0c9fe33..dd4a7301 100644 --- a/conn.go +++ b/conn.go @@ -404,6 +404,7 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.C func (c *Conn) exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) { mode := c.config.DefaultQueryExecMode + var queryRewriter QueryRewriter optionLoop: for len(arguments) > 0 { @@ -411,11 +412,18 @@ optionLoop: case QueryExecMode: mode = arg arguments = arguments[1:] + case QueryRewriter: + queryRewriter = arg + arguments = arguments[1:] default: break optionLoop } } + if queryRewriter != nil { + sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments) + } + // Always use simple protocol when there are no arguments. if len(arguments) == 0 { mode = QueryExecModeSimpleProtocol @@ -682,6 +690,11 @@ type QueryResultFormats []int16 // QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID. type QueryResultFormatsByOID map[uint32]int16 +// QueryRewriter rewrites a query when used as the first arguments to a query method. +type QueryRewriter interface { + RewriteQuery(ctx context.Context, conn *Conn, sql string, args ...any) (newSQL string, newArgs []any) +} + // Query executes sql with args. It is safe to attempt to read from the returned Rows even if an error is returned. The // error will be the available in rows.Err() after rows are closed. So it is allowed to ignore the error returned from // Query and handle it in Rows. @@ -696,6 +709,7 @@ func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) var resultFormats QueryResultFormats var resultFormatsByOID QueryResultFormatsByOID mode := c.config.DefaultQueryExecMode + var queryRewriter QueryRewriter optionLoop: for len(args) > 0 { @@ -709,11 +723,18 @@ optionLoop: case QueryExecMode: mode = arg args = args[1:] + case QueryRewriter: + queryRewriter = arg + args = args[1:] default: break optionLoop } } + if queryRewriter != nil { + sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args) + } + c.eqb.Reset() anynil.NormalizeSlice(args) rows := c.getRows(ctx, sql, args) @@ -883,6 +904,30 @@ func (c *Conn) QueryFunc(ctx context.Context, sql string, args []any, scans []an func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { mode := c.config.DefaultQueryExecMode + for _, bi := range b.items { + var queryRewriter QueryRewriter + sql := bi.query + arguments := bi.arguments + + optionLoop: + for len(arguments) > 0 { + switch arg := arguments[0].(type) { + case QueryRewriter: + queryRewriter = arg + arguments = arguments[1:] + default: + break optionLoop + } + } + + if queryRewriter != nil { + sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments) + } + + bi.query = sql + bi.arguments = arguments + } + if mode == QueryExecModeSimpleProtocol { var sb strings.Builder for i, bi := range b.items { diff --git a/conn_test.go b/conn_test.go index 61f6c951..392ea623 100644 --- a/conn_test.go +++ b/conn_test.go @@ -230,6 +230,25 @@ func TestExec(t *testing.T) { }) } +type testQueryRewriter struct { + sql string + args []any +} + +func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args ...any) (newSQL string, newArgs []any) { + return qr.sql, qr.args +} + +func TestExecWithQueryRewriter(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + qr := testQueryRewriter{sql: "select $1::int", args: []any{42}} + _, err := conn.Exec(ctx, "should be replaced", &qr) + require.NoError(t, err) + }) +} + func TestExecFailure(t *testing.T) { t.Parallel() diff --git a/query_test.go b/query_test.go index 0e310eef..78cacb6c 100644 --- a/query_test.go +++ b/query_test.go @@ -1864,6 +1864,26 @@ func TestQueryErrorWithDisabledStatementCache(t *testing.T) { ensureConnValid(t, conn) } +func TestQueryWithQueryRewriter(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + qr := testQueryRewriter{sql: "select $1::int", args: []any{42}} + rows, err := conn.Query(ctx, "should be replaced", &qr) + require.NoError(t, err) + + var n int32 + var rowCount int + for rows.Next() { + rowCount++ + err = rows.Scan(&n) + require.NoError(t, err) + } + + require.NoError(t, rows.Err()) + }) +} + func TestConnQueryFunc(t *testing.T) { t.Parallel() From 107196ab0c90528a41a9c2910e158763ac320cd4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Apr 2022 18:43:04 -0500 Subject: [PATCH 1019/1158] Add NamedArgs https://github.com/jackc/pgx/issues/1186 https://github.com/jackc/pgx/issues/387 --- conn.go | 2 +- conn_test.go | 2 +- named_args.go | 266 +++++++++++++++++++++++++++++++++++++++++++++ named_args_test.go | 96 ++++++++++++++++ 4 files changed, 364 insertions(+), 2 deletions(-) create mode 100644 named_args.go create mode 100644 named_args_test.go diff --git a/conn.go b/conn.go index dd4a7301..72154325 100644 --- a/conn.go +++ b/conn.go @@ -692,7 +692,7 @@ type QueryResultFormatsByOID map[uint32]int16 // QueryRewriter rewrites a query when used as the first arguments to a query method. type QueryRewriter interface { - RewriteQuery(ctx context.Context, conn *Conn, sql string, args ...any) (newSQL string, newArgs []any) + RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) } // Query executes sql with args. It is safe to attempt to read from the returned Rows even if an error is returned. The diff --git a/conn_test.go b/conn_test.go index 392ea623..675fba17 100644 --- a/conn_test.go +++ b/conn_test.go @@ -235,7 +235,7 @@ type testQueryRewriter struct { args []any } -func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args ...any) (newSQL string, newArgs []any) { +func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any) { return qr.sql, qr.args } diff --git a/named_args.go b/named_args.go new file mode 100644 index 00000000..e6906b3b --- /dev/null +++ b/named_args.go @@ -0,0 +1,266 @@ +package pgx + +import ( + "context" + "strconv" + "strings" + "unicode/utf8" +) + +// NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$' +// ordinal placeholder and construct the appropriate arguments. +// +// For example, the following two queries are equivalent: +// +// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})) +// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2})) +type NamedArgs map[string]any + +// RewriteQuery implements the QueryRewriter interface. +func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) { + l := &sqlLexer{ + src: sql, + stateFn: rawState, + nameToOrdinal: make(map[namedArg]int, len(na)), + } + + for l.stateFn != nil { + l.stateFn = l.stateFn(l) + } + + sb := strings.Builder{} + for _, p := range l.parts { + switch p := p.(type) { + case string: + sb.WriteString(p) + case namedArg: + sb.WriteRune('$') + sb.WriteString(strconv.Itoa(l.nameToOrdinal[p])) + } + } + + newArgs = make([]any, len(l.nameToOrdinal)) + for name, ordinal := range l.nameToOrdinal { + newArgs[ordinal-1] = na[string(name)] + } + + return sb.String(), newArgs +} + +type namedArg string + +type sqlLexer struct { + src string + start int + pos int + nested int // multiline comment nesting level. + stateFn stateFn + parts []any + + nameToOrdinal map[namedArg]int +} + +type stateFn func(*sqlLexer) stateFn + +func rawState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case 'e', 'E': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '\'' { + l.pos += width + return escapeStringState + } + case '\'': + return singleQuoteState + case '"': + return doubleQuoteState + case '@': + nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) + if isLetter(nextRune) { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos-width]) + } + l.start = l.pos + return namedArgState + } + case '-': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '-' { + l.pos += width + return oneLineCommentState + } + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + return multilineCommentState + } + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func isLetter(r rune) bool { + return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') +} + +func namedArgState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + if r == utf8.RuneError { + if l.pos-l.start > 0 { + na := namedArg(l.src[l.start:l.pos]) + if _, found := l.nameToOrdinal[na]; !found { + l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1 + } + l.parts = append(l.parts, na) + l.start = l.pos + } + return nil + } else if !(isLetter(r) || (r >= '0' && r <= '9')) { + l.pos -= width + na := namedArg(l.src[l.start:l.pos]) + if _, found := l.nameToOrdinal[na]; !found { + l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1 + } + l.parts = append(l.parts, namedArg(na)) + l.start = l.pos + return rawState + } + } +} + +func singleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func doubleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '"': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '"' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func escapeStringState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func oneLineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\n', '\r': + return rawState + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func multilineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + l.nested++ + } + case '*': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '/' { + continue + } + + l.pos += width + if l.nested == 0 { + return rawState + } + l.nested-- + + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} diff --git a/named_args_test.go b/named_args_test.go new file mode 100644 index 00000000..fea3b897 --- /dev/null +++ b/named_args_test.go @@ -0,0 +1,96 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/assert" +) + +func TestNamedArgsRewriteQuery(t *testing.T) { + t.Parallel() + + for i, tt := range []struct { + sql string + args []any + namedArgs pgx.NamedArgs + expectedSQL string + expectedArgs []any + }{ + { + sql: "select * from users where id = @id", + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: "select * from users where id = $1", + expectedArgs: []any{int32(42)}, + }, + { + sql: "select * from t where foo < @abc and baz = @def and bar < @abc", + namedArgs: pgx.NamedArgs{"abc": int32(42), "def": int32(1)}, + expectedSQL: "select * from t where foo < $1 and baz = $2 and bar < $1", + expectedArgs: []any{int32(42), int32(1)}, + }, + { + sql: "select @a::int, @b::text", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "select $1::int, $2::text", + expectedArgs: []any{int32(42), "foo"}, + }, + { + sql: "at end @", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "at end @", + expectedArgs: []any{}, + }, + { + sql: "ignores without letter after @ foo bar", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "ignores without letter after @ foo bar", + expectedArgs: []any{}, + }, + { + sql: "name must start with letter @1 foo bar", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "name must start with letter @1 foo bar", + expectedArgs: []any{}, + }, + { + sql: `select *, '@foo' as "@bar" from users where id = @id`, + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: `select *, '@foo' as "@bar" from users where id = $1`, + expectedArgs: []any{int32(42)}, + }, + { + sql: `select * -- @foo + from users -- @single line comments + where id = @id;`, + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: `select * -- @foo + from users -- @single line comments + where id = $1;`, + expectedArgs: []any{int32(42)}, + }, + { + sql: `select * /* @multi line + @comment + */ + /* /* with @nesting */ */ + from users + where id = @id;`, + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: `select * /* @multi line + @comment + */ + /* /* with @nesting */ */ + from users + where id = $1;`, + expectedArgs: []any{int32(42)}, + }, + + // test comments and quotes + } { + sql, args := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args) + assert.Equalf(t, tt.expectedSQL, sql, "%d", i) + assert.Equalf(t, tt.expectedArgs, args, "%d", i) + } +} From c093c4af215d4166def41fe2b274e9cf92e44a78 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Apr 2022 18:56:38 -0500 Subject: [PATCH 1020/1158] Update changelog --- CHANGELOG.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 162a7f68..551c4f35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,7 +43,7 @@ values, but any type may now implement `CompositeIndexGetter` and `CompositeInde ### Range Types Range types are now handled with types `RangeCodec` and `Range[T]`. This allows additional user defined range types to -easily be handled. +easily be handled. Multirange types are handled similarly with `MultirangeCodec` and `Multirange[T]`. ### pgxtype @@ -116,6 +116,10 @@ ownership remains with the read buffer and anything needing to retain a value mu Control over automatic prepared statement caching and simple protocol use are now combined into query execution mode. See documentation for `QueryExecMode`. +## QueryRewriter Interface and NamedArgs + +pgx now supports named arguments with the NamedArgs type. This is implemented via the new QueryRewriter interface which +allows arbitrary rewriting of query SQL and arguments. ## 3rd Party Logger Integration All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency From 53266f029fbb23a31220663ac094869ed4701a0f Mon Sep 17 00:00:00 2001 From: Diego Becciolini Date: Mon, 25 Apr 2022 12:53:15 +0100 Subject: [PATCH 1021/1158] Hstore: fix AssignTo Hstore.AssignTo a map of string pointers takes the address of the loop variable, thus setting all the entries to the same string pointer. extend TestHstoreAssignToNullable assert fix --- hstore.go | 3 ++- hstore_test.go | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/hstore.go b/hstore.go index 706a3964..d21af7bc 100644 --- a/hstore.go +++ b/hstore.go @@ -90,7 +90,8 @@ func (src *Hstore) AssignTo(dst interface{}) error { case Null: (*v)[k] = nil case Present: - (*v)[k] = &val.String + str := val.String + (*v)[k] = &str default: return fmt.Errorf("cannot decode %#v into %T", src, dst) } diff --git a/hstore_test.go b/hstore_test.go index 73ee0612..32a8f015 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -181,13 +181,14 @@ func TestHstoreAssignTo(t *testing.T) { func TestHstoreAssignToNullable(t *testing.T) { var m map[string]*string + strPtr := func(str string) *string { return &str } simpleTests := []struct { src pgtype.Hstore dst *map[string]*string expected map[string]*string }{ - {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {Status: pgtype.Null}}, Status: pgtype.Present}, dst: &m, expected: map[string]*string{"foo": nil}}, + {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {Status: pgtype.Null}, "bar": {String: "1", Status: pgtype.Present}, "baz": {String: "2", Status: pgtype.Present}}, Status: pgtype.Present}, dst: &m, expected: map[string]*string{"foo": nil, "bar": strPtr("1"), "baz": strPtr("2")}}, {src: pgtype.Hstore{Status: pgtype.Null}, dst: &m, expected: ((map[string]*string)(nil))}, } From d846dbcb75b2ac38a1c6fd0390cd18472fe72dee Mon Sep 17 00:00:00 2001 From: Harmen Date: Sun, 24 Apr 2022 08:03:31 +0200 Subject: [PATCH 1022/1158] allow string values in timestamp[tz].Set() --- date.go | 4 ++-- timestamp.go | 8 ++++++++ timestamp_test.go | 1 + timestamptz.go | 8 ++++++++ timestamptz_test.go | 1 + 5 files changed, 20 insertions(+), 2 deletions(-) diff --git a/date.go b/date.go index e8d21a78..ca84970e 100644 --- a/date.go +++ b/date.go @@ -37,14 +37,14 @@ func (dst *Date) Set(src interface{}) error { switch value := src.(type) { case time.Time: *dst = Date{Time: value, Status: Present} - case string: - return dst.DecodeText(nil, []byte(value)) case *time.Time: if value == nil { *dst = Date{Status: Null} } else { return dst.Set(*value) } + case string: + return dst.DecodeText(nil, []byte(value)) case *string: if value == nil { *dst = Date{Status: Null} diff --git a/timestamp.go b/timestamp.go index 5517acb1..e043726d 100644 --- a/timestamp.go +++ b/timestamp.go @@ -46,6 +46,14 @@ func (dst *Timestamp) Set(src interface{}) error { } else { return dst.Set(*value) } + case string: + return dst.DecodeText(nil, []byte(value)) + case *string: + if value == nil { + *dst = Timestamp{Status: Null} + } else { + return dst.Set(*value) + } case InfinityModifier: *dst = Timestamp{InfinityModifier: value, Status: Present} default: diff --git a/timestamp_test.go b/timestamp_test.go index ea7ef57a..d818d4f6 100644 --- a/timestamp_test.go +++ b/timestamp_test.go @@ -123,6 +123,7 @@ func TestTimestampSet(t *testing.T) { {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: pgtype.Infinity, result: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, {source: pgtype.NegativeInfinity, result: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + {source: "2001-04-05 06:07:08", result: pgtype.Timestamp{Time: time.Date(2001, 4, 5, 6, 7, 8, 0, time.UTC), Status: pgtype.Present}}, } for i, tt := range successfulTests { diff --git a/timestamptz.go b/timestamptz.go index 58701970..72ae4991 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -48,6 +48,14 @@ func (dst *Timestamptz) Set(src interface{}) error { } else { return dst.Set(*value) } + case string: + return dst.DecodeText(nil, []byte(value)) + case *string: + if value == nil { + *dst = Timestamptz{Status: Null} + } else { + return dst.Set(*value) + } case InfinityModifier: *dst = Timestamptz{InfinityModifier: value, Status: Present} default: diff --git a/timestamptz_test.go b/timestamptz_test.go index 2ff326bb..d6a3f518 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -120,6 +120,7 @@ func TestTimestamptzSet(t *testing.T) { {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, {source: pgtype.Infinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, {source: pgtype.NegativeInfinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + {source: "2020-04-05 06:07:08Z", result: pgtype.Timestamptz{Time: time.Date(2020, 4, 5, 6, 7, 8, 0, time.UTC), Status: pgtype.Present}}, } for i, tt := range successfulTests { From d13bdbbd3511a163b5037f1c0310ff4ce302e5c0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 25 Apr 2022 10:16:47 -0500 Subject: [PATCH 1023/1158] NamedArgs allows underscore --- named_args.go | 2 +- named_args_test.go | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/named_args.go b/named_args.go index e6906b3b..3d91367b 100644 --- a/named_args.go +++ b/named_args.go @@ -128,7 +128,7 @@ func namedArgState(l *sqlLexer) stateFn { l.start = l.pos } return nil - } else if !(isLetter(r) || (r >= '0' && r <= '9')) { + } else if !(isLetter(r) || (r >= '0' && r <= '9') || r == '_') { l.pos -= width na := namedArg(l.src[l.start:l.pos]) if _, found := l.nameToOrdinal[na]; !found { diff --git a/named_args_test.go b/named_args_test.go index fea3b897..116e03dc 100644 --- a/named_args_test.go +++ b/named_args_test.go @@ -36,6 +36,12 @@ func TestNamedArgsRewriteQuery(t *testing.T) { expectedSQL: "select $1::int, $2::text", expectedArgs: []any{int32(42), "foo"}, }, + { + sql: "select @Abc::int, @b_4::text", + namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo"}, + expectedSQL: "select $1::int, $2::text", + expectedArgs: []any{int32(42), "foo"}, + }, { sql: "at end @", namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, From 7427820abac0fa694e0d522882a56e67b78fa5ae Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 26 Apr 2022 08:37:10 -0500 Subject: [PATCH 1024/1158] Scan binary UUID to string https://github.com/jackc/pgx/issues/1191 --- pgtype/uuid.go | 21 +++++++++++++++++++++ pgtype/uuid_test.go | 5 +++++ 2 files changed, 26 insertions(+) diff --git a/pgtype/uuid.go b/pgtype/uuid.go index 8c3bbba5..96a4c32f 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -173,6 +173,8 @@ func (UUIDCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan switch target.(type) { case UUIDScanner: return scanPlanBinaryUUIDToUUIDScanner{} + case TextScanner: + return scanPlanBinaryUUIDToTextScanner{} } case TextFormatCode: switch target.(type) { @@ -203,6 +205,25 @@ func (scanPlanBinaryUUIDToUUIDScanner) Scan(src []byte, dst any) error { return scanner.ScanUUID(uuid) } +type scanPlanBinaryUUIDToTextScanner struct{} + +func (scanPlanBinaryUUIDToTextScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TextScanner) + + if src == nil { + return scanner.ScanText(Text{}) + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for UUID: %v", len(src)) + } + + var buf [16]byte + copy(buf[:], src) + + return scanner.ScanText(Text{String: encodeUUID(buf), Valid: true}) +} + type scanPlanTextAnyToUUIDScanner struct{} func (scanPlanTextAnyToUUIDScanner) Scan(src []byte, dst any) error { diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index 06ff38c2..2dc258b1 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -27,6 +27,11 @@ func TestUUIDCodec(t *testing.T) { new(pgtype.UUID), isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), }, + { + pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + new(string), + isExpectedEq("00010203-0405-0607-0809-0a0b0c0d0e0f"), + }, {pgtype.UUID{}, new([]byte), isExpectedEqBytes([]byte(nil))}, {pgtype.UUID{}, new(pgtype.UUID), isExpectedEq(pgtype.UUID{})}, {nil, new(pgtype.UUID), isExpectedEq(pgtype.UUID{})}, From 84e8238fa074e88b15e2737f4d8a0a8abb1ae6bf Mon Sep 17 00:00:00 2001 From: sireax Date: Mon, 25 Apr 2022 18:37:41 +0400 Subject: [PATCH 1025/1158] Fix: setting krbspn and krbsrvname did'n work --- config.go | 12 +++++++++++- krb5.go | 8 ++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/config.go b/config.go index 6e6930ee..859672ea 100644 --- a/config.go +++ b/config.go @@ -41,7 +41,9 @@ type Config struct { BuildFrontend BuildFrontendFunc RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) - Fallbacks []*FallbackConfig + KerberosSrvName string + KerberosSpn string + Fallbacks []*FallbackConfig // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next @@ -265,6 +267,14 @@ func ParseConfig(connString string) (*Config, error) { "servicefile": {}, } + // Adding kerberos configuration + if _, present := settings["krbsrvname"]; present { + config.KerberosSrvName = settings["krbsrvname"] + } + if _, present := settings["krbspn"]; present { + config.KerberosSpn = settings["krbspn"] + } + for k, v := range settings { if _, present := notRuntimeParams[k]; present { continue diff --git a/krb5.go b/krb5.go index 1f9ce97c..f2dbe45a 100644 --- a/krb5.go +++ b/krb5.go @@ -41,14 +41,14 @@ func (c *PgConn) gssAuth() error { } var nextData []byte - if spn, ok := c.config.RuntimeParams["krbspn"]; ok { + if c.config.KerberosSpn != "" { // Use the supplied SPN if provided. - nextData, err = cli.GetInitTokenFromSPN(spn) + nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn) } else { // Allow the kerberos service name to be overridden service := "postgres" - if val, ok := c.config.RuntimeParams["krbsrvname"]; ok { - service = val + if c.config.KerberosSrvName != "" { + service = c.config.KerberosSrvName } nextData, err = cli.GetInitToken(c.config.Host, service) } From 0c6266ef3075d65c9ec3883a0251b79f9a7371ac Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 26 Apr 2022 14:52:01 -0500 Subject: [PATCH 1026/1158] Fix scanning null did not overwrite slice --- pgtype/array.go | 2 +- query_test.go | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pgtype/array.go b/pgtype/array.go index 93f5aa9b..0fa4c129 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -463,7 +463,7 @@ func (a FlatArray[T]) IndexType() any { func (a *FlatArray[T]) SetDimensions(dimensions []ArrayDimension) error { if dimensions == nil { - a = nil + *a = nil return nil } diff --git a/query_test.go b/query_test.go index 78cacb6c..7fa507b0 100644 --- a/query_test.go +++ b/query_test.go @@ -1094,6 +1094,21 @@ func TestReadingNullByteArrays(t *testing.T) { } } +func TestQueryNullSliceIsSet(t *testing.T) { + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + a := []int32{1, 2, 3} + err := conn.QueryRow(context.Background(), "select null::int[]").Scan(&a) + if err != nil { + t.Fatalf("conn.QueryRow failed: %v", err) + } + + if a != nil { + t.Errorf("Expected 'a' to be nil, but it was: %v", a) + } +} + func TestConnQueryDatabaseSQLScanner(t *testing.T) { t.Parallel() From 81d55568f650ada744172ca76905c073beb4f168 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 28 Apr 2022 08:13:44 -0500 Subject: [PATCH 1027/1158] Clarify v5 supported Go version plans fixes https://github.com/jackc/pgx/issues/1197 --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4448eb23..3cc5a91b 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,9 @@ In addition, there are tests specific for PgBouncer that will be executed if `PG ## Supported Go and PostgreSQL Versions -pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.17 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). +~~pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.17 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).~~ + +`v5` is targeted at Go 1.18+. The general release of `v5` is not planned until second half of 2022 so it is expected that the policy of supporting the two most recent versions of Go will be maintained or restored soon after its release. ## Version Policy From a89a400b6998f5831bbea9e29301dd54f1203861 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Apr 2022 08:27:57 -0500 Subject: [PATCH 1028/1158] Fix documentation for Rows.RawValues and test new behavior --- conn.go | 4 ++-- conn_test.go | 32 ++++++++++++++++++++++++++++++++ rows.go | 4 ++-- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index 5819611a..ec029ace 100644 --- a/conn.go +++ b/conn.go @@ -864,8 +864,8 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row { type QueryFuncRow interface { FieldDescriptions() []pgproto3.FieldDescription - // RawValues returns the unparsed bytes of the row values. The returned [][]byte is only valid during the current - // function call. However, the underlying byte data is safe to retain a reference to and mutate. + // RawValues returns the unparsed bytes of the row values. The returned data is only valid during the current + // function call. RawValues() [][]byte } diff --git a/conn_test.go b/conn_test.go index 5115d76e..79697cbd 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,6 +1,7 @@ package pgx_test import ( + "bytes" "context" "os" "strings" @@ -1053,3 +1054,34 @@ func TestInsertDurationInterval(t *testing.T) { require.EqualValues(t, 1, n) }) } + +func TestRawValuesUnderlyingMemoryReused(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var buf []byte + + rows, err := conn.Query(ctx, `select 1::int`) + require.NoError(t, err) + + for rows.Next() { + buf = rows.RawValues()[0] + } + + require.NoError(t, rows.Err()) + + original := make([]byte, len(buf)) + copy(original, buf) + + for i := 0; i < 1_000_000; i++ { + rows, err := conn.Query(ctx, `select $1::int`, i) + require.NoError(t, err) + rows.Close() + require.NoError(t, rows.Err()) + + if bytes.Compare(original, buf) != 0 { + return + } + } + + t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not") + }) +} diff --git a/rows.go b/rows.go index da5773e0..b757b1c7 100644 --- a/rows.go +++ b/rows.go @@ -51,8 +51,8 @@ type Rows interface { // true. Values() ([]any, error) - // RawValues returns the unparsed bytes of the row values. The returned [][]byte is only valid until the next Next - // call or the Rows is closed. However, the underlying byte data is safe to retain a reference to and mutate. + // RawValues returns the unparsed bytes of the row values. The returned data is only valid until the next Next + // call or the Rows is closed. RawValues() [][]byte } From 01190e5d78a45773bb50ed9b7f768b349f6a36c5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Apr 2022 08:29:51 -0500 Subject: [PATCH 1029/1158] Update ScanPlan.Scan documentation --- pgtype/pgtype.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index bf58bb23..636b0954 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -516,7 +516,8 @@ type EncodePlan interface { // ScanPlan is a precompiled plan to scan into a type of destination. type ScanPlan interface { - // Scan scans src into target. + // Scan scans src into target. src is only valid during the call to Scan. The ScanPlan must not retain a reference to + // src. Scan(src []byte, target any) error } From c1495aace02f325b9d318a1ed8d7c8f92285cd1f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Apr 2022 12:49:12 -0500 Subject: [PATCH 1030/1158] Add RowScanner interface --- CHANGELOG.md | 4 ++++ rows.go | 13 +++++++++++++ rows_test.go | 30 ++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+) create mode 100644 rows_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 551c4f35..a7fc57ef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -120,6 +120,10 @@ See documentation for `QueryExecMode`. pgx now supports named arguments with the NamedArgs type. This is implemented via the new QueryRewriter interface which allows arbitrary rewriting of query SQL and arguments. + +## RowScanner Interface + +The `RowScanner` interface allows a single argument to Rows.Scan to scan the entire row. ## 3rd Party Logger Integration All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency diff --git a/rows.go b/rows.go index b757b1c7..4f9c533d 100644 --- a/rows.go +++ b/rows.go @@ -69,6 +69,12 @@ type Row interface { Scan(dest ...any) error } +// RowScanner scans an entire row at a time into the RowScanner. +type RowScanner interface { + // ScanRows scans the row. + ScanRow(rows Rows) error +} + // connRow implements the Row interface for Conn.QueryRow. type connRow connRows @@ -212,6 +218,13 @@ func (rows *connRows) Scan(dest ...any) error { rows.fatal(err) return err } + + if len(dest) == 1 { + if rc, ok := dest[0].(RowScanner); ok { + return rc.ScanRow(rows) + } + } + if len(fieldDescriptions) != len(dest) { err := fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) rows.fatal(err) diff --git a/rows_test.go b/rows_test.go new file mode 100644 index 00000000..37f8e1de --- /dev/null +++ b/rows_test.go @@ -0,0 +1,30 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" +) + +type testRowScanner struct { + name string + age int32 +} + +func (rs *testRowScanner) ScanRow(rows pgx.Rows) error { + return rows.Scan(&rs.name, &rs.age) +} + +func TestRowScanner(t *testing.T) { + t.Parallel() + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var s testRowScanner + err := conn.QueryRow(ctx, "select 'Adam' as name, 72 as height").Scan(&s) + require.NoError(t, err) + require.Equal(t, "Adam", s.name) + require.Equal(t, int32(72), s.age) + }) +} From 0135721378bac6a9a65c9d01925e903f71c145bf Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Fri, 6 May 2022 14:15:03 -0600 Subject: [PATCH 1031/1158] Add support for Unix sockets on Windows Fixes #1199. --- config.go | 21 ++++++++++++++++- config_test.go | 61 ++++++++++++++++++++++++++++++++++++++++++++++++++ pgconn.go | 2 +- 3 files changed, 82 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 859672ea..e141a2f8 100644 --- a/config.go +++ b/config.go @@ -100,10 +100,29 @@ type FallbackConfig struct { TLSConfig *tls.Config // nil disables TLS } +// isAbsolutePath checks if the provided value is an absolute path either +// beginning with a forward slash (as on Linux-based systems) or with a capital +// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows). +func isAbsolutePath(path string) bool { + isWindowsPath := func(p string) bool { + if len(p) < 3 { + return false + } + drive := p[0] + colon := p[1] + backslash := p[2] + if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' { + return true + } + return false + } + return strings.HasPrefix(path, "/") || isWindowsPath(path) +} + // NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with // net.Dial. func NetworkAddress(host string, port uint16) (network, address string) { - if strings.HasPrefix(host, "/") { + if isAbsolutePath(host) { network = "unix" address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) } else { diff --git a/config_test.go b/config_test.go index da28782d..a28db3d6 100644 --- a/config_test.go +++ b/config_test.go @@ -231,6 +231,18 @@ func TestParseConfig(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + name: "database url unix domain socket host on windows", + connString: "postgres:///foo?host=C:\\tmp", + config: &pgconn.Config{ + User: osUserName, + Host: "C:\\tmp", + Port: 5432, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { name: "database url dbname", connString: "postgres://localhost/?dbname=foo&sslmode=disable", @@ -703,6 +715,55 @@ func TestConfigCopyCanBeUsedToConnect(t *testing.T) { assert.NoError(t, err) } +func TestNetworkAddress(t *testing.T) { + tests := []struct { + name string + host string + wantNet string + }{ + { + name: "Default Unix socket address", + host: "/var/run/postgresql", + wantNet: "unix", + }, + { + name: "Windows Unix socket address (standard drive name)", + host: "C:\\tmp", + wantNet: "unix", + }, + { + name: "Windows Unix socket address (first drive name)", + host: "A:\\tmp", + wantNet: "unix", + }, + { + name: "Windows Unix socket address (last drive name)", + host: "Z:\\tmp", + wantNet: "unix", + }, + { + name: "Assume TCP for unknown formats", + host: "a/tmp", + wantNet: "tcp", + }, + { + name: "loopback interface", + host: "localhost", + wantNet: "tcp", + }, + { + name: "IP address", + host: "127.0.0.1", + wantNet: "tcp", + }, + } + for i, tt := range tests { + gotNet, _ := pgconn.NetworkAddress(tt.host, 5432) + + assert.Equalf(t, tt.wantNet, gotNet, "Test %d (%s)", i, tt.name) + } +} + func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { if !assert.NotNil(t, expected) { return diff --git a/pgconn.go b/pgconn.go index f1304d08..ef5b76fd 100644 --- a/pgconn.go +++ b/pgconn.go @@ -187,7 +187,7 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba for _, fb := range fallbacks { // skip resolve for unix sockets - if strings.HasPrefix(fb.Host, "/") { + if isAbsolutePath(fb.Host) { configs = append(configs, &FallbackConfig{ Host: fb.Host, Port: fb.Port, From 1d398317ca41b5aa3f8f83ecc99a932d25d8c870 Mon Sep 17 00:00:00 2001 From: Rafi Shamim Date: Fri, 6 May 2022 22:41:08 -0400 Subject: [PATCH 1032/1158] Stop ignoring ErrorResponse during SCRAM auth The server may send back an ErrorResponse during SCRAM auth, and these messages may contain useful information that described why authentication failed. For example, if the password was invalid. --- auth_scram.go | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index 6a143fcd..d8d71116 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -78,12 +78,14 @@ func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) if err != nil { return nil, err } - saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue) - if ok { - return saslContinue, nil + switch m := msg.(type) { + case *pgproto3.AuthenticationSASLContinue: + return m, nil + case *pgproto3.ErrorResponse: + return nil, ErrorResponseToPgError(m) } - return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message") + return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg) } func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { @@ -91,12 +93,14 @@ func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { if err != nil { return nil, err } - saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal) - if ok { - return saslFinal, nil + switch m := msg.(type) { + case *pgproto3.AuthenticationSASLFinal: + return m, nil + case *pgproto3.ErrorResponse: + return nil, ErrorResponseToPgError(m) } - return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message") + return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg) } type scramClient struct { From 831fc211bc3b06eb7b461992c231fea7496dbf29 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 May 2022 07:11:19 -0500 Subject: [PATCH 1033/1158] Release v1.12.1 --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6df3ddcf..a3efb7f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +# 1.12.1 (May 7, 2022) + +* Fix: setting krbspn and krbsrvname in connection string (sireax) +* Add support for Unix sockets on Windows (Eno Compton) +* Stop ignoring ErrorResponse during SCRAM auth (Rafi Shamim) + # 1.12.0 (April 21, 2022) * Add pluggable GSSAPI support (Oliver Tan) From 989a4835de30e552a5aec5f3d700aed30492c080 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 12 May 2022 17:13:49 -0500 Subject: [PATCH 1034/1158] Remove rune to text conversion Because rune is an alias for int32 this caused some very surprising results. e.g. inserting int32(65) into text would insert "A" instead of "65". --- pgtype/enum_codec.go | 4 ---- pgtype/text.go | 31 ------------------------------- pgtype/text_test.go | 5 ----- 3 files changed, 40 deletions(-) diff --git a/pgtype/enum_codec.go b/pgtype/enum_codec.go index 93513111..3d23b12f 100644 --- a/pgtype/enum_codec.go +++ b/pgtype/enum_codec.go @@ -28,8 +28,6 @@ func (EnumCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodeP return encodePlanTextCodecString{} case []byte: return encodePlanTextCodecByteSlice{} - case rune: - return encodePlanTextCodecRune{} case TextValuer: return encodePlanTextCodecTextValuer{} } @@ -48,8 +46,6 @@ func (c *EnumCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP return scanPlanAnyToNewByteSlice{} case TextScanner: return &scanPlanTextAnyToEnumTextScanner{codec: c} - case *rune: - return scanPlanTextAnyToRune{} } } diff --git a/pgtype/text.go b/pgtype/text.go index 0f9df7a6..7f779d11 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/json" "fmt" - "unicode/utf8" ) type TextScanner interface { @@ -98,8 +97,6 @@ func (TextCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodeP return encodePlanTextCodecString{} case []byte: return encodePlanTextCodecByteSlice{} - case rune: - return encodePlanTextCodecRune{} case TextValuer: return encodePlanTextCodecTextValuer{} } @@ -124,14 +121,6 @@ func (encodePlanTextCodecByteSlice) Encode(value any, buf []byte) (newBuf []byte return buf, nil } -type encodePlanTextCodecRune struct{} - -func (encodePlanTextCodecRune) Encode(value any, buf []byte) (newBuf []byte, err error) { - r := value.(rune) - buf = append(buf, string(r)...) - return buf, nil -} - type encodePlanTextCodecStringer struct{} func (encodePlanTextCodecStringer) Encode(value any, buf []byte) (newBuf []byte, err error) { @@ -169,8 +158,6 @@ func (TextCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan return scanPlanAnyToByteScanner{} case TextScanner: return scanPlanTextAnyToTextScanner{} - case *rune: - return scanPlanTextAnyToRune{} } } @@ -223,24 +210,6 @@ func (scanPlanAnyToByteScanner) Scan(src []byte, dst any) error { return p.ScanBytes(src) } -type scanPlanTextAnyToRune struct{} - -func (scanPlanTextAnyToRune) Scan(src []byte, dst any) error { - if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) - } - - r, size := utf8.DecodeRune(src) - if size != len(src) { - return fmt.Errorf("cannot scan %v into %T: more than one rune received", src, dst) - } - - p := (dst).(*rune) - *p = r - - return nil -} - type scanPlanTextAnyToTextScanner struct{} func (scanPlanTextAnyToTextScanner) Scan(src []byte, dst any) error { diff --git a/pgtype/text_test.go b/pgtype/text_test.go index 1d717f49..eb5d005e 100644 --- a/pgtype/text_test.go +++ b/pgtype/text_test.go @@ -33,11 +33,6 @@ func TestTextCodec(t *testing.T) { {"foo", new(string), isExpectedEq("foo")}, {someFmtStringer{}, new(string), isExpectedEq("some fmt.Stringer")}, }) - - // rune requires known OID because otherwise it is considered an int32. - pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, pgTypeName, []pgxtest.ValueRoundTripTest{ - {rune('R'), new(rune), isExpectedEq(rune('R'))}, - }) } } From 5714896b1047d2415448d32caa87547625509783 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 May 2022 11:06:44 -0500 Subject: [PATCH 1035/1158] Restructure sending messages Use an internal buffer in pgproto3.Frontend and pgproto3.Backend instead of directly writing to the underlying net.Conn. This will allow tracing messages as well as simplify pipeline mode. --- internal/pgmock/pgmock.go | 3 +- pgconn/config.go | 2 +- pgconn/errors.go | 17 ------ pgconn/frontend_test.go | 70 ---------------------- pgconn/pgconn.go | 121 +++++++++++--------------------------- pgconn/pgconn_test.go | 43 -------------- pgproto3/backend.go | 28 +++++++-- pgproto3/frontend.go | 28 +++++++-- pgproto3/pgproto3.go | 19 ++++++ 9 files changed, 105 insertions(+), 226 deletions(-) delete mode 100644 pgconn/frontend_test.go diff --git a/internal/pgmock/pgmock.go b/internal/pgmock/pgmock.go index 97dd024d..c82d7ffc 100644 --- a/internal/pgmock/pgmock.go +++ b/internal/pgmock/pgmock.go @@ -97,7 +97,8 @@ type sendMessageStep struct { } func (e *sendMessageStep) Step(backend *pgproto3.Backend) error { - return backend.Send(e.msg) + backend.Send(e.msg) + return backend.Flush() } func SendMessage(msg pgproto3.BackendMessage) Step { diff --git a/pgconn/config.go b/pgconn/config.go index 8a22d4ce..bfec11d4 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -222,7 +222,7 @@ func ParseConfig(connString string) (*Config, error) { User: settings["user"], Password: settings["password"], RuntimeParams: make(map[string]string), - BuildFrontend: func(r io.Reader, w io.Writer) Frontend { + BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { return pgproto3.NewFrontend(r, w) }, } diff --git a/pgconn/errors.go b/pgconn/errors.go index a32b29c9..030f7e0a 100644 --- a/pgconn/errors.go +++ b/pgconn/errors.go @@ -178,23 +178,6 @@ func newContextAlreadyDoneError(ctx context.Context) (err error) { return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}} } -type writeError struct { - err error - safeToRetry bool -} - -func (e *writeError) Error() string { - return fmt.Sprintf("write failed: %s", e.err.Error()) -} - -func (e *writeError) SafeToRetry() bool { - return e.safeToRetry -} - -func (e *writeError) Unwrap() error { - return e.err -} - func redactPW(connString string) string { if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { if u, err := url.Parse(connString); err == nil { diff --git a/pgconn/frontend_test.go b/pgconn/frontend_test.go deleted file mode 100644 index 439d3251..00000000 --- a/pgconn/frontend_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package pgconn_test - -import ( - "context" - "io" - "os" - "testing" - - "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgproto3" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// frontendWrapper allows to hijack a regular frontend, and inject a specific response -type frontendWrapper struct { - front pgconn.Frontend - - msg pgproto3.BackendMessage -} - -// frontendWrapper implements the pgconn.Frontend interface -var _ pgconn.Frontend = (*frontendWrapper)(nil) - -func (f *frontendWrapper) Receive() (pgproto3.BackendMessage, error) { - if f.msg != nil { - return f.msg, nil - } - - return f.front.Receive() -} - -func TestFrontendFatalErrExec(t *testing.T) { - t.Parallel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - - buildFrontend := config.BuildFrontend - var front *frontendWrapper - - config.BuildFrontend = func(r io.Reader, w io.Writer) pgconn.Frontend { - wrapped := buildFrontend(r, w) - front = &frontendWrapper{wrapped, nil} - - return front - } - - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - require.NotNil(t, conn) - require.NotNil(t, front) - - // set frontend to return a "FATAL" message on next call - front.msg = &pgproto3.ErrorResponse{Severity: "FATAL", Message: "unit testing fatal error"} - - _, err = conn.Exec(context.Background(), "SELECT 1").ReadAll() - assert.Error(t, err) - - err = conn.Close(context.Background()) - assert.NoError(t, err) - - select { - case <-conn.CleanupDone(): - t.Log("ok, CleanupDone() is not blocking") - - default: - assert.Fail(t, "connection closed but CleanupDone() still blocking") - } -} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index a23b1daf..2cbf8c50 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -29,8 +29,6 @@ const ( connStatusBusy ) -const wbufLen = 1024 - // Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from // LISTEN/NOTIFY notification. type Notice PgError @@ -50,7 +48,7 @@ type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) type LookupFunc func(ctx context.Context, host string) (addrs []string, err error) // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. -type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend +type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin @@ -64,11 +62,6 @@ type NoticeHandler func(*PgConn, *Notice) // notice event. type NotificationHandler func(*PgConn, *Notification) -// Frontend used to receive messages from backend. -type Frontend interface { - Receive() (pgproto3.BackendMessage, error) -} - // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn // the underlying TCP or unix domain socket connection @@ -76,7 +69,7 @@ type PgConn struct { secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server txStatus byte - frontend Frontend + frontend *pgproto3.Frontend config *Config @@ -90,7 +83,6 @@ type PgConn struct { peekedMsg pgproto3.BackendMessage // Reusable / preallocated resources - wbuf []byte // write buffer resultReader ResultReader multiResultReader MultiResultReader contextWatcher *ctxwatch.ContextWatcher @@ -230,7 +222,6 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config - pgConn.wbuf = make([]byte, 0, wbufLen) pgConn.cleanupDone = make(chan struct{}) var err error @@ -282,7 +273,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig startupMsg.Parameters["database"] = config.Database } - if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { + pgConn.frontend.Send(&startupMsg) + if err := pgConn.frontend.Flush(); err != nil { pgConn.conn.Close() return nil, &connectError{config: config, msg: "failed to write startup message", err: err} } @@ -383,9 +375,8 @@ func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { } func (pgConn *PgConn) txPasswordMessage(password string) (err error) { - msg := &pgproto3.PasswordMessage{Password: password} - _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf)) - return err + pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password}) + return pgConn.frontend.Flush() } func hexMD5(s string) string { @@ -412,36 +403,6 @@ func (pgConn *PgConn) signalMessage() chan struct{} { return ch } -// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as -// error to call SendBytes while reading the result of a query. -// -// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. -// See https://www.postgresql.org/docs/current/protocol.html. -func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { - if err := pgConn.lock(); err != nil { - return err - } - defer pgConn.unlock() - - if ctx != context.Background() { - select { - case <-ctx.Done(): - return newContextAlreadyDoneError(ctx) - default: - } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() - } - - n, err := pgConn.conn.Write(buf) - if err != nil { - pgConn.asyncClose() - return &writeError{err: err, safeToRetry: n == 0} - } - - return nil -} - // ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the // connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages // are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger @@ -797,15 +758,13 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ defer pgConn.contextWatcher.Unwatch() } - buf := pgConn.wbuf - buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) - buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) - buf = (&pgproto3.Sync{}).Encode(buf) - - n, err := pgConn.conn.Write(buf) + pgConn.frontend.Send(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) + pgConn.frontend.Send(&pgproto3.Describe{ObjectType: 'S', Name: name}) + pgConn.frontend.Send(&pgproto3.Sync{}) + err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() - return nil, &writeError{err: err, safeToRetry: n == 0} + return nil, err } psd := &StatementDescription{Name: name, SQL: sql} @@ -971,15 +930,13 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { pgConn.contextWatcher.Watch(ctx) } - buf := pgConn.wbuf - buf = (&pgproto3.Query{String: sql}).Encode(buf) - - n, err := pgConn.conn.Write(buf) + pgConn.frontend.Send(&pgproto3.Query{String: sql}) + err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() pgConn.contextWatcher.Unwatch() multiResult.closed = true - multiResult.err = &writeError{err: err, safeToRetry: n == 0} + multiResult.err = err pgConn.unlock() return multiResult } @@ -1045,11 +1002,10 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return result } - buf := pgConn.wbuf - buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) - buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + pgConn.frontend.Send(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) + pgConn.frontend.Send(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) - pgConn.execExtendedSuffix(buf, result) + pgConn.execExtendedSuffix(result) return result } @@ -1072,10 +1028,9 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } - buf := pgConn.wbuf - buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + pgConn.frontend.Send(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) - pgConn.execExtendedSuffix(buf, result) + pgConn.execExtendedSuffix(result) return result } @@ -1115,15 +1070,15 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } -func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { - buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) - buf = (&pgproto3.Execute{}).Encode(buf) - buf = (&pgproto3.Sync{}).Encode(buf) +func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { + pgConn.frontend.Send(&pgproto3.Describe{ObjectType: 'P'}) + pgConn.frontend.Send(&pgproto3.Execute{}) + pgConn.frontend.Send(&pgproto3.Sync{}) - n, err := pgConn.conn.Write(buf) + err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() - result.concludeCommand(CommandTag{}, &writeError{err: err, safeToRetry: n == 0}) + result.concludeCommand(CommandTag{}, err) pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() @@ -1151,14 +1106,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm } // Send copy to command - buf := pgConn.wbuf - buf = (&pgproto3.Query{String: sql}).Encode(buf) + pgConn.frontend.Send(&pgproto3.Query{String: sql}) - n, err := pgConn.conn.Write(buf) + err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() pgConn.unlock() - return CommandTag{}, &writeError{err: err, safeToRetry: n == 0} + return CommandTag{}, err } // Read results @@ -1211,13 +1165,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } // Send copy to command - buf := pgConn.wbuf - buf = (&pgproto3.Query{String: sql}).Encode(buf) + pgConn.frontend.Send(&pgproto3.Query{String: sql}) - n, err := pgConn.conn.Write(buf) + err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() - return CommandTag{}, &writeError{err: err, safeToRetry: n == 0} + return CommandTag{}, err } // Send copy data @@ -1280,15 +1233,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } close(abortCopyChan) - buf = buf[:0] if copyErr == io.EOF || pgErr != nil { - copyDone := &pgproto3.CopyDone{} - buf = copyDone.Encode(buf) + pgConn.frontend.Send(&pgproto3.CopyDone{}) } else { - copyFail := &pgproto3.CopyFail{Message: copyErr.Error()} - buf = copyFail.Encode(buf) + pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()}) } - _, err = pgConn.conn.Write(buf) + err = pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() return CommandTag{}, err @@ -1692,7 +1642,7 @@ type HijackedConn struct { SecretKey uint32 // key to use to send a cancel query message to the server ParameterStatuses map[string]string // parameters that have been reported by the server TxStatus byte - Frontend Frontend + Frontend *pgproto3.Frontend Config *Config } @@ -1736,7 +1686,6 @@ func Construct(hc *HijackedConn) (*PgConn, error) { status: connStatusIdle, - wbuf: make([]byte, 0, wbufLen), cleanupDone: make(chan struct{}), } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 3ae0d1d4..fdce6e7d 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -1915,49 +1915,6 @@ func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) { } } -func TestConnSendBytesAndReceiveMessage(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the messages we expect. - - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) - - queryMsg := pgproto3.Query{String: "select 42"} - buf := queryMsg.Encode(nil) - - err = pgConn.SendBytes(ctx, buf) - require.NoError(t, err) - - msg, err := pgConn.ReceiveMessage(ctx) - require.NoError(t, err) - _, ok := msg.(*pgproto3.RowDescription) - require.True(t, ok) - - msg, err = pgConn.ReceiveMessage(ctx) - require.NoError(t, err) - _, ok = msg.(*pgproto3.DataRow) - require.True(t, ok) - - msg, err = pgConn.ReceiveMessage(ctx) - require.NoError(t, err) - _, ok = msg.(*pgproto3.CommandComplete) - require.True(t, ok) - - msg, err = pgConn.ReceiveMessage(ctx) - require.NoError(t, err) - _, ok = msg.(*pgproto3.ReadyForQuery) - require.True(t, ok) - - ensureConnValid(t, pgConn) -} - func TestHijackAndConstruct(t *testing.T) { t.Parallel() diff --git a/pgproto3/backend.go b/pgproto3/backend.go index b7db6f76..d619f7e7 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -11,6 +11,8 @@ type Backend struct { cr *chunkReader w io.Writer + wbuf []byte + // Frontend message flyweights bind Bind cancelRequest CancelRequest @@ -47,10 +49,28 @@ func NewBackend(r io.Reader, w io.Writer) *Backend { return &Backend{cr: cr, w: w} } -// Send sends a message to the frontend. -func (b *Backend) Send(msg BackendMessage) error { - _, err := b.w.Write(msg.Encode(nil)) - return err +// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is +// called. +func (b *Backend) Send(msg BackendMessage) { + b.wbuf = msg.Encode(b.wbuf) +} + +// Flush writes any pending messages to the frontend (i.e. the client). +func (b *Backend) Flush() error { + n, err := b.w.Write(b.wbuf) + + const maxLen = 1024 + if len(b.wbuf) > maxLen { + b.wbuf = make([]byte, 0, maxLen) + } else { + b.wbuf = b.wbuf[:0] + } + + if err != nil { + return &writeError{err: err, safeToRetry: n == 0} + } + + return nil } // ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 435275d6..beaaef5f 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -12,6 +12,8 @@ type Frontend struct { cr *chunkReader w io.Writer + wbuf []byte + // Backend message flyweights authenticationOk AuthenticationOk authenticationCleartextPassword AuthenticationCleartextPassword @@ -56,10 +58,28 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend { return &Frontend{cr: cr, w: w} } -// Send sends a message to the backend. -func (f *Frontend) Send(msg FrontendMessage) error { - _, err := f.w.Write(msg.Encode(nil)) - return err +// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is +// called. +func (f *Frontend) Send(msg FrontendMessage) { + f.wbuf = msg.Encode(f.wbuf) +} + +// Flush writes any pending messages to the backend (i.e. the server). +func (f *Frontend) Flush() error { + n, err := f.w.Write(f.wbuf) + + const maxLen = 1024 + if len(f.wbuf) > maxLen { + f.wbuf = make([]byte, 0, maxLen) + } else { + f.wbuf = f.wbuf[:0] + } + + if err != nil { + return &writeError{err: err, safeToRetry: n == 0} + } + + return nil } func translateEOFtoErrUnexpectedEOF(err error) error { diff --git a/pgproto3/pgproto3.go b/pgproto3/pgproto3.go index 70c825e3..a0333aa5 100644 --- a/pgproto3/pgproto3.go +++ b/pgproto3/pgproto3.go @@ -17,11 +17,13 @@ type Message interface { Encode(dst []byte) []byte } +// FrontendMessage is a message sent by the frontend (i.e. the client). type FrontendMessage interface { Message Frontend() // no-op method to distinguish frontend from backend methods } +// BackendMessage is a message sent by the backend (i.e. the server). type BackendMessage interface { Message Backend() // no-op method to distinguish frontend from backend methods @@ -50,6 +52,23 @@ func (e *invalidMessageFormatErr) Error() string { return fmt.Sprintf("%s body is invalid", e.messageType) } +type writeError struct { + err error + safeToRetry bool +} + +func (e *writeError) Error() string { + return fmt.Sprintf("write failed: %s", e.err.Error()) +} + +func (e *writeError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *writeError) Unwrap() error { + return e.err +} + // getValueFromJSON gets the value from a protocol message representation in JSON. func getValueFromJSON(v map[string]string) ([]byte, error) { if v == nil { From f2e96156a03e6151c3c598f4e79d13bf27cf9f4c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 May 2022 14:43:04 -0500 Subject: [PATCH 1036/1158] Add message tracing --- pgconn/auth_scram.go | 6 +- pgconn/krb5.go | 3 +- pgconn/pgconn.go | 13 ++- pgproto3/backend.go | 18 +++- pgproto3/frontend.go | 19 +++- pgproto3/trace.go | 191 +++++++++++++++++++++++++++++++++++++++++ pgproto3/trace_test.go | 79 +++++++++++++++++ 7 files changed, 321 insertions(+), 8 deletions(-) create mode 100644 pgproto3/trace.go create mode 100644 pgproto3/trace_test.go diff --git a/pgconn/auth_scram.go b/pgconn/auth_scram.go index de13c687..37b602d6 100644 --- a/pgconn/auth_scram.go +++ b/pgconn/auth_scram.go @@ -41,7 +41,8 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { AuthMechanism: "SCRAM-SHA-256", Data: sc.clientFirstMessage(), } - _, err = c.conn.Write(saslInitialResponse.Encode(nil)) + c.frontend.Send(saslInitialResponse) + err = c.frontend.Flush() if err != nil { return err } @@ -60,7 +61,8 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { saslResponse := &pgproto3.SASLResponse{ Data: []byte(sc.clientFinalMessage()), } - _, err = c.conn.Write(saslResponse.Encode(nil)) + c.frontend.Send(saslResponse) + err = c.frontend.Flush() if err != nil { return err } diff --git a/pgconn/krb5.go b/pgconn/krb5.go index 8dffc879..a4bca01f 100644 --- a/pgconn/krb5.go +++ b/pgconn/krb5.go @@ -61,7 +61,8 @@ func (c *PgConn) gssAuth() error { gssResponse := &pgproto3.GSSResponse{ Data: nextData, } - _, err = c.conn.Write(gssResponse.Encode(nil)) + c.frontend.Send(gssResponse) + err = c.frontend.Flush() if err != nil { return err } diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 2cbf8c50..935d3530 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -515,7 +515,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { return msg, nil } -// Conn returns the underlying net.Conn. +// Conn returns the underlying net.Conn. This rarely necessary. func (pgConn *PgConn) Conn() net.Conn { return pgConn.conn } @@ -542,6 +542,11 @@ func (pgConn *PgConn) SecretKey() uint32 { return pgConn.secretKey } +// Frontend returns the underlying *pgproto3.Frontend. This rarely necessary. +func (pgConn *PgConn) Frontend() *pgproto3.Frontend { + return pgConn.frontend +} + // Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The // underlying net.Conn.Close() will always be called regardless of any other errors. @@ -571,7 +576,8 @@ func (pgConn *PgConn) Close(ctx context.Context) error { // ignores errors. // // See https://github.com/jackc/pgx/issues/637 - pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) + pgConn.frontend.Send(&pgproto3.Terminate{}) + pgConn.frontend.Flush() return pgConn.conn.Close() } @@ -597,7 +603,8 @@ func (pgConn *PgConn) asyncClose() { pgConn.conn.SetDeadline(deadline) - pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) + pgConn.frontend.Send(&pgproto3.Terminate{}) + pgConn.frontend.Flush() }() } diff --git a/pgproto3/backend.go b/pgproto3/backend.go index d619f7e7..ba0be3d3 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -11,6 +11,10 @@ type Backend struct { cr *chunkReader w io.Writer + // MessageTracer is used to trace messages when Send or Receive is called. This means an outbound message is traced + // before it is actually transmitted (i.e. before Flush). + MessageTracer MessageTracer + wbuf []byte // Frontend message flyweights @@ -52,7 +56,11 @@ func NewBackend(r io.Reader, w io.Writer) *Backend { // Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is // called. func (b *Backend) Send(msg BackendMessage) { + prevLen := len(b.wbuf) b.wbuf = msg.Encode(b.wbuf) + if b.MessageTracer != nil { + b.MessageTracer.TraceMessage('B', int32(len(b.wbuf)-prevLen), msg) + } } // Flush writes any pending messages to the frontend (i.e. the client). @@ -193,7 +201,15 @@ func (b *Backend) Receive() (FrontendMessage, error) { b.partialMsg = false err = msg.Decode(msgBody) - return msg, err + if err != nil { + return nil, err + } + + if b.MessageTracer != nil { + b.MessageTracer.TraceMessage('F', int32(5+len(msgBody)), msg) + } + + return msg, nil } // SetAuthType sets the authentication type in the backend. diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index beaaef5f..342a0ddd 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -12,6 +12,11 @@ type Frontend struct { cr *chunkReader w io.Writer + // MessageTracer is used to trace messages when Send or Receive is called. This means an outbound message is traced + // before it is actually transmitted (i.e. before Flush). It is safe to change this variable when the Frontend is + // idle. Setting and unsetting MessageTracer provides equivalent functionality to PQtrace and PQuntrace in libpq. + MessageTracer MessageTracer + wbuf []byte // Backend message flyweights @@ -61,7 +66,11 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend { // Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is // called. func (f *Frontend) Send(msg FrontendMessage) { + prevLen := len(f.wbuf) f.wbuf = msg.Encode(f.wbuf) + if f.MessageTracer != nil { + f.MessageTracer.TraceMessage('F', int32(len(f.wbuf)-prevLen), msg) + } } // Flush writes any pending messages to the backend (i.e. the server). @@ -166,7 +175,15 @@ func (f *Frontend) Receive() (BackendMessage, error) { } err = msg.Decode(msgBody) - return msg, err + if err != nil { + return nil, err + } + + if f.MessageTracer != nil { + f.MessageTracer.TraceMessage('B', int32(5+len(msgBody)), msg) + } + + return msg, nil } // Authentication message type constants. diff --git a/pgproto3/trace.go b/pgproto3/trace.go new file mode 100644 index 00000000..b35ecdb6 --- /dev/null +++ b/pgproto3/trace.go @@ -0,0 +1,191 @@ +package pgproto3 + +import ( + "bytes" + "fmt" + "io" + "strings" + "time" +) + +// MessageTracer is an interface that traces the messages send to and from a Backend or Frontend. +type MessageTracer interface { + // TraceMessage tracks the sending or receiving of a message. sender is either 'F' for frontend or 'B' for backend. + TraceMessage(sender byte, encodedLen int32, msg Message) +} + +// LibpqMessageTracer is a MessageTracer that roughly mimics the format produced by the libpq C function PQtrace. +type LibpqMessageTracer struct { + Writer io.Writer + + // SuppressTimestamps prevents printing of timestamps. + SuppressTimestamps bool + + // RegressMode redacts fields that may be vary between executions. + RegressMode bool +} + +func (t *LibpqMessageTracer) TraceMessage(sender byte, encodedLen int32, msg Message) { + buf := &bytes.Buffer{} + + if !t.SuppressTimestamps { + now := time.Now() + buf.WriteString(now.Format("2006-01-02 15:04:05.000000")) + buf.WriteByte('\t') + } + + buf.WriteByte(sender) + buf.WriteByte('\t') + + switch msg := msg.(type) { + case *AuthenticationCleartextPassword: + buf.WriteString("AuthenticationCleartextPassword") + case *AuthenticationGSS: + buf.WriteString("AuthenticationGSS") + case *AuthenticationGSSContinue: + buf.WriteString("AuthenticationGSSContinue") + case *AuthenticationMD5Password: + buf.WriteString("AuthenticationMD5Password") + case *AuthenticationOk: + buf.WriteString("AuthenticationOk") + case *AuthenticationSASL: + buf.WriteString("AuthenticationSASL") + case *AuthenticationSASLContinue: + buf.WriteString("AuthenticationSASLContinue") + case *AuthenticationSASLFinal: + buf.WriteString("AuthenticationSASLFinal") + case *BackendKeyData: + if t.RegressMode { + buf.WriteString("BackendKeyData\t NNNN NNNN") + } else { + fmt.Fprintf(buf, "BackendKeyData\t %d %d", msg.ProcessID, msg.SecretKey) + } + case *Bind: + fmt.Fprintf(buf, "Bind\t %s %s %d", traceDoubleQuotedString([]byte(msg.DestinationPortal)), traceDoubleQuotedString([]byte(msg.PreparedStatement)), len(msg.ParameterFormatCodes)) + for _, fc := range msg.ParameterFormatCodes { + fmt.Fprintf(buf, " %d", fc) + } + fmt.Fprintf(buf, " %d", len(msg.Parameters)) + for _, p := range msg.Parameters { + fmt.Fprintf(buf, " %s", traceSingleQuotedString(p)) + } + fmt.Fprintf(buf, " %d", len(msg.ResultFormatCodes)) + for _, fc := range msg.ResultFormatCodes { + fmt.Fprintf(buf, " %d", fc) + } + case *BindComplete: + buf.WriteString("BindComplete") + case *CancelRequest: + buf.WriteString("CancelRequest") + case *Close: + buf.WriteString("Close") + case *CloseComplete: + buf.WriteString("CloseComplete") + case *CommandComplete: + fmt.Fprintf(buf, "CommandComplete\t %s", traceDoubleQuotedString(msg.CommandTag)) + case *CopyBothResponse: + buf.WriteString("CopyBothResponse") + case *CopyData: + buf.WriteString("CopyData") + case *CopyDone: + buf.WriteString("CopyDone") + case *CopyFail: + fmt.Fprintf(buf, "CopyFail\t %s", traceDoubleQuotedString([]byte(msg.Message))) + case *CopyInResponse: + buf.WriteString("CopyInResponse") + case *CopyOutResponse: + buf.WriteString("CopyOutResponse") + case *DataRow: + fmt.Fprintf(buf, "DataRow\t %d", len(msg.Values)) + for _, v := range msg.Values { + if v == nil { + buf.WriteString(" -1") + } else { + fmt.Fprintf(buf, " %d %s", len(v), traceSingleQuotedString(v)) + } + } + case *Describe: + fmt.Fprintf(buf, "Describe\t %c %s", msg.ObjectType, traceDoubleQuotedString([]byte(msg.Name))) + case *EmptyQueryResponse: + buf.WriteString("EmptyQueryResponse") + case *ErrorResponse: + buf.WriteString("ErrorResponse") + case *Execute: + fmt.Fprintf(buf, "Execute\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) + case *Flush: + buf.WriteString("Flush") + case *FunctionCall: + buf.WriteString("FunctionCall") + case *FunctionCallResponse: + buf.WriteString("FunctionCallResponse") + case *GSSEncRequest: + buf.WriteString("GSSEncRequest") + case *NoData: + buf.WriteString("NoData") + case *NoticeResponse: + buf.WriteString("NoticeResponse") + case *NotificationResponse: + fmt.Fprintf(buf, "NotificationResponse\t %d %s %s", msg.PID, traceDoubleQuotedString([]byte(msg.Channel)), traceDoubleQuotedString([]byte(msg.Payload))) + case *ParameterDescription: + buf.WriteString("ParameterDescription") + case *ParameterStatus: + fmt.Fprintf(buf, "ParameterStatus\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Value))) + case *Parse: + fmt.Fprintf(buf, "Parse\t %s %s %d", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Query)), len(msg.ParameterOIDs)) + for _, oid := range msg.ParameterOIDs { + fmt.Fprintf(buf, " %d", oid) + } + case *ParseComplete: + buf.WriteString("ParseComplete") + case *PortalSuspended: + buf.WriteString("PortalSuspended") + case *Query: + buf.WriteString("Query\t") + fmt.Fprintf(buf, ` "%s"`, msg.String) + case *ReadyForQuery: + fmt.Fprintf(buf, "ReadyForQuery\t %c", msg.TxStatus) + case *RowDescription: + buf.WriteString("RowDescription\t") + fmt.Fprintf(buf, " %d", len(msg.Fields)) + for _, fd := range msg.Fields { + fmt.Fprintf(buf, ` %s %d %d %d %d %d %d`, traceDoubleQuotedString(fd.Name), fd.TableOID, fd.TableAttributeNumber, fd.DataTypeOID, fd.DataTypeSize, fd.TypeModifier, fd.Format) + } + case *SSLRequest: + buf.WriteString("SSLRequest") + case *StartupMessage: + buf.WriteString("StartupMessage") + case *Sync: + buf.WriteString("Sync") + case *Terminate: + buf.WriteString("Terminate") + default: + buf.WriteString("Unknown") + } + + buf.WriteByte('\n') + buf.WriteTo(t.Writer) +} + +// traceDoubleQuotedString returns buf as a double-quoted string without any escaping. It is roughly equivalent to +// pqTraceOutputString in libpq. +func traceDoubleQuotedString(buf []byte) string { + return `"` + string(buf) + `"` +} + +// traceSingleQuotedString returns buf as a single-quoted string with non-printable characters hex-escaped. It is +// roughly equivalent to pqTraceOutputNchar in libpq. +func traceSingleQuotedString(buf []byte) string { + sb := &strings.Builder{} + + sb.WriteByte('\'') + for _, b := range buf { + if b < 32 || b > 126 { + fmt.Fprintf(sb, `\x%x`, b) + } else { + sb.WriteByte(b) + } + } + sb.WriteByte('\'') + + return sb.String() +} diff --git a/pgproto3/trace_test.go b/pgproto3/trace_test.go new file mode 100644 index 00000000..f78bd346 --- /dev/null +++ b/pgproto3/trace_test.go @@ -0,0 +1,79 @@ +package pgproto3_test + +import ( + "bytes" + "context" + "io" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +func TestLibpqMessageTracer(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + traceOutput := &bytes.Buffer{} + + config.BuildFrontend = func(r io.Reader, w io.Writer) *pgproto3.Frontend { + f := pgproto3.NewFrontend(r, w) + f.MessageTracer = &pgproto3.LibpqMessageTracer{ + Writer: traceOutput, + SuppressTimestamps: true, + RegressMode: true, + } + return f + } + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer conn.Close(ctx) + + result := conn.ExecParams(ctx, "select n from generate_series(1,5) n", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + expected := `F StartupMessage +B AuthenticationOk +B ParameterStatus "application_name" "" +B ParameterStatus "client_encoding" "UTF8" +B ParameterStatus "DateStyle" "ISO, MDY" +B ParameterStatus "default_transaction_read_only" "off" +B ParameterStatus "in_hot_standby" "off" +B ParameterStatus "integer_datetimes" "on" +B ParameterStatus "IntervalStyle" "postgres" +B ParameterStatus "is_superuser" "on" +B ParameterStatus "server_encoding" "UTF8" +B ParameterStatus "server_version" "14.3" +B ParameterStatus "session_authorization" "jack" +B ParameterStatus "standard_conforming_strings" "on" +B ParameterStatus "TimeZone" "America/Chicago" +B BackendKeyData NNNN NNNN +B ReadyForQuery I +F Parse "" "select n from generate_series(1,5) n" 0 +F Bind "" "" 0 0 0 +F Describe P "" +F Execute "" 0 +F Sync +B ParseComplete +B BindComplete +B RowDescription 1 "n" 0 0 23 4 -1 0 +B DataRow 1 1 '1' +B DataRow 1 1 '2' +B DataRow 1 1 '3' +B DataRow 1 1 '4' +B DataRow 1 1 '5' +B CommandComplete "SELECT 5" +B ReadyForQuery I +` + + require.Equal(t, expected, traceOutput.String()) +} From b74c109f61fd17a89d2850ee379faf7c74754c8b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 May 2022 17:18:05 -0500 Subject: [PATCH 1037/1158] Optimize tracing The addition of tracing caused messages to escape to the heap. By avoiding interfaces the messages no longer escape. --- pgconn/pgconn.go | 26 +-- pgproto3/backend.go | 28 ++- pgproto3/frontend.go | 120 ++++++++++- pgproto3/trace.go | 446 ++++++++++++++++++++++++++++++++--------- pgproto3/trace_test.go | 7 +- 5 files changed, 497 insertions(+), 130 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 935d3530..af8aeb57 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -765,9 +765,9 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ defer pgConn.contextWatcher.Unwatch() } - pgConn.frontend.Send(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) - pgConn.frontend.Send(&pgproto3.Describe{ObjectType: 'S', Name: name}) - pgConn.frontend.Send(&pgproto3.Sync{}) + pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) + pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) + pgConn.frontend.SendSync(&pgproto3.Sync{}) err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() @@ -937,7 +937,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { pgConn.contextWatcher.Watch(ctx) } - pgConn.frontend.Send(&pgproto3.Query{String: sql}) + pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() @@ -1009,8 +1009,8 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return result } - pgConn.frontend.Send(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) - pgConn.frontend.Send(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + pgConn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) + pgConn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) pgConn.execExtendedSuffix(result) @@ -1035,7 +1035,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } - pgConn.frontend.Send(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + pgConn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) pgConn.execExtendedSuffix(result) @@ -1078,9 +1078,9 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by } func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { - pgConn.frontend.Send(&pgproto3.Describe{ObjectType: 'P'}) - pgConn.frontend.Send(&pgproto3.Execute{}) - pgConn.frontend.Send(&pgproto3.Sync{}) + pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) + pgConn.frontend.SendExecute(&pgproto3.Execute{}) + pgConn.frontend.SendSync(&pgproto3.Sync{}) err := pgConn.frontend.Flush() if err != nil { @@ -1113,7 +1113,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm } // Send copy to command - pgConn.frontend.Send(&pgproto3.Query{String: sql}) + pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) err := pgConn.frontend.Flush() if err != nil { @@ -1172,7 +1172,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } // Send copy to command - pgConn.frontend.Send(&pgproto3.Query{String: sql}) + pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) err := pgConn.frontend.Flush() if err != nil { @@ -1196,7 +1196,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co buf = buf[0 : n+5] pgio.SetInt32(buf[sp:], int32(n+4)) - _, writeErr := pgConn.conn.Write(buf) + writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf) if writeErr != nil { // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. pgConn.conn.Close() diff --git a/pgproto3/backend.go b/pgproto3/backend.go index ba0be3d3..09aeb7c8 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -1,6 +1,7 @@ package pgproto3 import ( + "bytes" "encoding/binary" "fmt" "io" @@ -11,9 +12,9 @@ type Backend struct { cr *chunkReader w io.Writer - // MessageTracer is used to trace messages when Send or Receive is called. This means an outbound message is traced + // tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced // before it is actually transmitted (i.e. before Flush). - MessageTracer MessageTracer + tracer *tracer wbuf []byte @@ -58,8 +59,8 @@ func NewBackend(r io.Reader, w io.Writer) *Backend { func (b *Backend) Send(msg BackendMessage) { prevLen := len(b.wbuf) b.wbuf = msg.Encode(b.wbuf) - if b.MessageTracer != nil { - b.MessageTracer.TraceMessage('B', int32(len(b.wbuf)-prevLen), msg) + if b.tracer != nil { + b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg) } } @@ -81,6 +82,21 @@ func (b *Backend) Flush() error { return nil } +// Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function +// PQtrace. +func (b *Backend) Trace(w io.Writer, options TracerOptions) { + b.tracer = &tracer{ + w: w, + buf: &bytes.Buffer{}, + TracerOptions: options, + } +} + +// Untrace stops tracing. +func (b *Backend) Untrace() { + b.tracer = nil +} + // ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method // because the initial connection message is "special" and does not include the message type as the first byte. This // will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest. @@ -205,8 +221,8 @@ func (b *Backend) Receive() (FrontendMessage, error) { return nil, err } - if b.MessageTracer != nil { - b.MessageTracer.TraceMessage('F', int32(5+len(msgBody)), msg) + if b.tracer != nil { + b.tracer.traceMessage('F', int32(5+len(msgBody)), msg) } return msg, nil diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 342a0ddd..321d0bf9 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -1,6 +1,7 @@ package pgproto3 import ( + "bytes" "encoding/binary" "errors" "fmt" @@ -12,10 +13,10 @@ type Frontend struct { cr *chunkReader w io.Writer - // MessageTracer is used to trace messages when Send or Receive is called. This means an outbound message is traced + // tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced // before it is actually transmitted (i.e. before Flush). It is safe to change this variable when the Frontend is - // idle. Setting and unsetting MessageTracer provides equivalent functionality to PQtrace and PQuntrace in libpq. - MessageTracer MessageTracer + // idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq. + tracer *tracer wbuf []byte @@ -65,16 +66,25 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend { // Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is // called. +// +// Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods +// such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an +// extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden +// behind an interface. func (f *Frontend) Send(msg FrontendMessage) { prevLen := len(f.wbuf) f.wbuf = msg.Encode(f.wbuf) - if f.MessageTracer != nil { - f.MessageTracer.TraceMessage('F', int32(len(f.wbuf)-prevLen), msg) + if f.tracer != nil { + f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg) } } // Flush writes any pending messages to the backend (i.e. the server). func (f *Frontend) Flush() error { + if len(f.wbuf) == 0 { + return nil + } + n, err := f.w.Write(f.wbuf) const maxLen = 1024 @@ -91,6 +101,102 @@ func (f *Frontend) Flush() error { return nil } +// Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function +// PQtrace. +func (f *Frontend) Trace(w io.Writer, options TracerOptions) { + f.tracer = &tracer{ + w: w, + buf: &bytes.Buffer{}, + TracerOptions: options, + } +} + +// Untrace stops tracing. +func (f *Frontend) Untrace() { + f.tracer = nil +} + +// SendBind sends a Bind message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendBind(msg *Bind) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendParse(msg *Parse) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendDescribe(msg *Describe) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendExecute sends a Execute message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendExecute(msg *Execute) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceExecute('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendSync sends a Sync message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendSync(msg *Sync) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendQuery sends a Query message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendQuery(msg *Query) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendUnbufferedEncodedCopyData immediately sends an encoded CopyData message to the backend (i.e. the server). This method +// is more efficient than sending a CopyData message with Send as the message data is not copied to the internal buffer +// before being written out. The internal buffer is flushed before the message is sent. +func (f *Frontend) SendUnbufferedEncodedCopyData(msg []byte) error { + err := f.Flush() + if err != nil { + return err + } + + n, err := f.w.Write(msg) + if err != nil { + return &writeError{err: err, safeToRetry: n == 0} + } + + if f.tracer != nil { + f.tracer.traceCopyData('F', int32(len(msg)-1), &CopyData{}) + } + + return nil +} + func translateEOFtoErrUnexpectedEOF(err error) error { if err == io.EOF { return io.ErrUnexpectedEOF @@ -179,8 +285,8 @@ func (f *Frontend) Receive() (BackendMessage, error) { return nil, err } - if f.MessageTracer != nil { - f.MessageTracer.TraceMessage('B', int32(5+len(msgBody)), msg) + if f.tracer != nil { + f.tracer.traceMessage('B', int32(5+len(msgBody)), msg) } return msg, nil diff --git a/pgproto3/trace.go b/pgproto3/trace.go index b35ecdb6..704b2ee7 100644 --- a/pgproto3/trace.go +++ b/pgproto3/trace.go @@ -8,16 +8,16 @@ import ( "time" ) -// MessageTracer is an interface that traces the messages send to and from a Backend or Frontend. -type MessageTracer interface { - // TraceMessage tracks the sending or receiving of a message. sender is either 'F' for frontend or 'B' for backend. - TraceMessage(sender byte, encodedLen int32, msg Message) +// tracer traces the messages send to and from a Backend or Frontend. The format it produces roughly mimics the +// format produced by the libpq C function PQtrace. +type tracer struct { + w io.Writer + buf *bytes.Buffer + TracerOptions } -// LibpqMessageTracer is a MessageTracer that roughly mimics the format produced by the libpq C function PQtrace. -type LibpqMessageTracer struct { - Writer io.Writer - +// TracerOptions controls tracing behavior. It is roughly equivalent to the libpq function PQsetTraceFlags. +type TracerOptions struct { // SuppressTimestamps prevents printing of timestamps. SuppressTimestamps bool @@ -25,148 +25,394 @@ type LibpqMessageTracer struct { RegressMode bool } -func (t *LibpqMessageTracer) TraceMessage(sender byte, encodedLen int32, msg Message) { - buf := &bytes.Buffer{} - - if !t.SuppressTimestamps { - now := time.Now() - buf.WriteString(now.Format("2006-01-02 15:04:05.000000")) - buf.WriteByte('\t') - } - - buf.WriteByte(sender) - buf.WriteByte('\t') - +func (t *tracer) traceMessage(sender byte, encodedLen int32, msg Message) { switch msg := msg.(type) { case *AuthenticationCleartextPassword: - buf.WriteString("AuthenticationCleartextPassword") + t.traceAuthenticationCleartextPassword(sender, encodedLen, msg) case *AuthenticationGSS: - buf.WriteString("AuthenticationGSS") + t.traceAuthenticationGSS(sender, encodedLen, msg) case *AuthenticationGSSContinue: - buf.WriteString("AuthenticationGSSContinue") + t.traceAuthenticationGSSContinue(sender, encodedLen, msg) case *AuthenticationMD5Password: - buf.WriteString("AuthenticationMD5Password") + t.traceAuthenticationMD5Password(sender, encodedLen, msg) case *AuthenticationOk: - buf.WriteString("AuthenticationOk") + t.traceAuthenticationOk(sender, encodedLen, msg) case *AuthenticationSASL: - buf.WriteString("AuthenticationSASL") + t.traceAuthenticationSASL(sender, encodedLen, msg) case *AuthenticationSASLContinue: - buf.WriteString("AuthenticationSASLContinue") + t.traceAuthenticationSASLContinue(sender, encodedLen, msg) case *AuthenticationSASLFinal: - buf.WriteString("AuthenticationSASLFinal") + t.traceAuthenticationSASLFinal(sender, encodedLen, msg) case *BackendKeyData: - if t.RegressMode { - buf.WriteString("BackendKeyData\t NNNN NNNN") - } else { - fmt.Fprintf(buf, "BackendKeyData\t %d %d", msg.ProcessID, msg.SecretKey) - } + t.traceBackendKeyData(sender, encodedLen, msg) case *Bind: - fmt.Fprintf(buf, "Bind\t %s %s %d", traceDoubleQuotedString([]byte(msg.DestinationPortal)), traceDoubleQuotedString([]byte(msg.PreparedStatement)), len(msg.ParameterFormatCodes)) - for _, fc := range msg.ParameterFormatCodes { - fmt.Fprintf(buf, " %d", fc) - } - fmt.Fprintf(buf, " %d", len(msg.Parameters)) - for _, p := range msg.Parameters { - fmt.Fprintf(buf, " %s", traceSingleQuotedString(p)) - } - fmt.Fprintf(buf, " %d", len(msg.ResultFormatCodes)) - for _, fc := range msg.ResultFormatCodes { - fmt.Fprintf(buf, " %d", fc) - } + t.traceBind(sender, encodedLen, msg) case *BindComplete: - buf.WriteString("BindComplete") + t.traceBindComplete(sender, encodedLen, msg) case *CancelRequest: - buf.WriteString("CancelRequest") + t.traceCancelRequest(sender, encodedLen, msg) case *Close: - buf.WriteString("Close") + t.traceClose(sender, encodedLen, msg) case *CloseComplete: - buf.WriteString("CloseComplete") + t.traceCloseComplete(sender, encodedLen, msg) case *CommandComplete: - fmt.Fprintf(buf, "CommandComplete\t %s", traceDoubleQuotedString(msg.CommandTag)) + t.traceCommandComplete(sender, encodedLen, msg) case *CopyBothResponse: - buf.WriteString("CopyBothResponse") + t.traceCopyBothResponse(sender, encodedLen, msg) case *CopyData: - buf.WriteString("CopyData") + t.traceCopyData(sender, encodedLen, msg) case *CopyDone: - buf.WriteString("CopyDone") + t.traceCopyDone(sender, encodedLen, msg) case *CopyFail: - fmt.Fprintf(buf, "CopyFail\t %s", traceDoubleQuotedString([]byte(msg.Message))) + t.traceCopyFail(sender, encodedLen, msg) case *CopyInResponse: - buf.WriteString("CopyInResponse") + t.traceCopyInResponse(sender, encodedLen, msg) case *CopyOutResponse: - buf.WriteString("CopyOutResponse") + t.traceCopyOutResponse(sender, encodedLen, msg) case *DataRow: - fmt.Fprintf(buf, "DataRow\t %d", len(msg.Values)) - for _, v := range msg.Values { - if v == nil { - buf.WriteString(" -1") - } else { - fmt.Fprintf(buf, " %d %s", len(v), traceSingleQuotedString(v)) - } - } + t.traceDataRow(sender, encodedLen, msg) case *Describe: - fmt.Fprintf(buf, "Describe\t %c %s", msg.ObjectType, traceDoubleQuotedString([]byte(msg.Name))) + t.traceDescribe(sender, encodedLen, msg) case *EmptyQueryResponse: - buf.WriteString("EmptyQueryResponse") + t.traceEmptyQueryResponse(sender, encodedLen, msg) case *ErrorResponse: - buf.WriteString("ErrorResponse") + t.traceErrorResponse(sender, encodedLen, msg) case *Execute: - fmt.Fprintf(buf, "Execute\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) + t.traceExecute(sender, encodedLen, msg) case *Flush: - buf.WriteString("Flush") + t.traceFlush(sender, encodedLen, msg) case *FunctionCall: - buf.WriteString("FunctionCall") + t.traceFunctionCall(sender, encodedLen, msg) case *FunctionCallResponse: - buf.WriteString("FunctionCallResponse") + t.traceFunctionCallResponse(sender, encodedLen, msg) case *GSSEncRequest: - buf.WriteString("GSSEncRequest") + t.traceGSSEncRequest(sender, encodedLen, msg) case *NoData: - buf.WriteString("NoData") + t.traceNoData(sender, encodedLen, msg) case *NoticeResponse: - buf.WriteString("NoticeResponse") + t.traceNoticeResponse(sender, encodedLen, msg) case *NotificationResponse: - fmt.Fprintf(buf, "NotificationResponse\t %d %s %s", msg.PID, traceDoubleQuotedString([]byte(msg.Channel)), traceDoubleQuotedString([]byte(msg.Payload))) + t.traceNotificationResponse(sender, encodedLen, msg) case *ParameterDescription: - buf.WriteString("ParameterDescription") + t.traceParameterDescription(sender, encodedLen, msg) case *ParameterStatus: - fmt.Fprintf(buf, "ParameterStatus\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Value))) + t.traceParameterStatus(sender, encodedLen, msg) case *Parse: - fmt.Fprintf(buf, "Parse\t %s %s %d", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Query)), len(msg.ParameterOIDs)) - for _, oid := range msg.ParameterOIDs { - fmt.Fprintf(buf, " %d", oid) - } + t.traceParse(sender, encodedLen, msg) case *ParseComplete: - buf.WriteString("ParseComplete") + t.traceParseComplete(sender, encodedLen, msg) case *PortalSuspended: - buf.WriteString("PortalSuspended") + t.tracePortalSuspended(sender, encodedLen, msg) case *Query: - buf.WriteString("Query\t") - fmt.Fprintf(buf, ` "%s"`, msg.String) + t.traceQuery(sender, encodedLen, msg) case *ReadyForQuery: - fmt.Fprintf(buf, "ReadyForQuery\t %c", msg.TxStatus) + t.traceReadyForQuery(sender, encodedLen, msg) case *RowDescription: - buf.WriteString("RowDescription\t") - fmt.Fprintf(buf, " %d", len(msg.Fields)) - for _, fd := range msg.Fields { - fmt.Fprintf(buf, ` %s %d %d %d %d %d %d`, traceDoubleQuotedString(fd.Name), fd.TableOID, fd.TableAttributeNumber, fd.DataTypeOID, fd.DataTypeSize, fd.TypeModifier, fd.Format) - } + t.traceRowDescription(sender, encodedLen, msg) case *SSLRequest: - buf.WriteString("SSLRequest") + t.traceSSLRequest(sender, encodedLen, msg) case *StartupMessage: - buf.WriteString("StartupMessage") + t.traceStartupMessage(sender, encodedLen, msg) case *Sync: - buf.WriteString("Sync") + t.traceSync(sender, encodedLen, msg) case *Terminate: - buf.WriteString("Terminate") + t.traceTerminate(sender, encodedLen, msg) default: - buf.WriteString("Unknown") + t.beginTrace(sender, encodedLen, "Unknown") + t.finishTrace() } - - buf.WriteByte('\n') - buf.WriteTo(t.Writer) } -// traceDoubleQuotedString returns buf as a double-quoted string without any escaping. It is roughly equivalent to +func (t *tracer) traceAuthenticationCleartextPassword(sender byte, encodedLen int32, msg *AuthenticationCleartextPassword) { + t.beginTrace(sender, encodedLen, "AuthenticationCleartextPassword") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationGSS(sender byte, encodedLen int32, msg *AuthenticationGSS) { + t.beginTrace(sender, encodedLen, "AuthenticationGSS") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationGSSContinue(sender byte, encodedLen int32, msg *AuthenticationGSSContinue) { + t.beginTrace(sender, encodedLen, "AuthenticationGSSContinue") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationMD5Password(sender byte, encodedLen int32, msg *AuthenticationMD5Password) { + t.beginTrace(sender, encodedLen, "AuthenticationMD5Password") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationOk(sender byte, encodedLen int32, msg *AuthenticationOk) { + t.beginTrace(sender, encodedLen, "AuthenticationOk") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationSASL(sender byte, encodedLen int32, msg *AuthenticationSASL) { + t.beginTrace(sender, encodedLen, "AuthenticationSASL") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationSASLContinue(sender byte, encodedLen int32, msg *AuthenticationSASLContinue) { + t.beginTrace(sender, encodedLen, "AuthenticationSASLContinue") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationSASLFinal(sender byte, encodedLen int32, msg *AuthenticationSASLFinal) { + t.beginTrace(sender, encodedLen, "AuthenticationSASLFinal") + t.finishTrace() +} + +func (t *tracer) traceBackendKeyData(sender byte, encodedLen int32, msg *BackendKeyData) { + t.beginTrace(sender, encodedLen, "BackendKeyData") + if t.RegressMode { + t.buf.WriteString("\t NNNN NNNN") + } else { + fmt.Fprintf(t.buf, "\t %d %d", msg.ProcessID, msg.SecretKey) + } + t.finishTrace() +} + +func (t *tracer) traceBind(sender byte, encodedLen int32, msg *Bind) { + t.beginTrace(sender, encodedLen, "Bind") + fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.DestinationPortal)), traceDoubleQuotedString([]byte(msg.PreparedStatement)), len(msg.ParameterFormatCodes)) + for _, fc := range msg.ParameterFormatCodes { + fmt.Fprintf(t.buf, " %d", fc) + } + fmt.Fprintf(t.buf, " %d", len(msg.Parameters)) + for _, p := range msg.Parameters { + fmt.Fprintf(t.buf, " %s", traceSingleQuotedString(p)) + } + fmt.Fprintf(t.buf, " %d", len(msg.ResultFormatCodes)) + for _, fc := range msg.ResultFormatCodes { + fmt.Fprintf(t.buf, " %d", fc) + } + t.finishTrace() +} + +func (t *tracer) traceBindComplete(sender byte, encodedLen int32, msg *BindComplete) { + t.beginTrace(sender, encodedLen, "BindComplete") + t.finishTrace() +} + +func (t *tracer) traceCancelRequest(sender byte, encodedLen int32, msg *CancelRequest) { + t.beginTrace(sender, encodedLen, "CancelRequest") + t.finishTrace() +} + +func (t *tracer) traceClose(sender byte, encodedLen int32, msg *Close) { + t.beginTrace(sender, encodedLen, "Close") + t.finishTrace() +} + +func (t *tracer) traceCloseComplete(sender byte, encodedLen int32, msg *CloseComplete) { + t.beginTrace(sender, encodedLen, "CloseComplete") + t.finishTrace() +} + +func (t *tracer) traceCommandComplete(sender byte, encodedLen int32, msg *CommandComplete) { + t.beginTrace(sender, encodedLen, "CommandComplete") + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString(msg.CommandTag)) + t.finishTrace() +} + +func (t *tracer) traceCopyBothResponse(sender byte, encodedLen int32, msg *CopyBothResponse) { + t.beginTrace(sender, encodedLen, "CopyBothResponse") + t.finishTrace() +} + +func (t *tracer) traceCopyData(sender byte, encodedLen int32, msg *CopyData) { + t.beginTrace(sender, encodedLen, "CopyData") + t.finishTrace() +} + +func (t *tracer) traceCopyDone(sender byte, encodedLen int32, msg *CopyDone) { + t.beginTrace(sender, encodedLen, "CopyDone") + t.finishTrace() +} + +func (t *tracer) traceCopyFail(sender byte, encodedLen int32, msg *CopyFail) { + t.beginTrace(sender, encodedLen, "CopyFail") + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.Message))) + t.finishTrace() +} + +func (t *tracer) traceCopyInResponse(sender byte, encodedLen int32, msg *CopyInResponse) { + t.beginTrace(sender, encodedLen, "CopyInResponse") + t.finishTrace() +} + +func (t *tracer) traceCopyOutResponse(sender byte, encodedLen int32, msg *CopyOutResponse) { + t.beginTrace(sender, encodedLen, "CopyOutResponse") + t.finishTrace() +} + +func (t *tracer) traceDataRow(sender byte, encodedLen int32, msg *DataRow) { + t.beginTrace(sender, encodedLen, "DataRow") + fmt.Fprintf(t.buf, "\t %d", len(msg.Values)) + for _, v := range msg.Values { + if v == nil { + t.buf.WriteString(" -1") + } else { + fmt.Fprintf(t.buf, " %d %s", len(v), traceSingleQuotedString(v)) + } + } + t.finishTrace() +} + +func (t *tracer) traceDescribe(sender byte, encodedLen int32, msg *Describe) { + t.beginTrace(sender, encodedLen, "Describe") + fmt.Fprintf(t.buf, "\t %c %s", msg.ObjectType, traceDoubleQuotedString([]byte(msg.Name))) + t.finishTrace() +} + +func (t *tracer) traceEmptyQueryResponse(sender byte, encodedLen int32, msg *EmptyQueryResponse) { + t.beginTrace(sender, encodedLen, "EmptyQueryResponse") + t.finishTrace() +} + +func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorResponse) { + t.beginTrace(sender, encodedLen, "ErrorResponse") + t.finishTrace() +} + +func (t *tracer) traceExecute(sender byte, encodedLen int32, msg *Execute) { + t.beginTrace(sender, encodedLen, "Execute") + fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) + t.finishTrace() +} + +func (t *tracer) traceFlush(sender byte, encodedLen int32, msg *Flush) { + t.beginTrace(sender, encodedLen, "Flush") + t.finishTrace() +} + +func (t *tracer) traceFunctionCall(sender byte, encodedLen int32, msg *FunctionCall) { + t.beginTrace(sender, encodedLen, "FunctionCall") + t.finishTrace() +} + +func (t *tracer) traceFunctionCallResponse(sender byte, encodedLen int32, msg *FunctionCallResponse) { + t.beginTrace(sender, encodedLen, "FunctionCallResponse") + t.finishTrace() +} + +func (t *tracer) traceGSSEncRequest(sender byte, encodedLen int32, msg *GSSEncRequest) { + t.beginTrace(sender, encodedLen, "GSSEncRequest") + t.finishTrace() +} + +func (t *tracer) traceNoData(sender byte, encodedLen int32, msg *NoData) { + t.beginTrace(sender, encodedLen, "NoData") + t.finishTrace() +} + +func (t *tracer) traceNoticeResponse(sender byte, encodedLen int32, msg *NoticeResponse) { + t.beginTrace(sender, encodedLen, "NoticeResponse") + t.finishTrace() +} + +func (t *tracer) traceNotificationResponse(sender byte, encodedLen int32, msg *NotificationResponse) { + t.beginTrace(sender, encodedLen, "NotificationResponse") + fmt.Fprintf(t.buf, "\t %d %s %s", msg.PID, traceDoubleQuotedString([]byte(msg.Channel)), traceDoubleQuotedString([]byte(msg.Payload))) + t.finishTrace() +} + +func (t *tracer) traceParameterDescription(sender byte, encodedLen int32, msg *ParameterDescription) { + t.beginTrace(sender, encodedLen, "ParameterDescription") + t.finishTrace() +} + +func (t *tracer) traceParameterStatus(sender byte, encodedLen int32, msg *ParameterStatus) { + t.beginTrace(sender, encodedLen, "ParameterStatus") + fmt.Fprintf(t.buf, "\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Value))) + t.finishTrace() +} + +func (t *tracer) traceParse(sender byte, encodedLen int32, msg *Parse) { + t.beginTrace(sender, encodedLen, "Parse") + fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Query)), len(msg.ParameterOIDs)) + for _, oid := range msg.ParameterOIDs { + fmt.Fprintf(t.buf, " %d", oid) + } + t.finishTrace() +} + +func (t *tracer) traceParseComplete(sender byte, encodedLen int32, msg *ParseComplete) { + t.beginTrace(sender, encodedLen, "ParseComplete") + t.finishTrace() +} + +func (t *tracer) tracePortalSuspended(sender byte, encodedLen int32, msg *PortalSuspended) { + t.beginTrace(sender, encodedLen, "PortalSuspended") + t.finishTrace() +} + +func (t *tracer) traceQuery(sender byte, encodedLen int32, msg *Query) { + t.beginTrace(sender, encodedLen, "Query") + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.String))) + t.finishTrace() +} + +func (t *tracer) traceReadyForQuery(sender byte, encodedLen int32, msg *ReadyForQuery) { + t.beginTrace(sender, encodedLen, "ReadyForQuery") + fmt.Fprintf(t.buf, "\t %c", msg.TxStatus) + t.finishTrace() +} + +func (t *tracer) traceRowDescription(sender byte, encodedLen int32, msg *RowDescription) { + t.beginTrace(sender, encodedLen, "RowDescription") + fmt.Fprintf(t.buf, "\t %d", len(msg.Fields)) + for _, fd := range msg.Fields { + fmt.Fprintf(t.buf, ` %s %d %d %d %d %d %d`, traceDoubleQuotedString(fd.Name), fd.TableOID, fd.TableAttributeNumber, fd.DataTypeOID, fd.DataTypeSize, fd.TypeModifier, fd.Format) + } + t.finishTrace() +} + +func (t *tracer) traceSSLRequest(sender byte, encodedLen int32, msg *SSLRequest) { + t.beginTrace(sender, encodedLen, "SSLRequest") + t.finishTrace() +} + +func (t *tracer) traceStartupMessage(sender byte, encodedLen int32, msg *StartupMessage) { + t.beginTrace(sender, encodedLen, "StartupMessage") + t.finishTrace() +} + +func (t *tracer) traceSync(sender byte, encodedLen int32, msg *Sync) { + t.beginTrace(sender, encodedLen, "Sync") + t.finishTrace() +} + +func (t *tracer) traceTerminate(sender byte, encodedLen int32, msg *Terminate) { + t.beginTrace(sender, encodedLen, "Terminate") + t.finishTrace() +} + +func (t *tracer) beginTrace(sender byte, encodedLen int32, msgType string) { + if !t.SuppressTimestamps { + now := time.Now() + t.buf.WriteString(now.Format("2006-01-02 15:04:05.000000")) + t.buf.WriteByte('\t') + } + + t.buf.WriteByte(sender) + t.buf.WriteByte('\t') + t.buf.WriteString(msgType) +} + +func (t *tracer) finishTrace() { + t.buf.WriteByte('\n') + t.buf.WriteTo(t.w) + + if t.buf.Cap() > 1024 { + t.buf = &bytes.Buffer{} + } else { + t.buf.Reset() + } +} + +// traceDoubleQuotedString returns t.buf as a double-quoted string without any escaping. It is roughly equivalent to // pqTraceOutputString in libpq. func traceDoubleQuotedString(buf []byte) string { return `"` + string(buf) + `"` diff --git a/pgproto3/trace_test.go b/pgproto3/trace_test.go index f78bd346..a4057008 100644 --- a/pgproto3/trace_test.go +++ b/pgproto3/trace_test.go @@ -13,7 +13,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestLibpqMessageTracer(t *testing.T) { +func TestTrace(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -26,11 +26,10 @@ func TestLibpqMessageTracer(t *testing.T) { config.BuildFrontend = func(r io.Reader, w io.Writer) *pgproto3.Frontend { f := pgproto3.NewFrontend(r, w) - f.MessageTracer = &pgproto3.LibpqMessageTracer{ - Writer: traceOutput, + f.Trace(traceOutput, pgproto3.TracerOptions{ SuppressTimestamps: true, RegressMode: true, - } + }) return f } From 67635f896c52cb0d9be8b2411a794e927e6da0dd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 May 2022 17:30:47 -0500 Subject: [PATCH 1038/1158] Fix output to include message size and add some docs --- pgproto3/doc.go | 7 +++++ pgproto3/trace.go | 3 ++ pgproto3/trace_test.go | 64 +++++++++++++++++++++--------------------- 3 files changed, 42 insertions(+), 32 deletions(-) diff --git a/pgproto3/doc.go b/pgproto3/doc.go index 8226dc98..e0e1cf87 100644 --- a/pgproto3/doc.go +++ b/pgproto3/doc.go @@ -1,4 +1,11 @@ // Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. // +// The primary interfaces are Frontend and Backend. They correspond to a client and server respectively. Messages are +// sent with Send (or a specialized Send variant). Messages are automatically bufferred to minimize small writes. Call +// Flush to ensure a message has actually been sent. +// +// The Trace method of Frontend and Backend can be used to examine the wire-level message traffic. It outputs in a +// similar format to the PQtrace function in libpq. +// // See https://www.postgresql.org/docs/current/protocol-message-formats.html for meanings of the different messages. package pgproto3 diff --git a/pgproto3/trace.go b/pgproto3/trace.go index 704b2ee7..d3edc4aa 100644 --- a/pgproto3/trace.go +++ b/pgproto3/trace.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "io" + "strconv" "strings" "time" ) @@ -399,6 +400,8 @@ func (t *tracer) beginTrace(sender byte, encodedLen int32, msgType string) { t.buf.WriteByte(sender) t.buf.WriteByte('\t') t.buf.WriteString(msgType) + t.buf.WriteByte('\t') + t.buf.WriteString(strconv.FormatInt(int64(encodedLen), 10)) } func (t *tracer) finishTrace() { diff --git a/pgproto3/trace_test.go b/pgproto3/trace_test.go index a4057008..0ace057b 100644 --- a/pgproto3/trace_test.go +++ b/pgproto3/trace_test.go @@ -40,38 +40,38 @@ func TestTrace(t *testing.T) { result := conn.ExecParams(ctx, "select n from generate_series(1,5) n", nil, nil, nil, nil).Read() require.NoError(t, result.Err) - expected := `F StartupMessage -B AuthenticationOk -B ParameterStatus "application_name" "" -B ParameterStatus "client_encoding" "UTF8" -B ParameterStatus "DateStyle" "ISO, MDY" -B ParameterStatus "default_transaction_read_only" "off" -B ParameterStatus "in_hot_standby" "off" -B ParameterStatus "integer_datetimes" "on" -B ParameterStatus "IntervalStyle" "postgres" -B ParameterStatus "is_superuser" "on" -B ParameterStatus "server_encoding" "UTF8" -B ParameterStatus "server_version" "14.3" -B ParameterStatus "session_authorization" "jack" -B ParameterStatus "standard_conforming_strings" "on" -B ParameterStatus "TimeZone" "America/Chicago" -B BackendKeyData NNNN NNNN -B ReadyForQuery I -F Parse "" "select n from generate_series(1,5) n" 0 -F Bind "" "" 0 0 0 -F Describe P "" -F Execute "" 0 -F Sync -B ParseComplete -B BindComplete -B RowDescription 1 "n" 0 0 23 4 -1 0 -B DataRow 1 1 '1' -B DataRow 1 1 '2' -B DataRow 1 1 '3' -B DataRow 1 1 '4' -B DataRow 1 1 '5' -B CommandComplete "SELECT 5" -B ReadyForQuery I + expected := `F StartupMessage 37 +B AuthenticationOk 9 +B ParameterStatus 23 "application_name" "" +B ParameterStatus 26 "client_encoding" "UTF8" +B ParameterStatus 24 "DateStyle" "ISO, MDY" +B ParameterStatus 39 "default_transaction_read_only" "off" +B ParameterStatus 24 "in_hot_standby" "off" +B ParameterStatus 26 "integer_datetimes" "on" +B ParameterStatus 28 "IntervalStyle" "postgres" +B ParameterStatus 21 "is_superuser" "on" +B ParameterStatus 26 "server_encoding" "UTF8" +B ParameterStatus 25 "server_version" "14.3" +B ParameterStatus 32 "session_authorization" "jack" +B ParameterStatus 36 "standard_conforming_strings" "on" +B ParameterStatus 30 "TimeZone" "America/Chicago" +B BackendKeyData 13 NNNN NNNN +B ReadyForQuery 6 I +F Parse 45 "" "select n from generate_series(1,5) n" 0 +F Bind 13 "" "" 0 0 0 +F Describe 7 P "" +F Execute 10 "" 0 +F Sync 5 +B ParseComplete 5 +B BindComplete 5 +B RowDescription 27 1 "n" 0 0 23 4 -1 0 +B DataRow 12 1 1 '1' +B DataRow 12 1 1 '2' +B DataRow 12 1 1 '3' +B DataRow 12 1 1 '4' +B DataRow 12 1 1 '5' +B CommandComplete 14 "SELECT 5" +B ReadyForQuery 6 I ` require.Equal(t, expected, traceOutput.String()) From b59cd505080fd8912db24fa206a456bf68c6ec4d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 23 May 2022 17:52:10 -0500 Subject: [PATCH 1039/1158] TestTrace enables tracing after connection established This avoids locking to a specific version of the server. --- pgproto3/trace_test.go | 42 ++++++++---------------------------------- 1 file changed, 8 insertions(+), 34 deletions(-) diff --git a/pgproto3/trace_test.go b/pgproto3/trace_test.go index 0ace057b..ee4bb376 100644 --- a/pgproto3/trace_test.go +++ b/pgproto3/trace_test.go @@ -3,7 +3,6 @@ package pgproto3_test import ( "bytes" "context" - "io" "os" "testing" "time" @@ -19,45 +18,20 @@ func TestTrace(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - - traceOutput := &bytes.Buffer{} - - config.BuildFrontend = func(r io.Reader, w io.Writer) *pgproto3.Frontend { - f := pgproto3.NewFrontend(r, w) - f.Trace(traceOutput, pgproto3.TracerOptions{ - SuppressTimestamps: true, - RegressMode: true, - }) - return f - } - - conn, err := pgconn.ConnectConfig(ctx, config) + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer conn.Close(ctx) + traceOutput := &bytes.Buffer{} + conn.Frontend().Trace(traceOutput, pgproto3.TracerOptions{ + SuppressTimestamps: true, + RegressMode: true, + }) + result := conn.ExecParams(ctx, "select n from generate_series(1,5) n", nil, nil, nil, nil).Read() require.NoError(t, result.Err) - expected := `F StartupMessage 37 -B AuthenticationOk 9 -B ParameterStatus 23 "application_name" "" -B ParameterStatus 26 "client_encoding" "UTF8" -B ParameterStatus 24 "DateStyle" "ISO, MDY" -B ParameterStatus 39 "default_transaction_read_only" "off" -B ParameterStatus 24 "in_hot_standby" "off" -B ParameterStatus 26 "integer_datetimes" "on" -B ParameterStatus 28 "IntervalStyle" "postgres" -B ParameterStatus 21 "is_superuser" "on" -B ParameterStatus 26 "server_encoding" "UTF8" -B ParameterStatus 25 "server_version" "14.3" -B ParameterStatus 32 "session_authorization" "jack" -B ParameterStatus 36 "standard_conforming_strings" "on" -B ParameterStatus 30 "TimeZone" "America/Chicago" -B BackendKeyData 13 NNNN NNNN -B ReadyForQuery 6 I -F Parse 45 "" "select n from generate_series(1,5) n" 0 + expected := `F Parse 45 "" "select n from generate_series(1,5) n" 0 F Bind 13 "" "" 0 0 0 F Describe 7 P "" F Execute 10 "" 0 From 55e0b4c30e14fef921cb0a23a9357937dd2de8d6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 23 May 2022 18:15:53 -0500 Subject: [PATCH 1040/1158] Skip CockroachDB in TestTrace --- pgproto3/trace_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pgproto3/trace_test.go b/pgproto3/trace_test.go index ee4bb376..904bbdfb 100644 --- a/pgproto3/trace_test.go +++ b/pgproto3/trace_test.go @@ -22,6 +22,10 @@ func TestTrace(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) + if conn.ParameterStatus("crdb_version") != "" { + t.Skip("Skipping message trace on CockroachDB as it varies slightly from PostgreSQL") + } + traceOutput := &bytes.Buffer{} conn.Frontend().Trace(traceOutput, pgproto3.TracerOptions{ SuppressTimestamps: true, From bfaea9e7ec1ce0741c7bf5e6a014bf2100b758f5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 24 May 2022 08:26:37 -0500 Subject: [PATCH 1041/1158] Fix rare race in CopyFrom --- pgconn/pgconn.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index af8aeb57..c8b41f84 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1184,8 +1184,11 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co abortCopyChan := make(chan struct{}) copyErrChan := make(chan error, 1) signalMessageChan := pgConn.signalMessage() + senderDoneChan := make(chan struct{}) go func() { + defer close(senderDoneChan) + buf := make([]byte, 0, 65536) buf = append(buf, 'd') sp := len(buf) @@ -1239,6 +1242,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } } close(abortCopyChan) + <-senderDoneChan if copyErr == io.EOF || pgErr != nil { pgConn.frontend.Send(&pgproto3.CopyDone{}) From 7ddbd74d5e5a52cd24f750cacfc9be91d4f49f76 Mon Sep 17 00:00:00 2001 From: Oliver Tan Date: Tue, 24 May 2022 10:39:13 -0700 Subject: [PATCH 1042/1158] stop ignoring ErrorResponse during GSS auth --- krb5.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/krb5.go b/krb5.go index f2dbe45a..08427b8e 100644 --- a/krb5.go +++ b/krb5.go @@ -2,6 +2,8 @@ package pgconn import ( "errors" + "fmt" + "github.com/jackc/pgproto3/v2" ) @@ -85,10 +87,13 @@ func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) { if err != nil { return nil, err } - gssContinue, ok := msg.(*pgproto3.AuthenticationGSSContinue) - if ok { - return gssContinue, nil + + switch m := msg.(type) { + case *pgproto3.AuthenticationGSSContinue: + return m, nil + case *pgproto3.ErrorResponse: + return nil, ErrorResponseToPgError(m) } - return nil, errors.New("expected AuthenticationGSSContinue message but received unexpected message") + return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg) } From 824d8ad40daa6ab015603df0dba250769d6c0653 Mon Sep 17 00:00:00 2001 From: James Hartig Date: Wed, 25 May 2022 09:16:02 -0400 Subject: [PATCH 1043/1158] support *sql.Scanner for null handling Fixes jackc/pgx#1211 --- pgtype.go | 54 +++++++++++++++++++++++++++++++++++++++++++++----- pgtype_test.go | 41 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/pgtype.go b/pgtype.go index eba09fa5..4078da7b 100644 --- a/pgtype.go +++ b/pgtype.go @@ -533,8 +533,22 @@ type scanPlanDataTypeSQLScanner DataType func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { scanner, ok := dst.(sql.Scanner) if !ok { - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) + dv := reflect.ValueOf(dst) + if dv.Kind() != reflect.Ptr || !dv.Type().Elem().Implements(scannerType) { + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + if src == nil { + // Ensure the pointer points to a zero version of the value + dv.Elem().Set(reflect.Zero(dv.Type().Elem())) + return nil + } + dv = dv.Elem() + // If the pointer is to a nil pointer then set that before scanning + if dv.Kind() == reflect.Ptr && dv.IsNil() { + dv.Set(reflect.New(dv.Type().Elem())) + } + scanner = dv.Interface().(sql.Scanner) } dt := (*DataType)(plan) @@ -593,7 +607,25 @@ func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode type scanPlanSQLScanner struct{} func (scanPlanSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - scanner := dst.(sql.Scanner) + scanner, ok := dst.(sql.Scanner) + if !ok { + dv := reflect.ValueOf(dst) + if dv.Kind() != reflect.Ptr || !dv.Type().Elem().Implements(scannerType) { + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + if src == nil { + // Ensure the pointer points to a zero version of the value + dv.Elem().Set(reflect.Zero(dv.Type())) + return nil + } + dv = dv.Elem() + // If the pointer is to a nil pointer then set that before scanning + if dv.Kind() == reflect.Ptr && dv.IsNil() { + dv.Set(reflect.New(dv.Type().Elem())) + } + scanner = dv.Interface().(sql.Scanner) + } if src == nil { // This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the // text format path would be converted to empty string. @@ -761,6 +793,18 @@ func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byt return newPlan.Scan(ci, oid, formatCode, src, dst) } +var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +func isScanner(dst interface{}) bool { + if _, ok := dst.(sql.Scanner); ok { + return true + } + if t := reflect.TypeOf(dst); t.Kind() == reflect.Ptr && t.Elem().Implements(scannerType) { + return true + } + return false +} + // PlanScan prepares a plan to scan a value into dst. func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan { switch formatCode { @@ -825,13 +869,13 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan } if dt != nil { - if _, ok := dst.(sql.Scanner); ok { + if isScanner(dst) { return (*scanPlanDataTypeSQLScanner)(dt) } return (*scanPlanDataTypeAssignTo)(dt) } - if _, ok := dst.(sql.Scanner); ok { + if isScanner(dst) { return scanPlanSQLScanner{} } diff --git a/pgtype_test.go b/pgtype_test.go index 85ca55e9..9127766f 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -310,3 +310,44 @@ func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { } } } + +type pgCustomInt int64 + +func (ci *pgCustomInt) Scan(src interface{}) error { + *ci = pgCustomInt(src.(int64)) + return nil +} + +func TestScanPlanBinaryInt32ScanScanner(t *testing.T) { + ci := pgtype.NewConnInfo() + src := []byte{0, 42} + var v pgCustomInt + + plan := ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &v) + err := plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, &v) + require.NoError(t, err) + require.EqualValues(t, 42, v) + + ptr := new(pgCustomInt) + plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, &ptr) + require.NoError(t, err) + require.EqualValues(t, 42, *ptr) + + ptr = new(pgCustomInt) + err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, nil, &ptr) + require.NoError(t, err) + assert.Nil(t, ptr) + + ptr = nil + plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, &ptr) + require.NoError(t, err) + require.EqualValues(t, 42, *ptr) + + ptr = nil + plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, nil, &ptr) + require.NoError(t, err) + assert.Nil(t, ptr) +} From 7d5993d104baf229bde3dea704455403a43f8bc4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 28 May 2022 06:32:39 -0500 Subject: [PATCH 1044/1158] Add BenchmarkConnectClose --- bench_test.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/bench_test.go b/bench_test.go index db27491f..c441b374 100644 --- a/bench_test.go +++ b/bench_test.go @@ -18,6 +18,20 @@ import ( "github.com/stretchr/testify/require" ) +func BenchmarkConnectClose(b *testing.B) { + for i := 0; i < b.N; i++ { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + b.Fatal(err) + } + + err = conn.Close(context.Background()) + if err != nil { + b.Fatal(err) + } + } +} + func BenchmarkMinimalUnpreparedSelectWithoutStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec From e12ba1b6b90590447e2cc597c2d62be7877c3b72 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 28 May 2022 10:59:54 -0500 Subject: [PATCH 1045/1158] Extract iobufpool --- internal/iobufpool/iobufpool.go | 46 ++++++++++++ internal/iobufpool/iobufpool_internal_test.go | 36 ++++++++++ internal/iobufpool/iobufpool_test.go | 35 +++++++++ pgproto3/chunkreader.go | 71 +++++-------------- pgproto3/chunkreader_test.go | 49 +------------ 5 files changed, 134 insertions(+), 103 deletions(-) create mode 100644 internal/iobufpool/iobufpool.go create mode 100644 internal/iobufpool/iobufpool_internal_test.go create mode 100644 internal/iobufpool/iobufpool_test.go diff --git a/internal/iobufpool/iobufpool.go b/internal/iobufpool/iobufpool.go new file mode 100644 index 00000000..52c52f45 --- /dev/null +++ b/internal/iobufpool/iobufpool.go @@ -0,0 +1,46 @@ +// Package iobufpool implements a global segregated-fit pool of buffers for IO. +package iobufpool + +import "sync" + +const minPoolExpOf2 = 8 + +var pools [18]*sync.Pool + +func init() { + for i := range pools { + bufLen := 1 << (minPoolExpOf2 + i) + pools[i] = &sync.Pool{New: func() any { return make([]byte, bufLen) }} + } +} + +// Get gets a []byte with len >= size and len <= size*2. +func Get(size int) []byte { + i := poolIdx(size) + if i >= len(pools) { + return make([]byte, size) + } + return pools[i].Get().([]byte) +} + +// Put returns buf to the pool. +func Put(buf []byte) { + i := poolIdx(len(buf)) + if i >= len(pools) { + return + } + + pools[i].Put(buf) +} + +func poolIdx(size int) int { + size-- + size >>= minPoolExpOf2 + i := 0 + for size > 0 { + size >>= 1 + i++ + } + + return i +} diff --git a/internal/iobufpool/iobufpool_internal_test.go b/internal/iobufpool/iobufpool_internal_test.go new file mode 100644 index 00000000..38b499f9 --- /dev/null +++ b/internal/iobufpool/iobufpool_internal_test.go @@ -0,0 +1,36 @@ +package iobufpool + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPoolIdx(t *testing.T) { + tests := []struct { + size int + expected int + }{ + {size: 0, expected: 0}, + {size: 1, expected: 0}, + {size: 255, expected: 0}, + {size: 256, expected: 0}, + {size: 257, expected: 1}, + {size: 511, expected: 1}, + {size: 512, expected: 1}, + {size: 513, expected: 2}, + {size: 1023, expected: 2}, + {size: 1024, expected: 2}, + {size: 1025, expected: 3}, + {size: 2047, expected: 3}, + {size: 2048, expected: 3}, + {size: 2049, expected: 4}, + {size: 8388607, expected: 15}, + {size: 8388608, expected: 15}, + {size: 8388609, expected: 16}, + } + for _, tt := range tests { + idx := poolIdx(tt.size) + assert.Equalf(t, tt.expected, idx, "size: %d", tt.size) + } +} diff --git a/internal/iobufpool/iobufpool_test.go b/internal/iobufpool/iobufpool_test.go new file mode 100644 index 00000000..9ad7417d --- /dev/null +++ b/internal/iobufpool/iobufpool_test.go @@ -0,0 +1,35 @@ +package iobufpool_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/internal/iobufpool" + "github.com/stretchr/testify/assert" +) + +func TestGet(t *testing.T) { + tests := []struct { + requestedLen int + expectedLen int + }{ + {requestedLen: 0, expectedLen: 256}, + {requestedLen: 128, expectedLen: 256}, + {requestedLen: 255, expectedLen: 256}, + {requestedLen: 256, expectedLen: 256}, + {requestedLen: 257, expectedLen: 512}, + {requestedLen: 511, expectedLen: 512}, + {requestedLen: 512, expectedLen: 512}, + {requestedLen: 513, expectedLen: 1024}, + {requestedLen: 1023, expectedLen: 1024}, + {requestedLen: 1024, expectedLen: 1024}, + {requestedLen: 33554431, expectedLen: 33554432}, + {requestedLen: 33554432, expectedLen: 33554432}, + + // Above 32 MiB skip the pool and allocate exactly the requested size. + {requestedLen: 33554433, expectedLen: 33554433}, + } + for _, tt := range tests { + buf := iobufpool.Get(tt.requestedLen) + assert.Equalf(t, tt.expectedLen, len(buf), "requestedLen: %d", tt.requestedLen) + } +} diff --git a/pgproto3/chunkreader.go b/pgproto3/chunkreader.go index 8834f521..3c35d0b1 100644 --- a/pgproto3/chunkreader.go +++ b/pgproto3/chunkreader.go @@ -2,48 +2,10 @@ package pgproto3 import ( "io" - "sync" + + "github.com/jackc/pgx/v5/internal/iobufpool" ) -type bigBufPool struct { - pool sync.Pool - byteSize int -} - -var bigBufPools []*bigBufPool - -func init() { - KiB := 1024 - bigBufSizes := []int{64 * KiB, 256 * KiB, 1024 * KiB, 4096 * KiB} - bigBufPools = make([]*bigBufPool, len(bigBufSizes)) - - for i := range bigBufPools { - byteSize := bigBufSizes[i] - bigBufPools[i] = &bigBufPool{ - pool: sync.Pool{New: func() any { return make([]byte, byteSize) }}, - byteSize: byteSize, - } - } -} - -func getBigBuf(size int) []byte { - for _, bigBufPool := range bigBufPools { - if size < bigBufPool.byteSize { - return bigBufPool.pool.Get().([]byte) - } - } - return make([]byte, size) -} - -func releaseBigBuf(buf []byte) { - for _, bigBufPool := range bigBufPools { - if len(buf) == bigBufPool.byteSize { - bigBufPool.pool.Put(buf) - return - } - } -} - // chunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and // will read as much as will fit in the current buffer in a single call regardless of how large a read is actually // requested. The memory returned via Next is only valid until the next call to Next. @@ -55,28 +17,26 @@ type chunkReader struct { buf []byte rp, wp int // buf read position and write position - ownBuf []byte // buf owned by chunkReader + minBufSize int } -// newChunkReader creates and returns a new chunkReader for r with default configuration with bufSize internal buffer. -// If bufSize is <= 0 it uses a default value. -func newChunkReader(r io.Reader, bufSize int) *chunkReader { - if bufSize <= 0 { +// newChunkReader creates and returns a new chunkReader for r with default configuration. If minBufSize is <= 0 it uses +// a default value. +func newChunkReader(r io.Reader, minBufSize int) *chunkReader { + if minBufSize <= 0 { // By historical reasons Postgres currently has 8KB send buffer inside, // so here we want to have at least the same size buffer. // @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134 // @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru // // In addition, testing has found no benefit of any larger buffer. - bufSize = 8192 + minBufSize = 8192 } - buf := make([]byte, bufSize) - return &chunkReader{ - r: r, - buf: buf, - ownBuf: buf, + r: r, + minBufSize: minBufSize, + buf: iobufpool.Get(minBufSize), } } @@ -85,9 +45,9 @@ func newChunkReader(r io.Reader, bufSize int) *chunkReader { func (r *chunkReader) Next(n int) (buf []byte, err error) { // Reset the buffer if it is empty if r.rp == r.wp { - if len(r.buf) != len(r.ownBuf) { - releaseBigBuf(r.buf) - r.buf = r.ownBuf + if len(r.buf) != r.minBufSize { + iobufpool.Put(r.buf) + r.buf = iobufpool.Get(r.minBufSize) } r.rp = 0 r.wp = 0 @@ -102,9 +62,10 @@ func (r *chunkReader) Next(n int) (buf []byte, err error) { // buf is smaller than requested number of bytes if len(r.buf) < n { - bigBuf := getBigBuf(n) + bigBuf := iobufpool.Get(n) r.wp = copy(bigBuf, r.buf[r.rp:r.wp]) r.rp = 0 + iobufpool.Put(r.buf) r.buf = bigBuf } diff --git a/pgproto3/chunkreader_test.go b/pgproto3/chunkreader_test.go index 7d7bac7f..41c8ce65 100644 --- a/pgproto3/chunkreader_test.go +++ b/pgproto3/chunkreader_test.go @@ -29,7 +29,7 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2) } - if bytes.Compare(r.buf, src) != 0 { + if bytes.Compare(r.buf[:len(src)], src) != 0 { t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf) } @@ -46,53 +46,6 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { } } -func TestChunkReaderNextGetsBiggerBufAsNeededFromBigBufPools(t *testing.T) { - server := &bytes.Buffer{} - r := newChunkReader(server, 4) - - src := []byte{1, 2, 3, 4, 5, 6, 7, 8} - server.Write(src) - - n1, err := r.Next(5) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n1, src[0:5]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1) - } - if len(r.buf) != bigBufPools[0].byteSize { - t.Fatalf("Expected len(r.buf) to be %v, but it was %v", bigBufPools[0].byteSize, len(r.buf)) - } -} - -func TestChunkReaderReusesBuf(t *testing.T) { - server := &bytes.Buffer{} - r := newChunkReader(server, 4) - - src := []byte{1, 2, 3, 4, 5, 6, 7, 8} - server.Write(src) - - n1, err := r.Next(4) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n1, src[0:4]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1) - } - - n2, err := r.Next(4) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n2, src[4:8]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2) - } - - if bytes.Compare(n1, src[4:8]) != 0 { - t.Fatalf("Expected slice to be reused, expected %v but it was %v", src[4:8], n1) - } -} - type randomReader struct { rnd *rand.Rand } From 2afddedda837064a1c0a42998e27711181884217 Mon Sep 17 00:00:00 2001 From: James Hartig Date: Wed, 1 Jun 2022 10:57:42 -0400 Subject: [PATCH 1046/1158] protect against panic from PlanScan when interface{}(nil) is passed --- pgtype.go | 2 +- pgtype_test.go | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/pgtype.go b/pgtype.go index 4078da7b..08e7395e 100644 --- a/pgtype.go +++ b/pgtype.go @@ -799,7 +799,7 @@ func isScanner(dst interface{}) bool { if _, ok := dst.(sql.Scanner); ok { return true } - if t := reflect.TypeOf(dst); t.Kind() == reflect.Ptr && t.Elem().Implements(scannerType) { + if t := reflect.TypeOf(dst); t != nil && t.Kind() == reflect.Ptr && t.Elem().Implements(scannerType) { return true } return false diff --git a/pgtype_test.go b/pgtype_test.go index 9127766f..67f36373 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -351,3 +351,13 @@ func TestScanPlanBinaryInt32ScanScanner(t *testing.T) { require.NoError(t, err) assert.Nil(t, ptr) } + +// Test for https://github.com/jackc/pgtype/issues/164 +func TestScanPlanInterface(t *testing.T) { + ci := pgtype.NewConnInfo() + src := []byte{0, 42} + var v interface{} + plan := ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, v) + err := plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, v) + assert.Error(t, err) +} From 6fc738ea05eec3bec8b39e568b99c5ad52ce8073 Mon Sep 17 00:00:00 2001 From: William Storey Date: Fri, 3 Jun 2022 18:00:52 +0000 Subject: [PATCH 1047/1158] Use correct test description --- inet_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/inet_test.go b/inet_test.go index 09c6b21f..8d70c0d0 100644 --- a/inet_test.go +++ b/inet_test.go @@ -68,8 +68,8 @@ func TestInetSet(t *testing.T) { assert.Equalf(t, tt.result.Status, r.Status, "%d: Status", i) if tt.result.Status == pgtype.Present { - assert.Equalf(t, tt.result.IPNet.Mask, r.IPNet.Mask, "%d: IP", i) - assert.Truef(t, tt.result.IPNet.IP.Equal(r.IPNet.IP), "%d: Mask", i) + assert.Equalf(t, tt.result.IPNet.Mask, r.IPNet.Mask, "%d: Mask", i) + assert.Truef(t, tt.result.IPNet.IP.Equal(r.IPNet.IP), "%d: IP", i) } } } From 1e485c1c3b3a757f07124a6a7b71510dd0c008a2 Mon Sep 17 00:00:00 2001 From: William Storey Date: Fri, 3 Jun 2022 18:08:58 +0000 Subject: [PATCH 1048/1158] Do not send IPv4 networks as IPv4-mapped IPv6 Previously if we provided a parameter that was an array of strings such as []string{"0.0.0.0/8"}, we would encode this when sending to Postgres as ::ffff:0.0.0.0/8. From what I can tell, this is because when parsing the IP/network using net functions, we get a byte array that is 16 bytes long, even if it is an IPv4 network. In Inet.EncodeBinary(), we look at the length of the IP to determine what family the input is, and saw it as IPv6 because of this. We now always normalize IPv4 addresses using To4(). --- inet.go | 19 ++++++++++++++----- inet_test.go | 5 ++++- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/inet.go b/inet.go index f35f88ba..25e56170 100644 --- a/inet.go +++ b/inet.go @@ -47,17 +47,26 @@ func (dst *Inet) Set(src interface{}) error { case string: ip, ipnet, err := net.ParseCIDR(value) if err != nil { - ip = net.ParseIP(value) + ip := net.ParseIP(value) if ip == nil { return fmt.Errorf("unable to parse inet address: %s", value) } - ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} + if ipv4 := ip.To4(); ipv4 != nil { - ip = ipv4 - ipnet.Mask = net.CIDRMask(32, 32) + ipnet = &net.IPNet{IP: ipv4, Mask: net.CIDRMask(32, 32)} + } else { + ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} + } + } else { + ipnet.IP = ip + if ipv4 := ipnet.IP.To4(); ipv4 != nil { + ipnet.IP = ipv4 + if len(ipnet.Mask) == 16 { + ipnet.Mask = ipnet.Mask[12:] // Needed if input is IPv4-mapped IPv6. + } } } - ipnet.IP = ip + *dst = Inet{IPNet: ipnet, Status: Present} case *net.IPNet: if value == nil { diff --git a/inet_test.go b/inet_test.go index 8d70c0d0..badbf82e 100644 --- a/inet_test.go +++ b/inet_test.go @@ -52,10 +52,12 @@ func TestInetSet(t *testing.T) { {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: "1.2.3.4/24", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("1.2.3.4"), Mask: net.CIDRMask(24, 32)}, Status: pgtype.Present}}, + {source: "1.2.3.4/24", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("1.2.3.4").To4(), Mask: net.CIDRMask(24, 32)}, Status: pgtype.Present}}, {source: "10.0.0.1", result: pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Status: pgtype.Present}}, {source: "2607:f8b0:4009:80b::200e", result: pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Status: pgtype.Present}}, {source: net.ParseIP(""), result: pgtype.Inet{Status: pgtype.Null}}, + {source: "0.0.0.0/8", result: pgtype.Inet{IPNet: mustParseInet(t, "0.0.0.0/8"), Status: pgtype.Present}}, + {source: "::ffff:0.0.0.0/104", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("0.0.0.0").To4(), Mask: net.CIDRMask(8, 32)}, Status: pgtype.Present}}, } for i, tt := range successfulTests { @@ -70,6 +72,7 @@ func TestInetSet(t *testing.T) { if tt.result.Status == pgtype.Present { assert.Equalf(t, tt.result.IPNet.Mask, r.IPNet.Mask, "%d: Mask", i) assert.Truef(t, tt.result.IPNet.IP.Equal(r.IPNet.IP), "%d: IP", i) + assert.Equalf(t, len(tt.result.IPNet.IP), len(r.IPNet.IP), "%d: IP length", i) } } } From 4db2a33562c6d2d38da9dbe9b8e29f2d4487cc5b Mon Sep 17 00:00:00 2001 From: William Storey Date: Mon, 6 Jun 2022 16:50:43 +0000 Subject: [PATCH 1049/1158] Do not convert IPv4-mapped IPv6 addresses to IPv4 These addresses behave differently in some cases, so assume if we're given them, we keep them as they are. --- inet.go | 26 +++++++++++++++++++++++--- inet_test.go | 2 +- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/inet.go b/inet.go index 25e56170..a343f5e2 100644 --- a/inet.go +++ b/inet.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "fmt" "net" + "strings" ) // Network address family is dependent on server socket.h value for AF_INET. @@ -52,17 +53,17 @@ func (dst *Inet) Set(src interface{}) error { return fmt.Errorf("unable to parse inet address: %s", value) } - if ipv4 := ip.To4(); ipv4 != nil { + if ipv4 := maybeGetIPv4(value, ip); ipv4 != nil { ipnet = &net.IPNet{IP: ipv4, Mask: net.CIDRMask(32, 32)} } else { ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} } } else { ipnet.IP = ip - if ipv4 := ipnet.IP.To4(); ipv4 != nil { + if ipv4 := maybeGetIPv4(value, ipnet.IP); ipv4 != nil { ipnet.IP = ipv4 if len(ipnet.Mask) == 16 { - ipnet.Mask = ipnet.Mask[12:] // Needed if input is IPv4-mapped IPv6. + ipnet.Mask = ipnet.Mask[12:] // Not sure this is ever needed. } } } @@ -96,6 +97,25 @@ func (dst *Inet) Set(src interface{}) error { return nil } +// Convert the net.IP to IPv4, if appropriate. +// +// When parsing a string to a net.IP using net.ParseIP() and the like, we get a +// 16 byte slice for IPv4 addresses as well as IPv6 addresses. This function +// calls To4() to convert them to a 4 byte slice. This is useful as it allows +// users of the net.IP check for IPv4 addresses based on the length and makes +// it clear we are handling IPv4 as opposed to IPv6 or IPv4-mapped IPv6 +// addresses. +func maybeGetIPv4(input string, ip net.IP) net.IP { + // Do not do this if the provided input looks like IPv6. This is because + // To4() on IPv4-mapped IPv6 addresses converts them to IPv4, which behave + // different in some cases. + if strings.Contains(input, ":") { + return nil + } + + return ip.To4() +} + func (dst Inet) Get() interface{} { switch dst.Status { case Present: diff --git a/inet_test.go b/inet_test.go index badbf82e..52759371 100644 --- a/inet_test.go +++ b/inet_test.go @@ -57,7 +57,7 @@ func TestInetSet(t *testing.T) { {source: "2607:f8b0:4009:80b::200e", result: pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Status: pgtype.Present}}, {source: net.ParseIP(""), result: pgtype.Inet{Status: pgtype.Null}}, {source: "0.0.0.0/8", result: pgtype.Inet{IPNet: mustParseInet(t, "0.0.0.0/8"), Status: pgtype.Present}}, - {source: "::ffff:0.0.0.0/104", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("0.0.0.0").To4(), Mask: net.CIDRMask(8, 32)}, Status: pgtype.Present}}, + {source: "::ffff:0.0.0.0/104", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("::ffff:0.0.0.0"), Mask: net.CIDRMask(104, 128)}, Status: pgtype.Present}}, } for i, tt := range successfulTests { From 6dd004c8b8f4f938a26778020882139b8f4de1c2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 20 Jun 2022 20:40:25 -0500 Subject: [PATCH 1050/1158] Backport numeric to string from v5 refs https://github.com/jackc/pgx/issues/1230 --- numeric.go | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/numeric.go b/numeric.go index cd057749..1f32b36b 100644 --- a/numeric.go +++ b/numeric.go @@ -1,6 +1,7 @@ package pgtype import ( + "bytes" "database/sql/driver" "encoding/binary" "fmt" @@ -375,6 +376,12 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } v.Set(rat) + case *string: + buf, err := encodeNumericText(*src, nil) + if err != nil { + return err + } + *v = string(buf) default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) @@ -792,3 +799,55 @@ func (src Numeric) Value() (driver.Value, error) { return nil, errUndefined } } + +func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { + // if !n.Valid { + // return nil, nil + // } + + if n.NaN { + buf = append(buf, "NaN"...) + return buf, nil + } else if n.InfinityModifier == Infinity { + buf = append(buf, "Infinity"...) + return buf, nil + } else if n.InfinityModifier == NegativeInfinity { + buf = append(buf, "-Infinity"...) + return buf, nil + } + + buf = append(buf, n.numberTextBytes()...) + + return buf, nil +} + +// numberString returns a string of the number. undefined if NaN, infinite, or NULL +func (n Numeric) numberTextBytes() []byte { + intStr := n.Int.String() + buf := &bytes.Buffer{} + exp := int(n.Exp) + if exp > 0 { + buf.WriteString(intStr) + for i := 0; i < exp; i++ { + buf.WriteByte('0') + } + } else if exp < 0 { + if len(intStr) <= -exp { + buf.WriteString("0.") + leadingZeros := -exp - len(intStr) + for i := 0; i < leadingZeros; i++ { + buf.WriteByte('0') + } + buf.WriteString(intStr) + } else if len(intStr) > -exp { + dpPos := len(intStr) + exp + buf.WriteString(intStr[:dpPos]) + buf.WriteByte('.') + buf.WriteString(intStr[dpPos:]) + } + } else { + buf.WriteString(intStr) + } + + return buf.Bytes() +} From c0a4d1b9ce701620652eb27cf73bf86891616976 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 20 Jun 2022 20:43:56 -0500 Subject: [PATCH 1051/1158] Add a few tests --- pgtype/numeric_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 14225e92..693fe6f2 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -113,6 +113,10 @@ func TestNumericCodec(t *testing.T) { {"1.23", new(string), isExpectedEq("1.23")}, {pgtype.Numeric{}, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, {nil, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, + {mustParseNumeric(t, "1"), new(string), isExpectedEq("1")}, + {pgtype.Numeric{NaN: true, Valid: true}, new(string), isExpectedEq("NaN")}, + {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, new(string), isExpectedEq("Infinity")}, + {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(string), isExpectedEq("-Infinity")}, }) pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int8", []pgxtest.ValueRoundTripTest{ From 12c49ee213fabc092f24b92db6874ed0d319d7b3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 23 Jun 2022 21:01:56 -0500 Subject: [PATCH 1052/1158] shopspring-numeric extension does not panic on NaN https://github.com/jackc/pgtype/issues/169 --- ext/shopspring-numeric/decimal.go | 10 ++++++++++ ext/shopspring-numeric/decimal_test.go | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index ef3ce201..c75efa36 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -263,6 +263,16 @@ func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } + if num.NaN { + return errors.New("cannot decode 'NaN'") + } + if num.InfinityModifier == pgtype.Infinity { + return errors.New("cannot decode 'Infinity'") + } + if num.InfinityModifier == pgtype.NegativeInfinity { + return errors.New("cannot decode '-Infinity'") + } + *dst = Numeric{Decimal: decimal.NewFromBigInt(num.Int, num.Exp), Status: pgtype.Present} return nil diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go index e635da41..e3c6d59d 100644 --- a/ext/shopspring-numeric/decimal_test.go +++ b/ext/shopspring-numeric/decimal_test.go @@ -1,6 +1,7 @@ package numeric_test import ( + "context" "fmt" "math/big" "math/rand" @@ -93,6 +94,15 @@ func TestNumericNormalize(t *testing.T) { }) } +func TestNumericNaN(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + var n shopspring.Numeric + err := conn.QueryRow(context.Background(), `select 'NaN'::numeric`).Scan(&n) + require.EqualError(t, err, `can't scan into dest[0]: cannot decode 'NaN'`) +} + func TestNumericTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, From 811d855a35bb585cfc1e28df5228241a8d276b14 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Jun 2022 13:15:31 -0500 Subject: [PATCH 1053/1158] Add non-blocking IO This eliminates an edge case that can cause a deadlock and is a prerequisite to cheaply testing connection liveness and to recoving a connection after a timeout. https://github.com/jackc/pgconn/issues/27 Squashed commit of the following: commit 0d7b0dddea1575e9fd72592665badb8cbdd581cc Author: Jack Christensen Date: Sat Jun 25 13:15:05 2022 -0500 Add test for non-blocking IO preventing deadlock commit 79d68d23d38bb03ddb8bf13cb45792430eaf959a Author: Jack Christensen Date: Sat Jun 18 18:23:24 2022 -0500 Release CopyFrom buf when done commit 95a43139c7b0b7557898c4480e5b3e42417ee3c0 Author: Jack Christensen Date: Sat Jun 18 18:22:32 2022 -0500 Avoid allocations with non-blocking write commit 6b63ceee076794bc4380495a55dd414dbbd08a43 Author: Jack Christensen Date: Sat Jun 18 17:46:49 2022 -0500 Simplify iobufpool usage commit 60ecdda02e5a24c894df4f58d31c485b90de5d5b Author: Jack Christensen Date: Sat Jun 18 11:51:59 2022 -0500 Add true non-blocking IO commit 7dd26a34a182d4aacaed3bf8c09f9cc48a7b6156 Author: Jack Christensen Date: Sat Jun 4 20:28:23 2022 -0500 Fix block when reading more than buffered commit afa702213f1b6d24c976406448301b2be53b7f70 Author: Jack Christensen Date: Sat Jun 4 20:10:23 2022 -0500 More TLS support commit 51655bf8f40321d5f89bc3c02dd55fba0ac6aa49 Author: Jack Christensen Date: Sat Jun 4 17:46:00 2022 -0500 Steps toward TLS commit 2b80beb1ed75f0f58db8188b87753dbc26b62098 Author: Jack Christensen Date: Sat Jun 4 13:06:29 2022 -0500 Litle more TLS support commit 765b2c6e7b034ff6ffab3974579fd6ee7add593b Author: Jack Christensen Date: Sat Jun 4 12:29:30 2022 -0500 Add testing of TLS commit 5b64432afbed9224f9512cc46624c88e7ebec625 Author: Jack Christensen Date: Sat Jun 4 09:48:19 2022 -0500 Introduce testVariants in prep for TLS commit ecebd7b103d4a9125c61e83f3651b950658b0b84 Author: Jack Christensen Date: Sat Jun 4 09:32:14 2022 -0500 Handle and test read of previously buffered data commit 09c64d8cf3ca5be1a31bef46bf78fa5cb9fae831 Author: Jack Christensen Date: Sat Jun 4 09:04:48 2022 -0500 Rename nbbconn to nbconn commit 73398bc67a7b7bd1aa044fb9b0546f4198ef92d2 Author: Jack Christensen Date: Sat Jun 4 08:59:53 2022 -0500 Remove backup files commit f1df39a29d23ae4e5175b92c69697f2bf9b4e112 Author: Jack Christensen Date: Sat Jun 4 08:58:05 2022 -0500 Initial passing tests commit ea3cdab234343fc9761d9b7966c5346179cd1b01 Author: Jack Christensen Date: Sat Jun 4 08:38:57 2022 -0500 Fix connect timeout commit ca22396789d120ff556f9704f4470268fbc8c0d8 Author: Jack Christensen Date: Thu Jun 2 19:32:55 2022 -0500 wip commit 2e7b46d5d7454daf0859dd48f8a8e190995164c5 Author: Jack Christensen Date: Mon May 30 08:32:43 2022 -0500 Update comments commit 7d04dc5caa80cb147929b6f65bab60a27baaff89 Author: Jack Christensen Date: Sat May 28 19:43:23 2022 -0500 Fix broken test commit bf1edc77d70465b4097a59c08c581033d2033ac6 Author: Jack Christensen Date: Sat May 28 19:40:33 2022 -0500 fixed putting wrong size bufs commit 1f7a855b2e4d1e14f85ac5f5683e2b93db0a4bd9 Author: Jack Christensen Date: Sat May 28 18:13:47 2022 -0500 initial not quite working non-blocking conn --- internal/iobufpool/iobufpool.go | 39 +- internal/iobufpool/iobufpool_internal_test.go | 2 +- internal/iobufpool/iobufpool_test.go | 75 ++- internal/nbconn/bufferqueue.go | 70 +++ internal/nbconn/nbconn.go | 513 ++++++++++++++++ internal/nbconn/nbconn_test.go | 554 ++++++++++++++++++ pgconn/pgconn.go | 188 +++--- pgconn/pgconn_test.go | 43 +- 8 files changed, 1333 insertions(+), 151 deletions(-) create mode 100644 internal/nbconn/bufferqueue.go create mode 100644 internal/nbconn/nbconn.go create mode 100644 internal/nbconn/nbconn_test.go diff --git a/internal/iobufpool/iobufpool.go b/internal/iobufpool/iobufpool.go index 52c52f45..9e55c435 100644 --- a/internal/iobufpool/iobufpool.go +++ b/internal/iobufpool/iobufpool.go @@ -14,26 +14,16 @@ func init() { } } -// Get gets a []byte with len >= size and len <= size*2. +// Get gets a []byte of len size with cap <= size*2. func Get(size int) []byte { - i := poolIdx(size) + i := getPoolIdx(size) if i >= len(pools) { return make([]byte, size) } - return pools[i].Get().([]byte) + return pools[i].Get().([]byte)[:size] } -// Put returns buf to the pool. -func Put(buf []byte) { - i := poolIdx(len(buf)) - if i >= len(pools) { - return - } - - pools[i].Put(buf) -} - -func poolIdx(size int) int { +func getPoolIdx(size int) int { size-- size >>= minPoolExpOf2 i := 0 @@ -44,3 +34,24 @@ func poolIdx(size int) int { return i } + +// Put returns buf to the pool. +func Put(buf []byte) { + i := putPoolIdx(cap(buf)) + if i < 0 { + return + } + + pools[i].Put(buf) +} + +func putPoolIdx(size int) int { + minPoolSize := 1 << minPoolExpOf2 + for i := range pools { + if size == minPoolSize<= len(bq.queue) { + bq.growQueue() + } + bq.queue[bq.w] = buf + bq.w++ +} + +func (bq *bufferQueue) pushFront(buf []byte) { + bq.lock.Lock() + defer bq.lock.Unlock() + + if bq.w >= len(bq.queue) { + bq.growQueue() + } + copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w]) + bq.queue[bq.r] = buf + bq.w++ +} + +func (bq *bufferQueue) popFront() []byte { + bq.lock.Lock() + defer bq.lock.Unlock() + + if bq.r == bq.w { + return nil + } + + buf := bq.queue[bq.r] + bq.queue[bq.r] = nil // Clear reference so it can be garbage collected. + bq.r++ + + if bq.r == bq.w { + bq.r = 0 + bq.w = 0 + if len(bq.queue) > minBufferQueueLen { + bq.queue = make([][]byte, minBufferQueueLen) + } + } + + return buf +} + +func (bq *bufferQueue) growQueue() { + desiredLen := (len(bq.queue) + 1) * 3 / 2 + if desiredLen < minBufferQueueLen { + desiredLen = minBufferQueueLen + } + + newQueue := make([][]byte, desiredLen) + copy(newQueue, bq.queue) + bq.queue = newQueue +} diff --git a/internal/nbconn/nbconn.go b/internal/nbconn/nbconn.go new file mode 100644 index 00000000..00d0e420 --- /dev/null +++ b/internal/nbconn/nbconn.go @@ -0,0 +1,513 @@ +// Package nbconn implements a non-blocking net.Conn wrapper. +// +// It is designed to solve three problems. +// +// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all +// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion. +// +// The second is the inability to use a write deadline with a TLS.Conn without killing the connection. +// +// The third is to efficiently check if a connection has been closed via a non-blocking read. +package nbconn + +import ( + "crypto/tls" + "errors" + "net" + "os" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/jackc/pgx/v5/internal/iobufpool" +) + +var errClosed = errors.New("closed") +var ErrWouldBlock = new(wouldBlockError) + +const fakeNonblockingWaitDuration = 100 * time.Millisecond + +// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read +// mode. +var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC) + +// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to +// ignore all future calls. +var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC) + +// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error. +type wouldBlockError struct{} + +func (*wouldBlockError) Error() string { + return "would block" +} + +func (*wouldBlockError) Timeout() bool { return true } +func (*wouldBlockError) Temporary() bool { return true } + +// Conn is a net.Conn where Write never blocks and always succeeds. Flush must be called to actually write to the +// underlying connection. +type Conn interface { + net.Conn + Flush() error +} + +// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn. +type NetConn struct { + conn net.Conn + rawConn syscall.RawConn + + readQueue bufferQueue + writeQueue bufferQueue + + readFlushLock sync.Mutex + // non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the + // callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations. + nonblockWriteBuf []byte + nonblockWriteErr error + nonblockWriteN int + + readDeadlineLock sync.Mutex + readDeadline time.Time + readNonblocking bool + + writeDeadlineLock sync.Mutex + writeDeadline time.Time + + // Only access with atomics + closed int64 // 0 = not closed, 1 = closed +} + +func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn { + nc := &NetConn{ + conn: conn, + } + + if !fakeNonBlockingIO { + if sc, ok := conn.(syscall.Conn); ok { + if rawConn, err := sc.SyscallConn(); err == nil { + nc.rawConn = rawConn + } + } + } + + return nc +} + +// Read implements io.Reader. +func (c *NetConn) Read(b []byte) (n int, err error) { + if c.isClosed() { + return 0, errClosed + } + + c.readFlushLock.Lock() + defer c.readFlushLock.Unlock() + + err = c.flush() + if err != nil { + return 0, err + } + + for n < len(b) { + buf := c.readQueue.popFront() + if buf == nil { + break + } + copiedN := copy(b[n:], buf) + if copiedN < len(buf) { + buf = buf[copiedN:] + c.readQueue.pushFront(buf) + } else { + iobufpool.Put(buf) + } + n += copiedN + } + + // If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to + // Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block. + if n > 0 { + return n, nil + } + + var readNonblocking bool + c.readDeadlineLock.Lock() + readNonblocking = c.readNonblocking + c.readDeadlineLock.Unlock() + + var readN int + if readNonblocking { + readN, err = c.nonblockingRead(b[n:]) + } else { + readN, err = c.conn.Read(b[n:]) + } + n += readN + return n, err +} + +// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is +// closed. Call Flush to actually write to the underlying connection. +func (c *NetConn) Write(b []byte) (n int, err error) { + if c.isClosed() { + return 0, errClosed + } + + buf := iobufpool.Get(len(b)) + copy(buf, b) + c.writeQueue.pushBack(buf) + return len(b), nil +} + +func (c *NetConn) Close() (err error) { + swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1) + if !swapped { + return errClosed + } + + defer func() { + closeErr := c.conn.Close() + if err == nil { + err = closeErr + } + }() + + c.readFlushLock.Lock() + defer c.readFlushLock.Unlock() + err = c.flush() + if err != nil { + return err + } + + return nil +} + +func (c *NetConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *NetConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t). +func (c *NetConn) SetDeadline(t time.Time) error { + err := c.SetReadDeadline(t) + if err != nil { + return err + } + return c.SetWriteDeadline(t) +} + +// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking. +func (c *NetConn) SetReadDeadline(t time.Time) error { + if c.isClosed() { + return errClosed + } + + c.readDeadlineLock.Lock() + defer c.readDeadlineLock.Unlock() + if c.readDeadline == disableSetDeadlineDeadline { + return nil + } + if t == disableSetDeadlineDeadline { + c.readDeadline = t + return nil + } + + if t == NonBlockingDeadline { + c.readNonblocking = true + t = time.Time{} + } else { + c.readNonblocking = false + } + + c.readDeadline = t + + return c.conn.SetReadDeadline(t) +} + +func (c *NetConn) SetWriteDeadline(t time.Time) error { + if c.isClosed() { + return errClosed + } + + c.writeDeadlineLock.Lock() + defer c.writeDeadlineLock.Unlock() + if c.writeDeadline == disableSetDeadlineDeadline { + return nil + } + if t == disableSetDeadlineDeadline { + c.writeDeadline = t + return nil + } + + c.writeDeadline = t + + return c.conn.SetWriteDeadline(t) +} + +func (c *NetConn) Flush() error { + if c.isClosed() { + return errClosed + } + + c.readFlushLock.Lock() + defer c.readFlushLock.Unlock() + return c.flush() +} + +// flush does the actual work of flushing the writeQueue. readFlushLock must already be held. +func (c *NetConn) flush() error { + var stopChan chan struct{} + var errChan chan error + + defer func() { + if stopChan != nil { + select { + case stopChan <- struct{}{}: + case <-errChan: + } + } + }() + + for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() { + remainingBuf := buf + for len(remainingBuf) > 0 { + n, err := c.nonblockingWrite(remainingBuf) + remainingBuf = remainingBuf[n:] + if err != nil { + if !errors.Is(err, ErrWouldBlock) { + buf = buf[:len(remainingBuf)] + copy(buf, remainingBuf) + c.writeQueue.pushFront(buf) + return err + } + + // Writing was blocked. Reading might unblock it. + if stopChan == nil { + stopChan, errChan = c.bufferNonblockingRead() + } + + select { + case err := <-errChan: + stopChan = nil + return err + default: + } + + } + } + iobufpool.Put(buf) + } + + return nil +} + +func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) { + stopChan = make(chan struct{}) + errChan = make(chan error, 1) + + go func() { + for { + buf := iobufpool.Get(8 * 1024) + n, err := c.nonblockingRead(buf) + if n > 0 { + buf = buf[:n] + c.readQueue.pushBack(buf) + } + + if err != nil { + if !errors.Is(err, ErrWouldBlock) { + errChan <- err + return + } + } + + select { + case <-stopChan: + return + default: + } + } + }() + + return stopChan, errChan +} + +func (c *NetConn) isClosed() bool { + closed := atomic.LoadInt64(&c.closed) + return closed == 1 +} + +func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) { + if c.rawConn == nil { + return c.fakeNonblockingWrite(b) + } else { + return c.realNonblockingWrite(b) + } +} + +func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) { + c.writeDeadlineLock.Lock() + defer c.writeDeadlineLock.Unlock() + + deadline := time.Now().Add(fakeNonblockingWaitDuration) + if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) { + err = c.conn.SetWriteDeadline(deadline) + if err != nil { + return 0, err + } + defer func() { + // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. + c.conn.SetWriteDeadline(c.writeDeadline) + + if err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + err = ErrWouldBlock + } + } + }() + } + + return c.conn.Write(b) +} + +// realNonblockingWrite does a non-blocking write. readFlushLock must already be held. +func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { + c.nonblockWriteBuf = b + c.nonblockWriteN = 0 + c.nonblockWriteErr = nil + err = c.rawConn.Write(func(fd uintptr) (done bool) { + c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf) + return true + }) + n = c.nonblockWriteN + if err == nil && c.nonblockWriteErr != nil { + if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) { + err = ErrWouldBlock + } else { + err = c.nonblockWriteErr + } + } + if err != nil { + // n may be -1 when an error occurs. + if n < 0 { + n = 0 + } + + return n, err + } + + return n, nil +} + +func (c *NetConn) nonblockingRead(b []byte) (n int, err error) { + if c.rawConn == nil { + return c.fakeNonblockingRead(b) + } else { + return c.realNonblockingRead(b) + } +} + +func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) { + c.readDeadlineLock.Lock() + defer c.readDeadlineLock.Unlock() + + deadline := time.Now().Add(fakeNonblockingWaitDuration) + if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) { + err = c.conn.SetReadDeadline(deadline) + if err != nil { + return 0, err + } + defer func() { + // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. + c.conn.SetReadDeadline(c.readDeadline) + + if err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + err = ErrWouldBlock + } + } + }() + } + + return c.conn.Read(b) +} + +func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { + var funcErr error + err = c.rawConn.Read(func(fd uintptr) (done bool) { + n, funcErr = syscall.Read(int(fd), b) + return true + }) + if err == nil && funcErr != nil { + if errors.Is(funcErr, syscall.EWOULDBLOCK) { + err = ErrWouldBlock + } else { + err = funcErr + } + } + if err != nil { + // n may be -1 when an error occurs. + if n < 0 { + n = 0 + } + + return n, err + } + + return n, nil +} + +// syscall.Conn is interface + +// TLSClient establishes a TLS connection as a client over conn using config. +// +// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby +// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the +// *TLSConn is returned. +func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) { + tc := tls.Client(conn, config) + err := tc.Handshake() + if err != nil { + return nil, err + } + + // Ensure last written part of Handshake is actually sent. + err = conn.Flush() + if err != nil { + return nil, err + } + + return &TLSConn{ + tlsConn: tc, + nbConn: conn, + }, nil +} + +// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a +// tls.Conn. +type TLSConn struct { + tlsConn *tls.Conn + nbConn *NetConn +} + +func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) } +func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) } +func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() } +func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() } +func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() } + +func (tc *TLSConn) Close() error { + // tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then + // sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our + // own 5 second deadline then make all set deadlines no-op. + tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5)) + tc.tlsConn.SetDeadline(disableSetDeadlineDeadline) + + return tc.tlsConn.Close() +} + +func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) } +func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) } +func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) } diff --git a/internal/nbconn/nbconn_test.go b/internal/nbconn/nbconn_test.go new file mode 100644 index 00000000..2db47039 --- /dev/null +++ b/internal/nbconn/nbconn_test.go @@ -0,0 +1,554 @@ +package nbconn_test + +import ( + "crypto/tls" + "io" + "net" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5/internal/nbconn" + "github.com/stretchr/testify/require" +) + +// Test keys generated with: +// +// $ openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem -sha256 -nodes -days 20000 -subj '/CN=localhost' + +var testTLSPublicKey = []byte(`-----BEGIN CERTIFICATE----- +MIICpjCCAY4CCQCjQKYdUDQzKDANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls +b2NhbGhvc3QwIBcNMjIwNjA0MTY1MzE2WhgPMjA3NzAzMDcxNjUzMTZaMBQxEjAQ +BgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +ALHbOu80cfSPufKTZsKf3E5rCXHeIHjaIbgHEXA2SW/n77U8oZX518s+27FO0sK5 +yA0WnEIwY34PU359sNR5KelARGnaeh3HdaGm1nuyyxBtwwAqIuM0UxGAMF/mQ4lT +caZPxG+7WlYDqnE3eVXUtG4c+T7t5qKAB3MtfbzKFSjczkWkroi6cTypmHArGghT +0VWWVu0s9oNp5q8iWchY2o9f0aIjmKv6FgtilO+geev+4U+QvtvrziR5BO3/3EgW +c5TUVcf+lwkvp8ziXvargmjjnNTyeF37y4KpFcex0v7z7hSrUK4zU0+xRn7Bp17v +7gzj0xN+HCsUW1cjPFNezX0CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAbEBzewzg +Z5F+BqMSxP3HkMCkLLH0N9q0/DkZaVyZ38vrjcjaDYuabq28kA2d5dc5jxsQpvTw +HTGqSv1ZxJP3pBFv6jLSh8xaM6tUkk482Q6DnZGh97CD4yup/yJzkn5nv9OHtZ9g +TnaQeeXgOz0o5Zq9IpzHJb19ysya3UCIK8oKXbSO4Qd168seCq75V2BFHDpmejjk +D92eT6WODlzzvZbhzA1F3/cUilZdhbQtJMqdecKvD+yrBpzGVqzhWQsXwsRAU1fB +hShx+D14zUGM2l4wlVzOAuGh4ZL7x3AjJsc86TsCavTspS0Xl51j+mRbiULq7G7Y +E7ZYmaKTMOhvkg== +-----END CERTIFICATE-----`) + +// The strings.ReplaceAll is used to placate any secret scanners that would squawk if they saw a private key embedded in +// source code. +var testTLSPrivateKey = []byte(strings.ReplaceAll(`-----BEGIN TESTING KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCx2zrvNHH0j7ny +k2bCn9xOawlx3iB42iG4BxFwNklv5++1PKGV+dfLPtuxTtLCucgNFpxCMGN+D1N+ +fbDUeSnpQERp2nodx3WhptZ7sssQbcMAKiLjNFMRgDBf5kOJU3GmT8Rvu1pWA6px +N3lV1LRuHPk+7eaigAdzLX28yhUo3M5FpK6IunE8qZhwKxoIU9FVllbtLPaDaeav +IlnIWNqPX9GiI5ir+hYLYpTvoHnr/uFPkL7b684keQTt/9xIFnOU1FXH/pcJL6fM +4l72q4Jo45zU8nhd+8uCqRXHsdL+8+4Uq1CuM1NPsUZ+wade7+4M49MTfhwrFFtX +IzxTXs19AgMBAAECggEBAJcHt5ARVQN8WUbobMawwX/F3QtYuPJnKWMAfYpwTwQ8 +TI32orCcrObmxeBXMxowcPTMUnzSYmpV0W0EhvimuzRbYr0Qzcoj6nwPFOuN9GpL +CuBE58NQV4nw9SM6gfdHaKb17bWDvz5zdnUVym9cZKts5yrNEqDDX5Aq/S8n27gJ +/qheXwSxwETVO6kMEW1ndNIWDP8DPQ0E4O//RuMZwxpnZdnjGKkdVNy8I1BpgDgn +lwgkE3H3IciASki1GYXoyvrIiRwMQVzvYD2zcgwK9OZSjZe0TGwAGa+eQdbs3A9I +Ir1kYn6ZMGMRFJA2XHJW3hMZdWB/t2xMBGy75Uv9sAECgYEA1o+oRUYwwQ1MwBo9 +YA6c00KjhFgrjdzyKPQrN14Q0dw5ErqRkhp2cs7BRdCDTDrjAegPc3Otg7uMa1vp +RgU/C72jwzFLYATvn+RLGRYRyqIE+bQ22/lLnXTrp4DCfdMrqWuQbIYouGHqfQrq +MfdtSUpQ6VZCi9zHehXOYwBMvQECgYEA1DTQFpe+tndIFmguxxaBwDltoPh5omzd +3vA7iFct2+UYk5W9shfAekAaZk2WufKmmC3OfBWYyIaJ7QwQpuGDS3zwjy6WFMTE +Otp2CypFCVahwHcvn2jYHmDMT0k0Pt6X2S3GAyWTyEPv7mAfKR1OWUYi7ZgdXpt0 +TtL3Z3JyhH0CgYEAwveHUGuXodUUCPvPCZo9pzrGm1wDN8WtxskY/Bbd8dTLh9lA +riKdv3Vg6q+un3ZjETht0dsrsKib0HKUZqwdve11AcmpVHcnx4MLOqBzSk4vdzfr +IbhGna3A9VRrZyqcYjb75aGDHwjaqwVgCkdrZ03AeEeJ8M2N9cIa6Js9IAECgYBu +nlU24cVdspJWc9qml3ntrUITnlMxs1R5KXuvF9rk/OixzmYDV1RTpeTdHWcL6Yyk +WYSAtHVfWpq9ggOQKpBZonh3+w3rJ6MvFsBgE5nHQ2ywOrENhQbb1xPJ5NwiRcCc +Srsk2srNo3SIK30y3n8AFIqSljABKEIZ8Olc+JDvtQKBgQCiKz43zI6a0HscgZ77 +DCBduWP4nk8BM7QTFxs9VypjrylMDGGtTKHc5BLA5fNZw97Hb7pcicN7/IbUnQUD +pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG +2qWm8jTPeDC3sq+67s2oojHf+Q== +-----END TESTING KEY-----`, "TESTING KEY", "PRIVATE KEY")) + +func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) { + for _, tt := range []struct { + name string + makeConns func(t *testing.T) (local, remote net.Conn) + useTLS bool + fakeNonBlockingIO bool + }{ + { + name: "Pipe", + makeConns: makePipeConns, + useTLS: false, + fakeNonBlockingIO: true, + }, + { + name: "TCP with Fake Non-blocking IO", + makeConns: makeTCPConns, + useTLS: false, + fakeNonBlockingIO: true, + }, + { + name: "TLS over TCP with Fake Non-blocking IO", + makeConns: makeTCPConns, + useTLS: true, + fakeNonBlockingIO: true, + }, + { + name: "TCP with Real Non-blocking IO", + makeConns: makeTCPConns, + useTLS: false, + fakeNonBlockingIO: false, + }, + { + name: "TLS over TCP with Real Non-blocking IO", + makeConns: makeTCPConns, + useTLS: true, + fakeNonBlockingIO: false, + }, + } { + t.Run(tt.name, func(t *testing.T) { + local, remote := tt.makeConns(t) + + // Just to be sure both ends get closed. Also, it retains a reference so one side of the connection doesn't get + // garbage collected. This could happen when a test is testing against a non-responsive remote. Since it never + // uses remote it may be garbage collected leading to the connection being closed. + defer local.Close() + defer remote.Close() + + var conn nbconn.Conn + netConn := nbconn.NewNetConn(local, tt.fakeNonBlockingIO) + + if tt.useTLS { + cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey) + require.NoError(t, err) + + tlsServer := tls.Server(remote, &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + serverTLSHandshakeChan := make(chan error) + go func() { + err := tlsServer.Handshake() + serverTLSHandshakeChan <- err + }() + + tlsConn, err := nbconn.TLSClient(netConn, &tls.Config{InsecureSkipVerify: true}) + require.NoError(t, err) + conn = tlsConn + + err = <-serverTLSHandshakeChan + require.NoError(t, err) + remote = tlsServer + } else { + conn = netConn + } + + f(t, conn, remote) + }) + } +} + +// makePipeConns returns a connected pair of net.Conns created with net.Pipe(). It is entirely synchronous so it is +// useful for testing an exact sequence of reads and writes with the underlying connection blocking. +func makePipeConns(t *testing.T) (local, remote net.Conn) { + local, remote = net.Pipe() + t.Cleanup(func() { + local.Close() + remote.Close() + }) + + return local, remote +} + +// makeTCPConns returns a connected pair of net.Conns running over TCP on localhost. +func makeTCPConns(t *testing.T) (local, remote net.Conn) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + type acceptResultT struct { + conn net.Conn + err error + } + acceptChan := make(chan acceptResultT) + + go func() { + conn, err := ln.Accept() + acceptChan <- acceptResultT{conn: conn, err: err} + }() + + local, err = net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + + acceptResult := <-acceptChan + require.NoError(t, acceptResult.err) + + remote = acceptResult.conn + + return local, remote +} + +func TestWriteIsBuffered(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + // net.Pipe is synchronous so the Write would block if not buffered. + writeBuf := []byte("test") + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + + errChan := make(chan error, 1) + go func() { + err := conn.Flush() + errChan <- err + }() + + readBuf := make([]byte, len(writeBuf)) + _, err = remote.Read(readBuf) + require.NoError(t, err) + + require.NoError(t, <-errChan) + }) +} + +func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + err := conn.SetWriteDeadline(time.Now()) + require.NoError(t, err) + + writeBuf := []byte("test") + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + }) +} + +func TestReadFlushesWriteBuffer(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + writeBuf := []byte("test") + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + + errChan := make(chan error, 2) + go func() { + readBuf := make([]byte, len(writeBuf)) + _, err := remote.Read(readBuf) + errChan <- err + + _, err = remote.Write([]byte("okay")) + errChan <- err + }() + + readBuf := make([]byte, 4) + _, err = conn.Read(readBuf) + require.NoError(t, err) + require.Equal(t, []byte("okay"), readBuf) + + require.NoError(t, <-errChan) + require.NoError(t, <-errChan) + }) +} + +func TestCloseFlushesWriteBuffer(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + writeBuf := []byte("test") + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + + errChan := make(chan error, 1) + go func() { + readBuf := make([]byte, len(writeBuf)) + _, err := remote.Read(readBuf) + errChan <- err + }() + + err = conn.Close() + require.NoError(t, err) + + require.NoError(t, <-errChan) + }) +} + +// This test exercises the non-blocking write path. Because writes are buffered it is difficult trigger this with +// certainty and visibility. So this test tries to trigger what would otherwise be a deadlock by both sides writing +// large values. +func TestInternalNonBlockingWrite(t *testing.T) { + const deadlockSize = 4 * 1024 * 1024 + + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + writeBuf := make([]byte, deadlockSize) + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, deadlockSize, n) + + errChan := make(chan error, 1) + go func() { + remoteWriteBuf := make([]byte, deadlockSize) + _, err := remote.Write(remoteWriteBuf) + if err != nil { + errChan <- err + return + } + + readBuf := make([]byte, deadlockSize) + _, err = io.ReadFull(remote, readBuf) + errChan <- err + }() + + readBuf := make([]byte, deadlockSize) + _, err = conn.Read(readBuf) + require.NoError(t, err) + + err = conn.Close() + require.NoError(t, err) + + require.NoError(t, <-errChan) + }) +} + +func TestInternalNonBlockingWriteWithDeadline(t *testing.T) { + const deadlockSize = 4 * 1024 * 1024 + + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + writeBuf := make([]byte, deadlockSize) + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, deadlockSize, n) + + err = conn.SetDeadline(time.Now().Add(100 * time.Millisecond)) + require.NoError(t, err) + + err = conn.Flush() + require.Error(t, err) + }) +} + +func TestNonBlockingRead(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + err := conn.SetReadDeadline(nbconn.NonBlockingDeadline) + require.NoError(t, err) + + buf := make([]byte, 4) + n, err := conn.Read(buf) + require.ErrorIs(t, err, nbconn.ErrWouldBlock) + require.EqualValues(t, 0, n) + + errChan := make(chan error, 1) + go func() { + _, err := remote.Write([]byte("okay")) + errChan <- err + }() + + err = conn.SetReadDeadline(time.Time{}) + require.NoError(t, err) + + n, err = conn.Read(buf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + }) +} + +func TestReadPreviouslyBuffered(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flush must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + readBuf := make([]byte, 5) + n, err := conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 5, n) + require.Equal(t, []byte("alpha"), readBuf) + }) +} + +func TestReadMoreThanPreviouslyBufferedDoesNotBlock(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flush must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + readBuf := make([]byte, 10) + n, err := conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 5, n) + require.Equal(t, []byte("alpha"), readBuf[:n]) + }) +} + +func TestReadPreviouslyBufferedPartialRead(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flush must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + readBuf := make([]byte, 2) + n, err := conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 2, n) + require.Equal(t, []byte("al"), readBuf) + + readBuf = make([]byte, 3) + n, err = conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 3, n) + require.Equal(t, []byte("pha"), readBuf) + }) +} + +func TestReadMultiplePreviouslyBuffered(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + _, err = remote.Write([]byte("beta")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flush must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + readBuf := make([]byte, 9) + n, err := io.ReadFull(conn, readBuf) + require.NoError(t, err) + require.EqualValues(t, 9, n) + require.Equal(t, []byte("alphabeta"), readBuf) + }) +} + +func TestReadPreviouslyBufferedAndReadMore(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + + flushCompleteChan := make(chan struct{}) + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + <-flushCompleteChan + + _, err = remote.Write([]byte("beta")) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flush must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + close(flushCompleteChan) + + readBuf := make([]byte, 9) + + n, err := io.ReadFull(conn, readBuf) + require.NoError(t, err) + require.EqualValues(t, 9, n) + require.Equal(t, []byte("alphabeta"), readBuf) + + err = <-errChan + require.NoError(t, err) + }) +} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index c8b41f84..002db39a 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -13,9 +13,10 @@ import ( "net" "strconv" "strings" - "sync" "time" + "github.com/jackc/pgx/v5/internal/iobufpool" + "github.com/jackc/pgx/v5/internal/nbconn" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" "github.com/jackc/pgx/v5/pgproto3" @@ -75,11 +76,6 @@ type PgConn struct { status byte // One of connStatus* constants - bufferingReceive bool - bufferingReceiveMux sync.Mutex - bufferingReceiveMsg pgproto3.BackendMessage - bufferingReceiveErr error - peekedMsg pgproto3.BackendMessage // Reusable / preallocated resources @@ -234,13 +230,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } return nil, &connectError{config: config, msg: "dial error", err: err} } + netConn = nbconn.NewNetConn(netConn, false) pgConn.conn = netConn pgConn.contextWatcher = newContextWatcher(netConn) pgConn.contextWatcher.Watch(ctx) if fallbackConfig.TLSConfig != nil { - tlsConn, err := startTLS(netConn, fallbackConfig.TLSConfig) + tlsConn, err := startTLS(netConn.(*nbconn.NetConn), fallbackConfig.TLSConfig) pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { netConn.Close() @@ -356,7 +353,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { ) } -func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { +func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (net.Conn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return nil, err @@ -371,7 +368,12 @@ func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { return nil, errors.New("server refused TLS connection") } - return tls.Client(conn, tlsConfig), nil + tlsConn, err := nbconn.TLSClient(conn, tlsConfig) + if err != nil { + return nil, err + } + + return tlsConn, nil } func (pgConn *PgConn) txPasswordMessage(password string) (err error) { @@ -385,24 +387,6 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } -func (pgConn *PgConn) signalMessage() chan struct{} { - if pgConn.bufferingReceive { - panic("BUG: signalMessage when already in progress") - } - - pgConn.bufferingReceive = true - pgConn.bufferingReceiveMux.Lock() - - ch := make(chan struct{}) - go func() { - pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() - pgConn.bufferingReceiveMux.Unlock() - close(ch) - }() - - return ch -} - // ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the // connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages // are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger @@ -442,25 +426,13 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { return pgConn.peekedMsg, nil } - var msg pgproto3.BackendMessage - var err error - if pgConn.bufferingReceive { - pgConn.bufferingReceiveMux.Lock() - msg = pgConn.bufferingReceiveMsg - err = pgConn.bufferingReceiveErr - pgConn.bufferingReceiveMux.Unlock() - pgConn.bufferingReceive = false - - // If a timeout error happened in the background try the read again. - var netErr net.Error - if errors.As(err, &netErr) && netErr.Timeout() { - msg, err = pgConn.frontend.Receive() - } - } else { - msg, err = pgConn.frontend.Receive() - } + msg, err := pgConn.frontend.Receive() if err != nil { + if errors.Is(err, nbconn.ErrWouldBlock) { + return nil, err + } + // Close on anything other than timeout error - everything else is fatal var netErr net.Error isNetErr := errors.As(err, &netErr) @@ -479,13 +451,6 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { msg, err := pgConn.peekMessage() if err != nil { - // Close on anything other than timeout error - everything else is fatal - var netErr net.Error - isNetErr := errors.As(err, &netErr) - if !(isNetErr && netErr.Timeout()) { - pgConn.asyncClose() - } - return nil, err } pgConn.peekedMsg = nil @@ -1173,62 +1138,58 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // Send copy to command pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() return CommandTag{}, err } - // Send copy data - abortCopyChan := make(chan struct{}) - copyErrChan := make(chan error, 1) - signalMessageChan := pgConn.signalMessage() - senderDoneChan := make(chan struct{}) - - go func() { - defer close(senderDoneChan) - - buf := make([]byte, 0, 65536) - buf = append(buf, 'd') - sp := len(buf) - - for { - n, readErr := r.Read(buf[5:cap(buf)]) - if n > 0 { - buf = buf[0 : n+5] - pgio.SetInt32(buf[sp:], int32(n+4)) - - writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf) - if writeErr != nil { - // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. - pgConn.conn.Close() - - copyErrChan <- writeErr - return - } - } - if readErr != nil { - copyErrChan <- readErr - return - } - - select { - case <-abortCopyChan: - return - default: - } + err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline) + if err != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + nonblocking := true + defer func() { + if nonblocking { + pgConn.conn.SetReadDeadline(time.Time{}) } }() - var pgErr error - var copyErr error - for copyErr == nil && pgErr == nil { - select { - case copyErr = <-copyErrChan: - case <-signalMessageChan: + buf := iobufpool.Get(65536) + defer iobufpool.Put(buf) + buf[0] = 'd' + + var readErr, pgErr error + for pgErr == nil { + // Read chunk from r. + var n int + n, readErr = r.Read(buf[5:cap(buf)]) + + // Send chunk to PostgreSQL. + if n > 0 { + buf = buf[0 : n+5] + pgio.SetInt32(buf[1:], int32(n+4)) + + writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf) + if writeErr != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + } + + // Abort loop if there was a read error. + if readErr != nil { + break + } + + // Read messages until error or none available. + for pgErr == nil { msg, err := pgConn.receiveMessage() if err != nil { + if errors.Is(err, nbconn.ErrWouldBlock) { + break + } pgConn.asyncClose() return CommandTag{}, preferContextOverNetTimeoutError(ctx, err) } @@ -1236,18 +1197,22 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co switch msg := msg.(type) { case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) - default: - signalMessageChan = pgConn.signalMessage() + break } } } - close(abortCopyChan) - <-senderDoneChan - if copyErr == io.EOF || pgErr != nil { + err = pgConn.conn.SetReadDeadline(time.Time{}) + if err != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + nonblocking = false + + if readErr == io.EOF || pgErr != nil { pgConn.frontend.Send(&pgproto3.CopyDone{}) } else { - pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()}) + pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()}) } err = pgConn.frontend.Flush() if err != nil { @@ -1603,18 +1568,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) - // A large batch can deadlock without concurrent reading and writing. If the Write fails the underlying net.Conn is - // closed. This is all that can be done without introducing a race condition or adding a concurrent safe communication - // channel to relay the error back. The practical effect of this is that the underlying Write error is not reported. - // The error the code reading the batch results receives will be a closed connection error. - // - // See https://github.com/jackc/pgx/issues/374. - go func() { - _, err := pgConn.conn.Write(batch.buf) - if err != nil { - pgConn.conn.Close() - } - }() + _, err := pgConn.conn.Write(batch.buf) + if err != nil { + multiResult.closed = true + multiResult.err = err + pgConn.unlock() + return multiResult + } return multiResult } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index fdce6e7d..07b68995 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -1849,13 +1849,14 @@ func TestConnCancelRequest(t *testing.T) { multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") - // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a - // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a - // few milliseconds. - time.Sleep(50 * time.Millisecond) + go func() { + // The query is actually sent when multiResult.NextResult() is called. So wait to ensure it is sent. + // Once Flush is available this could use that instead. + time.Sleep(500 * time.Millisecond) - err = pgConn.CancelRequest(context.Background()) - require.NoError(t, err) + err = pgConn.CancelRequest(context.Background()) + require.NoError(t, err) + }() for multiResult.NextResult() { } @@ -2027,6 +2028,36 @@ func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) { require.Error(t, err) } +// https://github.com/jackc/pgconn/issues/27 +func TestConnLargeResponseWhileWritingDoesNotDeadlock(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), "set client_min_messages = debug5").ReadAll() + require.NoError(t, err) + + // The actual contents of this test aren't important. What's important is a large amount of data to be written and + // because of client_min_messages = debug5 the server will return a large amount of data. + + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { From 82ca09e645fbf6f4be0545418917aac161eb5e16 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Jun 2022 13:31:41 -0500 Subject: [PATCH 1054/1158] Numeric infinity only supported on PG 14+ Move to PG 14+ specific test --- pgtype/numeric_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 693fe6f2..071f0c24 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -115,8 +115,6 @@ func TestNumericCodec(t *testing.T) { {nil, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, {mustParseNumeric(t, "1"), new(string), isExpectedEq("1")}, {pgtype.Numeric{NaN: true, Valid: true}, new(string), isExpectedEq("NaN")}, - {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, new(string), isExpectedEq("Infinity")}, - {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(string), isExpectedEq("-Infinity")}, }) pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int8", []pgxtest.ValueRoundTripTest{ @@ -137,6 +135,8 @@ func TestNumericCodecInfinity(t *testing.T) { {float32(math.Inf(-1)), new(float32), isExpectedEq(float32(math.Inf(-1)))}, {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true})}, {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true})}, + {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, new(string), isExpectedEq("Infinity")}, + {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(string), isExpectedEq("-Infinity")}, }) } From 125ee9670eb3a779386a409bbbb6a870193b18d9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Jun 2022 13:43:16 -0500 Subject: [PATCH 1055/1158] Test TLS connection with pg_stat_ssl Because of the nbconn wrapper it is no longer possible to check if the conn is a *tls.Conn directly. This is actually a more reliable test anyway. --- pgconn/pgconn_test.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 07b68995..357a99c1 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -4,7 +4,6 @@ import ( "bytes" "compress/gzip" "context" - "crypto/tls" "errors" "fmt" "io" @@ -66,9 +65,11 @@ func TestConnectTLS(t *testing.T) { conn, err := pgconn.Connect(context.Background(), connString) require.NoError(t, err) - if _, ok := conn.Conn().(*tls.Conn); !ok { - t.Error("not a TLS connection") - } + result := conn.ExecParams(context.Background(), `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, 1) + require.Len(t, result.Rows[0], 1) + require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection") closeConn(t, conn) } From b068d537532c820a412f9b1b5e582afc23ae930c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Jun 2022 14:07:48 -0500 Subject: [PATCH 1056/1158] Fix race in test Goroutine should have it's own err var instead of sharing. --- pgconn/pgconn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 357a99c1..87936225 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -1855,7 +1855,7 @@ func TestConnCancelRequest(t *testing.T) { // Once Flush is available this could use that instead. time.Sleep(500 * time.Millisecond) - err = pgConn.CancelRequest(context.Background()) + err := pgConn.CancelRequest(context.Background()) require.NoError(t, err) }() From 72b1dcff2fa42a0c15de9369116a626416f94610 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Jun 2022 15:55:09 -0500 Subject: [PATCH 1057/1158] Add pgconn.CheckConn --- bench_test.go | 11 +++++++- internal/nbconn/nbconn.go | 49 +++++++++++++++++++++++++--------- internal/nbconn/nbconn_test.go | 29 ++++++++++++++++++++ pgconn/pgconn.go | 30 ++++++++++++++------- pgconn/pgconn_test.go | 28 +++++++++++++++++++ 5 files changed, 124 insertions(+), 23 deletions(-) diff --git a/bench_test.go b/bench_test.go index c441b374..31b3b38e 100644 --- a/bench_test.go +++ b/bench_test.go @@ -13,6 +13,7 @@ import ( "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/internal/nbconn" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/require" @@ -1236,7 +1237,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) { } type queryRecorder struct { - conn net.Conn + conn nbconn.Conn writeBuf []byte readCount int } @@ -1252,6 +1253,14 @@ func (qr *queryRecorder) Write(b []byte) (n int, err error) { return qr.conn.Write(b) } +func (qr *queryRecorder) BufferReadUntilBlock() error { + return qr.conn.BufferReadUntilBlock() +} + +func (qr *queryRecorder) Flush() error { + return qr.conn.Flush() +} + func (qr *queryRecorder) Close() error { return qr.conn.Close() } diff --git a/internal/nbconn/nbconn.go b/internal/nbconn/nbconn.go index 00d0e420..16c4b713 100644 --- a/internal/nbconn/nbconn.go +++ b/internal/nbconn/nbconn.go @@ -13,6 +13,7 @@ package nbconn import ( "crypto/tls" "errors" + "io" "net" "os" "sync" @@ -46,11 +47,16 @@ func (*wouldBlockError) Error() string { func (*wouldBlockError) Timeout() bool { return true } func (*wouldBlockError) Temporary() bool { return true } -// Conn is a net.Conn where Write never blocks and always succeeds. Flush must be called to actually write to the -// underlying connection. +// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to +// the underlying connection. type Conn interface { net.Conn + + // Flush flushes any buffered writes. Flush() error + + // BufferReadUntilBlock reads and buffers any sucessfully read bytes until the read would block. + BufferReadUntilBlock() error } // NetConn is a non-blocking net.Conn wrapper. It implements net.Conn. @@ -303,24 +309,35 @@ func (c *NetConn) flush() error { return nil } +func (c *NetConn) BufferReadUntilBlock() error { + for { + buf := iobufpool.Get(8 * 1024) + n, err := c.nonblockingRead(buf) + if n > 0 { + buf = buf[:n] + c.readQueue.pushBack(buf) + } + + if err != nil { + if errors.Is(err, ErrWouldBlock) { + return nil + } else { + return err + } + } + } +} + func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) { stopChan = make(chan struct{}) errChan = make(chan error, 1) go func() { for { - buf := iobufpool.Get(8 * 1024) - n, err := c.nonblockingRead(buf) - if n > 0 { - buf = buf[:n] - c.readQueue.pushBack(buf) - } - + err := c.BufferReadUntilBlock() if err != nil { - if !errors.Is(err, ErrWouldBlock) { - errChan <- err - return - } + errChan <- err + return } select { @@ -456,6 +473,11 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { return n, err } + // syscall read did not return an error and 0 bytes were read means EOF. + if n == 0 { + return 0, io.EOF + } + return n, nil } @@ -494,6 +516,7 @@ type TLSConn struct { func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) } func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) } +func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() } func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() } func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() } func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() } diff --git a/internal/nbconn/nbconn_test.go b/internal/nbconn/nbconn_test.go index 2db47039..de32b9c7 100644 --- a/internal/nbconn/nbconn_test.go +++ b/internal/nbconn/nbconn_test.go @@ -2,6 +2,7 @@ package nbconn_test import ( "crypto/tls" + "errors" "io" "net" "strings" @@ -345,6 +346,34 @@ func TestNonBlockingRead(t *testing.T) { }) } +func TestBufferNonBlockingRead(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + err := conn.BufferReadUntilBlock() + require.NoError(t, err) + + errChan := make(chan error, 1) + go func() { + _, err := remote.Write([]byte("okay")) + errChan <- err + }() + + for i := 0; i < 1000; i++ { + err = conn.BufferReadUntilBlock() + if !errors.Is(err, nbconn.ErrWouldBlock) { + break + } + time.Sleep(time.Millisecond) + } + require.NoError(t, err) + + buf := make([]byte, 4) + n, err := conn.Read(buf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + require.Equal(t, []byte("okay"), buf) + }) +} + func TestReadPreviouslyBuffered(t *testing.T) { testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 002db39a..306b2e16 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -65,7 +65,7 @@ type NotificationHandler func(*PgConn, *Notification) // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { - conn net.Conn // the underlying TCP or unix domain socket connection + conn nbconn.Conn // the non-blocking wrapper for the underlying TCP or unix domain socket connection pid uint32 // backend pid secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server @@ -230,22 +230,22 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } return nil, &connectError{config: config, msg: "dial error", err: err} } - netConn = nbconn.NewNetConn(netConn, false) + nbNetConn := nbconn.NewNetConn(netConn, false) - pgConn.conn = netConn - pgConn.contextWatcher = newContextWatcher(netConn) + pgConn.conn = nbNetConn + pgConn.contextWatcher = newContextWatcher(nbNetConn) pgConn.contextWatcher.Watch(ctx) if fallbackConfig.TLSConfig != nil { - tlsConn, err := startTLS(netConn.(*nbconn.NetConn), fallbackConfig.TLSConfig) + nbTLSConn, err := startTLS(nbNetConn, fallbackConfig.TLSConfig) pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { netConn.Close() return nil, &connectError{config: config, msg: "tls error", err: err} } - pgConn.conn = tlsConn - pgConn.contextWatcher = newContextWatcher(tlsConn) + pgConn.conn = nbTLSConn + pgConn.contextWatcher = newContextWatcher(nbTLSConn) pgConn.contextWatcher.Watch(ctx) } @@ -353,7 +353,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { ) } -func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (net.Conn, error) { +func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return nil, err @@ -1596,6 +1596,18 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) { return strings.Replace(s, "'", "''", -1), nil } +// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by reading and +// buffering until the read would block or an error occurs. This can be used to check if the server has closed the +// connection. If this is done immediately before sending a query it reduces the chances a query will be sent that fails +// without the client knowing whether the server received it or not. +func (pgConn *PgConn) CheckConn() error { + err := pgConn.conn.BufferReadUntilBlock() + if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) { + return err + } + return nil +} + // makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory. func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag { ct := make([]byte, len(buf)) @@ -1608,7 +1620,7 @@ func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag { // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // compatibility. type HijackedConn struct { - Conn net.Conn // the underlying TCP or unix domain socket connection + Conn nbconn.Conn // the non-blocking wrapper of the underlying TCP or unix domain socket connection PID uint32 // backend pid SecretKey uint32 // key to use to send a cancel query message to the server ParameterStatuses map[string]string // parameters that have been reported by the server diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 87936225..f517f268 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -2059,6 +2059,34 @@ func TestConnLargeResponseWhileWritingDoesNotDeadlock(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnCheckConn(t *testing.T) { + t.Parallel() + + c1, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_TCP_CONN_STRING")) + require.NoError(t, err) + defer c1.Close(context.Background()) + + if c1.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") + } + + err = c1.CheckConn() + require.NoError(t, err) + + c2, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_TCP_CONN_STRING")) + require.NoError(t, err) + defer c2.Close(context.Background()) + + _, err = c2.Exec(context.Background(), fmt.Sprintf("select pg_terminate_backend(%d)", c1.PID())).ReadAll() + require.NoError(t, err) + + // Give a little time for the signal to actually kill the backend. + time.Sleep(500 * time.Millisecond) + + err = c1.CheckConn() + require.Error(t, err) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { From 9afd320b9e059ea23a07ba83205f9ed0f7fe288b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Jun 2022 16:05:20 -0500 Subject: [PATCH 1058/1158] Fix flickering test in CI While this test always worked on my machine, it flickered in CI. And to be fair the test can't guarantee the condition it is testing. Work around this by trying many times before admitting failure. --- internal/iobufpool/iobufpool_test.go | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/internal/iobufpool/iobufpool_test.go b/internal/iobufpool/iobufpool_test.go index 51b08215..09e258bb 100644 --- a/internal/iobufpool/iobufpool_test.go +++ b/internal/iobufpool/iobufpool_test.go @@ -5,7 +5,6 @@ import ( "github.com/jackc/pgx/v5/internal/iobufpool" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestGetCap(t *testing.T) { @@ -70,9 +69,17 @@ func TestPutHandlesWrongSizedBuffers(t *testing.T) { } func TestPutGetBufferReuse(t *testing.T) { - buf := iobufpool.Get(4) - buf[0] = 1 - iobufpool.Put(buf) - buf = iobufpool.Get(4) - require.Equal(t, byte(1), buf[0]) + // There is no way to guarantee a buffer will be reused. It should be, but a GC between the Put and the Get will cause + // it not to be. So try many times. + for i := 0; i < 100000; i++ { + buf := iobufpool.Get(4) + buf[0] = 1 + iobufpool.Put(buf) + buf = iobufpool.Get(4) + if buf[0] == 1 { + return + } + } + + t.Error("buffer was never reused") } From 26eda0f86d2b9e357be5e66ec1b4421660291d0c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Jun 2022 16:55:09 -0500 Subject: [PATCH 1059/1158] Check for ENV conn string and skip test if missing --- pgconn/pgconn_test.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index f517f268..b47f17d6 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -2062,7 +2062,14 @@ func TestConnLargeResponseWhileWritingDoesNotDeadlock(t *testing.T) { func TestConnCheckConn(t *testing.T) { t.Parallel() - c1, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_TCP_CONN_STRING")) + // Intentionally using TCP connection for more predictable close behavior. (Not sure if Unix domain sockets would behave subtlely different.) + + connString := os.Getenv(os.Getenv("PGX_TEST_TCP_CONN_STRING")) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + c1, err := pgconn.Connect(context.Background(), connString) require.NoError(t, err) defer c1.Close(context.Background()) @@ -2073,7 +2080,7 @@ func TestConnCheckConn(t *testing.T) { err = c1.CheckConn() require.NoError(t, err) - c2, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_TCP_CONN_STRING")) + c2, err := pgconn.Connect(context.Background(), connString) require.NoError(t, err) defer c2.Close(context.Background()) From 03da9fcec609b82d7ebc950069e5963422ef0059 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Jun 2022 17:58:53 -0500 Subject: [PATCH 1060/1158] Check conn liveness before using when idle for more than 1 second Implemented in pgxpool.Pool and database/sql. https://github.com/jackc/pgx/issues/672 --- pgxpool/pool.go | 9 +++++++ pgxpool/pool_test.go | 52 +++++++++++++++++++++++++++++++++++ stdlib/sql.go | 19 +++++++++---- stdlib/sql_test.go | 64 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 139 insertions(+), 5 deletions(-) diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 0446a245..de4e1066 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -417,6 +417,15 @@ func (p *Pool) Acquire(ctx context.Context) (*Conn, error) { } cr := res.Value() + + if res.IdleDuration() > time.Second { + err := cr.conn.PgConn().CheckConn() + if err != nil { + res.Destroy() + continue + } + } + if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { return cr.getConn(p, res), nil } diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index a6d0a083..3e3058d2 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -11,6 +11,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -141,6 +142,57 @@ func TestPoolAcquireAndConnHijack(t *testing.T) { require.Equal(t, int32(1), n) } +func TestPoolAcquireChecksIdleConns(t *testing.T) { + t.Parallel() + + controllerConn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer controllerConn.Close(context.Background()) + pgxtest.SkipCockroachDB(t, controllerConn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") + + pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + var conns []*pgxpool.Conn + for i := 0; i < 3; i++ { + c, err := pool.Acquire(context.Background()) + require.NoError(t, err) + conns = append(conns, c) + } + + require.EqualValues(t, 3, pool.Stat().TotalConns()) + + var pids []uint32 + for _, c := range conns { + pids = append(pids, c.Conn().PgConn().PID()) + c.Release() + } + + _, err = controllerConn.Exec(context.Background(), `select pg_terminate_backend(n) from unnest($1::int[]) n`, pids) + require.NoError(t, err) + + // All conns are dead they don't know it and neither does the pool. + require.EqualValues(t, 3, pool.Stat().TotalConns()) + + // Wait long enough so the pool will realize it needs to check the connections. + time.Sleep(time.Second) + + // Pool should try all existing connections and find them dead, then create a new connection which should successfully ping. + err = pool.Ping(context.Background()) + require.NoError(t, err) + + // The original 3 conns should have been terminated and the a new conn established for the ping. + require.EqualValues(t, 1, pool.Stat().TotalConns()) + c, err := pool.Acquire(context.Background()) + require.NoError(t, err) + + cPID := c.Conn().PgConn().PID() + c.Release() + + require.NotContains(t, pids, cPID) +} + func TestPoolAcquireFunc(t *testing.T) { t.Parallel() diff --git a/stdlib/sql.go b/stdlib/sql.go index e4565227..8a24c4c5 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -308,11 +308,12 @@ func UnregisterConnConfig(connStr string) { } type Conn struct { - conn *pgx.Conn - psCount int64 // Counter used for creating unique prepared statement names - driver *Driver - connConfig pgx.ConnConfig - resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused + conn *pgx.Conn + psCount int64 // Counter used for creating unique prepared statement names + driver *Driver + connConfig pgx.ConnConfig + resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused + lastResetSessionTime time.Time } // Conn returns the underlying *pgx.Conn @@ -450,6 +451,14 @@ func (c *Conn) ResetSession(ctx context.Context) error { return driver.ErrBadConn } + now := time.Now() + if now.Sub(c.lastResetSessionTime) > time.Second { + if err := c.conn.PgConn().CheckConn(); err != nil { + return driver.ErrBadConn + } + } + c.lastResetSessionTime = now + return c.resetSessionFunc(ctx, c.conn) } diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 9106df62..ee038add 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -1154,3 +1154,67 @@ func TestResetSessionHookCalled(t *testing.T) { require.True(t, mockCalled) } + +func TestCheckIdleConn(t *testing.T) { + controllerConn, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeDB(t, controllerConn) + + skipCockroachDB(t, controllerConn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") + + db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeDB(t, db) + + var conns []*sql.Conn + for i := 0; i < 3; i++ { + c, err := db.Conn(context.Background()) + require.NoError(t, err) + conns = append(conns, c) + } + + require.EqualValues(t, 3, db.Stats().OpenConnections) + + var pids []uint32 + for _, c := range conns { + err := c.Raw(func(driverConn any) error { + pids = append(pids, driverConn.(*stdlib.Conn).Conn().PgConn().PID()) + return nil + }) + require.NoError(t, err) + err = c.Close() + require.NoError(t, err) + } + + // The database/sql connection pool seems to automatically close idle connections to only keep 2 alive. + // require.EqualValues(t, 3, db.Stats().OpenConnections) + + _, err = controllerConn.ExecContext(context.Background(), `select pg_terminate_backend(n) from unnest($1::int[]) n`, pids) + require.NoError(t, err) + + // All conns are dead they don't know it and neither does the pool. But because of database/sql automatically closing + // idle connections we can't be sure how many we should have. require.EqualValues(t, 3, db.Stats().OpenConnections) + + // Wait long enough so the pool will realize it needs to check the connections. + time.Sleep(time.Second) + + // Pool should try all existing connections and find them dead, then create a new connection which should successfully ping. + err = db.PingContext(context.Background()) + require.NoError(t, err) + + // The original 3 conns should have been terminated and the a new conn established for the ping. + require.EqualValues(t, 1, db.Stats().OpenConnections) + c, err := db.Conn(context.Background()) + require.NoError(t, err) + + var cPID uint32 + err = c.Raw(func(driverConn any) error { + cPID = driverConn.(*stdlib.Conn).Conn().PgConn().PID() + return nil + }) + require.NoError(t, err) + err = c.Close() + require.NoError(t, err) + + require.NotContains(t, pids, cPID) +} From 585022440b6f015d2c54f80cd46b55e1499472d7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Jun 2022 18:11:39 -0500 Subject: [PATCH 1061/1158] Update changelog --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a7fc57ef..51582553 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,12 @@ release work due to releasing multiple packages, and less clear changelogs. `CommandTag` is now an opaque type instead of directly exposing an underlying `[]byte`. +`Trace()` method adds low level message tracing similar to the `PQtrace` function in `libpq`. + +pgconn now uses non-blocking IO. This is a significant internal restructuring, but it should not cause any visible changes on its own. However, it is important in implementing other new features. + +`CheckConn()` checks a connection's liveness by doing a non-blocking read. This can be used to detect database restarts or network interruptions without executing a query or a ping. + ## pgtype The `pgtype` package has been significantly changed. From ed3e9f1dd4dc215d11ff7b615fa76d5a746cc4dd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 1 Jul 2022 15:33:12 -0500 Subject: [PATCH 1062/1158] Check for more specific error --- internal/nbconn/nbconn_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/nbconn/nbconn_test.go b/internal/nbconn/nbconn_test.go index de32b9c7..c4a8e65b 100644 --- a/internal/nbconn/nbconn_test.go +++ b/internal/nbconn/nbconn_test.go @@ -318,6 +318,7 @@ func TestInternalNonBlockingWriteWithDeadline(t *testing.T) { err = conn.Flush() require.Error(t, err) + require.Contains(t, err.Error(), "i/o timeout") }) } From 25935a39b6edcd7a885408fca5ffe20df753a82c Mon Sep 17 00:00:00 2001 From: "sergey.bashilov" Date: Fri, 17 Jun 2022 14:51:56 +0300 Subject: [PATCH 1063/1158] add prefer-standby target_session_attrs --- config.go | 22 +++++++++++++++++++++- config_test.go | 45 ++++++++++++++++++++++++--------------------- errors.go | 17 +++++++++++++++++ pgconn.go | 36 ++++++++++++++++++++++++++++-------- 4 files changed, 90 insertions(+), 30 deletions(-) diff --git a/config.go b/config.go index e141a2f8..fc5b3c0f 100644 --- a/config.go +++ b/config.go @@ -50,6 +50,8 @@ type Config struct { // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. ValidateConnect ValidateConnectFunc + HasPreferStandbyTargetSessionAttr bool + // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables // or prepare statements). If this returns an error the connection attempt fails. AfterConnect AfterConnectFunc @@ -367,7 +369,10 @@ func ParseConfig(connString string) (*Config, error) { config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary case "standby": config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby - case "any", "prefer-standby": + case "prefer-standby": + config.ValidateConnect = ValidateConnectTargetSessionAttrsPrefferStandby + config.HasPreferStandbyTargetSessionAttr = true + case "any": // do nothing default: return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} @@ -810,3 +815,18 @@ func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgCon return nil } + +// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible +// target_session_attrs=prefer-standby. +func ValidateConnectTargetSessionAttrsPrefferStandby(ctx context.Context, pgConn *PgConn) error { + result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() + if result.Err != nil { + return result.Err + } + + if string(result.Rows[0][0]) != "t" { + return &preferStanbyNotFoundError{err: errors.New("server is not in hot standby mode")} + } + + return nil +} diff --git a/config_test.go b/config_test.go index a28db3d6..6311f1f1 100644 --- a/config_test.go +++ b/config_test.go @@ -584,13 +584,13 @@ func TestParseConfig(t *testing.T) { name: "target_session_attrs primary", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=primary", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPrimary, }, }, @@ -598,13 +598,13 @@ func TestParseConfig(t *testing.T) { name: "target_session_attrs standby", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=standby", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsStandby, }, }, @@ -612,13 +612,15 @@ func TestParseConfig(t *testing.T) { name: "target_session_attrs prefer-standby", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=prefer-standby", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPrefferStandby, + HasPreferStandbyTargetSessionAttr: true, }, }, { @@ -783,6 +785,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName // Can't test function equality, so just test that they are set or not. assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) + assert.Equalf(t, expected.HasPreferStandbyTargetSessionAttr, actual.HasPreferStandbyTargetSessionAttr, "%s - HasPreferStandbyTargetSessionAttr", testName) if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { if expected.TLSConfig != nil { diff --git a/errors.go b/errors.go index a32b29c9..2bc74df7 100644 --- a/errors.go +++ b/errors.go @@ -219,3 +219,20 @@ func redactURL(u *url.URL) string { } return u.String() } + +type preferStanbyNotFoundError struct { + err error + safeToRetry bool +} + +func (e *preferStanbyNotFoundError) Error() string { + return fmt.Sprintf("standby server not found: %s", e.err.Error()) +} + +func (e *preferStanbyNotFoundError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *preferStanbyNotFoundError) Unwrap() error { + return e.err +} diff --git a/pgconn.go b/pgconn.go index ef5b76fd..8e7ac668 100644 --- a/pgconn.go +++ b/pgconn.go @@ -148,25 +148,34 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} } + foundBestServer := false + var fallbackConfig *FallbackConfig for _, fc := range fallbackConfigs { pgConn, err = connect(ctx, config, fc) if err == nil { + foundBestServer = true break } else if pgerr, ok := err.(*PgError); ok { err = &connectError{config: config, msg: "server error", err: pgerr} - const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password - const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings - const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist - const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege - if pgerr.Code == ERRCODE_INVALID_PASSWORD || - pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION || - pgerr.Code == ERRCODE_INVALID_CATALOG_NAME || - pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { + if checkPgError(pgerr) { break } + } else if cerr, ok := err.(*connectError); ok && config.HasPreferStandbyTargetSessionAttr { + if _, ok := cerr.err.(*preferStanbyNotFoundError); ok { + fallbackConfig = fc + } } } + if !foundBestServer && fallbackConfig != nil { + config.ValidateConnect = nil + pgConn, err = connect(ctx, config, fallbackConfig) + if pgerr, ok := err.(*PgError); ok { + err = &connectError{config: config, msg: "server error", err: pgerr} + } + config.ValidateConnect = ValidateConnectTargetSessionAttrsPrefferStandby + } + if err != nil { return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError } @@ -182,6 +191,17 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err return pgConn, nil } +func checkPgError(pgerr *PgError) bool { + const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password + const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings + const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist + const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege + return pgerr.Code == ERRCODE_INVALID_PASSWORD || + pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION || + pgerr.Code == ERRCODE_INVALID_CATALOG_NAME || + pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE +} + func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) { var configs []*FallbackConfig From 1b6543f29c8c08ccfc51a8d426ea44d960ae4d3e Mon Sep 17 00:00:00 2001 From: "sergey.bashilov" Date: Mon, 20 Jun 2022 12:15:15 +0300 Subject: [PATCH 1064/1158] fix typos --- config.go | 8 ++++---- config_test.go | 2 +- errors.go | 8 ++++---- pgconn.go | 24 ++++++++++-------------- 4 files changed, 19 insertions(+), 23 deletions(-) diff --git a/config.go b/config.go index fc5b3c0f..dac7b95b 100644 --- a/config.go +++ b/config.go @@ -370,7 +370,7 @@ func ParseConfig(connString string) (*Config, error) { case "standby": config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby case "prefer-standby": - config.ValidateConnect = ValidateConnectTargetSessionAttrsPrefferStandby + config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby config.HasPreferStandbyTargetSessionAttr = true case "any": // do nothing @@ -816,16 +816,16 @@ func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgCon return nil } -// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible +// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible // target_session_attrs=prefer-standby. -func ValidateConnectTargetSessionAttrsPrefferStandby(ctx context.Context, pgConn *PgConn) error { +func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error { result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() if result.Err != nil { return result.Err } if string(result.Rows[0][0]) != "t" { - return &preferStanbyNotFoundError{err: errors.New("server is not in hot standby mode")} + return &preferStandbyNotFoundError{err: errors.New("server is not in hot standby mode")} } return nil diff --git a/config_test.go b/config_test.go index 6311f1f1..c8d8cee6 100644 --- a/config_test.go +++ b/config_test.go @@ -619,7 +619,7 @@ func TestParseConfig(t *testing.T) { Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, - ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPrefferStandby, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPreferStandby, HasPreferStandbyTargetSessionAttr: true, }, }, diff --git a/errors.go b/errors.go index 2bc74df7..7ed8889c 100644 --- a/errors.go +++ b/errors.go @@ -220,19 +220,19 @@ func redactURL(u *url.URL) string { return u.String() } -type preferStanbyNotFoundError struct { +type preferStandbyNotFoundError struct { err error safeToRetry bool } -func (e *preferStanbyNotFoundError) Error() string { +func (e *preferStandbyNotFoundError) Error() string { return fmt.Sprintf("standby server not found: %s", e.err.Error()) } -func (e *preferStanbyNotFoundError) SafeToRetry() bool { +func (e *preferStandbyNotFoundError) SafeToRetry() bool { return e.safeToRetry } -func (e *preferStanbyNotFoundError) Unwrap() error { +func (e *preferStandbyNotFoundError) Unwrap() error { return e.err } diff --git a/pgconn.go b/pgconn.go index 8e7ac668..1a1d3505 100644 --- a/pgconn.go +++ b/pgconn.go @@ -157,11 +157,18 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err break } else if pgerr, ok := err.(*PgError); ok { err = &connectError{config: config, msg: "server error", err: pgerr} - if checkPgError(pgerr) { + const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password + const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings + const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist + const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege + if pgerr.Code == ERRCODE_INVALID_PASSWORD || + pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION || + pgerr.Code == ERRCODE_INVALID_CATALOG_NAME || + pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { break } } else if cerr, ok := err.(*connectError); ok && config.HasPreferStandbyTargetSessionAttr { - if _, ok := cerr.err.(*preferStanbyNotFoundError); ok { + if _, ok := cerr.err.(*preferStandbyNotFoundError); ok { fallbackConfig = fc } } @@ -173,7 +180,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err if pgerr, ok := err.(*PgError); ok { err = &connectError{config: config, msg: "server error", err: pgerr} } - config.ValidateConnect = ValidateConnectTargetSessionAttrsPrefferStandby + config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby } if err != nil { @@ -191,17 +198,6 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err return pgConn, nil } -func checkPgError(pgerr *PgError) bool { - const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password - const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings - const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist - const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege - return pgerr.Code == ERRCODE_INVALID_PASSWORD || - pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION || - pgerr.Code == ERRCODE_INVALID_CATALOG_NAME || - pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE -} - func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) { var configs []*FallbackConfig From 618a12a094636d97ec67a2950adc1dd7dac37241 Mon Sep 17 00:00:00 2001 From: "sergey.bashilov" Date: Fri, 24 Jun 2022 14:02:59 +0300 Subject: [PATCH 1065/1158] remove HasPreferStandbyTargetSessionAttr, rename error to indicate server is not standby --- config.go | 5 +---- config_test.go | 18 ++++++++---------- errors.go | 8 ++++---- pgconn.go | 4 ++-- 4 files changed, 15 insertions(+), 20 deletions(-) diff --git a/config.go b/config.go index dac7b95b..4ca09dda 100644 --- a/config.go +++ b/config.go @@ -50,8 +50,6 @@ type Config struct { // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. ValidateConnect ValidateConnectFunc - HasPreferStandbyTargetSessionAttr bool - // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables // or prepare statements). If this returns an error the connection attempt fails. AfterConnect AfterConnectFunc @@ -371,7 +369,6 @@ func ParseConfig(connString string) (*Config, error) { config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby case "prefer-standby": config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby - config.HasPreferStandbyTargetSessionAttr = true case "any": // do nothing default: @@ -825,7 +822,7 @@ func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn } if string(result.Rows[0][0]) != "t" { - return &preferStandbyNotFoundError{err: errors.New("server is not in hot standby mode")} + return &NotStandbyError{err: errors.New("server is not in hot standby mode")} } return nil diff --git a/config_test.go b/config_test.go index c8d8cee6..6b48ea27 100644 --- a/config_test.go +++ b/config_test.go @@ -612,15 +612,14 @@ func TestParseConfig(t *testing.T) { name: "target_session_attrs prefer-standby", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=prefer-standby", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPreferStandby, - HasPreferStandbyTargetSessionAttr: true, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPreferStandby, }, }, { @@ -785,7 +784,6 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName // Can't test function equality, so just test that they are set or not. assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) - assert.Equalf(t, expected.HasPreferStandbyTargetSessionAttr, actual.HasPreferStandbyTargetSessionAttr, "%s - HasPreferStandbyTargetSessionAttr", testName) if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { if expected.TLSConfig != nil { diff --git a/errors.go b/errors.go index 7ed8889c..9f04476d 100644 --- a/errors.go +++ b/errors.go @@ -220,19 +220,19 @@ func redactURL(u *url.URL) string { return u.String() } -type preferStandbyNotFoundError struct { +type NotStandbyError struct { err error safeToRetry bool } -func (e *preferStandbyNotFoundError) Error() string { +func (e *NotStandbyError) Error() string { return fmt.Sprintf("standby server not found: %s", e.err.Error()) } -func (e *preferStandbyNotFoundError) SafeToRetry() bool { +func (e *NotStandbyError) SafeToRetry() bool { return e.safeToRetry } -func (e *preferStandbyNotFoundError) Unwrap() error { +func (e *NotStandbyError) Unwrap() error { return e.err } diff --git a/pgconn.go b/pgconn.go index 1a1d3505..5e436ffe 100644 --- a/pgconn.go +++ b/pgconn.go @@ -167,8 +167,8 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { break } - } else if cerr, ok := err.(*connectError); ok && config.HasPreferStandbyTargetSessionAttr { - if _, ok := cerr.err.(*preferStandbyNotFoundError); ok { + } else if cerr, ok := err.(*connectError); ok { + if _, ok := cerr.err.(*NotStandbyError); ok { fallbackConfig = fc } } From cdc240d920c29140eb8980a6d650b4a570c19a04 Mon Sep 17 00:00:00 2001 From: "sergey.bashilov" Date: Fri, 24 Jun 2022 14:20:36 +0300 Subject: [PATCH 1066/1158] rename error --- config.go | 2 +- errors.go | 8 ++++---- pgconn.go | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/config.go b/config.go index 4ca09dda..8fd7efbf 100644 --- a/config.go +++ b/config.go @@ -822,7 +822,7 @@ func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn } if string(result.Rows[0][0]) != "t" { - return &NotStandbyError{err: errors.New("server is not in hot standby mode")} + return &NotPreferredError{err: errors.New("server is not in hot standby mode")} } return nil diff --git a/errors.go b/errors.go index 9f04476d..66d35584 100644 --- a/errors.go +++ b/errors.go @@ -220,19 +220,19 @@ func redactURL(u *url.URL) string { return u.String() } -type NotStandbyError struct { +type NotPreferredError struct { err error safeToRetry bool } -func (e *NotStandbyError) Error() string { +func (e *NotPreferredError) Error() string { return fmt.Sprintf("standby server not found: %s", e.err.Error()) } -func (e *NotStandbyError) SafeToRetry() bool { +func (e *NotPreferredError) SafeToRetry() bool { return e.safeToRetry } -func (e *NotStandbyError) Unwrap() error { +func (e *NotPreferredError) Unwrap() error { return e.err } diff --git a/pgconn.go b/pgconn.go index 5e436ffe..6093d17b 100644 --- a/pgconn.go +++ b/pgconn.go @@ -168,7 +168,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err break } } else if cerr, ok := err.(*connectError); ok { - if _, ok := cerr.err.(*NotStandbyError); ok { + if _, ok := cerr.err.(*NotPreferredError); ok { fallbackConfig = fc } } From a18df2374a85fcdcf8ccf6ebc2cec5dcc1156c61 Mon Sep 17 00:00:00 2001 From: "sergey.bashilov" Date: Fri, 1 Jul 2022 17:50:25 +0300 Subject: [PATCH 1067/1158] add ignore not preferred err flag in connect func --- pgconn.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pgconn.go b/pgconn.go index 6093d17b..430f4367 100644 --- a/pgconn.go +++ b/pgconn.go @@ -151,7 +151,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err foundBestServer := false var fallbackConfig *FallbackConfig for _, fc := range fallbackConfigs { - pgConn, err = connect(ctx, config, fc) + pgConn, err = connect(ctx, config, fc, false) if err == nil { foundBestServer = true break @@ -175,12 +175,10 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err } if !foundBestServer && fallbackConfig != nil { - config.ValidateConnect = nil - pgConn, err = connect(ctx, config, fallbackConfig) + pgConn, err = connect(ctx, config, fallbackConfig, true) if pgerr, ok := err.(*PgError); ok { err = &connectError{config: config, msg: "server error", err: pgerr} } - config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby } if err != nil { @@ -243,7 +241,8 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba return configs, nil } -func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { +func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig, + ignoreNotPreferredErr bool) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config pgConn.wbuf = make([]byte, 0, wbufLen) @@ -356,6 +355,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig err := config.ValidateConnect(ctx, pgConn) if err != nil { + if _, ok := err.(*NotPreferredError); ignoreNotPreferredErr && ok { + return pgConn, nil + } pgConn.conn.Close() return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} } From ae2881a23c66209ca3525000547fb4debbe8baf4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 Jul 2022 21:48:16 -0500 Subject: [PATCH 1068/1158] Add pipeline mode to pgconn --- CHANGELOG.md | 2 + pgconn/doc.go | 5 + pgconn/pgconn.go | 272 +++++++++++++++++++++++++++- pgconn/pgconn_test.go | 412 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 685 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 51582553..81b6cb26 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ pgconn now uses non-blocking IO. This is a significant internal restructuring, b `CheckConn()` checks a connection's liveness by doing a non-blocking read. This can be used to detect database restarts or network interruptions without executing a query or a ping. +pgconn now supports pipeline mode. + ## pgtype The `pgtype` package has been significantly changed. diff --git a/pgconn/doc.go b/pgconn/doc.go index cde58cd8..e3242cf4 100644 --- a/pgconn/doc.go +++ b/pgconn/doc.go @@ -18,6 +18,11 @@ Executing Multiple Queries in a Single Round Trip Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query result. The ReadAll method reads all query results into memory. +Pipeline Mode + +Pipeline mode allows sending queries without having read the results of previously sent queries. It allows +control of exactly how many and when network round trips occur. + Context Support All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 306b2e16..b386a786 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -81,6 +81,7 @@ type PgConn struct { // Reusable / preallocated resources resultReader ResultReader multiResultReader MultiResultReader + pipeline Pipeline contextWatcher *ctxwatch.ContextWatcher cleanupDone chan struct{} @@ -1242,8 +1243,9 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { - pgConn *PgConn - ctx context.Context + pgConn *PgConn + ctx context.Context + pipeline *Pipeline rr *ResultReader @@ -1276,9 +1278,13 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - mrr.pgConn.contextWatcher.Unwatch() mrr.closed = true - mrr.pgConn.unlock() + if mrr.pipeline != nil { + mrr.pipeline.expectedReadyForQueryCount-- + } else { + mrr.pgConn.contextWatcher.Unwatch() + mrr.pgConn.unlock() + } case *pgproto3.ErrorResponse: mrr.err = ErrorResponseToPgError(msg) } @@ -1341,6 +1347,7 @@ func (mrr *MultiResultReader) Close() error { type ResultReader struct { pgConn *PgConn multiResultReader *MultiResultReader + pipeline *Pipeline ctx context.Context fieldDescriptions []pgproto3.FieldDescription @@ -1429,7 +1436,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { } } - if rr.multiResultReader == nil { + if rr.multiResultReader == nil && rr.pipeline == nil { for { msg, err := rr.receiveMessage() if err != nil { @@ -1539,7 +1546,8 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor } // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a -// transaction is already in progress or SQL contains transaction control statements. +// transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing +// multiple queries in a single round trip than using pipeline mode. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ @@ -1676,3 +1684,255 @@ func Construct(hc *HijackedConn) (*PgConn, error) { return pgConn, nil } + +// Pipeline represents a connection in pipeline mode. +// +// SendPrepare, SendQueryParam, and SendQueryPrepared queue requests to the server. These requests are not written until +// pipeline is flushed by Flush or Sync. Sync must be called after the last request is queued. Requests between +// synchronization points are implicitly transactional unless explicit transaction control statements have been issued. +// +// The context the pipeline was started with is in effect for the entire life of the Pipeline. +// +// For a deeper understanding of pipeline mode see the PostgreSQL documentation for the extended query protocol +// (https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY) and the libpq pipeline mode +// (https://www.postgresql.org/docs/current/libpq-pipeline-mode.html). +type Pipeline struct { + conn *PgConn + ctx context.Context + + expectedReadyForQueryCount int + pendingSync bool + + err error + closed bool +} + +// PipelineSync is returned by GetResults when a ReadyForQuery message is received. +type PipelineSync struct{} + +// StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent +// to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection +// to normal mode. While in pipeline mode, no methods that communicate with the server may be called except +// CancelRequest and Close. ctx is in effect for entire life of the *Pipeline. +// +// Prefer ExecBatch when only sending one group of queries at once. +func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline { + if err := pgConn.lock(); err != nil { + return &Pipeline{ + closed: true, + err: err, + } + } + + pgConn.pipeline = Pipeline{ + conn: pgConn, + ctx: ctx, + } + pipeline := &pgConn.pipeline + + if ctx != context.Background() { + select { + case <-ctx.Done(): + pipeline.closed = true + pipeline.err = newContextAlreadyDoneError(ctx) + pgConn.unlock() + return pipeline + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + return pipeline +} + +// SendPrepare is the pipeline version of *PgConn.Prepare. +func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) { + if p.closed { + return + } + p.pendingSync = true + + p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) + p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) +} + +// SendQueryParams is the pipeline version of *PgConn.QueryParams. +func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { + if p.closed { + return + } + p.pendingSync = true + + p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) + p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) + p.conn.frontend.SendExecute(&pgproto3.Execute{}) +} + +// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared. +func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { + if p.closed { + return + } + p.pendingSync = true + + p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) + p.conn.frontend.SendExecute(&pgproto3.Execute{}) +} + +// Flush flushes the queued requests without establishing a synchronization point. +func (p *Pipeline) Flush() error { + if p.closed { + if p.err != nil { + return p.err + } + return errors.New("pipeline closed") + } + + err := p.conn.frontend.Flush() + if err != nil { + err = preferContextOverNetTimeoutError(p.ctx, err) + + p.conn.asyncClose() + + p.conn.contextWatcher.Unwatch() + p.conn.unlock() + p.closed = true + p.err = err + return err + } + + return nil +} + +// Sync establishes a synchronization point and flushes the queued requests. +func (p *Pipeline) Sync() error { + p.conn.frontend.SendSync(&pgproto3.Sync{}) + err := p.Flush() + if err != nil { + return err + } + + p.pendingSync = false + p.expectedReadyForQueryCount++ + + return nil +} + +// GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or +// *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. If no +// results are available, results and err will both be nil. +func (p *Pipeline) GetResults() (results any, err error) { + if p.expectedReadyForQueryCount == 0 { + return nil, nil + } + + for { + msg, err := p.conn.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + p.conn.resultReader = ResultReader{ + pgConn: p.conn, + pipeline: p, + ctx: p.ctx, + fieldDescriptions: msg.Fields, + } + return &p.conn.resultReader, nil + case *pgproto3.CommandComplete: + p.conn.resultReader = ResultReader{ + commandTag: p.conn.makeCommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + } + return &p.conn.resultReader, nil + case *pgproto3.ParseComplete: + peekedMsg, err := p.conn.peekMessage() + if err != nil { + return nil, err + } + if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok { + return p.getResultsPrepare() + } + case *pgproto3.ReadyForQuery: + p.expectedReadyForQueryCount-- + return &PipelineSync{}, nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + return nil, pgErr + } + + } + +} + +func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { + psd := &StatementDescription{} + + for { + msg, err := p.conn.receiveMessage() + if err != nil { + p.conn.asyncClose() + return nil, preferContextOverNetTimeoutError(p.ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) + case *pgproto3.RowDescription: + psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) + copy(psd.Fields, msg.Fields) + return psd, nil + + // These should never happen here. But don't take chances that could lead to a deadlock. + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + return nil, pgErr + case *pgproto3.CommandComplete: + p.conn.asyncClose() + return nil, errors.New("BUG: received CommandComplete while handling Describe") + case *pgproto3.ReadyForQuery: + p.conn.asyncClose() + return nil, errors.New("BUG: received ReadyForQuery while handling Describe") + } + } +} + +// Close closes the pipeline and returns the connection to normal mode. +func (p *Pipeline) Close() error { + if p.closed { + return p.err + } + p.closed = true + + if p.pendingSync { + p.conn.asyncClose() + p.err = errors.New("pipeline has unsynced requests") + p.conn.contextWatcher.Unwatch() + p.conn.unlock() + + return p.err + } + + for p.expectedReadyForQueryCount > 0 { + _, err := p.GetResults() + if err != nil { + var pgErr *PgError + if !errors.As(err, &pgErr) { + p.conn.asyncClose() + p.err = err + break + } + } + } + + p.conn.contextWatcher.Unwatch() + p.conn.unlock() + + return p.err +} diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index b47f17d6..c72ed6d6 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -20,6 +20,7 @@ import ( "github.com/jackc/pgx/v5/internal/pgmock" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -2094,6 +2095,417 @@ func TestConnCheckConn(t *testing.T) { require.Error(t, err) } +func TestPipelinePrepare(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendPrepare("selectInt", "select $1::int as a", nil) + pipeline.SendPrepare("selectText", "select $1::text as b", nil) + pipeline.SendPrepare("selectNoParams", "select 42 as c", nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "a") + require.Equal(t, []uint32{pgtype.Int4OID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + sd, ok = results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "b") + require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + sd, ok = results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "c") + require.Equal(t, []uint32{}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelinePrepareError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendPrepare("selectInt", "select $1::int as a", nil) + pipeline.SendPrepare("selectError", "bad", nil) + pipeline.SendPrepare("selectText", "select $1::text as b", nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "a") + require.Equal(t, []uint32{pgtype.Int4OID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) + require.Nil(t, results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineQuery(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "2", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "3", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "4", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "5", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelinePrepareQuery(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendPrepare("ps", "select $1::text as msg", nil) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("goodbye")}, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "msg") + require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "hello", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "goodbye", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineQueryErrorBetweenSyncs(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 6`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "2", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "3", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + var pgErr *pgconn.PgError + require.ErrorAs(t, readResult.Err, &pgErr) + require.Equal(t, "22012", pgErr.Code) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "5", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "6", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineCloseReadsUnreadResults(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + err = pipeline.Close() + require.EqualError(t, err, "pipeline has unsynced requests") +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { From a97ba0c34a638cdd470d03c2a459ed0c7acacd09 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 Jul 2022 21:49:16 -0500 Subject: [PATCH 1069/1158] Remove ReceiveResults Pipeline mode should be used instead. --- CHANGELOG.md | 2 ++ pgconn/pgconn.go | 33 --------------------------------- 2 files changed, 2 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 81b6cb26..a51ef9bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ pgconn now uses non-blocking IO. This is a significant internal restructuring, b pgconn now supports pipeline mode. +`*PgConn.ReceiveResults` removed. Use pipeline mode instead. + ## pgtype The `pgtype` package has been significantly changed. diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index b386a786..6f6f0486 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -917,39 +917,6 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { return multiResult } -// ReceiveResults reads the result that might be returned by Postgres after a SendBytes -// (e.a. after sending a CopyDone in a copy-both situation). -// -// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. -// See https://www.postgresql.org/docs/current/protocol.html. -func (pgConn *PgConn) ReceiveResults(ctx context.Context) *MultiResultReader { - if err := pgConn.lock(); err != nil { - return &MultiResultReader{ - closed: true, - err: err, - } - } - - pgConn.multiResultReader = MultiResultReader{ - pgConn: pgConn, - ctx: ctx, - } - multiResult := &pgConn.multiResultReader - if ctx != context.Background() { - select { - case <-ctx.Done(): - multiResult.closed = true - multiResult.err = newContextAlreadyDoneError(ctx) - pgConn.unlock() - return multiResult - default: - } - pgConn.contextWatcher.Watch(ctx) - } - - return multiResult -} - // ExecParams executes a command via the PostgreSQL extended query protocol. // // sql is a SQL command string. It may only contain one query. Parameter substitution is positional using $1, $2, $3, From f635b43a6badee6033bd4279efe0059d3cd579d6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 Jul 2022 22:00:42 -0500 Subject: [PATCH 1070/1158] Use bigint in tests for CockroachDB compatibility CRDB automatically changes int4 to int8. --- pgconn/pgconn_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index c72ed6d6..87837f8b 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -2103,7 +2103,7 @@ func TestPipelinePrepare(t *testing.T) { defer closeConn(t, pgConn) pipeline := pgConn.StartPipeline(context.Background()) - pipeline.SendPrepare("selectInt", "select $1::int as a", nil) + pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil) pipeline.SendPrepare("selectText", "select $1::text as b", nil) pipeline.SendPrepare("selectNoParams", "select 42 as c", nil) err = pipeline.Sync() @@ -2115,7 +2115,7 @@ func TestPipelinePrepare(t *testing.T) { require.Truef(t, ok, "expected StatementDescription, got: %#v", results) require.Len(t, sd.Fields, 1) require.Equal(t, string(sd.Fields[0].Name), "a") - require.Equal(t, []uint32{pgtype.Int4OID}, sd.ParamOIDs) + require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs) results, err = pipeline.GetResults() require.NoError(t, err) @@ -2156,7 +2156,7 @@ func TestPipelinePrepareError(t *testing.T) { defer closeConn(t, pgConn) pipeline := pgConn.StartPipeline(context.Background()) - pipeline.SendPrepare("selectInt", "select $1::int as a", nil) + pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil) pipeline.SendPrepare("selectError", "bad", nil) pipeline.SendPrepare("selectText", "select $1::text as b", nil) err = pipeline.Sync() @@ -2168,7 +2168,7 @@ func TestPipelinePrepareError(t *testing.T) { require.Truef(t, ok, "expected StatementDescription, got: %#v", results) require.Len(t, sd.Fields, 1) require.Equal(t, string(sd.Fields[0].Name), "a") - require.Equal(t, []uint32{pgtype.Int4OID}, sd.ParamOIDs) + require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs) results, err = pipeline.GetResults() var pgErr *pgconn.PgError From f7433cc5f29f36275c75ff6797586646b36a20c1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 4 Jul 2022 06:20:15 -0500 Subject: [PATCH 1071/1158] Fix typo --- pgconn/pgconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 6f6f0486..a49e569a 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1654,7 +1654,7 @@ func Construct(hc *HijackedConn) (*PgConn, error) { // Pipeline represents a connection in pipeline mode. // -// SendPrepare, SendQueryParam, and SendQueryPrepared queue requests to the server. These requests are not written until +// SendPrepare, SendQueryParams, and SendQueryPrepared queue requests to the server. These requests are not written until // pipeline is flushed by Flush or Sync. Sync must be called after the last request is queued. Requests between // synchronization points are implicitly transactional unless explicit transaction control statements have been issued. // From 1168b375e44eb5ad5396dbba03cb1782649f0fae Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 4 Jul 2022 08:36:05 -0500 Subject: [PATCH 1072/1158] Expose pgx functionality for manual integration with pgconn This is primarily useful for using pipeline mode. --- batch.go | 6 +- conn.go | 194 ++++++++------------------------------ extended_query_builder.go | 144 +++++++++++++++++++++++----- pipeline_test.go | 79 ++++++++++++++++ rows.go | 73 ++++++++------ tx.go | 8 +- 6 files changed, 289 insertions(+), 215 deletions(-) create mode 100644 pipeline_test.go diff --git a/batch.go b/batch.go index 98f216dd..6fd61295 100644 --- a/batch.go +++ b/batch.go @@ -116,12 +116,12 @@ func (br *batchResults) Query() (Rows, error) { } if br.err != nil { - return &connRows{err: br.err, closed: true}, br.err + return &baseRows{err: br.err, closed: true}, br.err } if br.closed { alreadyClosedErr := fmt.Errorf("batch already closed") - return &connRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr + return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr } rows := br.conn.getRows(br.ctx, query, arguments) @@ -182,7 +182,7 @@ func (br *batchResults) QueryFunc(scans []any, f func(QueryFuncRow) error) (pgco // QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. func (br *batchResults) QueryRow() Row { rows, _ := br.Query() - return (*connRow)(rows.(*connRows)) + return (*connRow)(rows.(*baseRows)) } diff --git a/conn.go b/conn.go index ec029ace..ba2ba578 100644 --- a/conn.go +++ b/conn.go @@ -75,7 +75,7 @@ type Conn struct { typeMap *pgtype.Map wbuf []byte - eqb extendedQueryBuilder + eqb ExtendedQueryBuilder } // Identifier a PostgreSQL identifier or name. Identifiers can be composed of @@ -485,49 +485,25 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []a return commandTag, err } -func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, args []any) error { - if len(sd.ParamOIDs) != len(args) { - return fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args)) - } - - c.eqb.Reset() - - anynil.NormalizeSlice(args) - - for i := range args { - err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) - if err != nil { - err = fmt.Errorf("failed to encode args[%d]: %v", i, err) - return err - } - } - - for i := range sd.Fields { - c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) - } - - return nil -} - func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) { - err := c.execParamsAndPreparedPrefix(sd, arguments) + err := c.eqb.Build(c.typeMap, sd, arguments) if err != nil { return pgconn.CommandTag{}, err } - result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read() - c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return result.CommandTag, result.Err } func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) { - err := c.execParamsAndPreparedPrefix(sd, arguments) + err := c.eqb.Build(c.typeMap, sd, arguments) if err != nil { return pgconn.CommandTag{}, err } - result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() - c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return result.CommandTag, result.Err } @@ -540,79 +516,18 @@ func (e *unknownArgumentTypeQueryExecModeExecError) Error() string { } func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) { - c.eqb.Reset() - - anynil.NormalizeSlice(args) - err := c.appendParamsForQueryExecModeExec(args) + err := c.eqb.Build(c.typeMap, nil, args) if err != nil { return pgconn.CommandTag{}, err } - result := c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats).Read() - c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + result := c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return result.CommandTag, result.Err } -// appendParamsForQueryExecModeExec appends the args to c.eqb. -// -// Parameters must be encoded in the text format because of differences in type conversion between timestamps and -// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the -// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both -// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL -// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date. -// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion -// before converting it to date. This means that dates can be shifted by one day. In text format without that double -// type conversion it takes the date directly and ignores time zone (i.e. it works). -// -// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is -// no way to safely use binary or to specify the parameter OIDs. -func (c *Conn) appendParamsForQueryExecModeExec(args []any) error { - for _, arg := range args { - if arg == nil { - err := c.eqb.AppendParamFormat(c.typeMap, 0, TextFormatCode, arg) - if err != nil { - return err - } - } else { - dt, ok := c.TypeMap().TypeForValue(arg) - if !ok { - var tv pgtype.TextValuer - if tv, ok = arg.(pgtype.TextValuer); ok { - t, err := tv.TextValue() - if err != nil { - return err - } - - dt, ok = c.TypeMap().TypeForOID(pgtype.TextOID) - if ok { - arg = t - } - } - } - if !ok { - var str fmt.Stringer - if str, ok = arg.(fmt.Stringer); ok { - dt, ok = c.TypeMap().TypeForOID(pgtype.TextOID) - if ok { - arg = str.String() - } - } - } - if !ok { - return &unknownArgumentTypeQueryExecModeExecError{arg: arg} - } - err := c.eqb.AppendParamFormat(c.typeMap, dt.OID, TextFormatCode, arg) - if err != nil { - return err - } - } - } - - return nil -} - -func (c *Conn) getRows(ctx context.Context, sql string, args []any) *connRows { - r := &connRows{} +func (c *Conn) getRows(ctx context.Context, sql string, args []any) *baseRows { + r := &baseRows{} r.ctx = ctx r.logger = c @@ -735,7 +650,7 @@ optionLoop: sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args) } - c.eqb.Reset() + c.eqb.reset() anynil.NormalizeSlice(args) rows := c.getRows(ctx, sql, args) @@ -782,13 +697,10 @@ optionLoop: rows.sql = sd.SQL - for i := range args { - err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) - if err != nil { - err = fmt.Errorf("failed to encode args[%d]: %v", i, err) - rows.fatal(err) - return rows, rows.err - } + err = c.eqb.Build(c.typeMap, sd, args) + if err != nil { + rows.fatal(err) + return rows, rows.err } if resultFormatsByOID != nil { @@ -799,26 +711,22 @@ optionLoop: } if resultFormats == nil { - for i := range sd.Fields { - c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) - } - - resultFormats = c.eqb.resultFormats + resultFormats = c.eqb.ResultFormats } if !explicitPreparedStatement && mode == QueryExecModeCacheDescribe { - rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats) + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, resultFormats) } else { - rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) + rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats) } } else if mode == QueryExecModeExec { - err := c.appendParamsForQueryExecModeExec(args) + err := c.eqb.Build(c.typeMap, nil, args) if err != nil { rows.fatal(err) return rows, rows.err } - rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats) + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) } else if mode == QueryExecModeSimpleProtocol { sql, err = c.sanitizeForSimpleQuery(sql, args...) if err != nil { @@ -843,7 +751,7 @@ optionLoop: return rows, rows.err } - c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return rows, rows.err } @@ -853,7 +761,7 @@ optionLoop: // error with ErrNoRows if no rows are returned. func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row { rows, _ := c.Query(ctx, sql, args...) - return (*connRow)(rows.(*connRows)) + return (*connRow)(rows.(*baseRows)) } // QueryFuncRow is the argument to the QueryFunc callback function. @@ -954,34 +862,23 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { if mode == QueryExecModeExec { for _, bi := range b.items { - c.eqb.Reset() + c.eqb.reset() anynil.NormalizeSlice(bi.arguments) sd := c.preparedStatements[bi.query] if sd != nil { - if len(sd.ParamOIDs) != len(bi.arguments) { - return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} - } - - for i := range bi.arguments { - err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i]) - if err != nil { - err = fmt.Errorf("failed to encode args[%d]: %v", i, err) - return &batchResults{ctx: ctx, conn: c, err: err} - } - } - - for i := range sd.Fields { - c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) - } - - batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats) - } else { - err := c.appendParamsForQueryExecModeExec(bi.arguments) + err := c.eqb.Build(c.typeMap, sd, bi.arguments) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } - batch.ExecParams(bi.query, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats) + + batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) + } else { + err := c.eqb.Build(c.typeMap, nil, bi.arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) } } } else { @@ -1014,7 +911,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { } for _, bi := range b.items { - c.eqb.Reset() + c.eqb.reset() sd := c.preparedStatements[bi.query] if sd == nil { @@ -1029,29 +926,20 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} } - anynil.NormalizeSlice(bi.arguments) - - for i := range bi.arguments { - err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i]) - if err != nil { - err = fmt.Errorf("failed to encode args[%d]: %v", i, err) - return &batchResults{ctx: ctx, conn: c, err: err} - } - } - - for i := range sd.Fields { - c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) + err := c.eqb.Build(c.typeMap, sd, bi.arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} } if sd.Name == "" { - batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats) + batch.ExecParams(bi.query, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats) } else { - batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats) + batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) } } } - c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. mrr := c.pgConn.ExecBatch(ctx, batch) diff --git a/extended_query_builder.go b/extended_query_builder.go index e69d0b36..1c47063c 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -1,62 +1,98 @@ package pgx import ( + "fmt" + "github.com/jackc/pgx/v5/internal/anynil" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" ) -type extendedQueryBuilder struct { - paramValues [][]byte +// ExtendedQueryBuilder is used to choose the parameter formats, to format the parameters and to choose the result +// formats for an extended query. +type ExtendedQueryBuilder struct { + ParamValues [][]byte paramValueBytes []byte - paramFormats []int16 - resultFormats []int16 + ParamFormats []int16 + ResultFormats []int16 } -func (eqb *extendedQueryBuilder) AppendParam(m *pgtype.Map, oid uint32, arg any) error { - f := eqb.chooseParameterFormatCode(m, oid, arg) - return eqb.AppendParamFormat(m, oid, f, arg) +// Build sets ParamValues, ParamFormats, and ResultFormats for use with *PgConn.ExecParams or *PgConn.ExecPrepared. If +// sd is nil then QueryExecModeExec behavior will be used. +func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error { + eqb.reset() + + anynil.NormalizeSlice(args) + + if sd == nil { + return eqb.appendParamsForQueryExecModeExec(m, args) + } + + if len(sd.ParamOIDs) != len(args) { + return fmt.Errorf("mismatched param and argument count") + } + + for i := range args { + err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i]) + if err != nil { + err = fmt.Errorf("failed to encode args[%d]: %v", i, err) + return err + } + } + + for i := range sd.Fields { + eqb.appendResultFormat(m.FormatCodeForOID(sd.Fields[i].DataTypeOID)) + } + + return nil } -func (eqb *extendedQueryBuilder) AppendParamFormat(m *pgtype.Map, oid uint32, format int16, arg any) error { - eqb.paramFormats = append(eqb.paramFormats, format) +// appendParam appends a parameter to the query. format may be -1 to automatically choose the format. If arg is nil it +// must be an untyped nil. +func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error { + if format == -1 { + format = eqb.chooseParameterFormatCode(m, oid, arg) + } + eqb.ParamFormats = append(eqb.ParamFormats, format) v, err := eqb.encodeExtendedParamValue(m, oid, format, arg) if err != nil { return err } - eqb.paramValues = append(eqb.paramValues, v) + eqb.ParamValues = append(eqb.ParamValues, v) return nil } -func (eqb *extendedQueryBuilder) AppendResultFormat(f int16) { - eqb.resultFormats = append(eqb.resultFormats, f) +// appendResultFormat appends a result format to the query. +func (eqb *ExtendedQueryBuilder) appendResultFormat(format int16) { + eqb.ResultFormats = append(eqb.ResultFormats, format) } -// Reset readies eqb to build another query. -func (eqb *extendedQueryBuilder) Reset() { - eqb.paramValues = eqb.paramValues[0:0] +// reset readies eqb to build another query. +func (eqb *ExtendedQueryBuilder) reset() { + eqb.ParamValues = eqb.ParamValues[0:0] eqb.paramValueBytes = eqb.paramValueBytes[0:0] - eqb.paramFormats = eqb.paramFormats[0:0] - eqb.resultFormats = eqb.resultFormats[0:0] + eqb.ParamFormats = eqb.ParamFormats[0:0] + eqb.ResultFormats = eqb.ResultFormats[0:0] - if cap(eqb.paramValues) > 64 { - eqb.paramValues = make([][]byte, 0, 64) + if cap(eqb.ParamValues) > 64 { + eqb.ParamValues = make([][]byte, 0, 64) } if cap(eqb.paramValueBytes) > 256 { eqb.paramValueBytes = make([]byte, 0, 256) } - if cap(eqb.paramFormats) > 64 { - eqb.paramFormats = make([]int16, 0, 64) + if cap(eqb.ParamFormats) > 64 { + eqb.ParamFormats = make([]int16, 0, 64) } - if cap(eqb.resultFormats) > 64 { - eqb.resultFormats = make([]int16, 0, 64) + if cap(eqb.ResultFormats) > 64 { + eqb.ResultFormats = make([]int16, 0, 64) } } -func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) { +func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) { if anynil.Is(arg) { return nil, nil } @@ -81,7 +117,7 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uin // chooseParameterFormatCode determines the correct format code for an // argument to a prepared statement. It defaults to TextFormatCode if no // determination can be made. -func (eqb *extendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg any) int16 { +func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg any) int16 { switch arg.(type) { case string, *string: return TextFormatCode @@ -89,3 +125,61 @@ func (eqb *extendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid ui return m.FormatCodeForOID(oid) } + +// appendParamsForQueryExecModeExec appends the args to eqb. +// +// Parameters must be encoded in the text format because of differences in type conversion between timestamps and +// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the +// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both +// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL +// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date. +// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion +// before converting it to date. This means that dates can be shifted by one day. In text format without that double +// type conversion it takes the date directly and ignores time zone (i.e. it works). +// +// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is +// no way to safely use binary or to specify the parameter OIDs. +func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error { + for _, arg := range args { + if arg == nil { + err := eqb.appendParam(m, 0, TextFormatCode, arg) + if err != nil { + return err + } + } else { + dt, ok := m.TypeForValue(arg) + if !ok { + var tv pgtype.TextValuer + if tv, ok = arg.(pgtype.TextValuer); ok { + t, err := tv.TextValue() + if err != nil { + return err + } + + dt, ok = m.TypeForOID(pgtype.TextOID) + if ok { + arg = t + } + } + } + if !ok { + var str fmt.Stringer + if str, ok = arg.(fmt.Stringer); ok { + dt, ok = m.TypeForOID(pgtype.TextOID) + if ok { + arg = str.String() + } + } + } + if !ok { + return &unknownArgumentTypeQueryExecModeExecError{arg: arg} + } + err := eqb.appendParam(m, dt.OID, TextFormatCode, arg) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/pipeline_test.go b/pipeline_test.go new file mode 100644 index 00000000..b8590bf9 --- /dev/null +++ b/pipeline_test.go @@ -0,0 +1,79 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/require" +) + +func TestPipelineWithoutPreparedOrDescribedStatements(t *testing.T) { + t.Parallel() + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pipeline := conn.PgConn().StartPipeline(ctx) + + eqb := pgx.ExtendedQueryBuilder{} + + err := eqb.Build(conn.TypeMap(), nil, []any{1, 2}) + require.NoError(t, err) + pipeline.SendQueryParams(`select $1::bigint + $2::bigint`, eqb.ParamValues, nil, eqb.ParamFormats, eqb.ResultFormats) + + err = eqb.Build(conn.TypeMap(), nil, []any{3, 4, 5}) + require.NoError(t, err) + pipeline.SendQueryParams(`select $1::bigint + $2::bigint + $3::bigint`, eqb.ParamValues, nil, eqb.ParamFormats, eqb.ResultFormats) + + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.True(t, ok) + rows := pgx.RowsFromResultReader(conn.TypeMap(), rr) + + rowCount := 0 + var n int64 + for rows.Next() { + err = rows.Scan(&n) + require.NoError(t, err) + rowCount++ + } + require.NoError(t, rows.Err()) + require.Equal(t, 1, rowCount) + require.Equal(t, "SELECT 1", rows.CommandTag().String()) + require.EqualValues(t, 3, n) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.True(t, ok) + rows = pgx.RowsFromResultReader(conn.TypeMap(), rr) + + rowCount = 0 + n = 0 + for rows.Next() { + err = rows.Scan(&n) + require.NoError(t, err) + rowCount++ + } + require.NoError(t, rows.Err()) + require.Equal(t, 1, rowCount) + require.Equal(t, "SELECT 1", rows.CommandTag().String()) + require.EqualValues(t, 12, n) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.True(t, ok) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + }) +} diff --git a/rows.go b/rows.go index 4f9c533d..d9c0ba47 100644 --- a/rows.go +++ b/rows.go @@ -76,10 +76,10 @@ type RowScanner interface { } // connRow implements the Row interface for Conn.QueryRow. -type connRow connRows +type connRow baseRows func (r *connRow) Scan(dest ...any) (err error) { - rows := (*connRows)(r) + rows := (*baseRows)(r) if rows.Err() != nil { return rows.Err() @@ -109,33 +109,36 @@ type rowLog interface { log(ctx context.Context, lvl LogLevel, msg string, data map[string]any) } -// connRows implements the Rows interface for Conn.Query. -type connRows struct { - ctx context.Context - logger rowLog - typeMap *pgtype.Map - values [][]byte - rowCount int - err error - commandTag pgconn.CommandTag - startTime time.Time - sql string - args []any - closed bool - conn *Conn +// baseRows implements the Rows interface for Conn.Query. +type baseRows struct { + typeMap *pgtype.Map + resultReader *pgconn.ResultReader - resultReader *pgconn.ResultReader - multiResultReader *pgconn.MultiResultReader + values [][]byte + + commandTag pgconn.CommandTag + err error + closed bool scanPlans []pgtype.ScanPlan scanTypes []reflect.Type + + conn *Conn + multiResultReader *pgconn.MultiResultReader + + logger rowLog + ctx context.Context + startTime time.Time + sql string + args []any + rowCount int } -func (rows *connRows) FieldDescriptions() []pgproto3.FieldDescription { +func (rows *baseRows) FieldDescriptions() []pgproto3.FieldDescription { return rows.resultReader.FieldDescriptions() } -func (rows *connRows) Close() { +func (rows *baseRows) Close() { if rows.closed { return } @@ -167,24 +170,25 @@ func (rows *connRows) Close() { if rows.logger.shouldLog(LogLevelError) { rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]any{"err": rows.err, "sql": rows.sql, "args": logQueryArgs(rows.args)}) } - if rows.err != nil && rows.conn.statementCache != nil { - rows.conn.statementCache.StatementErrored(rows.sql, rows.err) - } } } + + if rows.err != nil && rows.conn != nil && rows.conn.statementCache != nil { + rows.conn.statementCache.StatementErrored(rows.sql, rows.err) + } } -func (rows *connRows) CommandTag() pgconn.CommandTag { +func (rows *baseRows) CommandTag() pgconn.CommandTag { return rows.commandTag } -func (rows *connRows) Err() error { +func (rows *baseRows) Err() error { return rows.err } // fatal signals an error occurred after the query was sent to the server. It // closes the rows automatically. -func (rows *connRows) fatal(err error) { +func (rows *baseRows) fatal(err error) { if rows.err != nil { return } @@ -193,7 +197,7 @@ func (rows *connRows) fatal(err error) { rows.Close() } -func (rows *connRows) Next() bool { +func (rows *baseRows) Next() bool { if rows.closed { return false } @@ -208,7 +212,7 @@ func (rows *connRows) Next() bool { } } -func (rows *connRows) Scan(dest ...any) error { +func (rows *baseRows) Scan(dest ...any) error { m := rows.typeMap fieldDescriptions := rows.FieldDescriptions() values := rows.values @@ -261,7 +265,7 @@ func (rows *connRows) Scan(dest ...any) error { return nil } -func (rows *connRows) Values() ([]any, error) { +func (rows *baseRows) Values() ([]any, error) { if rows.closed { return nil, errors.New("rows is closed") } @@ -304,7 +308,7 @@ func (rows *connRows) Values() ([]any, error) { return values, rows.Err() } -func (rows *connRows) RawValues() [][]byte { +func (rows *baseRows) RawValues() [][]byte { return rows.values } @@ -348,3 +352,12 @@ func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgproto3.FieldDescription, return nil } + +// RowsFromResultReader returns a Rows that will read from values resultReader and decode with typeMap. It can be used +// to read from the lower level pgconn interface. +func RowsFromResultReader(typeMap *pgtype.Map, resultReader *pgconn.ResultReader) Rows { + return &baseRows{ + typeMap: typeMap, + resultReader: resultReader, + } +} diff --git a/tx.go b/tx.go index 7254e3dc..76b1768c 100644 --- a/tx.go +++ b/tx.go @@ -281,7 +281,7 @@ func (tx *dbTx) Query(ctx context.Context, sql string, args ...any) (Rows, error if tx.closed { // Because checking for errors can be deferred to the *Rows, build one with the error err := ErrTxClosed - return &connRows{closed: true, err: err}, err + return &baseRows{closed: true, err: err}, err } return tx.conn.Query(ctx, sql, args...) @@ -290,7 +290,7 @@ func (tx *dbTx) Query(ctx context.Context, sql string, args ...any) (Rows, error // QueryRow delegates to the underlying *Conn func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...any) Row { rows, _ := tx.Query(ctx, sql, args...) - return (*connRow)(rows.(*connRows)) + return (*connRow)(rows.(*baseRows)) } // QueryFunc delegates to the underlying *Conn. @@ -400,7 +400,7 @@ func (sp *dbSimulatedNestedTx) Query(ctx context.Context, sql string, args ...an if sp.closed { // Because checking for errors can be deferred to the *Rows, build one with the error err := ErrTxClosed - return &connRows{closed: true, err: err}, err + return &baseRows{closed: true, err: err}, err } return sp.tx.Query(ctx, sql, args...) @@ -409,7 +409,7 @@ func (sp *dbSimulatedNestedTx) Query(ctx context.Context, sql string, args ...an // QueryRow delegates to the underlying Tx func (sp *dbSimulatedNestedTx) QueryRow(ctx context.Context, sql string, args ...any) Row { rows, _ := sp.Query(ctx, sql, args...) - return (*connRow)(rows.(*connRows)) + return (*connRow)(rows.(*baseRows)) } // QueryFunc delegates to the underlying Tx. From a86f4f3db988a5b4c389a4f5b2869d774057c18f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 7 Jul 2022 19:32:01 -0500 Subject: [PATCH 1073/1158] Add deallocate to pipeline mode --- pgconn/pgconn.go | 15 +++++++++++++++ pgconn/pgconn_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ pgproto3/frontend.go | 10 ++++++++++ 3 files changed, 66 insertions(+) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index a49e569a..7426add9 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1677,6 +1677,9 @@ type Pipeline struct { // PipelineSync is returned by GetResults when a ReadyForQuery message is received. type PipelineSync struct{} +// CloseComplete is returned by GetResults when a CloseComplete message is received. +type CloseComplete struct{} + // StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent // to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection // to normal mode. While in pipeline mode, no methods that communicate with the server may be called except @@ -1723,6 +1726,16 @@ func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) { p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) } +// SendDeallocate deallocates a prepared statement. +func (p *Pipeline) SendDeallocate(name string) { + if p.closed { + return + } + p.pendingSync = true + + p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name}) +} + // SendQueryParams is the pipeline version of *PgConn.QueryParams. func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { if p.closed { @@ -1825,6 +1838,8 @@ func (p *Pipeline) GetResults() (results any, err error) { if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok { return p.getResultsPrepare() } + case *pgproto3.CloseComplete: + return &CloseComplete{}, nil case *pgproto3.ReadyForQuery: p.expectedReadyForQueryCount-- return &PipelineSync{}, nil diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 87837f8b..0fba9417 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -2190,6 +2190,47 @@ func TestPipelinePrepareError(t *testing.T) { ensureConnValid(t, pgConn) } +func TestPipelinePrepareAndDeallocate(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil) + pipeline.SendDeallocate("selectInt") + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "a") + require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.CloseComplete) + require.Truef(t, ok, "expected CloseComplete, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + func TestPipelineQuery(t *testing.T) { t.Parallel() diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 321d0bf9..eed8dc4f 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -136,6 +136,16 @@ func (f *Frontend) SendParse(msg *Parse) { } } +// SendClose sends a Close message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendClose(msg *Close) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg) + } +} + // SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until // Flush is called. func (f *Frontend) SendDescribe(msg *Describe) { From 76946fb5a39fb222724cbac50701b678ee6d53f6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 7 Jul 2022 20:29:04 -0500 Subject: [PATCH 1074/1158] Replace QueryFunc with ForEachScannedRow --- CHANGELOG.md | 5 + batch.go | 34 --- batch_test.go | 3 +- conn.go | 65 +---- doc.go | 29 +- pgtype/bytea.go | 4 +- pgtype/integration_benchmark_test.go | 351 +++++++++-------------- pgtype/integration_benchmark_test.go.erb | 10 +- pgxpool/batch_results.go | 8 - pgxpool/conn.go | 4 - pgxpool/pool.go | 10 - pgxpool/tx.go | 4 - query_test.go | 99 ------- rows.go | 25 ++ rows_test.go | 100 +++++++ tx.go | 19 -- 16 files changed, 301 insertions(+), 469 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a51ef9bd..a0b95203 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -134,6 +134,11 @@ allows arbitrary rewriting of query SQL and arguments. ## RowScanner Interface The `RowScanner` interface allows a single argument to Rows.Scan to scan the entire row. + +## QueryFunc Replaced + +`QueryFunc` has been replaced by using `ForEachScannedRow`. + ## 3rd Party Logger Integration All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency diff --git a/batch.go b/batch.go index 6fd61295..21830a1f 100644 --- a/batch.go +++ b/batch.go @@ -42,9 +42,6 @@ type BatchResults interface { // QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow. QueryRow() Row - // QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc. - QueryFunc(scans []any, f func(QueryFuncRow) error) (pgconn.CommandTag, error) - // Close closes the batch operation. This must be called before the underlying connection can be used again. Any error // that occurred during a batch operation may have made it impossible to resyncronize the connection with the server. // In this case the underlying connection will have been closed. Close is safe to call multiple times. @@ -148,37 +145,6 @@ func (br *batchResults) Query() (Rows, error) { return rows, nil } -// QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc. -func (br *batchResults) QueryFunc(scans []any, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { - if br.closed { - return pgconn.CommandTag{}, fmt.Errorf("batch already closed") - } - - rows, err := br.Query() - if err != nil { - return pgconn.CommandTag{}, err - } - defer rows.Close() - - for rows.Next() { - err = rows.Scan(scans...) - if err != nil { - return pgconn.CommandTag{}, err - } - - err = f(rows) - if err != nil { - return pgconn.CommandTag{}, err - } - } - - if err := rows.Err(); err != nil { - return pgconn.CommandTag{}, err - } - - return rows.CommandTag(), nil -} - // QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. func (br *batchResults) QueryRow() Row { rows, _ := br.Query() diff --git a/batch_test.go b/batch_test.go index 96cf61c2..abe9f915 100644 --- a/batch_test.go +++ b/batch_test.go @@ -108,7 +108,8 @@ func TestConnSendBatch(t *testing.T) { } rowCount = 0 - _, err = br.QueryFunc([]any{&id, &description, &amount}, func(pgx.QueryFuncRow) error { + rows, _ = br.Query() + _, err = pgx.ForEachScannedRow(rows, []any{&id, &description, &amount}, func() error { if id != selectFromLedgerExpectedRows[rowCount].id { t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) } diff --git a/conn.go b/conn.go index ba2ba578..d8ab21d7 100644 --- a/conn.go +++ b/conn.go @@ -12,7 +12,6 @@ import ( "github.com/jackc/pgx/v5/internal/sanitize" "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" ) @@ -764,48 +763,6 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row { return (*connRow)(rows.(*baseRows)) } -// QueryFuncRow is the argument to the QueryFunc callback function. -// -// QueryFuncRow is an interface instead of a struct to allow tests to mock QueryFunc. However, adding a method to an -// interface is technically a breaking change. Because of this the QueryFuncRow interface is partially excluded from -// semantic version requirements. Methods will not be removed or changed, but new methods may be added. -type QueryFuncRow interface { - FieldDescriptions() []pgproto3.FieldDescription - - // RawValues returns the unparsed bytes of the row values. The returned data is only valid during the current - // function call. - RawValues() [][]byte -} - -// QueryFunc executes sql with args. For each row returned by the query the values will scanned into the elements of -// scans and f will be called. If any row fails to scan or f returns an error the query will be aborted and the error -// will be returned. -func (c *Conn) QueryFunc(ctx context.Context, sql string, args []any, scans []any, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { - rows, err := c.Query(ctx, sql, args...) - if err != nil { - return pgconn.CommandTag{}, err - } - defer rows.Close() - - for rows.Next() { - err = rows.Scan(scans...) - if err != nil { - return pgconn.CommandTag{}, err - } - - err = f(rows) - if err != nil { - return pgconn.CommandTag{}, err - } - } - - if err := rows.Err(); err != nil { - return pgconn.CommandTag{}, err - } - - return rows.CommandTag(), nil -} - // SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. @@ -1038,20 +995,20 @@ func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.Com var fields []pgtype.CompositeCodecField var fieldName string var fieldOID uint32 - _, err = c.QueryFunc(ctx, `select attname, atttypid + rows, _ := c.Query(ctx, `select attname, atttypid from pg_attribute where attrelid=$1 order by attnum`, - []any{typrelid}, - []any{&fieldName, &fieldOID}, - func(qfr QueryFuncRow) error { - dt, ok := c.TypeMap().TypeForOID(fieldOID) - if !ok { - return fmt.Errorf("unknown composite type field OID: %v", fieldOID) - } - fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt}) - return nil - }) + typrelid, + ) + _, err = ForEachScannedRow(rows, []any{&fieldName, &fieldOID}, func() error { + dt, ok := c.TypeMap().TypeForOID(fieldOID) + if !ok { + return fmt.Errorf("unknown composite type field OID: %v", fieldOID) + } + fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt}) + return nil + }) if err != nil { return nil, err } diff --git a/doc.go b/doc.go index 660cf5a3..2e779dbb 100644 --- a/doc.go +++ b/doc.go @@ -63,6 +63,18 @@ pgx implements Query and Scan in the familiar database/sql style. // No errors found - do something with sum +ForEachScannedRow can be used to execute a callback function for every row. This is often easier than iterating over rows directly. + + var sum, n int32 + rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 10) + _, err := pgx.ForEachScannedRow(rows, []any{&n}, func(pgx.QueryFuncRow) error { + sum += n + return nil + }) + if err != nil { + return err + } + pgx also implements QueryRow in the same style as database/sql. var name string @@ -82,23 +94,6 @@ Use Exec to execute a query that does not return a result set. return errors.New("No row found to delete") } -QueryFunc can be used to execute a callback function for every row. This is often easier to use than Query. - - var sum, n int32 - _, err = conn.QueryFunc( - context.Background(), - "select generate_series(1,$1)", - []any{10}, - []any{&n}, - func(pgx.QueryFuncRow) error { - sum += n - return nil - }, - ) - if err != nil { - return err - } - Base Type Mapping pgx maps between all common base types directly between Go and PostgreSQL. In particular: diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 51994005..2e067672 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -17,8 +17,8 @@ type BytesValuer interface { } // DriverBytes is a byte slice that holds a reference to memory owned by the driver. It is only valid from the time it -// is scanned until Rows.Next or Rows.Close is called. It is safe to use in a function passed to QueryFunc. It is never -// safe to use DriverBytes with QueryRow as Row.Scan internally calls Rows.Close before returning. +// is scanned until Rows.Next or Rows.Close is called. It is never safe to use DriverBytes with QueryRow as Row.Scan +// internally calls Rows.Close before returning. type DriverBytes []byte func (b *DriverBytes) ScanBytes(v []byte) error { diff --git a/pgtype/integration_benchmark_test.go b/pgtype/integration_benchmark_test.go index 624c29ea..4ba8b9b5 100644 --- a/pgtype/integration_benchmark_test.go +++ b/pgtype/integration_benchmark_test.go @@ -1,4 +1,3 @@ -// Do not edit. Generated from pgtype/integration_benchmark_test.go.erb package pgtype_test import ( @@ -14,13 +13,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *test b.ResetTimer() var v [1]int16 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -33,13 +31,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *te b.ResetTimer() var v [1]int16 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -52,13 +49,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_10_columns(b *tes b.ResetTimer() var v [10]int16 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -71,13 +67,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_1_rows_10_columns(b *t b.ResetTimer() var v [10]int16 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -90,13 +85,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_10_rows_1_columns(b *tes b.ResetTimer() var v [1]int16 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -109,13 +103,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_10_rows_1_columns(b *t b.ResetTimer() var v [1]int16 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -128,13 +121,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_100_rows_10_columns(b *t b.ResetTimer() var v [10]int16 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -147,13 +139,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_100_rows_10_columns(b b.ResetTimer() var v [10]int16 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -166,13 +157,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_1_rows_1_columns(b *test b.ResetTimer() var v [1]int32 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -185,13 +175,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_1_rows_1_columns(b *te b.ResetTimer() var v [1]int32 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -204,13 +193,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_1_rows_10_columns(b *tes b.ResetTimer() var v [10]int32 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -223,13 +211,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_1_rows_10_columns(b *t b.ResetTimer() var v [10]int32 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -242,13 +229,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_10_rows_1_columns(b *tes b.ResetTimer() var v [1]int32 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -261,13 +247,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_10_rows_1_columns(b *t b.ResetTimer() var v [1]int32 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -280,13 +265,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_100_rows_10_columns(b *t b.ResetTimer() var v [10]int32 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -299,13 +283,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_100_rows_10_columns(b b.ResetTimer() var v [10]int32 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -318,13 +301,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_1_rows_1_columns(b *test b.ResetTimer() var v [1]int64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -337,13 +319,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_1_rows_1_columns(b *te b.ResetTimer() var v [1]int64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -356,13 +337,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_1_rows_10_columns(b *tes b.ResetTimer() var v [10]int64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -375,13 +355,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_1_rows_10_columns(b *t b.ResetTimer() var v [10]int64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -394,13 +373,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_10_rows_1_columns(b *tes b.ResetTimer() var v [1]int64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -413,13 +391,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_10_rows_1_columns(b *t b.ResetTimer() var v [1]int64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -432,13 +409,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_100_rows_10_columns(b *t b.ResetTimer() var v [10]int64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -451,13 +427,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_100_rows_10_columns(b b.ResetTimer() var v [10]int64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -470,13 +445,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_1_rows_1_columns(b *tes b.ResetTimer() var v [1]uint64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -489,13 +463,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_1_rows_1_columns(b *t b.ResetTimer() var v [1]uint64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -508,13 +481,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_1_rows_10_columns(b *te b.ResetTimer() var v [10]uint64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -527,13 +499,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_1_rows_10_columns(b * b.ResetTimer() var v [10]uint64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -546,13 +517,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_10_rows_1_columns(b *te b.ResetTimer() var v [1]uint64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -565,13 +535,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_10_rows_1_columns(b * b.ResetTimer() var v [1]uint64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -584,13 +553,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_100_rows_10_columns(b * b.ResetTimer() var v [10]uint64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -603,13 +571,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_100_rows_10_columns(b b.ResetTimer() var v [10]uint64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -622,13 +589,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_1_columns(b b.ResetTimer() var v [1]pgtype.Int4 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -641,13 +607,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_1_columns b.ResetTimer() var v [1]pgtype.Int4 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -660,13 +625,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_10_columns( b.ResetTimer() var v [10]pgtype.Int4 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -679,13 +643,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_10_column b.ResetTimer() var v [10]pgtype.Int4 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -698,13 +661,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_10_rows_1_columns( b.ResetTimer() var v [1]pgtype.Int4 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -717,13 +679,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_10_rows_1_column b.ResetTimer() var v [1]pgtype.Int4 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -736,13 +697,12 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_100_rows_10_column b.ResetTimer() var v [10]pgtype.Int4 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -755,13 +715,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_100_rows_10_colu b.ResetTimer() var v [10]pgtype.Int4 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -774,13 +733,12 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_1_rows_1_columns(b *t b.ResetTimer() var v [1]int64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -793,13 +751,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_1_rows_1_columns(b b.ResetTimer() var v [1]int64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -812,13 +769,12 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_1_rows_10_columns(b * b.ResetTimer() var v [10]int64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -831,13 +787,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_1_rows_10_columns(b b.ResetTimer() var v [10]int64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -850,13 +805,12 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_10_rows_1_columns(b * b.ResetTimer() var v [1]int64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -869,13 +823,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_10_rows_1_columns(b b.ResetTimer() var v [1]int64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -888,13 +841,12 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_100_rows_10_columns(b b.ResetTimer() var v [10]int64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -907,13 +859,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_100_rows_10_columns b.ResetTimer() var v [10]int64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -926,13 +877,12 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_1_rows_1_columns(b b.ResetTimer() var v [1]float64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -945,13 +895,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_1_rows_1_columns( b.ResetTimer() var v [1]float64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -964,13 +913,12 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_1_rows_10_columns(b b.ResetTimer() var v [10]float64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -983,13 +931,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_1_rows_10_columns b.ResetTimer() var v [10]float64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1002,13 +949,12 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_10_rows_1_columns(b b.ResetTimer() var v [1]float64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1021,13 +967,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_10_rows_1_columns b.ResetTimer() var v [1]float64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1040,13 +985,12 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_100_rows_10_columns b.ResetTimer() var v [10]float64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1059,13 +1003,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_100_rows_10_colum b.ResetTimer() var v [10]float64 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1078,13 +1021,12 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_1_col b.ResetTimer() var v [1]pgtype.Numeric for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1097,13 +1039,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_1_c b.ResetTimer() var v [1]pgtype.Numeric for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1116,13 +1057,12 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_10_co b.ResetTimer() var v [10]pgtype.Numeric for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1135,13 +1075,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_10_ b.ResetTimer() var v [10]pgtype.Numeric for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1154,13 +1093,12 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_10_rows_1_co b.ResetTimer() var v [1]pgtype.Numeric for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1173,13 +1111,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_10_rows_1_ b.ResetTimer() var v [1]pgtype.Numeric for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1192,13 +1129,12 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_10_ b.ResetTimer() var v [10]pgtype.Numeric for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1211,13 +1147,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_1 b.ResetTimer() var v [10]pgtype.Numeric for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1230,13 +1165,12 @@ func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testing b.ResetTimer() var v []int32 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select array_agg(n) from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1249,13 +1183,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testi b.ResetTimer() var v []int32 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select array_agg(n) from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1268,13 +1201,12 @@ func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *testin b.ResetTimer() var v []int32 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select array_agg(n) from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1287,13 +1219,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *test b.ResetTimer() var v []int32 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select array_agg(n) from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1306,13 +1237,12 @@ func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *testi b.ResetTimer() var v []int32 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select array_agg(n) from generate_series(1, 1000) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, - []any{&v}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1325,13 +1255,12 @@ func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *tes b.ResetTimer() var v []int32 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select array_agg(n) from generate_series(1, 1000) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, - []any{&v}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v}, func() error { return nil }) if err != nil { b.Fatal(err) } diff --git a/pgtype/integration_benchmark_test.go.erb b/pgtype/integration_benchmark_test.go.erb index 3459f0cb..144d9dd7 100644 --- a/pgtype/integration_benchmark_test.go.erb +++ b/pgtype/integration_benchmark_test.go.erb @@ -22,13 +22,12 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go b.ResetTimer() var v [<%= columns %>]<%= go_type %> for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select <% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>n::<%= pg_type %> + <%= col_idx%><% end %> from generate_series(1, <%= rows %>) n`, []any{pgx.QueryResultFormats{<%= format_code %>}}, - []any{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -47,13 +46,12 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_Int4Array_With_Go_Int4Array b.ResetTimer() var v []int32 for i := 0; i < b.N; i++ { - _, err := conn.QueryFunc( + rows, _ := conn.Query( ctx, `select array_agg(n) from generate_series(1, <%= array_size %>) n`, []any{pgx.QueryResultFormats{<%= format_code %>}}, - []any{&v}, - func(pgx.QueryFuncRow) error { return nil }, ) + _, err := pgx.ForEachScannedRow(rows, []any{&v}, func() error { return nil }) if err != nil { b.Fatal(err) } diff --git a/pgxpool/batch_results.go b/pgxpool/batch_results.go index fcd10b37..5d5c681d 100644 --- a/pgxpool/batch_results.go +++ b/pgxpool/batch_results.go @@ -17,10 +17,6 @@ func (br errBatchResults) Query() (pgx.Rows, error) { return errRows{err: br.err}, br.err } -func (br errBatchResults) QueryFunc(scans []any, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { - return pgconn.CommandTag{}, br.err -} - func (br errBatchResults) QueryRow() pgx.Row { return errRow{err: br.err} } @@ -42,10 +38,6 @@ func (br *poolBatchResults) Query() (pgx.Rows, error) { return br.br.Query() } -func (br *poolBatchResults) QueryFunc(scans []any, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { - return br.br.QueryFunc(scans, f) -} - func (br *poolBatchResults) QueryRow() pgx.Row { return br.br.QueryRow() } diff --git a/pgxpool/conn.go b/pgxpool/conn.go index 3ab8b375..b8711da9 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -74,10 +74,6 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { return c.Conn().QueryRow(ctx, sql, args...) } -func (c *Conn) QueryFunc(ctx context.Context, sql string, args []any, scans []any, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { - return c.Conn().QueryFunc(ctx, sql, args, scans, f) -} - func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { return c.Conn().SendBatch(ctx, b) } diff --git a/pgxpool/pool.go b/pgxpool/pool.go index de4e1066..d73b93fb 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -533,16 +533,6 @@ func (p *Pool) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { return c.getPoolRow(row) } -func (p *Pool) QueryFunc(ctx context.Context, sql string, args []any, scans []any, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { - c, err := p.Acquire(ctx) - if err != nil { - return pgconn.CommandTag{}, err - } - defer c.Release() - - return c.QueryFunc(ctx, sql, args, scans, f) -} - func (p *Pool) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { c, err := p.Acquire(ctx) if err != nil { diff --git a/pgxpool/tx.go b/pgxpool/tx.go index 79da567c..3ddb742c 100644 --- a/pgxpool/tx.go +++ b/pgxpool/tx.go @@ -81,10 +81,6 @@ func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { return tx.t.QueryRow(ctx, sql, args...) } -func (tx *Tx) QueryFunc(ctx context.Context, sql string, args []any, scans []any, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { - return tx.t.QueryFunc(ctx, sql, args, scans, f) -} - func (tx *Tx) Conn() *pgx.Conn { return tx.t.Conn() } diff --git a/query_test.go b/query_test.go index 7fa507b0..0c8b5fab 100644 --- a/query_test.go +++ b/query_test.go @@ -1898,102 +1898,3 @@ func TestQueryWithQueryRewriter(t *testing.T) { require.NoError(t, rows.Err()) }) } - -func TestConnQueryFunc(t *testing.T) { - t.Parallel() - - pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - var actualResults []any - - var a, b int - ct, err := conn.QueryFunc( - context.Background(), - "select n, n * 2 from generate_series(1, $1) n", - []any{3}, - []any{&a, &b}, - func(pgx.QueryFuncRow) error { - actualResults = append(actualResults, []any{a, b}) - return nil - }, - ) - require.NoError(t, err) - - expectedResults := []any{ - []any{1, 2}, - []any{2, 4}, - []any{3, 6}, - } - require.Equal(t, expectedResults, actualResults) - require.EqualValues(t, 3, ct.RowsAffected()) - }) -} - -func TestConnQueryFuncScanError(t *testing.T) { - t.Parallel() - - pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - var actualResults []any - - var a, b int - ct, err := conn.QueryFunc( - context.Background(), - "select 'foo', 'bar' from generate_series(1, $1) n", - []any{3}, - []any{&a, &b}, - func(pgx.QueryFuncRow) error { - actualResults = append(actualResults, []any{a, b}) - return nil - }, - ) - require.EqualError(t, err, "can't scan into dest[0]: cannot scan OID 25 in text format into *int") - require.Equal(t, pgconn.CommandTag{}, ct) - }) -} - -func TestConnQueryFuncAbort(t *testing.T) { - t.Parallel() - - pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - var a, b int - ct, err := conn.QueryFunc( - context.Background(), - "select n, n * 2 from generate_series(1, $1) n", - []any{3}, - []any{&a, &b}, - func(pgx.QueryFuncRow) error { - return errors.New("abort") - }, - ) - require.EqualError(t, err, "abort") - require.Equal(t, pgconn.CommandTag{}, ct) - }) -} - -func ExampleConn_QueryFunc() { - conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - fmt.Printf("Unable to establish connection: %v", err) - return - } - - var a, b int - _, err = conn.QueryFunc( - context.Background(), - "select n, n * 2 from generate_series(1, $1) n", - []any{3}, - []any{&a, &b}, - func(pgx.QueryFuncRow) error { - fmt.Printf("%v, %v\n", a, b) - return nil - }, - ) - if err != nil { - fmt.Printf("QueryFunc error: %v", err) - return - } - - // Output: - // 1, 2 - // 2, 4 - // 3, 6 -} diff --git a/rows.go b/rows.go index d9c0ba47..a1492c3e 100644 --- a/rows.go +++ b/rows.go @@ -361,3 +361,28 @@ func RowsFromResultReader(typeMap *pgtype.Map, resultReader *pgconn.ResultReader resultReader: resultReader, } } + +// ForEachScannedRow iterates through rows. For each row it scans into the elements of scans and calls fn. If any row +// fails to scan or fn returns an error the query will be aborted and the error will be returned. Rows will be closed +// when ForEachScannedRow returns. +func ForEachScannedRow(rows Rows, scans []any, fn func() error) (pgconn.CommandTag, error) { + defer rows.Close() + + for rows.Next() { + err := rows.Scan(scans...) + if err != nil { + return pgconn.CommandTag{}, err + } + + err = fn() + if err != nil { + return pgconn.CommandTag{}, err + } + } + + if err := rows.Err(); err != nil { + return pgconn.CommandTag{}, err + } + + return rows.CommandTag(), nil +} diff --git a/rows_test.go b/rows_test.go index 37f8e1de..63bb77d5 100644 --- a/rows_test.go +++ b/rows_test.go @@ -2,9 +2,14 @@ package pgx_test import ( "context" + "errors" + "fmt" + "os" "testing" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) @@ -28,3 +33,98 @@ func TestRowScanner(t *testing.T) { require.Equal(t, int32(72), s.age) }) } + +func TestForEachScannedRow(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var actualResults []any + + rows, _ := conn.Query( + context.Background(), + "select n, n * 2 from generate_series(1, $1) n", + 3, + ) + var a, b int + ct, err := pgx.ForEachScannedRow(rows, []any{&a, &b}, func() error { + actualResults = append(actualResults, []any{a, b}) + return nil + }) + require.NoError(t, err) + + expectedResults := []any{ + []any{1, 2}, + []any{2, 4}, + []any{3, 6}, + } + require.Equal(t, expectedResults, actualResults) + require.EqualValues(t, 3, ct.RowsAffected()) + }) +} + +func TestForEachScannedRowScanError(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var actualResults []any + + rows, _ := conn.Query( + context.Background(), + "select 'foo', 'bar' from generate_series(1, $1) n", + 3, + ) + var a, b int + ct, err := pgx.ForEachScannedRow(rows, []any{&a, &b}, func() error { + actualResults = append(actualResults, []any{a, b}) + return nil + }) + require.EqualError(t, err, "can't scan into dest[0]: cannot scan OID 25 in text format into *int") + require.Equal(t, pgconn.CommandTag{}, ct) + }) +} + +func TestForEachScannedRowAbort(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query( + context.Background(), + "select n, n * 2 from generate_series(1, $1) n", + 3, + ) + var a, b int + ct, err := pgx.ForEachScannedRow(rows, []any{&a, &b}, func() error { + return errors.New("abort") + }) + require.EqualError(t, err, "abort") + require.Equal(t, pgconn.CommandTag{}, ct) + }) +} + +func ExampleForEachScannedRow() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + rows, _ := conn.Query( + context.Background(), + "select n, n * 2 from generate_series(1, $1) n", + 3, + ) + var a, b int + _, err = pgx.ForEachScannedRow(rows, []any{&a, &b}, func() error { + fmt.Printf("%v, %v\n", a, b) + return nil + }) + if err != nil { + fmt.Printf("ForEachScannedRow error: %v", err) + return + } + + // Output: + // 1, 2 + // 2, 4 + // 3, 6 +} diff --git a/tx.go b/tx.go index 76b1768c..2a05b70d 100644 --- a/tx.go +++ b/tx.go @@ -163,7 +163,6 @@ type Tx interface { Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) Query(ctx context.Context, sql string, args ...any) (Rows, error) QueryRow(ctx context.Context, sql string, args ...any) Row - QueryFunc(ctx context.Context, sql string, args []any, scans []any, f func(QueryFuncRow) error) (pgconn.CommandTag, error) // Conn returns the underlying *Conn that on which this transaction is executing. Conn() *Conn @@ -293,15 +292,6 @@ func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...any) Row { return (*connRow)(rows.(*baseRows)) } -// QueryFunc delegates to the underlying *Conn. -func (tx *dbTx) QueryFunc(ctx context.Context, sql string, args []any, scans []any, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { - if tx.closed { - return pgconn.CommandTag{}, ErrTxClosed - } - - return tx.conn.QueryFunc(ctx, sql, args, scans, f) -} - // CopyFrom delegates to the underlying *Conn func (tx *dbTx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { if tx.closed { @@ -412,15 +402,6 @@ func (sp *dbSimulatedNestedTx) QueryRow(ctx context.Context, sql string, args .. return (*connRow)(rows.(*baseRows)) } -// QueryFunc delegates to the underlying Tx. -func (sp *dbSimulatedNestedTx) QueryFunc(ctx context.Context, sql string, args []any, scans []any, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { - if sp.closed { - return pgconn.CommandTag{}, ErrTxClosed - } - - return sp.tx.QueryFunc(ctx, sql, args, scans, f) -} - // CopyFrom delegates to the underlying *Conn func (sp *dbSimulatedNestedTx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { if sp.closed { From ba58e3d5d2c34d47254c3faf49445eb46e81f908 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Jul 2022 08:32:12 -0500 Subject: [PATCH 1075/1158] Fix pipeline prepare query without row results --- pgconn/pgconn.go | 5 +++++ pgconn/pgconn_test.go | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 7426add9..bb4d35a9 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1871,6 +1871,11 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { copy(psd.Fields, msg.Fields) return psd, nil + // NoData is returned instead of RowDescription when there is no expected result. e.g. An INSERT without a RETURNING + // clause. + case *pgproto3.NoData: + return psd, nil + // These should never happen here. But don't take chances that could lead to a deadlock. case *pgproto3.ErrorResponse: pgErr := ErrorResponseToPgError(msg) diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 0fba9417..598d6629 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -2102,10 +2102,15 @@ func TestPipelinePrepare(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + result := pgConn.ExecParams(context.Background(), `create temporary table t (id text primary key)`, nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + pipeline := pgConn.StartPipeline(context.Background()) pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil) pipeline.SendPrepare("selectText", "select $1::text as b", nil) pipeline.SendPrepare("selectNoParams", "select 42 as c", nil) + pipeline.SendPrepare("insertNoResults", "insert into t (id) values ($1)", nil) + pipeline.SendPrepare("insertNoParamsOrResults", "insert into t (id) values ('foo')", nil) err = pipeline.Sync() require.NoError(t, err) @@ -2133,6 +2138,20 @@ func TestPipelinePrepare(t *testing.T) { require.Equal(t, string(sd.Fields[0].Name), "c") require.Equal(t, []uint32{}, sd.ParamOIDs) + results, err = pipeline.GetResults() + require.NoError(t, err) + sd, ok = results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 0) + require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + sd, ok = results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 0) + require.Len(t, sd.ParamOIDs, 0) + results, err = pipeline.GetResults() require.NoError(t, err) _, ok = results.(*pgconn.PipelineSync) From e7aa76ccf9f3bf1eb2050237adba9cbe04f1ae27 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Jul 2022 09:28:11 -0500 Subject: [PATCH 1076/1158] SendBatch now uses pipeline mode to prepare and describe statements Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1 for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements in a single network round trip. So it would only take 2 round trips. --- CHANGELOG.md | 6 + batch.go | 163 +++++++++++ batch_test.go | 2 +- conn.go | 404 +++++++++++++++++++------- conn_test.go | 17 +- internal/stmtcache/lru.go | 169 ----------- internal/stmtcache/lru_cache.go | 98 +++++++ internal/stmtcache/lru_test.go | 292 ------------------- internal/stmtcache/stmtcache.go | 69 +++-- internal/stmtcache/unlimited_cache.go | 71 +++++ pgconn/pgconn.go | 2 +- rows.go | 13 +- 12 files changed, 694 insertions(+), 612 deletions(-) delete mode 100644 internal/stmtcache/lru.go create mode 100644 internal/stmtcache/lru_cache.go delete mode 100644 internal/stmtcache/lru_test.go create mode 100644 internal/stmtcache/unlimited_cache.go diff --git a/CHANGELOG.md b/CHANGELOG.md index a0b95203..8b6c3a96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -139,6 +139,12 @@ The `RowScanner` interface allows a single argument to Rows.Scan to scan the ent `QueryFunc` has been replaced by using `ForEachScannedRow`. +## SendBatch Uses Pipeline Mode When Appropriate + +Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1 +for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements +in a single network round trip. So it would only take 2 round trips. + ## 3rd Party Logger Integration All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency diff --git a/batch.go b/batch.go index 21830a1f..f2a9b4c8 100644 --- a/batch.go +++ b/batch.go @@ -11,6 +11,7 @@ import ( type batchItem struct { query string arguments []any + sd *pgconn.StatementDescription } // Batch queries are a way of bundling multiple queries together to avoid @@ -192,3 +193,165 @@ func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { } return } + +type pipelineBatchResults struct { + ctx context.Context + conn *Conn + pipeline *pgconn.Pipeline + lastRows *baseRows + err error + b *Batch + ix int + closed bool +} + +// Exec reads the results from the next query in the batch as if the query has been sent with Exec. +func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) { + if br.err != nil { + return pgconn.CommandTag{}, br.err + } + if br.closed { + return pgconn.CommandTag{}, fmt.Errorf("batch already closed") + } + if br.lastRows != nil && br.lastRows.err != nil { + return pgconn.CommandTag{}, br.err + } + + query, arguments, _ := br.nextQueryAndArgs() + + results, err := br.pipeline.GetResults() + if err != nil { + br.err = err + return pgconn.CommandTag{}, err + } + var commandTag pgconn.CommandTag + switch results := results.(type) { + case *pgconn.ResultReader: + commandTag, err = results.Close() + default: + return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results) + } + + if err != nil { + br.err = err + if br.conn.shouldLog(LogLevelError) { + br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{ + "sql": query, + "args": logQueryArgs(arguments), + "err": err, + }) + } + } else if br.conn.shouldLog(LogLevelInfo) { + br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]any{ + "sql": query, + "args": logQueryArgs(arguments), + "commandTag": commandTag, + }) + } + + return commandTag, err +} + +// Query reads the results from the next query in the batch as if the query has been sent with Query. +func (br *pipelineBatchResults) Query() (Rows, error) { + if br.err != nil { + return &baseRows{err: br.err, closed: true}, br.err + } + + if br.closed { + alreadyClosedErr := fmt.Errorf("batch already closed") + return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr + } + + if br.lastRows != nil && br.lastRows.err != nil { + br.err = br.lastRows.err + return &baseRows{err: br.err, closed: true}, br.err + } + + query, arguments, ok := br.nextQueryAndArgs() + if !ok { + query = "batch query" + } + + rows := br.conn.getRows(br.ctx, query, arguments) + br.lastRows = rows + + results, err := br.pipeline.GetResults() + if err != nil { + br.err = err + rows.err = err + rows.closed = true + if br.conn.shouldLog(LogLevelError) { + br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]any{ + "sql": query, + "args": logQueryArgs(arguments), + "err": rows.err, + }) + } + } else { + switch results := results.(type) { + case *pgconn.ResultReader: + rows.resultReader = results + default: + err = fmt.Errorf("unexpected pipeline result: %T", results) + br.err = err + rows.err = err + rows.closed = true + } + } + + return rows, rows.err +} + +// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. +func (br *pipelineBatchResults) QueryRow() Row { + rows, _ := br.Query() + return (*connRow)(rows.(*baseRows)) + +} + +// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to +// resyncronize the connection with the server. In this case the underlying connection will have been closed. +func (br *pipelineBatchResults) Close() error { + if br.err != nil { + return br.err + } + + if br.lastRows != nil && br.lastRows.err != nil { + br.err = br.lastRows.err + return br.err + } + + if br.closed { + return nil + } + br.closed = true + + // log any queries that haven't yet been logged by Exec or Query + for { + query, args, ok := br.nextQueryAndArgs() + if !ok { + break + } + + if br.conn.shouldLog(LogLevelInfo) { + br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]any{ + "sql": query, + "args": logQueryArgs(args), + }) + } + } + + return br.pipeline.Close() +} + +func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) { + if br.b != nil && br.ix < len(br.b.items) { + bi := br.b.items[br.ix] + query = bi.query + args = bi.arguments + ok = true + br.ix++ + } + return +} diff --git a/batch_test.go b/batch_test.go index abe9f915..f5409d25 100644 --- a/batch_test.go +++ b/batch_test.go @@ -420,7 +420,7 @@ func TestConnSendBatchQueryError(t *testing.T) { err = br.Close() if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") { - t.Errorf("rows.Err() => %v, want error code %v", err, 22012) + t.Errorf("br.Close() => %v, want error code %v", err, 22012) } }) diff --git a/conn.go b/conn.go index d8ab21d7..997a84d3 100644 --- a/conn.go +++ b/conn.go @@ -236,11 +236,11 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { c.wbuf = make([]byte, 0, 1024) if c.config.StatementCacheCapacity > 0 { - c.statementCache = stmtcache.New(c.pgConn, stmtcache.ModePrepare, c.config.StatementCacheCapacity) + c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity) } if c.config.DescriptionCacheCapacity > 0 { - c.descriptionCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, c.config.DescriptionCacheCapacity) + c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity) } return c, nil @@ -382,6 +382,10 @@ func (c *Conn) Config() *ConnConfig { return c.config.Copy() } // Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced // positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + return pgconn.CommandTag{}, err + } + startTime := time.Now() commandTag, err := c.exec(ctx, sql, arguments...) @@ -437,9 +441,13 @@ optionLoop: if c.statementCache == nil { return pgconn.CommandTag{}, errDisabledStatementCache } - sd, err := c.statementCache.Get(ctx, sql) - if err != nil { - return pgconn.CommandTag{}, err + sd := c.statementCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql) + if err != nil { + return pgconn.CommandTag{}, err + } + c.statementCache.Put(sd) } return c.execPrepared(ctx, sd, arguments) @@ -447,9 +455,12 @@ optionLoop: if c.descriptionCache == nil { return pgconn.CommandTag{}, errDisabledDescriptionCache } - sd, err := c.descriptionCache.Get(ctx, sql) - if err != nil { - return pgconn.CommandTag{}, err + sd := c.descriptionCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, "", sql) + if err != nil { + return pgconn.CommandTag{}, err + } } return c.execParams(ctx, sd, arguments) @@ -620,6 +631,10 @@ type QueryRewriter interface { // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) { + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + return &baseRows{err: err, closed: true}, err + } + var resultFormats QueryResultFormats var resultFormatsByOID QueryResultFormatsByOID mode := c.config.DefaultQueryExecMode @@ -649,6 +664,11 @@ optionLoop: sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args) } + // Bypass any statement caching. + if sql == "" { + mode = QueryExecModeSimpleProtocol + } + c.eqb.reset() anynil.NormalizeSlice(args) rows := c.getRows(ctx, sql, args) @@ -664,10 +684,14 @@ optionLoop: rows.fatal(err) return rows, err } - sd, err = c.statementCache.Get(ctx, sql) - if err != nil { - rows.fatal(err) - return rows, err + sd = c.statementCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql) + if err != nil { + rows.fatal(err) + return rows, err + } + c.statementCache.Put(sd) } case QueryExecModeCacheDescribe: if c.descriptionCache == nil { @@ -675,10 +699,14 @@ optionLoop: rows.fatal(err) return rows, err } - sd, err = c.descriptionCache.Get(ctx, sql) - if err != nil { - rows.fatal(err) - return rows, err + sd = c.descriptionCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, "", sql) + if err != nil { + rows.fatal(err) + return rows, err + } + c.descriptionCache.Put(sd) } case QueryExecModeDescribeExec: sd, err = c.Prepare(ctx, "", sql) @@ -767,6 +795,10 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row { // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + mode := c.config.DefaultQueryExecMode for _, bi := range b.items { @@ -794,105 +826,70 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { } if mode == QueryExecModeSimpleProtocol { - var sb strings.Builder - for i, bi := range b.items { - if i > 0 { - sb.WriteByte(';') - } - sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} - } - sb.WriteString(sql) - } - mrr := c.pgConn.Exec(ctx, sb.String()) - return &batchResults{ - ctx: ctx, - conn: c, - mrr: mrr, - b: b, - ix: 0, + return c.sendBatchQueryExecModeSimpleProtocol(ctx, b) + } + + // All other modes use extended protocol and thus can use prepared statements. + for _, bi := range b.items { + if sd, ok := c.preparedStatements[bi.query]; ok { + bi.sd = sd } } + switch mode { + case QueryExecModeExec: + return c.sendBatchQueryExecModeExec(ctx, b) + case QueryExecModeCacheStatement: + return c.sendBatchQueryExecModeCacheStatement(ctx, b) + case QueryExecModeCacheDescribe: + return c.sendBatchQueryExecModeCacheDescribe(ctx, b) + case QueryExecModeDescribeExec: + return c.sendBatchQueryExecModeDescribeExec(ctx, b) + default: + panic("unknown QueryExecMode") + } +} + +func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults { + var sb strings.Builder + for i, bi := range b.items { + if i > 0 { + sb.WriteByte(';') + } + sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + sb.WriteString(sql) + } + mrr := c.pgConn.Exec(ctx, sb.String()) + return &batchResults{ + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + ix: 0, + } +} + +func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults { batch := &pgconn.Batch{} - if mode == QueryExecModeExec { - for _, bi := range b.items { - c.eqb.reset() - anynil.NormalizeSlice(bi.arguments) - - sd := c.preparedStatements[bi.query] - if sd != nil { - err := c.eqb.Build(c.typeMap, sd, bi.arguments) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} - } - - batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) - } else { - err := c.eqb.Build(c.typeMap, nil, bi.arguments) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} - } - batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) - } - } - } else { - - distinctUnpreparedQueries := map[string]struct{}{} - - for _, bi := range b.items { - if _, ok := c.preparedStatements[bi.query]; ok { - continue - } - distinctUnpreparedQueries[bi.query] = struct{}{} - } - - var stmtCache stmtcache.Cache - if len(distinctUnpreparedQueries) > 0 { - if mode == QueryExecModeCacheStatement && c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) { - stmtCache = c.statementCache - } else if mode == QueryExecModeCacheStatement && c.descriptionCache != nil && c.descriptionCache.Cap() >= len(distinctUnpreparedQueries) { - stmtCache = c.descriptionCache - } else { - stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries)) - } - - for sql, _ := range distinctUnpreparedQueries { - _, err := stmtCache.Get(ctx, sql) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} - } - } - } - - for _, bi := range b.items { - c.eqb.reset() - - sd := c.preparedStatements[bi.query] - if sd == nil { - var err error - sd, err = stmtCache.Get(ctx, bi.query) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} - } - } - - if len(sd.ParamOIDs) != len(bi.arguments) { - return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} - } - + for _, bi := range b.items { + sd := bi.sd + if sd != nil { err := c.eqb.Build(c.typeMap, sd, bi.arguments) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } - if sd.Name == "" { - batch.ExecParams(bi.query, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats) - } else { - batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) + batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) + } else { + err := c.eqb.Build(c.typeMap, nil, bi.arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} } + batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) } } @@ -909,6 +906,171 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { } } +func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { + if c.statementCache == nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache} + } + + distinctNewQueries := []*pgconn.StatementDescription{} + distinctNewQueriesIdxMap := make(map[string]int) + + for _, bi := range b.items { + if bi.sd == nil { + sd := c.statementCache.Get(bi.query) + if sd != nil { + bi.sd = sd + } else { + if idx, present := distinctNewQueriesIdxMap[bi.query]; present { + bi.sd = distinctNewQueries[idx] + } else { + sd = &pgconn.StatementDescription{ + Name: stmtcache.NextStatementName(), + SQL: bi.query, + } + distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) + distinctNewQueries = append(distinctNewQueries, sd) + bi.sd = sd + } + } + } + } + + return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.statementCache) +} + +func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { + if c.descriptionCache == nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache} + } + + distinctNewQueries := []*pgconn.StatementDescription{} + distinctNewQueriesIdxMap := make(map[string]int) + + for _, bi := range b.items { + if bi.sd == nil { + sd := c.descriptionCache.Get(bi.query) + if sd != nil { + bi.sd = sd + } else { + if idx, present := distinctNewQueriesIdxMap[bi.query]; present { + bi.sd = distinctNewQueries[idx] + } else { + sd = &pgconn.StatementDescription{ + SQL: bi.query, + } + distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) + distinctNewQueries = append(distinctNewQueries, sd) + bi.sd = sd + } + } + } + } + + return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.descriptionCache) +} + +func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { + distinctNewQueries := []*pgconn.StatementDescription{} + distinctNewQueriesIdxMap := make(map[string]int) + + for _, bi := range b.items { + if bi.sd == nil { + if idx, present := distinctNewQueriesIdxMap[bi.query]; present { + bi.sd = distinctNewQueries[idx] + } else { + sd := &pgconn.StatementDescription{ + SQL: bi.query, + } + distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) + distinctNewQueries = append(distinctNewQueries, sd) + bi.sd = sd + } + } + } + + return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, nil) +} + +func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) { + pipeline := c.pgConn.StartPipeline(context.Background()) + defer func() { + if pbr.err != nil { + pipeline.Close() + } + }() + + // Prepare any needed queries + if len(distinctNewQueries) > 0 { + for _, sd := range distinctNewQueries { + pipeline.SendPrepare(sd.Name, sd.SQL, nil) + } + + err := pipeline.Sync() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + } + + for _, sd := range distinctNewQueries { + results, err := pipeline.GetResults() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + } + + resultSD, ok := results.(*pgconn.StatementDescription) + if !ok { + return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)} + } + + // Fill in the previously empty / pending statement descriptions. + sd.ParamOIDs = resultSD.ParamOIDs + sd.Fields = resultSD.Fields + } + + results, err := pipeline.GetResults() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + } + + _, ok := results.(*pgconn.PipelineSync) + if !ok { + return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)} + } + } + + // Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later. + if sdCache != nil { + for _, sd := range distinctNewQueries { + c.statementCache.Put(sd) + } + } + + // Queue the queries. + for _, bi := range b.items { + err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments) + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + } + + if bi.sd.Name == "" { + pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats) + } else { + pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) + } + } + + err := pipeline.Sync() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + } + + return &pipelineBatchResults{ + ctx: ctx, + conn: c, + pipeline: pipeline, + b: b, + } +} + func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) { if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" { return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on") @@ -1015,3 +1177,37 @@ order by attnum`, return fields, nil } + +func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error { + if c.descriptionCache != nil { + c.descriptionCache.HandleInvalidated() + } + + var invalidatedStatements []*pgconn.StatementDescription + if c.statementCache != nil { + invalidatedStatements = c.statementCache.HandleInvalidated() + } + + if len(invalidatedStatements) == 0 { + return nil + } + + pipeline := c.pgConn.StartPipeline(ctx) + defer pipeline.Close() + + for _, sd := range invalidatedStatements { + pipeline.SendDeallocate(sd.Name) + } + + err := pipeline.Sync() + if err != nil { + return fmt.Errorf("failed to deallocate cached statement(s): %w", err) + } + + err = pipeline.Close() + if err != nil { + return fmt.Errorf("failed to deallocate cached statement(s): %w", err) + } + + return nil +} diff --git a/conn_test.go b/conn_test.go index 79697cbd..a023a7d7 100644 --- a/conn_test.go +++ b/conn_test.go @@ -931,6 +931,7 @@ func TestStmtCacheInvalidationConn(t *testing.T) { rows, err := conn.Query(ctx, getSQL, 1) require.NoError(t, err) rows.Close() + require.NoError(t, rows.Err()) // Now, change the schema of the table out from under the statement, making it invalid. _, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") @@ -948,10 +949,10 @@ func TestStmtCacheInvalidationConn(t *testing.T) { rows.Close() for _, err := range []error{nextErr, rows.Err()} { if err == nil { - t.Fatal("expected InvalidCachedStatementPlanError: no error") + t.Fatal(`expected "cached plan must not change result type": no error`) } if !strings.Contains(err.Error(), "cached plan must not change result type") { - t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error()) + t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error()) } } @@ -995,6 +996,7 @@ func TestStmtCacheInvalidationTx(t *testing.T) { rows, err := tx.Query(ctx, getSQL, 1) require.NoError(t, err) rows.Close() + require.NoError(t, rows.Err()) // Now, change the schema of the table out from under the statement, making it invalid. _, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") @@ -1012,18 +1014,17 @@ func TestStmtCacheInvalidationTx(t *testing.T) { rows.Close() for _, err := range []error{nextErr, rows.Err()} { if err == nil { - t.Fatal("expected InvalidCachedStatementPlanError: no error") + t.Fatal(`expected "cached plan must not change result type": no error`) } if !strings.Contains(err.Error(), "cached plan must not change result type") { - t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error()) + t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error()) } } - rows, err = tx.Query(ctx, getSQL, 1) - require.NoError(t, err) // error does not pop up immediately - rows.Next() + rows, _ = tx.Query(ctx, getSQL, 1) + rows.Close() err = rows.Err() - // Retries within the same transaction are errors (really anything except a rollbakc + // Retries within the same transaction are errors (really anything except a rollback // will be an error in this transaction). require.Error(t, err) rows.Close() diff --git a/internal/stmtcache/lru.go b/internal/stmtcache/lru.go deleted file mode 100644 index a3378c86..00000000 --- a/internal/stmtcache/lru.go +++ /dev/null @@ -1,169 +0,0 @@ -package stmtcache - -import ( - "container/list" - "context" - "fmt" - "sync/atomic" - - "github.com/jackc/pgx/v5/pgconn" -) - -var lruCount uint64 - -// LRU implements Cache with a Least Recently Used (LRU) cache. -type LRU struct { - conn *pgconn.PgConn - mode int - cap int - prepareCount int - m map[string]*list.Element - l *list.List - psNamePrefix string - stmtsToClear []string -} - -// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. -func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { - mustBeValidMode(mode) - mustBeValidCap(cap) - - n := atomic.AddUint64(&lruCount, 1) - - return &LRU{ - conn: conn, - mode: mode, - cap: cap, - m: make(map[string]*list.Element), - l: list.New(), - psNamePrefix: fmt.Sprintf("lrupsc_%d", n), - } -} - -// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. -func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { - if ctx != context.Background() { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - } - - // flush an outstanding bad statements - txStatus := c.conn.TxStatus() - if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 { - for _, stmt := range c.stmtsToClear { - err := c.clearStmt(ctx, stmt) - if err != nil { - return nil, err - } - } - } - - if el, ok := c.m[sql]; ok { - c.l.MoveToFront(el) - return el.Value.(*pgconn.StatementDescription), nil - } - - if c.l.Len() == c.cap { - err := c.removeOldest(ctx) - if err != nil { - return nil, err - } - } - - psd, err := c.prepare(ctx, sql) - if err != nil { - return nil, err - } - - el := c.l.PushFront(psd) - c.m[sql] = el - - return psd, nil -} - -// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. -func (c *LRU) Clear(ctx context.Context) error { - for c.l.Len() > 0 { - err := c.removeOldest(ctx) - if err != nil { - return err - } - } - - return nil -} - -func (c *LRU) StatementErrored(sql string, err error) { - pgErr, ok := err.(*pgconn.PgError) - if !ok { - return - } - - // https://github.com/jackc/pgx/issues/1162 - // - // We used to look for the message "cached plan must not change result type". However, that message can be localized. - // Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to - // tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't - // have so it should be safe. - possibleInvalidCachedPlanError := pgErr.Code == "0A000" - if possibleInvalidCachedPlanError { - c.stmtsToClear = append(c.stmtsToClear, sql) - } -} - -func (c *LRU) clearStmt(ctx context.Context, sql string) error { - elem, inMap := c.m[sql] - if !inMap { - // The statement probably fell off the back of the list. In that case, we've - // ensured that it isn't in the cache, so we can declare victory. - return nil - } - - c.l.Remove(elem) - - psd := elem.Value.(*pgconn.StatementDescription) - delete(c.m, psd.SQL) - if c.mode == ModePrepare { - return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close() - } - return nil -} - -// Len returns the number of cached prepared statement descriptions. -func (c *LRU) Len() int { - return c.l.Len() -} - -// Cap returns the maximum number of cached prepared statement descriptions. -func (c *LRU) Cap() int { - return c.cap -} - -// Mode returns the mode of the cache (ModePrepare or ModeDescribe) -func (c *LRU) Mode() int { - return c.mode -} - -func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { - var name string - if c.mode == ModePrepare { - name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) - c.prepareCount += 1 - } - - return c.conn.Prepare(ctx, name, sql, nil) -} - -func (c *LRU) removeOldest(ctx context.Context) error { - oldest := c.l.Back() - c.l.Remove(oldest) - psd := oldest.Value.(*pgconn.StatementDescription) - delete(c.m, psd.SQL) - if c.mode == ModePrepare { - return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close() - } - return nil -} diff --git a/internal/stmtcache/lru_cache.go b/internal/stmtcache/lru_cache.go new file mode 100644 index 00000000..a25cc8b1 --- /dev/null +++ b/internal/stmtcache/lru_cache.go @@ -0,0 +1,98 @@ +package stmtcache + +import ( + "container/list" + + "github.com/jackc/pgx/v5/pgconn" +) + +// LRUCache implements Cache with a Least Recently Used (LRU) cache. +type LRUCache struct { + cap int + m map[string]*list.Element + l *list.List + invalidStmts []*pgconn.StatementDescription +} + +// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache. +func NewLRUCache(cap int) *LRUCache { + return &LRUCache{ + cap: cap, + m: make(map[string]*list.Element), + l: list.New(), + } +} + +// Get returns the statement description for sql. Returns nil if not found. +func (c *LRUCache) Get(key string) *pgconn.StatementDescription { + if el, ok := c.m[key]; ok { + c.l.MoveToFront(el) + return el.Value.(*pgconn.StatementDescription) + } + + return nil + +} + +// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. +func (c *LRUCache) Put(sd *pgconn.StatementDescription) { + if sd.SQL == "" { + panic("cannot store statement description with empty SQL") + } + + if _, present := c.m[sd.SQL]; present { + return + } + + if c.l.Len() == c.cap { + c.invalidateOldest() + } + + el := c.l.PushFront(sd) + c.m[sd.SQL] = el +} + +// Invalidate invalidates statement description identified by sql. Does nothing if not found. +func (c *LRUCache) Invalidate(sql string) { + if el, ok := c.m[sql]; ok { + delete(c.m, sql) + c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) + c.l.Remove(el) + } +} + +// InvalidateAll invalidates all statement descriptions. +func (c *LRUCache) InvalidateAll() { + el := c.l.Front() + for el != nil { + c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) + el = el.Next() + } + + c.m = make(map[string]*list.Element) + c.l = list.New() +} + +func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription { + invalidStmts := c.invalidStmts + c.invalidStmts = nil + return invalidStmts +} + +// Len returns the number of cached prepared statement descriptions. +func (c *LRUCache) Len() int { + return c.l.Len() +} + +// Cap returns the maximum number of cached prepared statement descriptions. +func (c *LRUCache) Cap() int { + return c.cap +} + +func (c *LRUCache) invalidateOldest() { + oldest := c.l.Back() + sd := oldest.Value.(*pgconn.StatementDescription) + c.invalidStmts = append(c.invalidStmts, sd) + delete(c.m, sd.SQL) + c.l.Remove(oldest) +} diff --git a/internal/stmtcache/lru_test.go b/internal/stmtcache/lru_test.go deleted file mode 100644 index 7690a2b0..00000000 --- a/internal/stmtcache/lru_test.go +++ /dev/null @@ -1,292 +0,0 @@ -package stmtcache_test - -import ( - "context" - "fmt" - "math/rand" - "os" - "regexp" - "testing" - "time" - - "github.com/jackc/pgx/v5/internal/stmtcache" - "github.com/jackc/pgx/v5/pgconn" - - "github.com/stretchr/testify/require" -) - -func TestLRUModePrepare(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer conn.Close(ctx) - - cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) - require.EqualValues(t, 0, cache.Len()) - require.EqualValues(t, 2, cache.Cap()) - require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) - - psd, err := cache.Get(ctx, "select 1") - require.NoError(t, err) - require.NotNil(t, psd) - require.EqualValues(t, 1, cache.Len()) - require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) - - psd, err = cache.Get(ctx, "select 1") - require.NoError(t, err) - require.NotNil(t, psd) - require.EqualValues(t, 1, cache.Len()) - require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) - - psd, err = cache.Get(ctx, "select 2") - require.NoError(t, err) - require.NotNil(t, psd) - require.EqualValues(t, 2, cache.Len()) - require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn)) - - psd, err = cache.Get(ctx, "select 3") - require.NoError(t, err) - require.NotNil(t, psd) - require.EqualValues(t, 2, cache.Len()) - require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn)) - - err = cache.Clear(ctx) - require.NoError(t, err) - require.EqualValues(t, 0, cache.Len()) - require.Empty(t, fetchServerStatements(t, ctx, conn)) -} - -func TestLRUStmtInvalidation(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer conn.Close(ctx) - - // we construct a fake error because its not super straightforward to actually call - // a prepared statement from the LRU cache without the helper routines which live - // in pgx proper. - fakeInvalidCachePlanError := &pgconn.PgError{ - Severity: "ERROR", - Code: "0A000", - Message: "cached plan must not change result type", - } - - cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) - - // - // outside of a transaction, we eagerly flush the statement - // - - _, err = cache.Get(ctx, "select 1") - require.NoError(t, err) - require.EqualValues(t, 1, cache.Len()) - require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) - - cache.StatementErrored("select 1", fakeInvalidCachePlanError) - _, err = cache.Get(ctx, "select 2") - require.NoError(t, err) - require.EqualValues(t, 1, cache.Len()) - require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) - - err = cache.Clear(ctx) - require.NoError(t, err) - - // - // within an errored transaction, we defer the flush to after the first get - // that happens after the transaction is rolled back - // - - _, err = cache.Get(ctx, "select 1") - require.NoError(t, err) - require.EqualValues(t, 1, cache.Len()) - require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) - - res := conn.Exec(ctx, "begin") - require.NoError(t, res.Close()) - require.Equal(t, byte('T'), conn.TxStatus()) - - res = conn.Exec(ctx, "selec") - require.Error(t, res.Close()) - require.Equal(t, byte('E'), conn.TxStatus()) - - cache.StatementErrored("select 1", fakeInvalidCachePlanError) - require.EqualValues(t, 1, cache.Len()) - - res = conn.Exec(ctx, "rollback") - require.NoError(t, res.Close()) - - _, err = cache.Get(ctx, "select 2") - require.EqualValues(t, 1, cache.Len()) - require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) -} - -func TestLRUStmtInvalidationIntegration(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer conn.Close(ctx) - - cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) - - result := conn.ExecParams(ctx, "create temporary table stmtcache_table (a text)", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - - sql := "select * from stmtcache_table" - sd1, err := cache.Get(ctx, sql) - require.NoError(t, err) - - result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read() - require.NoError(t, result.Err) - - result = conn.ExecParams(ctx, "alter table stmtcache_table add column b text", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - - result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read() - require.EqualError(t, result.Err, "ERROR: cached plan must not change result type (SQLSTATE 0A000)") - - cache.StatementErrored(sql, result.Err) - - sd2, err := cache.Get(ctx, sql) - require.NoError(t, err) - require.NotEqual(t, sd1.Name, sd2.Name) - - result = conn.ExecPrepared(ctx, sd2.Name, nil, nil, nil).Read() - require.NoError(t, result.Err) -} - -func TestLRUModePrepareStress(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer conn.Close(ctx) - - cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 8) - require.EqualValues(t, 0, cache.Len()) - require.EqualValues(t, 8, cache.Cap()) - require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) - - for i := 0; i < 1000; i++ { - psd, err := cache.Get(ctx, fmt.Sprintf("select %d", rand.Intn(50))) - require.NoError(t, err) - require.NotNil(t, psd) - result := conn.ExecPrepared(ctx, psd.Name, nil, nil, nil).Read() - require.NoError(t, result.Err) - } -} - -func TestLRUModeDescribe(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer conn.Close(ctx) - - cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) - require.EqualValues(t, 0, cache.Len()) - require.EqualValues(t, 2, cache.Cap()) - require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode()) - - psd, err := cache.Get(ctx, "select 1") - require.NoError(t, err) - require.NotNil(t, psd) - require.EqualValues(t, 1, cache.Len()) - require.Empty(t, fetchServerStatements(t, ctx, conn)) - - psd, err = cache.Get(ctx, "select 1") - require.NoError(t, err) - require.NotNil(t, psd) - require.EqualValues(t, 1, cache.Len()) - require.Empty(t, fetchServerStatements(t, ctx, conn)) - - psd, err = cache.Get(ctx, "select 2") - require.NoError(t, err) - require.NotNil(t, psd) - require.EqualValues(t, 2, cache.Len()) - require.Empty(t, fetchServerStatements(t, ctx, conn)) - - psd, err = cache.Get(ctx, "select 3") - require.NoError(t, err) - require.NotNil(t, psd) - require.EqualValues(t, 2, cache.Len()) - require.Empty(t, fetchServerStatements(t, ctx, conn)) - - err = cache.Clear(ctx) - require.NoError(t, err) - require.EqualValues(t, 0, cache.Len()) - require.Empty(t, fetchServerStatements(t, ctx, conn)) -} - -func TestLRUContext(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer conn.Close(ctx) - - cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) - - // test 1 : getting a value for the first time with a cancelled context returns an error - ctx1, cancel1 := context.WithCancel(ctx) - cancel1() - - desc, err := cache.Get(ctx1, "SELECT 1") - require.Error(t, err) - require.Nil(t, desc) - - // test 2 : when querying for the 2nd time a cached value, if the context is canceled return an error - ctx2, cancel2 := context.WithCancel(ctx) - - desc, err = cache.Get(ctx2, "SELECT 2") - require.NoError(t, err) - require.NotNil(t, desc) - - cancel2() - - desc, err = cache.Get(ctx2, "SELECT 2") - require.Error(t, err) - require.Nil(t, desc) -} - -func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string { - result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - var statements []string - for _, r := range result.Rows { - statement := string(r[0]) - if conn.ParameterStatus("crdb_version") != "" { - if statement == "PREPARE AS select statement from pg_prepared_statements" { - // CockroachDB includes the currently running unnamed prepared statement while PostgreSQL does not. Ignore it. - continue - } - - // CockroachDB includes the "PREPARE ... AS" text in the statement even if it was prepared through the extended - // protocol will PostgreSQL does not. Normalize the statement. - re := regexp.MustCompile(`^PREPARE lrupsc[0-9_]+ AS `) - statement = re.ReplaceAllString(statement, "") - } - statements = append(statements, statement) - } - return statements -} diff --git a/internal/stmtcache/stmtcache.go b/internal/stmtcache/stmtcache.go index a2582019..f975273e 100644 --- a/internal/stmtcache/stmtcache.go +++ b/internal/stmtcache/stmtcache.go @@ -2,57 +2,56 @@ package stmtcache import ( - "context" + "strconv" + "sync/atomic" "github.com/jackc/pgx/v5/pgconn" ) -const ( - ModePrepare = iota // Cache should prepare named statements. - ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement. -) +var stmtCounter int64 -// Cache prepares and caches prepared statement descriptions. +// NextStatementName returns a statement name that will be unique for the lifetime of the program. +func NextStatementName() string { + n := atomic.AddInt64(&stmtCounter, 1) + return "stmtcache_" + strconv.FormatInt(n, 10) +} + +// Cache caches statement descriptions. type Cache interface { - // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. - Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) + // Get returns the statement description for sql. Returns nil if not found. + Get(sql string) *pgconn.StatementDescription - // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. - Clear(ctx context.Context) error + // Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. + Put(sd *pgconn.StatementDescription) - // StatementErrored informs the cache that the given statement resulted in an error when it - // was last used against the database. In some cases, this will cause the cache to maer that - // statement as bad. The bad statement will instead be flushed during the next call to Get - // that occurs outside of a failed transaction. - StatementErrored(sql string, err error) + // Invalidate invalidates statement description identified by sql. Does nothing if not found. + Invalidate(sql string) + + // InvalidateAll invalidates all statement descriptions. + InvalidateAll() + + // HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated. + HandleInvalidated() []*pgconn.StatementDescription // Len returns the number of cached prepared statement descriptions. Len() int // Cap returns the maximum number of cached prepared statement descriptions. Cap() int - - // Mode returns the mode of the cache (ModePrepare or ModeDescribe) - Mode() int } -// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is -// the maximum size of the cache. -func New(conn *pgconn.PgConn, mode int, cap int) Cache { - mustBeValidMode(mode) - mustBeValidCap(cap) - - return NewLRU(conn, mode, cap) -} - -func mustBeValidMode(mode int) { - if mode != ModePrepare && mode != ModeDescribe { - panic("mode must be ModePrepare or ModeDescribe") +func IsStatementInvalid(err error) bool { + pgErr, ok := err.(*pgconn.PgError) + if !ok { + return false } -} -func mustBeValidCap(cap int) { - if cap < 1 { - panic("cache must have cap of >= 1") - } + // https://github.com/jackc/pgx/issues/1162 + // + // We used to look for the message "cached plan must not change result type". However, that message can be localized. + // Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to + // tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't + // have so it should be safe. + possibleInvalidCachedPlanError := pgErr.Code == "0A000" + return possibleInvalidCachedPlanError } diff --git a/internal/stmtcache/unlimited_cache.go b/internal/stmtcache/unlimited_cache.go new file mode 100644 index 00000000..f5f59396 --- /dev/null +++ b/internal/stmtcache/unlimited_cache.go @@ -0,0 +1,71 @@ +package stmtcache + +import ( + "math" + + "github.com/jackc/pgx/v5/pgconn" +) + +// UnlimitedCache implements Cache with no capacity limit. +type UnlimitedCache struct { + m map[string]*pgconn.StatementDescription + invalidStmts []*pgconn.StatementDescription +} + +// NewUnlimitedCache creates a new UnlimitedCache. +func NewUnlimitedCache() *UnlimitedCache { + return &UnlimitedCache{ + m: make(map[string]*pgconn.StatementDescription), + } +} + +// Get returns the statement description for sql. Returns nil if not found. +func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription { + return c.m[sql] +} + +// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. +func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) { + if sd.SQL == "" { + panic("cannot store statement description with empty SQL") + } + + if _, present := c.m[sd.SQL]; present { + return + } + + c.m[sd.SQL] = sd +} + +// Invalidate invalidates statement description identified by sql. Does nothing if not found. +func (c *UnlimitedCache) Invalidate(sql string) { + if sd, ok := c.m[sql]; ok { + delete(c.m, sql) + c.invalidStmts = append(c.invalidStmts, sd) + } +} + +// InvalidateAll invalidates all statement descriptions. +func (c *UnlimitedCache) InvalidateAll() { + for _, sd := range c.m { + c.invalidStmts = append(c.invalidStmts, sd) + } + + c.m = make(map[string]*pgconn.StatementDescription) +} + +func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription { + invalidStmts := c.invalidStmts + c.invalidStmts = nil + return invalidStmts +} + +// Len returns the number of cached prepared statement descriptions. +func (c *UnlimitedCache) Len() int { + return len(c.m) +} + +// Cap returns the maximum number of cached prepared statement descriptions. +func (c *UnlimitedCache) Cap() int { + return math.MaxInt +} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index bb4d35a9..65fb015a 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1909,10 +1909,10 @@ func (p *Pipeline) Close() error { for p.expectedReadyForQueryCount > 0 { _, err := p.GetResults() if err != nil { + p.err = err var pgErr *PgError if !errors.As(err, &pgErr) { p.conn.asyncClose() - p.err = err break } } diff --git a/rows.go b/rows.go index a1492c3e..4d4c5ec6 100644 --- a/rows.go +++ b/rows.go @@ -7,6 +7,7 @@ import ( "reflect" "time" + "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" @@ -173,8 +174,16 @@ func (rows *baseRows) Close() { } } - if rows.err != nil && rows.conn != nil && rows.conn.statementCache != nil { - rows.conn.statementCache.StatementErrored(rows.sql, rows.err) + if rows.err != nil && rows.conn != nil && rows.sql != "" { + if stmtcache.IsStatementInvalid(rows.err) { + if sc := rows.conn.statementCache; sc != nil { + sc.Invalidate(rows.sql) + } + + if sc := rows.conn.descriptionCache; sc != nil { + sc.Invalidate(rows.sql) + } + } } } From c31b89a3f2816d834ef52a038e4ebe2b887f7fb0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Jul 2022 10:20:54 -0500 Subject: [PATCH 1077/1158] Delay handling invalidated statements when in transaction --- conn.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/conn.go b/conn.go index 997a84d3..ea4229d7 100644 --- a/conn.go +++ b/conn.go @@ -1179,6 +1179,10 @@ order by attnum`, } func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error { + if c.pgConn.TxStatus() != 'I' { + return nil + } + if c.descriptionCache != nil { c.descriptionCache.HandleInvalidated() } From 3dafb5d4ee34145e26172dd272428eba4b52a164 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Jul 2022 10:21:17 -0500 Subject: [PATCH 1078/1158] Skip test with non-standard CRDB behavior --- conn_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/conn_test.go b/conn_test.go index a023a7d7..596b0d18 100644 --- a/conn_test.go +++ b/conn_test.go @@ -974,6 +974,10 @@ func TestStmtCacheInvalidationTx(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip("Server has non-standard prepare in errored transaction behavior (https://github.com/cockroachdb/cockroach/issues/84140)") + } + // create a table and fill it with some data _, err := conn.Exec(ctx, ` DROP TABLE IF EXISTS drop_cols; From da192291f75eb9effbb1ab8e6e2e182c3929d1bb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Jul 2022 16:38:34 -0500 Subject: [PATCH 1079/1158] Add CollectRows and RowTo* functions Collect functionality was originally developed in pgxutil --- CHANGELOG.md | 5 ++- rows.go | 121 +++++++++++++++++++++++++++++++++++++++++++++++++++ rows_test.go | 96 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 220 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b6c3a96..6d2d8ee8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -135,9 +135,10 @@ allows arbitrary rewriting of query SQL and arguments. The `RowScanner` interface allows a single argument to Rows.Scan to scan the entire row. -## QueryFunc Replaced +## Rows Result Helpers -`QueryFunc` has been replaced by using `ForEachScannedRow`. +* `CollectRows` and `RowTo*` functions simplify collecting results into a slice. +* `QueryFunc` has been replaced by using `ForEachScannedRow`. ## SendBatch Uses Pipeline Mode When Appropriate diff --git a/rows.go b/rows.go index 4d4c5ec6..0c630bc4 100644 --- a/rows.go +++ b/rows.go @@ -395,3 +395,124 @@ func ForEachScannedRow(rows Rows, scans []any, fn func() error) (pgconn.CommandT return rows.CommandTag(), nil } + +// CollectableRow is the subset of Rows methods that a RowToFunc is allowed to call. +type CollectableRow interface { + FieldDescriptions() []pgproto3.FieldDescription + Scan(dest ...any) error + Values() ([]any, error) + RawValues() [][]byte +} + +// RowToFunc is a function that scans or otherwise converts row to a T. +type RowToFunc[T any] func(row CollectableRow) (T, error) + +// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T. +func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { + defer rows.Close() + + slice := []T{} + + for rows.Next() { + value, err := fn(rows) + if err != nil { + return nil, err + } + slice = append(slice, value) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return slice, nil +} + +// RowTo returns a T scanned from row. +func RowTo[T any](row CollectableRow) (T, error) { + var value T + err := row.Scan(&value) + return value, err +} + +// RowTo returns a the address of a T scanned from row. +func RowToAddrOf[T any](row CollectableRow) (*T, error) { + var value T + err := row.Scan(&value) + return &value, err +} + +// RowToMap returns a map scanned from row. +func RowToMap(row CollectableRow) (map[string]any, error) { + var value map[string]any + err := row.Scan((*mapRowScanner)(&value)) + return value, err +} + +type mapRowScanner map[string]any + +func (rs *mapRowScanner) ScanRow(rows Rows) error { + values, err := rows.Values() + if err != nil { + return err + } + + *rs = make(mapRowScanner, len(values)) + + for i := range values { + (*rs)[string(rows.FieldDescriptions()[i].Name)] = values[i] + } + + return nil +} + +// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row +// has fields. The row and T fields will by matched by position. +func RowToStructByPos[T any](row CollectableRow) (T, error) { + var value T + err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) + return value, err +} + +// RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a +// public fields as row has fields. The row and T fields will by matched by position. +func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) { + var value T + err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) + return &value, err +} + +type positionalStructRowScanner struct { + ptrToStruct any +} + +func (rs *positionalStructRowScanner) ScanRow(rows Rows) error { + dst := rs.ptrToStruct + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() != reflect.Ptr { + return fmt.Errorf("dst not a pointer") + } + + dstElemValue := dstValue.Elem() + dstElemType := dstElemValue.Type() + + exportedFields := make([]int, 0, dstElemType.NumField()) + for i := 0; i < dstElemType.NumField(); i++ { + sf := dstElemType.Field(i) + if sf.PkgPath == "" { + exportedFields = append(exportedFields, i) + } + } + + rowFieldCount := len(rows.RawValues()) + if rowFieldCount > len(exportedFields) { + return fmt.Errorf("got %d values, but dst struct has only %d fields", rowFieldCount, len(exportedFields)) + } + + scanTargets := make([]any, rowFieldCount) + for i := 0; i < rowFieldCount; i++ { + scanTargets[i] = dstElemValue.Field(exportedFields[i]).Addr().Interface() + } + + return rows.Scan(scanTargets...) +} diff --git a/rows_test.go b/rows_test.go index 63bb77d5..9f07ee2e 100644 --- a/rows_test.go +++ b/rows_test.go @@ -10,6 +10,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -128,3 +129,98 @@ func ExampleForEachScannedRow() { // 2, 4 // 3, 6 } + +func TestCollectRows(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`) + numbers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + require.NoError(t, err) + + assert.Len(t, numbers, 100) + for i := range numbers { + assert.Equal(t, int32(i), numbers[i]) + } + }) +} + +func TestRowTo(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`) + numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32]) + require.NoError(t, err) + + assert.Len(t, numbers, 100) + for i := range numbers { + assert.Equal(t, int32(i), numbers[i]) + } + }) +} + +func TestRowToAddrOf(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`) + numbers, err := pgx.CollectRows(rows, pgx.RowToAddrOf[int32]) + require.NoError(t, err) + + assert.Len(t, numbers, 100) + for i := range numbers { + assert.Equal(t, int32(i), *numbers[i]) + } + }) +} + +func TestRowToMap(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'Joe' as name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToMap) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Joe", slice[i]["name"]) + assert.EqualValues(t, i, slice[i]["age"]) + } + }) +} + +func TestRowToStructPos(t *testing.T) { + type person struct { + Name string + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'Joe' as name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByPos[person]) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Joe", slice[i].Name) + assert.EqualValues(t, i, slice[i].Age) + } + }) +} + +func TestRowToAddrOfStructPos(t *testing.T) { + type person struct { + Name string + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'Joe' as name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToAddrOfStructByPos[person]) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Joe", slice[i].Name) + assert.EqualValues(t, i, slice[i].Age) + } + }) +} From 90c2dc6f68f8eaece05203f8a98503df6336bb23 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Jul 2022 16:47:28 -0500 Subject: [PATCH 1080/1158] Rename ForEachScannedRow to ForEachRow --- CHANGELOG.md | 2 +- batch_test.go | 2 +- conn.go | 2 +- doc.go | 4 +- pgtype/integration_benchmark_test.go | 140 +++++++++++------------ pgtype/integration_benchmark_test.go.erb | 4 +- rows.go | 6 +- rows_test.go | 18 +-- 8 files changed, 89 insertions(+), 89 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d2d8ee8..44964ac7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -138,7 +138,7 @@ The `RowScanner` interface allows a single argument to Rows.Scan to scan the ent ## Rows Result Helpers * `CollectRows` and `RowTo*` functions simplify collecting results into a slice. -* `QueryFunc` has been replaced by using `ForEachScannedRow`. +* `ForEachRow` simplifies scanning each row and executing code using the scanned values. `ForEachRow` replaces `QueryFunc`. ## SendBatch Uses Pipeline Mode When Appropriate diff --git a/batch_test.go b/batch_test.go index f5409d25..e8d6f677 100644 --- a/batch_test.go +++ b/batch_test.go @@ -109,7 +109,7 @@ func TestConnSendBatch(t *testing.T) { rowCount = 0 rows, _ = br.Query() - _, err = pgx.ForEachScannedRow(rows, []any{&id, &description, &amount}, func() error { + _, err = pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error { if id != selectFromLedgerExpectedRows[rowCount].id { t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) } diff --git a/conn.go b/conn.go index ea4229d7..3f6bee0d 100644 --- a/conn.go +++ b/conn.go @@ -1163,7 +1163,7 @@ where attrelid=$1 order by attnum`, typrelid, ) - _, err = ForEachScannedRow(rows, []any{&fieldName, &fieldOID}, func() error { + _, err = ForEachRow(rows, []any{&fieldName, &fieldOID}, func() error { dt, ok := c.TypeMap().TypeForOID(fieldOID) if !ok { return fmt.Errorf("unknown composite type field OID: %v", fieldOID) diff --git a/doc.go b/doc.go index 2e779dbb..48971110 100644 --- a/doc.go +++ b/doc.go @@ -63,11 +63,11 @@ pgx implements Query and Scan in the familiar database/sql style. // No errors found - do something with sum -ForEachScannedRow can be used to execute a callback function for every row. This is often easier than iterating over rows directly. +ForEachRow can be used to execute a callback function for every row. This is often easier than iterating over rows directly. var sum, n int32 rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 10) - _, err := pgx.ForEachScannedRow(rows, []any{&n}, func(pgx.QueryFuncRow) error { + _, err := pgx.ForEachRow(rows, []any{&n}, func(pgx.QueryFuncRow) error { sum += n return nil }) diff --git a/pgtype/integration_benchmark_test.go b/pgtype/integration_benchmark_test.go index 4ba8b9b5..22ac3344 100644 --- a/pgtype/integration_benchmark_test.go +++ b/pgtype/integration_benchmark_test.go @@ -18,7 +18,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *test `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -36,7 +36,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *te `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -54,7 +54,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_10_columns(b *tes `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -72,7 +72,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_1_rows_10_columns(b *t `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -90,7 +90,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_10_rows_1_columns(b *tes `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -108,7 +108,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_10_rows_1_columns(b *t `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -126,7 +126,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_100_rows_10_columns(b *t `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -144,7 +144,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_100_rows_10_columns(b `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -162,7 +162,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_1_rows_1_columns(b *test `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -180,7 +180,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_1_rows_1_columns(b *te `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -198,7 +198,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_1_rows_10_columns(b *tes `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -216,7 +216,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_1_rows_10_columns(b *t `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -234,7 +234,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_10_rows_1_columns(b *tes `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -252,7 +252,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_10_rows_1_columns(b *t `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -270,7 +270,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_100_rows_10_columns(b *t `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -288,7 +288,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_100_rows_10_columns(b `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -306,7 +306,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_1_rows_1_columns(b *test `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -324,7 +324,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_1_rows_1_columns(b *te `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -342,7 +342,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_1_rows_10_columns(b *tes `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -360,7 +360,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_1_rows_10_columns(b *t `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -378,7 +378,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_10_rows_1_columns(b *tes `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -396,7 +396,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_10_rows_1_columns(b *t `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -414,7 +414,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_100_rows_10_columns(b *t `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -432,7 +432,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_100_rows_10_columns(b `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -450,7 +450,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_1_rows_1_columns(b *tes `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -468,7 +468,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_1_rows_1_columns(b *t `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -486,7 +486,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_1_rows_10_columns(b *te `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -504,7 +504,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_1_rows_10_columns(b * `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -522,7 +522,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_10_rows_1_columns(b *te `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -540,7 +540,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_10_rows_1_columns(b * `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -558,7 +558,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_100_rows_10_columns(b * `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -576,7 +576,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_100_rows_10_columns(b `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -594,7 +594,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_1_columns(b `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -612,7 +612,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_1_columns `select n::int4 + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -630,7 +630,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_10_columns( `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -648,7 +648,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_10_column `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -666,7 +666,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_10_rows_1_columns( `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -684,7 +684,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_10_rows_1_column `select n::int4 + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -702,7 +702,7 @@ func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_100_rows_10_column `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -720,7 +720,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_100_rows_10_colu `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -738,7 +738,7 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_1_rows_1_columns(b *t `select n::numeric + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -756,7 +756,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_1_rows_1_columns(b `select n::numeric + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -774,7 +774,7 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_1_rows_10_columns(b * `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -792,7 +792,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_1_rows_10_columns(b `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -810,7 +810,7 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_10_rows_1_columns(b * `select n::numeric + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -828,7 +828,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_10_rows_1_columns(b `select n::numeric + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -846,7 +846,7 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_100_rows_10_columns(b `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -864,7 +864,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_100_rows_10_columns `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -882,7 +882,7 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_1_rows_1_columns(b `select n::numeric + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -900,7 +900,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_1_rows_1_columns( `select n::numeric + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -918,7 +918,7 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_1_rows_10_columns(b `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -936,7 +936,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_1_rows_10_columns `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -954,7 +954,7 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_10_rows_1_columns(b `select n::numeric + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -972,7 +972,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_10_rows_1_columns `select n::numeric + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -990,7 +990,7 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_100_rows_10_columns `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1008,7 +1008,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_100_rows_10_colum `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1026,7 +1026,7 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_1_col `select n::numeric + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1044,7 +1044,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_1_c `select n::numeric + 0 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1062,7 +1062,7 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_10_co `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1080,7 +1080,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_10_ `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1098,7 +1098,7 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_10_rows_1_co `select n::numeric + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1116,7 +1116,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_10_rows_1_ `select n::numeric + 0 from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1134,7 +1134,7 @@ func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_10_ `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1152,7 +1152,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_1 `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1170,7 +1170,7 @@ func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testing `select array_agg(n) from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1188,7 +1188,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testi `select array_agg(n) from generate_series(1, 10) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1206,7 +1206,7 @@ func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *testin `select array_agg(n) from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1224,7 +1224,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *test `select array_agg(n) from generate_series(1, 100) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1242,7 +1242,7 @@ func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *testi `select array_agg(n) from generate_series(1, 1000) n`, []any{pgx.QueryResultFormats{pgx.TextFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -1260,7 +1260,7 @@ func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *tes `select array_agg(n) from generate_series(1, 1000) n`, []any{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) if err != nil { b.Fatal(err) } diff --git a/pgtype/integration_benchmark_test.go.erb b/pgtype/integration_benchmark_test.go.erb index 144d9dd7..0175700a 100644 --- a/pgtype/integration_benchmark_test.go.erb +++ b/pgtype/integration_benchmark_test.go.erb @@ -27,7 +27,7 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go `select <% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>n::<%= pg_type %> + <%= col_idx%><% end %> from generate_series(1, <%= rows %>) n`, []any{pgx.QueryResultFormats{<%= format_code %>}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, func() error { return nil }) if err != nil { b.Fatal(err) } @@ -51,7 +51,7 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_Int4Array_With_Go_Int4Array `select array_agg(n) from generate_series(1, <%= array_size %>) n`, []any{pgx.QueryResultFormats{<%= format_code %>}}, ) - _, err := pgx.ForEachScannedRow(rows, []any{&v}, func() error { return nil }) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) if err != nil { b.Fatal(err) } diff --git a/rows.go b/rows.go index 0c630bc4..2afb31df 100644 --- a/rows.go +++ b/rows.go @@ -371,10 +371,10 @@ func RowsFromResultReader(typeMap *pgtype.Map, resultReader *pgconn.ResultReader } } -// ForEachScannedRow iterates through rows. For each row it scans into the elements of scans and calls fn. If any row +// ForEachRow iterates through rows. For each row it scans into the elements of scans and calls fn. If any row // fails to scan or fn returns an error the query will be aborted and the error will be returned. Rows will be closed -// when ForEachScannedRow returns. -func ForEachScannedRow(rows Rows, scans []any, fn func() error) (pgconn.CommandTag, error) { +// when ForEachRow returns. +func ForEachRow(rows Rows, scans []any, fn func() error) (pgconn.CommandTag, error) { defer rows.Close() for rows.Next() { diff --git a/rows_test.go b/rows_test.go index 9f07ee2e..8806b1f6 100644 --- a/rows_test.go +++ b/rows_test.go @@ -35,7 +35,7 @@ func TestRowScanner(t *testing.T) { }) } -func TestForEachScannedRow(t *testing.T) { +func TestForEachRow(t *testing.T) { t.Parallel() pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { @@ -47,7 +47,7 @@ func TestForEachScannedRow(t *testing.T) { 3, ) var a, b int - ct, err := pgx.ForEachScannedRow(rows, []any{&a, &b}, func() error { + ct, err := pgx.ForEachRow(rows, []any{&a, &b}, func() error { actualResults = append(actualResults, []any{a, b}) return nil }) @@ -63,7 +63,7 @@ func TestForEachScannedRow(t *testing.T) { }) } -func TestForEachScannedRowScanError(t *testing.T) { +func TestForEachRowScanError(t *testing.T) { t.Parallel() pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { @@ -75,7 +75,7 @@ func TestForEachScannedRowScanError(t *testing.T) { 3, ) var a, b int - ct, err := pgx.ForEachScannedRow(rows, []any{&a, &b}, func() error { + ct, err := pgx.ForEachRow(rows, []any{&a, &b}, func() error { actualResults = append(actualResults, []any{a, b}) return nil }) @@ -84,7 +84,7 @@ func TestForEachScannedRowScanError(t *testing.T) { }) } -func TestForEachScannedRowAbort(t *testing.T) { +func TestForEachRowAbort(t *testing.T) { t.Parallel() pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { @@ -94,7 +94,7 @@ func TestForEachScannedRowAbort(t *testing.T) { 3, ) var a, b int - ct, err := pgx.ForEachScannedRow(rows, []any{&a, &b}, func() error { + ct, err := pgx.ForEachRow(rows, []any{&a, &b}, func() error { return errors.New("abort") }) require.EqualError(t, err, "abort") @@ -102,7 +102,7 @@ func TestForEachScannedRowAbort(t *testing.T) { }) } -func ExampleForEachScannedRow() { +func ExampleForEachRow() { conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { fmt.Printf("Unable to establish connection: %v", err) @@ -115,12 +115,12 @@ func ExampleForEachScannedRow() { 3, ) var a, b int - _, err = pgx.ForEachScannedRow(rows, []any{&a, &b}, func() error { + _, err = pgx.ForEachRow(rows, []any{&a, &b}, func() error { fmt.Printf("%v, %v\n", a, b) return nil }) if err != nil { - fmt.Printf("ForEachScannedRow error: %v", err) + fmt.Printf("ForEachRow error: %v", err) return } From 62f034758660c08051586a030f8404390c99efd7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Jul 2022 16:59:29 -0500 Subject: [PATCH 1081/1158] Add CollectOneRow --- CHANGELOG.md | 1 + rows.go | 21 +++++++++++++++++++++ rows_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 44964ac7..4438e0bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -138,6 +138,7 @@ The `RowScanner` interface allows a single argument to Rows.Scan to scan the ent ## Rows Result Helpers * `CollectRows` and `RowTo*` functions simplify collecting results into a slice. +* `CollectOneRow` collects one row using `RowTo*` functions. * `ForEachRow` simplifies scanning each row and executing code using the scanned values. `ForEachRow` replaces `QueryFunc`. ## SendBatch Uses Pipeline Mode When Appropriate diff --git a/rows.go b/rows.go index 2afb31df..90a24d28 100644 --- a/rows.go +++ b/rows.go @@ -428,6 +428,27 @@ func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { return slice, nil } +// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true. +// CollectOneRow is to CollectRows as QueryRow is to Query. +func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { + defer rows.Close() + + var value T + var err error + + if !rows.Next() { + return value, ErrNoRows + } + + value, err = fn(rows) + if err != nil { + return value, err + } + + rows.Close() + return value, rows.Err() +} + // RowTo returns a T scanned from row. func RowTo[T any](row CollectableRow) (T, error) { var value T diff --git a/rows_test.go b/rows_test.go index 8806b1f6..0e3bc179 100644 --- a/rows_test.go +++ b/rows_test.go @@ -147,6 +147,47 @@ func TestCollectRows(t *testing.T) { }) } +func TestCollectOneRow(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 42`) + n, err := pgx.CollectOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + assert.NoError(t, err) + assert.Equal(t, int32(42), n) + }) +} + +func TestCollectOneRowNotFound(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 42 where false`) + n, err := pgx.CollectOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + assert.ErrorIs(t, err, pgx.ErrNoRows) + assert.Equal(t, int32(0), n) + }) +} + +func TestCollectOneRowIgnoresExtraRows(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(42, 99) n`) + n, err := pgx.CollectOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + require.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, int32(42), n) + }) +} + func TestRowTo(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`) From 31ec18cc650b3fc2320563b9801f42a1b1cef45e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Jul 2022 17:25:55 -0500 Subject: [PATCH 1082/1158] Replace Begin and BeginTx methods with functions --- CHANGELOG.md | 5 ++ doc.go | 4 +- pgxpool/conn.go | 8 --- pgxpool/pool.go | 14 ----- pgxpool/pool_test.go | 10 ++-- pgxpool/tx.go | 4 -- tx.go | 140 +++++++++++++++++++------------------------ tx_test.go | 14 ++--- 8 files changed, 82 insertions(+), 117 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4438e0bc..4e97434f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -141,6 +141,11 @@ The `RowScanner` interface allows a single argument to Rows.Scan to scan the ent * `CollectOneRow` collects one row using `RowTo*` functions. * `ForEachRow` simplifies scanning each row and executing code using the scanned values. `ForEachRow` replaces `QueryFunc`. +## Tx Helpers + +Rather than every type that implemented `Begin` or `BeginTx` methods also needing to implement `BeginFunc` and +`BeginTxFunc` these methods have been converted to functions that take a db that implements `Begin` or `BeginTx`. + ## SendBatch Uses Pipeline Mode When Appropriate Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1 diff --git a/doc.go b/doc.go index 48971110..0fd3713b 100644 --- a/doc.go +++ b/doc.go @@ -247,10 +247,10 @@ These are internally implemented with savepoints. Use BeginTx to control the transaction mode. -BeginFunc and BeginTxFunc are variants that begin a transaction, execute a function, and commit or rollback the +BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the transaction depending on the return value of the function. These can be simpler and less error prone to use. - err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error { _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") return err }) diff --git a/pgxpool/conn.go b/pgxpool/conn.go index b8711da9..b9ff29dc 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -92,14 +92,6 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er return c.Conn().BeginTx(ctx, txOptions) } -func (c *Conn) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { - return c.Conn().BeginFunc(ctx, f) -} - -func (c *Conn) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error { - return c.Conn().BeginTxFunc(ctx, txOptions, f) -} - func (c *Conn) Ping(ctx context.Context) error { return c.Conn().Ping(ctx) } diff --git a/pgxpool/pool.go b/pgxpool/pool.go index d73b93fb..7027e282 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -570,20 +570,6 @@ func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er return &Tx{t: t, c: c}, err } -func (p *Pool) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { - return p.BeginTxFunc(ctx, pgx.TxOptions{}, f) -} - -func (p *Pool) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error { - c, err := p.Acquire(ctx) - if err != nil { - return err - } - defer c.Release() - - return c.BeginTxFunc(ctx, txOptions, f) -} - func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { c, err := p.Acquire(ctx) if err != nil { diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 3e3058d2..5cd943d7 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -806,15 +806,15 @@ func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { db.Exec(context.Background(), "drop table pgxpooltx") }() - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)") require.NoError(t, err) return nil @@ -853,11 +853,11 @@ func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { db.Exec(context.Background(), "drop table pgxpooltx") }() - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)") require.NoError(t, err) return errors.New("do a rollback") diff --git a/pgxpool/tx.go b/pgxpool/tx.go index 3ddb742c..74df8593 100644 --- a/pgxpool/tx.go +++ b/pgxpool/tx.go @@ -18,10 +18,6 @@ func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) { return tx.t.Begin(ctx) } -func (tx *Tx) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { - return tx.t.BeginFunc(ctx, f) -} - // Commit commits the transaction and returns the associated connection back to the Pool. Commit will return ErrTxClosed // if the Tx is already closed, but is otherwise safe to call multiple times. If the commit fails with a rollback status // (e.g. the transaction was already in a broken state) then ErrTxCommitRollback will be returned. diff --git a/tx.go b/tx.go index 2a05b70d..24daf0f8 100644 --- a/tx.go +++ b/tx.go @@ -94,39 +94,6 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) { return &dbTx{conn: c}, nil } -// BeginFunc starts a transaction and calls f. If f does not return an error the transaction is committed. If f returns -// an error the transaction is rolled back. The context will be used when executing the transaction control statements -// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of f. -func (c *Conn) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { - return c.BeginTxFunc(ctx, TxOptions{}, f) -} - -// BeginTxFunc starts a transaction with txOptions determining the transaction mode and calls f. If f does not return -// an error the transaction is committed. If f returns an error the transaction is rolled back. The context will be -// used when executing the transaction control statements (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect -// the execution of f. -func (c *Conn) BeginTxFunc(ctx context.Context, txOptions TxOptions, f func(Tx) error) (err error) { - var tx Tx - tx, err = c.BeginTx(ctx, txOptions) - if err != nil { - return err - } - defer func() { - rollbackErr := tx.Rollback(ctx) - if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) { - err = rollbackErr - } - }() - - fErr := f(tx) - if fErr != nil { - _ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return - return fErr - } - - return tx.Commit(ctx) -} - // Tx represents a database transaction. // // Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx @@ -138,20 +105,17 @@ type Tx interface { // Begin starts a pseudo nested transaction. Begin(ctx context.Context) (Tx, error) - // BeginFunc starts a pseudo nested transaction and executes f. If f does not return an err the pseudo nested - // transaction will be committed. If it does then it will be rolled back. - BeginFunc(ctx context.Context, f func(Tx) error) (err error) - // Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested - // transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple - // times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then - // ErrTxCommitRollback will be returned. + // transaction. Commit will return an error where errors.Is(ErrTxClosed) is true if the Tx is already closed, but is + // otherwise safe to call multiple times. If the commit fails with a rollback status (e.g. the transaction was already + // in a broken state) then an error where errors.Is(ErrTxCommitRollback) is true will be returned. Commit(ctx context.Context) error // Rollback rolls back the transaction if this is a real transaction or rolls back to the savepoint if this is a - // pseudo nested transaction. Rollback will return ErrTxClosed if the Tx is already closed, but is otherwise safe to - // call multiple times. Hence, a defer tx.Rollback() is safe even if tx.Commit() will be called first in a non-error - // condition. Any other failure of a real transaction will result in the connection being closed. + // pseudo nested transaction. Rollback will return an error where errors.Is(ErrTxClosed) is true if the Tx is already + // closed, but is otherwise safe to call multiple times. Hence, a defer tx.Rollback() is safe even if tx.Commit() will + // be called first in a non-error condition. Any other failure of a real transaction will result in the connection + // being closed. Rollback(ctx context.Context) error CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) @@ -194,32 +158,6 @@ func (tx *dbTx) Begin(ctx context.Context) (Tx, error) { return &dbSimulatedNestedTx{tx: tx, savepointNum: tx.savepointNum}, nil } -func (tx *dbTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { - if tx.closed { - return ErrTxClosed - } - - var savepoint Tx - savepoint, err = tx.Begin(ctx) - if err != nil { - return err - } - defer func() { - rollbackErr := savepoint.Rollback(ctx) - if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) { - err = rollbackErr - } - }() - - fErr := f(savepoint) - if fErr != nil { - _ = savepoint.Rollback(ctx) // ignore rollback error as there is already an error to return - return fErr - } - - return savepoint.Commit(ctx) -} - // Commit commits the transaction. func (tx *dbTx) Commit(ctx context.Context) error { if tx.closed { @@ -335,14 +273,6 @@ func (sp *dbSimulatedNestedTx) Begin(ctx context.Context) (Tx, error) { return sp.tx.Begin(ctx) } -func (sp *dbSimulatedNestedTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { - if sp.closed { - return ErrTxClosed - } - - return sp.tx.BeginFunc(ctx, f) -} - // Commit releases the savepoint essentially committing the pseudo nested transaction. func (sp *dbSimulatedNestedTx) Commit(ctx context.Context) error { if sp.closed { @@ -427,3 +357,59 @@ func (sp *dbSimulatedNestedTx) LargeObjects() LargeObjects { func (sp *dbSimulatedNestedTx) Conn() *Conn { return sp.tx.Conn() } + +// BeginFunc calls Begin on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn +// returns an error it calls Rollback on db. The context will be used when executing the transaction control statements +// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn. +func BeginFunc( + ctx context.Context, + db interface { + Begin(ctx context.Context) (Tx, error) + }, + fn func(Tx) error, +) (err error) { + var tx Tx + tx, err = db.Begin(ctx) + if err != nil { + return err + } + + return beginFuncExec(ctx, tx, fn) +} + +// BeginTxFunc calls BeginTx on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn +// returns an error it calls Rollback on db. The context will be used when executing the transaction control statements +// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn. +func BeginTxFunc( + ctx context.Context, + db interface { + BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) + }, + txOptions TxOptions, + fn func(Tx) error, +) (err error) { + var tx Tx + tx, err = db.BeginTx(ctx, txOptions) + if err != nil { + return err + } + + return beginFuncExec(ctx, tx, fn) +} + +func beginFuncExec(ctx context.Context, tx Tx, fn func(Tx) error) (err error) { + defer func() { + rollbackErr := tx.Rollback(ctx) + if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) { + err = rollbackErr + } + }() + + fErr := fn(tx) + if fErr != nil { + _ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return + return fErr + } + + return tx.Commit(ctx) +} diff --git a/tx_test.go b/tx_test.go index d45553a2..9c1c70d3 100644 --- a/tx_test.go +++ b/tx_test.go @@ -312,7 +312,7 @@ func TestBeginFunc(t *testing.T) { _, err := conn.Exec(context.Background(), createSql) require.NoError(t, err) - err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error { _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") require.NoError(t, err) return nil @@ -341,7 +341,7 @@ func TestBeginFuncRollbackOnError(t *testing.T) { _, err := conn.Exec(context.Background(), createSql) require.NoError(t, err) - err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error { _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") require.NoError(t, err) return errors.New("some error") @@ -522,15 +522,15 @@ func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { _, err := db.Exec(context.Background(), createSql) require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (1)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (2)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (3)") require.NoError(t, err) return nil @@ -565,11 +565,11 @@ func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { _, err := db.Exec(context.Background(), createSql) require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (1)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (2)") require.NoError(t, err) return errors.New("do a rollback") From 80a529fcb74c0ebdbdf77d92c2aeb95ab9373a57 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Jul 2022 17:48:46 -0500 Subject: [PATCH 1083/1158] Test LoadType disambiguate name by schema --- conn_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/conn_test.go b/conn_test.go index 596b0d18..93e0cd84 100644 --- a/conn_test.go +++ b/conn_test.go @@ -906,6 +906,46 @@ func TestDomainType(t *testing.T) { }) } +func TestLoadTypeSameNameInDifferentSchemas(t *testing.T) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + tx, err := conn.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + _, err = tx.Exec(ctx, `create schema pgx_a; +create type pgx_a.point as (a text, b text); +create schema pgx_b; +create type pgx_b.point as (c text); +`) + require.NoError(t, err) + + // Register types + for _, typename := range []string{"pgx_a.point", "pgx_b.point"} { + // Obviously using conn while a tx is in use and registering a type after the connection has been established are + // really bad practices, but for the sake of convenience we do it in the test here. + dt, err := conn.LoadType(ctx, typename) + require.NoError(t, err) + conn.TypeMap().RegisterType(dt) + } + + type aPoint struct { + A string + B string + } + + type bPoint struct { + C string + } + + var a aPoint + var b bPoint + err = tx.QueryRow(ctx, `select '(foo,bar)'::pgx_a.point, '(baz)'::pgx_b.point`).Scan(&a, &b) + require.NoError(t, err) + require.Equal(t, aPoint{"foo", "bar"}, a) + require.Equal(t, bPoint{"baz"}, b) + }) +} + func TestStmtCacheInvalidationConn(t *testing.T) { ctx := context.Background() From 731daea58638c2fec840d5be348b8d0477758c61 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Jul 2022 21:08:15 -0500 Subject: [PATCH 1084/1158] Skip test on CockroachDB --- conn_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/conn_test.go b/conn_test.go index 93e0cd84..6d7f8434 100644 --- a/conn_test.go +++ b/conn_test.go @@ -908,6 +908,8 @@ func TestDomainType(t *testing.T) { func TestLoadTypeSameNameInDifferentSchemas(t *testing.T) { pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does support composite types (https://github.com/cockroachdb/cockroach/issues/27792)") + tx, err := conn.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) From b662ab67677b11e1b31c4c159574ff57c8a65484 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Jul 2022 21:26:38 -0500 Subject: [PATCH 1085/1158] Better encode error message --- pgtype/pgtype.go | 26 +++++++++++++++++++++++--- query_test.go | 2 +- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 636b0954..50483fa7 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1816,6 +1816,27 @@ func (plan *wrapMultiDimSliceEncodePlan) Encode(value any, buf []byte) (newBuf [ return plan.next.Encode(&w, buf) } +func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error) error { + var format string + switch formatCode { + case TextFormatCode: + format = "text" + case BinaryFormatCode: + format = "binary" + default: + format = fmt.Sprintf("unknown (%d)", formatCode) + } + + var dataTypeName string + if t, ok := m.oidToType[oid]; ok { + dataTypeName = t.Name + } else { + dataTypeName = "unknown type" + } + + return fmt.Errorf("unable to encode %#v into %s format for %s (OID %d): %s", value, format, dataTypeName, oid, err) +} + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written. @@ -1837,13 +1858,12 @@ func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBu return m.Encode(oid, formatCode, v, buf) } - return nil, fmt.Errorf("unable to encode %#v into format code %d for OID %d", value, formatCode, oid) + return nil, newEncodeError(value, m, oid, formatCode, errors.New("cannot find encode plan")) } newBuf, err = plan.Encode(value, buf) if err != nil { - err = fmt.Errorf("unable to encode %#v into format code %d for OID %d: %v", value, formatCode, oid, err) - return nil, err + return nil, newEncodeError(value, m, oid, formatCode, err) } return newBuf, nil diff --git a/query_test.go b/query_test.go index 0c8b5fab..dad93eaf 100644 --- a/query_test.go +++ b/query_test.go @@ -984,7 +984,7 @@ func TestQueryRowErrors(t *testing.T) { {"select $1::badtype", []any{"Jack"}, []any{&actual.i16}, `type "badtype" does not exist`}, {"SYNTAX ERROR", []any{}, []any{&actual.i16}, "SQLSTATE 42601"}, {"select $1::text", []any{"Jack"}, []any{&actual.i16}, "cannot scan OID 25 in text format into *int16"}, - {"select $1::point", []any{int(705)}, []any{&actual.s}, "unable to encode 705 into format code 1 for OID 600"}, + {"select $1::point", []any{int(705)}, []any{&actual.s}, "unable to encode 705 into binary format for point (OID 600)"}, } for i, tt := range tests { From 7974a102fc711b29537a01dc29330373c1f00364 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Jul 2022 21:47:39 -0500 Subject: [PATCH 1086/1158] Improve Scan error messages --- pgtype/pgtype.go | 12 ++++++++++-- query_test.go | 8 ++++---- rows_test.go | 2 +- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 50483fa7..904c99fa 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -582,6 +582,7 @@ func (scanPlanAnyTextToBytes) Scan(src []byte, dst any) error { } type scanPlanFail struct { + m *Map oid uint32 formatCode int16 } @@ -597,7 +598,14 @@ func (plan *scanPlanFail) Scan(src []byte, dst any) error { format = fmt.Sprintf("unknown %d", plan.formatCode) } - return fmt.Errorf("cannot scan OID %v in %v format into %T", plan.oid, format, dst) + var dataTypeName string + if t, ok := plan.m.oidToType[plan.oid]; ok { + dataTypeName = t.Name + } else { + dataTypeName = "unknown type" + } + + return fmt.Errorf("cannot scan %s (OID %d) in %v format into %T", dataTypeName, plan.oid, format, dst) } // TryWrapScanPlanFunc is a function that tries to create a wrapper plan for target. If successful it returns a plan @@ -1201,7 +1209,7 @@ func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { return &scanPlanSQLScanner{formatCode: formatCode} } - return &scanPlanFail{oid: oid, formatCode: formatCode} + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} } func (m *Map) Scan(oid uint32, formatCode int16, src []byte, dst any) error { diff --git a/query_test.go b/query_test.go index dad93eaf..59cf9355 100644 --- a/query_test.go +++ b/query_test.go @@ -274,8 +274,8 @@ func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T) var s string err := conn.QueryRow(context.Background(), "select point(1,2)").Scan(&s) - if err == nil || !(strings.Contains(err.Error(), "cannot scan OID 600 in binary format into *string")) { - t.Fatalf("Expected Scan to fail to encode binary value into string but: %v", err) + if err == nil || !(strings.Contains(err.Error(), "cannot scan point (OID 600) in binary format into *string")) { + t.Fatalf("Expected Scan to fail to scan binary value into string but: %v", err) } ensureConnValid(t, conn) @@ -397,7 +397,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) { t.Fatal("Expected Rows to have an error after an improper read but it didn't") } - if rows.Err().Error() != "can't scan into dest[0]: cannot scan OID 23 in binary format into *time.Time" { + if rows.Err().Error() != "can't scan into dest[0]: cannot scan int4 (OID 23) in binary format into *time.Time" { t.Fatalf("Expected different Rows.Err(): %v", rows.Err()) } @@ -983,7 +983,7 @@ func TestQueryRowErrors(t *testing.T) { }{ {"select $1::badtype", []any{"Jack"}, []any{&actual.i16}, `type "badtype" does not exist`}, {"SYNTAX ERROR", []any{}, []any{&actual.i16}, "SQLSTATE 42601"}, - {"select $1::text", []any{"Jack"}, []any{&actual.i16}, "cannot scan OID 25 in text format into *int16"}, + {"select $1::text", []any{"Jack"}, []any{&actual.i16}, "cannot scan text (OID 25) in text format into *int16"}, {"select $1::point", []any{int(705)}, []any{&actual.s}, "unable to encode 705 into binary format for point (OID 600)"}, } diff --git a/rows_test.go b/rows_test.go index 0e3bc179..cbc26887 100644 --- a/rows_test.go +++ b/rows_test.go @@ -79,7 +79,7 @@ func TestForEachRowScanError(t *testing.T) { actualResults = append(actualResults, []any{a, b}) return nil }) - require.EqualError(t, err, "can't scan into dest[0]: cannot scan OID 25 in text format into *int") + require.EqualError(t, err, "can't scan into dest[0]: cannot scan text (OID 25) in text format into *int") require.Equal(t, pgconn.CommandTag{}, ct) }) } From e7eb8a3250fa1d6d4c9c57fc896e62b0d44247b0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 10 Jul 2022 14:29:44 -0500 Subject: [PATCH 1087/1158] Use netip package for representing inet and cidr types --- CHANGELOG.md | 1 + README.md | 2 +- doc.go | 3 +- pgtype/builtin_wrappers.go | 81 +++++++++++++++---- pgtype/inet.go | 162 ++++++++++++++----------------------- pgtype/inet_test.go | 52 +++++++++++- pgtype/pgtype.go | 52 +++++++++++- pgtype/pgtype_test.go | 2 + 8 files changed, 231 insertions(+), 124 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e97434f..5e605f68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,6 +102,7 @@ This matches the convention set by `database/sql`. In addition, for comparable t * `Hstore` is now defined as `map[string]*string`. * `JSON` and `JSONB` types removed. Use `[]byte` or `string` directly. * `QChar` type removed. Use `rune` or `byte` directly. +* `Inet` and `Cidr` types removed. Use `netip.Addr` and `netip.Prefix` directly. These types are more memory efficient than the previous `net.IPNet`. * `Macaddr` type removed. Use `net.HardwareAddr` directly. * Renamed `pgtype.ConnInfo` to `pgtype.Map`. * Renamed `pgtype.DataType` to `pgtype.Type`. diff --git a/README.md b/README.md index 3cc5a91b..c4f7239f 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ pgx supports many features beyond what is available through `database/sql`: * Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings * Hstore support * JSON and JSONB support -* Maps `inet` and `cidr` PostgreSQL types to `net.IPNet` and `net.IP` +* Maps `inet` and `cidr` PostgreSQL types to `netip.Addr` and `netip.Prefix` * Large object support * NULL mapping to Null* struct or pointer to pointer * Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types diff --git a/doc.go b/doc.go index 0fd3713b..cfc2af85 100644 --- a/doc.go +++ b/doc.go @@ -152,8 +152,7 @@ pgx includes built-in support to marshal and unmarshal between Go types and the Inet and CIDR Mapping -pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In addition, as a convenience pgx will encode -from a net.IP; it will assume a /32 netmask for IPv4 and a /128 for IPv6. +pgx converts netip.Prefix and netip.Addr to and from inet and cidr PostgreSQL types. Custom Type Support diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index da9cf0bb..7a992b09 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -1,9 +1,11 @@ package pgtype import ( + "errors" "fmt" "math" "net" + "net/netip" "reflect" "time" ) @@ -460,44 +462,95 @@ func (w durationWrapper) IntervalValue() (Interval, error) { type netIPNetWrapper net.IPNet -func (w *netIPNetWrapper) ScanInet(v Inet) error { - if !v.Valid { +func (w *netIPNetWrapper) ScanNetipPrefix(v netip.Prefix) error { + if !v.IsValid() { return fmt.Errorf("cannot scan NULL into *net.IPNet") } - *w = (netIPNetWrapper)(*v.IPNet) + *w = netIPNetWrapper{ + IP: v.Addr().AsSlice(), + Mask: net.CIDRMask(v.Bits(), v.Addr().BitLen()), + } + return nil } +func (w netIPNetWrapper) NetipPrefixValue() (netip.Prefix, error) { + ip, ok := netip.AddrFromSlice(w.IP) + if !ok { + return netip.Prefix{}, errors.New("invalid net.IPNet") + } -func (w netIPNetWrapper) InetValue() (Inet, error) { - return Inet{IPNet: (*net.IPNet)(&w), Valid: true}, nil + ones, _ := w.Mask.Size() + + return netip.PrefixFrom(ip, ones), nil } type netIPWrapper net.IP func (w netIPWrapper) SkipUnderlyingTypePlan() {} -func (w *netIPWrapper) ScanInet(v Inet) error { - if !v.Valid { +func (w *netIPWrapper) ScanNetipPrefix(v netip.Prefix) error { + if !v.IsValid() { *w = nil return nil } - if oneCount, bitCount := v.IPNet.Mask.Size(); oneCount != bitCount { + if v.Addr().BitLen() != v.Bits() { return fmt.Errorf("cannot scan %v to *net.IP", v) } - *w = netIPWrapper(v.IPNet.IP) + + *w = netIPWrapper(v.Addr().AsSlice()) return nil } -func (w netIPWrapper) InetValue() (Inet, error) { +func (w netIPWrapper) NetipPrefixValue() (netip.Prefix, error) { if w == nil { - return Inet{}, nil + return netip.Prefix{}, nil } - bitCount := len(w) * 8 - mask := net.CIDRMask(bitCount, bitCount) - return Inet{IPNet: &net.IPNet{Mask: mask, IP: net.IP(w)}, Valid: true}, nil + addr, ok := netip.AddrFromSlice([]byte(w)) + if !ok { + return netip.Prefix{}, errors.New("invalid net.IP") + } + + return netip.PrefixFrom(addr, addr.BitLen()), nil +} + +type netipPrefixWrapper netip.Prefix + +func (w *netipPrefixWrapper) ScanNetipPrefix(v netip.Prefix) error { + *w = netipPrefixWrapper(v) + return nil +} + +func (w netipPrefixWrapper) NetipPrefixValue() (netip.Prefix, error) { + return netip.Prefix(w), nil +} + +type netipAddrWrapper netip.Addr + +func (w *netipAddrWrapper) ScanNetipPrefix(v netip.Prefix) error { + if !v.IsValid() { + *w = netipAddrWrapper(netip.Addr{}) + return nil + } + + if v.Addr().BitLen() != v.Bits() { + return fmt.Errorf("cannot scan %v to netip.Addr", v) + } + + *w = netipAddrWrapper(v.Addr()) + + return nil +} + +func (w netipAddrWrapper) NetipPrefixValue() (netip.Prefix, error) { + addr := (netip.Addr)(w) + if !addr.IsValid() { + return netip.Prefix{}, nil + } + + return netip.PrefixFrom(addr, addr.BitLen()), nil } type mapStringToPointerStringWrapper map[string]*string diff --git a/pgtype/inet.go b/pgtype/inet.go index f8abeef8..f094ed2f 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -1,9 +1,11 @@ package pgtype import ( + "bytes" "database/sql/driver" + "errors" "fmt" - "net" + "net/netip" ) // Network address family is dependent on server socket.h value for AF_INET. @@ -14,57 +16,16 @@ const ( defaultAFInet6 = 3 ) -type InetScanner interface { - ScanInet(v Inet) error +type NetipPrefixScanner interface { + ScanNetipPrefix(v netip.Prefix) error } -type InetValuer interface { - InetValue() (Inet, error) -} - -// Inet represents both inet and cidr PostgreSQL types. -type Inet struct { - IPNet *net.IPNet - Valid bool -} - -func (inet *Inet) ScanInet(v Inet) error { - *inet = v - return nil -} - -func (inet Inet) InetValue() (Inet, error) { - return inet, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Inet) Scan(src any) error { - if src == nil { - *dst = Inet{} - return nil - } - - switch src := src.(type) { - case string: - return scanPlanTextAnyToInetScanner{}.Scan([]byte(src), dst) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Inet) Value() (driver.Value, error) { - if !src.Valid { - return nil, nil - } - - buf, err := InetCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) - if err != nil { - return nil, err - } - return string(buf), err +type NetipPrefixValuer interface { + NetipPrefixValue() (netip.Prefix, error) } +// InetCodec handles both inet and cidr PostgreSQL types. The preferred Go types are netip.Prefix and netip.Addr. If +// IsValid() is false then they are treated as SQL NULL. type InetCodec struct{} func (InetCodec) FormatSupported(format int16) bool { @@ -76,7 +37,7 @@ func (InetCodec) PreferredFormat() int16 { } func (InetCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { - if _, ok := value.(InetValuer); !ok { + if _, ok := value.(NetipPrefixValuer); !ok { return nil } @@ -93,51 +54,56 @@ func (InetCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodeP type encodePlanInetCodecBinary struct{} func (encodePlanInetCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { - inet, err := value.(InetValuer).InetValue() + prefix, err := value.(NetipPrefixValuer).NetipPrefixValue() if err != nil { return nil, err } - if !inet.Valid { + if !prefix.IsValid() { return nil, nil } var family byte - switch len(inet.IPNet.IP) { - case net.IPv4len: + if prefix.Addr().Is4() { family = defaultAFInet - case net.IPv6len: + } else { family = defaultAFInet6 - default: - return nil, fmt.Errorf("Unexpected IP length: %v", len(inet.IPNet.IP)) } buf = append(buf, family) - ones, _ := inet.IPNet.Mask.Size() + ones := prefix.Bits() buf = append(buf, byte(ones)) // is_cidr is ignored on server buf = append(buf, 0) - buf = append(buf, byte(len(inet.IPNet.IP))) + if family == defaultAFInet { + buf = append(buf, byte(4)) + b := prefix.Addr().As4() + buf = append(buf, b[:]...) + } else { + buf = append(buf, byte(16)) + b := prefix.Addr().As16() + buf = append(buf, b[:]...) + } - return append(buf, inet.IPNet.IP...), nil + return buf, nil } type encodePlanInetCodecText struct{} func (encodePlanInetCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { - inet, err := value.(InetValuer).InetValue() + prefix, err := value.(NetipPrefixValuer).NetipPrefixValue() if err != nil { return nil, err } - if !inet.Valid { + if !prefix.IsValid() { return nil, nil } - return append(buf, inet.IPNet.String()...), nil + return append(buf, prefix.String()...), nil } func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { @@ -145,13 +111,13 @@ func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan switch format { case BinaryFormatCode: switch target.(type) { - case InetScanner: - return scanPlanBinaryInetToInetScanner{} + case NetipPrefixScanner: + return scanPlanBinaryInetToNetipPrefixScanner{} } case TextFormatCode: switch target.(type) { - case InetScanner: - return scanPlanTextAnyToInetScanner{} + case NetipPrefixScanner: + return scanPlanTextAnyToNetipPrefixScanner{} } } @@ -167,26 +133,26 @@ func (c InetCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (an return nil, nil } - var inet Inet - err := codecScan(c, m, oid, format, src, &inet) + var prefix netip.Prefix + err := codecScan(c, m, oid, format, src, &prefix) if err != nil { return nil, err } - if !inet.Valid { + if !prefix.IsValid() { return nil, nil } - return inet.IPNet, nil + return prefix, nil } -type scanPlanBinaryInetToInetScanner struct{} +type scanPlanBinaryInetToNetipPrefixScanner struct{} -func (scanPlanBinaryInetToInetScanner) Scan(src []byte, dst any) error { - scanner := (dst).(InetScanner) +func (scanPlanBinaryInetToNetipPrefixScanner) Scan(src []byte, dst any) error { + scanner := (dst).(NetipPrefixScanner) if src == nil { - return scanner.ScanInet(Inet{}) + return scanner.ScanNetipPrefix(netip.Prefix{}) } if len(src) != 8 && len(src) != 20 { @@ -196,49 +162,39 @@ func (scanPlanBinaryInetToInetScanner) Scan(src []byte, dst any) error { // ignore family bits := src[1] // ignore is_cidr - addressLength := src[3] + // ignore addressLength - implicit in length of message - var ipnet net.IPNet - ipnet.IP = make(net.IP, int(addressLength)) - copy(ipnet.IP, src[4:]) - if ipv4 := ipnet.IP.To4(); ipv4 != nil { - ipnet.IP = ipv4 + addr, ok := netip.AddrFromSlice(src[4:]) + if !ok { + return errors.New("netip.AddrFromSlice failed") } - ipnet.Mask = net.CIDRMask(int(bits), len(ipnet.IP)*8) - return scanner.ScanInet(Inet{IPNet: &ipnet, Valid: true}) + return scanner.ScanNetipPrefix(netip.PrefixFrom(addr, int(bits))) } -type scanPlanTextAnyToInetScanner struct{} +type scanPlanTextAnyToNetipPrefixScanner struct{} -func (scanPlanTextAnyToInetScanner) Scan(src []byte, dst any) error { - scanner := (dst).(InetScanner) +func (scanPlanTextAnyToNetipPrefixScanner) Scan(src []byte, dst any) error { + scanner := (dst).(NetipPrefixScanner) if src == nil { - return scanner.ScanInet(Inet{}) + return scanner.ScanNetipPrefix(netip.Prefix{}) } - var ipnet *net.IPNet - var err error - - if ip := net.ParseIP(string(src)); ip != nil { - if ipv4 := ip.To4(); ipv4 != nil { - ip = ipv4 - } - bitCount := len(ip) * 8 - mask := net.CIDRMask(bitCount, bitCount) - ipnet = &net.IPNet{Mask: mask, IP: ip} - } else { - ip, ipnet, err = net.ParseCIDR(string(src)) + var prefix netip.Prefix + if bytes.IndexByte(src, '/') == -1 { + addr, err := netip.ParseAddr(string(src)) if err != nil { return err } - if ipv4 := ip.To4(); ipv4 != nil { - ip = ipv4 + prefix = netip.PrefixFrom(addr, addr.BitLen()) + } else { + var err error + prefix, err = netip.ParsePrefix(string(src)) + if err != nil { + return err } - ones, _ := ipnet.Mask.Size() - *ipnet = net.IPNet{IP: ip, Mask: net.CIDRMask(ones, len(ip)*8)} } - return scanner.ScanInet(Inet{IPNet: ipnet, Valid: true}) + return scanner.ScanNetipPrefix(prefix) } diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go index 0a174e1a..f4b43daf 100644 --- a/pgtype/inet_test.go +++ b/pgtype/inet_test.go @@ -3,9 +3,9 @@ package pgtype_test import ( "context" "net" + "net/netip" "testing" - "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxtest" ) @@ -31,7 +31,42 @@ func TestInetTranscode(t *testing.T) { {mustParseInet(t, "::/0"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::/0"))}, {mustParseInet(t, "::1/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::1/128"))}, {mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "2607:f8b0:4009:80b::200e/64"))}, - {nil, new(pgtype.Inet), isExpectedEq(pgtype.Inet{})}, + + {mustParseInet(t, "0.0.0.0/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("0.0.0.0/32"))}, + {mustParseInet(t, "127.0.0.1/8"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("127.0.0.1/8"))}, + {mustParseInet(t, "12.34.56.65/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("12.34.56.65/32"))}, + {mustParseInet(t, "192.168.1.16/24"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("192.168.1.16/24"))}, + {mustParseInet(t, "255.0.0.0/8"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("255.0.0.0/8"))}, + {mustParseInet(t, "255.255.255.255/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("255.255.255.255/32"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("2607:f8b0:4009:80b::200e/128"))}, + {mustParseInet(t, "::1/64"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::1/64"))}, + {mustParseInet(t, "::/0"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::/0"))}, + {mustParseInet(t, "::1/128"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::1/128"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("2607:f8b0:4009:80b::200e/64"))}, + + {netip.MustParsePrefix("0.0.0.0/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("0.0.0.0/32"))}, + {netip.MustParsePrefix("127.0.0.1/8"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("127.0.0.1/8"))}, + {netip.MustParsePrefix("12.34.56.65/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("12.34.56.65/32"))}, + {netip.MustParsePrefix("192.168.1.16/24"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("192.168.1.16/24"))}, + {netip.MustParsePrefix("255.0.0.0/8"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("255.0.0.0/8"))}, + {netip.MustParsePrefix("255.255.255.255/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("255.255.255.255/32"))}, + {netip.MustParsePrefix("::1/64"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::1/64"))}, + {netip.MustParsePrefix("::/0"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::/0"))}, + {netip.MustParsePrefix("::1/128"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::1/128"))}, + {netip.MustParsePrefix("2607:f8b0:4009:80b::200e/64"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("2607:f8b0:4009:80b::200e/64"))}, + + {netip.MustParseAddr("0.0.0.0"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("0.0.0.0"))}, + {netip.MustParseAddr("127.0.0.1"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("127.0.0.1"))}, + {netip.MustParseAddr("12.34.56.65"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("12.34.56.65"))}, + {netip.MustParseAddr("192.168.1.16"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("192.168.1.16"))}, + {netip.MustParseAddr("255.0.0.0"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("255.0.0.0"))}, + {netip.MustParseAddr("255.255.255.255"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("255.255.255.255"))}, + {netip.MustParseAddr("2607:f8b0:4009:80b::200e"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("2607:f8b0:4009:80b::200e"))}, + {netip.MustParseAddr("::1"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("::1"))}, + {netip.MustParseAddr("::"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("::"))}, + {netip.MustParseAddr("2607:f8b0:4009:80b::200e"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("2607:f8b0:4009:80b::200e"))}, + + {nil, new(netip.Prefix), isExpectedEq(netip.Prefix{})}, }) } @@ -48,6 +83,17 @@ func TestCidrTranscode(t *testing.T) { {mustParseInet(t, "::/0"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::/0"))}, {mustParseInet(t, "::1/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::1/128"))}, {mustParseInet(t, "2607:f8b0:4009:80b::200e/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "2607:f8b0:4009:80b::200e/128"))}, - {nil, new(pgtype.Inet), isExpectedEq(pgtype.Inet{})}, + + {netip.MustParsePrefix("0.0.0.0/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("0.0.0.0/32"))}, + {netip.MustParsePrefix("127.0.0.1/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("127.0.0.1/32"))}, + {netip.MustParsePrefix("12.34.56.0/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("12.34.56.0/32"))}, + {netip.MustParsePrefix("192.168.1.0/24"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("192.168.1.0/24"))}, + {netip.MustParsePrefix("255.0.0.0/8"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("255.0.0.0/8"))}, + {netip.MustParsePrefix("::/128"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::/128"))}, + {netip.MustParsePrefix("::/0"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::/0"))}, + {netip.MustParsePrefix("::1/128"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::1/128"))}, + {netip.MustParsePrefix("2607:f8b0:4009:80b::200e/128"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("2607:f8b0:4009:80b::200e/128"))}, + + {nil, new(netip.Prefix), isExpectedEq(netip.Prefix{})}, }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 904c99fa..160f18af 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net" + "net/netip" "reflect" "time" ) @@ -364,6 +365,8 @@ func NewMap() *Map { registerDefaultPgTypeVariants[net.IP](m, "inet") registerDefaultPgTypeVariants[net.IPNet](m, "cidr") + registerDefaultPgTypeVariants[netip.Addr](m, "inet") + registerDefaultPgTypeVariants[netip.Prefix](m, "cidr") // pgtype provided structs registerDefaultPgTypeVariants[Bits](m, "varbit") @@ -377,7 +380,6 @@ func NewMap() *Map { registerDefaultPgTypeVariants[Float8](m, "float8") registerDefaultPgTypeVariants[Range[Float8]](m, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange. registerDefaultPgTypeVariants[Multirange[Range[Float8]]](m, "nummultirange") // There is no PostgreSQL builtin float8multirange so map it to nummultirange. - registerDefaultPgTypeVariants[Inet](m, "inet") registerDefaultPgTypeVariants[Int2](m, "int2") registerDefaultPgTypeVariants[Int4](m, "int4") registerDefaultPgTypeVariants[Range[Int4]](m, "int4range") @@ -751,6 +753,10 @@ func TryWrapBuiltinTypeScanPlan(target any) (plan WrappedScanPlanNextSetter, nex return &wrapNetIPNetScanPlan{}, (*netIPNetWrapper)(target), true case *net.IP: return &wrapNetIPScanPlan{}, (*netIPWrapper)(target), true + case *netip.Prefix: + return &wrapNetipPrefixScanPlan{}, (*netipPrefixWrapper)(target), true + case *netip.Addr: + return &wrapNetipAddrScanPlan{}, (*netipAddrWrapper)(target), true case *map[string]*string: return &wrapMapStringToPointerStringScanPlan{}, (*mapStringToPointerStringWrapper)(target), true case *map[string]string: @@ -934,6 +940,26 @@ func (plan *wrapNetIPScanPlan) Scan(src []byte, dst any) error { return plan.next.Scan(src, (*netIPWrapper)(dst.(*net.IP))) } +type wrapNetipPrefixScanPlan struct { + next ScanPlan +} + +func (plan *wrapNetipPrefixScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapNetipPrefixScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*netipPrefixWrapper)(dst.(*netip.Prefix))) +} + +type wrapNetipAddrScanPlan struct { + next ScanPlan +} + +func (plan *wrapNetipAddrScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapNetipAddrScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*netipAddrWrapper)(dst.(*netip.Addr))) +} + type wrapMapStringToPointerStringScanPlan struct { next ScanPlan } @@ -1454,6 +1480,10 @@ func TryWrapBuiltinTypeEncodePlan(value any) (plan WrappedEncodePlanNextSetter, return &wrapNetIPNetEncodePlan{}, netIPNetWrapper(value), true case net.IP: return &wrapNetIPEncodePlan{}, netIPWrapper(value), true + case netip.Prefix: + return &wrapNetipPrefixEncodePlan{}, netipPrefixWrapper(value), true + case netip.Addr: + return &wrapNetipAddrEncodePlan{}, netipAddrWrapper(value), true case map[string]*string: return &wrapMapStringToPointerStringEncodePlan{}, mapStringToPointerStringWrapper(value), true case map[string]string: @@ -1639,6 +1669,26 @@ func (plan *wrapNetIPEncodePlan) Encode(value any, buf []byte) (newBuf []byte, e return plan.next.Encode(netIPWrapper(value.(net.IP)), buf) } +type wrapNetipPrefixEncodePlan struct { + next EncodePlan +} + +func (plan *wrapNetipPrefixEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapNetipPrefixEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(netipPrefixWrapper(value.(netip.Prefix)), buf) +} + +type wrapNetipAddrEncodePlan struct { + next EncodePlan +} + +func (plan *wrapNetipAddrEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapNetipAddrEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(netipAddrWrapper(value.(netip.Addr)), buf) +} + type wrapMapStringToPointerStringEncodePlan struct { next EncodePlan } diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index fa4c823b..e2465afc 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -55,6 +55,8 @@ func mustParseInet(t testing.TB, s string) *net.IPNet { if err == nil { if ipv4 := ip.To4(); ipv4 != nil { ipnet.IP = ipv4 + } else { + ipnet.IP = ip } return ipnet } From ca41a6a22222fd0c17e51f8caa2176147893c11c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 10 Jul 2022 14:32:08 -0500 Subject: [PATCH 1088/1158] Update docs --- internal/stmtcache/stmtcache.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/stmtcache/stmtcache.go b/internal/stmtcache/stmtcache.go index f975273e..e1bdcba5 100644 --- a/internal/stmtcache/stmtcache.go +++ b/internal/stmtcache/stmtcache.go @@ -1,4 +1,4 @@ -// Package stmtcache is a cache that can be used to implement lazy prepared statements. +// Package stmtcache is a cache for statement descriptions. package stmtcache import ( From a059d1099fe4f13dbf1b90ff8c30c754b18c028f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 10 Jul 2022 14:58:30 -0500 Subject: [PATCH 1089/1158] pgxpool pools always connect lazily Rename constructor functions now that they don't actually connect. --- CHANGELOG.md | 4 + examples/chat/main.go | 2 +- examples/url_shortener/main.go | 2 +- pgxpool/bench_test.go | 6 +- pgxpool/common_test.go | 1 - pgxpool/conn_test.go | 10 +- pgxpool/doc.go | 11 +- pgxpool/pool.go | 38 ++--- pgxpool/pool_test.go | 263 ++++++--------------------------- pgxpool/tx_test.go | 10 +- 10 files changed, 79 insertions(+), 268 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e605f68..69c0faad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,10 @@ pgconn now supports pipeline mode. `*PgConn.ReceiveResults` removed. Use pipeline mode instead. +## pgxpool + +`Connect` and `ConnectConfig` have been renamed to `New` and `NewConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect. + ## pgtype The `pgtype` package has been significantly changed. diff --git a/examples/chat/main.go b/examples/chat/main.go index 6e705fb6..5adbb3b6 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -13,7 +13,7 @@ var pool *pgxpool.Pool func main() { var err error - pool, err = pgxpool.Connect(context.Background(), os.Getenv("DATABASE_URL")) + pool, err = pgxpool.New(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { fmt.Fprintln(os.Stderr, "Unable to connect to database:", err) os.Exit(1) diff --git a/examples/url_shortener/main.go b/examples/url_shortener/main.go index bcee235e..092de1fb 100644 --- a/examples/url_shortener/main.go +++ b/examples/url_shortener/main.go @@ -75,7 +75,7 @@ func main() { log.Fatalln("Unable to parse DATABASE_URL:", err) } - db, err = pgxpool.ConnectConfig(context.Background(), poolConfig) + db, err = pgxpool.NewConfig(context.Background(), poolConfig) if err != nil { log.Fatalln("Unable to create connection pool:", err) } diff --git a/pgxpool/bench_test.go b/pgxpool/bench_test.go index 704371db..588d104c 100644 --- a/pgxpool/bench_test.go +++ b/pgxpool/bench_test.go @@ -11,7 +11,7 @@ import ( ) func BenchmarkAcquireAndRelease(b *testing.B) { - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(b, err) defer pool.Close() @@ -34,7 +34,7 @@ func BenchmarkMinimalPreparedSelectBaseline(b *testing.B) { return err } - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewConfig(context.Background(), config) require.NoError(b, err) conn, err := db.Acquire(context.Background()) @@ -65,7 +65,7 @@ func BenchmarkMinimalPreparedSelect(b *testing.B) { return err } - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewConfig(context.Background(), config) require.NoError(b, err) var n int64 diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go index c331b33b..eabc0e3c 100644 --- a/pgxpool/common_test.go +++ b/pgxpool/common_test.go @@ -148,7 +148,6 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgxpool.Config, testName assert.Equalf(t, expected.MaxConns, actual.MaxConns, "%s - MaxConns", testName) assert.Equalf(t, expected.MinConns, actual.MinConns, "%s - MinConns", testName) assert.Equalf(t, expected.HealthCheckPeriod, actual.HealthCheckPeriod, "%s - HealthCheckPeriod", testName) - assert.Equalf(t, expected.LazyConnect, actual.LazyConnect, "%s - LazyConnect", testName) assertConnConfigsEqual(t, expected.ConnConfig, actual.ConnConfig, testName) } diff --git a/pgxpool/conn_test.go b/pgxpool/conn_test.go index ff34f969..175981b7 100644 --- a/pgxpool/conn_test.go +++ b/pgxpool/conn_test.go @@ -12,7 +12,7 @@ import ( func TestConnExec(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -26,7 +26,7 @@ func TestConnExec(t *testing.T) { func TestConnQuery(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -40,7 +40,7 @@ func TestConnQuery(t *testing.T) { func TestConnQueryRow(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -54,7 +54,7 @@ func TestConnQueryRow(t *testing.T) { func TestConnSendBatch(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -68,7 +68,7 @@ func TestConnSendBatch(t *testing.T) { func TestConnCopyFrom(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() diff --git a/pgxpool/doc.go b/pgxpool/doc.go index e8239a6f..07f6359d 100644 --- a/pgxpool/doc.go +++ b/pgxpool/doc.go @@ -2,11 +2,11 @@ /* pgxpool implements a nearly identical interface to pgx connections. -Establishing a Connection +Creating a Pool -The primary way of establishing a connection is with `pgxpool.Connect`. +The primary way of creating a pool is with `pgxpool.New`. - pool, err := pgxpool.Connect(context.Background(), os.Getenv("DATABASE_URL")) + pool, err := pgxpool.New(context.Background(), os.Getenv("DATABASE_URL")) The database connection string can be in URL or DSN format. PostgreSQL settings, pgx settings, and pool settings can be specified here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the @@ -20,6 +20,9 @@ connection with `ConnectConfig`. // do something with every new connection } - pool, err := pgxpool.ConnectConfig(context.Background(), config) + pool, err := pgxpool.NewConfig(context.Background(), config) + +A pool returns without waiting for any connections to be established. Acquire a connection immediately after creating +the pool to check if a connection can successfully be established. */ package pgxpool diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 7027e282..65f1cb42 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -122,11 +122,6 @@ type Config struct { // HealthCheckPeriod is the duration between checks of the health of idle connections. HealthCheckPeriod time.Duration - // If set to true, pool doesn't do any I/O operation on initialization. - // And connects to the server only when the pool starts to be used. - // The default is false. - LazyConnect bool - createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -143,20 +138,18 @@ func (c *Config) Copy() *Config { // ConnString returns the connection string as parsed by pgxpool.ParseConfig into pgxpool.Config. func (c *Config) ConnString() string { return c.ConnConfig.ConnString() } -// Connect creates a new Pool and immediately establishes one connection. ctx can be used to cancel this initial -// connection. See ParseConfig for information on connString format. -func Connect(ctx context.Context, connString string) (*Pool, error) { +// New creates a new Pool. See ParseConfig for information on connString format. +func New(ctx context.Context, connString string) (*Pool, error) { config, err := ParseConfig(connString) if err != nil { return nil, err } - return ConnectConfig(ctx, config) + return NewConfig(ctx, config) } -// ConnectConfig creates a new Pool and immediately establishes one connection. ctx can be used to cancel this initial -// connection. config must have been created by ParseConfig. -func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { +// NewConfig creates a new Pool. config must have been created by ParseConfig. +func NewConfig(ctx context.Context, config *Config) (*Pool, error) { // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. if !config.createdByParseConfig { @@ -222,23 +215,10 @@ func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { config.MaxConns, ) - if !config.LazyConnect { - if err := p.createIdleResources(ctx, int(p.minConns)); err != nil { - // Couldn't create resources for minpool size. Close unhealthy pool. - p.Close() - return nil, err - } - - // Initially establish one connection - res, err := p.p.Acquire(ctx) - if err != nil { - p.Close() - return nil, err - } - res.Release() - } - - go p.backgroundHealthCheck() + go func() { + p.checkMinConns() // reach min conns as soon as possible + p.backgroundHealthCheck() + }() return p, nil } diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 5cd943d7..f296d819 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -19,7 +19,7 @@ import ( func TestConnect(t *testing.T) { t.Parallel() connString := os.Getenv("PGX_TEST_DATABASE") - pool, err := pgxpool.Connect(context.Background(), connString) + pool, err := pgxpool.New(context.Background(), connString) require.NoError(t, err) assert.Equal(t, connString, pool.Config().ConnString()) pool.Close() @@ -30,7 +30,7 @@ func TestConnectConfig(t *testing.T) { connString := os.Getenv("PGX_TEST_DATABASE") config, err := pgxpool.ParseConfig(connString) require.NoError(t, err) - pool, err := pgxpool.ConnectConfig(context.Background(), config) + pool, err := pgxpool.NewConfig(context.Background(), config) require.NoError(t, err) assertConfigsEqual(t, config, pool.Config(), "Pool.Config() returns original config") pool.Close() @@ -47,39 +47,12 @@ func TestParseConfigExtractsPoolArguments(t *testing.T) { assert.NotContains(t, config.ConnConfig.Config.RuntimeParams, "pool_min_conns") } -func TestConnectCancel(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) - assert.Nil(t, pool) - assert.Equal(t, context.Canceled, err) -} - -func TestLazyConnect(t *testing.T) { - t.Parallel() - - config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - assert.NoError(t, err) - config.LazyConnect = true - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - pool, err := pgxpool.ConnectConfig(ctx, config) - assert.NoError(t, err) - - _, err = pool.Exec(ctx, "SELECT 1") - assert.Equal(t, context.Canceled, err) -} - func TestConnectConfigRequiresConnConfigFromParseConfig(t *testing.T) { t.Parallel() config := &pgxpool.Config{} - require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgxpool.ConnectConfig(context.Background(), config) }) + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgxpool.NewConfig(context.Background(), config) }) } func TestConfigCopyReturnsEqualConfig(t *testing.T) { @@ -99,7 +72,7 @@ func TestConfigCopyCanBeUsedToConnect(t *testing.T) { copied := original.Copy() assert.NotPanics(t, func() { - _, err = pgxpool.ConnectConfig(context.Background(), copied) + _, err = pgxpool.NewConfig(context.Background(), copied) }) assert.NoError(t, err) } @@ -107,7 +80,7 @@ func TestConfigCopyCanBeUsedToConnect(t *testing.T) { func TestPoolAcquireAndConnRelease(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -121,7 +94,7 @@ func TestPoolAcquireAndConnHijack(t *testing.T) { ctx := context.Background() - pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -150,7 +123,7 @@ func TestPoolAcquireChecksIdleConns(t *testing.T) { defer controllerConn.Close(context.Background()) pgxtest.SkipCockroachDB(t, controllerConn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -196,7 +169,7 @@ func TestPoolAcquireChecksIdleConns(t *testing.T) { func TestPoolAcquireFunc(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -211,7 +184,7 @@ func TestPoolAcquireFunc(t *testing.T) { func TestPoolAcquireFuncReturnsFnError(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -232,7 +205,7 @@ func TestPoolBeforeConnect(t *testing.T) { return nil } - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -253,7 +226,7 @@ func TestPoolAfterConnect(t *testing.T) { return err } - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -276,7 +249,7 @@ func TestPoolBeforeAcquire(t *testing.T) { return acquireAttempts%2 == 0 } - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -308,7 +281,7 @@ func TestPoolAfterRelease(t *testing.T) { t.Parallel() func() { - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -331,7 +304,7 @@ func TestPoolAfterRelease(t *testing.T) { return afterReleaseCount%2 == 1 } - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -351,19 +324,11 @@ func TestPoolAfterRelease(t *testing.T) { func TestPoolAcquireAllIdle(t *testing.T) { t.Parallel() - db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + db, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer db.Close() - conns := db.AcquireAllIdle(context.Background()) - assert.Len(t, conns, 1) - - for _, c := range conns { - c.Release() - } - waitForReleaseToComplete() - - conns = make([]*pgxpool.Conn, 3) + conns := make([]*pgxpool.Conn, 3) for i := range conns { conns[i], err = db.Acquire(context.Background()) assert.NoError(t, err) @@ -392,7 +357,7 @@ func TestConnReleaseChecksMaxConnLifetime(t *testing.T) { config.MaxConnLifetime = 250 * time.Millisecond - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -411,7 +376,7 @@ func TestConnReleaseChecksMaxConnLifetime(t *testing.T) { func TestConnReleaseClosesBusyConn(t *testing.T) { t.Parallel() - db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + db, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer db.Close() @@ -437,7 +402,7 @@ func TestPoolBackgroundChecksMaxConnLifetime(t *testing.T) { config.MaxConnLifetime = 100 * time.Millisecond config.HealthCheckPeriod = 100 * time.Millisecond - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -460,7 +425,7 @@ func TestPoolBackgroundChecksMaxConnIdleTime(t *testing.T) { config.MaxConnIdleTime = 100 * time.Millisecond config.HealthCheckPeriod = 150 * time.Millisecond - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -487,7 +452,7 @@ func TestPoolBackgroundChecksMinConns(t *testing.T) { config.HealthCheckPeriod = 100 * time.Millisecond config.MinConns = 2 - db, err := pgxpool.ConnectConfig(context.Background(), config) + db, err := pgxpool.NewConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -500,7 +465,7 @@ func TestPoolBackgroundChecksMinConns(t *testing.T) { func TestPoolExec(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -510,7 +475,7 @@ func TestPoolExec(t *testing.T) { func TestPoolQuery(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -539,7 +504,7 @@ func TestPoolQuery(t *testing.T) { func TestPoolQueryRow(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -555,7 +520,7 @@ func TestPoolQueryRow(t *testing.T) { func TestPoolQueryRowErrNoRows(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -566,7 +531,7 @@ func TestPoolQueryRowErrNoRows(t *testing.T) { func TestPoolSendBatch(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -586,7 +551,7 @@ func TestPoolCopyFrom(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -629,7 +594,7 @@ func TestConnReleaseClosesConnInFailedTransaction(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -675,7 +640,7 @@ func TestConnReleaseClosesConnInTransaction(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -716,7 +681,7 @@ func TestConnReleaseDestroysClosedConn(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -737,7 +702,7 @@ func TestConnReleaseDestroysClosedConn(t *testing.T) { func TestConnPoolQueryConcurrentLoad(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -763,7 +728,7 @@ func TestConnReleaseWhenBeginFail(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - db, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + db, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer db.Close() @@ -787,7 +752,7 @@ func TestConnReleaseWhenBeginFail(t *testing.T) { } func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { - db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + db, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer db.Close() @@ -834,7 +799,7 @@ func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { } func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { - db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + db, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer db.Close() @@ -877,7 +842,7 @@ func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { } func TestIdempotentPoolClose(t *testing.T) { - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) // Close the open pool. @@ -887,7 +852,7 @@ func TestIdempotentPoolClose(t *testing.T) { require.NotPanics(t, func() { pool.Close() }) } -func TestConnectCreatesMinPool(t *testing.T) { +func TestConnectEagerlyReachesMinPoolSize(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) @@ -895,7 +860,6 @@ func TestConnectCreatesMinPool(t *testing.T) { config.MinConns = int32(12) config.MaxConns = int32(15) - config.LazyConnect = false acquireAttempts := int64(0) connectAttempts := int64(0) @@ -909,166 +873,27 @@ func TestConnectCreatesMinPool(t *testing.T) { return nil } - pool, err := pgxpool.ConnectConfig(context.Background(), config) + pool, err := pgxpool.NewConfig(context.Background(), config) require.NoError(t, err) defer pool.Close() - stat := pool.Stat() - require.Equal(t, int32(12), stat.IdleConns()) - require.Equal(t, int64(1), stat.AcquireCount()) - require.Equal(t, int32(12), stat.TotalConns()) - require.Equal(t, int64(0), acquireAttempts) - require.Equal(t, int64(12), connectAttempts) -} -func TestConnectSkipMinPoolWithLazy(t *testing.T) { - t.Parallel() + for i := 0; i < 500; i++ { + time.Sleep(10 * time.Millisecond) - config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - config.MinConns = int32(12) - config.MaxConns = int32(15) - config.LazyConnect = true - - acquireAttempts := int64(0) - connectAttempts := int64(0) - - config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { - atomic.AddInt64(&acquireAttempts, 1) - return true - } - config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { - atomic.AddInt64(&connectAttempts, 1) - return nil - } - - pool, err := pgxpool.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer pool.Close() - - stat := pool.Stat() - require.Equal(t, int32(0), stat.IdleConns()) - require.Equal(t, int64(0), stat.AcquireCount()) - require.Equal(t, int32(0), stat.TotalConns()) - require.Equal(t, int64(0), acquireAttempts) - require.Equal(t, int64(0), connectAttempts) -} - -func TestConnectMinPoolZero(t *testing.T) { - t.Parallel() - - config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - config.MinConns = int32(0) - config.MaxConns = int32(15) - config.LazyConnect = false - - acquireAttempts := int64(0) - connectAttempts := int64(0) - - config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { - atomic.AddInt64(&acquireAttempts, 1) - return true - } - config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { - atomic.AddInt64(&connectAttempts, 1) - return nil - } - - pool, err := pgxpool.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer pool.Close() - - stat := pool.Stat() - require.Equal(t, int32(1), stat.IdleConns()) - require.Equal(t, int64(1), stat.AcquireCount()) - require.Equal(t, int32(1), stat.TotalConns()) - require.Equal(t, int64(0), acquireAttempts) - require.Equal(t, int64(1), connectAttempts) -} - -func TestCreateMinPoolClosesConnectionsOnError(t *testing.T) { - t.Parallel() - - config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - config.MinConns = int32(12) - config.MaxConns = int32(15) - config.LazyConnect = false - - acquireAttempts := int64(0) - madeConnections := int64(0) - conns := make(chan *pgx.Conn, 15) - - config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { - atomic.AddInt64(&acquireAttempts, 1) - return true - } - config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { - conns <- conn - - atomic.AddInt64(&madeConnections, 1) - mc := atomic.LoadInt64(&madeConnections) - if mc == 10 { - return errors.New("mock error") + stat := pool.Stat() + if stat.IdleConns() == 12 && stat.AcquireCount() == 0 && stat.TotalConns() == 12 && atomic.LoadInt64(&acquireAttempts) == 0 && atomic.LoadInt64(&connectAttempts) == 12 { + return } - return nil - } - pool, err := pgxpool.ConnectConfig(context.Background(), config) - require.Error(t, err) - require.Nil(t, pool) - - close(conns) - for conn := range conns { - require.True(t, conn.IsClosed()) } - require.Equal(t, int64(0), acquireAttempts) - require.True(t, madeConnections >= 10, "Expected %d got %d", 10, madeConnections) -} + t.Fatal("did not reach min pool size") -func TestCreateMinPoolReturnsFirstError(t *testing.T) { - t.Parallel() - - config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - config.MinConns = int32(12) - config.MaxConns = int32(15) - config.LazyConnect = false - - acquireAttempts := int64(0) - connectAttempts := int64(0) - - mockErr := errors.New("mock connect error") - - config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { - atomic.AddInt64(&acquireAttempts, 1) - return true - } - config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { - atomic.AddInt64(&connectAttempts, 1) - ca := atomic.LoadInt64(&connectAttempts) - if ca >= 5 { - return mockErr - } - return nil - } - - pool, err := pgxpool.ConnectConfig(context.Background(), config) - require.Nil(t, pool) - require.Error(t, err) - - require.True(t, connectAttempts >= 5, "Expected %d got %d", 5, connectAttempts) - require.ErrorIs(t, err, mockErr) } func TestPoolSendBatchBatchCloseTwice(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() diff --git a/pgxpool/tx_test.go b/pgxpool/tx_test.go index e32d3efe..8e140bf5 100644 --- a/pgxpool/tx_test.go +++ b/pgxpool/tx_test.go @@ -12,7 +12,7 @@ import ( func TestTxExec(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -26,7 +26,7 @@ func TestTxExec(t *testing.T) { func TestTxQuery(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -40,7 +40,7 @@ func TestTxQuery(t *testing.T) { func TestTxQueryRow(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -54,7 +54,7 @@ func TestTxQueryRow(t *testing.T) { func TestTxSendBatch(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() @@ -68,7 +68,7 @@ func TestTxSendBatch(t *testing.T) { func TestTxCopyFrom(t *testing.T) { t.Parallel() - pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() From 224393188d5516ed7805a1f32a627b6973d735fa Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 11 Jul 2022 08:07:23 -0500 Subject: [PATCH 1090/1158] Fix InetCodec.DecodeValue --- pgtype/inet.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgtype/inet.go b/pgtype/inet.go index f094ed2f..a85646d7 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -134,7 +134,7 @@ func (c InetCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (an } var prefix netip.Prefix - err := codecScan(c, m, oid, format, src, &prefix) + err := codecScan(c, m, oid, format, src, (*netipPrefixWrapper)(&prefix)) if err != nil { return nil, err } From 786de2bda864972f9b87d0e34ea0f47c3778dbd2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 11 Jul 2022 20:42:55 -0500 Subject: [PATCH 1091/1158] Use correct cache --- conn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn.go b/conn.go index 3f6bee0d..f89dce7b 100644 --- a/conn.go +++ b/conn.go @@ -1040,7 +1040,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d // Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later. if sdCache != nil { for _, sd := range distinctNewQueries { - c.statementCache.Put(sd) + sdCache.Put(sd) } } From aaacdbf3ea42e3f977f4134ce13fa15a9a1e6c54 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 11 Jul 2022 21:09:03 -0500 Subject: [PATCH 1092/1158] Use string internally for CommandTag --- pgconn/benchmark_private_test.go | 6 ++--- pgconn/pgconn.go | 46 +++++++------------------------- pgconn/pgconn_private_test.go | 22 +++++++-------- 3 files changed, 24 insertions(+), 50 deletions(-) diff --git a/pgconn/benchmark_private_test.go b/pgconn/benchmark_private_test.go index e074c75c..9ea036ec 100644 --- a/pgconn/benchmark_private_test.go +++ b/pgconn/benchmark_private_test.go @@ -17,7 +17,7 @@ func BenchmarkCommandTagRowsAffected(b *testing.B) { } for _, bm := range benchmarks { - ct := CommandTag{buf: []byte(bm.commandTag)} + ct := CommandTag{s: bm.commandTag} b.Run(bm.commandTag, func(b *testing.B) { var n int64 for i := 0; i < b.N; i++ { @@ -31,7 +31,7 @@ func BenchmarkCommandTagRowsAffected(b *testing.B) { } func BenchmarkCommandTagTypeFromString(b *testing.B) { - ct := CommandTag{buf: []byte("UPDATE 1")} + ct := CommandTag{s: "UPDATE 1"} var update bool for i := 0; i < b.N; i++ { @@ -59,7 +59,7 @@ func BenchmarkCommandTagInsert(b *testing.B) { } for _, bm := range benchmarks { - ct := CommandTag{buf: []byte(bm.commandTag)} + ct := CommandTag{s: bm.commandTag} b.Run(bm.commandTag, func(b *testing.B) { var is bool for i := 0; i < b.N; i++ { diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 65fb015a..37928ed7 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -630,7 +630,7 @@ func (pgConn *PgConn) ParameterStatus(key string) string { // CommandTag is the result of an Exec function type CommandTag struct { - buf []byte + s string } // RowsAffected returns the number of rows affected. If the CommandTag was not @@ -638,8 +638,8 @@ type CommandTag struct { func (ct CommandTag) RowsAffected() int64 { // Find last non-digit idx := -1 - for i := len(ct.buf) - 1; i >= 0; i-- { - if ct.buf[i] >= '0' && ct.buf[i] <= '9' { + for i := len(ct.s) - 1; i >= 0; i-- { + if ct.s[i] >= '0' && ct.s[i] <= '9' { idx = i } else { break @@ -651,7 +651,7 @@ func (ct CommandTag) RowsAffected() int64 { } var n int64 - for _, b := range ct.buf[idx:] { + for _, b := range ct.s[idx:] { n = n*10 + int64(b-'0') } @@ -659,51 +659,27 @@ func (ct CommandTag) RowsAffected() int64 { } func (ct CommandTag) String() string { - return string(ct.buf) + return ct.s } // Insert is true if the command tag starts with "INSERT". func (ct CommandTag) Insert() bool { - return len(ct.buf) >= 6 && - ct.buf[0] == 'I' && - ct.buf[1] == 'N' && - ct.buf[2] == 'S' && - ct.buf[3] == 'E' && - ct.buf[4] == 'R' && - ct.buf[5] == 'T' + return strings.HasPrefix(ct.s, "INSERT") } // Update is true if the command tag starts with "UPDATE". func (ct CommandTag) Update() bool { - return len(ct.buf) >= 6 && - ct.buf[0] == 'U' && - ct.buf[1] == 'P' && - ct.buf[2] == 'D' && - ct.buf[3] == 'A' && - ct.buf[4] == 'T' && - ct.buf[5] == 'E' + return strings.HasPrefix(ct.s, "UPDATE") } // Delete is true if the command tag starts with "DELETE". func (ct CommandTag) Delete() bool { - return len(ct.buf) >= 6 && - ct.buf[0] == 'D' && - ct.buf[1] == 'E' && - ct.buf[2] == 'L' && - ct.buf[3] == 'E' && - ct.buf[4] == 'T' && - ct.buf[5] == 'E' + return strings.HasPrefix(ct.s, "DELETE") } // Select is true if the command tag starts with "SELECT". func (ct CommandTag) Select() bool { - return len(ct.buf) >= 6 && - ct.buf[0] == 'S' && - ct.buf[1] == 'E' && - ct.buf[2] == 'L' && - ct.buf[3] == 'E' && - ct.buf[4] == 'C' && - ct.buf[5] == 'T' + return strings.HasPrefix(ct.s, "SELECT") } type StatementDescription struct { @@ -1585,9 +1561,7 @@ func (pgConn *PgConn) CheckConn() error { // makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory. func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag { - ct := make([]byte, len(buf)) - copy(ct, buf) - return CommandTag{buf: ct} + return CommandTag{s: string(buf)} } // HijackedConn is the result of hijacking a connection. diff --git a/pgconn/pgconn_private_test.go b/pgconn/pgconn_private_test.go index 4368f717..5659bc9e 100644 --- a/pgconn/pgconn_private_test.go +++ b/pgconn/pgconn_private_test.go @@ -17,17 +17,17 @@ func TestCommandTag(t *testing.T) { isDelete bool isSelect bool }{ - {commandTag: CommandTag{buf: []byte("INSERT 0 5")}, rowsAffected: 5, isInsert: true}, - {commandTag: CommandTag{buf: []byte("UPDATE 0")}, rowsAffected: 0, isUpdate: true}, - {commandTag: CommandTag{buf: []byte("UPDATE 1")}, rowsAffected: 1, isUpdate: true}, - {commandTag: CommandTag{buf: []byte("DELETE 0")}, rowsAffected: 0, isDelete: true}, - {commandTag: CommandTag{buf: []byte("DELETE 1")}, rowsAffected: 1, isDelete: true}, - {commandTag: CommandTag{buf: []byte("DELETE 1234567890")}, rowsAffected: 1234567890, isDelete: true}, - {commandTag: CommandTag{buf: []byte("SELECT 1")}, rowsAffected: 1, isSelect: true}, - {commandTag: CommandTag{buf: []byte("SELECT 99999999999")}, rowsAffected: 99999999999, isSelect: true}, - {commandTag: CommandTag{buf: []byte("CREATE TABLE")}, rowsAffected: 0}, - {commandTag: CommandTag{buf: []byte("ALTER TABLE")}, rowsAffected: 0}, - {commandTag: CommandTag{buf: []byte("DROP TABLE")}, rowsAffected: 0}, + {commandTag: CommandTag{s: "INSERT 0 5"}, rowsAffected: 5, isInsert: true}, + {commandTag: CommandTag{s: "UPDATE 0"}, rowsAffected: 0, isUpdate: true}, + {commandTag: CommandTag{s: "UPDATE 1"}, rowsAffected: 1, isUpdate: true}, + {commandTag: CommandTag{s: "DELETE 0"}, rowsAffected: 0, isDelete: true}, + {commandTag: CommandTag{s: "DELETE 1"}, rowsAffected: 1, isDelete: true}, + {commandTag: CommandTag{s: "DELETE 1234567890"}, rowsAffected: 1234567890, isDelete: true}, + {commandTag: CommandTag{s: "SELECT 1"}, rowsAffected: 1, isSelect: true}, + {commandTag: CommandTag{s: "SELECT 99999999999"}, rowsAffected: 99999999999, isSelect: true}, + {commandTag: CommandTag{s: "CREATE TABLE"}, rowsAffected: 0}, + {commandTag: CommandTag{s: "ALTER TABLE"}, rowsAffected: 0}, + {commandTag: CommandTag{s: "DROP TABLE"}, rowsAffected: 0}, } for i, tt := range tests { From f0cd9cb8676434037df3b337feb192177879e074 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 11 Jul 2022 21:09:55 -0500 Subject: [PATCH 1093/1158] Update CommandTag comment --- pgconn/pgconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 37928ed7..e52051b3 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -628,7 +628,7 @@ func (pgConn *PgConn) ParameterStatus(key string) string { return pgConn.parameterStatuses[key] } -// CommandTag is the result of an Exec function +// CommandTag is the status text returned by PostgreSQL for a query. type CommandTag struct { s string } From 3dc9d17757eb0693ffea9d73507cc316a63c969a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 11 Jul 2022 21:15:37 -0500 Subject: [PATCH 1094/1158] Document new ResultReader.Values behavior --- CHANGELOG.md | 2 ++ pgconn/pgconn.go | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 69c0faad..387150c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ release work due to releasing multiple packages, and less clear changelogs. `CommandTag` is now an opaque type instead of directly exposing an underlying `[]byte`. +The return value `ResultReader.Values()` is no longer safe to retain a reference to after a subsequent call to `NextRow()` or `Close()`. + `Trace()` method adds low level message tracing similar to the `PQtrace` function in `libpq`. pgconn now uses non-blocking IO. This is a significant internal restructuring, but it should not cause any visible changes on its own. However, it is important in implementing other new features. diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index e52051b3..ab2e2df9 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1358,8 +1358,7 @@ func (rr *ResultReader) FieldDescriptions() []pgproto3.FieldDescription { } // Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only -// valid until the next NextRow call or the ResultReader is closed. However, the underlying byte data is safe to -// retain a reference to and mutate. +// valid until the next NextRow call or the ResultReader is closed. func (rr *ResultReader) Values() [][]byte { return rr.rowValues } From d5807f01ed61cc92a4914beab6a031bc7c553083 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 Jul 2022 06:57:56 -0500 Subject: [PATCH 1095/1158] Restore test from v4 --- pgtype/pgtype_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index e2465afc..b1b7e951 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -203,6 +203,57 @@ func TestTypeMapScanUnknownOIDIntoSQLScanner(t *testing.T) { assert.False(t, s.Valid) } +type pgCustomInt int64 + +func (ci *pgCustomInt) Scan(src interface{}) error { + *ci = pgCustomInt(src.(int64)) + return nil +} + +func TestScanPlanBinaryInt32ScanScanner(t *testing.T) { + ci := pgtype.NewMap() + src := []byte{0, 42} + var v pgCustomInt + + plan := ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &v) + err := plan.Scan(src, &v) + require.NoError(t, err) + require.EqualValues(t, 42, v) + + ptr := new(pgCustomInt) + plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(src, &ptr) + require.NoError(t, err) + require.EqualValues(t, 42, *ptr) + + ptr = new(pgCustomInt) + err = plan.Scan(nil, &ptr) + require.NoError(t, err) + assert.Nil(t, ptr) + + ptr = nil + plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(src, &ptr) + require.NoError(t, err) + require.EqualValues(t, 42, *ptr) + + ptr = nil + plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(nil, &ptr) + require.NoError(t, err) + assert.Nil(t, ptr) +} + +// Test for https://github.com/jackc/pgtype/issues/164 +func TestScanPlanInterface(t *testing.T) { + ci := pgtype.NewMap() + src := []byte{0, 42} + var v interface{} + plan := ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, v) + err := plan.Scan(src, v) + assert.Error(t, err) +} + func BenchmarkTypeMapScanInt4IntoBinaryDecoder(b *testing.B) { m := pgtype.NewMap() src := []byte{0, 0, 0, 42} From 9201cc0341050940ec44d8de45dfc5db198f80da Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Jul 2022 08:58:38 -0500 Subject: [PATCH 1096/1158] ConnectConfig copies config --- conn.go | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/conn.go b/conn.go index f89dce7b..1f87e9e7 100644 --- a/conn.go +++ b/conn.go @@ -113,6 +113,10 @@ func Connect(ctx context.Context, connString string) (*Conn, error) { // ConnectConfig establishes a connection with a PostgreSQL server with a configuration struct. // connConfig must have been created by ParseConfig. func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { + // In general this improves safety. In particular avoid the config.Config.OnNotification mutation from affecting other + // connections with the same config. See https://github.com/jackc/pgx/issues/618. + connConfig = connConfig.Copy() + return connect(ctx, connConfig) } @@ -188,23 +192,16 @@ func ParseConfig(connString string) (*ConnConfig, error) { return connConfig, nil } +// connect connects to a database. connect takes ownership of config. The caller must not use or access it again. func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. if !config.createdByParseConfig { panic("config must be created by ParseConfig") } - originalConfig := config - - // This isn't really a deep copy. But it is enough to avoid the config.Config.OnNotification mutation from affecting - // other connections with the same config. See https://github.com/jackc/pgx/issues/618. - { - configCopy := *config - config = &configCopy - } c = &Conn{ - config: originalConfig, + config: config, typeMap: pgtype.NewMap(), logLevel: config.LogLevel, logger: config.Logger, From 78875bb95ab29bf9b4eb653e6ee0db903155e07c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Jul 2022 12:27:10 -0500 Subject: [PATCH 1097/1158] Add tracing support Replaces existing logging support. Package tracelog provides adapter for old style logging. https://github.com/jackc/pgx/issues/1061 --- CHANGELOG.md | 4 +- README.md | 2 + batch.go | 186 +++++++------ batch_test.go | 88 ------- bench_test.go | 117 --------- conn.go | 119 +++++---- conn_test.go | 74 ------ copy_from.go | 25 +- doc.go | 18 +- helper_test.go | 3 +- pgproto3/frontend.go | 2 +- pgproto3/trace.go | 4 +- pgxpool/common_test.go | 3 +- rows.go | 38 +-- stdlib/sql_test.go | 7 +- tracelog/tracelog.go | 295 +++++++++++++++++++++ tracelog/tracelog_test.go | 301 +++++++++++++++++++++ tracer.go | 107 ++++++++ tracer_test.go | 538 ++++++++++++++++++++++++++++++++++++++ 19 files changed, 1446 insertions(+), 485 deletions(-) create mode 100644 tracelog/tracelog.go create mode 100644 tracelog/tracelog_test.go create mode 100644 tracer.go create mode 100644 tracer_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 387150c2..9a402c40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -159,7 +159,9 @@ Previously, a batch with 10 unique parameterized statements executed 100 times w for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements in a single network round trip. So it would only take 2 round trips. -## 3rd Party Logger Integration +## Tracing and Logging + +Internal logging support has been replaced with tracing hooks. This allows custom tracing integration with tools like OpenTelemetry. Package tracelog provides an adapter for pgx v4 loggers to act as a tracer. All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency tree. diff --git a/README.md b/README.md index c4f7239f..c7224b22 100644 --- a/README.md +++ b/README.md @@ -163,6 +163,8 @@ pgerrcode contains constants for the PostgreSQL error codes. ## Adapters for 3rd Party Loggers +These adapters can be used with the tracelog package. + * [github.com/jackc/pgx-go-kit-log](https://github.com/jackc/pgx-go-kit-log) * [github.com/jackc/pgx-log15](https://github.com/jackc/pgx-log15) * [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus) diff --git a/batch.go b/batch.go index f2a9b4c8..a6951096 100644 --- a/batch.go +++ b/batch.go @@ -50,13 +50,14 @@ type BatchResults interface { } type batchResults struct { - ctx context.Context - conn *Conn - mrr *pgconn.MultiResultReader - err error - b *Batch - ix int - closed bool + ctx context.Context + conn *Conn + mrr *pgconn.MultiResultReader + err error + b *Batch + ix int + closed bool + endTraced bool } // Exec reads the results from the next query in the batch as if the query has been sent with Exec. @@ -75,35 +76,29 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { if err == nil { err = errors.New("no result") } - if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{ - "sql": query, - "args": logQueryArgs(arguments), - "err": err, + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + Err: err, }) } return pgconn.CommandTag{}, err } commandTag, err := br.mrr.ResultReader().Close() + br.err = err - if err != nil { - if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{ - "sql": query, - "args": logQueryArgs(arguments), - "err": err, - }) - } - } else if br.conn.shouldLog(LogLevelInfo) { - br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]any{ - "sql": query, - "args": logQueryArgs(arguments), - "commandTag": commandTag, + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + CommandTag: commandTag, + Err: br.err, }) } - return commandTag, err + return commandTag, br.err } // Query reads the results from the next query in the batch as if the query has been sent with Query. @@ -123,6 +118,7 @@ func (br *batchResults) Query() (Rows, error) { } rows := br.conn.getRows(br.ctx, query, arguments) + rows.batchTracer = br.conn.batchTracer if !br.mrr.NextResult() { rows.err = br.mrr.Close() @@ -131,11 +127,11 @@ func (br *batchResults) Query() (Rows, error) { } rows.closed = true - if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]any{ - "sql": query, - "args": logQueryArgs(arguments), - "err": rows.err, + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + Err: rows.err, }) } @@ -156,6 +152,15 @@ func (br *batchResults) QueryRow() Row { // Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to // resyncronize the connection with the server. In this case the underlying connection will have been closed. func (br *batchResults) Close() error { + defer func() { + if !br.endTraced { + if br.conn != nil && br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err}) + } + br.endTraced = true + } + }() + if br.err != nil { return br.err } @@ -163,24 +168,26 @@ func (br *batchResults) Close() error { if br.closed { return nil } - br.closed = true - // log any queries that haven't yet been logged by Exec or Query - for { - query, args, ok := br.nextQueryAndArgs() - if !ok { - break - } - - if br.conn.shouldLog(LogLevelInfo) { - br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]any{ - "sql": query, - "args": logQueryArgs(args), - }) + // consume and log any queries that haven't yet been logged by Exec or Query + if br.conn.batchTracer != nil { + for br.err == nil && !br.closed && br.b != nil && br.ix < len(br.b.items) { + br.Exec() } } - return br.mrr.Close() + br.closed = true + + err := br.mrr.Close() + if br.err == nil { + br.err = err + } + + return br.err +} + +func (br *batchResults) earlyError() error { + return br.err } func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { @@ -195,14 +202,15 @@ func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { } type pipelineBatchResults struct { - ctx context.Context - conn *Conn - pipeline *pgconn.Pipeline - lastRows *baseRows - err error - b *Batch - ix int - closed bool + ctx context.Context + conn *Conn + pipeline *pgconn.Pipeline + lastRows *baseRows + err error + b *Batch + ix int + closed bool + endTraced bool } // Exec reads the results from the next query in the batch as if the query has been sent with Exec. @@ -227,25 +235,17 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) { var commandTag pgconn.CommandTag switch results := results.(type) { case *pgconn.ResultReader: - commandTag, err = results.Close() + commandTag, br.err = results.Close() default: return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results) } - if err != nil { - br.err = err - if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{ - "sql": query, - "args": logQueryArgs(arguments), - "err": err, - }) - } - } else if br.conn.shouldLog(LogLevelInfo) { - br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]any{ - "sql": query, - "args": logQueryArgs(arguments), - "commandTag": commandTag, + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + CommandTag: commandTag, + Err: br.err, }) } @@ -274,6 +274,7 @@ func (br *pipelineBatchResults) Query() (Rows, error) { } rows := br.conn.getRows(br.ctx, query, arguments) + rows.batchTracer = br.conn.batchTracer br.lastRows = rows results, err := br.pipeline.GetResults() @@ -281,11 +282,12 @@ func (br *pipelineBatchResults) Query() (Rows, error) { br.err = err rows.err = err rows.closed = true - if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]any{ - "sql": query, - "args": logQueryArgs(arguments), - "err": rows.err, + + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + Err: err, }) } } else { @@ -313,6 +315,15 @@ func (br *pipelineBatchResults) QueryRow() Row { // Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to // resyncronize the connection with the server. In this case the underlying connection will have been closed. func (br *pipelineBatchResults) Close() error { + defer func() { + if !br.endTraced { + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err}) + } + br.endTraced = true + } + }() + if br.err != nil { return br.err } @@ -325,24 +336,25 @@ func (br *pipelineBatchResults) Close() error { if br.closed { return nil } - br.closed = true - // log any queries that haven't yet been logged by Exec or Query - for { - query, args, ok := br.nextQueryAndArgs() - if !ok { - break - } - - if br.conn.shouldLog(LogLevelInfo) { - br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]any{ - "sql": query, - "args": logQueryArgs(args), - }) + // consume and log any queries that haven't yet been logged by Exec or Query + if br.conn.batchTracer != nil { + for br.err == nil && !br.closed && br.b != nil && br.ix < len(br.b.items) { + br.Exec() } } + br.closed = true - return br.pipeline.Close() + err := br.pipeline.Close() + if br.err == nil { + br.err = err + } + + return br.err +} + +func (br *pipelineBatchResults) earlyError() error { + return br.err } func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) { diff --git a/batch_test.go b/batch_test.go index e8d6f677..156e8f8f 100644 --- a/batch_test.go +++ b/batch_test.go @@ -733,94 +733,6 @@ func testConnSendBatch(t *testing.T, conn *pgx.Conn, queryCount int) { require.NoError(t, err) } -func TestLogBatchStatementsOnExec(t *testing.T) { - l1 := &testLogger{} - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = l1 - - conn := mustConnect(t, config) - defer closeConn(t, conn) - - l1.logs = l1.logs[0:0] // Clear logs written when establishing connection - - batch := &pgx.Batch{} - batch.Queue("create table foo (id bigint)") - batch.Queue("drop table foo") - - br := conn.SendBatch(context.Background(), batch) - - _, err := br.Exec() - if err != nil { - t.Fatalf("Unexpected error creating table: %v", err) - } - - _, err = br.Exec() - if err != nil { - t.Fatalf("Unexpected error dropping table: %v", err) - } - - if len(l1.logs) != 2 { - t.Fatalf("Expected two log entries but got %d", len(l1.logs)) - } - - if l1.logs[0].msg != "BatchResult.Exec" { - t.Errorf("Expected first log message to be 'BatchResult.Exec' but was '%s", l1.logs[0].msg) - } - - if l1.logs[0].data["sql"] != "create table foo (id bigint)" { - t.Errorf("Expected the first query to be 'create table foo (id bigint)' but was '%s'", l1.logs[0].data["sql"]) - } - - if l1.logs[1].msg != "BatchResult.Exec" { - t.Errorf("Expected second log message to be 'BatchResult.Exec' but was '%s", l1.logs[1].msg) - } - - if l1.logs[1].data["sql"] != "drop table foo" { - t.Errorf("Expected the second query to be 'drop table foo' but was '%s'", l1.logs[1].data["sql"]) - } -} - -func TestLogBatchStatementsOnBatchResultClose(t *testing.T) { - l1 := &testLogger{} - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = l1 - - conn := mustConnect(t, config) - defer closeConn(t, conn) - - l1.logs = l1.logs[0:0] // Clear logs written when establishing connection - - batch := &pgx.Batch{} - batch.Queue("select generate_series(1,$1)", 100) - batch.Queue("select 1 = 1;") - - br := conn.SendBatch(context.Background(), batch) - - if err := br.Close(); err != nil { - t.Fatalf("Unexpected batch error: %v", err) - } - - if len(l1.logs) != 2 { - t.Fatalf("Expected 2 log statements but found %d", len(l1.logs)) - } - - if l1.logs[0].msg != "BatchResult.Close" { - t.Errorf("Expected first log statement to be 'BatchResult.Close' but was %s", l1.logs[0].msg) - } - - if l1.logs[0].data["sql"] != "select generate_series(1,$1)" { - t.Errorf("Expected first query to be 'select generate_series(1,$1)' but was '%s'", l1.logs[0].data["sql"]) - } - - if l1.logs[1].msg != "BatchResult.Close" { - t.Errorf("Expected second log statement to be 'BatchResult.Close' but was %s", l1.logs[1].msg) - } - - if l1.logs[1].data["sql"] != "select 1 = 1;" { - t.Errorf("Expected second query to be 'select 1 = 1;' but was '%s'", l1.logs[1].data["sql"]) - } -} - func TestSendBatchSimpleProtocol(t *testing.T) { t.Parallel() diff --git a/bench_test.go b/bench_test.go index 31b3b38e..73e1b258 100644 --- a/bench_test.go +++ b/bench_test.go @@ -284,123 +284,6 @@ func BenchmarkPointerPointerWithPresentValues(b *testing.B) { } } -func BenchmarkSelectWithoutLogging(b *testing.B) { - conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) - defer closeConn(b, conn) - - benchmarkSelectWithLog(b, conn) -} - -type discardLogger struct{} - -func (dl discardLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]any) { -} - -func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) { - var logger discardLogger - config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = logger - config.LogLevel = pgx.LogLevelTrace - - conn := mustConnect(b, config) - defer closeConn(b, conn) - - benchmarkSelectWithLog(b, conn) -} - -func BenchmarkSelectWithLoggingDebugWithDiscard(b *testing.B) { - var logger discardLogger - config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = logger - config.LogLevel = pgx.LogLevelDebug - - conn := mustConnect(b, config) - defer closeConn(b, conn) - - benchmarkSelectWithLog(b, conn) -} - -func BenchmarkSelectWithLoggingInfoWithDiscard(b *testing.B) { - var logger discardLogger - config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = logger - config.LogLevel = pgx.LogLevelInfo - - conn := mustConnect(b, config) - defer closeConn(b, conn) - - benchmarkSelectWithLog(b, conn) -} - -func BenchmarkSelectWithLoggingErrorWithDiscard(b *testing.B) { - var logger discardLogger - config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = logger - config.LogLevel = pgx.LogLevelError - - conn := mustConnect(b, config) - defer closeConn(b, conn) - - benchmarkSelectWithLog(b, conn) -} - -func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) { - _, err := conn.Prepare(context.Background(), "test", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz") - if err != nil { - b.Fatal(err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - var record struct { - id int32 - userName string - email string - name string - sex string - birthDate time.Time - lastLoginTime time.Time - } - - err = conn.QueryRow(context.Background(), "test").Scan( - &record.id, - &record.userName, - &record.email, - &record.name, - &record.sex, - &record.birthDate, - &record.lastLoginTime, - ) - if err != nil { - b.Fatal(err) - } - - // These checks both ensure that the correct data was returned - // and provide a benchmark of accessing the returned values. - if record.id != 1 { - b.Fatalf("bad value for id: %v", record.id) - } - if record.userName != "johnsmith" { - b.Fatalf("bad value for userName: %v", record.userName) - } - if record.email != "johnsmith@example.com" { - b.Fatalf("bad value for email: %v", record.email) - } - if record.name != "John Smith" { - b.Fatalf("bad value for name: %v", record.name) - } - if record.sex != "male" { - b.Fatalf("bad value for sex: %v", record.sex) - } - if record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) { - b.Fatalf("bad value for birthDate: %v", record.birthDate) - } - if record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) { - b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) - } - } -} - const benchmarkWriteTableCreateSQL = `drop table if exists t; create table t( diff --git a/conn.go b/conn.go index 1f87e9e7..b8e0b232 100644 --- a/conn.go +++ b/conn.go @@ -19,8 +19,8 @@ import ( // then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic. type ConnConfig struct { pgconn.Config - Logger Logger - LogLevel LogLevel + + Tracer QueryTracer // Original connection string that was parsed into config. connString string @@ -63,8 +63,11 @@ type Conn struct { preparedStatements map[string]*pgconn.StatementDescription statementCache stmtcache.Cache descriptionCache stmtcache.Cache - logger Logger - logLevel LogLevel + + queryTracer QueryTracer + batchTracer BatchTracer + copyFromTracer CopyFromTracer + prepareTracer PrepareTracer notifications []*pgconn.Notification @@ -94,9 +97,6 @@ func (ident Identifier) Sanitize() string { // ErrNoRows occurs when rows are expected but none are returned. var ErrNoRows = errors.New("no rows in result set") -// ErrInvalidLogLevel occurs on attempt to set an invalid log level. -var ErrInvalidLogLevel = errors.New("invalid log level") - var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") @@ -182,7 +182,6 @@ func ParseConfig(connString string) (*ConnConfig, error) { connConfig := &ConnConfig{ Config: *config, createdByParseConfig: true, - LogLevel: LogLevelInfo, StatementCacheCapacity: statementCacheCapacity, DescriptionCacheCapacity: descriptionCacheCapacity, DefaultQueryExecMode: defaultQueryExecMode, @@ -194,6 +193,13 @@ func ParseConfig(connString string) (*ConnConfig, error) { // connect connects to a database. connect takes ownership of config. The caller must not use or access it again. func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { + if connectTracer, ok := config.Tracer.(ConnectTracer); ok { + ctx = connectTracer.TraceConnectStart(ctx, TraceConnectStartData{ConnConfig: config}) + defer func() { + connectTracer.TraceConnectEnd(ctx, TraceConnectEndData{Conn: c, Err: err}) + }() + } + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. if !config.createdByParseConfig { @@ -201,29 +207,28 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { } c = &Conn{ - config: config, - typeMap: pgtype.NewMap(), - logLevel: config.LogLevel, - logger: config.Logger, + config: config, + typeMap: pgtype.NewMap(), + queryTracer: config.Tracer, + } + + if t, ok := c.queryTracer.(BatchTracer); ok { + c.batchTracer = t + } + if t, ok := c.queryTracer.(CopyFromTracer); ok { + c.copyFromTracer = t + } + if t, ok := c.queryTracer.(PrepareTracer); ok { + c.prepareTracer = t } // Only install pgx notification system if no other callback handler is present. if config.Config.OnNotification == nil { config.Config.OnNotification = c.bufferNotifications - } else { - if c.shouldLog(LogLevelDebug) { - c.log(ctx, LogLevelDebug, "pgx notification handler disabled by application supplied OnNotification", map[string]any{"host": config.Config.Host}) - } } - if c.shouldLog(LogLevelInfo) { - c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]any{"host": config.Config.Host}) - } c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config) if err != nil { - if c.shouldLog(LogLevelError) { - c.log(ctx, LogLevelError, "connect failed", map[string]any{"err": err}) - } return nil, err } @@ -251,9 +256,6 @@ func (c *Conn) Close(ctx context.Context) error { } err := c.pgConn.Close(ctx) - if c.shouldLog(LogLevelInfo) { - c.log(ctx, LogLevelInfo, "closed connection", nil) - } return err } @@ -264,18 +266,23 @@ func (c *Conn) Close(ctx context.Context) error { // name and sql arguments. This allows a code path to Prepare and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) { + if c.prepareTracer != nil { + ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql}) + } + if name != "" { var ok bool if sd, ok = c.preparedStatements[name]; ok && sd.SQL == sql { + if c.prepareTracer != nil { + c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{AlreadyPrepared: true}) + } return sd, nil } } - if c.shouldLog(LogLevelError) { + if c.prepareTracer != nil { defer func() { - if err != nil { - c.log(ctx, LogLevelError, "Prepare failed", map[string]any{"err": err, "name": name, "sql": sql}) - } + c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{Err: err}) }() } @@ -337,21 +344,6 @@ func (c *Conn) die(err error) { c.pgConn.Close(ctx) } -func (c *Conn) shouldLog(lvl LogLevel) bool { - return c.logger != nil && c.logLevel >= lvl -} - -func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]any) { - if data == nil { - data = map[string]any{} - } - if c.pgConn != nil && c.pgConn.PID() != 0 { - data["pid"] = c.pgConn.PID() - } - - c.logger.Log(ctx, lvl, msg, data) -} - func quoteIdentifier(s string) string { return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` } @@ -379,24 +371,18 @@ func (c *Conn) Config() *ConnConfig { return c.config.Copy() } // Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced // positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { + if c.queryTracer != nil { + ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: arguments}) + } + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { return pgconn.CommandTag{}, err } - startTime := time.Now() - commandTag, err := c.exec(ctx, sql, arguments...) - if err != nil { - if c.shouldLog(LogLevelError) { - endTime := time.Now() - c.log(ctx, LogLevelError, "Exec", map[string]any{"sql": sql, "args": logQueryArgs(arguments), "err": err, "time": endTime.Sub(startTime)}) - } - return commandTag, err - } - if c.shouldLog(LogLevelInfo) { - endTime := time.Now() - c.log(ctx, LogLevelInfo, "Exec", map[string]any{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) + if c.queryTracer != nil { + c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{CommandTag: commandTag, Err: err}) } return commandTag, err @@ -537,7 +523,7 @@ func (c *Conn) getRows(ctx context.Context, sql string, args []any) *baseRows { r := &baseRows{} r.ctx = ctx - r.logger = c + r.queryTracer = c.queryTracer r.typeMap = c.typeMap r.startTime = time.Now() r.sql = sql @@ -628,7 +614,14 @@ type QueryRewriter interface { // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) { + if c.queryTracer != nil { + ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: args}) + } + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + if c.queryTracer != nil { + c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{Err: err}) + } return &baseRows{err: err, closed: true}, err } @@ -791,7 +784,17 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row { // SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. -func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { +func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { + if c.batchTracer != nil { + ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b}) + defer func() { + err := br.(interface{ earlyError() error }).earlyError() + if err != nil { + c.batchTracer.TraceBatchEnd(ctx, c, TraceBatchEndData{Err: err}) + } + }() + } + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } diff --git a/conn_test.go b/conn_test.go index 2ead63ce..b84093f4 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,7 +3,6 @@ package pgx_test import ( "bytes" "context" - "log" "os" "strings" "sync" @@ -743,79 +742,6 @@ func TestInsertTimestampArray(t *testing.T) { }) } -type testLog struct { - lvl pgx.LogLevel - msg string - data map[string]any -} - -type testLogger struct { - logs []testLog -} - -func (l *testLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]any) { - data["ctxdata"] = ctx.Value("ctxdata") - l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) -} - -func TestLogPassesContext(t *testing.T) { - t.Parallel() - - l1 := &testLogger{} - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = l1 - - conn := mustConnect(t, config) - defer closeConn(t, conn) - - l1.logs = l1.logs[0:0] // Clear logs written when establishing connection - - ctx := context.WithValue(context.Background(), "ctxdata", "foo") - - if _, err := conn.Exec(ctx, ";"); err != nil { - t.Fatal(err) - } - - if len(l1.logs) != 1 { - t.Fatal("Expected logger to be called once, but it wasn't") - } - - if l1.logs[0].data["ctxdata"] != "foo" { - t.Fatal("Expected context data to be passed to logger, but it wasn't") - } -} - -func TestLoggerFunc(t *testing.T) { - t.Parallel() - - const testMsg = "foo" - - buf := bytes.Buffer{} - logger := log.New(&buf, "", 0) - - createAdapterFn := func(logger *log.Logger) pgx.LoggerFunc { - return func(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { - logger.Printf("%s", testMsg) - } - } - - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = createAdapterFn(logger) - - conn := mustConnect(t, config) - defer closeConn(t, conn) - - buf.Reset() // Clear logs written when establishing connection - - if _, err := conn.Exec(context.TODO(), ";"); err != nil { - t.Fatal(err) - } - - if strings.TrimSpace(buf.String()) != testMsg { - t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String()) - } -} - func TestIdentifierSanitize(t *testing.T) { t.Parallel() diff --git a/copy_from.go b/copy_from.go index c5e9aae8..c8b98c57 100644 --- a/copy_from.go +++ b/copy_from.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "io" - "time" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgconn" @@ -89,6 +88,13 @@ type copyFrom struct { } func (ct *copyFrom) run(ctx context.Context) (int64, error) { + if ct.conn.copyFromTracer != nil { + ctx = ct.conn.copyFromTracer.TraceCopyFromStart(ctx, ct.conn, TraceCopyFromStartData{ + TableName: ct.tableName, + ColumnNames: ct.columnNames, + }) + } + quotedTableName := ct.tableName.Sanitize() cbuf := &bytes.Buffer{} for i, cn := range ct.columnNames { @@ -145,24 +151,19 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { w.Close() }() - startTime := time.Now() - commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) r.Close() <-doneChan - rowsAffected := commandTag.RowsAffected() - endTime := time.Now() - if err == nil { - if ct.conn.shouldLog(LogLevelInfo) { - ct.conn.log(ctx, LogLevelInfo, "CopyFrom", map[string]any{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected}) - } - } else if ct.conn.shouldLog(LogLevelError) { - ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]any{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime)}) + if ct.conn.copyFromTracer != nil { + ct.conn.copyFromTracer.TraceCopyFromEnd(ctx, ct.conn, TraceCopyFromEndData{ + CommandTag: commandTag, + Err: err, + }) } - return rowsAffected, err + return commandTag.RowsAffected(), err } func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) { diff --git a/doc.go b/doc.go index cfc2af85..b10ab1df 100644 --- a/doc.go +++ b/doc.go @@ -12,15 +12,7 @@ The primary way of establishing a connection is with `pgx.Connect`. The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with -`ConnectConfig`. - - config, err := pgx.ParseConfig(os.Getenv("DATABASE_URL")) - if err != nil { - // ... - } - config.Logger = log15adapter.NewLogger(log.New("module", "pgx")) - - conn, err := pgx.ConnectConfig(context.Background(), config) +`ConnectConfig` to configure settings such as tracing that cannot be configured with a connection string. Connection Pool @@ -315,11 +307,11 @@ notification is received or the context is canceled. } -Logging +Tracing and Logging -pgx defines a simple logger interface. Connections optionally accept a logger that satisfies this interface. Set -LogLevel to control logging verbosity. Adapters for github.com/inconshreveable/log15, github.com/sirupsen/logrus, -go.uber.org/zap, github.com/rs/zerolog, and the testing log are provided in the log directory. +pgx supports tracing by setting ConnConfig.Tracer. + +In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer. Lower Level PostgreSQL Functionality diff --git a/helper_test.go b/helper_test.go index f091d23e..26e54621 100644 --- a/helper_test.go +++ b/helper_test.go @@ -98,8 +98,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName return } - assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName) - assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName) + assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName) assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName) assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName) diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index eed8dc4f..09f04141 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -162,7 +162,7 @@ func (f *Frontend) SendExecute(msg *Execute) { prevLen := len(f.wbuf) f.wbuf = msg.Encode(f.wbuf) if f.tracer != nil { - f.tracer.traceExecute('F', int32(len(f.wbuf)-prevLen), msg) + f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg) } } diff --git a/pgproto3/trace.go b/pgproto3/trace.go index d3edc4aa..c09f68d1 100644 --- a/pgproto3/trace.go +++ b/pgproto3/trace.go @@ -79,7 +79,7 @@ func (t *tracer) traceMessage(sender byte, encodedLen int32, msg Message) { case *ErrorResponse: t.traceErrorResponse(sender, encodedLen, msg) case *Execute: - t.traceExecute(sender, encodedLen, msg) + t.TraceQueryute(sender, encodedLen, msg) case *Flush: t.traceFlush(sender, encodedLen, msg) case *FunctionCall: @@ -277,7 +277,7 @@ func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorRes t.finishTrace() } -func (t *tracer) traceExecute(sender byte, encodedLen int32, msg *Execute) { +func (t *tracer) TraceQueryute(sender byte, encodedLen int32, msg *Execute) { t.beginTrace(sender, encodedLen, "Execute") fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) t.finishTrace() diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go index eabc0e3c..16f4f553 100644 --- a/pgxpool/common_test.go +++ b/pgxpool/common_test.go @@ -160,8 +160,7 @@ func assertConnConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, test return } - assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName) - assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName) + assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName) assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName) assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName) diff --git a/rows.go b/rows.go index c91f3aff..ca5533d9 100644 --- a/rows.go +++ b/rows.go @@ -105,11 +105,6 @@ func (r *connRow) Scan(dest ...any) (err error) { return rows.Err() } -type rowLog interface { - shouldLog(lvl LogLevel) bool - log(ctx context.Context, lvl LogLevel, msg string, data map[string]any) -} - // baseRows implements the Rows interface for Conn.Query. type baseRows struct { typeMap *pgtype.Map @@ -127,12 +122,13 @@ type baseRows struct { conn *Conn multiResultReader *pgconn.MultiResultReader - logger rowLog - ctx context.Context - startTime time.Time - sql string - args []any - rowCount int + queryTracer QueryTracer + batchTracer BatchTracer + ctx context.Context + startTime time.Time + sql string + args []any + rowCount int } func (rows *baseRows) FieldDescriptions() []pgproto3.FieldDescription { @@ -161,20 +157,6 @@ func (rows *baseRows) Close() { } } - if rows.logger != nil { - endTime := time.Now() - - if rows.err == nil { - if rows.logger.shouldLog(LogLevelInfo) { - rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]any{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) - } - } else { - if rows.logger.shouldLog(LogLevelError) { - rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]any{"err": rows.err, "sql": rows.sql, "time": endTime.Sub(rows.startTime), "args": logQueryArgs(rows.args)}) - } - } - } - if rows.err != nil && rows.conn != nil && rows.sql != "" { if stmtcache.IsStatementInvalid(rows.err) { if sc := rows.conn.statementCache; sc != nil { @@ -186,6 +168,12 @@ func (rows *baseRows) Close() { } } } + + if rows.batchTracer != nil { + rows.batchTracer.TraceBatchQuery(rows.ctx, rows.conn, TraceBatchQueryData{SQL: rows.sql, Args: rows.args, CommandTag: rows.commandTag, Err: rows.err}) + } else if rows.queryTracer != nil { + rows.queryTracer.TraceQueryEnd(rows.ctx, rows.conn, TraceQueryEndData{rows.commandTag, rows.err}) + } } func (rows *baseRows) CommandTag() pgconn.CommandTag { diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index ee038add..ca2dccf3 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -17,6 +17,7 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/stdlib" + "github.com/jackc/pgx/v5/tracelog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -976,7 +977,7 @@ func TestScanJSONIntoJSONRawMessage(t *testing.T) { } type testLog struct { - lvl pgx.LogLevel + lvl tracelog.LogLevel msg string data map[string]any } @@ -985,7 +986,7 @@ type testLogger struct { logs []testLog } -func (l *testLogger) Log(ctx context.Context, lvl pgx.LogLevel, msg string, data map[string]any) { +func (l *testLogger) Log(ctx context.Context, lvl tracelog.LogLevel, msg string, data map[string]any) { l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data}) } @@ -994,7 +995,7 @@ func TestRegisterConnConfig(t *testing.T) { require.NoError(t, err) logger := &testLogger{} - connConfig.Logger = logger + connConfig.Tracer = &tracelog.TraceLog{Logger: logger, LogLevel: tracelog.LogLevelInfo} // Issue 947: Register and unregister a ConnConfig and ensure that the // returned connection string is not reused. diff --git a/tracelog/tracelog.go b/tracelog/tracelog.go new file mode 100644 index 00000000..d51b9b95 --- /dev/null +++ b/tracelog/tracelog.go @@ -0,0 +1,295 @@ +// Package tracelog provides a tracer that acts as a traditional logger. +package tracelog + +import ( + "context" + "encoding/hex" + "errors" + "fmt" + "time" + + "github.com/jackc/pgx/v5" +) + +// LogLevel represents the pgx logging level. See LogLevel* constants for +// possible values. +type LogLevel int + +// The values for log levels are chosen such that the zero value means that no +// log level was specified. +const ( + LogLevelTrace = LogLevel(6) + LogLevelDebug = LogLevel(5) + LogLevelInfo = LogLevel(4) + LogLevelWarn = LogLevel(3) + LogLevelError = LogLevel(2) + LogLevelNone = LogLevel(1) +) + +func (ll LogLevel) String() string { + switch ll { + case LogLevelTrace: + return "trace" + case LogLevelDebug: + return "debug" + case LogLevelInfo: + return "info" + case LogLevelWarn: + return "warn" + case LogLevelError: + return "error" + case LogLevelNone: + return "none" + default: + return fmt.Sprintf("invalid level %d", ll) + } +} + +// Logger is the interface used to get log output from pgx. +type Logger interface { + // Log a message at the given level with data key/value pairs. data may be nil. + Log(ctx context.Context, level LogLevel, msg string, data map[string]any) +} + +// LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface +type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) + +// Log delegates the logging request to the wrapped function +func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) { + f(ctx, level, msg, data) +} + +// LogLevelFromString converts log level string to constant +// +// Valid levels: +// trace +// debug +// info +// warn +// error +// none +func LogLevelFromString(s string) (LogLevel, error) { + switch s { + case "trace": + return LogLevelTrace, nil + case "debug": + return LogLevelDebug, nil + case "info": + return LogLevelInfo, nil + case "warn": + return LogLevelWarn, nil + case "error": + return LogLevelError, nil + case "none": + return LogLevelNone, nil + default: + return 0, errors.New("invalid log level") + } +} + +func logQueryArgs(args []any) []any { + logArgs := make([]any, 0, len(args)) + + for _, a := range args { + switch v := a.(type) { + case []byte: + if len(v) < 64 { + a = hex.EncodeToString(v) + } else { + a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64) + } + case string: + if len(v) > 64 { + a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64) + } + } + logArgs = append(logArgs, a) + } + + return logArgs +} + +// TraceLog implements pgx.QueryTracer, pgx.BatchTracer, pgx.ConnectTracer, and pgx.CopyFromTracer. All fields are +// required. +type TraceLog struct { + Logger Logger + LogLevel LogLevel +} + +type ctxKey int + +const ( + _ ctxKey = iota + tracelogQueryCtxKey + tracelogBatchCtxKey + tracelogCopyFromCtxKey + tracelogConnectCtxKey +) + +type traceQueryData struct { + startTime time.Time + sql string + args []any +} + +func (tl *TraceLog) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + return context.WithValue(ctx, tracelogQueryCtxKey, &traceQueryData{ + startTime: time.Now(), + sql: data.SQL, + args: data.Args, + }) +} + +func (tl *TraceLog) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + queryData := ctx.Value(tracelogQueryCtxKey).(*traceQueryData) + + endTime := time.Now() + interval := endTime.Sub(queryData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "err": data.Err, "time": interval}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "time": interval, "commandTag": data.CommandTag.String()}) + } +} + +type traceBatchData struct { + startTime time.Time +} + +func (tl *TraceLog) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + return context.WithValue(ctx, tracelogBatchCtxKey, &traceBatchData{ + startTime: time.Now(), + }) +} + +func (tl *TraceLog) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "BatchQuery", map[string]any{"sql": data.SQL, "args": logQueryArgs(data.Args), "err": data.Err}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "BatchQuery", map[string]any{"sql": data.SQL, "args": logQueryArgs(data.Args), "commandTag": data.CommandTag.String()}) + } +} + +func (tl *TraceLog) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + queryData := ctx.Value(tracelogBatchCtxKey).(*traceBatchData) + + endTime := time.Now() + interval := endTime.Sub(queryData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "BatchClose", map[string]any{"err": data.Err, "time": interval}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "BatchClose", map[string]any{"time": interval}) + } +} + +type traceCopyFromData struct { + startTime time.Time + TableName pgx.Identifier + ColumnNames []string +} + +func (tl *TraceLog) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + return context.WithValue(ctx, tracelogCopyFromCtxKey, &traceCopyFromData{ + startTime: time.Now(), + TableName: data.TableName, + ColumnNames: data.ColumnNames, + }) +} + +func (tl *TraceLog) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + copyFromData := ctx.Value(tracelogCopyFromCtxKey).(*traceCopyFromData) + + endTime := time.Now() + interval := endTime.Sub(copyFromData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval, "rowCount": data.CommandTag.RowsAffected()}) + } +} + +type traceConnectData struct { + startTime time.Time + connConfig *pgx.ConnConfig +} + +func (tl *TraceLog) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + return context.WithValue(ctx, tracelogConnectCtxKey, &traceConnectData{ + startTime: time.Now(), + connConfig: data.ConnConfig, + }) +} + +func (tl *TraceLog) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) { + connectData := ctx.Value(tracelogConnectCtxKey).(*traceConnectData) + + endTime := time.Now() + interval := endTime.Sub(connectData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.Logger.Log(ctx, LogLevelError, "Connect", map[string]any{ + "host": connectData.connConfig.Host, + "port": connectData.connConfig.Port, + "database": connectData.connConfig.Database, + "time": interval, + "err": data.Err, + }) + } + return + } + + if data.Conn != nil { + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, data.Conn, LogLevelInfo, "Connect", map[string]any{ + "host": connectData.connConfig.Host, + "port": connectData.connConfig.Port, + "database": connectData.connConfig.Database, + "time": interval, + }) + } + } +} + +func (tl *TraceLog) shouldLog(lvl LogLevel) bool { + return tl.LogLevel >= lvl +} + +func (tl *TraceLog) log(ctx context.Context, conn *pgx.Conn, lvl LogLevel, msg string, data map[string]any) { + if data == nil { + data = map[string]any{} + } + + pgConn := conn.PgConn() + if pgConn != nil { + pid := pgConn.PID() + if pid != 0 { + data["pid"] = pid + } + } + + tl.Logger.Log(ctx, lvl, msg, data) +} diff --git a/tracelog/tracelog_test.go b/tracelog/tracelog_test.go new file mode 100644 index 00000000..ed0f8eab --- /dev/null +++ b/tracelog/tracelog_test.go @@ -0,0 +1,301 @@ +package tracelog_test + +import ( + "bytes" + "context" + "log" + "os" + "strings" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/jackc/pgx/v5/tracelog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var defaultConnTestRunner pgxtest.ConnTestRunner + +func init() { + defaultConnTestRunner = pgxtest.DefaultConnTestRunner() + defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + return config + } +} + +type testLog struct { + lvl tracelog.LogLevel + msg string + data map[string]any +} + +type testLogger struct { + logs []testLog +} + +func (l *testLogger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) { + data["ctxdata"] = ctx.Value("ctxdata") + l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) +} + +func TestContextGetsPassedToLogMethod(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + + ctx = context.WithValue(context.Background(), "ctxdata", "foo") + _, err := conn.Exec(ctx, `;`) + require.NoError(t, err) + require.Len(t, logger.logs, 1) + require.Equal(t, "foo", logger.logs[0].data["ctxdata"]) + }) +} + +func TestLoggerFunc(t *testing.T) { + t.Parallel() + + const testMsg = "foo" + + buf := bytes.Buffer{} + logger := log.New(&buf, "", 0) + + createAdapterFn := func(logger *log.Logger) tracelog.LoggerFunc { + return func(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]interface{}) { + logger.Printf("%s", testMsg) + } + } + + config := defaultConnTestRunner.CreateConfig(context.Background(), t) + config.Tracer = &tracelog.TraceLog{ + Logger: createAdapterFn(logger), + LogLevel: tracelog.LogLevelTrace, + } + + conn, err := pgx.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer conn.Close(context.Background()) + + buf.Reset() // Clear logs written when establishing connection + + if _, err := conn.Exec(context.TODO(), ";"); err != nil { + t.Fatal(err) + } + + if strings.TrimSpace(buf.String()) != testMsg { + t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String()) + } +} + +func TestLogQuery(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + + _, err := conn.Exec(ctx, `select $1::text`, "testing") + require.NoError(t, err) + require.Len(t, logger.logs, 1) + require.Equal(t, "Query", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl) + + _, err = conn.Exec(ctx, `foo`, "testing") + require.Error(t, err) + require.Len(t, logger.logs, 2) + require.Equal(t, "Query", logger.logs[1].msg) + require.Equal(t, tracelog.LogLevelError, logger.logs[1].lvl) + require.Equal(t, err, logger.logs[1].data["err"]) + }) +} + +func TestLogCopyFrom(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(context.Background(), `create temporary table foo(a int4)`) + require.NoError(t, err) + + logger.logs = logger.logs[0:0] + + inputRows := [][]any{ + {int32(1)}, + {nil}, + } + + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + require.Len(t, logger.logs, 1) + require.Equal(t, "CopyFrom", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl) + + logger.logs = logger.logs[0:0] + + inputRows = [][]any{ + {"not an integer"}, + {nil}, + } + + copyCount, err = conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.Error(t, err) + require.EqualValues(t, 0, copyCount) + require.Len(t, logger.logs, 1) + require.Equal(t, "CopyFrom", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelError, logger.logs[0].lvl) + }) +} + +func TestLogConnect(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + config := defaultConnTestRunner.CreateConfig(context.Background(), t) + config.Tracer = tracer + + conn1, err := pgx.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer conn1.Close(context.Background()) + require.Len(t, logger.logs, 1) + require.Equal(t, "Connect", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl) + + logger.logs = logger.logs[0:0] + + config, err = pgx.ParseConfig("host=/invalid") + require.NoError(t, err) + config.Tracer = tracer + + conn2, err := pgx.ConnectConfig(context.Background(), config) + require.Nil(t, conn2) + require.Error(t, err) + require.Len(t, logger.logs, 1) + require.Equal(t, "Connect", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelError, logger.logs[0].lvl) +} + +func TestLogBatchStatementsOnExec(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + + batch := &pgx.Batch{} + batch.Queue("create table foo (id bigint)") + batch.Queue("drop table foo") + + br := conn.SendBatch(context.Background(), batch) + + _, err := br.Exec() + require.NoError(t, err) + + _, err = br.Exec() + require.NoError(t, err) + + err = br.Close() + require.NoError(t, err) + + require.Len(t, logger.logs, 3) + assert.Equal(t, "BatchQuery", logger.logs[0].msg) + assert.Equal(t, "create table foo (id bigint)", logger.logs[0].data["sql"]) + assert.Equal(t, "BatchQuery", logger.logs[1].msg) + assert.Equal(t, "drop table foo", logger.logs[1].data["sql"]) + assert.Equal(t, "BatchClose", logger.logs[2].msg) + + }) +} + +func TestLogBatchStatementsOnBatchResultClose(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + + batch := &pgx.Batch{} + batch.Queue("select generate_series(1,$1)", 100) + batch.Queue("select 1 = 1;") + + br := conn.SendBatch(context.Background(), batch) + err := br.Close() + require.NoError(t, err) + + require.Len(t, logger.logs, 3) + assert.Equal(t, "BatchQuery", logger.logs[0].msg) + assert.Equal(t, "select generate_series(1,$1)", logger.logs[0].data["sql"]) + assert.Equal(t, "BatchQuery", logger.logs[1].msg) + assert.Equal(t, "select 1 = 1;", logger.logs[1].data["sql"]) + assert.Equal(t, "BatchClose", logger.logs[2].msg) + }) +} diff --git a/tracer.go b/tracer.go new file mode 100644 index 00000000..58ca99f7 --- /dev/null +++ b/tracer.go @@ -0,0 +1,107 @@ +package pgx + +import ( + "context" + + "github.com/jackc/pgx/v5/pgconn" +) + +// QueryTracer traces Query, QueryRow, and Exec. +type QueryTracer interface { + // TraceQueryStart is called at the beginning of Query, QueryRow, and Exec calls. The returned context is used for the + // rest of the call and will be passed to TraceQueryEnd. + TraceQueryStart(ctx context.Context, conn *Conn, data TraceQueryStartData) context.Context + + TraceQueryEnd(ctx context.Context, conn *Conn, data TraceQueryEndData) +} + +type TraceQueryStartData struct { + SQL string + Args []any +} + +type TraceQueryEndData struct { + CommandTag pgconn.CommandTag + Err error +} + +// BatchTracer traces SendBatch. +type BatchTracer interface { + // TraceBatchStart is called at the beginning of SendBatch calls. The returned context is used for the + // rest of the call and will be passed to TraceBatchQuery and TraceBatchEnd. + TraceBatchStart(ctx context.Context, conn *Conn, data TraceBatchStartData) context.Context + + TraceBatchQuery(ctx context.Context, conn *Conn, data TraceBatchQueryData) + TraceBatchEnd(ctx context.Context, conn *Conn, data TraceBatchEndData) +} + +type TraceBatchStartData struct { + Batch *Batch +} + +type TraceBatchQueryData struct { + SQL string + Args []any + CommandTag pgconn.CommandTag + Err error +} + +type TraceBatchEndData struct { + Err error +} + +// CopyFromTracer traces CopyFrom. +type CopyFromTracer interface { + // TraceCopyFromStart is called at the beginning of CopyFrom calls. The returned context is used for the + // rest of the call and will be passed to TraceCopyFromEnd. + TraceCopyFromStart(ctx context.Context, conn *Conn, data TraceCopyFromStartData) context.Context + + TraceCopyFromEnd(ctx context.Context, conn *Conn, data TraceCopyFromEndData) +} + +type TraceCopyFromStartData struct { + TableName Identifier + ColumnNames []string +} + +type TraceCopyFromEndData struct { + CommandTag pgconn.CommandTag + Err error +} + +// PrepareTracer traces Prepare. +type PrepareTracer interface { + // TracePrepareStart is called at the beginning of Prepare calls. The returned context is used for the + // rest of the call and will be passed to TracePrepareEnd. + TracePrepareStart(ctx context.Context, conn *Conn, data TracePrepareStartData) context.Context + + TracePrepareEnd(ctx context.Context, conn *Conn, data TracePrepareEndData) +} + +type TracePrepareStartData struct { + Name string + SQL string +} + +type TracePrepareEndData struct { + AlreadyPrepared bool + Err error +} + +// ConnectTracer traces Connect and ConnectConfig. +type ConnectTracer interface { + // TraceConnectStart is called at the beginning of Connect and ConnectConfig calls. The returned context is used for + // the rest of the call and will be passed to TraceConnectEnd. + TraceConnectStart(ctx context.Context, data TraceConnectStartData) context.Context + + TraceConnectEnd(ctx context.Context, data TraceConnectEndData) +} + +type TraceConnectStartData struct { + ConnConfig *ConnConfig +} + +type TraceConnectEndData struct { + Conn *Conn + Err error +} diff --git a/tracer_test.go b/tracer_test.go new file mode 100644 index 00000000..86375b34 --- /dev/null +++ b/tracer_test.go @@ -0,0 +1,538 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +type testTracer struct { + traceQueryStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context + traceQueryEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) + traceBatchStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context + traceBatchQuery func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) + traceBatchEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) + traceCopyFromStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context + traceCopyFromEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) + tracePrepareStart func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context + tracePrepareEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) + traceConnectStart func(ctx context.Context, data pgx.TraceConnectStartData) context.Context + traceConnectEnd func(ctx context.Context, data pgx.TraceConnectEndData) +} + +func (tt *testTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + if tt.traceQueryStart != nil { + return tt.traceQueryStart(ctx, conn, data) + } + return ctx +} + +func (tt *testTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + if tt.traceQueryEnd != nil { + tt.traceQueryEnd(ctx, conn, data) + } +} + +func (tt *testTracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + if tt.traceBatchStart != nil { + return tt.traceBatchStart(ctx, conn, data) + } + return ctx +} + +func (tt *testTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + if tt.traceBatchQuery != nil { + tt.traceBatchQuery(ctx, conn, data) + } +} + +func (tt *testTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + if tt.traceBatchEnd != nil { + tt.traceBatchEnd(ctx, conn, data) + } +} + +func (tt *testTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + if tt.traceCopyFromStart != nil { + return tt.traceCopyFromStart(ctx, conn, data) + } + return ctx +} + +func (tt *testTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + if tt.traceCopyFromEnd != nil { + tt.traceCopyFromEnd(ctx, conn, data) + } +} + +func (tt *testTracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context { + if tt.tracePrepareStart != nil { + return tt.tracePrepareStart(ctx, conn, data) + } + return ctx +} + +func (tt *testTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + if tt.tracePrepareEnd != nil { + tt.tracePrepareEnd(ctx, conn, data) + } +} + +func (tt *testTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + if tt.traceConnectStart != nil { + return tt.traceConnectStart(ctx, data) + } + return ctx +} + +func (tt *testTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) { + if tt.traceConnectEnd != nil { + tt.traceConnectEnd(ctx, data) + } +} + +func TestTraceExec(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceQueryStartCalled := false + tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + traceQueryStartCalled = true + require.Equal(t, `select $1::text`, data.SQL) + require.Len(t, data.Args, 1) + require.Equal(t, `testing`, data.Args[0]) + return context.WithValue(ctx, "fromTraceQueryStart", "foo") + } + + traceQueryEndCalled := false + tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + traceQueryEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceQueryStart")) + require.Equal(t, `SELECT 1`, data.CommandTag.String()) + require.NoError(t, data.Err) + } + + _, err := conn.Exec(ctx, `select $1::text`, "testing") + require.NoError(t, err) + require.True(t, traceQueryStartCalled) + require.True(t, traceQueryEndCalled) + }) +} + +func TestTraceQuery(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceQueryStartCalled := false + tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + traceQueryStartCalled = true + require.Equal(t, `select $1::text`, data.SQL) + require.Len(t, data.Args, 1) + require.Equal(t, `testing`, data.Args[0]) + return context.WithValue(ctx, "fromTraceQueryStart", "foo") + } + + traceQueryEndCalled := false + tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + traceQueryEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceQueryStart")) + require.Equal(t, `SELECT 1`, data.CommandTag.String()) + require.NoError(t, data.Err) + } + + var s string + err := conn.QueryRow(ctx, `select $1::text`, "testing").Scan(&s) + require.NoError(t, err) + require.Equal(t, "testing", s) + require.True(t, traceQueryStartCalled) + require.True(t, traceQueryEndCalled) + }) +} + +func TestTraceBatchNormal(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceBatchStartCalled := false + tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + traceBatchStartCalled = true + require.NotNil(t, data.Batch) + require.Equal(t, 2, data.Batch.Len()) + return context.WithValue(ctx, "fromTraceBatchStart", "foo") + } + + traceBatchQueryCalledCount := 0 + tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + traceBatchQueryCalledCount++ + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.NoError(t, data.Err) + } + + traceBatchEndCalled := false + tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + traceBatchEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.NoError(t, data.Err) + } + + batch := &pgx.Batch{} + batch.Queue(`select 1`) + batch.Queue(`select 2`) + + br := conn.SendBatch(context.Background(), batch) + require.True(t, traceBatchStartCalled) + + var n int32 + err := br.QueryRow().Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + require.EqualValues(t, 1, traceBatchQueryCalledCount) + + err = br.QueryRow().Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 2, n) + require.EqualValues(t, 2, traceBatchQueryCalledCount) + + err = br.Close() + require.NoError(t, err) + + require.True(t, traceBatchEndCalled) + }) +} + +func TestTraceBatchClose(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceBatchStartCalled := false + tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + traceBatchStartCalled = true + require.NotNil(t, data.Batch) + require.Equal(t, 2, data.Batch.Len()) + return context.WithValue(ctx, "fromTraceBatchStart", "foo") + } + + traceBatchQueryCalledCount := 0 + tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + traceBatchQueryCalledCount++ + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.NoError(t, data.Err) + } + + traceBatchEndCalled := false + tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + traceBatchEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.NoError(t, data.Err) + } + + batch := &pgx.Batch{} + batch.Queue(`select 1`) + batch.Queue(`select 2`) + + br := conn.SendBatch(context.Background(), batch) + require.True(t, traceBatchStartCalled) + err := br.Close() + require.NoError(t, err) + require.EqualValues(t, 2, traceBatchQueryCalledCount) + require.True(t, traceBatchEndCalled) + }) +} + +func TestTraceBatchErrorWhileReadingResults(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceBatchStartCalled := false + tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + traceBatchStartCalled = true + require.NotNil(t, data.Batch) + require.Equal(t, 3, data.Batch.Len()) + return context.WithValue(ctx, "fromTraceBatchStart", "foo") + } + + traceBatchQueryCalledCount := 0 + tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + traceBatchQueryCalledCount++ + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + if traceBatchQueryCalledCount == 2 { + require.Error(t, data.Err) + } else { + require.NoError(t, data.Err) + } + } + + traceBatchEndCalled := false + tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + traceBatchEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.Error(t, data.Err) + } + + batch := &pgx.Batch{} + batch.Queue(`select 1`) + batch.Queue(`select 2/n-2 from generate_series(0,10) n`) + batch.Queue(`select 3`) + + br := conn.SendBatch(context.Background(), batch) + require.True(t, traceBatchStartCalled) + + commandTag, err := br.Exec() + require.NoError(t, err) + require.Equal(t, "SELECT 1", commandTag.String()) + + commandTag, err = br.Exec() + require.Error(t, err) + require.Equal(t, "", commandTag.String()) + + commandTag, err = br.Exec() + require.Error(t, err) + require.Equal(t, "", commandTag.String()) + + err = br.Close() + require.Error(t, err) + require.EqualValues(t, 2, traceBatchQueryCalledCount) + require.True(t, traceBatchEndCalled) + }) +} + +func TestTraceBatchErrorWhileReadingResultsWhileClosing(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceBatchStartCalled := false + tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + traceBatchStartCalled = true + require.NotNil(t, data.Batch) + require.Equal(t, 3, data.Batch.Len()) + return context.WithValue(ctx, "fromTraceBatchStart", "foo") + } + + traceBatchQueryCalledCount := 0 + tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + traceBatchQueryCalledCount++ + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + if traceBatchQueryCalledCount == 2 { + require.Error(t, data.Err) + } else { + require.NoError(t, data.Err) + } + } + + traceBatchEndCalled := false + tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + traceBatchEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.Error(t, data.Err) + } + + batch := &pgx.Batch{} + batch.Queue(`select 1`) + batch.Queue(`select 2/n-2 from generate_series(0,10) n`) + batch.Queue(`select 3`) + + br := conn.SendBatch(context.Background(), batch) + require.True(t, traceBatchStartCalled) + err := br.Close() + require.Error(t, err) + require.EqualValues(t, 2, traceBatchQueryCalledCount) + require.True(t, traceBatchEndCalled) + }) +} + +func TestTraceCopyFrom(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceCopyFromStartCalled := false + tracer.traceCopyFromStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + traceCopyFromStartCalled = true + require.Equal(t, pgx.Identifier{"foo"}, data.TableName) + require.Equal(t, []string{"a"}, data.ColumnNames) + return context.WithValue(ctx, "fromTraceCopyFromStart", "foo") + } + + traceCopyFromEndCalled := false + tracer.traceCopyFromEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + traceCopyFromEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceCopyFromStart")) + require.Equal(t, `COPY 2`, data.CommandTag.String()) + require.NoError(t, data.Err) + } + + _, err := conn.Exec(context.Background(), `create temporary table foo(a int4)`) + require.NoError(t, err) + + inputRows := [][]any{ + {int32(1)}, + {nil}, + } + + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + require.True(t, traceCopyFromStartCalled) + require.True(t, traceCopyFromEndCalled) + }) +} + +func TestTracePrepare(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + tracePrepareStartCalled := false + tracer.tracePrepareStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context { + tracePrepareStartCalled = true + require.Equal(t, `ps`, data.Name) + require.Equal(t, `select $1::text`, data.SQL) + return context.WithValue(ctx, "fromTracePrepareStart", "foo") + } + + tracePrepareEndCalled := false + tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + tracePrepareEndCalled = true + require.False(t, data.AlreadyPrepared) + require.NoError(t, data.Err) + } + + _, err := conn.Prepare(ctx, "ps", `select $1::text`) + require.NoError(t, err) + require.True(t, tracePrepareStartCalled) + require.True(t, tracePrepareEndCalled) + + tracePrepareStartCalled = false + tracePrepareEndCalled = false + tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + tracePrepareEndCalled = true + require.True(t, data.AlreadyPrepared) + require.NoError(t, data.Err) + } + + _, err = conn.Prepare(ctx, "ps", `select $1::text`) + require.NoError(t, err) + require.True(t, tracePrepareStartCalled) + require.True(t, tracePrepareEndCalled) + }) +} + +func TestTraceConnect(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + config := defaultConnTestRunner.CreateConfig(context.Background(), t) + config.Tracer = tracer + + traceConnectStartCalled := false + tracer.traceConnectStart = func(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + traceConnectStartCalled = true + require.NotNil(t, data.ConnConfig) + return context.WithValue(ctx, "fromTraceConnectStart", "foo") + } + + traceConnectEndCalled := false + tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) { + traceConnectEndCalled = true + require.NotNil(t, data.Conn) + require.NoError(t, data.Err) + } + + conn1, err := pgx.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer conn1.Close(context.Background()) + require.True(t, traceConnectStartCalled) + require.True(t, traceConnectEndCalled) + + config, err = pgx.ParseConfig("host=/invalid") + require.NoError(t, err) + config.Tracer = tracer + + traceConnectStartCalled = false + traceConnectEndCalled = false + tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) { + traceConnectEndCalled = true + require.Nil(t, data.Conn) + require.Error(t, data.Err) + } + + conn2, err := pgx.ConnectConfig(context.Background(), config) + require.Nil(t, conn2) + require.Error(t, err) + require.True(t, traceConnectStartCalled) + require.True(t, traceConnectEndCalled) +} From 29254180ca93e3f137dcebcfb14340c6d3d02695 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Jul 2022 17:46:47 -0500 Subject: [PATCH 1098/1158] Add callback functions to queued queries Improve batch query ergonomics by allowing the code to handle the results of a query to be right next to the query. --- CHANGELOG.md | 8 +++ batch.go | 116 +++++++++++++++++++++++++++++--------- batch_test.go | 150 ++++++++++++++++++++++++++++++++++++++++++++++++++ conn.go | 36 ++++++------ 4 files changed, 267 insertions(+), 43 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a402c40..83c32783 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -153,6 +153,14 @@ The `RowScanner` interface allows a single argument to Rows.Scan to scan the ent Rather than every type that implemented `Begin` or `BeginTx` methods also needing to implement `BeginFunc` and `BeginTxFunc` these methods have been converted to functions that take a db that implements `Begin` or `BeginTx`. +## Improved Batch Query Ergonomics + +Previously, the code for building a batch went in one place before the call to `SendBatch`, and the code for reading the +results went in one place after the call to `SendBatch`. This could make it difficult to match up the query and the code +to handle the results. Now `Queue` returns a `QueuedQuery` which has methods `Query`, `QueryRow`, and `Exec` which can +be used to register a callback function that will handle the result. Callback functions are called automatically when +`BatchResults.Close` is called. + ## SendBatch Uses Pipeline Mode When Appropriate Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1 diff --git a/batch.go b/batch.go index a6951096..af62039f 100644 --- a/batch.go +++ b/batch.go @@ -8,44 +8,99 @@ import ( "github.com/jackc/pgx/v5/pgconn" ) -type batchItem struct { +// QueuedQuery is a query that has been queued for execution via a Batch. +type QueuedQuery struct { query string arguments []any + fn batchItemFunc sd *pgconn.StatementDescription } +type batchItemFunc func(br BatchResults) error + +// Query sets fn to be called when the response to qq is received. +func (qq *QueuedQuery) Query(fn func(rows Rows) error) { + qq.fn = func(br BatchResults) error { + rows, err := br.Query() + if err != nil { + return err + } + defer rows.Close() + + err = fn(rows) + if err != nil { + return err + } + rows.Close() + + return rows.Err() + } +} + +// Query sets fn to be called when the response to qq is received. +func (qq *QueuedQuery) QueryRow(fn func(row Row) error) { + qq.fn = func(br BatchResults) error { + row := br.QueryRow() + return fn(row) + } +} + +// Exec sets fn to be called when the response to qq is received. +func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) { + qq.fn = func(br BatchResults) error { + ct, err := br.Exec() + if err != nil { + return err + } + + return fn(ct) + } +} + // Batch queries are a way of bundling multiple queries together to avoid // unnecessary network round trips. A Batch must only be sent once. type Batch struct { - items []*batchItem + queuedQueries []*QueuedQuery } // Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. -func (b *Batch) Queue(query string, arguments ...any) { - b.items = append(b.items, &batchItem{ +func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery { + qq := &QueuedQuery{ query: query, arguments: arguments, - }) + } + b.queuedQueries = append(b.queuedQueries, qq) + return qq } // Len returns number of queries that have been queued so far. func (b *Batch) Len() int { - return len(b.items) + return len(b.queuedQueries) } type BatchResults interface { - // Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. + // Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer + // calling Exec on the QueuedQuery. Exec() (pgconn.CommandTag, error) - // Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. + // Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer + // calling Query on the QueuedQuery. Query() (Rows, error) // QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow. + // Prefer calling QueryRow on the QueuedQuery. QueryRow() Row - // Close closes the batch operation. This must be called before the underlying connection can be used again. Any error - // that occurred during a batch operation may have made it impossible to resyncronize the connection with the server. - // In this case the underlying connection will have been closed. Close is safe to call multiple times. + // Close closes the batch operation. All unread results are read and any callback functions registered with + // QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an + // error or the batch encounters an error subsequent callback functions will not be called. + // + // Close must be called before the underlying connection can be used again. Any error that occurred during a batch + // operation may have made it impossible to resyncronize the connection with the server. In this case the underlying + // connection will have been closed. + // + // Close is safe to call multiple times. If it returns an error subsequent calls will return the same error. Callback + // functions will not be rerun. Close() error } @@ -55,7 +110,7 @@ type batchResults struct { mrr *pgconn.MultiResultReader err error b *Batch - ix int + qqIdx int closed bool endTraced bool } @@ -169,9 +224,14 @@ func (br *batchResults) Close() error { return nil } - // consume and log any queries that haven't yet been logged by Exec or Query - if br.conn.batchTracer != nil { - for br.err == nil && !br.closed && br.b != nil && br.ix < len(br.b.items) { + // Read and run fn for all remaining items + for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { + if br.b.queuedQueries[br.qqIdx].fn != nil { + err := br.b.queuedQueries[br.qqIdx].fn(br) + if err != nil && br.err == nil { + br.err = err + } + } else { br.Exec() } } @@ -191,12 +251,12 @@ func (br *batchResults) earlyError() error { } func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { - if br.b != nil && br.ix < len(br.b.items) { - bi := br.b.items[br.ix] + if br.b != nil && br.qqIdx < len(br.b.queuedQueries) { + bi := br.b.queuedQueries[br.qqIdx] query = bi.query args = bi.arguments ok = true - br.ix++ + br.qqIdx++ } return } @@ -208,7 +268,7 @@ type pipelineBatchResults struct { lastRows *baseRows err error b *Batch - ix int + qqIdx int closed bool endTraced bool } @@ -337,12 +397,18 @@ func (br *pipelineBatchResults) Close() error { return nil } - // consume and log any queries that haven't yet been logged by Exec or Query - if br.conn.batchTracer != nil { - for br.err == nil && !br.closed && br.b != nil && br.ix < len(br.b.items) { + // Read and run fn for all remaining items + for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { + if br.b.queuedQueries[br.qqIdx].fn != nil { + err := br.b.queuedQueries[br.qqIdx].fn(br) + if err != nil && br.err == nil { + br.err = err + } + } else { br.Exec() } } + br.closed = true err := br.pipeline.Close() @@ -358,12 +424,12 @@ func (br *pipelineBatchResults) earlyError() error { } func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) { - if br.b != nil && br.ix < len(br.b.items) { - bi := br.b.items[br.ix] + if br.b != nil && br.qqIdx < len(br.b.queuedQueries) { + bi := br.b.queuedQueries[br.qqIdx] query = bi.query args = bi.arguments ok = true - br.ix++ + br.qqIdx++ } return } diff --git a/batch_test.go b/batch_test.go index 156e8f8f..2ade0d4a 100644 --- a/batch_test.go +++ b/batch_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "context" "errors" + "fmt" "os" "testing" @@ -148,6 +149,99 @@ func TestConnSendBatch(t *testing.T) { }) } +func TestConnSendBatchQueuedQuery(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test") + + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null + );` + mustExec(t, conn, sql) + + batch := &pgx.Batch{} + + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1).Exec(func(ct pgconn.CommandTag) error { + assert.EqualValues(t, 1, ct.RowsAffected()) + return nil + }) + + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2).Exec(func(ct pgconn.CommandTag) error { + assert.EqualValues(t, 1, ct.RowsAffected()) + return nil + }) + + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3).Exec(func(ct pgconn.CommandTag) error { + assert.EqualValues(t, 1, ct.RowsAffected()) + return nil + }) + + selectFromLedgerExpectedRows := []struct { + id int32 + description string + amount int32 + }{ + {1, "q1", 1}, + {2, "q2", 2}, + {3, "q3", 3}, + } + + batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error { + rowCount := 0 + var id int32 + var description string + var amount int32 + _, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error { + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id) + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description) + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount) + rowCount++ + + return nil + }) + assert.NoError(t, err) + return nil + }) + + batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error { + rowCount := 0 + var id int32 + var description string + var amount int32 + _, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error { + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id) + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description) + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount) + rowCount++ + + return nil + }) + assert.NoError(t, err) + return nil + }) + + batch.Queue("select * from ledger where false").QueryRow(func(row pgx.Row) error { + err := row.Scan(nil, nil, nil) + assert.ErrorIs(t, err, pgx.ErrNoRows) + return nil + }) + + batch.Queue("select sum(amount) from ledger").QueryRow(func(row pgx.Row) error { + var sumAmount int32 + err := row.Scan(&sumAmount) + assert.NoError(t, err) + assert.EqualValues(t, 6, sumAmount) + return nil + }) + + err := conn.SendBatch(context.Background(), batch).Close() + assert.NoError(t, err) + }) +} + func TestConnSendBatchMany(t *testing.T) { t.Parallel() @@ -773,3 +867,59 @@ func TestSendBatchSimpleProtocol(t *testing.T) { assert.EqualValues(t, 3, values[0]) assert.False(t, rows.Next()) } + +func ExampleConn_SendBatch() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + batch := &pgx.Batch{} + batch.Queue("select 1 + 1").QueryRow(func(row pgx.Row) error { + var n int32 + err := row.Scan(&n) + if err != nil { + return err + } + + fmt.Println(n) + + return err + }) + + batch.Queue("select 1 + 2").QueryRow(func(row pgx.Row) error { + var n int32 + err := row.Scan(&n) + if err != nil { + return err + } + + fmt.Println(n) + + return err + }) + + batch.Queue("select 2 + 3").QueryRow(func(row pgx.Row) error { + var n int32 + err := row.Scan(&n) + if err != nil { + return err + } + + fmt.Println(n) + + return err + }) + + err = conn.SendBatch(context.Background(), batch).Close() + if err != nil { + fmt.Printf("SendBatch error: %v", err) + return + } + + // Output: + // 2 + // 3 + // 5 +} diff --git a/conn.go b/conn.go index b8e0b232..1a43a3ca 100644 --- a/conn.go +++ b/conn.go @@ -801,7 +801,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { mode := c.config.DefaultQueryExecMode - for _, bi := range b.items { + for _, bi := range b.queuedQueries { var queryRewriter QueryRewriter sql := bi.query arguments := bi.arguments @@ -830,7 +830,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { } // All other modes use extended protocol and thus can use prepared statements. - for _, bi := range b.items { + for _, bi := range b.queuedQueries { if sd, ok := c.preparedStatements[bi.query]; ok { bi.sd = sd } @@ -852,7 +852,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults { var sb strings.Builder - for i, bi := range b.items { + for i, bi := range b.queuedQueries { if i > 0 { sb.WriteByte(';') } @@ -864,18 +864,18 @@ func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batc } mrr := c.pgConn.Exec(ctx, sb.String()) return &batchResults{ - ctx: ctx, - conn: c, - mrr: mrr, - b: b, - ix: 0, + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + qqIdx: 0, } } func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults { batch := &pgconn.Batch{} - for _, bi := range b.items { + for _, bi := range b.queuedQueries { sd := bi.sd if sd != nil { err := c.eqb.Build(c.typeMap, sd, bi.arguments) @@ -898,11 +898,11 @@ func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchR mrr := c.pgConn.ExecBatch(ctx, batch) return &batchResults{ - ctx: ctx, - conn: c, - mrr: mrr, - b: b, - ix: 0, + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + qqIdx: 0, } } @@ -914,7 +914,7 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueriesIdxMap := make(map[string]int) - for _, bi := range b.items { + for _, bi := range b.queuedQueries { if bi.sd == nil { sd := c.statementCache.Get(bi.query) if sd != nil { @@ -946,7 +946,7 @@ func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueriesIdxMap := make(map[string]int) - for _, bi := range b.items { + for _, bi := range b.queuedQueries { if bi.sd == nil { sd := c.descriptionCache.Get(bi.query) if sd != nil { @@ -973,7 +973,7 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueriesIdxMap := make(map[string]int) - for _, bi := range b.items { + for _, bi := range b.queuedQueries { if bi.sd == nil { if idx, present := distinctNewQueriesIdxMap[bi.query]; present { bi.sd = distinctNewQueries[idx] @@ -1045,7 +1045,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d } // Queue the queries. - for _, bi := range b.items { + for _, bi := range b.queuedQueries { err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments) if err != nil { return &pipelineBatchResults{ctx: ctx, conn: c, err: err} From b6f5cbd15e1a11f13074f1996041fe310961a6d7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Jul 2022 17:56:24 -0500 Subject: [PATCH 1099/1158] Add Conn to Rows interface https://github.com/jackc/pgx/issues/1191 --- pgxpool/rows.go | 5 +++++ rows.go | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/pgxpool/rows.go b/pgxpool/rows.go index aeb65179..0c0a7382 100644 --- a/pgxpool/rows.go +++ b/pgxpool/rows.go @@ -18,6 +18,7 @@ func (errRows) Next() bool { return false } func (e errRows) Scan(dest ...any) error { return e.err } func (e errRows) Values() ([]any, error) { return nil, e.err } func (e errRows) RawValues() [][]byte { return nil } +func (e errRows) Conn() *pgx.Conn { return nil } type errRow struct { err error @@ -86,6 +87,10 @@ func (rows *poolRows) RawValues() [][]byte { return rows.r.RawValues() } +func (rows *poolRows) Conn() *pgx.Conn { + return rows.r.Conn() +} + type poolRow struct { r pgx.Row c *Conn diff --git a/rows.go b/rows.go index ca5533d9..c4fd283c 100644 --- a/rows.go +++ b/rows.go @@ -55,6 +55,10 @@ type Rows interface { // RawValues returns the unparsed bytes of the row values. The returned data is only valid until the next Next // call or the Rows is closed. RawValues() [][]byte + + // Conn returns the underlying *Conn on which the query was executed. This may return nil if Rows did not come from a + // *Conn (e.g. if it was created by RowsFromResultReader) + Conn() *Conn } // Row is a convenience wrapper over Rows that is returned by QueryRow. @@ -310,6 +314,10 @@ func (rows *baseRows) RawValues() [][]byte { return rows.values } +func (rows *baseRows) Conn() *Conn { + return rows.conn +} + type ScanArgError struct { ColumnIndex int Err error From a5b4f888c221566c3034bda659afa460701a256e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Jul 2022 18:16:19 -0500 Subject: [PATCH 1100/1158] Fix flickering test on CI Ensure the conn reads everything expected before closing. --- internal/nbconn/nbconn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/nbconn/nbconn_test.go b/internal/nbconn/nbconn_test.go index c4a8e65b..8b672e4e 100644 --- a/internal/nbconn/nbconn_test.go +++ b/internal/nbconn/nbconn_test.go @@ -294,7 +294,7 @@ func TestInternalNonBlockingWrite(t *testing.T) { }() readBuf := make([]byte, deadlockSize) - _, err = conn.Read(readBuf) + _, err = io.ReadFull(conn, readBuf) require.NoError(t, err) err = conn.Close() From 32ec44f726e1376620ad746a226cd691c4b8d528 Mon Sep 17 00:00:00 2001 From: Eric McCormack Date: Tue, 21 Jun 2022 11:19:46 -0400 Subject: [PATCH 1101/1158] Add support for SslPassword --- .idea/pgconn.iml | 9 ++++++++ config.go | 56 +++++++++++++++++++++++++++++++++++++++++++----- pgconn.go | 2 ++ pgconn_test.go | 55 ++++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 116 insertions(+), 6 deletions(-) create mode 100644 .idea/pgconn.iml diff --git a/.idea/pgconn.iml b/.idea/pgconn.iml new file mode 100644 index 00000000..5e764c4f --- /dev/null +++ b/.idea/pgconn.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/config.go b/config.go index 8fd7efbf..12a48288 100644 --- a/config.go +++ b/config.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/pem" "errors" "fmt" "io" @@ -60,6 +61,9 @@ type Config struct { // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. OnNotification NotificationHandler + // SslPasswordCallback is a callback function to handle Auth callback for SSL Password + SslPasswordCallback SslPasswordCallbackHandler + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -132,6 +136,11 @@ func NetworkAddress(host string, port uint16) (network, address string) { return network, address } +// ParseConfig builds a *Config when sslpasswordcallback function is not provided +func ParseConfig(connString string) (*Config, error) { + return ParseConfigWithSslPasswordCallback(connString, nil) +} + // ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same // defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely matches // the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). See @@ -171,6 +180,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // PGSSLCERT // PGSSLKEY // PGSSLROOTCERT +// PGSSLPASSWORD // PGAPPNAME // PGCONNECT_TIMEOUT // PGTARGETSESSIONATTRS @@ -194,6 +204,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually // changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting // TLCConfig. +// sslPasswordCallback function provide a callback function for sslpassword // // Other known differences with libpq: // @@ -207,7 +218,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // servicefile // libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a // part of the connection string. -func ParseConfig(connString string) (*Config, error) { +func ParseConfigWithSslPasswordCallback(connString string, sslPasswordCallback SslPasswordCallbackHandler) (*Config, error) { defaultSettings := defaultSettings() envSettings := parseEnvSettings() @@ -278,6 +289,7 @@ func ParseConfig(connString string) (*Config, error) { "sslkey": {}, "sslcert": {}, "sslrootcert": {}, + "sslpassword": {}, "krbspn": {}, "krbsrvname": {}, "target_session_attrs": {}, @@ -326,7 +338,7 @@ func ParseConfig(connString string) (*Config, error) { tlsConfigs = append(tlsConfigs, nil) } else { var err error - tlsConfigs, err = configTLS(settings, host) + tlsConfigs, err = configTLS(settings, host, sslPasswordCallback) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} } @@ -406,6 +418,7 @@ func parseEnvSettings() map[string]string { "PGSSLKEY": "sslkey", "PGSSLCERT": "sslcert", "PGSSLROOTCERT": "sslrootcert", + "PGSSLPASSWORD": "sslpassword", "PGTARGETSESSIONATTRS": "target_session_attrs", "PGSERVICE": "service", "PGSERVICEFILE": "servicefile", @@ -592,12 +605,13 @@ func parseServiceSettings(servicefilePath, serviceName string) (map[string]strin // configTLS uses libpq's TLS parameters to construct []*tls.Config. It is // necessary to allow returning multiple TLS configs as sslmode "allow" and // "prefer" allow fallback. -func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, error) { +func configTLS(settings map[string]string, thisHost string, sslPasswordCallback SslPasswordCallbackHandler) ([]*tls.Config, error) { host := thisHost sslmode := settings["sslmode"] sslrootcert := settings["sslrootcert"] sslcert := settings["sslcert"] sslkey := settings["sslkey"] + sslpassword := settings["sslpassword"] // Match libpq default behavior if sslmode == "" { @@ -685,11 +699,43 @@ func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, erro } if sslcert != "" && sslkey != "" { - cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + buf, err := ioutil.ReadFile(sslkey) + if err != nil { + return nil, fmt.Errorf("unable to read sslkey: %w", err) + } + block, _ := pem.Decode(buf) + var pemKey []byte + // If PEM is encrypted, attempt to decrypt using pass phrase + if x509.IsEncryptedPEMBlock(block) { + if sslpassword == "" { + if sslPasswordCallback == nil { + return nil, fmt.Errorf("unable to find sslpassword: %w", err) + } + sslpassword = sslPasswordCallback() + } + // Attempt decryption with pass phrase + // NOTE: only supports RSA (PKCS#1) + decryptedKey, err := x509.DecryptPEMBlock(block, []byte(sslpassword)) + // Should we also provide warning for PKCS#1 needed? + if err != nil { + return nil, fmt.Errorf("unable to decrypt key: %w", err) + } + pemBytes := pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: decryptedKey, + } + pemKey = pem.EncodeToMemory(&pemBytes) + } else { + pemKey = pem.EncodeToMemory(block) + } + certfile, err := ioutil.ReadFile(sslcert) if err != nil { return nil, fmt.Errorf("unable to read cert: %w", err) } - + cert, err := tls.X509KeyPair(certfile, pemKey) + if err != nil { + return nil, fmt.Errorf("unable to load cert: %w", err) + } tlsConfig.Certificates = []tls.Certificate{cert} } diff --git a/pgconn.go b/pgconn.go index 430f4367..67d6af38 100644 --- a/pgconn.go +++ b/pgconn.go @@ -64,6 +64,8 @@ type NoticeHandler func(*PgConn, *Notice) // notice event. type NotificationHandler func(*PgConn, *Notification) +type SslPasswordCallbackHandler func() (string) + // Frontend used to receive messages from backend. type Frontend interface { Receive() (pgproto3.BackendMessage, error) diff --git a/pgconn_test.go b/pgconn_test.go index 32186fc6..d9adda99 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1,6 +1,7 @@ package pgconn_test import ( + "bufio" "bytes" "compress/gzip" "context" @@ -63,7 +64,59 @@ func TestConnectTLS(t *testing.T) { t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") } - conn, err := pgconn.Connect(context.Background(), connString) + var conn *pgconn.PgConn + var err error + + isSslPasswrodEmpty := strings.HasSuffix(connString, "sslpassword=") + + if isSslPasswrodEmpty { + config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword) + require.Nil(t, err) + + conn, err = pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + } else { + conn, err = pgconn.Connect(context.Background(), connString) + require.NoError(t, err) + } + + if _, ok := conn.Conn().(*tls.Conn); !ok { + t.Error("not a TLS connection") + } + + closeConn(t, conn) +} + +func GetSslPassword() string { + readFile, err := os.Open("data.txt") + if err != nil { + fmt.Println(err) + } + fileScanner := bufio.NewScanner(readFile) + fileScanner.Split(bufio.ScanLines) + for fileScanner.Scan() { + line := fileScanner.Text() + if strings.HasPrefix(line, "sslpassword=") { + index := len("sslpassword=") + line := line[index:] + return line + } + } + return "" +} + +func TestConnectTLSCallback(t *testing.T) { + t.Parallel() + + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } + + config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword) + require.Nil(t, err) + + conn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) if _, ok := conn.Conn().(*tls.Conn); !ok { From c56b38c1f64d21b05026682c3c6a468ac4a998d3 Mon Sep 17 00:00:00 2001 From: Eric McCormack Date: Fri, 8 Jul 2022 14:24:59 -0400 Subject: [PATCH 1102/1158] SSL password - changes based on community feedback --- config.go | 27 ++++++++++++++++----------- pgconn.go | 2 -- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/config.go b/config.go index 12a48288..fa9e3801 100644 --- a/config.go +++ b/config.go @@ -26,6 +26,7 @@ import ( type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error +type GetSSLPasswordFunc func(ctx context.Context) string // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A // manually initialized Config will cause ConnectConfig to panic. @@ -61,12 +62,14 @@ type Config struct { // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. OnNotification NotificationHandler - // SslPasswordCallback is a callback function to handle Auth callback for SSL Password - SslPasswordCallback SslPasswordCallbackHandler - createdByParseConfig bool // Used to enforce created by ParseConfig rule. } +//Congig Options such as getsslpassword function +type ParseConfigOptions struct { + GetSSLPassword GetSSLPasswordFunc +} + // Copy returns a deep copy of the config that is safe to use and modify. // The only exception is the TLSConfig field: // according to the tls.Config docs it must not be modified after creation. @@ -138,7 +141,8 @@ func NetworkAddress(host string, port uint16) (network, address string) { // ParseConfig builds a *Config when sslpasswordcallback function is not provided func ParseConfig(connString string) (*Config, error) { - return ParseConfigWithSslPasswordCallback(connString, nil) + var parseConfigOptions ParseConfigOptions + return ParseConfigWithOptions(connString, parseConfigOptions) } // ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same @@ -204,7 +208,7 @@ func ParseConfig(connString string) (*Config, error) { // which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually // changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting // TLCConfig. -// sslPasswordCallback function provide a callback function for sslpassword +// ParseConfigOptions options for parse config // // Other known differences with libpq: // @@ -218,7 +222,7 @@ func ParseConfig(connString string) (*Config, error) { // servicefile // libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a // part of the connection string. -func ParseConfigWithSslPasswordCallback(connString string, sslPasswordCallback SslPasswordCallbackHandler) (*Config, error) { +func ParseConfigWithOptions(connString string, parseConfigOptions ParseConfigOptions) (*Config, error) { defaultSettings := defaultSettings() envSettings := parseEnvSettings() @@ -338,7 +342,7 @@ func ParseConfigWithSslPasswordCallback(connString string, sslPasswordCallback S tlsConfigs = append(tlsConfigs, nil) } else { var err error - tlsConfigs, err = configTLS(settings, host, sslPasswordCallback) + tlsConfigs, err = configTLS(settings, host, parseConfigOptions) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} } @@ -605,7 +609,7 @@ func parseServiceSettings(servicefilePath, serviceName string) (map[string]strin // configTLS uses libpq's TLS parameters to construct []*tls.Config. It is // necessary to allow returning multiple TLS configs as sslmode "allow" and // "prefer" allow fallback. -func configTLS(settings map[string]string, thisHost string, sslPasswordCallback SslPasswordCallbackHandler) ([]*tls.Config, error) { +func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) { host := thisHost sslmode := settings["sslmode"] sslrootcert := settings["sslrootcert"] @@ -708,10 +712,11 @@ func configTLS(settings map[string]string, thisHost string, sslPasswordCallback // If PEM is encrypted, attempt to decrypt using pass phrase if x509.IsEncryptedPEMBlock(block) { if sslpassword == "" { - if sslPasswordCallback == nil { - return nil, fmt.Errorf("unable to find sslpassword: %w", err) + if(parseConfigOptions.GetSSLPassword != nil){ + sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) + }else{ + return nil, fmt.Errorf("unable to find sslpassword") } - sslpassword = sslPasswordCallback() } // Attempt decryption with pass phrase // NOTE: only supports RSA (PKCS#1) diff --git a/pgconn.go b/pgconn.go index 67d6af38..430f4367 100644 --- a/pgconn.go +++ b/pgconn.go @@ -64,8 +64,6 @@ type NoticeHandler func(*PgConn, *Notice) // notice event. type NotificationHandler func(*PgConn, *Notification) -type SslPasswordCallbackHandler func() (string) - // Frontend used to receive messages from backend. type Frontend interface { Receive() (pgproto3.BackendMessage, error) From 7402796e02f1112dd0a3c972c2f8e0580b762988 Mon Sep 17 00:00:00 2001 From: Eric McCormack Date: Mon, 11 Jul 2022 12:20:10 -0400 Subject: [PATCH 1103/1158] Delete pgconn.iml --- .idea/pgconn.iml | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 .idea/pgconn.iml diff --git a/.idea/pgconn.iml b/.idea/pgconn.iml deleted file mode 100644 index 5e764c4f..00000000 --- a/.idea/pgconn.iml +++ /dev/null @@ -1,9 +0,0 @@ - - - - - - - - - \ No newline at end of file From cdd2cc41244843d1aaaf47efa89ff1f8dce6e3c1 Mon Sep 17 00:00:00 2001 From: "yun.xu" Date: Tue, 19 Jul 2022 10:36:38 -0400 Subject: [PATCH 1104/1158] EC-2198 change for sslpassword --- config.go | 21 ++++++++---- pgconn.go | 12 +++++++ pgconn_test.go | 89 ++++++++++++++++++++++---------------------------- 3 files changed, 66 insertions(+), 56 deletions(-) diff --git a/config.go b/config.go index fa9e3801..2e038304 100644 --- a/config.go +++ b/config.go @@ -709,22 +709,31 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P } block, _ := pem.Decode(buf) var pemKey []byte + var decryptedKey []byte + var decryptedError error // If PEM is encrypted, attempt to decrypt using pass phrase if x509.IsEncryptedPEMBlock(block) { - if sslpassword == "" { + // Attempt decryption with pass phrase + // NOTE: only supports RSA (PKCS#1) + if(sslpassword != ""){ + decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) + } + //if sslpassword not provided or has decryption error when use it + //try to find sslpassword with callback function + if (sslpassword == "" || decryptedError!= nil) { if(parseConfigOptions.GetSSLPassword != nil){ sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) - }else{ + } + if(sslpassword == ""){ return nil, fmt.Errorf("unable to find sslpassword") } } - // Attempt decryption with pass phrase - // NOTE: only supports RSA (PKCS#1) - decryptedKey, err := x509.DecryptPEMBlock(block, []byte(sslpassword)) + decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) // Should we also provide warning for PKCS#1 needed? - if err != nil { + if decryptedError != nil { return nil, fmt.Errorf("unable to decrypt key: %w", err) } + pemBytes := pem.Block{ Type: "RSA PRIVATE KEY", Bytes: decryptedKey, diff --git a/pgconn.go b/pgconn.go index 430f4367..f582f5b8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -109,6 +109,18 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) { return ConnectConfig(ctx, config) } +// Connect establishes a connection to a PostgreSQL server using the environment +// and connString (in URL or DSN format) and ParseConfigOptions +// to provide configuration. See documentation for ParseConfig for details. ctx can be used to cancel a connect attempt. +func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { + config, err := ParseConfigWithOptions(connString, parseConfigOptions) + if err != nil { + return nil, err + } + + return ConnectConfig(ctx, config) +} + // Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with // ParseConfig. ctx can be used to cancel a connect attempt. // diff --git a/pgconn_test.go b/pgconn_test.go index d9adda99..9a52abf6 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1,7 +1,6 @@ package pgconn_test import ( - "bufio" "bytes" "compress/gzip" "context" @@ -54,6 +53,35 @@ func TestConnect(t *testing.T) { } } +func TestConnectWithOption(t *testing.T) { + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } + var sslOptions pgconn.ParseConfigOptions + sslOptions.GetSSLPassword = GetSSLPassword + conn, err := pgconn.ConnectWithOptions(context.Background(), connString, sslOptions) + require.NoError(t, err) + + closeConn(t, conn) + }) + } +} + // TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure // connection. func TestConnectTLS(t *testing.T) { @@ -67,58 +95,14 @@ func TestConnectTLS(t *testing.T) { var conn *pgconn.PgConn var err error - isSslPasswrodEmpty := strings.HasSuffix(connString, "sslpassword=") - - if isSslPasswrodEmpty { - config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword) - require.Nil(t, err) - - conn, err = pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - } else { - conn, err = pgconn.Connect(context.Background(), connString) - require.NoError(t, err) - } - - if _, ok := conn.Conn().(*tls.Conn); !ok { - t.Error("not a TLS connection") - } - - closeConn(t, conn) -} - -func GetSslPassword() string { - readFile, err := os.Open("data.txt") - if err != nil { - fmt.Println(err) - } - fileScanner := bufio.NewScanner(readFile) - fileScanner.Split(bufio.ScanLines) - for fileScanner.Scan() { - line := fileScanner.Text() - if strings.HasPrefix(line, "sslpassword=") { - index := len("sslpassword=") - line := line[index:] - return line - } - } - return "" -} - -func TestConnectTLSCallback(t *testing.T) { - t.Parallel() - - connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") - } - - config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword) + var sslOptions pgconn.ParseConfigOptions + sslOptions.GetSSLPassword = GetSSLPassword + config, err := pgconn.ParseConfigWithOptions(connString, sslOptions) require.Nil(t, err) - conn, err := pgconn.ConnectConfig(context.Background(), config) + conn, err = pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) - + if _, ok := conn.Conn().(*tls.Conn); !ok { t.Error("not a TLS connection") } @@ -2180,3 +2164,8 @@ func Example() { // 3 // SELECT 3 } + +func GetSSLPassword(ctx context.Context) string { + connString := os.Getenv("PGX_SSL_PASSWORD") + return connString +} \ No newline at end of file From 69b99209fb3add083fca198fe026294def5312a5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 20 Jul 2022 06:06:54 -0500 Subject: [PATCH 1105/1158] Run go fmt --- config.go | 20 ++++++++++---------- defaults.go | 1 + pgconn.go | 4 ++-- pgconn_test.go | 14 +++++++------- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/config.go b/config.go index 2e038304..dd023a96 100644 --- a/config.go +++ b/config.go @@ -67,7 +67,7 @@ type Config struct { //Congig Options such as getsslpassword function type ParseConfigOptions struct { - GetSSLPassword GetSSLPasswordFunc + GetSSLPassword GetSSLPasswordFunc } // Copy returns a deep copy of the config that is safe to use and modify. @@ -715,25 +715,25 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P if x509.IsEncryptedPEMBlock(block) { // Attempt decryption with pass phrase // NOTE: only supports RSA (PKCS#1) - if(sslpassword != ""){ - decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) + if sslpassword != "" { + decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) } //if sslpassword not provided or has decryption error when use it //try to find sslpassword with callback function - if (sslpassword == "" || decryptedError!= nil) { - if(parseConfigOptions.GetSSLPassword != nil){ - sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) + if sslpassword == "" || decryptedError != nil { + if parseConfigOptions.GetSSLPassword != nil { + sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) } - if(sslpassword == ""){ - return nil, fmt.Errorf("unable to find sslpassword") + if sslpassword == "" { + return nil, fmt.Errorf("unable to find sslpassword") } } decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) // Should we also provide warning for PKCS#1 needed? - if decryptedError != nil { + if decryptedError != nil { return nil, fmt.Errorf("unable to decrypt key: %w", err) } - + pemBytes := pem.Block{ Type: "RSA PRIVATE KEY", Bytes: decryptedKey, diff --git a/defaults.go b/defaults.go index f69cad31..c7209fdd 100644 --- a/defaults.go +++ b/defaults.go @@ -1,3 +1,4 @@ +//go:build !windows // +build !windows package pgconn diff --git a/pgconn.go b/pgconn.go index f582f5b8..c817de49 100644 --- a/pgconn.go +++ b/pgconn.go @@ -109,8 +109,8 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) { return ConnectConfig(ctx, config) } -// Connect establishes a connection to a PostgreSQL server using the environment -// and connString (in URL or DSN format) and ParseConfigOptions +// Connect establishes a connection to a PostgreSQL server using the environment +// and connString (in URL or DSN format) and ParseConfigOptions // to provide configuration. See documentation for ParseConfig for details. ctx can be used to cancel a connect attempt. func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { config, err := ParseConfigWithOptions(connString, parseConfigOptions) diff --git a/pgconn_test.go b/pgconn_test.go index 9a52abf6..c08fa54a 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -72,8 +72,8 @@ func TestConnectWithOption(t *testing.T) { if connString == "" { t.Skipf("Skipping due to missing environment variable %v", tt.env) } - var sslOptions pgconn.ParseConfigOptions - sslOptions.GetSSLPassword = GetSSLPassword + var sslOptions pgconn.ParseConfigOptions + sslOptions.GetSSLPassword = GetSSLPassword conn, err := pgconn.ConnectWithOptions(context.Background(), connString, sslOptions) require.NoError(t, err) @@ -97,12 +97,12 @@ func TestConnectTLS(t *testing.T) { var sslOptions pgconn.ParseConfigOptions sslOptions.GetSSLPassword = GetSSLPassword - config, err := pgconn.ParseConfigWithOptions(connString, sslOptions) + config, err := pgconn.ParseConfigWithOptions(connString, sslOptions) require.Nil(t, err) - conn, err = pgconn.ConnectConfig(context.Background(), config) + conn, err = pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) - + if _, ok := conn.Conn().(*tls.Conn); !ok { t.Error("not a TLS connection") } @@ -2167,5 +2167,5 @@ func Example() { func GetSSLPassword(ctx context.Context) string { connString := os.Getenv("PGX_SSL_PASSWORD") - return connString -} \ No newline at end of file + return connString +} From fe0fb3b24dca7c395038dc81bdb84e6706e50979 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 20 Jul 2022 06:28:08 -0500 Subject: [PATCH 1106/1158] Clean up docs for new ParseConfigOptions feature --- config.go | 73 ++++++++++++++++++++++++++------------------------ pgconn.go | 6 ++--- pgconn_test.go | 2 +- 3 files changed, 42 insertions(+), 39 deletions(-) diff --git a/config.go b/config.go index dd023a96..2277dc1d 100644 --- a/config.go +++ b/config.go @@ -65,8 +65,10 @@ type Config struct { createdByParseConfig bool // Used to enforce created by ParseConfig rule. } -//Congig Options such as getsslpassword function +// ParseConfigOptions contains options that control how a config is built such as getsslpassword. type ParseConfigOptions struct { + // GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function + // PQsetSSLKeyPassHook_OpenSSL. GetSSLPassword GetSSLPasswordFunc } @@ -139,16 +141,10 @@ func NetworkAddress(host string, port uint16) (network, address string) { return network, address } -// ParseConfig builds a *Config when sslpasswordcallback function is not provided -func ParseConfig(connString string) (*Config, error) { - var parseConfigOptions ParseConfigOptions - return ParseConfigWithOptions(connString, parseConfigOptions) -} - -// ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same -// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely matches -// the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). See -// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be +// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It +// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely +// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). +// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be // empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. // // # Example DSN @@ -172,22 +168,22 @@ func ParseConfig(connString string) (*Config, error) { // ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed // via database URL or DSN: // -// PGHOST -// PGPORT -// PGDATABASE -// PGUSER -// PGPASSWORD -// PGPASSFILE -// PGSERVICE -// PGSERVICEFILE -// PGSSLMODE -// PGSSLCERT -// PGSSLKEY -// PGSSLROOTCERT +// PGHOST +// PGPORT +// PGDATABASE +// PGUSER +// PGPASSWORD +// PGPASSFILE +// PGSERVICE +// PGSERVICEFILE +// PGSSLMODE +// PGSSLCERT +// PGSSLKEY +// PGSSLROOTCERT // PGSSLPASSWORD -// PGAPPNAME -// PGCONNECT_TIMEOUT -// PGTARGETSESSIONATTRS +// PGAPPNAME +// PGCONNECT_TIMEOUT +// PGTARGETSESSIONATTRS // // See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. // @@ -207,8 +203,7 @@ func ParseConfig(connString string) (*Config, error) { // sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback // which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually // changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting -// TLCConfig. -// ParseConfigOptions options for parse config +// TLSConfig. // // Other known differences with libpq: // @@ -217,12 +212,20 @@ func ParseConfig(connString string) (*Config, error) { // // In addition, ParseConfig accepts the following options: // -// min_read_buffer_size -// The minimum size of the internal read buffer. Default 8192. -// servicefile -// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a -// part of the connection string. -func ParseConfigWithOptions(connString string, parseConfigOptions ParseConfigOptions) (*Config, error) { +// min_read_buffer_size +// The minimum size of the internal read buffer. Default 8192. +// servicefile +// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a +// part of the connection string. +func ParseConfig(connString string) (*Config, error) { + var parseConfigOptions ParseConfigOptions + return ParseConfigWithOptions(connString, parseConfigOptions) +} + +// ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard +// C library libpq. options contains settings that cannot be specified in a connString such as providing a function to +// get the SSL password. +func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) { defaultSettings := defaultSettings() envSettings := parseEnvSettings() @@ -342,7 +345,7 @@ func ParseConfigWithOptions(connString string, parseConfigOptions ParseConfigOpt tlsConfigs = append(tlsConfigs, nil) } else { var err error - tlsConfigs, err = configTLS(settings, host, parseConfigOptions) + tlsConfigs, err = configTLS(settings, host, options) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} } diff --git a/pgconn.go b/pgconn.go index c817de49..17f19e95 100644 --- a/pgconn.go +++ b/pgconn.go @@ -109,9 +109,9 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) { return ConnectConfig(ctx, config) } -// Connect establishes a connection to a PostgreSQL server using the environment -// and connString (in URL or DSN format) and ParseConfigOptions -// to provide configuration. See documentation for ParseConfig for details. ctx can be used to cancel a connect attempt. +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) +// and ParseConfigOptions to provide additional configuration. See documentation for ParseConfig for details. ctx can be +// used to cancel a connect attempt. func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { config, err := ParseConfigWithOptions(connString, parseConfigOptions) if err != nil { diff --git a/pgconn_test.go b/pgconn_test.go index c08fa54a..a4f0ec63 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -53,7 +53,7 @@ func TestConnect(t *testing.T) { } } -func TestConnectWithOption(t *testing.T) { +func TestConnectWithOptions(t *testing.T) { tests := []struct { name string env string From 7c819729386d5e6eacebe396fcf17d9e04fa5a75 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 07:04:23 -0500 Subject: [PATCH 1107/1158] Update line wrapping in docs --- doc.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/doc.go b/doc.go index b10ab1df..debaedb3 100644 --- a/doc.go +++ b/doc.go @@ -55,7 +55,8 @@ pgx implements Query and Scan in the familiar database/sql style. // No errors found - do something with sum -ForEachRow can be used to execute a callback function for every row. This is often easier than iterating over rows directly. +ForEachRow can be used to execute a callback function for every row. This is often easier than iterating over rows +directly. var sum, n int32 rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 10) @@ -164,9 +165,10 @@ is recommended that this situation be avoided by implementing pgx interfaces on Composite types and row values -Row values and composite types are represented as pgtype.Record (https://pkg.go.dev/github.com/jackc/pgtype?tab=doc#Record). -It is possible to get values of your custom type by implementing DecodeBinary interface. Decoding into -pgtype.Record first can simplify process by avoiding dealing with raw protocol directly. +Row values and composite types are represented as pgtype.Record +(https://pkg.go.dev/github.com/jackc/pgtype?tab=doc#Record). It is possible to get values of your custom type by +implementing DecodeBinary interface. Decoding into pgtype.Record first can simplify process by avoiding dealing with raw +protocol directly. For example: @@ -259,8 +261,8 @@ for information on how to customize or disable the statement cache. Copy Protocol Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. CopyFrom accepts a -CopyFromSource interface. If the data is already in a [][]any use CopyFromRows to wrap it in a CopyFromSource -interface. Or implement CopyFromSource to avoid buffering the entire data set in memory. +CopyFromSource interface. If the data is already in a [][]any use CopyFromRows to wrap it in a CopyFromSource interface. +Or implement CopyFromSource to avoid buffering the entire data set in memory. rows := [][]any{ {"John", "Smith", int32(36)}, From f07ad22f148d269805f7e9474417f118c04f7bd8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 07:04:32 -0500 Subject: [PATCH 1108/1158] Update PgBouncer docs --- doc.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc.go b/doc.go index debaedb3..1edef5cc 100644 --- a/doc.go +++ b/doc.go @@ -322,7 +322,7 @@ used to access this lower layer. PgBouncer -pgx is compatible with PgBouncer in two modes. One is when the connection has a statement cache in "describe" mode. The -other is when the connection is using the simple protocol. This can be set with the PreferSimpleProtocol config option. +By default pgx automatically uses prepared statements. Prepared statements are incompaptible with PgBouncer. This can be +disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode. */ package pgx From d433545662aa5a7351e610f31bf800994a185fc2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 07:06:22 -0500 Subject: [PATCH 1109/1158] Remove obsolete doc --- doc.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/doc.go b/doc.go index 1edef5cc..12d89dca 100644 --- a/doc.go +++ b/doc.go @@ -209,10 +209,6 @@ For example: result := MyType{} err := conn.QueryRow(context.Background(), "select row(1, 'foo'::text)", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&r) -Raw Bytes Mapping - -[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified to PostgreSQL. - Transactions Transactions are started by calling Begin. From cb48716c67cd6d1341ef565fff12467993ae9b40 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 07:31:14 -0500 Subject: [PATCH 1110/1158] Update to new package path --- doc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc.go b/doc.go index 12d89dca..98eea261 100644 --- a/doc.go +++ b/doc.go @@ -150,7 +150,7 @@ pgx converts netip.Prefix and netip.Addr to and from inet and cidr PostgreSQL ty Custom Type Support pgx includes support for the common data types like integers, floats, strings, dates, and times that have direct -mappings between Go and SQL. In addition, pgx uses the github.com/jackc/pgtype library to support more types. See +mappings between Go and SQL. In addition, pgx uses the github.com/jackc/pgx/v5/pgtype library to support more types. See documention for that library for instructions on how to implement custom types. See example_custom_type_test.go for an example of a custom type for the PostgreSQL point type. From 178a84261faae68cc8b71e1a35389c7354e0f7b7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 07:53:02 -0500 Subject: [PATCH 1111/1158] Improve Query docs --- conn.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/conn.go b/conn.go index 1a43a3ca..53004327 100644 --- a/conn.go +++ b/conn.go @@ -603,12 +603,16 @@ type QueryRewriter interface { RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) } -// Query executes sql with args. It is safe to attempt to read from the returned Rows even if an error is returned. The -// error will be the available in rows.Err() after rows are closed. So it is allowed to ignore the error returned from -// Query and handle it in Rows. +// Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query +// and initializing Rows will be returned. Err() on the returned Rows must be checked after the Rows is closed to +// determine if the query executed successfully. // -// Err() on the returned Rows must be checked after the Rows is closed to determine if the query executed successfully -// as some errors can only be detected by reading the entire response. e.g. A divide by zero error on the last row. +// The returned Rows must be closed before the connection can be used again. It is safe to attempt to read from the +// returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It +// is allowed to ignore the error returned from Query and handle it in Rows. +// +// An implementor of QueryRewriter may be passed as the first element of args. It can rewrite the sql and change or +// replace args. For example, NamedArgs is QueryRewriter that implements named arguments. // // For extra control over how the query is executed, the types QueryExecMode, QueryResultFormats, and // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely From 4087119005d5da178eab1538958ac9a33f896b56 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 08:24:44 -0500 Subject: [PATCH 1112/1158] Add Conn.Query example --- query_test.go | 71 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/query_test.go b/query_test.go index 59cf9355..f10078bf 100644 --- a/query_test.go +++ b/query_test.go @@ -1898,3 +1898,74 @@ func TestQueryWithQueryRewriter(t *testing.T) { require.NoError(t, rows.Err()) }) } + +// This example uses Query without using any helpers to read the results. Normally CollectRows, ForEachRow, or another +// helper function should be used. +func ExampleConn_Query() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + // Setup example schema and data. + _, err = conn.Exec(ctx, ` +create temporary table products ( + id int primary key generated by default as identity, + name varchar(100) not null, + price int not null +); + +insert into products (name, price) values + ('Cheeseburger', 10), + ('Double Cheeseburger', 14), + ('Fries', 5), + ('Soft Drink', 3); +`) + if err != nil { + fmt.Printf("Unable to setup example schema and data: %v", err) + return + } + + rows, err := conn.Query(ctx, "select name, price from products where price < $1 order by price desc", 12) + + // It is unnecessary to check err. If an error occurred it will be returned by rows.Err() later. But in rare + // cases it may be useful to detect the error as early as possible. + if err != nil { + fmt.Printf("Query error: %v", err) + return + } + + // Ensure rows is closed. It is safe to close rows multiple times. + defer rows.Close() + + // Iterate through the result set + for rows.Next() { + var name string + var price int32 + + err = rows.Scan(&name, &price) + if err != nil { + fmt.Printf("Scan error: %v", err) + return + } + + fmt.Printf("%s: $%d\n", name, price) + } + + // rows is closed automatically when rows.Next() returns false so it is not necessary to manually close rows. + + // The first error encountered by the original Query call, rows.Next or rows.Scan will be returned here. + if rows.Err() != nil { + fmt.Printf("rows error: %v", err) + return + } + + // Output: + // Cheeseburger: $10 + // Fries: $5 + // Soft Drink: $3 +} From 9a61fc250f304200a241b0bd502a7f8ce041b021 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 08:31:37 -0500 Subject: [PATCH 1113/1158] Recommend CollectRows in ConnQuery docs --- conn.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/conn.go b/conn.go index 53004327..6b84e13d 100644 --- a/conn.go +++ b/conn.go @@ -611,6 +611,10 @@ type QueryRewriter interface { // returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It // is allowed to ignore the error returned from Query and handle it in Rows. // +// It is possible for a query to return one or more rows before encountering an error. In most cases the rows should be +// collected before processing rather than processed while receiving each row. This avoids the possibility of the +// application processing rows from a query that the server rejected. The CollectRows function is useful here. +// // An implementor of QueryRewriter may be passed as the first element of args. It can rewrite the sql and change or // replace args. For example, NamedArgs is QueryRewriter that implements named arguments. // From 68b7e12df2a6587b3962fbcf53d28d4ffde5956d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 08:52:01 -0500 Subject: [PATCH 1114/1158] Add examples --- rows_test.go | 136 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 135 insertions(+), 1 deletion(-) diff --git a/rows_test.go b/rows_test.go index cbc26887..e25ceeea 100644 --- a/rows_test.go +++ b/rows_test.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "testing" + "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" @@ -147,6 +148,35 @@ func TestCollectRows(t *testing.T) { }) } +// This example uses CollectRows with a manually written collector function. In most cases RowTo, RowToAddrOf, +// RowToStructByPos, RowToAddrOfStructByPos, or another generic function would be used. +func ExampleCollectRows() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + rows, _ := conn.Query(ctx, `select n from generate_series(1, 5) n`) + numbers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + fmt.Println(numbers) + + // Output: + // [1 2 3 4 5] +} + func TestCollectOneRow(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { rows, _ := conn.Query(ctx, `select 42`) @@ -201,6 +231,29 @@ func TestRowTo(t *testing.T) { }) } +func ExampleRowTo() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + rows, _ := conn.Query(ctx, `select n from generate_series(1, 5) n`) + numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32]) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + fmt.Println(numbers) + + // Output: + // [1 2 3 4 5] +} + func TestRowToAddrOf(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`) @@ -214,6 +267,35 @@ func TestRowToAddrOf(t *testing.T) { }) } +func ExampleRowToAddrOf() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + rows, _ := conn.Query(ctx, `select n from generate_series(1, 5) n`) + pNumbers, err := pgx.CollectRows(rows, pgx.RowToAddrOf[int32]) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + for _, p := range pNumbers { + fmt.Println(*p) + } + + // Output: + // 1 + // 2 + // 3 + // 4 + // 5 +} + func TestRowToMap(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { rows, _ := conn.Query(ctx, `select 'Joe' as name, n as age from generate_series(0, 9) n`) @@ -228,7 +310,7 @@ func TestRowToMap(t *testing.T) { }) } -func TestRowToStructPos(t *testing.T) { +func TestRowToStructByPos(t *testing.T) { type person struct { Name string Age int32 @@ -247,6 +329,58 @@ func TestRowToStructPos(t *testing.T) { }) } +func ExampleRowToStructByPos() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + // Setup example schema and data. + _, err = conn.Exec(ctx, ` +create temporary table products ( + id int primary key generated by default as identity, + name varchar(100) not null, + price int not null +); + +insert into products (name, price) values + ('Cheeseburger', 10), + ('Double Cheeseburger', 14), + ('Fries', 5), + ('Soft Drink', 3); +`) + if err != nil { + fmt.Printf("Unable to setup example schema and data: %v", err) + return + } + + type product struct { + ID int32 + Name string + Price int32 + } + + rows, _ := conn.Query(ctx, "select * from products where price < $1 order by price desc", 12) + products, err := pgx.CollectRows(rows, pgx.RowToStructByPos[product]) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + for _, p := range products { + fmt.Printf("%s: $%d\n", p.Name, p.Price) + } + + // Output: + // Cheeseburger: $10 + // Fries: $5 + // Soft Drink: $3 +} + func TestRowToAddrOfStructPos(t *testing.T) { type person struct { Name string From 83780b85b5e09fed5fadd3c94f30e9ec2bb5c574 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 08:54:59 -0500 Subject: [PATCH 1115/1158] Remove pgx logging code moved to tracelog --- log/testingadapter/adapter.go | 4 +- logger.go | 106 ---------------------------------- 2 files changed, 2 insertions(+), 108 deletions(-) delete mode 100644 logger.go diff --git a/log/testingadapter/adapter.go b/log/testingadapter/adapter.go index 65c14157..c901a6a6 100644 --- a/log/testingadapter/adapter.go +++ b/log/testingadapter/adapter.go @@ -6,7 +6,7 @@ import ( "context" "fmt" - "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/tracelog" ) // TestingLogger interface defines the subset of testing.TB methods used by this @@ -23,7 +23,7 @@ func NewLogger(l TestingLogger) *Logger { return &Logger{l: l} } -func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]any) { +func (l *Logger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) { logArgs := make([]any, 0, 2+len(data)) logArgs = append(logArgs, level, msg) for k, v := range data { diff --git a/logger.go b/logger.go deleted file mode 100644 index cd1255d1..00000000 --- a/logger.go +++ /dev/null @@ -1,106 +0,0 @@ -package pgx - -import ( - "context" - "encoding/hex" - "errors" - "fmt" -) - -// The values for log levels are chosen such that the zero value means that no -// log level was specified. -const ( - LogLevelTrace = 6 - LogLevelDebug = 5 - LogLevelInfo = 4 - LogLevelWarn = 3 - LogLevelError = 2 - LogLevelNone = 1 -) - -// LogLevel represents the pgx logging level. See LogLevel* constants for -// possible values. -type LogLevel int - -func (ll LogLevel) String() string { - switch ll { - case LogLevelTrace: - return "trace" - case LogLevelDebug: - return "debug" - case LogLevelInfo: - return "info" - case LogLevelWarn: - return "warn" - case LogLevelError: - return "error" - case LogLevelNone: - return "none" - default: - return fmt.Sprintf("invalid level %d", ll) - } -} - -// Logger is the interface used to get logging from pgx internals. -type Logger interface { - // Log a message at the given level with data key/value pairs. data may be nil. - Log(ctx context.Context, level LogLevel, msg string, data map[string]any) -} - -// LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface -type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) - -// Log delegates the logging request to the wrapped function -func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) { - f(ctx, level, msg, data) -} - -// LogLevelFromString converts log level string to constant -// -// Valid levels: -// trace -// debug -// info -// warn -// error -// none -func LogLevelFromString(s string) (LogLevel, error) { - switch s { - case "trace": - return LogLevelTrace, nil - case "debug": - return LogLevelDebug, nil - case "info": - return LogLevelInfo, nil - case "warn": - return LogLevelWarn, nil - case "error": - return LogLevelError, nil - case "none": - return LogLevelNone, nil - default: - return 0, errors.New("invalid log level") - } -} - -func logQueryArgs(args []any) []any { - logArgs := make([]any, 0, len(args)) - - for _, a := range args { - switch v := a.(type) { - case []byte: - if len(v) < 64 { - a = hex.EncodeToString(v) - } else { - a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64) - } - case string: - if len(v) > 64 { - a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64) - } - } - logArgs = append(logArgs, a) - } - - return logArgs -} From e487ab08860cd01ac8c8987b63b68030edb8c970 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 09:04:03 -0500 Subject: [PATCH 1116/1158] Docs should emphasize CollectRows and ForEachRow --- doc.go | 37 ++++++++----------------------------- 1 file changed, 8 insertions(+), 29 deletions(-) diff --git a/doc.go b/doc.go index 98eea261..3b0ac90b 100644 --- a/doc.go +++ b/doc.go @@ -21,39 +21,18 @@ concurrency safe connection pool. Query Interface -pgx implements Query and Scan in the familiar database/sql style. +pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and +ForEachRow that are a simpler and safer way of processing rows than manually calling rows.Next(), rows.Scan, and +rows.Err(). - var sum int32 +CollectRows can be used collect all returned rows into a slice. - // Send the query to the server. The returned rows MUST be closed - // before conn can be used again. - rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) + rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 5) + numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32]) if err != nil { - return err + return err } - - // rows.Close is called by rows.Next when all rows are read - // or an error occurs in Next or Scan. So it may optionally be - // omitted if nothing in the rows.Next loop can panic. It is - // safe to close rows multiple times. - defer rows.Close() - - // Iterate through the result set - for rows.Next() { - var n int32 - err = rows.Scan(&n) - if err != nil { - return err - } - sum += n - } - - // Any errors encountered by rows.Next or rows.Scan will be returned here - if rows.Err() != nil { - return rows.Err() - } - - // No errors found - do something with sum + // numbers => [1 2 3 4 5] ForEachRow can be used to execute a callback function for every row. This is often easier than iterating over rows directly. From 3595561d9a533cb4e57124674b697b0e2d96a929 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 09:29:25 -0500 Subject: [PATCH 1117/1158] More doc improvements --- doc.go | 132 +++----------------------------------------------- pgtype/doc.go | 63 +++++++++++++++++++++++- stdlib/sql.go | 4 +- 3 files changed, 72 insertions(+), 127 deletions(-) diff --git a/doc.go b/doc.go index 3b0ac90b..4cd3861b 100644 --- a/doc.go +++ b/doc.go @@ -1,8 +1,9 @@ // Package pgx is a PostgreSQL database driver. /* -pgx provides lower level access to PostgreSQL than the standard database/sql. It remains as similar to the database/sql -interface as possible while providing better speed and access to PostgreSQL specific features. Import -github.com/jackc/pgx/v4/stdlib to use pgx as a database/sql compatible driver. +pgx provides a native PostgreSQL driver and can act as a database/sql driver. The native PostgreSQL interface is similar +to the database/sql interface while providing better speed and access to PostgreSQL specific features. Use +github.com/jackc/pgx/v5/stdlib to use pgx as a database/sql compatible driver. See that package's documentation for +details. Establishing a Connection @@ -32,7 +33,7 @@ CollectRows can be used collect all returned rows into a slice. if err != nil { return err } - // numbers => [1 2 3 4 5] + // numbers => [1 2 3 4 5] ForEachRow can be used to execute a callback function for every row. This is often easier than iterating over rows directly. @@ -66,127 +67,10 @@ Use Exec to execute a query that does not return a result set. return errors.New("No row found to delete") } -Base Type Mapping +PostgreSQL Data Types -pgx maps between all common base types directly between Go and PostgreSQL. In particular: - - Go PostgreSQL - ----------------------- - string varchar - text - - // Integers are automatically be converted to any other integer type if - // it can be done without overflow or underflow. - int8 - int16 smallint - int32 int - int64 bigint - int - uint8 - uint16 - uint32 - uint64 - uint - - // Floats are strict and do not automatically convert like integers. - float32 float4 - float64 float8 - - time.Time date - timestamp - timestamptz - - []byte bytea - - -Null Mapping - -pgx can map nulls in two ways. The first is package pgtype provides types that have a data field and a status field. -They work in a similar fashion to database/sql. The second is to use a pointer to a pointer. - - var foo pgtype.Varchar - var bar *string - err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&foo, &bar) - if err != nil { - return err - } - -Array Mapping - -pgx maps between int16, int32, int64, float32, float64, and string Go slices and the equivalent PostgreSQL array type. -Go slices of native types do not support nulls, so if a PostgreSQL array that contains a null is read into a native Go -slice an error will occur. The pgtype package includes many more array types for PostgreSQL types that do not directly -map to native Go types. - -JSON and JSONB Mapping - -pgx includes built-in support to marshal and unmarshal between Go types and the PostgreSQL JSON and JSONB. - -Inet and CIDR Mapping - -pgx converts netip.Prefix and netip.Addr to and from inet and cidr PostgreSQL types. - -Custom Type Support - -pgx includes support for the common data types like integers, floats, strings, dates, and times that have direct -mappings between Go and SQL. In addition, pgx uses the github.com/jackc/pgx/v5/pgtype library to support more types. See -documention for that library for instructions on how to implement custom types. - -See example_custom_type_test.go for an example of a custom type for the PostgreSQL point type. - -pgx also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer -interfaces. - -If pgx does cannot natively encode a type and that type is a renamed type (e.g. type MyTime time.Time) pgx will attempt -to encode the underlying type. While this is usually desired behavior it can produce surprising behavior if one the -underlying type and the renamed type each implement database/sql interfaces and the other implements pgx interfaces. It -is recommended that this situation be avoided by implementing pgx interfaces on the renamed type. - -Composite types and row values - -Row values and composite types are represented as pgtype.Record -(https://pkg.go.dev/github.com/jackc/pgtype?tab=doc#Record). It is possible to get values of your custom type by -implementing DecodeBinary interface. Decoding into pgtype.Record first can simplify process by avoiding dealing with raw -protocol directly. - -For example: - - type MyType struct { - a int // NULL will cause decoding error - b *string // there can be NULL in this position in SQL - } - - func (t *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - r := pgtype.Record{ - Fields: []pgtype.Value{&pgtype.Int4{}, &pgtype.Text{}}, - } - - if err := r.DecodeBinary(ci, src); err != nil { - return err - } - - if r.Status != pgtype.Present { - return errors.New("BUG: decoding should not be called on NULL value") - } - - a := r.Fields[0].(*pgtype.Int4) - b := r.Fields[1].(*pgtype.Text) - - // type compatibility is checked by AssignTo - // only lossless assignments will succeed - if err := a.AssignTo(&t.a); err != nil { - return err - } - - // AssignTo also deals with null value handling - if err := b.AssignTo(&t.b); err != nil { - return err - } - return nil - } - - result := MyType{} - err := conn.QueryRow(context.Background(), "select row(1, 'foo'::text)", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&r) +The package pgtype provides extensive and customizable support for converting Go values to and from PostgreSQL values +including array and composite types. See that package's documentation for details. Transactions diff --git a/pgtype/doc.go b/pgtype/doc.go index 9764aabf..62d73ed2 100644 --- a/pgtype/doc.go +++ b/pgtype/doc.go @@ -6,6 +6,53 @@ types already registered. Additional types can be registered with Map.RegisterTy Use Map.Scan and Map.Encode to decode PostgreSQL values to Go and encode Go values to PostgreSQL respectively. +Base Type Mapping + +pgtype maps between all common base types directly between Go and PostgreSQL. In particular: + + Go PostgreSQL + ----------------------- + string varchar + text + + // Integers are automatically be converted to any other integer type if + // it can be done without overflow or underflow. + int8 + int16 smallint + int32 int + int64 bigint + int + uint8 + uint16 + uint32 + uint64 + uint + + // Floats are strict and do not automatically convert like integers. + float32 float4 + float64 float8 + + time.Time date + timestamp + timestamptz + + netip.Addr inet + netip.Prefix cidr + + []byte bytea + +Null Values + +pgtype can map NULLs in two ways. The first is types that can directly represent NULL such as Int4. They work in a +similar fashion to database/sql. The second is to use a pointer to a pointer. + + var foo pgtype.Text + var bar *string + err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&foo, &bar) + if err != nil { + return err + } + JSON Support pgtype automatically marshals and unmarshals data from json and jsonb PostgreSQL types. @@ -23,7 +70,8 @@ CompositeIndexGetter. Enum Support -PostgreSQL enums can usually be treated as text. However, EnumCodec implements support for interning strings which can reduce memory usage. +PostgreSQL enums can usually be treated as text. However, EnumCodec implements support for interning strings which can +reduce memory usage. Array, Composite, and Enum Type Registration @@ -35,6 +83,8 @@ Generally, all Codecs will support interfaces that can be implemented to enable PointCodec can use any Go type that implements the PointScanner and PointValuer interfaces. So rather than use pgtype.Point and application can directly use its own point type with pgtype as long as it implements those interfaces. +See example_custom_type_test.go for an example of a custom type for the PostgreSQL point type. + Sometimes pgx supports a PostgreSQL type such as numeric but the Go type is in an external package that does not have pgx support such as github.com/shopspring/decimal. These types can be registered with pgtype with custom conversion logic. See https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid for a example @@ -51,6 +101,17 @@ Encoding Unknown Types pgtype works best when the OID of the PostgreSQL type is known. But in some cases such as using the simple protocol the OID is unknown. In this case Map.RegisterDefaultPgType can be used to register an assumed OID for a particular Go type. +Renamed Types + +If pgtype does not recognize a type and that type is a renamed simple type simple (e.g. type MyInt32 int32) pgtype acts +as if it is the underlying type. It currently cannot automatically detect the underlying type of renamed structs (eg.g. +type MyTime time.Time). + +Compatibility with database/sql + +pgtype also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer +interfaces. + Overview of Scanning Implementation The first step is to use the OID to lookup the correct Codec. If the OID is unavailable, Map will try to find the OID diff --git a/stdlib/sql.go b/stdlib/sql.go index 8a24c4c5..fc0b0239 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -27,8 +27,8 @@ // // db.QueryRow("select * from users where id=$1", userID) // -// In Go 1.13 and above (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard database/sql.DB connection -// pool. This allows operations that use pgx specific functionality. +// (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard database/sql.DB connection pool. This allows +// operations that use pgx specific functionality. // // // Given db is a *sql.DB // conn, err := db.Conn(context.Background()) From 4739f79fca82f5461411d1df7ef984143125c403 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 09:42:46 -0500 Subject: [PATCH 1118/1158] More doc tweaks --- doc.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/doc.go b/doc.go index 4cd3861b..9f3f9774 100644 --- a/doc.go +++ b/doc.go @@ -17,8 +17,8 @@ here. In addition, a config struct can be created by `ParseConfig` and modified Connection Pool -`*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use sub-package pgxpool for a -concurrency safe connection pool. +`*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use package +github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool. Query Interface @@ -174,10 +174,12 @@ pgx supports tracing by setting ConnConfig.Tracer. In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer. +For debug tracing of the actual PostgreSQL wire protocol messages see github.com/jackc/pgx/v5/pgproto3. + Lower Level PostgreSQL Functionality -pgx is implemented on top of github.com/jackc/pgconn a lower level PostgreSQL driver. The Conn.PgConn() method can be -used to access this lower layer. +github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn in +implemented on top of pgconn. The Conn.PgConn() method can be used to access this lower layer. PgBouncer From 5cee04a0262fe032847443979f24d9b047d81b8c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 10:11:13 -0500 Subject: [PATCH 1119/1158] Add child records docs and examples --- pgtype/doc.go | 5 ++ pgtype/example_child_records_test.go | 90 ++++++++++++++++++++++++++++ pgtype/example_custom_type_test.go | 2 +- pgtype/example_json_test.go | 2 +- 4 files changed, 97 insertions(+), 2 deletions(-) create mode 100644 pgtype/example_child_records_test.go diff --git a/pgtype/doc.go b/pgtype/doc.go index 62d73ed2..7b4ed409 100644 --- a/pgtype/doc.go +++ b/pgtype/doc.go @@ -112,6 +112,11 @@ Compatibility with database/sql pgtype also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer interfaces. +Child Records + +pgtype's support for arrays and composite records can be used to load records and their children in a single query. See +example_child_records_test.go for an example. + Overview of Scanning Implementation The first step is to use the OID to lookup the correct Codec. If the OID is unavailable, Map will try to find the OID diff --git a/pgtype/example_child_records_test.go b/pgtype/example_child_records_test.go new file mode 100644 index 00000000..0b1f6d43 --- /dev/null +++ b/pgtype/example_child_records_test.go @@ -0,0 +1,90 @@ +package pgtype_test + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/jackc/pgx/v5" +) + +// This example uses a single query to return parent and child records. +func Example_childRecords() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + // Setup example schema and data. + _, err = conn.Exec(ctx, ` +create temporary table teams ( + name text primary key +); + +create temporary table players ( + name text primary key, + team_name text, + position text +); + +insert into teams (name) values + ('Alpha'), + ('Beta'); + +insert into players (name, team_name, position) values + ('Adam', 'Alpha', 'wing'), + ('Bill', 'Alpha', 'halfback'), + ('Charlie', 'Alpha', 'fullback'), + ('Don', 'Beta', 'halfback'), + ('Edgar', 'Beta', 'halfback'), + ('Frank', 'Beta', 'fullback') +`) + if err != nil { + fmt.Printf("Unable to setup example schema and data: %v", err) + return + } + + type Player struct { + Name string + Position string + } + + type Team struct { + Name string + Players []Player + } + + rows, _ := conn.Query(ctx, ` +select t.name, + (select array_agg(row(p.name, position) order by p.name) from players p where p.team_name = t.name) +from teams t +order by t.name +`) + teams, err := pgx.CollectRows(rows, pgx.RowToStructByPos[Team]) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + for _, team := range teams { + fmt.Println(team.Name) + for _, player := range team.Players { + fmt.Printf(" %s: %s\n", player.Name, player.Position) + } + } + + // Output: + // Alpha + // Adam: wing + // Bill: halfback + // Charlie: fullback + // Beta + // Don: halfback + // Edgar: halfback + // Frank: fullback +} diff --git a/pgtype/example_custom_type_test.go b/pgtype/example_custom_type_test.go index 2fd63bcc..ceb9a0aa 100644 --- a/pgtype/example_custom_type_test.go +++ b/pgtype/example_custom_type_test.go @@ -39,7 +39,7 @@ func (src *Point) String() string { return fmt.Sprintf("%.1f, %.1f", src.X, src.Y) } -func Example_CustomType() { +func Example_customType() { conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { fmt.Printf("Unable to establish connection: %v", err) diff --git a/pgtype/example_json_test.go b/pgtype/example_json_test.go index c11348b7..98fb675a 100644 --- a/pgtype/example_json_test.go +++ b/pgtype/example_json_test.go @@ -8,7 +8,7 @@ import ( "github.com/jackc/pgx/v5" ) -func Example_JSON() { +func Example_json() { conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { fmt.Printf("Unable to establish connection: %v", err) From ce378b4d9c817e0a9c3dc641a7bbe2d51951c02d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 10:21:01 -0500 Subject: [PATCH 1120/1158] Skip example on Cockroach DB --- pgtype/example_child_records_test.go | 34 ++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/pgtype/example_child_records_test.go b/pgtype/example_child_records_test.go index 0b1f6d43..29ae7ef3 100644 --- a/pgtype/example_child_records_test.go +++ b/pgtype/example_child_records_test.go @@ -9,6 +9,16 @@ import ( "github.com/jackc/pgx/v5" ) +type Player struct { + Name string + Position string +} + +type Team struct { + Name string + Players []Player +} + // This example uses a single query to return parent and child records. func Example_childRecords() { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) @@ -20,6 +30,20 @@ func Example_childRecords() { return } + if conn.PgConn().ParameterStatus("crdb_version") != "" { + // Skip test / example when running on CockroachDB which doesn't support the point type. Since an example can't be + // skipped fake success instead. + fmt.Println(`Alpha + Adam: wing + Bill: halfback + Charlie: fullback +Beta + Don: halfback + Edgar: halfback + Frank: fullback`) + return + } + // Setup example schema and data. _, err = conn.Exec(ctx, ` create temporary table teams ( @@ -49,16 +73,6 @@ insert into players (name, team_name, position) values return } - type Player struct { - Name string - Position string - } - - type Team struct { - Name string - Players []Player - } - rows, _ := conn.Query(ctx, ` select t.name, (select array_agg(row(p.name, position) order by p.name) from players p where p.team_name = t.name) From 2da0a11c52b1c62db65df8ed4cb3f2fa544fa9f0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 10:52:35 -0500 Subject: [PATCH 1121/1158] Skip some examples on CockroachDB --- pgtype/example_child_records_test.go | 3 +-- query_test.go | 8 ++++++++ rows_test.go | 8 ++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/pgtype/example_child_records_test.go b/pgtype/example_child_records_test.go index 29ae7ef3..9a4218ba 100644 --- a/pgtype/example_child_records_test.go +++ b/pgtype/example_child_records_test.go @@ -31,8 +31,7 @@ func Example_childRecords() { } if conn.PgConn().ParameterStatus("crdb_version") != "" { - // Skip test / example when running on CockroachDB which doesn't support the point type. Since an example can't be - // skipped fake success instead. + // Skip test / example when running on CockroachDB. Since an example can't be skipped fake success instead. fmt.Println(`Alpha Adam: wing Bill: halfback diff --git a/query_test.go b/query_test.go index f10078bf..317e4f60 100644 --- a/query_test.go +++ b/query_test.go @@ -1911,6 +1911,14 @@ func ExampleConn_Query() { return } + if conn.PgConn().ParameterStatus("crdb_version") != "" { + // Skip test / example when running on CockroachDB. Since an example can't be skipped fake success instead. + fmt.Println(`Cheeseburger: $10 +Fries: $5 +Soft Drink: $3`) + return + } + // Setup example schema and data. _, err = conn.Exec(ctx, ` create temporary table products ( diff --git a/rows_test.go b/rows_test.go index e25ceeea..6771469f 100644 --- a/rows_test.go +++ b/rows_test.go @@ -339,6 +339,14 @@ func ExampleRowToStructByPos() { return } + if conn.PgConn().ParameterStatus("crdb_version") != "" { + // Skip test / example when running on CockroachDB. Since an example can't be skipped fake success instead. + fmt.Println(`Cheeseburger: $10 +Fries: $5 +Soft Drink: $3`) + return + } + // Setup example schema and data. _, err = conn.Exec(ctx, ` create temporary table products ( From 9d0f27bc4bffd471592b01f58e7c244c6b4f96ce Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 15:21:34 -0500 Subject: [PATCH 1122/1158] Initial fuzz testing and fix Initial fuzz testing of pgproto3 found a panic --- pgproto3/frontend.go | 8 ++++- pgproto3/fuzz_test.go | 29 +++++++++++++++++++ ...5e392b0501847a9b35d3857e67872046dbdc04913e | 2 ++ 3 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 pgproto3/fuzz_test.go create mode 100644 pgproto3/testdata/fuzz/FuzzFrontend/65d91093341a68b16f04605e392b0501847a9b35d3857e67872046dbdc04913e diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 09f04141..83dea963 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -223,7 +223,13 @@ func (f *Frontend) Receive() (BackendMessage, error) { } f.msgType = header[0] - f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + + msgLength := int(binary.BigEndian.Uint32(header[1:])) + if msgLength < 4 { + return nil, fmt.Errorf("invalid message length: %d", msgLength) + } + + f.bodyLen = msgLength - 4 f.partialMsg = true } diff --git a/pgproto3/fuzz_test.go b/pgproto3/fuzz_test.go new file mode 100644 index 00000000..84ea8430 --- /dev/null +++ b/pgproto3/fuzz_test.go @@ -0,0 +1,29 @@ +package pgproto3_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +func FuzzFrontend(f *testing.F) { + testcases := [][]byte{ + {'Z', 0, 0, 0, 5}, + } + for _, tc := range testcases { + f.Add(tc) + } + f.Fuzz(func(t *testing.T, encodedMsg []byte) { + r := &bytes.Buffer{} + w := &bytes.Buffer{} + fe := pgproto3.NewFrontend(r, w) + + _, err := r.Write(encodedMsg) + require.NoError(t, err) + + // Not checking anything other than no panic. + fe.Receive() + }) +} diff --git a/pgproto3/testdata/fuzz/FuzzFrontend/65d91093341a68b16f04605e392b0501847a9b35d3857e67872046dbdc04913e b/pgproto3/testdata/fuzz/FuzzFrontend/65d91093341a68b16f04605e392b0501847a9b35d3857e67872046dbdc04913e new file mode 100644 index 00000000..4db40929 --- /dev/null +++ b/pgproto3/testdata/fuzz/FuzzFrontend/65d91093341a68b16f04605e392b0501847a9b35d3857e67872046dbdc04913e @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("0\x00\x00\x00\x02") From 7f382f5190f58c16f5bd9d60f4443b658a5a3a22 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Jul 2022 16:13:06 -0500 Subject: [PATCH 1123/1158] Better fuzz testing and fix several bugs it found Fix infinite loop in AuthenticationSASL.Decode Fix panic in CommandComplete.Decode Fix panic in DataRow.Decode Fix panic in NotificationResponse.Decode --- pgproto3/authentication_sasl.go | 7 ++-- pgproto3/command_complete.go | 5 ++- pgproto3/data_row.go | 10 +++--- pgproto3/fuzz_test.go | 36 ++++++++++++++++--- pgproto3/notification_response.go | 4 +++ pgproto3/pgproto3.go | 3 +- ...8f7062d6a07796fdc43b33e0ba9dbd7074a0211fa6 | 4 +++ ...5e392b0501847a9b35d3857e67872046dbdc04913e | 2 -- ...04d526ae14326c8573b7409032caac8461e83065f7 | 4 +++ ...160fbc6e28794612a411d00bde104364ee281c4214 | 4 +++ ...5f7dd023933f3a86ab566e3f2b091eb36248107eb4 | 4 +++ 11 files changed, 67 insertions(+), 16 deletions(-) create mode 100644 pgproto3/testdata/fuzz/FuzzFrontend/39c5e864da4707fc15fea48f7062d6a07796fdc43b33e0ba9dbd7074a0211fa6 delete mode 100644 pgproto3/testdata/fuzz/FuzzFrontend/65d91093341a68b16f04605e392b0501847a9b35d3857e67872046dbdc04913e create mode 100644 pgproto3/testdata/fuzz/FuzzFrontend/9b06792b1aaac8a907dbfa04d526ae14326c8573b7409032caac8461e83065f7 create mode 100644 pgproto3/testdata/fuzz/FuzzFrontend/a661fb98e802839f0a7361160fbc6e28794612a411d00bde104364ee281c4214 create mode 100644 pgproto3/testdata/fuzz/FuzzFrontend/fc98dcd487a5173b38763a5f7dd023933f3a86ab566e3f2b091eb36248107eb4 diff --git a/pgproto3/authentication_sasl.go b/pgproto3/authentication_sasl.go index 996b97d3..59650d4c 100644 --- a/pgproto3/authentication_sasl.go +++ b/pgproto3/authentication_sasl.go @@ -36,10 +36,11 @@ func (dst *AuthenticationSASL) Decode(src []byte) error { authMechanisms := src[4:] for len(authMechanisms) > 1 { idx := bytes.IndexByte(authMechanisms, 0) - if idx > 0 { - dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx])) - authMechanisms = authMechanisms[idx+1:] + if idx == -1 { + return &invalidMessageFormatErr{messageType: "AuthenticationSASL", details: "unterminated string"} } + dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx])) + authMechanisms = authMechanisms[idx+1:] } return nil diff --git a/pgproto3/command_complete.go b/pgproto3/command_complete.go index a19b906c..814027ca 100644 --- a/pgproto3/command_complete.go +++ b/pgproto3/command_complete.go @@ -18,8 +18,11 @@ func (*CommandComplete) Backend() {} // type identifier and 4 byte message length. func (dst *CommandComplete) Decode(src []byte) error { idx := bytes.IndexByte(src, 0) + if idx == -1 { + return &invalidMessageFormatErr{messageType: "CommandComplete", details: "unterminated string"} + } if idx != len(src)-1 { - return &invalidMessageFormatErr{messageType: "CommandComplete"} + return &invalidMessageFormatErr{messageType: "CommandComplete", details: "string terminated too early"} } dst.CommandTag = src[:idx] diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go index 0bfe9a0d..4de77977 100644 --- a/pgproto3/data_row.go +++ b/pgproto3/data_row.go @@ -43,19 +43,19 @@ func (dst *DataRow) Decode(src []byte) error { return &invalidMessageFormatErr{messageType: "DataRow"} } - msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) + valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 // null - if msgSize == -1 { + if valueLen == -1 { dst.Values[i] = nil } else { - if len(src[rp:]) < msgSize { + if len(src[rp:]) < valueLen || valueLen < 0 { return &invalidMessageFormatErr{messageType: "DataRow"} } - dst.Values[i] = src[rp : rp+msgSize : rp+msgSize] - rp += msgSize + dst.Values[i] = src[rp : rp+valueLen : rp+valueLen] + rp += valueLen } } diff --git a/pgproto3/fuzz_test.go b/pgproto3/fuzz_test.go index 84ea8430..332596ab 100644 --- a/pgproto3/fuzz_test.go +++ b/pgproto3/fuzz_test.go @@ -4,22 +4,50 @@ import ( "bytes" "testing" + "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgproto3" "github.com/stretchr/testify/require" ) func FuzzFrontend(f *testing.F) { - testcases := [][]byte{ - {'Z', 0, 0, 0, 5}, + testcases := []struct { + msgType byte + msgLen uint32 + msgBody []byte + }{ + { + msgType: 'Z', + msgLen: 2, + msgBody: []byte{'I'}, + }, + { + msgType: 'Z', + msgLen: 5, + msgBody: []byte{'I'}, + }, } for _, tc := range testcases { - f.Add(tc) + f.Add(tc.msgType, tc.msgLen, tc.msgBody) } - f.Fuzz(func(t *testing.T, encodedMsg []byte) { + f.Fuzz(func(t *testing.T, msgType byte, msgLen uint32, msgBody []byte) { + // Prune any msgLen > len(msgBody) because they would hang the test waiting for more input. + if int(msgLen) > len(msgBody)+4 { + return + } + + // Prune any messages that are too long. + if msgLen > 128 || len(msgBody) > 128 { + return + } + r := &bytes.Buffer{} w := &bytes.Buffer{} fe := pgproto3.NewFrontend(r, w) + var encodedMsg []byte + encodedMsg = append(encodedMsg, msgType) + encodedMsg = pgio.AppendUint32(encodedMsg, msgLen) + encodedMsg = append(encodedMsg, msgBody...) _, err := r.Write(encodedMsg) require.NoError(t, err) diff --git a/pgproto3/notification_response.go b/pgproto3/notification_response.go index 03ce51e5..228e0dac 100644 --- a/pgproto3/notification_response.go +++ b/pgproto3/notification_response.go @@ -22,6 +22,10 @@ func (*NotificationResponse) Backend() {} func (dst *NotificationResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "NotificationResponse", details: "too short"} + } + pid := binary.BigEndian.Uint32(buf.Next(4)) b, err := buf.ReadBytes(0) diff --git a/pgproto3/pgproto3.go b/pgproto3/pgproto3.go index a0333aa5..ef5a5489 100644 --- a/pgproto3/pgproto3.go +++ b/pgproto3/pgproto3.go @@ -46,10 +46,11 @@ func (e *invalidMessageLenErr) Error() string { type invalidMessageFormatErr struct { messageType string + details string } func (e *invalidMessageFormatErr) Error() string { - return fmt.Sprintf("%s body is invalid", e.messageType) + return fmt.Sprintf("%s body is invalid %s", e.messageType, e.details) } type writeError struct { diff --git a/pgproto3/testdata/fuzz/FuzzFrontend/39c5e864da4707fc15fea48f7062d6a07796fdc43b33e0ba9dbd7074a0211fa6 b/pgproto3/testdata/fuzz/FuzzFrontend/39c5e864da4707fc15fea48f7062d6a07796fdc43b33e0ba9dbd7074a0211fa6 new file mode 100644 index 00000000..d1c612d3 --- /dev/null +++ b/pgproto3/testdata/fuzz/FuzzFrontend/39c5e864da4707fc15fea48f7062d6a07796fdc43b33e0ba9dbd7074a0211fa6 @@ -0,0 +1,4 @@ +go test fuzz v1 +byte('A') +uint32(5) +[]byte("0") diff --git a/pgproto3/testdata/fuzz/FuzzFrontend/65d91093341a68b16f04605e392b0501847a9b35d3857e67872046dbdc04913e b/pgproto3/testdata/fuzz/FuzzFrontend/65d91093341a68b16f04605e392b0501847a9b35d3857e67872046dbdc04913e deleted file mode 100644 index 4db40929..00000000 --- a/pgproto3/testdata/fuzz/FuzzFrontend/65d91093341a68b16f04605e392b0501847a9b35d3857e67872046dbdc04913e +++ /dev/null @@ -1,2 +0,0 @@ -go test fuzz v1 -[]byte("0\x00\x00\x00\x02") diff --git a/pgproto3/testdata/fuzz/FuzzFrontend/9b06792b1aaac8a907dbfa04d526ae14326c8573b7409032caac8461e83065f7 b/pgproto3/testdata/fuzz/FuzzFrontend/9b06792b1aaac8a907dbfa04d526ae14326c8573b7409032caac8461e83065f7 new file mode 100644 index 00000000..763b70ae --- /dev/null +++ b/pgproto3/testdata/fuzz/FuzzFrontend/9b06792b1aaac8a907dbfa04d526ae14326c8573b7409032caac8461e83065f7 @@ -0,0 +1,4 @@ +go test fuzz v1 +byte('D') +uint32(21) +[]byte("00\xb300000000000000") diff --git a/pgproto3/testdata/fuzz/FuzzFrontend/a661fb98e802839f0a7361160fbc6e28794612a411d00bde104364ee281c4214 b/pgproto3/testdata/fuzz/FuzzFrontend/a661fb98e802839f0a7361160fbc6e28794612a411d00bde104364ee281c4214 new file mode 100644 index 00000000..3d995c28 --- /dev/null +++ b/pgproto3/testdata/fuzz/FuzzFrontend/a661fb98e802839f0a7361160fbc6e28794612a411d00bde104364ee281c4214 @@ -0,0 +1,4 @@ +go test fuzz v1 +byte('C') +uint32(4) +[]byte("0") diff --git a/pgproto3/testdata/fuzz/FuzzFrontend/fc98dcd487a5173b38763a5f7dd023933f3a86ab566e3f2b091eb36248107eb4 b/pgproto3/testdata/fuzz/FuzzFrontend/fc98dcd487a5173b38763a5f7dd023933f3a86ab566e3f2b091eb36248107eb4 new file mode 100644 index 00000000..45f0ba81 --- /dev/null +++ b/pgproto3/testdata/fuzz/FuzzFrontend/fc98dcd487a5173b38763a5f7dd023933f3a86ab566e3f2b091eb36248107eb4 @@ -0,0 +1,4 @@ +go test fuzz v1 +byte('R') +uint32(13) +[]byte("\x00\x00\x00\n0\x12\xebG\x8dI']G\xdac\x95\xb7\x18\xb0\x02\xe8m\xc2\x00\xef\x03\x12\x1b\xbdj\x10\x9f\xf9\xeb\xb8") From c3258b7f52a3144ecee26228491d7513cd673715 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jul 2022 09:10:50 -0500 Subject: [PATCH 1124/1158] Fix scan pointer to pointer to nil slice https://github.com/jackc/pgx/issues/1263 --- pgtype/array_codec.go | 7 +++++++ pgtype/pgtype_test.go | 11 +++++++++++ 2 files changed, 18 insertions(+) diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 8ed45da5..dae12039 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" + "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -216,6 +217,12 @@ func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target any) Scan return nil } + // target / arrayScanner might be a pointer to a nil. If it is create one so we can call ScanIndexType to plan the + // scan of the elements. + if anynil.Is(target) { + arrayScanner = reflect.New(reflect.TypeOf(target).Elem()).Interface().(ArraySetter) + } + elementType := arrayScanner.ScanIndexType() elementScanPlan := m.PlanScan(c.ElementType.OID, format, elementType) diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 685cbb62..6325829e 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -254,6 +254,17 @@ func TestScanPlanInterface(t *testing.T) { assert.Error(t, err) } +// https://github.com/jackc/pgx/issues/1263 +func TestMapScanPtrToPtrToSlice(t *testing.T) { + m := pgtype.NewMap() + src := []byte("{foo,bar}") + var v *[]string + plan := m.PlanScan(pgtype.TextArrayOID, pgtype.TextFormatCode, &v) + err := plan.Scan(src, &v) + require.NoError(t, err) + require.Equal(t, []string{"foo", "bar"}, *v) +} + func BenchmarkTypeMapScanInt4IntoBinaryDecoder(b *testing.B) { m := pgtype.NewMap() src := []byte{0, 0, 0, 42} From 033fc6f62a899f9ee152d3f2b5d06a9df200a8e2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jul 2022 09:16:42 -0500 Subject: [PATCH 1125/1158] Rename pgxpool.NewConfig to NewWithConfig https://github.com/jackc/pgx/issues/1264 --- examples/url_shortener/main.go | 2 +- pgxpool/bench_test.go | 4 ++-- pgxpool/pool.go | 6 +++--- pgxpool/pool_test.go | 24 ++++++++++++------------ 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/examples/url_shortener/main.go b/examples/url_shortener/main.go index 092de1fb..12195922 100644 --- a/examples/url_shortener/main.go +++ b/examples/url_shortener/main.go @@ -75,7 +75,7 @@ func main() { log.Fatalln("Unable to parse DATABASE_URL:", err) } - db, err = pgxpool.NewConfig(context.Background(), poolConfig) + db, err = pgxpool.NewWithConfig(context.Background(), poolConfig) if err != nil { log.Fatalln("Unable to create connection pool:", err) } diff --git a/pgxpool/bench_test.go b/pgxpool/bench_test.go index 588d104c..c2d58a38 100644 --- a/pgxpool/bench_test.go +++ b/pgxpool/bench_test.go @@ -34,7 +34,7 @@ func BenchmarkMinimalPreparedSelectBaseline(b *testing.B) { return err } - db, err := pgxpool.NewConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(b, err) conn, err := db.Acquire(context.Background()) @@ -65,7 +65,7 @@ func BenchmarkMinimalPreparedSelect(b *testing.B) { return err } - db, err := pgxpool.NewConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(b, err) var n int64 diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 9872e670..5fc77014 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -160,11 +160,11 @@ func New(ctx context.Context, connString string) (*Pool, error) { return nil, err } - return NewConfig(ctx, config) + return NewWithConfig(ctx, config) } -// NewConfig creates a new Pool. config must have been created by ParseConfig. -func NewConfig(ctx context.Context, config *Config) (*Pool, error) { +// NewWithConfig creates a new Pool. config must have been created by ParseConfig. +func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. if !config.createdByParseConfig { diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 0e4d8acf..b5ce9ad7 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -30,7 +30,7 @@ func TestConnectConfig(t *testing.T) { connString := os.Getenv("PGX_TEST_DATABASE") config, err := pgxpool.ParseConfig(connString) require.NoError(t, err) - pool, err := pgxpool.NewConfig(context.Background(), config) + pool, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) assertConfigsEqual(t, config, pool.Config(), "Pool.Config() returns original config") pool.Close() @@ -52,7 +52,7 @@ func TestConnectConfigRequiresConnConfigFromParseConfig(t *testing.T) { config := &pgxpool.Config{} - require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgxpool.NewConfig(context.Background(), config) }) + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgxpool.NewWithConfig(context.Background(), config) }) } func TestConfigCopyReturnsEqualConfig(t *testing.T) { @@ -72,7 +72,7 @@ func TestConfigCopyCanBeUsedToConnect(t *testing.T) { copied := original.Copy() assert.NotPanics(t, func() { - _, err = pgxpool.NewConfig(context.Background(), copied) + _, err = pgxpool.NewWithConfig(context.Background(), copied) }) assert.NoError(t, err) } @@ -205,7 +205,7 @@ func TestPoolBeforeConnect(t *testing.T) { return nil } - db, err := pgxpool.NewConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -226,7 +226,7 @@ func TestPoolAfterConnect(t *testing.T) { return err } - db, err := pgxpool.NewConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -249,7 +249,7 @@ func TestPoolBeforeAcquire(t *testing.T) { return acquireAttempts%2 == 0 } - db, err := pgxpool.NewConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -304,7 +304,7 @@ func TestPoolAfterRelease(t *testing.T) { return afterReleaseCount%2 == 1 } - db, err := pgxpool.NewConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -357,7 +357,7 @@ func TestConnReleaseChecksMaxConnLifetime(t *testing.T) { config.MaxConnLifetime = 250 * time.Millisecond - db, err := pgxpool.NewConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -410,7 +410,7 @@ func TestPoolBackgroundChecksMaxConnLifetime(t *testing.T) { config.MaxConnLifetime = 100 * time.Millisecond config.HealthCheckPeriod = 100 * time.Millisecond - db, err := pgxpool.NewConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -435,7 +435,7 @@ func TestPoolBackgroundChecksMaxConnIdleTime(t *testing.T) { config.MaxConnIdleTime = 100 * time.Millisecond config.HealthCheckPeriod = 150 * time.Millisecond - db, err := pgxpool.NewConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -464,7 +464,7 @@ func TestPoolBackgroundChecksMinConns(t *testing.T) { config.HealthCheckPeriod = 100 * time.Millisecond config.MinConns = 2 - db, err := pgxpool.NewConfig(context.Background(), config) + db, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer db.Close() @@ -909,7 +909,7 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) { return nil } - pool, err := pgxpool.NewConfig(context.Background(), config) + pool, err := pgxpool.NewWithConfig(context.Background(), config) require.NoError(t, err) defer pool.Close() From 957671a6ec5681f66026ce56ec3ac49bb55a5f4b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jul 2022 12:16:16 -0500 Subject: [PATCH 1126/1158] Use puddle v2 --- go.mod | 6 +++--- go.sum | 23 ++++++++--------------- pgxpool/conn.go | 2 +- pgxpool/pool.go | 2 +- pgxpool/stat.go | 2 +- 5 files changed, 14 insertions(+), 21 deletions(-) diff --git a/go.mod b/go.mod index 6710f0e8..428b85f5 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,8 @@ go 1.18 require ( github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b - github.com/jackc/puddle v1.2.2-0.20220404125616-4e959849469a - github.com/stretchr/testify v1.7.0 + github.com/jackc/puddle/v2 v2.0.0-beta.1 + github.com/stretchr/testify v1.8.0 golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b golang.org/x/text v0.3.7 ) @@ -16,5 +16,5 @@ require ( github.com/kr/pretty v0.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect - gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9dc19a92..74d73902 100644 --- a/go.sum +++ b/go.sum @@ -5,10 +5,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= -github.com/jackc/puddle v1.2.1 h1:gI8os0wpRXFd4FiAY2dWiqRK037tjj3t7rKFeO4X5iw= -github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.2.2-0.20220404125616-4e959849469a h1:oH7y/b+q2BEerCnARr/HZc1NxOYbKSJor4MqQXlhh+s= -github.com/jackc/puddle v1.2.2-0.20220404125616-4e959849469a/go.mod h1:ZQuO1Un86Xpe1ShKl08ERTzYhzWq+OvrvotbpeE3XO0= +github.com/jackc/puddle/v2 v2.0.0-beta.1 h1:Y4Ao+kFWANtDhWUkdw1JcbH+x84/aq6WUfhVQ1wdib8= +github.com/jackc/puddle/v2 v2.0.0-beta.1/go.mod h1:itE7ZJY8xnoo0JqJEpSMprN0f+NQkMCuEV/N9j8h0oc= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -17,25 +15,20 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b h1:QAqMVf3pSa6eeTsuklijukjXBlj7Es2QQplab+/RbQ4= golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pgxpool/conn.go b/pgxpool/conn.go index 802026e2..36f90969 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -6,7 +6,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" - puddle "github.com/jackc/puddle/puddleg" + "github.com/jackc/puddle/v2" ) // Conn is an acquired *pgx.Conn from a Pool. diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 5fc77014..c9d79696 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -12,7 +12,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" - puddle "github.com/jackc/puddle/puddleg" + "github.com/jackc/puddle/v2" ) var defaultMaxConns = int32(4) diff --git a/pgxpool/stat.go b/pgxpool/stat.go index 47342be4..cfa0c4c5 100644 --- a/pgxpool/stat.go +++ b/pgxpool/stat.go @@ -3,7 +3,7 @@ package pgxpool import ( "time" - "github.com/jackc/puddle" + "github.com/jackc/puddle/v2" ) // Stat is a snapshot of Pool statistics. From 83670d675d125bc3b3e57bf53fd17083f5dc5a0c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jul 2022 12:17:00 -0500 Subject: [PATCH 1127/1158] Upgrade golang.org/x/crypto --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 428b85f5..7893a50a 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/jackc/puddle/v2 v2.0.0-beta.1 github.com/stretchr/testify v1.8.0 - golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b + golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa golang.org/x/text v0.3.7 ) diff --git a/go.sum b/go.sum index 74d73902..d4cafbf2 100644 --- a/go.sum +++ b/go.sum @@ -23,6 +23,8 @@ github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PK github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b h1:QAqMVf3pSa6eeTsuklijukjXBlj7Es2QQplab+/RbQ4= golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From 0eda0109cad12658a03874462f88c1f55b5effd9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jul 2022 12:22:29 -0500 Subject: [PATCH 1128/1158] Add Pool.Reset() --- pgxpool/pool.go | 9 +++++++++ pgxpool/pool_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/pgxpool/pool.go b/pgxpool/pool.go index c9d79696..c7601a38 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -534,6 +534,15 @@ func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn { return conns } +// Reset closes all connections, but leaves the pool open. It is intended for use when an error is detected that would +// disrupt all connections (such as a network interruption or a server state change). +// +// It is safe to reset a pool while connections are checked out. Those connections will be closed when they are returned +// to the pool. +func (p *Pool) Reset() { + p.p.Reset() +} + // Config returns a copy of config that was used to initialize this pool. func (p *Pool) Config() *Config { return p.config.Copy() } diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index b5ce9ad7..cfebca7c 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -349,6 +349,31 @@ func TestPoolAcquireAllIdle(t *testing.T) { } } +func TestPoolReset(t *testing.T) { + t.Parallel() + + db, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer db.Close() + + conns := make([]*pgxpool.Conn, 3) + for i := range conns { + conns[i], err = db.Acquire(context.Background()) + assert.NoError(t, err) + } + + db.Reset() + + for _, c := range conns { + if c != nil { + c.Release() + } + } + waitForReleaseToComplete() + + require.EqualValues(t, 0, db.Stat().TotalConns()) +} + func TestConnReleaseChecksMaxConnLifetime(t *testing.T) { t.Parallel() From 0a539a9d92131c8353435393670f4d10b8b36ebb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Aug 2022 05:58:55 -0500 Subject: [PATCH 1129/1158] Upgrade pgproto3 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index aaf7a486..a2eeecff 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.3.0 + github.com/jackc/pgproto3/v2 v2.3.1 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.7.0 golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 diff --git a/go.sum b/go.sum index a3834fd2..95b6bbca 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,8 @@ github.com/jackc/pgproto3/v2 v2.2.1-0.20220412121321-175856ffd3c8 h1:KxsCQec+1iw github.com/jackc/pgproto3/v2 v2.2.1-0.20220412121321-175856ffd3c8/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.3.0 h1:brH0pCGBDkBW07HWlN/oSBXrmo3WB0UvZd1pIuDcL8Y= github.com/jackc/pgproto3/v2 v2.3.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.3.1 h1:nwj7qwf0S+Q7ISFfBndqeLwSwxs+4DPsbRFjECT1Y4Y= +github.com/jackc/pgproto3/v2 v2.3.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= From 5192d9acc15feea530a46d397effe01a44d4647c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Aug 2022 06:00:03 -0500 Subject: [PATCH 1130/1158] Upgrade 3rd party dependencies --- go.mod | 4 ++-- go.sum | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index a2eeecff..a1bec1f4 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.3.1 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b - github.com/stretchr/testify v1.7.0 - golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 + github.com/stretchr/testify v1.8.0 + golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa golang.org/x/text v0.3.7 ) diff --git a/go.sum b/go.sum index 95b6bbca..e89d701d 100644 --- a/go.sum +++ b/go.sum @@ -75,12 +75,17 @@ github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6Mwd github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= @@ -95,11 +100,14 @@ golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWP golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -110,6 +118,7 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -134,3 +143,5 @@ gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:a gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 4c048d40d859bbb702bff9b46cfefd44b89d82c1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Aug 2022 06:07:40 -0500 Subject: [PATCH 1131/1158] Update changelog --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a3efb7f2..f6a6807f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +# 1.13.0 (August 6, 2022) + +* Add sslpassword support (Eric McCormack and yun.xu) +* Add prefer-standby target_session_attrs support (sergey.bashilov) +* Fix GSS ErrorResponse handling (Oliver Tan) + # 1.12.1 (May 7, 2022) * Fix: setting krbspn and krbsrvname in connection string (sireax) From 1f64122c421b7d681b6973b61d53eeb4731660ce Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Aug 2022 06:27:32 -0500 Subject: [PATCH 1132/1158] Tweak changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 83c32783..1fe8bdf9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -135,7 +135,7 @@ See documentation for `QueryExecMode`. ## QueryRewriter Interface and NamedArgs -pgx now supports named arguments with the NamedArgs type. This is implemented via the new QueryRewriter interface which +pgx now supports named arguments with the `NamedArgs` type. This is implemented via the new `QueryRewriter` interface which allows arbitrary rewriting of query SQL and arguments. ## RowScanner Interface From 1453cd4b97024bf0452ae12bbc2ff1dceefd471e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Aug 2022 07:11:11 -0500 Subject: [PATCH 1133/1158] Update v5 status --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c7224b22..2bfa6a2b 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ # pgx - PostgreSQL Driver and Toolkit -*This is the v5 development branch. It is still in active development and testing.* +*This is the v5 development branch. It is still in beta testing.* pgx is a pure Go driver and toolkit for PostgreSQL. From 33b782a96d0fb1c82985c8bd5192f7a144b8877d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 11 Aug 2022 20:55:50 -0500 Subject: [PATCH 1134/1158] Potential fix for Windows https://github.com/jackc/pgx/issues/1274 --- internal/nbconn/nbconn.go | 60 -------------------- internal/nbconn/nbconn_fake_non_block.go | 13 +++++ internal/nbconn/nbconn_real_non_block.go | 70 ++++++++++++++++++++++++ 3 files changed, 83 insertions(+), 60 deletions(-) create mode 100644 internal/nbconn/nbconn_fake_non_block.go create mode 100644 internal/nbconn/nbconn_real_non_block.go diff --git a/internal/nbconn/nbconn.go b/internal/nbconn/nbconn.go index 16c4b713..c091ea3f 100644 --- a/internal/nbconn/nbconn.go +++ b/internal/nbconn/nbconn.go @@ -13,7 +13,6 @@ package nbconn import ( "crypto/tls" "errors" - "io" "net" "os" "sync" @@ -389,35 +388,6 @@ func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) { return c.conn.Write(b) } -// realNonblockingWrite does a non-blocking write. readFlushLock must already be held. -func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { - c.nonblockWriteBuf = b - c.nonblockWriteN = 0 - c.nonblockWriteErr = nil - err = c.rawConn.Write(func(fd uintptr) (done bool) { - c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf) - return true - }) - n = c.nonblockWriteN - if err == nil && c.nonblockWriteErr != nil { - if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) { - err = ErrWouldBlock - } else { - err = c.nonblockWriteErr - } - } - if err != nil { - // n may be -1 when an error occurs. - if n < 0 { - n = 0 - } - - return n, err - } - - return n, nil -} - func (c *NetConn) nonblockingRead(b []byte) (n int, err error) { if c.rawConn == nil { return c.fakeNonblockingRead(b) @@ -451,36 +421,6 @@ func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) { return c.conn.Read(b) } -func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { - var funcErr error - err = c.rawConn.Read(func(fd uintptr) (done bool) { - n, funcErr = syscall.Read(int(fd), b) - return true - }) - if err == nil && funcErr != nil { - if errors.Is(funcErr, syscall.EWOULDBLOCK) { - err = ErrWouldBlock - } else { - err = funcErr - } - } - if err != nil { - // n may be -1 when an error occurs. - if n < 0 { - n = 0 - } - - return n, err - } - - // syscall read did not return an error and 0 bytes were read means EOF. - if n == 0 { - return 0, io.EOF - } - - return n, nil -} - // syscall.Conn is interface // TLSClient establishes a TLS connection as a client over conn using config. diff --git a/internal/nbconn/nbconn_fake_non_block.go b/internal/nbconn/nbconn_fake_non_block.go new file mode 100644 index 00000000..7e8b7634 --- /dev/null +++ b/internal/nbconn/nbconn_fake_non_block.go @@ -0,0 +1,13 @@ +//go:build !(aix || android || darwin || dragonfly || freebsd || hurd || illumos || ios || linux || netbsd || openbsd || solaris) + +package nbconn + +// Not using unix build tag for support on Go 1.18. + +func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { + return fakeNonblockingWrite(b) +} + +func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { + return c.fakeNonblockingRead(b) +} diff --git a/internal/nbconn/nbconn_real_non_block.go b/internal/nbconn/nbconn_real_non_block.go new file mode 100644 index 00000000..ee48d129 --- /dev/null +++ b/internal/nbconn/nbconn_real_non_block.go @@ -0,0 +1,70 @@ +//go:build aix || android || darwin || dragonfly || freebsd || hurd || illumos || ios || linux || netbsd || openbsd || solaris + +package nbconn + +// Not using unix build tag for support on Go 1.18. + +import ( + "errors" + "io" + "syscall" +) + +// realNonblockingWrite does a non-blocking write. readFlushLock must already be held. +func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { + c.nonblockWriteBuf = b + c.nonblockWriteN = 0 + c.nonblockWriteErr = nil + err = c.rawConn.Write(func(fd uintptr) (done bool) { + c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf) + return true + }) + n = c.nonblockWriteN + if err == nil && c.nonblockWriteErr != nil { + if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) { + err = ErrWouldBlock + } else { + err = c.nonblockWriteErr + } + } + if err != nil { + // n may be -1 when an error occurs. + if n < 0 { + n = 0 + } + + return n, err + } + + return n, nil +} + +func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { + var funcErr error + err = c.rawConn.Read(func(fd uintptr) (done bool) { + n, funcErr = syscall.Read(int(fd), b) + return true + }) + if err == nil && funcErr != nil { + if errors.Is(funcErr, syscall.EWOULDBLOCK) { + err = ErrWouldBlock + } else { + err = funcErr + } + } + if err != nil { + // n may be -1 when an error occurs. + if n < 0 { + n = 0 + } + + return n, err + } + + // syscall read did not return an error and 0 bytes were read means EOF. + if n == 0 { + return 0, io.EOF + } + + return n, nil +} From 906f709e0c779acdc3904b3127dd6be2a9cf1cee Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 11 Aug 2022 20:59:37 -0500 Subject: [PATCH 1135/1158] Fix typo in Windows code https://github.com/jackc/pgx/issues/1274 --- internal/nbconn/nbconn_fake_non_block.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/nbconn/nbconn_fake_non_block.go b/internal/nbconn/nbconn_fake_non_block.go index 7e8b7634..cf05df1c 100644 --- a/internal/nbconn/nbconn_fake_non_block.go +++ b/internal/nbconn/nbconn_fake_non_block.go @@ -5,7 +5,7 @@ package nbconn // Not using unix build tag for support on Go 1.18. func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { - return fakeNonblockingWrite(b) + return c.fakeNonblockingWrite(b) } func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { From 8256ab147f449cecd197021dfee2ee8e7e4d25cc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Aug 2022 08:09:44 -0500 Subject: [PATCH 1136/1158] Add build tag to skip default PG type registration https://github.com/jackc/pgx/issues/1273#issuecomment-1207338136 --- pgtype/doc.go | 8 +++++ pgtype/pgtype.go | 32 ------------------ pgtype/register_default_pg_types.go | 35 ++++++++++++++++++++ pgtype/register_default_pg_types_disabled.go | 6 ++++ 4 files changed, 49 insertions(+), 32 deletions(-) create mode 100644 pgtype/register_default_pg_types.go create mode 100644 pgtype/register_default_pg_types_disabled.go diff --git a/pgtype/doc.go b/pgtype/doc.go index 7b4ed409..834aa0b6 100644 --- a/pgtype/doc.go +++ b/pgtype/doc.go @@ -139,5 +139,13 @@ implementing NumericScanner and passes that to the Codec. Map.Scan and Map.Encode are convenience methods that wrap Map.PlanScan and Map.PlanEncode. Determining how to scan or encode a particular type may be a time consuming operation. Hence the planning and execution steps of a conversion are internally separated. + +Reducing Compiled Binary Size + +pgx.QueryExecModeExec and pgx.QueryExecModeSimpleProtocol require the default PostgreSQL type to be registered for each +Go type used as a query parameter. By default pgx does this for all supported types and their array variants. If an +application does not use those query execution modes or manually registers the default PostgreSQL type for the types it +uses as query parameters it can use the build tag nopgxregisterdefaulttypes. This omits the default type registration +and reduces the compiled binary size by ~2MB. */ package pgtype diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 160f18af..793ec28a 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -410,38 +410,6 @@ func NewMap() *Map { return m } -func registerDefaultPgTypeVariants[T any](m *Map, name string) { - arrayName := "_" + name - - var value T - m.RegisterDefaultPgType(value, name) // T - m.RegisterDefaultPgType(&value, name) // *T - - var sliceT []T - m.RegisterDefaultPgType(sliceT, arrayName) // []T - m.RegisterDefaultPgType(&sliceT, arrayName) // *[]T - - var slicePtrT []*T - m.RegisterDefaultPgType(slicePtrT, arrayName) // []*T - m.RegisterDefaultPgType(&slicePtrT, arrayName) // *[]*T - - var arrayOfT Array[T] - m.RegisterDefaultPgType(arrayOfT, arrayName) // Array[T] - m.RegisterDefaultPgType(&arrayOfT, arrayName) // *Array[T] - - var arrayOfPtrT Array[*T] - m.RegisterDefaultPgType(arrayOfPtrT, arrayName) // Array[*T] - m.RegisterDefaultPgType(&arrayOfPtrT, arrayName) // *Array[*T] - - var flatArrayOfT FlatArray[T] - m.RegisterDefaultPgType(flatArrayOfT, arrayName) // FlatArray[T] - m.RegisterDefaultPgType(&flatArrayOfT, arrayName) // *FlatArray[T] - - var flatArrayOfPtrT FlatArray[*T] - m.RegisterDefaultPgType(flatArrayOfPtrT, arrayName) // FlatArray[*T] - m.RegisterDefaultPgType(&flatArrayOfPtrT, arrayName) // *FlatArray[*T] -} - func (m *Map) RegisterType(t *Type) { m.oidToType[t.OID] = t m.nameToType[t.Name] = t diff --git a/pgtype/register_default_pg_types.go b/pgtype/register_default_pg_types.go new file mode 100644 index 00000000..be1ca4a1 --- /dev/null +++ b/pgtype/register_default_pg_types.go @@ -0,0 +1,35 @@ +//go:build !nopgxregisterdefaulttypes + +package pgtype + +func registerDefaultPgTypeVariants[T any](m *Map, name string) { + arrayName := "_" + name + + var value T + m.RegisterDefaultPgType(value, name) // T + m.RegisterDefaultPgType(&value, name) // *T + + var sliceT []T + m.RegisterDefaultPgType(sliceT, arrayName) // []T + m.RegisterDefaultPgType(&sliceT, arrayName) // *[]T + + var slicePtrT []*T + m.RegisterDefaultPgType(slicePtrT, arrayName) // []*T + m.RegisterDefaultPgType(&slicePtrT, arrayName) // *[]*T + + var arrayOfT Array[T] + m.RegisterDefaultPgType(arrayOfT, arrayName) // Array[T] + m.RegisterDefaultPgType(&arrayOfT, arrayName) // *Array[T] + + var arrayOfPtrT Array[*T] + m.RegisterDefaultPgType(arrayOfPtrT, arrayName) // Array[*T] + m.RegisterDefaultPgType(&arrayOfPtrT, arrayName) // *Array[*T] + + var flatArrayOfT FlatArray[T] + m.RegisterDefaultPgType(flatArrayOfT, arrayName) // FlatArray[T] + m.RegisterDefaultPgType(&flatArrayOfT, arrayName) // *FlatArray[T] + + var flatArrayOfPtrT FlatArray[*T] + m.RegisterDefaultPgType(flatArrayOfPtrT, arrayName) // FlatArray[*T] + m.RegisterDefaultPgType(&flatArrayOfPtrT, arrayName) // *FlatArray[*T] +} diff --git a/pgtype/register_default_pg_types_disabled.go b/pgtype/register_default_pg_types_disabled.go new file mode 100644 index 00000000..56fe7c22 --- /dev/null +++ b/pgtype/register_default_pg_types_disabled.go @@ -0,0 +1,6 @@ +//go:build nopgxregisterdefaulttypes + +package pgtype + +func registerDefaultPgTypeVariants[T any](m *Map, name string) { +} From 02d9a5acd8a1c62ff3b33f0b44bb16425cb5e8bc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Aug 2022 08:41:06 -0500 Subject: [PATCH 1137/1158] Fix naming of some tests --- pgtype/pgtype_test.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 6325829e..11ae39e0 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -118,14 +118,14 @@ func skipPostgreSQLVersionLessThan(t testing.TB, minVersion int64) { } } -func TestTypeMapScanNilIsNoOp(t *testing.T) { +func TestMapScanNilIsNoOp(t *testing.T) { m := pgtype.NewMap() err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), nil) assert.NoError(t, err) } -func TestTypeMapScanTextFormatInterfacePtr(t *testing.T) { +func TestMapScanTextFormatInterfacePtr(t *testing.T) { m := pgtype.NewMap() var got any err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), &got) @@ -133,7 +133,7 @@ func TestTypeMapScanTextFormatInterfacePtr(t *testing.T) { assert.Equal(t, "foo", got) } -func TestTypeMapScanTextFormatNonByteaIntoByteSlice(t *testing.T) { +func TestMapScanTextFormatNonByteaIntoByteSlice(t *testing.T) { m := pgtype.NewMap() var got []byte err := m.Scan(pgtype.JSONBOID, pgx.TextFormatCode, []byte("{}"), &got) @@ -141,7 +141,7 @@ func TestTypeMapScanTextFormatNonByteaIntoByteSlice(t *testing.T) { assert.Equal(t, []byte("{}"), got) } -func TestTypeMapScanBinaryFormatInterfacePtr(t *testing.T) { +func TestMapScanBinaryFormatInterfacePtr(t *testing.T) { m := pgtype.NewMap() var got any err := m.Scan(pgtype.TextOID, pgx.BinaryFormatCode, []byte("foo"), &got) @@ -149,7 +149,7 @@ func TestTypeMapScanBinaryFormatInterfacePtr(t *testing.T) { assert.Equal(t, "foo", got) } -func TestTypeMapScanUnknownOIDToStringsAndBytes(t *testing.T) { +func TestMapScanUnknownOIDToStringsAndBytes(t *testing.T) { unknownOID := uint32(999999) srcBuf := []byte("foo") m := pgtype.NewMap() @@ -175,7 +175,7 @@ func TestTypeMapScanUnknownOIDToStringsAndBytes(t *testing.T) { assert.Equal(t, []byte("foo"), []byte(rb)) } -func TestTypeMapScanPointerToNilStructDoesNotCrash(t *testing.T) { +func TestMapScanPointerToNilStructDoesNotCrash(t *testing.T) { m := pgtype.NewMap() type myStruct struct{} @@ -184,7 +184,7 @@ func TestTypeMapScanPointerToNilStructDoesNotCrash(t *testing.T) { require.NotNil(t, err) } -func TestTypeMapScanUnknownOIDTextFormat(t *testing.T) { +func TestMapScanUnknownOIDTextFormat(t *testing.T) { m := pgtype.NewMap() var n int32 @@ -193,7 +193,7 @@ func TestTypeMapScanUnknownOIDTextFormat(t *testing.T) { assert.EqualValues(t, 123, n) } -func TestTypeMapScanUnknownOIDIntoSQLScanner(t *testing.T) { +func TestMapScanUnknownOIDIntoSQLScanner(t *testing.T) { m := pgtype.NewMap() var s sql.NullString @@ -265,7 +265,7 @@ func TestMapScanPtrToPtrToSlice(t *testing.T) { require.Equal(t, []string{"foo", "bar"}, *v) } -func BenchmarkTypeMapScanInt4IntoBinaryDecoder(b *testing.B) { +func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) { m := pgtype.NewMap() src := []byte{0, 0, 0, 42} var v pgtype.Int4 @@ -282,7 +282,7 @@ func BenchmarkTypeMapScanInt4IntoBinaryDecoder(b *testing.B) { } } -func BenchmarkTypeMapScanInt4IntoGoInt32(b *testing.B) { +func BenchmarkMapScanInt4IntoGoInt32(b *testing.B) { m := pgtype.NewMap() src := []byte{0, 0, 0, 42} var v int32 From 7c6a31f9d271942b4ab4f66f7799e83c0d479ffd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Aug 2022 09:30:29 -0500 Subject: [PATCH 1138/1158] CopyFrom parses strings to encode into binary format https://github.com/jackc/pgx/issues/1277 https://github.com/jackc/pgx/issues/1267 --- copy_from_test.go | 29 +++++++++++++++++++++++++++++ values.go | 24 +++++++++++++++++++++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/copy_from_test.go b/copy_from_test.go index d979d2dc..49bfcb34 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -615,3 +615,32 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { ensureConnValid(t, conn) } + +func TestConnCopyFromAutomaticStringConversion(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int8 + )`) + + inputRows := [][]interface{}{ + {"42"}, + {"7"}, + {8}, + } + + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + rows, _ := conn.Query(context.Background(), "select * from foo") + nums, err := pgx.CollectRows(rows, pgx.RowTo[int64]) + require.NoError(t, err) + + require.Equal(t, []int64{42, 7, 8}, nums) + + ensureConnValid(t, conn) +} diff --git a/values.go b/values.go index d27e071d..19c642fa 100644 --- a/values.go +++ b/values.go @@ -1,6 +1,8 @@ package pgx import ( + "errors" + "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgtype" @@ -36,11 +38,31 @@ func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, er buf = pgio.AppendInt32(buf, -1) argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf) if err != nil { - return nil, err + if argBuf2, err2 := tryScanStringCopyValueThenEncode(m, buf, oid, arg); err2 == nil { + argBuf = argBuf2 + } else { + return nil, err + } } + if argBuf != nil { buf = argBuf pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } return buf, nil } + +func tryScanStringCopyValueThenEncode(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { + s, ok := arg.(string) + if !ok { + return nil, errors.New("not a string") + } + + var v any + err := m.Scan(oid, TextFormatCode, []byte(s), &v) + if err != nil { + return nil, err + } + + return m.Encode(oid, BinaryFormatCode, v, buf) +} From c842802d65459822e7f2e84ee28db606965e5cb2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Aug 2022 09:49:06 -0500 Subject: [PATCH 1139/1158] Failsafe timeout for background pool connections Do not override existing connect timeout. --- pgxpool/pool.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pgxpool/pool.go b/pgxpool/pool.go index df017892..e053904c 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -208,14 +208,14 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { // see https://github.com/jackc/pgx/issues/1259 ctx = detachedCtx{ctx} - // But we do want to ensure that a connect won't hang forever. - ctx, cancel := context.WithTimeout(ctx, 2*time.Minute) - defer cancel() + connConfig := p.config.ConnConfig.Copy() - connConfig := p.config.ConnConfig + // But we do want to ensure that a connect won't hang forever. + if connConfig.ConnectTimeout <= 0 { + connConfig.ConnectTimeout = 2 * time.Minute + } if p.beforeConnect != nil { - connConfig = p.config.ConnConfig.Copy() if err := p.beforeConnect(ctx, connConfig); err != nil { return nil, err } From faabb0696f92988c815fc675c61c1961c87a7047 Mon Sep 17 00:00:00 2001 From: Nathan Giardina Date: Fri, 12 Aug 2022 20:40:44 +0000 Subject: [PATCH 1140/1158] Fix for timeout when a single node has timed out, created a new context to allow for each db node to timeout individually --- pgconn.go | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pgconn.go b/pgconn.go index 17f19e95..b23d7ded 100644 --- a/pgconn.go +++ b/pgconn.go @@ -128,19 +128,13 @@ func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptio // authentication error will terminate the chain of attempts (like libpq: // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, // if all attempts fail the last error is returned. -func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) { +func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err error) { // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. if !config.createdByParseConfig { panic("config must be created by ParseConfig") } - // ConnectTimeout restricts the whole connection process. - if config.ConnectTimeout != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, config.ConnectTimeout) - defer cancel() - } // Simplify usage by treating primary config and fallbacks the same. fallbackConfigs := []*FallbackConfig{ { @@ -150,7 +144,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err }, } fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) - + ctx := octx fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) if err != nil { return nil, &connectError{config: config, msg: "hostname resolving error", err: err} @@ -163,6 +157,14 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err foundBestServer := false var fallbackConfig *FallbackConfig for _, fc := range fallbackConfigs { + // ConnectTimeout restricts the whole connection process. + if config.ConnectTimeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) + defer cancel() + } else { + ctx = octx + } pgConn, err = connect(ctx, config, fc, false) if err == nil { foundBestServer = true From 067771b2e67a789dfc4844e54e3c3f5443e58fca Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Mon, 15 Aug 2022 23:24:34 +0300 Subject: [PATCH 1141/1158] Set SNI for SSL connections This allows an SNI-aware proxy to route connections. Patch adds a new connection option (`sslsni`) to opt out of the SNI, to have the same behavior as `libpq` does. See more in `sslsni` sections at . --- config.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/config.go b/config.go index 2277dc1d..0a276c6b 100644 --- a/config.go +++ b/config.go @@ -297,6 +297,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con "sslcert": {}, "sslrootcert": {}, "sslpassword": {}, + "sslsni": {}, "krbspn": {}, "krbsrvname": {}, "target_session_attrs": {}, @@ -424,6 +425,7 @@ func parseEnvSettings() map[string]string { "PGSSLMODE": "sslmode", "PGSSLKEY": "sslkey", "PGSSLCERT": "sslcert", + "PGSSLSNI": "sslsni", "PGSSLROOTCERT": "sslrootcert", "PGSSLPASSWORD": "sslpassword", "PGTARGETSESSIONATTRS": "target_session_attrs", @@ -619,11 +621,15 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P sslcert := settings["sslcert"] sslkey := settings["sslkey"] sslpassword := settings["sslpassword"] + sslsni := settings["sslsni"] // Match libpq default behavior if sslmode == "" { sslmode = "prefer" } + if sslsni == "" { + sslsni = "1" + } tlsConfig := &tls.Config{} @@ -756,6 +762,10 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P tlsConfig.Certificates = []tls.Certificate{cert} } + if sslsni == "1" { + tlsConfig.ServerName = host + } + switch sslmode { case "allow": return []*tls.Config{nil, tlsConfig}, nil From e3406d95f9be53211a295ee4b8f7208a27f48f05 Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Tue, 16 Aug 2022 02:03:11 +0300 Subject: [PATCH 1142/1158] Add test coverage for client SNI --- pgconn_test.go | 172 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) diff --git a/pgconn_test.go b/pgconn_test.go index a4f0ec63..302dffd5 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -2169,3 +2169,175 @@ func GetSSLPassword(ctx context.Context) string { connString := os.Getenv("PGX_SSL_PASSWORD") return connString } + +var rsaCertPEM = `-----BEGIN CERTIFICATE----- +MIIDCTCCAfGgAwIBAgIUQDlN1g1bzxIJ8KWkayNcQY5gzMEwDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMDgxNTIxNDgyNloXDTIzMDgx +NTIxNDgyNlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF +AAOCAQ8AMIIBCgKCAQEA0vOppiT8zE+076acRORzD5JVbRYKMK3XlWLVrHua4+ct +Rm54WyP+3XsYU4JGGGKgb8E+u2UosGJYcSM+b+U1/5XPTcpuumS+pCiD9WP++A39 +tsukYwR7m65cgpiI4dlLEZI3EWpAW+Bb3230KiYW4sAmQ0Ih4PrN+oPvzcs86F4d +9Y03CqVUxRKLBLaClZQAg8qz2Pawwj1FKKjDX7u2fRVR0wgOugpCMOBJMcCgz9pp +0HSa4x3KZDHEZY7Pah5XwWrCfAEfRWsSTGcNaoN8gSxGFM1JOEJa8SAuPGjFcYIv +MmVWdw0FXCgYlSDL02fzLE0uyvXBDibzSqOk770JhQIDAQABo1MwUTAdBgNVHQ4E +FgQUiJ8JLENJ+2k1Xl4o6y2Lc/qHTh0wHwYDVR0jBBgwFoAUiJ8JLENJ+2k1Xl4o +6y2Lc/qHTh0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAwjn2 +gnNAhFvh58VqLIjU6ftvn6rhz5B9dg2+XyY8sskLhhkO1nL9339BVZsRt+eI3a7I +81GNIm9qHVM3MUAcQv3SZy+0UPVUT8DNH2LwHT3CHnYTBP8U+8n8TDNGSTMUhIBB +Rx+6KwODpwLdI79VGT3IkbU9bZwuepB9I9nM5t/tt5kS4gHmJFlO0aLJFCTO4Scf +hp/WLPv4XQUH+I3cPfaJRxz2j0Kc8iOzMhFmvl1XOGByjX6X33LnOzY/LVeTSGyS +VgC32BGtnMwuy5XZYgFAeUx9HKy4tG4OH2Ux6uPF/WAhsug6PXSjV7BK6wYT5i27 +MlascjupnaptKX/wMA== +-----END CERTIFICATE----- +` + +var rsaKeyPEM = testingKey(`-----BEGIN TESTING KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDS86mmJPzMT7Tv +ppxE5HMPklVtFgowrdeVYtWse5rj5y1GbnhbI/7dexhTgkYYYqBvwT67ZSiwYlhx +Iz5v5TX/lc9Nym66ZL6kKIP1Y/74Df22y6RjBHubrlyCmIjh2UsRkjcRakBb4Fvf +bfQqJhbiwCZDQiHg+s36g+/NyzzoXh31jTcKpVTFEosEtoKVlACDyrPY9rDCPUUo +qMNfu7Z9FVHTCA66CkIw4EkxwKDP2mnQdJrjHcpkMcRljs9qHlfBasJ8AR9FaxJM +Zw1qg3yBLEYUzUk4QlrxIC48aMVxgi8yZVZ3DQVcKBiVIMvTZ/MsTS7K9cEOJvNK +o6TvvQmFAgMBAAECggEAKzTK54Ol33bn2TnnwdiElIjlRE2CUswYXrl6iDRc2hbs +WAOiVRB/T/+5UMla7/2rXJhY7+rdNZs/ABU24ZYxxCJ77jPrD/Q4c8j0lhsgCtBa +ycjV543wf0dsHTd+ubtWu8eVzdRUUD0YtB+CJevdPh4a+CWgaMMV0xyYzi61T+Yv +Z7Uc3awIAiT4Kw9JRmJiTnyMJg5vZqW3BBAX4ZIvS/54ipwEU+9sWLcuH7WmCR0B +QCTqS6hfJDLm//dGC89Iyno57zfYuiT3PYCWH5crr/DH3LqnwlNaOGSBkhkXuIL+ +QvOaUMe2i0pjqxDrkBx05V554vyy9jEvK7i330HL4QKBgQDUJmouEr0+o7EMBApC +CPPu58K04qY5t9aGciG/pOurN42PF99yNZ1CnynH6DbcnzSl8rjc6Y65tzTlWods +bjwVfcmcokG7sPcivJvVjrjKpSQhL8xdZwSAjcqjN4yoJ/+ghm9w+SRmZr6oCQZ3 +1jREfJKT+PGiWTEjYcExPWUD2QKBgQD+jdgq4c3tFavU8Hjnlf75xbStr5qu+fp2 +SGLRRbX+msQwVbl2ZM9AJLoX9MTCl7D9zaI3ONhheMmfJ77lDTa3VMFtr3NevGA6 +MxbiCEfRtQpNkJnsqCixLckx3bskj5+IF9BWzw7y7nOzdhoWVFv/+TltTm3RB51G +McdlmmVjjQKBgQDSFAw2/YV6vtu2O1XxGC591/Bd8MaMBziev+wde3GHhaZfGVPC +I8dLTpMwCwowpFKdNeLLl1gnHX161I+f1vUWjw4TVjVjaBUBx+VEr2Tb/nXtiwiD +QV0a883CnGJjreAblKRMKdpasMmBWhaWmn39h6Iad3zHuCzJjaaiXNpn2QKBgQCf +k1Q8LanmQnuh1c41f7aD5gjKCRezMUpt9BrejhD1NxheJJ9LNQ8nat6uPedLBcUS +lmJms+AR2qKqf0QQWyQ98YgAtshgTz8TvQtPT1mWgSOgVFHqJdC8obNK63FyDgc4 +TZVxlgQNDqbBjfv0m5XA9f+mIlB9hYR2iKYzb4K30QKBgQC+LEJYZh00zsXttGHr +5wU1RzbgDIEsNuu+nZ4MxsaCik8ILNRHNXdeQbnADKuo6ATfhdmDIQMVZLG8Mivi +UwnwLd1GhizvqvLHa3ULnFphRyMGFxaLGV48axTT2ADoMX67ILrIY/yjycLqRZ3T +z3w+CgS20UrbLIR1YXfqUXge1g== +-----END TESTING KEY----- +`) + +func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } + +func TestSNISupport(t *testing.T) { + t.Parallel() + tests := []struct { + name string + sni_param string + sni_set bool + }{ + { + name: "SNI is passed by default", + sni_param: "", + sni_set: true, + }, + { + name: "SNI is passed when asked for", + sni_param: "sslsni=1", + sni_set: true, + }, + { + name: "SNI is not passed when disabled", + sni_param: "sslsni=0", + sni_set: false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + serverSNINameChan := make(chan string, 1) + defer close(serverErrChan) + defer close(serverSNINameChan) + + go func() { + var sniHost string + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + serverErrChan <- err + return + } + + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn) + startupMessage, err := backend.ReceiveStartupMessage() + if err != nil { + serverErrChan <- err + return + } + + switch startupMessage.(type) { + case *pgproto3.SSLRequest: + _, err = conn.Write([]byte("S")) + if err != nil { + serverErrChan <- err + return + } + default: + serverErrChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage) + return + } + + cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM)) + if err != nil { + serverErrChan <- err + return + } + + srv := tls.Server(conn, &tls.Config{ + Certificates: []tls.Certificate{cert}, + GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { + sniHost = argHello.ServerName + return nil, nil + }, + }) + defer srv.Close() + + if err := srv.Handshake(); err != nil { + serverErrChan <- fmt.Errorf("handshake: %v", err) + return + } + + srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil)) + srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)) + srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)) + + serverSNINameChan <- sniHost + }() + + port := strings.Split(ln.Addr().String(), ":")[1] + connStr := fmt.Sprintf("sslmode=require host=localhost port=%s %s", port, tt.sni_param) + _, err = pgconn.Connect(context.Background(), connStr) + + select { + case sniHost := <-serverSNINameChan: + if tt.sni_set { + require.Equal(t, sniHost, "localhost") + } else { + require.Equal(t, sniHost, "") + } + case err = <-serverErrChan: + t.Fatalf("server failed with error: %+v", err) + case <-time.After(time.Millisecond * 100): + t.Fatal("exceeded connection timeout without erroring out") + } + }) + } +} From 15f8e6323e113feb44678e5cb16cfbeba4630bf6 Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Wed, 17 Aug 2022 12:04:42 +0300 Subject: [PATCH 1143/1158] Fix tests that check tls.Config.ServerName -- with SNI this field is filled, unless SNI is delibaretely disabled. Also, do not set SNI when host is an IP address as per RFC 6066. --- config.go | 5 ++- config_test.go | 105 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 108 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 0a276c6b..4080f2c6 100644 --- a/config.go +++ b/config.go @@ -762,7 +762,10 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P tlsConfig.Certificates = []tls.Certificate{cert} } - if sslsni == "1" { + // Set Server Name Indication (SNI), if enabled by connection parameters. + // Per RFC 6066, do not set it if the host is a literal IP address (IPv4 + // or IPv6). + if sslsni == "1" && net.ParseIP(host) == nil { tlsConfig.ServerName = host } diff --git a/config_test.go b/config_test.go index 6b48ea27..629b5c0f 100644 --- a/config_test.go +++ b/config_test.go @@ -53,6 +53,7 @@ func TestParseConfig(t *testing.T) { Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, + ServerName: "localhost", }, RuntimeParams: map[string]string{}, Fallbacks: []*pgconn.FallbackConfig{ @@ -94,6 +95,7 @@ func TestParseConfig(t *testing.T) { Port: 5432, TLSConfig: &tls.Config{ InsecureSkipVerify: true, + ServerName: "localhost", }, }, }, @@ -111,6 +113,7 @@ func TestParseConfig(t *testing.T) { Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, + ServerName: "localhost", }, RuntimeParams: map[string]string{}, Fallbacks: []*pgconn.FallbackConfig{ @@ -133,6 +136,7 @@ func TestParseConfig(t *testing.T) { Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, + ServerName: "localhost", }, RuntimeParams: map[string]string{}, }, @@ -148,6 +152,7 @@ func TestParseConfig(t *testing.T) { Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, + ServerName: "localhost", }, RuntimeParams: map[string]string{}, }, @@ -519,6 +524,7 @@ func TestParseConfig(t *testing.T) { Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, + ServerName: "foo", }, RuntimeParams: map[string]string{}, Fallbacks: []*pgconn.FallbackConfig{ @@ -532,6 +538,7 @@ func TestParseConfig(t *testing.T) { Port: 5432, TLSConfig: &tls.Config{ InsecureSkipVerify: true, + ServerName: "bar", }}, &pgconn.FallbackConfig{ Host: "bar", @@ -543,6 +550,7 @@ func TestParseConfig(t *testing.T) { Port: 5432, TLSConfig: &tls.Config{ InsecureSkipVerify: true, + ServerName: "baz", }}, &pgconn.FallbackConfig{ Host: "baz", @@ -648,6 +656,82 @@ func TestParseConfig(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + name: "SNI is set by default", + connString: "postgres://jack:secret@sni.test:5432/mydb?sslmode=require", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "sni.test", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "sni.test", + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "SNI is not set for IPv4", + connString: "postgres://jack:secret@1.1.1.1:5432/mydb?sslmode=require", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "1.1.1.1", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "SNI is not set for IPv6", + connString: "postgres://jack:secret@[::1]:5432/mydb?sslmode=require", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "::1", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "SNI is not set when disabled (URL-style)", + connString: "postgres://jack:secret@sni.test:5432/mydb?sslmode=require&sslsni=0", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "sni.test", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "SNI is not set when disabled (key/value style)", + connString: "user=jack password=secret host=sni.test dbname=mydb sslmode=require sslsni=0", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "sni.test", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, } for i, tt := range tests { @@ -820,7 +904,7 @@ func TestParseConfigEnvLibpq(t *testing.T) { } } - pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"} + pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT", "PGSSLSNI"} savedEnv := make(map[string]string) for _, n := range pgEnvvars { @@ -884,6 +968,23 @@ func TestParseConfigEnvLibpq(t *testing.T) { RuntimeParams: map[string]string{"application_name": "pgxtest"}, }, }, + { + name: "SNI can be disabled via environment variable", + envvars: map[string]string{ + "PGHOST": "test.foo", + "PGSSLMODE": "require", + "PGSSLSNI": "0", + }, + config: &pgconn.Config{ + User: osUserName, + Host: "test.foo", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, } for i, tt := range tests { @@ -974,6 +1075,7 @@ application_name = spaced string Port: 9999, TLSConfig: &tls.Config{ InsecureSkipVerify: true, + ServerName: "abc.example.com", }, RuntimeParams: map[string]string{}, Fallbacks: []*pgconn.FallbackConfig{ @@ -995,6 +1097,7 @@ application_name = spaced string User: "defuser", TLSConfig: &tls.Config{ InsecureSkipVerify: true, + ServerName: "def.example.com", }, RuntimeParams: map[string]string{"application_name": "spaced string"}, Fallbacks: []*pgconn.FallbackConfig{ From dbee461dc9226a78a4fdba3fa8892650e46f3527 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 Aug 2022 17:42:04 -0500 Subject: [PATCH 1144/1158] Update previous pgconn merge for v5 --- pgconn/pgconn_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 0935aa3c..153072c2 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -4,6 +4,7 @@ import ( "bytes" "compress/gzip" "context" + "crypto/tls" "errors" "fmt" "io" @@ -2738,7 +2739,7 @@ func TestSNISupport(t *testing.T) { return } - backend := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn) + backend := pgproto3.NewBackend(conn, conn) startupMessage, err := backend.ReceiveStartupMessage() if err != nil { serverErrChan <- err From ae65a8007b3078346959e4b80d1f37d3730b92df Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Aug 2022 09:54:30 -0500 Subject: [PATCH 1145/1158] Use higher pgconn.FieldDescription with string Name Instead of using pgproto3.FieldDescription through pgconn and pgx. This lets the lowest level pgproto3 still be as memory efficient as possible. https://github.com/jackc/pgx/pull/1281 --- pgconn/pgconn.go | 61 ++++++++++++++++++++++++++++++++----------- pgconn/pgconn_test.go | 10 +++---- pgxpool/rows.go | 21 +++++++-------- query_test.go | 2 +- rows.go | 9 +++---- 5 files changed, 66 insertions(+), 37 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 546a4bd0..e29283f4 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -83,6 +83,7 @@ type PgConn struct { multiResultReader MultiResultReader pipeline Pipeline contextWatcher *ctxwatch.ContextWatcher + fieldDescriptions [16]FieldDescription cleanupDone chan struct{} } @@ -526,9 +527,10 @@ func (pgConn *PgConn) PID() uint32 { // TxStatus returns the current TxStatus as reported by the server in the ReadyForQuery message. // // Possible return values: -// 'I' - idle / not in transaction -// 'T' - in a transaction -// 'E' - in a failed transaction +// +// 'I' - idle / not in transaction +// 'T' - in a transaction +// 'E' - in a failed transaction // // See https://www.postgresql.org/docs/current/protocol-message-formats.html. func (pgConn *PgConn) TxStatus() byte { @@ -714,11 +716,41 @@ func (ct CommandTag) Select() bool { return strings.HasPrefix(ct.s, "SELECT") } +type FieldDescription struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + Format int16 +} + +func (pgConn *PgConn) convertRowDescription(dst []FieldDescription, rd *pgproto3.RowDescription) []FieldDescription { + if cap(dst) >= len(rd.Fields) { + dst = dst[:len(rd.Fields):len(rd.Fields)] + } else { + dst = make([]FieldDescription, len(rd.Fields)) + } + + for i := range rd.Fields { + dst[i].Name = string(rd.Fields[i].Name) + dst[i].TableOID = rd.Fields[i].TableOID + dst[i].TableAttributeNumber = rd.Fields[i].TableAttributeNumber + dst[i].DataTypeOID = rd.Fields[i].DataTypeOID + dst[i].DataTypeSize = rd.Fields[i].DataTypeSize + dst[i].TypeModifier = rd.Fields[i].TypeModifier + dst[i].Format = rd.Fields[i].Format + } + + return dst +} + type StatementDescription struct { Name string SQL string ParamOIDs []uint32 - Fields []pgproto3.FieldDescription + Fields []FieldDescription } // Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This @@ -765,8 +797,7 @@ readloop: psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) copy(psd.ParamOIDs, msg.ParameterOIDs) case *pgproto3.RowDescription: - psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) - copy(psd.Fields, msg.Fields) + psd.Fields = pgConn.convertRowDescription(nil, msg) case *pgproto3.ErrorResponse: parseErr = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: @@ -1281,8 +1312,9 @@ func (mrr *MultiResultReader) NextResult() bool { pgConn: mrr.pgConn, multiResultReader: mrr, ctx: mrr.ctx, - fieldDescriptions: msg.Fields, + fieldDescriptions: mrr.pgConn.convertRowDescription(mrr.pgConn.fieldDescriptions[:], msg), } + mrr.rr = &mrr.pgConn.resultReader return true case *pgproto3.CommandComplete: @@ -1325,7 +1357,7 @@ type ResultReader struct { pipeline *Pipeline ctx context.Context - fieldDescriptions []pgproto3.FieldDescription + fieldDescriptions []FieldDescription rowValues [][]byte commandTag CommandTag commandConcluded bool @@ -1335,7 +1367,7 @@ type ResultReader struct { // Result is the saved query response that is returned by calling Read on a ResultReader. type Result struct { - FieldDescriptions []pgproto3.FieldDescription + FieldDescriptions []FieldDescription Rows [][][]byte CommandTag CommandTag Err error @@ -1347,7 +1379,7 @@ func (rr *ResultReader) Read() *Result { for rr.NextRow() { if br.FieldDescriptions == nil { - br.FieldDescriptions = make([]pgproto3.FieldDescription, len(rr.FieldDescriptions())) + br.FieldDescriptions = make([]FieldDescription, len(rr.FieldDescriptions())) copy(br.FieldDescriptions, rr.FieldDescriptions()) } @@ -1385,7 +1417,7 @@ func (rr *ResultReader) NextRow() bool { // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until // the ResultReader is closed. -func (rr *ResultReader) FieldDescriptions() []pgproto3.FieldDescription { +func (rr *ResultReader) FieldDescriptions() []FieldDescription { return rr.fieldDescriptions } @@ -1473,7 +1505,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error switch msg := msg.(type) { case *pgproto3.RowDescription: - rr.fieldDescriptions = msg.Fields + rr.fieldDescriptions = rr.pgConn.convertRowDescription(rr.pgConn.fieldDescriptions[:], msg) case *pgproto3.CommandComplete: rr.concludeCommand(rr.pgConn.makeCommandTag(msg.CommandTag), nil) case *pgproto3.EmptyQueryResponse: @@ -1825,7 +1857,7 @@ func (p *Pipeline) GetResults() (results any, err error) { pgConn: p.conn, pipeline: p, ctx: p.ctx, - fieldDescriptions: msg.Fields, + fieldDescriptions: p.conn.convertRowDescription(p.conn.fieldDescriptions[:], msg), } return &p.conn.resultReader, nil case *pgproto3.CommandComplete: @@ -1872,8 +1904,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) copy(psd.ParamOIDs, msg.ParameterOIDs) case *pgproto3.RowDescription: - psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) - copy(psd.Fields, msg.Fields) + psd.Fields = p.conn.convertRowDescription(nil, msg) return psd, nil // NoData is returned instead of RowDescription when there is no expected result. e.g. An INSERT without a RETURNING diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 153072c2..601cbc8e 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -642,13 +642,13 @@ func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) { require.True(t, mrr.NextResult()) require.Len(t, mrr.ResultReader().FieldDescriptions(), 1) - assert.Equal(t, []byte("msg"), mrr.ResultReader().FieldDescriptions()[0].Name) + assert.Equal(t, "msg", mrr.ResultReader().FieldDescriptions()[0].Name) _, err = mrr.ResultReader().Close() require.NoError(t, err) require.True(t, mrr.NextResult()) require.Len(t, mrr.ResultReader().FieldDescriptions(), 1) - assert.Equal(t, []byte("num"), mrr.ResultReader().FieldDescriptions()[0].Name) + assert.Equal(t, "num", mrr.ResultReader().FieldDescriptions()[0].Name) _, err = mrr.ResultReader().Close() require.NoError(t, err) @@ -772,7 +772,7 @@ func TestConnExecParams(t *testing.T) { result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil) require.Len(t, result.FieldDescriptions(), 1) - assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) + assert.Equal(t, "msg", result.FieldDescriptions()[0].Name) rowCount := 0 for result.NextRow() { @@ -937,7 +937,7 @@ func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) { result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil) require.Len(t, result.FieldDescriptions(), 1) - assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) + assert.Equal(t, "msg", result.FieldDescriptions()[0].Name) rowCount := 0 for result.NextRow() { @@ -968,7 +968,7 @@ func TestConnExecPrepared(t *testing.T) { result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) require.Len(t, result.FieldDescriptions(), 1) - assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) + assert.Equal(t, "msg", result.FieldDescriptions()[0].Name) rowCount := 0 for result.NextRow() { diff --git a/pgxpool/rows.go b/pgxpool/rows.go index 0c0a7382..2b11ecd3 100644 --- a/pgxpool/rows.go +++ b/pgxpool/rows.go @@ -3,22 +3,21 @@ package pgxpool import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgproto3" ) type errRows struct { err error } -func (errRows) Close() {} -func (e errRows) Err() error { return e.err } -func (errRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} } -func (errRows) FieldDescriptions() []pgproto3.FieldDescription { return nil } -func (errRows) Next() bool { return false } -func (e errRows) Scan(dest ...any) error { return e.err } -func (e errRows) Values() ([]any, error) { return nil, e.err } -func (e errRows) RawValues() [][]byte { return nil } -func (e errRows) Conn() *pgx.Conn { return nil } +func (errRows) Close() {} +func (e errRows) Err() error { return e.err } +func (errRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} } +func (errRows) FieldDescriptions() []pgconn.FieldDescription { return nil } +func (errRows) Next() bool { return false } +func (e errRows) Scan(dest ...any) error { return e.err } +func (e errRows) Values() ([]any, error) { return nil, e.err } +func (e errRows) RawValues() [][]byte { return nil } +func (e errRows) Conn() *pgx.Conn { return nil } type errRow struct { err error @@ -51,7 +50,7 @@ func (rows *poolRows) CommandTag() pgconn.CommandTag { return rows.r.CommandTag() } -func (rows *poolRows) FieldDescriptions() []pgproto3.FieldDescription { +func (rows *poolRows) FieldDescriptions() []pgconn.FieldDescription { return rows.r.FieldDescriptions() } diff --git a/query_test.go b/query_test.go index 317e4f60..720a1911 100644 --- a/query_test.go +++ b/query_test.go @@ -66,7 +66,7 @@ func TestConnQueryRowsFieldDescriptionsBeforeNext(t *testing.T) { defer rows.Close() require.Len(t, rows.FieldDescriptions(), 1) - assert.Equal(t, []byte("msg"), rows.FieldDescriptions()[0].Name) + assert.Equal(t, "msg", rows.FieldDescriptions()[0].Name) } func TestConnQueryWithoutResultSetCommandTag(t *testing.T) { diff --git a/rows.go b/rows.go index c4fd283c..80df4bb2 100644 --- a/rows.go +++ b/rows.go @@ -9,7 +9,6 @@ import ( "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" ) @@ -34,7 +33,7 @@ type Rows interface { // CommandTag returns the command tag from this query. It is only available after Rows is closed. CommandTag() pgconn.CommandTag - FieldDescriptions() []pgproto3.FieldDescription + FieldDescriptions() []pgconn.FieldDescription // Next prepares the next row for reading. It returns true if there is another // row and false if no more rows are available. It automatically closes rows @@ -135,7 +134,7 @@ type baseRows struct { rowCount int } -func (rows *baseRows) FieldDescriptions() []pgproto3.FieldDescription { +func (rows *baseRows) FieldDescriptions() []pgconn.FieldDescription { return rows.resultReader.FieldDescriptions() } @@ -337,7 +336,7 @@ func (e ScanArgError) Unwrap() error { // fieldDescriptions - OID and format of values // values - the raw data as returned from the PostgreSQL server // dest - the destination that values will be decoded into -func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgproto3.FieldDescription, values [][]byte, dest ...any) error { +func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, values [][]byte, dest ...any) error { if len(fieldDescriptions) != len(values) { return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) } @@ -395,7 +394,7 @@ func ForEachRow(rows Rows, scans []any, fn func() error) (pgconn.CommandTag, err // CollectableRow is the subset of Rows methods that a RowToFunc is allowed to call. type CollectableRow interface { - FieldDescriptions() []pgproto3.FieldDescription + FieldDescriptions() []pgconn.FieldDescription Scan(dest ...any) error Values() ([]any, error) RawValues() [][]byte From 0d5d8e013747f33b6798d474ca6604ce3e1f5615 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 22 Aug 2022 20:06:42 -0500 Subject: [PATCH 1146/1158] Fallback to other format when encoding query arguments The preferred format may not be possible for certain arguments. For example, the preferred format for numeric is binary. But if shopspring/decimal is being used without jackc/pgx-shopspring-decimal then it will use the database/sql/driver.Valuer interface. This will return a string. That string should be sent in the text format. A similar case occurs when encoding a []string into a non-text PostgreSQL array such as uuid[]. --- extended_query_builder.go | 23 +++++++++++++++++++++-- pgtype/array_codec_test.go | 33 +++++++++++++++++++++++++++++++++ query_test.go | 22 +++++++++++++++++++++- 3 files changed, 75 insertions(+), 3 deletions(-) diff --git a/extended_query_builder.go b/extended_query_builder.go index 1c47063c..b0c0e02b 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -51,14 +51,33 @@ func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescri // must be an untyped nil. func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error { if format == -1 { - format = eqb.chooseParameterFormatCode(m, oid, arg) + preferredFormat := eqb.chooseParameterFormatCode(m, oid, arg) + preferredErr := eqb.appendParam(m, oid, preferredFormat, arg) + if preferredErr == nil { + return nil + } + + var otherFormat int16 + if preferredFormat == TextFormatCode { + otherFormat = BinaryFormatCode + } else { + otherFormat = TextFormatCode + } + + otherErr := eqb.appendParam(m, oid, otherFormat, arg) + if otherErr == nil { + return nil + } + + return preferredErr // return the error from the preferred format } - eqb.ParamFormats = append(eqb.ParamFormats, format) v, err := eqb.encodeExtendedParamValue(m, oid, format, arg) if err != nil { return err } + + eqb.ParamFormats = append(eqb.ParamFormats, format) eqb.ParamValues = append(eqb.ParamValues, v) return nil diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go index 9da027e8..a558d0fc 100644 --- a/pgtype/array_codec_test.go +++ b/pgtype/array_codec_test.go @@ -2,6 +2,8 @@ package pgtype_test import ( "context" + "encoding/hex" + "strings" "testing" pgx "github.com/jackc/pgx/v5" @@ -124,6 +126,37 @@ func TestArrayCodecAnySlice(t *testing.T) { }) } +// https://github.com/jackc/pgx/issues/1273#issuecomment-1218262703 +func TestArrayCodecSliceArgConversion(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + arg := []string{ + "3ad95bfd-ecea-4032-83c3-0c823cafb372", + "951baf11-c0cc-4afc-a779-abff0611dbf1", + "8327f244-7e2f-45e7-a10b-fbdc9d6f3378", + } + + var expected []pgtype.UUID + + for _, s := range arg { + buf, err := hex.DecodeString(strings.ReplaceAll(s, "-", "")) + require.NoError(t, err) + var u pgtype.UUID + copy(u.Bytes[:], buf) + u.Valid = true + expected = append(expected, u) + } + + var actual []pgtype.UUID + err := conn.QueryRow( + ctx, + "select $1::uuid[]", + arg, + ).Scan(&actual) + require.NoError(t, err) + require.Equal(t, expected, actual) + }) +} + func TestArrayCodecDecodeValue(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { for _, tt := range []struct { diff --git a/query_test.go b/query_test.go index 720a1911..b2aa5d10 100644 --- a/query_test.go +++ b/query_test.go @@ -1165,7 +1165,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes ensureConnValid(t, conn) } -func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t *testing.T) { +func TestConnQueryDatabaseSQLDriverScannerWithBinaryPgTypeThatAcceptsSameType(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -1181,6 +1181,26 @@ func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t * ensureConnValid(t, conn) } +// https://github.com/jackc/pgx/issues/1273#issuecomment-1221672175 +func TestConnQueryDatabaseSQLDriverValuerTextWhenBinaryIsPreferred(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + arg := sql.NullString{String: "1.234", Valid: true} + var result pgtype.Numeric + err := conn.QueryRow(context.Background(), "select $1::numeric", arg).Scan(&result) + require.NoError(t, err) + + require.True(t, result.Valid) + f64, err := result.Float64Value() + require.NoError(t, err) + require.Equal(t, pgtype.Float8{Float64: 1.234, Valid: true}, f64) + + ensureConnValid(t, conn) +} + func TestConnQueryDatabaseSQLNullX(t *testing.T) { t.Parallel() From 2e73d1e8eeccb2882ca43d09345c6d0c1048b406 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 22 Aug 2022 20:56:36 -0500 Subject: [PATCH 1147/1158] Improve error message when failing to scan a NULL::json --- pgtype/json.go | 3 +++ pgtype/json_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/pgtype/json.go b/pgtype/json.go index de2f08df..51f2509f 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/json" + "fmt" "reflect" ) @@ -129,6 +130,8 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { return nil } } + + return fmt.Errorf("cannot scan null into %T", dst) } return json.Unmarshal(src, dst) diff --git a/pgtype/json_test.go b/pgtype/json_test.go index c349fa24..04964e8c 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -4,7 +4,9 @@ import ( "context" "testing" + pgx "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" ) func isExpectedEqMap(a any) func(any) bool { @@ -58,3 +60,36 @@ func TestJSONCodec(t *testing.T) { {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, }) } + +// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648 +func TestJSONCodecUnmarshalSQLNull(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + // Slices are nilified + slice := []string{"foo", "bar", "baz"} + err := conn.QueryRow(ctx, "select null::json").Scan(&slice) + require.NoError(t, err) + require.Nil(t, slice) + + // Maps are nilified + m := map[string]any{"foo": "bar"} + err = conn.QueryRow(ctx, "select null::json").Scan(&m) + require.NoError(t, err) + require.Nil(t, m) + + // Pointer to pointer are nilified + n := 42 + p := &n + err = conn.QueryRow(ctx, "select null::json").Scan(&p) + require.NoError(t, err) + require.Nil(t, p) + + // A string cannot scan a NULL. + str := "foobar" + err = conn.QueryRow(ctx, "select null::json").Scan(&str) + require.EqualError(t, err, "can't scan into dest[0]: cannot scan null into *string") + + // A non-string cannot scan a NULL. + err = conn.QueryRow(ctx, "select null::json").Scan(&n) + require.EqualError(t, err, "can't scan into dest[0]: cannot scan null into *int") + }) +} From fe3a4f31503fd23b9e560ed372bd365368b7f880 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 22 Aug 2022 20:58:22 -0500 Subject: [PATCH 1148/1158] Standardize casing for NULL in error messages --- pgtype/bool.go | 4 +-- pgtype/enum_codec.go | 2 +- pgtype/float4.go | 4 +-- pgtype/float8.go | 4 +-- pgtype/int.go | 80 +++++++++++++++++++++---------------------- pgtype/int.go.erb | 12 +++---- pgtype/json.go | 2 +- pgtype/json_test.go | 4 +-- pgtype/pgtype.go | 2 +- pgtype/qchar.go | 4 +-- pgtype/range_codec.go | 16 ++++----- pgtype/text.go | 2 +- pgtype/uint32.go | 2 +- 13 files changed, 69 insertions(+), 69 deletions(-) diff --git a/pgtype/bool.go b/pgtype/bool.go index 6f3ef8ca..e7be27e2 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -240,7 +240,7 @@ type scanPlanBinaryBoolToBool struct{} func (scanPlanBinaryBoolToBool) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 1 { @@ -261,7 +261,7 @@ type scanPlanTextAnyToBool struct{} func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 1 { diff --git a/pgtype/enum_codec.go b/pgtype/enum_codec.go index 3d23b12f..5e787c1e 100644 --- a/pgtype/enum_codec.go +++ b/pgtype/enum_codec.go @@ -85,7 +85,7 @@ type scanPlanTextAnyToEnumString struct { func (plan *scanPlanTextAnyToEnumString) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } p := (dst).(*string) diff --git a/pgtype/float4.go b/pgtype/float4.go index a68fa7b2..2540f9e5 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -177,7 +177,7 @@ type scanPlanBinaryFloat4ToFloat32 struct{} func (scanPlanBinaryFloat4ToFloat32) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 4 { @@ -254,7 +254,7 @@ type scanPlanTextAnyToFloat32 struct{} func (scanPlanTextAnyToFloat32) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } n, err := strconv.ParseFloat(string(src), 32) diff --git a/pgtype/float8.go b/pgtype/float8.go index 98334dc6..6af27d6f 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -215,7 +215,7 @@ type scanPlanBinaryFloat8ToFloat64 struct{} func (scanPlanBinaryFloat8ToFloat64) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 8 { @@ -292,7 +292,7 @@ type scanPlanTextAnyToFloat64 struct{} func (scanPlanTextAnyToFloat64) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } n, err := strconv.ParseFloat(string(src), 64) diff --git a/pgtype/int.go b/pgtype/int.go index 147a3656..1cda0ba3 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -296,7 +296,7 @@ type scanPlanBinaryInt2ToInt8 struct{} func (scanPlanBinaryInt2ToInt8) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 2 { @@ -324,7 +324,7 @@ type scanPlanBinaryInt2ToUint8 struct{} func (scanPlanBinaryInt2ToUint8) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 2 { @@ -354,7 +354,7 @@ type scanPlanBinaryInt2ToInt16 struct{} func (scanPlanBinaryInt2ToInt16) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 2 { @@ -375,7 +375,7 @@ type scanPlanBinaryInt2ToUint16 struct{} func (scanPlanBinaryInt2ToUint16) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 2 { @@ -401,7 +401,7 @@ type scanPlanBinaryInt2ToInt32 struct{} func (scanPlanBinaryInt2ToInt32) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 2 { @@ -422,7 +422,7 @@ type scanPlanBinaryInt2ToUint32 struct{} func (scanPlanBinaryInt2ToUint32) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 2 { @@ -448,7 +448,7 @@ type scanPlanBinaryInt2ToInt64 struct{} func (scanPlanBinaryInt2ToInt64) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 2 { @@ -469,7 +469,7 @@ type scanPlanBinaryInt2ToUint64 struct{} func (scanPlanBinaryInt2ToUint64) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 2 { @@ -495,7 +495,7 @@ type scanPlanBinaryInt2ToInt struct{} func (scanPlanBinaryInt2ToInt) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 2 { @@ -516,7 +516,7 @@ type scanPlanBinaryInt2ToUint struct{} func (scanPlanBinaryInt2ToUint) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 2 { @@ -856,7 +856,7 @@ type scanPlanBinaryInt4ToInt8 struct{} func (scanPlanBinaryInt4ToInt8) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 4 { @@ -884,7 +884,7 @@ type scanPlanBinaryInt4ToUint8 struct{} func (scanPlanBinaryInt4ToUint8) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 4 { @@ -914,7 +914,7 @@ type scanPlanBinaryInt4ToInt16 struct{} func (scanPlanBinaryInt4ToInt16) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 4 { @@ -942,7 +942,7 @@ type scanPlanBinaryInt4ToUint16 struct{} func (scanPlanBinaryInt4ToUint16) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 4 { @@ -972,7 +972,7 @@ type scanPlanBinaryInt4ToInt32 struct{} func (scanPlanBinaryInt4ToInt32) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 4 { @@ -993,7 +993,7 @@ type scanPlanBinaryInt4ToUint32 struct{} func (scanPlanBinaryInt4ToUint32) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 4 { @@ -1019,7 +1019,7 @@ type scanPlanBinaryInt4ToInt64 struct{} func (scanPlanBinaryInt4ToInt64) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 4 { @@ -1040,7 +1040,7 @@ type scanPlanBinaryInt4ToUint64 struct{} func (scanPlanBinaryInt4ToUint64) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 4 { @@ -1066,7 +1066,7 @@ type scanPlanBinaryInt4ToInt struct{} func (scanPlanBinaryInt4ToInt) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 4 { @@ -1087,7 +1087,7 @@ type scanPlanBinaryInt4ToUint struct{} func (scanPlanBinaryInt4ToUint) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 4 { @@ -1427,7 +1427,7 @@ type scanPlanBinaryInt8ToInt8 struct{} func (scanPlanBinaryInt8ToInt8) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 8 { @@ -1455,7 +1455,7 @@ type scanPlanBinaryInt8ToUint8 struct{} func (scanPlanBinaryInt8ToUint8) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 8 { @@ -1485,7 +1485,7 @@ type scanPlanBinaryInt8ToInt16 struct{} func (scanPlanBinaryInt8ToInt16) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 8 { @@ -1513,7 +1513,7 @@ type scanPlanBinaryInt8ToUint16 struct{} func (scanPlanBinaryInt8ToUint16) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 8 { @@ -1543,7 +1543,7 @@ type scanPlanBinaryInt8ToInt32 struct{} func (scanPlanBinaryInt8ToInt32) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 8 { @@ -1571,7 +1571,7 @@ type scanPlanBinaryInt8ToUint32 struct{} func (scanPlanBinaryInt8ToUint32) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 8 { @@ -1601,7 +1601,7 @@ type scanPlanBinaryInt8ToInt64 struct{} func (scanPlanBinaryInt8ToInt64) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 8 { @@ -1622,7 +1622,7 @@ type scanPlanBinaryInt8ToUint64 struct{} func (scanPlanBinaryInt8ToUint64) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 8 { @@ -1648,7 +1648,7 @@ type scanPlanBinaryInt8ToInt struct{} func (scanPlanBinaryInt8ToInt) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 8 { @@ -1676,7 +1676,7 @@ type scanPlanBinaryInt8ToUint struct{} func (scanPlanBinaryInt8ToUint) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 8 { @@ -1748,7 +1748,7 @@ type scanPlanTextAnyToInt8 struct{} func (scanPlanTextAnyToInt8) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } p, ok := (dst).(*int8) @@ -1769,7 +1769,7 @@ type scanPlanTextAnyToUint8 struct{} func (scanPlanTextAnyToUint8) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } p, ok := (dst).(*uint8) @@ -1790,7 +1790,7 @@ type scanPlanTextAnyToInt16 struct{} func (scanPlanTextAnyToInt16) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } p, ok := (dst).(*int16) @@ -1811,7 +1811,7 @@ type scanPlanTextAnyToUint16 struct{} func (scanPlanTextAnyToUint16) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } p, ok := (dst).(*uint16) @@ -1832,7 +1832,7 @@ type scanPlanTextAnyToInt32 struct{} func (scanPlanTextAnyToInt32) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } p, ok := (dst).(*int32) @@ -1853,7 +1853,7 @@ type scanPlanTextAnyToUint32 struct{} func (scanPlanTextAnyToUint32) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } p, ok := (dst).(*uint32) @@ -1874,7 +1874,7 @@ type scanPlanTextAnyToInt64 struct{} func (scanPlanTextAnyToInt64) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } p, ok := (dst).(*int64) @@ -1895,7 +1895,7 @@ type scanPlanTextAnyToUint64 struct{} func (scanPlanTextAnyToUint64) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } p, ok := (dst).(*uint64) @@ -1916,7 +1916,7 @@ type scanPlanTextAnyToInt struct{} func (scanPlanTextAnyToInt) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } p, ok := (dst).(*int) @@ -1937,7 +1937,7 @@ type scanPlanTextAnyToUint struct{} func (scanPlanTextAnyToUint) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } p, ok := (dst).(*uint) diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index f46a1dc3..572408e1 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -299,7 +299,7 @@ type scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %> struct{} func (scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %>) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != <%= pg_byte_size %> { @@ -333,7 +333,7 @@ type scanPlanBinaryInt<%= pg_byte_size %>ToUint<%= dst_bit_size %> struct{} func (scanPlanBinaryInt<%= pg_byte_size %>ToUint<%= dst_bit_size %>) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != <%= pg_byte_size %> { @@ -365,7 +365,7 @@ type scanPlanBinaryInt<%= pg_byte_size %>ToInt struct{} func (scanPlanBinaryInt<%= pg_byte_size %>ToInt) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != <%= pg_byte_size %> { @@ -397,7 +397,7 @@ type scanPlanBinaryInt<%= pg_byte_size %>ToUint struct{} func (scanPlanBinaryInt<%= pg_byte_size %>ToUint) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != <%= pg_byte_size %> { @@ -482,7 +482,7 @@ type scanPlanTextAnyToInt<%= type_suffix %> struct{} func (scanPlanTextAnyToInt<%= type_suffix %>) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } p, ok := (dst).(*int<%= type_suffix %>) @@ -503,7 +503,7 @@ type scanPlanTextAnyToUint<%= type_suffix %> struct{} func (scanPlanTextAnyToUint<%= type_suffix %>) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } p, ok := (dst).(*uint<%= type_suffix %>) diff --git a/pgtype/json.go b/pgtype/json.go index 51f2509f..0a089059 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -131,7 +131,7 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { } } - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } return json.Unmarshal(src, dst) diff --git a/pgtype/json_test.go b/pgtype/json_test.go index 04964e8c..d9d28404 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -86,10 +86,10 @@ func TestJSONCodecUnmarshalSQLNull(t *testing.T) { // A string cannot scan a NULL. str := "foobar" err = conn.QueryRow(ctx, "select null::json").Scan(&str) - require.EqualError(t, err, "can't scan into dest[0]: cannot scan null into *string") + require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *string") // A non-string cannot scan a NULL. err = conn.QueryRow(ctx, "select null::json").Scan(&n) - require.EqualError(t, err, "can't scan into dest[0]: cannot scan null into *int") + require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *int") }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 793ec28a..f8ad2bf3 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -529,7 +529,7 @@ type scanPlanString struct{} func (scanPlanString) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } p := (dst).(*string) diff --git a/pgtype/qchar.go b/pgtype/qchar.go index 0e65041f..fc40a5b2 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -74,7 +74,7 @@ type scanPlanQcharCodecByte struct{} func (scanPlanQcharCodecByte) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) > 1 { @@ -96,7 +96,7 @@ type scanPlanQcharCodecRune struct{} func (scanPlanQcharCodecRune) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) > 1 { diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go index f4ed41b6..8cfb3a63 100644 --- a/pgtype/range_codec.go +++ b/pgtype/range_codec.go @@ -107,7 +107,7 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byt if lowerType != Unbounded { if lower == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded") } sp := len(buf) @@ -123,7 +123,7 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byt return nil, fmt.Errorf("failed to encode %v as element of range: %v", lower, err) } if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -131,7 +131,7 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byt if upperType != Unbounded { if upper == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded") } sp := len(buf) @@ -147,7 +147,7 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byt return nil, fmt.Errorf("failed to encode %v as element of range: %v", upper, err) } if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -184,7 +184,7 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) if lowerType != Unbounded { if lower == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded") } lowerPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, lower) @@ -197,7 +197,7 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) return nil, fmt.Errorf("failed to encode %v as element of range: %v", lower, err) } if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded") } } @@ -205,7 +205,7 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) if upperType != Unbounded { if upper == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded") } upperPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, upper) @@ -218,7 +218,7 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) return nil, fmt.Errorf("failed to encode %v as element of range: %v", upper, err) } if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded") } } diff --git a/pgtype/text.go b/pgtype/text.go index 7f779d11..021ee331 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -180,7 +180,7 @@ type scanPlanTextAnyToString struct{} func (scanPlanTextAnyToString) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } p := (dst).(*string) diff --git a/pgtype/uint32.go b/pgtype/uint32.go index 37e0c65f..098c516c 100644 --- a/pgtype/uint32.go +++ b/pgtype/uint32.go @@ -248,7 +248,7 @@ type scanPlanBinaryUint32ToUint32 struct{} func (scanPlanBinaryUint32ToUint32) Scan(src []byte, dst any) error { if src == nil { - return fmt.Errorf("cannot scan null into %T", dst) + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 4 { From bb6c9971023beed746bce04a8f979fc99688a09c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 23 Aug 2022 19:39:15 -0500 Subject: [PATCH 1149/1158] Add NewCommandTag Useful for mocking and testing. https://github.com/jackc/pgx/issues/1273#issuecomment-1224154013 --- pgconn/pgconn.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index e29283f4..44de2897 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -667,6 +667,11 @@ type CommandTag struct { s string } +// NewCommandTag makes a CommandTag from s. +func NewCommandTag(s string) CommandTag { + return CommandTag{s: s} +} + // RowsAffected returns the number of rows affected. If the CommandTag was not // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. func (ct CommandTag) RowsAffected() int64 { From f8d088cfb613ce0d72c28179b08fdedbce07f2a7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 2 Sep 2022 18:37:02 -0500 Subject: [PATCH 1150/1158] Fix JSON scan not completely overwriting destination See https://github.com/jackc/pgtype/pull/185 for original report in pgx v4 / pgtype. --- pgtype/json.go | 3 +++ pgtype/json_test.go | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/pgtype/json.go b/pgtype/json.go index 0a089059..d0d98fc9 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -134,6 +134,9 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { return fmt.Errorf("cannot scan NULL into %T", dst) } + elem := reflect.ValueOf(dst).Elem() + elem.Set(reflect.Zero(elem.Type())) + return json.Unmarshal(src, dst) } diff --git a/pgtype/json_test.go b/pgtype/json_test.go index d9d28404..db20e576 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -93,3 +93,16 @@ func TestJSONCodecUnmarshalSQLNull(t *testing.T) { require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *int") }) } + +func TestJSONCodecClearExistingValueBeforeUnmarshal(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + m := map[string]any{} + err := conn.QueryRow(ctx, `select '{"foo": "bar"}'::json`).Scan(&m) + require.NoError(t, err) + require.Equal(t, map[string]any{"foo": "bar"}, m) + + err = conn.QueryRow(ctx, `select '{"baz": "quz"}'::json`).Scan(&m) + require.NoError(t, err) + require.Equal(t, map[string]any{"baz": "quz"}, m) + }) +} From 782133158f8c3bcd99aed28674053a82574c8d4c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Sep 2022 09:28:42 -0500 Subject: [PATCH 1151/1158] Test sending CopyData before CopyFrom responds with error --- pgconn/pgconn_test.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 601cbc8e..5b6ca284 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -1779,7 +1779,18 @@ func TestConnCopyFromQuerySyntaxError(t *testing.T) { srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") + // Send data even though the COPY FROM command will be rejected with a syntax error. This ensures that this does not + // break the connection. See https://github.com/jackc/pgconn/pull/127 for context. + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } + + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo FROM STDIN WITH (FORMAT csv)") require.Error(t, err) assert.IsType(t, &pgconn.PgError{}, err) assert.Equal(t, int64(0), res.RowsAffected()) From f015ced1bf367cb1f28613dfbfd729157a2de342 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Sep 2022 13:20:19 -0500 Subject: [PATCH 1152/1158] Use puddle v2.0.0-beta.2 for Acquire in background after cancel --- go.mod | 2 +- go.sum | 2 + pgxpool/pool.go | 101 +++++++++++++++++++++--------------------------- 3 files changed, 47 insertions(+), 58 deletions(-) diff --git a/go.mod b/go.mod index 7893a50a..8797ab9a 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.18 require ( github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b - github.com/jackc/puddle/v2 v2.0.0-beta.1 + github.com/jackc/puddle/v2 v2.0.0-beta.2 github.com/stretchr/testify v1.8.0 golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa golang.org/x/text v0.3.7 diff --git a/go.sum b/go.sum index d4cafbf2..8c34d004 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHF github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/puddle/v2 v2.0.0-beta.1 h1:Y4Ao+kFWANtDhWUkdw1JcbH+x84/aq6WUfhVQ1wdib8= github.com/jackc/puddle/v2 v2.0.0-beta.1/go.mod h1:itE7ZJY8xnoo0JqJEpSMprN0f+NQkMCuEV/N9j8h0oc= +github.com/jackc/puddle/v2 v2.0.0-beta.2 h1:xhhtVfiDyh29TTvZPIvY5zld5YYMmA9ErRr+fjMkmE0= +github.com/jackc/puddle/v2 v2.0.0-beta.2/go.mod h1:itE7ZJY8xnoo0JqJEpSMprN0f+NQkMCuEV/N9j8h0oc= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= diff --git a/pgxpool/pool.go b/pgxpool/pool.go index e053904c..cd13be55 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -70,16 +70,6 @@ func (cr *connResource) getPoolRows(c *Conn, r pgx.Rows) *poolRows { return pr } -// detachedCtx wraps a context and will never be canceled, regardless of if -// the wrapped one is cancelled. The Err() method will never return any errors. -type detachedCtx struct { - context.Context -} - -func (detachedCtx) Done() <-chan struct{} { return nil } -func (detachedCtx) Deadline() (time.Time, bool) { return time.Time{}, false } -func (detachedCtx) Err() error { return nil } - // Pool allows for connection reuse. type Pool struct { p *puddle.Pool[*connResource] @@ -197,64 +187,61 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { closeChan: make(chan struct{}), } - p.p = puddle.NewPool( - func(ctx context.Context) (*connResource, error) { - // we ignore cancellation on the original context because its either from - // the health check or its from a query and we don't want to cancel creating - // a connection just because the original query was cancelled since that - // could end up stampeding the server - // this will keep any Values in the original context and will just ignore - // cancellation - // see https://github.com/jackc/pgx/issues/1259 - ctx = detachedCtx{ctx} + var err error + p.p, err = puddle.NewPool( + &puddle.Config[*connResource]{ + Constructor: func(ctx context.Context) (*connResource, error) { + connConfig := p.config.ConnConfig.Copy() - connConfig := p.config.ConnConfig.Copy() - - // But we do want to ensure that a connect won't hang forever. - if connConfig.ConnectTimeout <= 0 { - connConfig.ConnectTimeout = 2 * time.Minute - } - - if p.beforeConnect != nil { - if err := p.beforeConnect(ctx, connConfig); err != nil { - return nil, err + // Connection will continue in background even if Acquire is canceled. Ensure that a connect won't hang forever. + if connConfig.ConnectTimeout <= 0 { + connConfig.ConnectTimeout = 2 * time.Minute } - } - conn, err := pgx.ConnectConfig(ctx, connConfig) - if err != nil { - return nil, err - } + if p.beforeConnect != nil { + if err := p.beforeConnect(ctx, connConfig); err != nil { + return nil, err + } + } - if p.afterConnect != nil { - err = p.afterConnect(ctx, conn) + conn, err := pgx.ConnectConfig(ctx, connConfig) if err != nil { - conn.Close(ctx) return nil, err } - } - cr := &connResource{ - conn: conn, - conns: make([]Conn, 64), - poolRows: make([]poolRow, 64), - poolRowss: make([]poolRows, 64), - } + if p.afterConnect != nil { + err = p.afterConnect(ctx, conn) + if err != nil { + conn.Close(ctx) + return nil, err + } + } - return cr, nil + cr := &connResource{ + conn: conn, + conns: make([]Conn, 64), + poolRows: make([]poolRow, 64), + poolRowss: make([]poolRows, 64), + } + + return cr, nil + }, + Destructor: func(value *connResource) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + conn := value.conn + conn.Close(ctx) + select { + case <-conn.PgConn().CleanupDone(): + case <-ctx.Done(): + } + cancel() + }, + MaxSize: config.MaxConns, }, - func(value *connResource) { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - conn := value.conn - conn.Close(ctx) - select { - case <-conn.PgConn().CleanupDone(): - case <-ctx.Done(): - } - cancel() - }, - config.MaxConns, ) + if err != nil { + return nil, err + } go func() { p.createIdleResources(ctx, int(p.minConns)) From ee2622a8e699a7b052f272b08a6354bfdb90dbcd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 6 Sep 2022 18:32:10 -0500 Subject: [PATCH 1153/1158] RowToStructByPos supports embedded structs https://github.com/jackc/pgx/issues/1273#issuecomment-1236966785 --- rows.go | 42 +++++++++++++++++++++++++----------------- rows_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 17 deletions(-) diff --git a/rows.go b/rows.go index 80df4bb2..33d8ab09 100644 --- a/rows.go +++ b/rows.go @@ -511,25 +511,33 @@ func (rs *positionalStructRowScanner) ScanRow(rows Rows) error { } dstElemValue := dstValue.Elem() - dstElemType := dstElemValue.Type() + scanTargets := rs.appendScanTargets(dstElemValue, nil) - exportedFields := make([]int, 0, dstElemType.NumField()) - for i := 0; i < dstElemType.NumField(); i++ { - sf := dstElemType.Field(i) - if sf.PkgPath == "" { - exportedFields = append(exportedFields, i) - } - } - - rowFieldCount := len(rows.RawValues()) - if rowFieldCount > len(exportedFields) { - return fmt.Errorf("got %d values, but dst struct has only %d fields", rowFieldCount, len(exportedFields)) - } - - scanTargets := make([]any, rowFieldCount) - for i := 0; i < rowFieldCount; i++ { - scanTargets[i] = dstElemValue.Field(exportedFields[i]).Addr().Interface() + if len(rows.RawValues()) > len(scanTargets) { + return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets)) } return rows.Scan(scanTargets...) } + +func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any { + dstElemType := dstElemValue.Type() + + if scanTargets == nil { + scanTargets = make([]any, 0, dstElemType.NumField()) + } + + for i := 0; i < dstElemType.NumField(); i++ { + sf := dstElemType.Field(i) + if sf.PkgPath == "" { + // Handle anoymous struct embedding, but do not try to handle embedded pointers. + if sf.Anonymous && sf.Type.Kind() == reflect.Struct { + scanTargets = append(scanTargets, rs.appendScanTargets(dstElemValue.Field(i), scanTargets)...) + } else { + scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) + } + } + } + + return scanTargets +} diff --git a/rows_test.go b/rows_test.go index 6771469f..7aeafac8 100644 --- a/rows_test.go +++ b/rows_test.go @@ -329,6 +329,50 @@ func TestRowToStructByPos(t *testing.T) { }) } +func TestRowToStructByPosEmbeddedStruct(t *testing.T) { + type Name struct { + First string + Last string + } + + type person struct { + Name + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByPos[person]) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "John", slice[i].Name.First) + assert.Equal(t, "Smith", slice[i].Name.Last) + assert.EqualValues(t, i, slice[i].Age) + } + }) +} + +// Pointer to struct is not supported. But check that we don't panic. +func TestRowToStructByPosEmbeddedPointerToStruct(t *testing.T) { + type Name struct { + First string + Last string + } + + type person struct { + *Name + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age from generate_series(0, 9) n`) + _, err := pgx.CollectRows(rows, pgx.RowToStructByPos[person]) + require.EqualError(t, err, "got 3 values, but dst struct has only 2 fields") + }) +} + func ExampleRowToStructByPos() { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() From 90b69c0ee006e7b43097cd393d0ccf31aaa02766 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Aug 2022 09:23:17 -0500 Subject: [PATCH 1154/1158] Fix atomic alignment on 32-bit platforms refs #1288 --- pgxpool/pool.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pgxpool/pool.go b/pgxpool/pool.go index cd13be55..236ba000 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -72,6 +72,12 @@ func (cr *connResource) getPoolRows(c *Conn, r pgx.Rows) *poolRows { // Pool allows for connection reuse. type Pool struct { + // 64 bit fields accessed with atomics must be at beginning of struct to guarantee alignment for certain 32-bit + // architectures. See BUGS section of https://pkg.go.dev/sync/atomic and https://github.com/jackc/pgx/issues/1288. + newConnsCount int64 + lifetimeDestroyCount int64 + idleDestroyCount int64 + p *puddle.Pool[*connResource] config *Config beforeConnect func(context.Context, *pgx.ConnConfig) error @@ -87,10 +93,6 @@ type Pool struct { healthCheckChan chan struct{} - newConnsCount int64 - lifetimeDestroyCount int64 - idleDestroyCount int64 - closeOnce sync.Once closeChan chan struct{} } From a05fb80b8af95f7dbad426aa6faab393df105290 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 16 Sep 2022 18:16:36 -0500 Subject: [PATCH 1155/1158] Update docs and changelog for renamed pgxpool.NewWithConfig fixes https://github.com/jackc/pgx/issues/1306 --- CHANGELOG.md | 2 +- pgxpool/doc.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fe8bdf9..b0f28d00 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,7 @@ pgconn now supports pipeline mode. ## pgxpool -`Connect` and `ConnectConfig` have been renamed to `New` and `NewConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect. +`Connect` and `ConnectConfig` have been renamed to `New` and `NewWithConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect. ## pgtype diff --git a/pgxpool/doc.go b/pgxpool/doc.go index 07f6359d..38e49795 100644 --- a/pgxpool/doc.go +++ b/pgxpool/doc.go @@ -20,7 +20,7 @@ connection with `ConnectConfig`. // do something with every new connection } - pool, err := pgxpool.NewConfig(context.Background(), config) + pool, err := pgxpool.NewWithConfig(context.Background(), config) A pool returns without waiting for any connections to be established. Acquire a connection immediately after creating the pool to check if a connection can successfully be established. From 4f1a8084f1d625e0cc943e2ecc044c7ef1cb4400 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 17 Sep 2022 09:03:48 -0500 Subject: [PATCH 1156/1158] Various doc and changelog tweaks --- CHANGELOG.md | 14 ++++++------- README.md | 58 ++++++++++++++++++---------------------------------- doc.go | 3 ++- 3 files changed, 29 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b0f28d00..2755a4ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -# Unreleased v5 +# v5.0.0 ## Merged Packages @@ -33,7 +33,7 @@ The `pgtype` package has been significantly changed. ### NULL Representation Previously, types had a `Status` field that could be `Undefined`, `Null`, or `Present`. This has been changed to a -`Valid` `bool` field to harmonize with how `database/sql` represents NULL and to make the zero value useable. +`Valid` `bool` field to harmonize with how `database/sql` represents `NULL` and to make the zero value useable. ### Codec and Value Split @@ -47,9 +47,9 @@ generally defined by implementing an interface that a particular `Codec` underst ### Array Types -All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This -significantly reduced the amount of code and the compiled binary size. This also means that less common array types such -as `point[]` are now supported. `Array[T]` supports PostgreSQL multi-dimensional arrays. +All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This also +means that less common array types such as `point[]` are now supported. `Array[T]` supports PostgreSQL multi-dimensional +arrays. ### Composite Types @@ -63,7 +63,7 @@ easily be handled. Multirange types are handled similarly with `MultirangeCodec` ### pgxtype -load data type moved to conn +`LoadDataType` moved to `*Conn` as `LoadType`. ### Bytea @@ -97,7 +97,7 @@ This matches the convention set by `database/sql`. In addition, for comparable t ### 3rd Party Type Integrations -* Extracted integrations with github.com/shopspring/decimal and github.com/gofrs/uuid to +* Extracted integrations with https://github.com/shopspring/decimal and https://github.com/gofrs/uuid to https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. This trims the pgx dependency tree. diff --git a/README.md b/README.md index 2bfa6a2b..fbffcf7e 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,12 @@ [![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://pkg.go.dev/github.com/jackc/pgx/v5) -[![Build Status](https://travis-ci.org/jackc/pgx.svg)](https://travis-ci.org/jackc/pgx) +![Build Status](https://github.com/jackc/pgx/actions/workflows/ci.yml/badge.svg) # pgx - PostgreSQL Driver and Toolkit -*This is the v5 development branch. It is still in beta testing.* - pgx is a pure Go driver and toolkit for PostgreSQL. -pgx aims to be low-level, fast, and performant, while also enabling PostgreSQL-specific features that the standard `database/sql` package does not allow for. - -The driver component of pgx can be used alongside the standard `database/sql` package. +The pgx driver is a low-level, high performance interface that exposes PostgreSQL-specific features such as `LISTEN` / +`NOTIFY` and `COPY`. It also includes an adapter for the standard `database/sql` interface. The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers, @@ -51,52 +48,39 @@ func main() { See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-started-with-pgx) for more information. -## Choosing Between the pgx and database/sql Interfaces - -It is recommended to use the pgx interface if: -1. The application only targets PostgreSQL. -2. No other libraries that require `database/sql` are in use. - -The pgx interface is faster and exposes more features. - -The `database/sql` interface only allows the underlying driver to return or receive the following types: `int64`, -`float64`, `bool`, `[]byte`, `string`, `time.Time`, or `nil`. Handling other types requires implementing the -`database/sql.Scanner` and the `database/sql/driver/driver.Valuer` interfaces which require transmission of values in text format. The binary format can be substantially faster, which is what the pgx interface uses. - ## Features -pgx supports many features beyond what is available through `database/sql`: - * Support for approximately 70 different PostgreSQL types * Automatic statement preparation and caching * Batch queries * Single-round trip query mode * Full TLS connection control * Binary format support for custom types (allows for much quicker encoding/decoding) -* COPY protocol support for faster bulk data loads -* Extendable logging support +* `COPY` protocol support for faster bulk data loads +* Tracing and logging support * Connection pool with after-connect hook for arbitrary connection setup -* Listen / notify +* `LISTEN` / `NOTIFY` * Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings -* Hstore support -* JSON and JSONB support +* `hstore` support +* `json` and `jsonb` support * Maps `inet` and `cidr` PostgreSQL types to `netip.Addr` and `netip.Prefix` * Large object support -* NULL mapping to Null* struct or pointer to pointer +* NULL mapping to pointer to pointer * Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types * Notice response handling * Simulated nested transactions with savepoints -## Performance +## Choosing Between the pgx and database/sql Interfaces -There are three areas in particular where pgx can provide a significant performance advantage over the standard -`database/sql` interface and other drivers: +The pgx interface is faster. Many PostgreSQL specific features such as `LISTEN` / `NOTIFY` and `COPY` are not available +through the `database/sql` interface. -1. PostgreSQL specific types - Types such as arrays can be parsed much quicker because pgx uses the binary format. -2. Automatic statement preparation and caching - pgx will prepare and cache statements by default. This can provide an - significant free improvement to code that does not explicitly use prepared statements. Under certain workloads, it can - perform nearly 3x the number of queries per second. -3. Batched queries - Multiple queries can be batched together to minimize network round trips. +The pgx interface is recommended when: + +1. The application only targets PostgreSQL. +2. No other libraries that require `database/sql` are in use. + +It is also possible to use the `database/sql` interface and convert a connection to the lower-level pgx interface as needed. ## Testing @@ -129,13 +113,11 @@ In addition, there are tests specific for PgBouncer that will be executed if `PG ## Supported Go and PostgreSQL Versions -~~pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.17 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).~~ - -`v5` is targeted at Go 1.18+. The general release of `v5` is not planned until second half of 2022 so it is expected that the policy of supporting the two most recent versions of Go will be maintained or restored soon after its release. +pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.18 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). ## Version Policy -pgx follows semantic versioning for the documented public API on stable releases. `v4` is the latest stable major version. +pgx follows semantic versioning for the documented public API on stable releases. `v5` is the latest stable major version. ## PGX Family Libraries diff --git a/doc.go b/doc.go index 9f3f9774..497ab660 100644 --- a/doc.go +++ b/doc.go @@ -97,7 +97,8 @@ Transactions are started by calling Begin. The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions. These are internally implemented with savepoints. -Use BeginTx to control the transaction mode. +Use BeginTx to control the transaction mode. BeginTx also can be used to ensure a new transaction is created instead of +a pseudo nested transaction. BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the transaction depending on the return value of the function. These can be simpler and less error prone to use. From 1a314bda3b604205092e4f48df90bbcbfaaef960 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 17 Sep 2022 10:18:06 -0500 Subject: [PATCH 1157/1158] pgconn.Timeout() no longer considers `context.Canceled` as a timeout error. https://github.com/jackc/pgconn/issues/81 --- CHANGELOG.md | 2 ++ pgconn/errors.go | 17 +++++++++++------ pgconn/pgconn.go | 28 ++++++++++++---------------- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2755a4ca..32acfdda 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ pgconn now supports pipeline mode. `*PgConn.ReceiveResults` removed. Use pipeline mode instead. +`Timeout()` no longer considers `context.Canceled` as a timeout error. `context.DeadlineExceeded` still is considered a timeout error. + ## pgxpool `Connect` and `ConnectConfig` have been renamed to `New` and `NewWithConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect. diff --git a/pgconn/errors.go b/pgconn/errors.go index 4254535e..3c54bbec 100644 --- a/pgconn/errors.go +++ b/pgconn/errors.go @@ -19,7 +19,7 @@ func SafeToRetry(err error) bool { } // Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a -// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. +// context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. func Timeout(err error) bool { var timeoutErr *errTimeout return errors.As(err, &timeoutErr) @@ -106,11 +106,16 @@ func (e *parseConfigError) Unwrap() error { return e.err } -// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == -// true. Otherwise returns err. -func preferContextOverNetTimeoutError(ctx context.Context, err error) error { - if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { - return &errTimeout{err: ctx.Err()} +func normalizeTimeoutError(ctx context.Context, err error) error { + if err, ok := err.(net.Error); ok && err.Timeout() { + if ctx.Err() == context.Canceled { + // Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error. + return context.Canceled + } else if ctx.Err() == context.DeadlineExceeded { + return &errTimeout{err: ctx.Err()} + } else { + return &errTimeout{err: err} + } } return err } diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 44de2897..59fa35c6 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -255,11 +255,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) netConn, err := config.DialFunc(ctx, network, address) if err != nil { - var netErr net.Error - if errors.As(err, &netErr) && netErr.Timeout() { - err = &errTimeout{err: err} - } - return nil, &connectError{config: config, msg: "dial error", err: err} + return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} } nbNetConn := nbconn.NewNetConn(netConn, false) @@ -314,7 +310,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err, ok := err.(*PgError); ok { return nil, err } - return nil, &connectError{config: config, msg: "failed to receive message", err: preferContextOverNetTimeoutError(ctx, err)} + return nil, &connectError{config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)} } switch msg := msg.(type) { @@ -448,7 +444,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa if err != nil { err = &pgconnError{ msg: "receive message failed", - err: preferContextOverNetTimeoutError(ctx, err), + err: normalizeTimeoutError(ctx, err), safeToRetry: true} } return msg, err @@ -794,7 +790,7 @@ readloop: msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, normalizeTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -907,7 +903,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { for { msg, err := pgConn.receiveMessage() if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return normalizeTimeoutError(ctx, err) } switch msg.(type) { @@ -1106,7 +1102,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return CommandTag{}, preferContextOverNetTimeoutError(ctx, err) + return CommandTag{}, normalizeTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1203,7 +1199,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co break } pgConn.asyncClose() - return CommandTag{}, preferContextOverNetTimeoutError(ctx, err) + return CommandTag{}, normalizeTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1238,7 +1234,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return CommandTag{}, preferContextOverNetTimeoutError(ctx, err) + return CommandTag{}, normalizeTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1281,7 +1277,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) if err != nil { mrr.pgConn.contextWatcher.Unwatch() - mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) + mrr.err = normalizeTimeoutError(mrr.ctx, err) mrr.closed = true mrr.pgConn.asyncClose() return nil, mrr.err @@ -1497,7 +1493,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error } if err != nil { - err = preferContextOverNetTimeoutError(rr.ctx, err) + err = normalizeTimeoutError(rr.ctx, err) rr.concludeCommand(CommandTag{}, err) rr.pgConn.contextWatcher.Unwatch() rr.closed = true @@ -1814,7 +1810,7 @@ func (p *Pipeline) Flush() error { err := p.conn.frontend.Flush() if err != nil { - err = preferContextOverNetTimeoutError(p.ctx, err) + err = normalizeTimeoutError(p.ctx, err) p.conn.asyncClose() @@ -1901,7 +1897,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { msg, err := p.conn.receiveMessage() if err != nil { p.conn.asyncClose() - return nil, preferContextOverNetTimeoutError(p.ctx, err) + return nil, normalizeTimeoutError(p.ctx, err) } switch msg := msg.(type) { From 5a055434f2dc61fcf8f33d7585fdb53dc1fd442d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 17 Sep 2022 10:24:19 -0500 Subject: [PATCH 1158/1158] Upgrade dependencies --- go.mod | 8 ++++---- go.sum | 26 +++++++++++++++----------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index 8797ab9a..5b7109ac 100644 --- a/go.mod +++ b/go.mod @@ -5,16 +5,16 @@ go 1.18 require ( github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b - github.com/jackc/puddle/v2 v2.0.0-beta.2 + github.com/jackc/puddle/v2 v2.0.0 github.com/stretchr/testify v1.8.0 - golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa + golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 golang.org/x/text v0.3.7 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/kr/pretty v0.1.0 // indirect + github.com/kr/pretty v0.3.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 8c34d004..0f1a952c 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,4 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -5,17 +6,20 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= -github.com/jackc/puddle/v2 v2.0.0-beta.1 h1:Y4Ao+kFWANtDhWUkdw1JcbH+x84/aq6WUfhVQ1wdib8= -github.com/jackc/puddle/v2 v2.0.0-beta.1/go.mod h1:itE7ZJY8xnoo0JqJEpSMprN0f+NQkMCuEV/N9j8h0oc= -github.com/jackc/puddle/v2 v2.0.0-beta.2 h1:xhhtVfiDyh29TTvZPIvY5zld5YYMmA9ErRr+fjMkmE0= -github.com/jackc/puddle/v2 v2.0.0-beta.2/go.mod h1:itE7ZJY8xnoo0JqJEpSMprN0f+NQkMCuEV/N9j8h0oc= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/jackc/puddle/v2 v2.0.0 h1:Kwk/AlLigcnZsDssc3Zun1dk1tAtQNPaBBxBHWn0Mjc= +github.com/jackc/puddle/v2 v2.0.0/go.mod h1:itE7ZJY8xnoo0JqJEpSMprN0f+NQkMCuEV/N9j8h0oc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -23,15 +27,15 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b h1:QAqMVf3pSa6eeTsuklijukjXBlj7Es2QQplab+/RbQ4= -golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= -golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 h1:Y/gsMcFOcR+6S6f3YeMKl5g+dZMEWqcz5Czj/GWYbkM= +golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=